Commit 73a6cb8b authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Support Auto Layout Inference and Parallelism with variable constraint (#417)

* [Enhancement] Introduce thread range management in layout and operation handling

* Added `SetThreadRange` method to `FragmentNode` for managing thread ranges.
* Updated `LayoutNode::Inverse` to provide more informative error messages.
* Refactored layout inference and operation lowering to utilize `thread_bounds` instead of `block_size`, enhancing flexibility for thread management.
* Introduced new tests for tilelang operations to validate thread range functionality and ensure correctness in parallel execution scenarios.

* lint fix

* [Refactor] Improve thread variable handling in layout inference and operation lowering

* Removed workaround for undefined thread_var in layout inference, ensuring proper handling of thread bounds.
* Updated logic to define thread bounds based on the presence of thread_var, enhancing robustness in thread management.
* Refactored thread_var initialization in lower_tile_op to maintain consistency across the codebase.

* [Refactor] Update thread variable handling in layout inference and operation lowering

* Refactored thread variable checks to ensure bounds are only accessed when defined, improving safety and clarity.
* Initialized thread_var with a default range to prevent undefined behavior.
* Updated logic in lower_tile_op to align with new thread variable handling, enhancing consistency across the codebase.
parent 4f24d8de
Subproject commit 4776d3119e43fe064d890b94be7587d53106b9e3 Subproject commit b0e25c31754df1a204773303f435abc55ceba1cf
...@@ -201,12 +201,18 @@ Fragment FragmentNode::DeReplicate() const { ...@@ -201,12 +201,18 @@ Fragment FragmentNode::DeReplicate() const {
int(*rep_size) / factor, NullOpt); int(*rep_size) / factor, NullOpt);
} }
Fragment FragmentNode::SetThreadRange(Range thread_range) {
thread_range_ = thread_range;
return GetRef<Fragment>(this);
}
Layout LayoutNode::Inverse() const { Layout LayoutNode::Inverse() const {
arith::Analyzer analyzer; arith::Analyzer analyzer;
arith::IterMapResult res = arith::IterMapResult res =
arith::DetectIterMap(forward_index_, getVarMap(), 1, arith::DetectIterMap(forward_index_, getVarMap(), 1,
arith::IterMapLevel::Bijective, &analyzer); arith::IterMapLevel::Bijective, &analyzer);
ICHECK(res->errors.empty()) << res->errors; ICHECK(res->errors.empty())
<< "Layout " << DebugOutput() << " has errors: " << res->errors;
auto outputs_shape = OutputShape(); auto outputs_shape = OutputShape();
Array<PrimExpr> outputs; Array<PrimExpr> outputs;
......
...@@ -95,15 +95,21 @@ public: ...@@ -95,15 +95,21 @@ public:
std::string DebugOutput() const final; std::string DebugOutput() const final;
Fragment SetThreadRange(Range thread_range);
Range ThreadRange() const { return thread_range_; }
bool IsEqual(const FragmentNode *other, bool skip_index = false) const; bool IsEqual(const FragmentNode *other, bool skip_index = false) const;
void VisitAttrs(tvm::AttrVisitor *v); void VisitAttrs(tvm::AttrVisitor *v);
bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const; bool SEqualReduce(const FragmentNode *other, SEqualReducer equal) const;
static constexpr const char *_type_key = "tl.Fragment"; static constexpr const char *_type_key = "tl.Fragment";
TVM_DECLARE_FINAL_OBJECT_INFO(FragmentNode, LayoutNode); TVM_DECLARE_FINAL_OBJECT_INFO(FragmentNode, LayoutNode);
protected: protected:
Map<Var, Range> getVarMap() const final; Map<Var, Range> getVarMap() const final;
Range thread_range_;
PrimExpr forward_thread_; PrimExpr forward_thread_;
PrimExpr replicate_size_; PrimExpr replicate_size_;
}; };
......
...@@ -160,8 +160,9 @@ Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -160,8 +160,9 @@ Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (is_cpu_target) { if (is_cpu_target) {
vectorized_thread_loop = VectorizeLoop(fused_loop); vectorized_thread_loop = VectorizeLoop(fused_loop);
} else { } else {
par_op->InferLayout({T.target, T.block_size, T.layout_map, T.buffer_remap}, par_op->InferLayout(
InferLevel::kFree); {T.target, T.thread_bounds, T.layout_map, T.buffer_remap},
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout()); par_op->GetLoopLayout());
vectorized_thread_loop = VectorizeLoop(thread_loop); vectorized_thread_loop = VectorizeLoop(thread_loop);
...@@ -421,9 +422,9 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -421,9 +422,9 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (dst.scope() == "local.fragment") { if (dst.scope() == "local.fragment") {
auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer)); auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.block_size, T.layout_map}, par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
InferLevel::kFree); InferLevel::kFree);
par_op->InferLayout({T.target, T.block_size, T.layout_map}, par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
InferLevel::kFree); InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout()); par_op->GetLoopLayout());
...@@ -439,7 +440,7 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -439,7 +440,7 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop; return vectorized_thread_loop;
} else if (dst.scope() == "shared.dyn" || dst.scope() == "shared") { } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared") {
auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer)); auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.block_size, T.layout_map}, par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},
InferLevel::kFree); InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout()); par_op->GetLoopLayout());
......
...@@ -114,12 +114,12 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -114,12 +114,12 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (TargetIsCDNA(T.target)) { if (TargetIsCDNA(T.target)) {
warp_size = 64; warp_size = 64;
} }
auto block_size = *as_const_int(T.thread_bounds->extent);
bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) && bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
(T.block_size / warp_size % 4 == 0); (block_size / warp_size % 4 == 0);
auto [warp_m, warp_n] = auto [warp_m, warp_n] =
ComputeWarpPartition(T.block_size / warp_size, T.target, maybe_wgmma); ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
std::stringstream ss; std::stringstream ss;
std::string op_name = "tl::gemm_ss"; std::string op_name = "tl::gemm_ss";
...@@ -161,11 +161,11 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -161,11 +161,11 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
return {}; return {};
LayoutMap results; LayoutMap results;
ICHECK(C.scope() == "local.fragment"); ICHECK(C.scope() == "local.fragment");
auto block_size = *as_const_int(T.thread_bounds->extent);
if (TargetIsVolta(T.target)) { if (TargetIsVolta(T.target)) {
const int warp_size = 32; const int warp_size = 32;
auto [warp_m, warp_n] = auto [warp_m, warp_n] =
ComputeWarpPartition(T.block_size / warp_size, T.target); ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment = auto fragment =
makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment); results.Set(C, fragment);
...@@ -187,7 +187,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -187,7 +187,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) { } else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target)) {
const int warp_size = 32; const int warp_size = 32;
auto [warp_m, warp_n] = auto [warp_m, warp_n] =
ComputeWarpPartition(T.block_size / warp_size, T.target); ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment = auto fragment =
makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment); results.Set(C, fragment);
...@@ -219,13 +219,13 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -219,13 +219,13 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} }
} else if (TargetIsHopper(T.target)) { } else if (TargetIsHopper(T.target)) {
const int warp_size = 32; const int warp_size = 32;
bool maybe_wgmma = (this->M >= 64) && (T.block_size / warp_size % 4 == 0); bool maybe_wgmma = (this->M >= 64) && (block_size / warp_size % 4 == 0);
if (!maybe_wgmma) { if (!maybe_wgmma) {
LOG(WARNING) LOG(WARNING)
<< "WGMMA is not enabled because M < 64 or block_size % 128 != 0"; << "WGMMA is not enabled because M < 64 or block_size % 128 != 0";
} }
auto [warp_m, warp_n] = auto [warp_m, warp_n] =
ComputeWarpPartition(T.block_size / warp_size, T.target, maybe_wgmma); ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
auto fragment = auto fragment =
maybe_wgmma maybe_wgmma
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n, ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
...@@ -257,7 +257,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -257,7 +257,7 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} else if (TargetIsCDNA(T.target)) { } else if (TargetIsCDNA(T.target)) {
const int warp_size = 64; const int warp_size = 64;
auto [warp_m, warp_n] = auto [warp_m, warp_n] =
ComputeWarpPartition(T.block_size / warp_size, T.target); ComputeWarpPartition(block_size / warp_size, T.target);
auto fragment = auto fragment =
makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits()); makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
......
...@@ -44,7 +44,7 @@ enum class InferLevel { ...@@ -44,7 +44,7 @@ enum class InferLevel {
struct LowerArgs { struct LowerArgs {
Target target; Target target;
size_t block_size; Range thread_bounds;
Var thread_var; Var thread_var;
AddWorkspaceCallback AddWorkspace; AddWorkspaceCallback AddWorkspace;
LayoutMap layout_map; LayoutMap layout_map;
...@@ -54,7 +54,7 @@ struct LowerArgs { ...@@ -54,7 +54,7 @@ struct LowerArgs {
struct LayoutInferArgs { struct LayoutInferArgs {
Target target; Target target;
size_t block_size; Range thread_bounds;
LayoutMap layout_map; LayoutMap layout_map;
Map<Buffer, Buffer> buffer_remap; Map<Buffer, Buffer> buffer_remap;
}; };
......
...@@ -128,6 +128,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -128,6 +128,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (level == InferLevel::kStrict) if (level == InferLevel::kStrict)
return {}; return {};
auto block_size = T.thread_bounds->extent - T.thread_bounds->min;
// Step 1: try to infer loop's partition from a source fragment // Step 1: try to infer loop's partition from a source fragment
Buffer source_buffer, read_source_buffer; Buffer source_buffer, read_source_buffer;
for (const auto &[buffer, _] : indice_map_) { for (const auto &[buffer, _] : indice_map_) {
...@@ -192,12 +193,10 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -192,12 +193,10 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
LOG(FATAL) << "coalesced_width should be an IntImmNode."; LOG(FATAL) << "coalesced_width should be an IntImmNode.";
} }
} }
loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds);
loop_layout_ = PlanLoopPartition(root_, T.block_size, vector_size);
} }
PrimExpr loop_thread_extent = loop_layout_->ThreadExtent(); PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();
if (!analyzer_.CanProveEqual(loop_thread_extent, if (!analyzer_.CanProveEqual(loop_thread_extent, block_size))
static_cast<int>(T.block_size)))
AddPredicate(LT(InputPlaceholder(0), loop_thread_extent)); AddPredicate(LT(InputPlaceholder(0), loop_thread_extent));
} else { } else {
return {}; return {};
......
...@@ -178,7 +178,8 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -178,7 +178,8 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Array<PrimExpr> thread_reduce_args = { Array<PrimExpr> thread_reduce_args = {
StringImm(ss.str()), BufferLoad(dst_buffer, dst_indices)}; StringImm(ss.str()), BufferLoad(dst_buffer, dst_indices)};
if (reducing_threads >= 32) { if (reducing_threads >= 32) {
PrimExpr workspace = T.AddWorkspace(T.block_size, dst_buffer->dtype); PrimExpr workspace = T.AddWorkspace(
*as_const_int(T.thread_bounds->extent), dst_buffer->dtype);
thread_reduce_args.push_back(workspace); thread_reduce_args.push_back(workspace);
} }
auto call = auto call =
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "../op/parallel.h" #include "../op/parallel.h"
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_fusion_utils.h" #include "common/loop_fusion_utils.h"
#include "loop_partition.h" #include "loop_partition.h"
#include "loop_vectorize.h" #include "loop_vectorize.h"
...@@ -44,6 +45,7 @@ public: ...@@ -44,6 +45,7 @@ public:
using namespace tir; using namespace tir;
using arith::IRMutatorWithAnalyzer; using arith::IRMutatorWithAnalyzer;
using arith::IRVisitorWithAnalyzer;
class ParallelLoopTransformer : public IRMutatorWithAnalyzer { class ParallelLoopTransformer : public IRMutatorWithAnalyzer {
public: public:
...@@ -57,111 +59,108 @@ public: ...@@ -57,111 +59,108 @@ public:
: IRMutatorWithAnalyzer(analyzer) {} : IRMutatorWithAnalyzer(analyzer) {}
Stmt VisitStmt_(const ForNode *op) final { Stmt VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kParallel) { if (op->kind != ForKind::kParallel)
return StmtMutator::VisitStmt_(op);
// Collect loop variables and ranges
auto for_node = GetRef<For>(op);
Array<Var> loop_vars;
Array<PrimExpr> loop_extents;
Stmt body = op->body;
// Bind the range of outer loop variables
analyzer_->Bind(op->loop_var, Range::FromMinExtent(0, op->extent));
loop_vars.push_back(op->loop_var);
loop_extents.push_back(op->extent);
// If there are inner loops, bind their ranges as well
while (const ForNode *inner = body.as<ForNode>()) {
analyzer_->Bind(inner->loop_var, Range::FromMinExtent(0, inner->extent));
loop_vars.push_back(inner->loop_var);
loop_extents.push_back(inner->extent);
body = inner->body;
}
// Collect loop variables and ranges ICHECK(loop_vars.size() == loop_extents.size())
auto for_node = GetRef<For>(op); << "loop_vars and loop_extents size mismatch";
Array<Var> loop_vars;
Array<PrimExpr> loop_extents;
Stmt body = op->body;
// Bind the range of outer loop variables
analyzer_->Bind(op->loop_var, Range::FromMinExtent(0, op->extent));
loop_vars.push_back(op->loop_var);
loop_extents.push_back(op->extent);
// If there are inner loops, bind their ranges as well
while (const ForNode *inner = body.as<ForNode>()) {
analyzer_->Bind(inner->loop_var,
Range::FromMinExtent(0, inner->extent));
loop_vars.push_back(inner->loop_var);
loop_extents.push_back(inner->extent);
body = inner->body;
}
ICHECK(loop_vars.size() == loop_extents.size()) // Collect buffer access information
<< "loop_vars and loop_extents size mismatch"; BufferAccessCollector collector;
collector(op->body);
// Collect buffer access information PrimExpr condition;
BufferAccessCollector collector;
collector(op->body);
PrimExpr condition; for (const auto &[buffer, indices] : collector.buffer_indices) {
ICHECK(indices.size() == buffer->shape.size())
<< "indices size mismatch with buffer shape";
for (const auto &[buffer, indices] : collector.buffer_indices) { for (size_t i = 0; i < indices.size(); ++i) {
ICHECK(indices.size() == buffer->shape.size()) auto index = indices[i];
<< "indices size mismatch with buffer shape"; auto bound = analyzer_->const_int_bound(index);
int64_t upper_bound = bound->max_value + 1;
int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
for (size_t i = 0; i < indices.size(); ++i) { // Collect the variables that used in the index
auto index = indices[i]; std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_vars;
auto bound = analyzer_->const_int_bound(index); // post order visit the index
int64_t upper_bound = bound->max_value + 1; PostOrderVisit(index, [&](const ObjectRef &obj) {
int64_t shape = Downcast<IntImm>(buffer->shape[i])->value; if (const VarNode *v = obj.as<VarNode>()) {
used_vars.insert(GetRef<Var>(v));
}
});
if (used_vars.size() == 0) {
continue;
}
// Collect the variables that used in the index // find related loop vars
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_vars; Array<Var> related_loop_vars;
// post order visit the index for (size_t j = 0; j < loop_vars.size(); ++j) {
PostOrderVisit(index, [&](const ObjectRef &obj) { auto loop_var = loop_vars[j];
if (const VarNode *v = obj.as<VarNode>()) { // if find related, pop the loop_vars and loop_extents
used_vars.insert(GetRef<Var>(v)); if (used_vars.count(loop_var)) {
} related_loop_vars.push_back(loop_var);
});
if (used_vars.size() == 0) {
continue;
} }
ICHECK(related_loop_vars.size() <= 1)
<< "Only one related loop var is supported currently, but got "
<< related_loop_vars
<< " implement multiple loop vars may not be "
<< "too hard, please send an issue if you need "
<< "came up with this message.";
// find related loop vars auto bound = analyzer_->const_int_bound(index);
Array<Var> related_loop_vars; int64_t upper_bound = bound->max_value + 1;
for (size_t j = 0; j < loop_vars.size(); ++j) { int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
auto loop_var = loop_vars[j]; if (upper_bound < shape) {
// if find related, pop the loop_vars and loop_extents PrimExpr predicate = LT(index, IntImm(index.dtype(), upper_bound));
if (used_vars.count(loop_var)) { condition =
related_loop_vars.push_back(loop_var); condition.defined() ? And(condition, predicate) : predicate;
}
ICHECK(related_loop_vars.size() <= 1) // replace the buffer index from A[i, r * 2] with A[i, j]
<< "Only one related loop var is supported currently, but got " // where r is the original index, j is the loop_var
<< related_loop_vars auto index_map = tir::IndexMap({loop_var}, {index});
<< " implement multiple loop vars may not be " auto inverse_index_map = index_map.Inverse(
<< "too hard, please send an issue if you need " {Range::FromMinExtent(0, IntImm(index.dtype(), upper_bound))},
<< "came up with this message."; analyzer_);
auto bound = analyzer_->const_int_bound(index); loop_extents.Set(i, IntImm(index.dtype(), shape));
int64_t upper_bound = bound->max_value + 1; body = tir::Substitute(body,
int64_t shape = Downcast<IntImm>(buffer->shape[i])->value; {{loop_var, inverse_index_map->MapIndices(
if (upper_bound < shape) { {loop_var}, analyzer_)[0]}});
PrimExpr predicate =
LT(index, IntImm(index.dtype(), upper_bound));
condition =
condition.defined() ? And(condition, predicate) : predicate;
// replace the buffer index from A[i, r * 2] with A[i, j]
// where r is the original index, j is the loop_var
auto index_map = tir::IndexMap({loop_var}, {index});
auto inverse_index_map = index_map.Inverse(
{Range::FromMinExtent(0, IntImm(index.dtype(), upper_bound))},
analyzer_);
loop_extents.Set(i, IntImm(index.dtype(), shape));
body = tir::Substitute(
body, {{loop_var, inverse_index_map->MapIndices(
{loop_var}, analyzer_)[0]}});
}
} }
} }
} }
if (condition.defined()) { }
body = IfThenElse(condition, body); if (condition.defined()) {
for (int j = loop_vars.size() - 1; j >= 0; --j) { body = IfThenElse(condition, body);
auto loop_var = loop_vars[j]; for (int j = loop_vars.size() - 1; j >= 0; --j) {
auto loop_extent = loop_extents[j]; auto loop_var = loop_vars[j];
body = For(loop_var, 0, loop_extent, ForKind::kParallel, body); auto loop_extent = loop_extents[j];
} body = For(loop_var, 0, loop_extent, ForKind::kParallel, body);
return Downcast<For>(body);
} }
// Only traverse the outer loop return Downcast<For>(body);
return for_node;
} }
return StmtMutator::VisitStmt_(op); // Only traverse the outer loop
return for_node;
} }
private: private:
...@@ -206,7 +205,7 @@ struct LayoutInferenceResult { ...@@ -206,7 +205,7 @@ struct LayoutInferenceResult {
Map<For, PrimExpr> predicate_map; Map<For, PrimExpr> predicate_map;
}; };
class BufferUseDefCollector : public StmtExprVisitor { class BufferUseDefCollector : public IRVisitorWithAnalyzer {
public: public:
BufferUseDefCollector(bool skip_thread_partition) BufferUseDefCollector(bool skip_thread_partition)
: skip_thread_partition_(skip_thread_partition) {} : skip_thread_partition_(skip_thread_partition) {}
...@@ -217,6 +216,9 @@ public: ...@@ -217,6 +216,9 @@ public:
ICHECK_EQ(infer_list_.size(), thread_var_vec_.size()) ICHECK_EQ(infer_list_.size(), thread_var_vec_.size())
<< "Size mismatch: infer_list_ and thread_var_vec_ must match in " << "Size mismatch: infer_list_ and thread_var_vec_ must match in "
"length."; "length.";
ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
<< "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
"length.";
// If needed, you can also check that annotated_layout_map_ is not empty, or // If needed, you can also check that annotated_layout_map_ is not empty, or
// anything else relevant to your setup. // anything else relevant to your setup.
...@@ -236,12 +238,6 @@ public: ...@@ -236,12 +238,6 @@ public:
// Check that each thread_var_vec_ entry is defined // Check that each thread_var_vec_ entry is defined
if (!thread_var_vec_[i].defined() && skip_thread_partition_) { if (!thread_var_vec_[i].defined() && skip_thread_partition_) {
// TODO(lei): This is a hack for cpu backend
if (!thread_var_.defined()) {
// Fake thread var to inference predicate for the buffer
thread_var_ = IterVar(Range::FromMinExtent(PrimExpr(0), PrimExpr(1)),
Var(""), IterVarType::kDataPar);
}
thread_var_vec_[i] = thread_var_; thread_var_vec_[i] = thread_var_;
} }
q.push(i); q.push(i);
...@@ -259,6 +255,7 @@ public: ...@@ -259,6 +255,7 @@ public:
// thread_var_vec_[cur_infer_id] // thread_var_vec_[cur_infer_id]
auto &next = infer_list_[cur_infer_id]; auto &next = infer_list_[cur_infer_id];
auto iter_var = thread_var_vec_[cur_infer_id]; auto iter_var = thread_var_vec_[cur_infer_id];
auto thread_bounds = thread_bounds_vec_[cur_infer_id];
// Double-check that 'next' is valid // Double-check that 'next' is valid
ICHECK(next != nullptr) << "infer_list_[" << cur_infer_id ICHECK(next != nullptr) << "infer_list_[" << cur_infer_id
...@@ -281,9 +278,7 @@ public: ...@@ -281,9 +278,7 @@ public:
// Run InferLayout // Run InferLayout
auto updates = next->InferLayout( auto updates = next->InferLayout(
LayoutInferArgs{target_, static_cast<size_t>(*extent_ptr), LayoutInferArgs{target_, thread_bounds, layout_map}, level);
layout_map},
level);
// Process the returned updates // Process the returned updates
for (const auto &[buffer, layout] : updates) { for (const auto &[buffer, layout] : updates) {
// Basic validity checks // Basic validity checks
...@@ -407,7 +402,7 @@ public: ...@@ -407,7 +402,7 @@ public:
private: private:
void VisitExpr_(const CallNode *op) final { void VisitExpr_(const CallNode *op) final {
StmtExprVisitor::VisitExpr_(op); IRVisitorWithAnalyzer::VisitExpr_(op);
// Do not analysis the call node to the global function. // Do not analysis the call node to the global function.
if (op->op.as<GlobalVarNode>()) if (op->op.as<GlobalVarNode>())
return; return;
...@@ -421,6 +416,16 @@ private: ...@@ -421,6 +416,16 @@ private:
} }
infer_list_.push_back(std::move(p)); infer_list_.push_back(std::move(p));
thread_var_vec_.push_back(thread_var_); thread_var_vec_.push_back(thread_var_);
if (analyzer_.const_int_bound.IsBound(thread_var_->var)) {
auto const_int_bound = analyzer_.const_int_bound(thread_var_);
auto min_value = const_int_bound->min_value;
auto max_value = const_int_bound->max_value;
auto dtype = thread_var_->var.dtype();
thread_bounds_vec_.push_back(Range::FromMinExtent(
IntImm(dtype, min_value), IntImm(dtype, max_value + 1)));
} else {
thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
}
} }
} }
...@@ -449,8 +454,18 @@ private: ...@@ -449,8 +454,18 @@ private:
} }
infer_list_.push_back(std::move(infer)); infer_list_.push_back(std::move(infer));
thread_var_vec_.push_back(thread_var_); thread_var_vec_.push_back(thread_var_);
if (thread_var_.defined() &&
analyzer_.const_int_bound.IsBound(thread_var_->var)) {
auto const_int_bound = analyzer_.const_int_bound(thread_var_);
auto dtype = thread_var_->var.dtype();
thread_bounds_vec_.push_back(Range::FromMinExtent(
IntImm(dtype, const_int_bound->min_value),
IntImm(dtype, const_int_bound->max_value + 1)));
} else {
thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
}
} else { } else {
StmtExprVisitor::VisitStmt(op->body); IRVisitorWithAnalyzer::VisitStmt(op->body);
} }
} }
...@@ -467,7 +482,7 @@ private: ...@@ -467,7 +482,7 @@ private:
annotated_layout_map_.Set(buffer, layout); annotated_layout_map_.Set(buffer, layout);
} }
} }
StmtExprVisitor::VisitStmt_(op); IRVisitorWithAnalyzer::VisitStmt_(op);
} }
void VisitStmt_(const AttrStmtNode *op) final { void VisitStmt_(const AttrStmtNode *op) final {
...@@ -478,15 +493,19 @@ private: ...@@ -478,15 +493,19 @@ private:
thread_var_ = iv; thread_var_ = iv;
} }
} }
StmtExprVisitor::VisitStmt_(op); IRVisitorWithAnalyzer::VisitStmt_(op);
} }
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
std::vector<std::unique_ptr<Operator>> infer_list_; std::vector<std::unique_ptr<Operator>> infer_list_;
std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual> std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
use_list_; use_list_;
IterVar thread_var_; // This is a workaround for cpu backend,
// we need to define a thread_var for the serial loop.
IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
IterVarType::kDataPar);
std::vector<IterVar> thread_var_vec_; std::vector<IterVar> thread_var_vec_;
std::vector<Range> thread_bounds_vec_;
Target target_; Target target_;
LayoutMap annotated_layout_map_; LayoutMap annotated_layout_map_;
bool skip_thread_partition_{false}; bool skip_thread_partition_{false};
...@@ -597,7 +616,8 @@ private: ...@@ -597,7 +616,8 @@ private:
private: private:
const LayoutInferenceResult result_; const LayoutInferenceResult result_;
IterVar thread_var_; IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
IterVarType::kDataPar);
bool skip_thread_partition_{false}; bool skip_thread_partition_{false};
}; };
......
...@@ -38,7 +38,7 @@ public: ...@@ -38,7 +38,7 @@ public:
private: private:
PrimExpr VisitExpr_(const BufferLoadNode *node) final { PrimExpr VisitExpr_(const BufferLoadNode *node) final {
auto visited = StmtExprMutator::VisitExpr_(node); auto visited = StmtExprMutator::VisitExpr_(node);
auto n = visited.as<BufferLoad>().value(); auto n = Downcast<BufferLoad>(visited);
auto nptr = n.CopyOnWrite(); auto nptr = n.CopyOnWrite();
nptr->indices = nptr->indices.Map( nptr->indices = nptr->indices.Map(
[&](const auto &e) { return analyzer_->Simplify(e); }); [&](const auto &e) { return analyzer_->Simplify(e); });
...@@ -46,7 +46,7 @@ private: ...@@ -46,7 +46,7 @@ private:
} }
Stmt VisitStmt_(const BufferStoreNode *node) final { Stmt VisitStmt_(const BufferStoreNode *node) final {
auto visited = StmtExprMutator::VisitStmt_(node); auto visited = StmtExprMutator::VisitStmt_(node);
auto n = visited.as<BufferStore>().value(); auto n = Downcast<BufferStore>(visited);
auto nptr = n.CopyOnWrite(); auto nptr = n.CopyOnWrite();
nptr->indices = nptr->indices.Map( nptr->indices = nptr->indices.Map(
[&](const auto &e) { return analyzer_->Simplify(e); }); [&](const auto &e) { return analyzer_->Simplify(e); });
...@@ -74,11 +74,10 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, ...@@ -74,11 +74,10 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
Map<Var, PrimExpr> vmap; Map<Var, PrimExpr> vmap;
Stmt body = op; Stmt body = op;
auto inv_loop = loop_layout->Inverse(); auto inv_loop = loop_layout->Inverse();
auto indices = auto indices = inv_loop->Forward(Array<PrimExpr>(vars.begin(), vars.end()));
inv_loop->Forward(vars.Map([](const Var &v) { return PrimExpr(v); }));
for (int i = 0; i < old_loop_depth; i++) { for (int i = 0; i < old_loop_depth; i++) {
ICHECK(body.as<For>().defined()); const ForNode *loop = body.as<ForNode>();
For loop = body.as<For>().value(); ICHECK(loop != nullptr);
vmap.Set(loop->loop_var, indices[i]); vmap.Set(loop->loop_var, indices[i]);
body = loop->body; body = loop->body;
} }
...@@ -94,7 +93,12 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, ...@@ -94,7 +93,12 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
body = BufferIndiceSimplify(analyzer)(body); body = BufferIndiceSimplify(analyzer)(body);
auto for_node = LoopPragmaUnroll(Downcast<For>(body)); auto for_node = LoopPragmaUnroll(Downcast<For>(body));
if (loop_layout->ThreadRange().defined()) {
auto range = loop_layout->ThreadRange();
auto thread_var_with_offset = thread_var - range->min;
for_node.CopyOnWrite()->body =
Substitute(for_node->body, {{thread_var, thread_var_with_offset}});
}
return for_node; return for_node;
} }
...@@ -161,6 +165,16 @@ Fragment PlanLoopPartition(For op, size_t num_thread, int vectorize_size) { ...@@ -161,6 +165,16 @@ Fragment PlanLoopPartition(For op, size_t num_thread, int vectorize_size) {
return partitioner.Partition(op, num_thread, vectorize_size); return partitioner.Partition(op, num_thread, vectorize_size);
} }
Fragment PlanLoopPartition(For op, int vectorize_size, Range thread_range) {
size_t num_thread =
*as_const_int(thread_range->extent) - *as_const_int(thread_range->min);
LoopPartitioner partitioner;
Fragment fragment = partitioner.Partition(op, num_thread, vectorize_size);
auto node = make_object<FragmentNode>(*fragment.get());
node->SetThreadRange(thread_range);
return Fragment(node);
}
For LoopPragmaUnroll(For stmt) { For LoopPragmaUnroll(For stmt) {
LoopPramaUnroller unroller; LoopPramaUnroller unroller;
For unrolled = Downcast<For>(unroller(stmt)); For unrolled = Downcast<For>(unroller(stmt));
......
...@@ -39,6 +39,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, ...@@ -39,6 +39,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
Fragment PlanLoopPartition(For op, size_t num_thread, int vectorize_size); Fragment PlanLoopPartition(For op, size_t num_thread, int vectorize_size);
Fragment PlanLoopPartition(For op, int vectorize_size, Range thread_range);
For LoopPragmaUnroll(For stmt); For LoopPragmaUnroll(For stmt);
} // namespace tl } // namespace tl
......
...@@ -289,11 +289,23 @@ private: ...@@ -289,11 +289,23 @@ private:
Optional<Bool> opt_disable_tma_lower = Optional<Bool> opt_disable_tma_lower =
ctxt->GetConfig(kDisableTMALower, Optional<Bool>()); ctxt->GetConfig(kDisableTMALower, Optional<Bool>());
bool disable_tma_lower = opt_disable_tma_lower.value_or(Bool(false)); bool disable_tma_lower = opt_disable_tma_lower.value_or(Bool(false));
Range thread_bounds;
if (analyzer_->const_int_bound.IsBound(thread_var_->var)) {
auto const_int_bound = analyzer_->const_int_bound(thread_var_);
auto min_value = const_int_bound->min_value;
auto max_value = const_int_bound->max_value;
thread_bounds =
Range::FromMinExtent(IntImm(thread_var_->var.dtype(), min_value),
IntImm(thread_var_->var.dtype(), max_value + 1));
} else {
thread_bounds = Range::FromMinExtent(0, 1);
}
auto lowered = tile_op->Lower(LowerArgs{target_, thread_block_size_, auto lowered = tile_op->Lower(
thread_var_, callback, layout_map_, LowerArgs{target_, thread_bounds, thread_var_->var, callback,
buffer_remap_, disable_tma_lower}, layout_map_, buffer_remap_, disable_tma_lower},
analyzer_); analyzer_);
return IRMutatorWithAnalyzer::VisitStmt(lowered); return IRMutatorWithAnalyzer::VisitStmt(lowered);
} }
...@@ -302,7 +314,7 @@ private: ...@@ -302,7 +314,7 @@ private:
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
ICHECK_NE(iv->thread_tag.length(), 0U); ICHECK_NE(iv->thread_tag.length(), 0U);
if (iv->thread_tag == "threadIdx.x") { if (iv->thread_tag == "threadIdx.x") {
thread_var_ = iv->var; thread_var_ = iv;
ICHECK(iv->dom->extent.as<IntImmNode>()); ICHECK(iv->dom->extent.as<IntImmNode>());
thread_block_size_ = iv->dom->extent.as<IntImmNode>()->value; thread_block_size_ = iv->dom->extent.as<IntImmNode>()->value;
} }
...@@ -314,7 +326,10 @@ private: ...@@ -314,7 +326,10 @@ private:
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, Layout> layout_map_; Map<Buffer, Layout> layout_map_;
Map<Buffer, Buffer> buffer_remap_; Map<Buffer, Buffer> buffer_remap_;
Var thread_var_; // This is a workaround for cpu backend,
// we need to define a thread_var for the serial loop.
IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
IterVarType::kDataPar);
size_t thread_block_size_ = 0; size_t thread_block_size_ = 0;
Array<Buffer> workspaces_; Array<Buffer> workspaces_;
// For ptx Node, we need to remap the buffer and indices // For ptx Node, we need to remap the buffer and indices
......
import tilelang
import tilelang.language as T
import torch
def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
tx = T.get_thread_binding(0)
if tx < 128:
for i, k in T.Parallel(block_M, block_N):
A_shared[i, k] = A[by * block_M + i, bx * block_N + k]
T.copy(A_shared, B[by * block_M, bx * block_N])
return main
def run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
program = tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
def test_tilelang_copy_mask_parallel():
run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128)
def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
tx = T.get_thread_binding(0)
if tx < 128:
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(A_shared, B[by * block_M, bx * block_N])
return main
def run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
program = tilelang_copy_mask_copy(M, N, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
def test_tilelang_copy_mask_copy():
run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128)
def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
tx = T.get_thread_binding(0)
if tx >= 128 and tx < 256:
for i, k in T.Parallel(block_M, block_N):
A_shared[i, k] = A[by * block_M + i, bx * block_N + k]
T.copy(A_shared, B[by * block_M, bx * block_N])
return main
def run_tilelang_copy_mask_parallel_range(M=1024,
N=1024,
block_M=128,
block_N=128,
dtype="float16"):
program = tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
def test_tilelang_copy_mask_parallel_range():
run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128)
def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"):
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
tx = T.get_thread_binding(0)
if tx >= 128 and tx < 256:
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(A_shared, B[by * block_M, bx * block_N])
return main
def run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
program = tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
def test_tilelang_copy_mask_copy_range():
run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128)
if __name__ == "__main__":
test_tilelang_copy_mask_copy_range()
...@@ -3,6 +3,15 @@ from tvm.target import Target ...@@ -3,6 +3,15 @@ from tvm.target import Target
import tilelang import tilelang
def allow_tma_and_warp_specialized(target: Target) -> bool:
if target.arch not in {"sm_90"}:
return False
cur_pass_ctx = tilelang.transform.get_pass_context()
disable_tma_lower = cur_pass_ctx.config.get("tl.disable_tma_lower", False)
disable_warp_specialized = cur_pass_ctx.config.get("tl.disable_warp_specialized", False)
return not (disable_tma_lower and disable_warp_specialized)
def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# Bind the target device information to the module # Bind the target device information to the module
mod = tir.transform.BindTarget(target)(mod) mod = tir.transform.BindTarget(target)(mod)
...@@ -30,7 +39,7 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -30,7 +39,7 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# which may be introduced by the LegalizeSafeMemoryAccess # which may be introduced by the LegalizeSafeMemoryAccess
if target.arch == "sm_90": if allow_tma_and_warp_specialized(target):
mod = tilelang.transform.IfStmtBinding()(mod) mod = tilelang.transform.IfStmtBinding()(mod)
mod = tilelang.transform.MultiVersionBuffer()(mod) mod = tilelang.transform.MultiVersionBuffer()(mod)
mod = tilelang.transform.WarpSpecialized()(mod) mod = tilelang.transform.WarpSpecialized()(mod)
......
...@@ -5,6 +5,12 @@ from . import _ffi_api ...@@ -5,6 +5,12 @@ from . import _ffi_api
from .simplify import Simplify, simplify_prim_func # noqa: F401 from .simplify import Simplify, simplify_prim_func # noqa: F401
def get_pass_context():
"""Get the current pass context"""
from tilelang import tvm as tvm
return tvm.transform.PassContext.current()
def ClusterPlanning(): def ClusterPlanning():
"""ClusterPlanning """ClusterPlanning
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment