Unverified Commit 85218bd9 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Enhane LetStmt Handling in Pipeline Transform (#1212)

* [Enhancement] Introduce LetWrapper for handling loop variable substitutions in pipeline rewriting

* Added LetWrapper struct to encapsulate variable and value pairs for loop variable substitutions.
* Updated PipelineRewriter to accept a vector of LetWrapper instances, allowing for proper handling of Let statements that depend on the pipeline loop variable.
* Enhanced the BuildPipeline method to incorporate LetWrapper instances into rewritten blocks, ensuring correct substitutions during pipeline execution.
* Refactored logic for processing Let statements to differentiate between those that use the loop variable and those that do not, improving the flexibility of the pipeline transformation.

* Refactor lambda expression for clarity in loop variable usage check in inject_pipeline.cc

* [Test] Add regression test for loop variable handling in kernel compilation

* Introduced a new test case to verify correct handling of loop variables in the kernel compilation process, addressing a regression issue with InjectSoftwarePipeline.
* The test ensures that the loop variable is not left as a free variable, which previously caused failures in MakePackedAPI.
* Configurations are set to disable warp specialization and TMA lowering to align with the original issue reproduction.

* Remove unused import in regression test for loop variable handling in kernel compilation
parent 918a21bd
......@@ -40,6 +40,11 @@ using namespace tir;
using namespace ffi;
namespace software_pipeline {
struct LetWrapper {
Var var;
PrimExpr value;
};
/*!
* \brief Create a block and infer the access region with the given body.
*
......@@ -233,10 +238,12 @@ class PipelineRewriter : public StmtExprMutator {
public:
PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
const Array<Buffer> &pipeline_allocs,
const For &pipeline_loop, const PipelineInfo &pipeline_info)
const For &pipeline_loop, const PipelineInfo &pipeline_info,
const std::vector<LetWrapper> &loop_var_let_wrappers)
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
pipeline_info_(pipeline_info) {}
pipeline_info_(pipeline_info),
loop_var_let_wrappers_(loop_var_let_wrappers) {}
Stmt BuildPipeline() {
// Step 1: Analyze accesses to the buffers in the pipeline and compute the
......@@ -677,6 +684,20 @@ private:
new_block = Downcast<Block>(Substitute(
new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
// If there were Let-wrappers outside the original pipeline body that
// depended on the pipeline loop var, push them into each rewritten
// block with the correct per-block substitution.
if (!loop_var_let_wrappers_.empty()) {
BlockNode *n = new_block.CopyOnWrite();
Stmt inner = n->body;
for (const auto &lw : loop_var_let_wrappers_) {
PrimExpr substituted = Substitute(
lw.value, {{pipeline_loop_->loop_var, normalized_access_index}});
inner = LetStmt(lw.var, substituted, inner);
}
n->body = inner;
}
if (pipeline_info_[block].async) {
auto &local_state = async_states_local[stage];
local_state.producer_head = normalized_access_index;
......@@ -738,6 +759,7 @@ private:
Map<Buffer, Buffer> buffer_remap_;
Array<Block> ordered_stmts_;
std::map<int, AsyncStateGlobal> async_states;
std::vector<LetWrapper> loop_var_let_wrappers_;
};
/*!
......@@ -865,6 +887,7 @@ private:
const SeqStmtNode *pipeline_body_seq = nullptr;
std::vector<std::function<Stmt(Stmt)>> rewrap_fns;
std::vector<LetWrapper> loop_var_let_wrappers;
auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) {
Any node = attr->node;
String attr_key = attr->attr_key;
......@@ -897,6 +920,16 @@ private:
continue;
}
if (const auto *let_stmt = current.as<LetStmtNode>()) {
// If this Let value uses the pipeline loop var, record it and push
// inside each rewritten block later so the loop var can be
// substituted with the correct per-iteration index. Otherwise, keep
// it as a normal wrapper.
bool uses_loop_var = UsesVar(
let_stmt->value,
[v = op->loop_var.get()](const VarNode *vn) { return vn == v; });
if (uses_loop_var) {
loop_var_let_wrappers.push_back({let_stmt->var, let_stmt->value});
} else {
Var var = let_stmt->var;
PrimExpr value = let_stmt->value;
Span span = let_stmt->span;
......@@ -905,6 +938,7 @@ private:
span](Stmt body) -> Stmt {
return LetStmt(var, value, body, span);
});
}
current = let_stmt->body;
continue;
}
......@@ -982,7 +1016,8 @@ private:
// Step 4: Rewrite the pipeline body.
Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
tvm::ffi::GetRef<For>(op), pipeline_info)
tvm::ffi::GetRef<For>(op), pipeline_info,
loop_var_let_wrappers)
.BuildPipeline();
auto apply_wrappers = [&](Stmt stmt) {
for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) {
......
import tilelang
import tilelang.language as T
import tilelang.testing
def _make_kernel(M, N):
dtype = "bfloat16"
@T.prim_func
def fwd_main(KV: T.Tensor((M, N), dtype), ids: T.Tensor((4,), "int32")):
with T.Kernel(4, threads=1):
A = T.alloc_shared([N], dtype)
B = T.alloc_shared([N], dtype)
# Regression for a bug where InjectSoftwarePipeline left the loop
# variable as a free var, causing MakePackedAPI to fail
for i in T.Pipelined(4, num_stages=1):
_id = ids[i]
T.copy(KV[_id, :], A)
T.clear(B)
return fwd_main
def test_make_packed_api_no_free_loop_var():
func = _make_kernel(4, 4)
# Keep warp-specialization/TMA disabled to match the original repro
cfg = {
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
}
tilelang.compile(func, pass_configs=cfg)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment