Commit 5c8de061 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Feature] Implement ParallelLoopTransformer for enhanced loop analysis (#295)

* [Feature] Implement ParallelLoopTransformer for enhanced loop analysis

- Introduced the ParallelLoopTransformer class to improve the handling of parallel loops in layout inference.
- Enhanced the analysis of loop variables and their extents, allowing for more accurate index range calculations.
- Added a BufferAccessCollector to gather buffer access information, ensuring correct index mapping and condition handling.
- Updated the LayoutInference pass to utilize the new transformer, improving overall performance and accuracy in loop transformations.

* test fix

* Fix typo in buffer variable documentation and enhance loop variable handling in layout inference. Added checks for related loop variables and improved condition handling for index mapping.

* Refactor loop variable handling in layout inference. Updated loop index variable from `i` to `j` for clarity and improved condition handling for index mapping by replacing `indices[i]` with `index` in predicate construction.
parent ff3cfa59
......@@ -4,6 +4,7 @@
*/
#include <tvm/tir/builtin.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
......@@ -25,7 +26,7 @@ namespace tl {
using namespace tir;
/*!
* \brief collect the mapping from the buffer var to its allocate
* \brief collect the mapping from the buffer var to it allocated buffer
*/
class ThreadBindingCollector : public StmtExprVisitor {
public:
......@@ -44,6 +45,161 @@ public:
using namespace tir;
using arith::IRMutatorWithAnalyzer;
class ParallelLoopTransformer : public IRMutatorWithAnalyzer {
public:
static Stmt Substitute(Stmt stmt, bool skip_thread_partition = false) {
arith::Analyzer analyzer;
ParallelLoopTransformer transformer(&analyzer);
return transformer.VisitStmt(stmt);
}
ParallelLoopTransformer(arith::Analyzer *analyzer)
: IRMutatorWithAnalyzer(analyzer) {}
Stmt VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kParallel) {
// Collect loop variables and ranges
auto for_node = GetRef<For>(op);
Array<Var> loop_vars;
Array<PrimExpr> loop_extents;
Stmt body = op->body;
// Bind the range of outer loop variables
analyzer_->Bind(op->loop_var, Range::FromMinExtent(0, op->extent));
loop_vars.push_back(op->loop_var);
loop_extents.push_back(op->extent);
// If there are inner loops, bind their ranges as well
while (const ForNode *inner = body.as<ForNode>()) {
analyzer_->Bind(inner->loop_var,
Range::FromMinExtent(0, inner->extent));
loop_vars.push_back(inner->loop_var);
loop_extents.push_back(inner->extent);
body = inner->body;
}
ICHECK(loop_vars.size() == loop_extents.size())
<< "loop_vars and loop_extents size mismatch";
// Collect buffer access information
BufferAccessCollector collector;
collector(op->body);
PrimExpr condition;
for (const auto &[buffer, indices] : collector.buffer_indices) {
ICHECK(indices.size() == buffer->shape.size())
<< "indices size mismatch with buffer shape";
for (size_t i = 0; i < indices.size(); ++i) {
auto index = indices[i];
auto bound = analyzer_->const_int_bound(index);
int64_t upper_bound = bound->max_value + 1;
int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
// Collect the variables that used in the index
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_vars;
// post order visit the index
PostOrderVisit(index, [&](const ObjectRef &obj) {
if (const VarNode *v = obj.as<VarNode>()) {
used_vars.insert(GetRef<Var>(v));
}
});
if (used_vars.size() == 0) {
continue;
}
// find related loop vars
Array<Var> related_loop_vars;
for (size_t j = 0; j < loop_vars.size(); ++j) {
auto loop_var = loop_vars[j];
// if find related, pop the loop_vars and loop_extents
if (used_vars.count(loop_var)) {
related_loop_vars.push_back(loop_var);
}
ICHECK(related_loop_vars.size() <= 1)
<< "Only one related loop var is supported currently, but got "
<< related_loop_vars
<< " implement multiple loop vars may not be "
<< "too hard, please send an issue if you need "
<< "came up with this message.";
auto bound = analyzer_->const_int_bound(index);
int64_t upper_bound = bound->max_value + 1;
int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
if (upper_bound < shape) {
PrimExpr predicate =
LT(index, IntImm(index.dtype(), upper_bound));
condition =
condition.defined() ? And(condition, predicate) : predicate;
// replace the buffer index from A[i, r * 2] with A[i, j]
// where r is the original index, j is the loop_var
auto index_map = tir::IndexMap({loop_var}, {index});
auto inverse_index_map = index_map.Inverse(
{Range::FromMinExtent(0, IntImm(index.dtype(), upper_bound))},
analyzer_);
loop_extents.Set(i, IntImm(index.dtype(), shape));
body = tir::Substitute(
body, {{loop_var, inverse_index_map->MapIndices(
{loop_var}, analyzer_)[0]}});
}
}
}
}
if (condition.defined()) {
body = IfThenElse(condition, body);
for (int j = loop_vars.size() - 1; j >= 0; --j) {
auto loop_var = loop_vars[j];
auto loop_extent = loop_extents[j];
body = For(loop_var, 0, loop_extent, ForKind::kParallel, body);
}
return Downcast<For>(body);
}
// Only traverse the outer loop
return for_node;
}
return StmtMutator::VisitStmt_(op);
}
private:
// Helper class for collecting buffer access information, only counts fragment
// buffer access
class BufferAccessCollector : public StmtExprVisitor {
public:
void VisitExpr_(const BufferLoadNode *op) final {
if (op->buffer.scope() == "local.fragment") {
if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
buffer_indices[op->buffer] = op->indices;
} else {
// check equal
ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
<< "indices mismatch for buffer: " << op->buffer;
}
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode *op) final {
if (op->buffer.scope() == "local.fragment") {
if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
buffer_indices[op->buffer] = op->indices;
} else {
// check equal
ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
<< "indices mismatch for buffer: " << op->buffer;
}
}
StmtExprVisitor::VisitStmt_(op);
}
std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual>
buffer_indices;
};
};
struct LayoutInferenceResult {
Map<Buffer, Layout> layout_map;
Map<For, Fragment> for_map;
......@@ -219,14 +375,12 @@ public:
// Check if base_infer is valid
ICHECK(base_infer != nullptr) << "Null pointer encountered in "
"infer_list_ while collecting for_map.";
if (auto for_infer = dynamic_cast<ParallelOp *>(base_infer.get())) {
// Check that the loop layout is defined
ICHECK(for_infer->GetLoopLayout().defined())
<< "The Layout for Parallel for cannot be inferred correctly:\n"
<< for_infer->GetRoot();
for_map.Set(for_infer->GetRoot(), for_infer->GetLoopLayout());
// thread_var_ should be defined if we rely on it
ICHECK(thread_var.defined())
<< "thread_var is not defined. Cannot retrieve predicate.";
......@@ -376,13 +530,13 @@ private:
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (result_.for_map.count(GetRef<For>(op))) {
auto loop_layout = result_.for_map[GetRef<For>(op)];
if (!skip_thread_partition_) {
// If none thread bindings are provided, partition the loop
for_node =
PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
}
for_node = VectorizeLoop(for_node);
if (result_.predicate_map.count(GetRef<For>(op))) {
return IfThenElse(result_.predicate_map[GetRef<For>(op)], for_node);
} else {
......@@ -412,6 +566,7 @@ private:
tvm::transform::Pass LayoutInference() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body);
ThreadBindingCollector collector;
collector(f->body);
bool has_thread_binding = collector.thread_binding_.size() > 0;
......
......@@ -412,4 +412,5 @@ def test_assert_tl_matmul_block_all_dynamic():
if __name__ == "__main__":
tilelang.testing.main()
# tilelang.testing.main()
test_assert_tl_matmul_macro()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment