".github/git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "0dc50a547ac7f10fbd09ef0e09dba445233c1913"
Unverified Commit 85218bd9 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Enhane LetStmt Handling in Pipeline Transform (#1212)

* [Enhancement] Introduce LetWrapper for handling loop variable substitutions in pipeline rewriting

* Added LetWrapper struct to encapsulate variable and value pairs for loop variable substitutions.
* Updated PipelineRewriter to accept a vector of LetWrapper instances, allowing for proper handling of Let statements that depend on the pipeline loop variable.
* Enhanced the BuildPipeline method to incorporate LetWrapper instances into rewritten blocks, ensuring correct substitutions during pipeline execution.
* Refactored logic for processing Let statements to differentiate between those that use the loop variable and those that do not, improving the flexibility of the pipeline transformation.

* Refactor lambda expression for clarity in loop variable usage check in inject_pipeline.cc

* [Test] Add regression test for loop variable handling in kernel compilation

* Introduced a new test case to verify correct handling of loop variables in the kernel compilation process, addressing a regression issue with InjectSoftwarePipeline.
* The test ensures that the loop variable is not left as a free variable, which previously caused failures in MakePackedAPI.
* Configurations are set to disable warp specialization and TMA lowering to align with the original issue reproduction.

* Remove unused import in regression test for loop variable handling in kernel compilation
parent 918a21bd
...@@ -40,6 +40,11 @@ using namespace tir; ...@@ -40,6 +40,11 @@ using namespace tir;
using namespace ffi; using namespace ffi;
namespace software_pipeline { namespace software_pipeline {
struct LetWrapper {
Var var;
PrimExpr value;
};
/*! /*!
* \brief Create a block and infer the access region with the given body. * \brief Create a block and infer the access region with the given body.
* *
...@@ -233,10 +238,12 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -233,10 +238,12 @@ class PipelineRewriter : public StmtExprMutator {
public: public:
PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer, PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
const Array<Buffer> &pipeline_allocs, const Array<Buffer> &pipeline_allocs,
const For &pipeline_loop, const PipelineInfo &pipeline_info) const For &pipeline_loop, const PipelineInfo &pipeline_info,
const std::vector<LetWrapper> &loop_var_let_wrappers)
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop), pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
pipeline_info_(pipeline_info) {} pipeline_info_(pipeline_info),
loop_var_let_wrappers_(loop_var_let_wrappers) {}
Stmt BuildPipeline() { Stmt BuildPipeline() {
// Step 1: Analyze accesses to the buffers in the pipeline and compute the // Step 1: Analyze accesses to the buffers in the pipeline and compute the
...@@ -677,6 +684,20 @@ private: ...@@ -677,6 +684,20 @@ private:
new_block = Downcast<Block>(Substitute( new_block = Downcast<Block>(Substitute(
new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
// If there were Let-wrappers outside the original pipeline body that
// depended on the pipeline loop var, push them into each rewritten
// block with the correct per-block substitution.
if (!loop_var_let_wrappers_.empty()) {
BlockNode *n = new_block.CopyOnWrite();
Stmt inner = n->body;
for (const auto &lw : loop_var_let_wrappers_) {
PrimExpr substituted = Substitute(
lw.value, {{pipeline_loop_->loop_var, normalized_access_index}});
inner = LetStmt(lw.var, substituted, inner);
}
n->body = inner;
}
if (pipeline_info_[block].async) { if (pipeline_info_[block].async) {
auto &local_state = async_states_local[stage]; auto &local_state = async_states_local[stage];
local_state.producer_head = normalized_access_index; local_state.producer_head = normalized_access_index;
...@@ -738,6 +759,7 @@ private: ...@@ -738,6 +759,7 @@ private:
Map<Buffer, Buffer> buffer_remap_; Map<Buffer, Buffer> buffer_remap_;
Array<Block> ordered_stmts_; Array<Block> ordered_stmts_;
std::map<int, AsyncStateGlobal> async_states; std::map<int, AsyncStateGlobal> async_states;
std::vector<LetWrapper> loop_var_let_wrappers_;
}; };
/*! /*!
...@@ -865,6 +887,7 @@ private: ...@@ -865,6 +887,7 @@ private:
const SeqStmtNode *pipeline_body_seq = nullptr; const SeqStmtNode *pipeline_body_seq = nullptr;
std::vector<std::function<Stmt(Stmt)>> rewrap_fns; std::vector<std::function<Stmt(Stmt)>> rewrap_fns;
std::vector<LetWrapper> loop_var_let_wrappers;
auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) { auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) {
Any node = attr->node; Any node = attr->node;
String attr_key = attr->attr_key; String attr_key = attr->attr_key;
...@@ -897,14 +920,25 @@ private: ...@@ -897,14 +920,25 @@ private:
continue; continue;
} }
if (const auto *let_stmt = current.as<LetStmtNode>()) { if (const auto *let_stmt = current.as<LetStmtNode>()) {
Var var = let_stmt->var; // If this Let value uses the pipeline loop var, record it and push
PrimExpr value = let_stmt->value; // inside each rewritten block later so the loop var can be
Span span = let_stmt->span; // substituted with the correct per-iteration index. Otherwise, keep
rewrap_fns.emplace_back([var = std::move(var), // it as a normal wrapper.
value = std::move(value), bool uses_loop_var = UsesVar(
span](Stmt body) -> Stmt { let_stmt->value,
return LetStmt(var, value, body, span); [v = op->loop_var.get()](const VarNode *vn) { return vn == v; });
}); if (uses_loop_var) {
loop_var_let_wrappers.push_back({let_stmt->var, let_stmt->value});
} else {
Var var = let_stmt->var;
PrimExpr value = let_stmt->value;
Span span = let_stmt->span;
rewrap_fns.emplace_back([var = std::move(var),
value = std::move(value),
span](Stmt body) -> Stmt {
return LetStmt(var, value, body, span);
});
}
current = let_stmt->body; current = let_stmt->body;
continue; continue;
} }
...@@ -982,7 +1016,8 @@ private: ...@@ -982,7 +1016,8 @@ private:
// Step 4: Rewrite the pipeline body. // Step 4: Rewrite the pipeline body.
Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs, Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
tvm::ffi::GetRef<For>(op), pipeline_info) tvm::ffi::GetRef<For>(op), pipeline_info,
loop_var_let_wrappers)
.BuildPipeline(); .BuildPipeline();
auto apply_wrappers = [&](Stmt stmt) { auto apply_wrappers = [&](Stmt stmt) {
for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) { for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) {
......
import tilelang
import tilelang.language as T
import tilelang.testing
def _make_kernel(M, N):
dtype = "bfloat16"
@T.prim_func
def fwd_main(KV: T.Tensor((M, N), dtype), ids: T.Tensor((4,), "int32")):
with T.Kernel(4, threads=1):
A = T.alloc_shared([N], dtype)
B = T.alloc_shared([N], dtype)
# Regression for a bug where InjectSoftwarePipeline left the loop
# variable as a free var, causing MakePackedAPI to fail
for i in T.Pipelined(4, num_stages=1):
_id = ids[i]
T.copy(KV[_id, :], A)
T.clear(B)
return fwd_main
def test_make_packed_api_no_free_loop_var():
func = _make_kernel(4, 4)
# Keep warp-specialization/TMA disabled to match the original repro
cfg = {
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
}
tilelang.compile(func, pass_configs=cfg)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment