Commit 86f96f8a authored by Yu Cheng's avatar Yu Cheng Committed by LeiWang1999
Browse files

[Dev] Implement IfStmtBinding and MergeIfStmt transformations (#211)



* [Dev] Implement IfStmtBinding and MergeIfStmt transformations

- Add IfStmtBinding to bind If statements to each statement in SeqStmt, enhancing the handling of conditional statements.
- Introduce MergeIfStmt to merge consecutive If statements within SeqStmt, optimizing the structure of conditional logic.
- Update phase.py to apply IfStmtBinding and MergeIfStmt transformations for the "sm_90" target.
- Enhance __init__.py with new functions for IfStmtBinding and MergeIfStmt, providing a clear interface for these transformations.

* Update license header in if_stmt_binding.cc

* Update license header in merge_if_stmt.cc

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent e2bc1cb6
/*!
* \file if_stmt_binding.cc
* \brief Bind the If Stmt to each Stmt in SeqStmt
*/
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
namespace tvm {
namespace tl {
using namespace tir;
class IfStmtBindingRewriter : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc &f) {
auto rewriter = IfStmtBindingRewriter();
f.CopyOnWrite()->body = rewriter(f->body);
return f;
}
private:
IfStmtBindingRewriter() = default;
Stmt VisitStmt_(const IfThenElseNode *op) final {
auto condition = op->condition;
auto then_case = op->then_case;
auto else_case = op->else_case;
auto bind_if_stmt = [](Optional<Stmt> body,
const PrimExpr condition) -> Stmt {
if (body.defined()) {
auto stmt = body.value();
if (auto seq_stmt = stmt.as<SeqStmtNode>()) {
Array<Stmt> seq_;
for (auto s : seq_stmt->seq) {
seq_.push_back(IfThenElse(condition, s, Stmt()));
}
return SeqStmt(std::move(seq_));
} else {
return IfThenElse(condition, stmt, Stmt());
}
} else {
return Stmt();
}
};
Array<Stmt> new_seq;
if (then_case.defined()) {
new_seq.push_back(bind_if_stmt(then_case, condition));
}
if (else_case.defined()) {
new_seq.push_back(bind_if_stmt(else_case, !condition));
}
return new_seq.size() == 1 ? new_seq[0] : SeqStmt(std::move(new_seq));
}
Stmt VisitStmt_(const SeqStmtNode *op) final {
Array<Stmt> seq;
for (auto stmt : op->seq) {
seq.push_back(VisitStmt(stmt));
}
return SeqStmt(std::move(seq));
}
};
using namespace tir::transform;
tvm::transform::Pass IfStmtBinding() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return IfStmtBindingRewriter::Substitute(f);
};
return CreatePrimFuncPass(pass_func, 0, "tl.IfStmtBinding", {});
}
TVM_REGISTER_GLOBAL("tl.transform.IfStmtBinding").set_body_typed(IfStmtBinding);
} // namespace tl
} // namespace tvm
/*!
* \file if_stmt_binding.cc
* \brief Merge the If Stmt in SeqStmt
*/
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
namespace tvm {
namespace tl {
using namespace tir;
class MergeIfStmtRewriter : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc &f) {
auto rewriter = MergeIfStmtRewriter();
f.CopyOnWrite()->body = rewriter(f->body);
return f;
}
private:
MergeIfStmtRewriter() = default;
Stmt VisitStmt_(const SeqStmtNode *op) final {
Array<Stmt> new_seq;
PrimExpr current_condition;
Array<Stmt> current_if_bodies;
for (const Stmt &stmt : op->seq) {
Stmt new_stmt = this->VisitStmt(stmt);
if (const IfThenElseNode *if_node = new_stmt.as<IfThenElseNode>()) {
if (!if_node->else_case.defined()) {
if (current_condition.defined() &&
StructuralEqual()(current_condition, if_node->condition)) {
current_if_bodies.push_back(if_node->then_case);
continue;
} else {
if (!current_if_bodies.empty()) {
new_seq.push_back(IfThenElse(current_condition,
current_if_bodies.size() == 1
? current_if_bodies[0]
: SeqStmt(current_if_bodies),
Stmt()));
current_if_bodies.clear();
}
current_condition = if_node->condition;
current_if_bodies.push_back(if_node->then_case);
continue;
}
}
}
if (!current_if_bodies.empty()) {
new_seq.push_back(IfThenElse(current_condition,
current_if_bodies.size() == 1
? current_if_bodies[0]
: SeqStmt(current_if_bodies),
Stmt()));
current_condition = PrimExpr();
current_if_bodies.clear();
}
new_seq.push_back(new_stmt);
}
if (!current_if_bodies.empty()) {
new_seq.push_back(IfThenElse(current_condition,
current_if_bodies.size() == 1
? current_if_bodies[0]
: SeqStmt(current_if_bodies),
Stmt()));
}
return new_seq.size() == 1 ? new_seq[0] : SeqStmt(new_seq);
}
};
using namespace tir::transform;
tvm::transform::Pass MergeIfStmt() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return MergeIfStmtRewriter::Substitute(f);
};
return CreatePrimFuncPass(pass_func, 0, "tl.MergeIfStmt", {});
}
TVM_REGISTER_GLOBAL("tl.transform.MergeIfStmt").set_body_typed(MergeIfStmt);
} // namespace tl
} // namespace tvm
......@@ -182,6 +182,46 @@ static Stmt makeParityWait(PrimExpr barrier_id, PrimExpr parity) {
// return is_gemm;
// }
class TMAExpectTxRewriter : public StmtExprMutator {
public:
TMAExpectTxRewriter(Stmt expect_tx) : expect_tx_(expect_tx) {}
static Stmt Rewrite(Stmt stmt, Stmt expect_tx) {
TMAExpectTxRewriter rewriter(expect_tx);
return rewriter(stmt);
}
private:
Stmt VisitStmt_(const ForNode *op) final {
insert_in_evaluate_ = false;
StmtExprMutator::VisitStmt_(op);
insert_in_evaluate_ = true;
if (contain_tma_load_) {
Array<Stmt> new_seq = {expect_tx_, GetRef<For>(op)};
contain_tma_load_ = false;
return SeqStmt(std::move(new_seq));
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(TMALoadOp()) ||
call->op.same_as(TMALoadIm2ColOp())) {
contain_tma_load_ = true;
if (insert_in_evaluate_) {
Array<Stmt> new_seq = {expect_tx_, GetRef<Evaluate>(op)};
return SeqStmt(std::move(new_seq));
}
}
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt expect_tx_;
bool contain_tma_load_;
bool insert_in_evaluate_;
};
class ProducerTraitsCollector : public StmtExprVisitor {
public:
ProducerTraitsCollector() { Clear(); }
......@@ -216,14 +256,29 @@ private:
loop_extents = old_loop_evtents;
}
void VisitStmt_(const IfThenElseNode *op) final {
bool old_in_if_cond = in_if_cond_;
in_if_cond_ = true;
VisitExpr(op->condition);
in_if_cond_ = old_in_if_cond;
VisitStmt(op->then_case);
if (op->else_case.defined()) {
VisitStmt(op->else_case.value());
}
}
void VisitExpr_(const BufferLoadNode *op) final {
has_simt_copy = true;
if (!in_if_cond_) {
has_simt_copy = true;
}
StmtExprVisitor::VisitExpr_(op);
}
bool has_simt_copy;
PrimExpr bulk_copy_bytes;
PrimExpr loop_extents;
bool in_if_cond_ = false;
};
// Rewrite the producer Stmt to use the correct barrier index
......@@ -515,9 +570,10 @@ private:
auto expect_tx = IfThenElse(
EQ(thread_var_, 0),
makeExpectTX(release_barrier_id, collector.BulkCopyBytes()));
block_stmt.push_back(expect_tx);
block_stmt.push_back(TMAExpectTxRewriter::Rewrite(stmt, expect_tx));
} else {
block_stmt.push_back(stmt);
}
block_stmt.push_back(stmt);
if (collector.HasSimtCopy() > 0) {
block_stmt.push_back(makeCpAsyncBarrier(release_barrier_id));
}
......
......@@ -32,10 +32,12 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# which may be introduced by the LegalizeSafeMemoryAccess
if target.arch == "sm_90":
mod = tilelang.transform.IfStmtBinding()(mod)
mod = tilelang.transform.MultiVersionBuffer()(mod)
mod = tilelang.transform.WarpSpecialized()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.MergeIfStmt()(mod)
mod = tilelang.transform.RewriteWgmmaSync()(mod)
# mod = tilelang.transform.WarpSpecializedPipeline()(mod)
mod = tilelang.transform.InjectFenceProxy()(mod)
......
......@@ -139,6 +139,28 @@ def ThreadPartialSync(storage_scope: str):
return _ffi_api.ThreadPartialSync(storage_scope) # type: ignore
def IfStmtBinding():
"""IfStmtBinding
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.IfStmtBinding() # type: ignore
def MergeIfStmt():
"""MergeIfStmt
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.MergeIfStmt() # type: ignore
def MultiVersionBuffer():
"""WarpSpecializedPipeline
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment