Commit 42c3b452 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Phaseout Pass ParallelLoopTransformer (#611)

* Refactor layout inference by removing the ParallelLoopTransformer class. Updated layout inference logic to streamline buffer access collection and condition handling in parallel loops. This change simplifies the code structure and enhances maintainability.

* Update MHA backward test cases to use reduced dimensions for batch size and context length
parent d9ae74c6
......@@ -47,158 +47,6 @@ using namespace tir;
using arith::IRMutatorWithAnalyzer;
using arith::IRVisitorWithAnalyzer;
class ParallelLoopTransformer : public IRMutatorWithAnalyzer {
public:
static Stmt Substitute(Stmt stmt, bool skip_thread_partition = false) {
arith::Analyzer analyzer;
ParallelLoopTransformer transformer(&analyzer);
return transformer.VisitStmt(stmt);
}
ParallelLoopTransformer(arith::Analyzer *analyzer)
: IRMutatorWithAnalyzer(analyzer) {}
Stmt VisitStmt_(const ForNode *op) final {
if (op->kind != ForKind::kParallel)
return StmtMutator::VisitStmt_(op);
// Collect loop variables and ranges
auto for_node = GetRef<For>(op);
Array<Var> loop_vars;
Array<PrimExpr> loop_extents;
Stmt body = op->body;
// Bind the range of outer loop variables
analyzer_->Bind(op->loop_var, Range::FromMinExtent(0, op->extent));
loop_vars.push_back(op->loop_var);
loop_extents.push_back(op->extent);
// If there are inner loops, bind their ranges as well
while (const ForNode *inner = body.as<ForNode>()) {
analyzer_->Bind(inner->loop_var, Range::FromMinExtent(0, inner->extent));
loop_vars.push_back(inner->loop_var);
loop_extents.push_back(inner->extent);
body = inner->body;
}
ICHECK(loop_vars.size() == loop_extents.size())
<< "loop_vars and loop_extents size mismatch";
// Collect buffer access information
BufferAccessCollector collector;
collector(op->body);
PrimExpr condition;
for (const auto &[buffer, indices] : collector.buffer_indices) {
ICHECK(indices.size() == buffer->shape.size())
<< "indices size mismatch with buffer shape";
for (size_t i = 0; i < indices.size(); ++i) {
auto index = indices[i];
auto bound = analyzer_->const_int_bound(index);
int64_t upper_bound = bound->max_value + 1;
int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
// Collect the variables that used in the index
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_vars;
// post order visit the index
PostOrderVisit(index, [&](const ObjectRef &obj) {
if (const VarNode *v = obj.as<VarNode>()) {
used_vars.insert(GetRef<Var>(v));
}
});
if (used_vars.size() == 0) {
continue;
}
// find related loop vars
Array<Var> related_loop_vars;
for (size_t j = 0; j < loop_vars.size(); ++j) {
auto loop_var = loop_vars[j];
// if find related, pop the loop_vars and loop_extents
if (used_vars.count(loop_var)) {
related_loop_vars.push_back(loop_var);
}
ICHECK(related_loop_vars.size() <= 1)
<< "Only one related loop var is supported currently, but got "
<< related_loop_vars
<< " implement multiple loop vars may not be "
<< "too hard, please send an issue if you need "
<< "came up with this message.";
auto bound = analyzer_->const_int_bound(index);
int64_t upper_bound = bound->max_value + 1;
int64_t shape = Downcast<IntImm>(buffer->shape[i])->value;
if (upper_bound < shape) {
PrimExpr predicate = LT(index, IntImm(index.dtype(), upper_bound));
condition =
condition.defined() ? And(condition, predicate) : predicate;
// replace the buffer index from A[i, r * 2] with A[i, j]
// where r is the original index, j is the loop_var
auto index_map = tir::IndexMap({loop_var}, {index});
auto inverse_index_map = index_map.Inverse(
{Range::FromMinExtent(0, IntImm(index.dtype(), upper_bound))},
analyzer_);
loop_extents.Set(i, IntImm(index.dtype(), shape));
body = tir::Substitute(body,
{{loop_var, inverse_index_map->MapIndices(
{loop_var}, analyzer_)[0]}});
}
}
}
}
if (condition.defined()) {
body = IfThenElse(condition, body);
for (int j = loop_vars.size() - 1; j >= 0; --j) {
auto loop_var = loop_vars[j];
auto loop_extent = loop_extents[j];
body = For(loop_var, 0, loop_extent, ForKind::kParallel, body);
}
return Downcast<For>(body);
}
// Only traverse the outer loop
return for_node;
}
private:
// Helper class for collecting buffer access information, only counts fragment
// buffer access
class BufferAccessCollector : public StmtExprVisitor {
public:
void VisitExpr_(const BufferLoadNode *op) final {
if (op->buffer.scope() == "local.fragment") {
if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
buffer_indices[op->buffer] = op->indices;
} else {
// check equal
ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
<< "indices mismatch for buffer: " << op->buffer;
}
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode *op) final {
if (op->buffer.scope() == "local.fragment") {
if (buffer_indices.find(op->buffer) == buffer_indices.end()) {
buffer_indices[op->buffer] = op->indices;
} else {
// check equal
ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices))
<< "indices mismatch for buffer: " << op->buffer;
}
}
StmtExprVisitor::VisitStmt_(op);
}
std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual>
buffer_indices;
};
};
struct LayoutInferenceResult {
Map<Buffer, Layout> layout_map;
Map<For, Fragment> for_map;
......@@ -653,7 +501,6 @@ private:
tvm::transform::Pass LayoutInference() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body);
ThreadBindingCollector collector;
collector(f->body);
bool has_thread_binding = collector.thread_binding_.size() > 0;
......
......@@ -307,8 +307,8 @@ def assert_mha_equal(batch, h, n_ctx, d_head, causal):
def test_mha_bwd():
assert_mha_equal(8, 32, 256, 64, False)
assert_mha_equal(8, 32, 256, 64, True)
assert_mha_equal(8, 32, 128, 64, False)
assert_mha_equal(8, 32, 128, 64, True)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment