Unverified Commit ff35fc08 authored by Wenhao Xie's avatar Wenhao Xie Committed by GitHub
Browse files

[Bugfix] Address PassContext contamination from CI and fix incorrect rewrites...

[Bugfix] Address PassContext contamination from CI and fix incorrect rewrites in warp specialized pass (#767)

* fix ci and pass bug

* fix

* try

* lint
parent 37051417
...@@ -376,14 +376,25 @@ private: ...@@ -376,14 +376,25 @@ private:
eq_op->b.as<VarNode>() == thread_var_.get()) { eq_op->b.as<VarNode>() == thread_var_.get()) {
maybe_thread_opt_ = true; maybe_thread_opt_ = true;
} }
maybe_thread_opt_ = do_shuffle_ && maybe_thread_opt_; auto then_case = StmtExprMutator::VisitStmt(op->then_case);
maybe_thread_opt_ = do_shuffle_ && maybe_thread_opt_ && has_tma_op_;
has_tma_op_ = false;
if (maybe_thread_opt_) {
return IfThenElse(
Call(DataType::Bool(), tl_shuffle_elect(), {thread_extent_}),
StmtExprMutator::VisitStmt(op->then_case), std::nullopt);
}
} }
if (maybe_thread_opt_) return StmtExprMutator::VisitStmt_(op);
return IfThenElse( }
Call(DataType::Bool(), tl_shuffle_elect(), {thread_extent_}),
StmtExprMutator::VisitStmt(op->then_case), std::nullopt); PrimExpr VisitExpr_(const CallNode *op) final {
else if (op->op.same_as(tl::tma_load()) ||
return StmtExprMutator::VisitStmt_(op); op->op.same_as(tl::tma_load_im2col()) ||
op->op.same_as(tl::tma_store())) {
has_tma_op_ = true;
}
return StmtExprMutator::VisitExpr_(op);
} }
Var thread_var_; Var thread_var_;
...@@ -391,6 +402,7 @@ private: ...@@ -391,6 +402,7 @@ private:
PrimExpr thread_extent_; PrimExpr thread_extent_;
bool maybe_thread_opt_ = false; bool maybe_thread_opt_ = false;
bool do_shuffle_; bool do_shuffle_;
bool has_tma_op_ = false;
}; };
Block MakeGroupBlock(const Stmt &stmt, Block MakeGroupBlock(const Stmt &stmt,
......
...@@ -64,9 +64,10 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -64,9 +64,10 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
bx * block_N + t % (block_N // vec_load_b) * bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0)) (block_N // vec_load_b) + vec], T.float16(0))
mod = tvm.tir.transform.BindTarget(auto_target)(Before) with tvm.transform.PassContext():
mod = tl.transform.LowerTileOp()(mod) mod = tvm.tir.transform.BindTarget(auto_target)(Before)
mod = tvm.tir.transform.Simplify()(mod) mod = tl.transform.LowerTileOp()(mod)
mod = tvm.tir.transform.Simplify()(mod)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(After) ref_mod = tvm.tir.transform.BindTarget(auto_target)(After)
ref_mod = tvm.tir.transform.Simplify()(ref_mod) ref_mod = tvm.tir.transform.Simplify()(ref_mod)
# Note(tzj): The structures are equal except the argument in "T.reads" function. # Note(tzj): The structures are equal except the argument in "T.reads" function.
......
...@@ -43,8 +43,9 @@ def assert_gemm_codegen( ...@@ -43,8 +43,9 @@ def assert_gemm_codegen(
accum_dtype="float", accum_dtype="float",
): ):
func = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype) func = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype)
# Because the current pass context have been polluted by previous testing.
artifact = tilelang.lower(func, target="webgpu") with tvm.transform.PassContext():
artifact = tilelang.lower(func, target="webgpu")
src_code = artifact.kernel_source src_code = artifact.kernel_source
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment