Unverified Commit f8ae600c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Do not force inline let stmt (#947)

* remove debug print

* Remove inline let expressions from the LowerAndLegalize function in phase.py

* add test

* Update sparse MLA examples to support SKV adjustment and correctness checks

- Changed SKV parameter from 32768 to 8192 in sparse MLA backward and forward tests.
- Added check_correctness parameter to test functions for validation of outputs.
- Updated test cases to reflect new SKV values and correctness checks.

* reduce test shape

* Update documentation structure and refactor main function parameters in example_fusedmoe_tilelang.py

- Added a new section for compiler internals in the documentation.
- Refactored the main function in example_fusedmoe_tilelang.py to accept parameters for hidden dimensions, expert configurations, and batch/sequence sizes, improving flexibility and readability.

* Update buffer access checks in merge_shared_memory_allocations.cc

- Changed the condition for buffer access from less than (<) to less than or equal to (<=) to allow access at the same scope level.
- Adjusted the logic for determining the access level when touching buffers to ensure correct handling of scope levels.

* lint fix

* Support pipeline with LetStmt

* lint fix

* • Fix LowerTileOp let handling to avoid LetInline dependency

  - inline let-bound BufferLoad nodes via resolver helpers and structured return
  - remap layouts/buffers using original data vars and only rewrite when needed
  - update pipeline planner to understand let-bound address_of buffers
  - document the new inline behaviour in docs/let_inline_fix.md

* fix for wgmma pipeline with let binding

* lint fix

* test fix

* reduce smem usage.

* let binding enhancement

* fix for dpgm

* fix simplify

* lint fix

* use tilelang.Simplify instead of tir.Simplify

* • Add TL_FORCE_LET_INLINE pass config and gate eager LetInline usage

  - register the new config in builtin headers/registration
  - add helper to pipeline enabling LetInline based on pass context
  - document LetStmt inlining controls and usage
parent 7cd0da99
# LetStmt Inlining in TileLang
This document explains how `LetStmt` inlining works in TileLang's simplification pipeline, which is an important optimization that affects code generation and performance.
## Overview
A `LetStmt` (Let Statement) is a temporary variable binding in the IR (Intermediate Representation). During compilation, TileLang's simplifier may choose to inline these temporary variables to simplify the code. TileLang also provides a standalone `LetInline` pass that performs eager substitution before the main legalization pipeline. However, not all `LetStmt` nodes can be safely inlined.
## When Does LetStmt Get Inlined?
The inlining logic is implemented in `src/transform/simplify.cc`. A `LetStmt` will be inlined if **both** of the following conditions are met:
### 1. The value satisfies `CanInlineLetStmt`
The `CanInlineLetStmt` helper returns `true` when:
- **The value is a constant** (`is_const_number(op->value)` returns true)
- **The value is a variable** (`op->value.as<VarNode>()` returns a node)
- **The value is an integer expression without side effects**:
- The value has `int` dtype
- The side effect level is `kPure` or lower (no observable side effects)
```cpp
bool CanInlineLetStmt(const LetStmtNode *op) {
if (is_const_number(op->value))
return true;
if (op->value.as<VarNode>())
return true;
// Won't face the deep expression explosion problem as in Let expression.
// attempt to inline as much as possible if the value integer type(can be
// index).
if (!op->value.dtype().is_int())
return false;
return SideEffect(op->value) <= CallEffectKind::kPure;
}
```
### 2. The variable is NOT used in buffer definitions
Even if `CanInlineLetStmt` returns true, the variable will **not** be inlined if it's used in a buffer's definition (shape, strides, elem_offset, or data fields).
This protection exists because:
- Buffer definitions are not updated during the simplification pass
- If a variable used in a buffer definition is inlined, later references to that buffer would fail to find the variable definition
- This would cause compilation errors or incorrect behavior
The mutator checks this before dropping the binding:
```cpp
bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get());
if (can_inline && !used_in_buffer_def) {
return body; // Inline: remove LetStmt and return body directly
}
```
## Example: Why Buffer Definition Variables Are Protected
Consider this code:
```python
let stride = M * 16
let buffer_a = Buffer(data, shape=[M, N], strides=[stride, 1])
buffer_a[i, j] = ...
```
- `stride` satisfies `CanInlineLetStmt` (it's an int expression with no side effects)
- However, `stride` is used in `buffer_a`'s `strides` field
- If we inline it, the buffer definition becomes `strides=[M*16, 1]`
- But the Buffer object's fields are not updated during simplification
- Later code accessing `buffer_a` would fail to find the `stride` variable
Therefore, `stride` is added to `used_in_buffer_def_` and will **not** be inlined.
## How Variables Are Collected
The `CollectVarsUsedInBufferDefinition` helper traverses all `BufferLoad` and `BufferStore` nodes and collects variables used in their buffer definitions:
```cpp
void VisitBuffer(const Buffer &buf) {
// Collect variables that should remain defined
VarUseDefAnalyzer usage(Array<Var>{});
usage(buf->data);
for (const auto &dim : buf->shape) {
usage(dim);
}
for (const auto &dim : buf->strides) {
usage(dim);
}
usage(buf->elem_offset);
// Track for use in LetStmtNode mutator
for (const auto &var : usage.undefined_) {
used_in_buffer_def_.insert(var.get());
}
}
```
## Practical Example: Temporary Variable Issue
Consider this TileLang code:
```python
for i in T.Parallel(block_N):
idx = bx * block_N + i
tmp = T.max(A[idx], 1)
B[idx] = tmp / 2
A[idx] = tmp * 2
```
In this case:
- `tmp` is an integer-like temporary variable
- It satisfies `CanInlineLetStmt` (pure int expression)
- It's **not** used in any buffer definition
- Therefore, `tmp` **will be inlined**
This means the IR becomes:
```python
for i in T.Parallel(block_N):
idx = bx * block_N + i
B[idx] = T.max(A[idx], 1) / 2
A[idx] = T.max(A[idx], 1) * 2
```
If this causes issues (e.g., `A[idx]` being read twice with different values due to the first write), it indicates a potential problem with the inlining heuristic or the code pattern.
## Controlling Let Inlining via Pass Config
TileLang exposes an explicit pass configuration key, `tilelang.PassConfigKey.TL_FORCE_LET_INLINE` (`"tl.force_let_inline"`), that allows users to force the eager `LetInline` pass to run before the legalization pipeline begins. When enabled, the pipeline invokes `tilelang.transform.LetInline()` at the start of `LowerAndLegalize` (see `tilelang/engine/phase.py`). This knob is useful when debugging LetStmt-related issues or when deterministic inlining behavior is desired across different environments.
```python
from tilelang import transform
from tilelang.engine.phase import LowerAndLegalize
with transform.PassContext(
config={transform.PassConfigKey.TL_FORCE_LET_INLINE: True}
):
lowered_mod = LowerAndLegalize(input_mod, target)
```
If the flag is left unset (the default), the eager pass is only applied when downstream transforms opt in (for example, by calling `_Simplify(..., inline_let=True)` inside Tile operators). The guard in `tilelang/engine/phase.py` ensures the eager pass is only triggered when the user explicitly requests it.
## Summary
The LetStmt inlining mechanism is a **conservative optimization** that:
1. Aggressively inlines simple, pure integer expressions to simplify the IR
2. Protects variables used in buffer definitions to avoid breaking buffer access
3. Helps reduce IR complexity and improve code generation
4. Can be forced through `TL_FORCE_LET_INLINE` when deterministic eager inlining is required
Understanding when inlining happens is crucial for:
- Debugging compilation issues
- Understanding generated code
- Writing efficient TileLang programs
- Identifying potential optimization opportunities or bugs
## Related Files
- `src/transform/simplify.cc`: Main Simplify implementation
- `src/transform/frontend_legalize.cc`: Standalone LetInline pass
- `tilelang/engine/phase.py`: Pipeline integration for eager LetInlining
- `testing/python/transform/test_tilelang_transform_let_inline.py`: Regression coverage for the pass
......@@ -35,6 +35,13 @@ deeplearning_operators/matmul
deeplearning_operators/deepseek_mla
:::
:::{toctree}
:maxdepth: 1
:caption: COMPILER INTERNALS
compiler_internals/letstmt_inline
:::
:::{toctree}
:maxdepth: 1
:caption: API Reference
......
......@@ -16,11 +16,11 @@ def test_example_tilelang_block_sparse_attn():
def test_example_tilelang_sparse_gqa_decode_varlen_indice():
example_tilelang_sparse_gqa_decode_varlen_indice.main()
example_tilelang_sparse_gqa_decode_varlen_indice.main(batch=1, max_cache_seqlen=2048)
def test_example_tilelang_sparse_gqa_decode_varlen_mask():
example_tilelang_sparse_gqa_decode_varlen_mask.main()
example_tilelang_sparse_gqa_decode_varlen_mask.main(batch=1, max_cache_seqlen=2048)
def test_example_triton_sparse_gqa_decode_varlen_indice():
......
......@@ -521,15 +521,21 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
return output
def main():
def main(d_hidden=7168,
d_expert=2048,
n_routed_experts=8,
n_shared_experts=1,
n_experts_per_token=4,
batch_size=1,
seq_len=8192):
config = {
"dhidden": 7168,
"dexpert": 2048,
"nroutedexperts": 8,
"nsharedexperts": 1,
"nexpertspertoken": 4,
"bs": 1,
"seqlen": 8192,
"dhidden": d_hidden,
"dexpert": d_expert,
"nroutedexperts": n_routed_experts,
"nsharedexperts": n_shared_experts,
"nexpertspertoken": n_experts_per_token,
"bs": batch_size,
"seqlen": seq_len,
"seed": 81394
}
......
......@@ -3,7 +3,14 @@ import example_fusedmoe_tilelang
def test_example_fusedmoe_tilelang():
example_fusedmoe_tilelang.main()
example_fusedmoe_tilelang.main(
d_hidden=1024,
d_expert=256,
n_routed_experts=8,
n_shared_experts=1,
n_experts_per_token=4,
batch_size=1,
seq_len=1024)
if __name__ == "__main__":
......
......@@ -25,6 +25,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kForceLetInline, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer);
......
......@@ -71,6 +71,14 @@ static constexpr const char *kDisableDynamicTailSplit =
static constexpr const char *kDisableThreadStorageSync =
"tl.disable_thread_storage_sync";
/*!
* \brief Force inline Let bindings during simplification.
*
* kForceLetInline = "tl.force_let_inline"
*
*/
static constexpr const char *kForceLetInline = "tl.force_let_inline";
/*!
* \brief The size of the vectorized dimension in buffer, designed by user
*
......
......@@ -26,6 +26,7 @@
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>
#include <functional>
#include <unordered_set>
#include <utility>
......@@ -845,7 +846,8 @@ private:
// Step 2: Find the body and buffer allocations of the pipeline. The body
// can be direct child of the for-loop. If the for-loop has BlockRealize as
// its child, the pipeline body will be the child of the block.
Stmt pipeline_body{nullptr};
Stmt pipeline_body_root{nullptr};
bool pipeline_body_from_block = false;
Array<Buffer> pipeline_allocs;
if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
const auto &block = realize->block;
......@@ -853,16 +855,68 @@ private:
ICHECK(buffer->IsInstance<BufferNode>());
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
pipeline_body = block->body;
pipeline_body_root = block->body;
pipeline_allocs = block->alloc_buffers;
pipeline_body_from_block = true;
} else {
pipeline_body = for_node->body;
pipeline_body_root = for_node->body;
}
const SeqStmtNode *pipeline_body_seq = nullptr;
std::vector<std::function<Stmt(Stmt)>> rewrap_fns;
auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) {
ObjectRef node = attr->node;
String attr_key = attr->attr_key;
PrimExpr value = attr->value;
Span span = attr->span;
rewrap_fns.emplace_back(
[node = std::move(node), attr_key = std::move(attr_key),
value = std::move(value), span](Stmt body) -> Stmt {
return AttrStmt(node, attr_key, value, body, span);
});
};
{
Stmt current = pipeline_body_root;
while (true) {
if (const auto *seq_stmt = current.as<SeqStmtNode>()) {
pipeline_body_seq = seq_stmt;
break;
}
const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
"should be SeqStmt, got "
<< pipeline_body->GetTypeKey();
if (const auto *if_then_else = current.as<IfThenElseNode>()) {
ICHECK(!if_then_else->else_case.defined())
<< "InjectSoftwarePipeline: Can't handle the body of the loop "
"because the IfThenElse node has an else branch";
PrimExpr condition = if_then_else->condition;
Span span = if_then_else->span;
rewrap_fns.emplace_back(
[condition = std::move(condition), span](Stmt body) -> Stmt {
return IfThenElse(condition, body, Stmt(), span);
});
current = if_then_else->then_case;
continue;
}
if (const auto *let_stmt = current.as<LetStmtNode>()) {
Var var = let_stmt->var;
PrimExpr value = let_stmt->value;
Span span = let_stmt->span;
rewrap_fns.emplace_back([var = std::move(var),
value = std::move(value),
span](Stmt body) -> Stmt {
return LetStmt(var, value, body, span);
});
current = let_stmt->body;
continue;
}
if (const auto *attr = current.as<AttrStmtNode>()) {
append_attr_wrapper(attr);
current = attr->body;
continue;
}
LOG(FATAL) << "ValueError: The body of the software pipeline should be "
<< "SeqStmt, got " << current->GetTypeKey();
}
}
ICHECK(pipeline_body_seq != nullptr);
// Step 3: Blockize the components of the pipeline. Each child of the
// pipelined loop will be converted into a block.
......@@ -934,6 +988,27 @@ private:
Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
GetRef<For>(op), pipeline_info)
.BuildPipeline();
auto apply_wrappers = [&](Stmt stmt) {
for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) {
stmt = (*it)(stmt);
}
return stmt;
};
if (!rewrap_fns.empty()) {
if (pipeline_body_from_block) {
BlockRealize pipeline_realize = Downcast<BlockRealize>(pipeline);
Block pipeline_block = pipeline_realize->block;
{
BlockNode *block_node = pipeline_block.CopyOnWrite();
block_node->body = apply_wrappers(block_node->body);
}
pipeline = BlockRealize(pipeline_realize->iter_values,
pipeline_realize->predicate, pipeline_block,
pipeline_realize->span);
} else {
pipeline = apply_wrappers(pipeline);
}
}
if (const auto *realize = op->body.as<BlockRealizeNode>()) {
const auto &block = realize->block;
......
......@@ -5,9 +5,11 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include <unordered_map>
#include "../layout/layout.h"
#include "../layout/utils.h"
......@@ -318,10 +320,16 @@ private:
return buffer_row_size;
}
PrimExpr
struct AccessPtrResult {
PrimExpr expr;
bool rewritten{false};
};
AccessPtrResult
HandleAccessPtrAndOffset(const PrimExpr &access_ptr,
const Optional<PrimExpr> &offset = std::nullopt,
DataType dtype = DataType::Int(32)) {
AccessPtrResult result{access_ptr, false};
// The 2th arg of T.tvm_access_ptr call is offset, we set it to 0 and
// accumulate it to smem_offset
CHECK(access_ptr->IsInstance<CallNode>())
......@@ -330,6 +338,16 @@ private:
if (access_ptr_call->op.same_as(builtin::tvm_access_ptr())) {
LOG(FATAL) << "Transformation for tvm_access_ptr is not implemented yet";
} else if (access_ptr_call->op.same_as(builtin::address_of())) {
Optional<PrimExpr> resolved = ResolveBufferLoad(access_ptr_call->args[0]);
ICHECK(resolved.defined())
<< "Invalid access op for permuted layout: " << access_ptr;
PrimExpr load_expr = resolved.value();
if (!load_expr.same_as(access_ptr_call->args[0])) {
auto node = access_ptr_call.CopyOnWrite();
node->args.Set(0, load_expr);
access_ptr_call = Call(access_ptr_call->dtype, access_ptr_call->op,
{load_expr}, access_ptr_call->span);
}
BufferLoad load = Downcast<BufferLoad>(access_ptr_call->args[0]);
Array<PrimExpr> indices = load->indices;
Array<PrimExpr> old_shape = load->buffer->shape;
......@@ -351,14 +369,17 @@ private:
PrimExpr smem_offset =
elem_offset + (offset.defined() ? offset.value() : 0);
auto new_buffer = buffer_remap_[load->buffer];
Buffer remap_key = FindRemapBuffer(load->buffer).value_or(load->buffer);
Optional<Layout> layout = FindLayout(remap_key);
if (!layout.defined() || !buffer_map_.count(remap_key->data)) {
return result;
}
auto new_buffer = buffer_remap_.count(remap_key)
? buffer_remap_[remap_key]
: load->buffer;
auto new_shape = new_buffer->shape;
auto buffer_map_iter =
buffer_map_.find(Downcast<Var>(load->buffer->data));
CHECK(buffer_map_iter != buffer_map_.end())
<< "The buffer corresponding to data Var " << access_ptr_call->args[0]
<< " is not found";
auto buffer_map_iter = buffer_map_.find(Downcast<Var>(remap_key->data));
int buffer_row_size = CheckAndGetBufferRowSize(buffer_map_iter->second);
(void)buffer_row_size;
......@@ -373,8 +394,7 @@ private:
remaining_offset = floordiv(remaining_offset, old_shape[i]);
}
auto forward_indices =
layout_map_[load->buffer]->Forward(multi_dim_indices);
auto forward_indices = layout.value()->Forward(multi_dim_indices);
PrimExpr new_offset = 0;
PrimExpr stride_offset = 1;
for (int i = static_cast<int>(new_shape.size()) - 1; i >= 0; --i) {
......@@ -390,14 +410,71 @@ private:
new_offset = floordiv(new_offset, new_shape[i]);
}
auto new_access_ptr = access_ptr_call.CopyOnWrite();
new_access_ptr->args.Set(0, BufferLoad(new_buffer, new_indices));
layout_remap_.Set(new_buffer, layout_map_[load->buffer]);
Array<PrimExpr> new_args = {BufferLoad(new_buffer, new_indices)};
if (buffer_remap_.count(remap_key)) {
layout_remap_.Set(new_buffer, layout.value());
}
result.rewritten = true;
result.expr = Call(access_ptr_call->dtype, access_ptr_call->op, new_args,
access_ptr_call->span);
return result;
} else {
LOG(FATAL) << "Invalid access op for permuted layout: " << access_ptr;
}
return access_ptr_call;
return result;
}
Optional<PrimExpr> ResolveBufferLoad(const PrimExpr &expr) const {
if (expr->IsInstance<BufferLoadNode>()) {
return expr;
}
if (const auto *var_node = expr.as<VarNode>()) {
Var var = GetRef<Var>(var_node);
auto it = let_bindings_.find(var);
if (it != let_bindings_.end()) {
return it->second;
}
}
return Optional<PrimExpr>();
}
Optional<Buffer> FindRemapBuffer(const Buffer &buffer) const {
if (buffer_remap_.count(buffer)) {
return buffer;
}
auto it = buffer_map_.find(buffer->data);
if (it != buffer_map_.end() && buffer_remap_.count(it->second)) {
return it->second;
}
for (const auto &kv : buffer_remap_) {
if (kv.first->data.same_as(buffer->data)) {
return kv.first;
}
if (kv.first->name == buffer->name) {
return kv.first;
}
}
return Optional<Buffer>();
}
Optional<Layout> FindLayout(const Buffer &buffer) const {
if (layout_map_.count(buffer)) {
return layout_map_[buffer];
}
auto it = buffer_map_.find(buffer->data);
if (it != buffer_map_.end() && layout_map_.count(it->second)) {
return layout_map_[it->second];
}
for (const auto &kv : layout_map_) {
if (kv.first->data.same_as(buffer->data)) {
return kv.second;
}
if (kv.first->name == buffer->name) {
return kv.second;
}
}
return Optional<Layout>();
}
PrimExpr VisitExpr_(const tir::CallNode *op) final {
......@@ -422,18 +499,30 @@ private:
// form: T.ptx_ldmatrix(..., smem_ptr, smem_offset)
// smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask)
// or T.address_of(buffer, offset)
auto access_ptr = call->args[5];
PrimExpr access_ptr = call->args[5];
PrimExpr smem_offset = call->args[6];
Call address_of_call = Downcast<Call>(access_ptr);
if (!address_of_call->op.same_as(builtin::address_of())) {
LOG(FATAL) << "Invalid access ptr for permuted layout: " << access_ptr;
}
Optional<PrimExpr> resolved = ResolveBufferLoad(address_of_call->args[0]);
ICHECK(resolved.defined())
<< "Invalid address_of argument for permuted layout: "
<< address_of_call->args[0];
PrimExpr load_expr = resolved.value();
if (!load_expr.same_as(address_of_call->args[0])) {
auto call_node = call.CopyOnWrite();
call_node->args.Set(5, Call(address_of_call->dtype, address_of_call->op,
{load_expr}, address_of_call->span));
address_of_call = Downcast<Call>(call->args[5]);
access_ptr = call->args[5];
}
BufferLoad load = Downcast<BufferLoad>(address_of_call->args[0]);
if (buffer_remap_.count(load->buffer)) {
auto new_access_ptr =
HandleAccessPtrAndOffset(access_ptr, smem_offset, call->dtype);
if (new_access_ptr.rewritten) {
auto new_call = call.CopyOnWrite();
new_call->args.Set(5, new_access_ptr);
new_call->args.Set(5, new_access_ptr.expr);
new_call->args.Set(6, IntImm(smem_offset->dtype, 0));
}
} else if (call->op.same_as(builtin::mma_store())) {
......@@ -442,8 +531,10 @@ private:
auto access_ptr = call->args[2];
auto new_access_ptr =
HandleAccessPtrAndOffset(access_ptr, std::nullopt, call->dtype);
if (new_access_ptr.rewritten) {
auto new_call = call.CopyOnWrite();
new_call->args.Set(2, new_access_ptr);
new_call->args.Set(2, new_access_ptr.expr);
}
} else {
LOG(FATAL) << "Invalid call node: " << call;
}
......@@ -500,6 +591,30 @@ private:
return var;
}
Stmt VisitStmt_(const LetStmtNode *op) final {
PrimExpr value = this->VisitExpr(op->value);
bool recorded = false;
if (value->IsInstance<BufferLoadNode>()) {
let_bindings_[op->var] = value;
recorded = true;
}
if (SideEffect(value) <= CallEffectKind::kPure) {
analyzer_->Bind(op->var, value);
}
Stmt body = this->VisitStmt(op->body);
if (recorded) {
let_bindings_.erase(op->var);
}
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
auto n = this->CopyOnWrite(op);
n->value = value;
n->body = body;
return Stmt(n);
}
}
/**
* @brief Handle an Evaluate node, lowering a detected tile operator to TIR.
*
......@@ -590,6 +705,8 @@ private:
// For ptx Node, we need to remap the buffer and indices
// By access CallNode instead of BufferLoad Node.
bool is_ptx_{false};
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>
let_bindings_;
// Mapping from data Var of a Buffer to Buffer, for lookup
std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
Map<Var, Var> var_remap_;
......
......@@ -194,14 +194,19 @@ public:
const VarNode *buf = op->buffer->data.get();
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size())
// Allow buffer access at the same level or deeper scope
// Changed from < to <= to handle cases where buffer is accessed
// in expressions at the same scope level where it's allocated
ICHECK_LE(it->second.level, scope_.size())
<< "Load memory in places other than store.";
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
} else {
scope_[it->second.level].touched.push_back(buf);
// When accessing at the same level, use that level
size_t access_level = std::min(it->second.level, scope_.size() - 1);
scope_[access_level].touched.push_back(buf);
}
}
}
......@@ -211,13 +216,16 @@ public:
// Directly reference to the variable count as a read.
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
// Allow buffer access at the same level or deeper scope
ICHECK_LE(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
} else {
scope_[it->second.level].touched.push_back(buf);
// When accessing at the same level, use that level
size_t access_level = std::min(it->second.level, scope_.size() - 1);
scope_[access_level].touched.push_back(buf);
}
}
}
......
......@@ -10,6 +10,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <functional>
#include <unordered_set>
#include <utility>
#include "../op/builtin.h"
......@@ -139,10 +141,40 @@ private:
Array<Buffer> GetVersionedBuffers(const Array<Stmt> &seq_stmt,
const Array<Buffer> &scoped_buffers) {
Array<Stmt> pipeline_stmts;
std::function<void(const Stmt &)> collect_stmts = [&](const Stmt &stmt) {
if (const auto *seq = stmt.as<SeqStmtNode>()) {
for (const Stmt &s : seq->seq) {
collect_stmts(s);
}
return;
}
if (const auto *let = stmt.as<LetStmtNode>()) {
collect_stmts(let->body);
return;
}
if (const auto *attr = stmt.as<AttrStmtNode>()) {
collect_stmts(attr->body);
return;
}
if (const auto *block_realize = stmt.as<BlockRealizeNode>()) {
collect_stmts(block_realize->block->body);
return;
}
if (const auto *block = stmt.as<BlockNode>()) {
collect_stmts(block->body);
return;
}
pipeline_stmts.push_back(stmt);
};
for (const Stmt &stmt : seq_stmt) {
collect_stmts(stmt);
}
std::vector<Role> roles;
Array<Array<BufferRegion>> reads, writes;
auto marker = WarpSpecializedRoleMarker_(buffer_data_to_buffer_);
for (auto stmt : seq_stmt) {
for (const Stmt &stmt : pipeline_stmts) {
marker(stmt);
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{},
/*name_hint=*/"", /*body*/ stmt);
......@@ -153,20 +185,68 @@ private:
}
std::unordered_set<const BufferNode *> consumer_used, producer_used;
for (size_t i = 0; i < seq_stmt.size(); i++) {
if (roles[i] == Role::kProducer) {
for (BufferRegion br : writes[i])
std::unordered_map<const BufferNode *, size_t> first_write_index;
std::unordered_map<const BufferNode *, size_t> last_read_index;
auto is_copy_stage = [&](size_t idx) {
bool has_shared_write = false;
for (const BufferRegion &wr : writes[idx]) {
auto scope = wr->buffer.scope();
if (scope == "shared" || scope == "shared.dyn") {
has_shared_write = true;
break;
}
}
if (!has_shared_write)
return false;
for (const BufferRegion &rd : reads[idx]) {
if (rd->buffer.scope() == "global") {
return true;
}
}
return false;
};
for (size_t i = 0; i < pipeline_stmts.size(); i++) {
bool copy_stage = is_copy_stage(i);
bool is_producer = roles[i] == Role::kProducer ||
(roles[i] == Role::kBoth && copy_stage);
bool is_consumer = roles[i] == Role::kConsumer ||
(roles[i] == Role::kBoth && !copy_stage);
if (is_producer) {
for (BufferRegion br : writes[i]) {
producer_used.insert(br->buffer.get());
} else {
for (BufferRegion br : reads[i])
}
}
if (is_consumer) {
for (BufferRegion br : reads[i]) {
consumer_used.insert(br->buffer.get());
}
}
for (BufferRegion br : writes[i]) {
const BufferNode *buf = br->buffer.get();
if (!first_write_index.count(buf)) {
first_write_index[buf] = i;
}
}
for (BufferRegion br : reads[i]) {
last_read_index[br->buffer.get()] = i;
}
}
Array<Buffer> versioned_buffers;
for (Buffer buffer : scoped_buffers) {
if (consumer_used.count(buffer.get()) &&
producer_used.count(buffer.get())) {
versioned_buffers.push_back(buffer);
continue;
}
// Fallback: if we saw a write before a later read, the buffer spans
// multiple stages even if role classification missed one side.
auto it_w = first_write_index.find(buffer.get());
auto it_r = last_read_index.find(buffer.get());
if (it_w != first_write_index.end() && it_r != last_read_index.end() &&
it_w->second < it_r->second) {
if (!is_copy_stage(it_w->second))
continue;
versioned_buffers.push_back(buffer);
}
}
return versioned_buffers;
......@@ -197,32 +277,112 @@ private:
}
}
block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
// Record the updated alloc list to recover buffers whose LCA is the block.
block_alloc_buffers_[op->block.get()] = block->alloc_buffers;
block_realize.CopyOnWrite()->block = block;
return block_realize;
}
Stmt VisitStmt_(const BlockNode *op) final {
stmt_stack_.push_back(op);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
stmt_stack_.pop_back();
return stmt;
}
Stmt VisitStmt_(const ForNode *op) final {
stmt_stack_.push_back(op);
loop_stack_.emplace_back(op->loop_var, op->extent);
auto num_stages_anno = op->annotations.Get("num_stages");
if (!num_stages_anno) {
auto for_node = StmtExprMutator::VisitStmt_(op);
loop_stack_.pop_back();
stmt_stack_.pop_back();
return for_node;
}
ICHECK(num_stages_anno->as<IntImmNode>());
int num_stages = static_cast<int>(num_stages_anno->as<IntImmNode>()->value);
const SeqStmtNode *pipeline_body_seq = op->body.as<SeqStmtNode>();
CHECK(pipeline_body_seq) << "ValueError: The body of the software pipeline "
"should be SeqStmt, got "
<< op->body->GetTypeKey();
Stmt pipeline_body_root{nullptr};
if (const auto *realize = op->body.as<BlockRealizeNode>()) {
const auto &block = realize->block;
for (const auto &buffer : block->alloc_buffers) {
ICHECK(buffer->IsInstance<BufferNode>());
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
pipeline_body_root = block->body;
} else {
pipeline_body_root = op->body;
}
const SeqStmtNode *pipeline_body_seq = nullptr;
{
// Traverse trivial wrappers (let/if) to find the actual SeqStmt body.
Stmt current = pipeline_body_root;
while (true) {
if (const auto *seq_stmt = current.as<SeqStmtNode>()) {
pipeline_body_seq = seq_stmt;
break;
}
if (const auto *if_then_else = current.as<IfThenElseNode>()) {
ICHECK(!if_then_else->else_case.defined())
<< "MultiVersionBuffer: Can't handle the body of the loop "
"because the IfThenElse node has an else branch";
current = if_then_else->then_case;
continue;
}
if (const auto *let_stmt = current.as<LetStmtNode>()) {
current = let_stmt->body;
continue;
}
LOG(FATAL)
<< "MultiVersionBuffer: Can't handle the body of the loop because "
<< "it is not a SeqStmt, IfThenElse without else, "
<< "or LetStmt wrapping them, but got " << current->GetTypeKey();
}
}
ICHECK(pipeline_body_seq != nullptr);
Array<Buffer> scoped_buffers = {};
Array<Buffer> scoped_buffers;
std::unordered_set<const BufferNode *> seen;
for (auto [buffer, stmt] : buffer_lca_) {
if (stmt.defined() && stmt.value().get() == op)
if (!stmt.defined())
continue;
const StmtNode *lca = stmt.value().get();
bool in_scope = false;
for (const StmtNode *ancestor : stmt_stack_) {
if (ancestor == lca) {
in_scope = true;
break;
}
}
if (!in_scope)
continue;
// Only double-buffer shared allocations; locals do not need versioning.
auto scope = buffer.scope();
if (!(scope == "shared" || scope == "shared.dyn"))
continue;
if (seen.insert(buffer.get()).second) {
scoped_buffers.push_back(buffer);
}
}
for (auto it = stmt_stack_.rbegin(); it != stmt_stack_.rend(); ++it) {
if (!(*it)->IsInstance<BlockNode>())
continue;
const auto *block = static_cast<const BlockNode *>(*it);
auto map_it = block_alloc_buffers_.find(block);
if (map_it == block_alloc_buffers_.end())
continue;
for (const Buffer &buffer : map_it->second) {
auto scope = buffer.scope();
if (!(scope == "shared" || scope == "shared.dyn"))
continue;
if (seen.insert(buffer.get()).second) {
scoped_buffers.push_back(buffer);
}
}
}
Array<Buffer> versioned_buffers =
GetVersionedBuffers(pipeline_body_seq->seq, scoped_buffers);
......@@ -240,6 +400,7 @@ private:
version_index_ = FloorMod(linear_index, num_stages);
auto for_node = StmtExprMutator::VisitStmt_(op);
loop_stack_.pop_back();
stmt_stack_.pop_back();
return for_node;
}
......@@ -312,9 +473,15 @@ private:
PrimExpr version_index_;
std::vector<std::pair<Var, PrimExpr>> loop_stack_;
// Track ancestor statements to query whether an LCA is inside the current
// loop.
std::vector<const StmtNode *> stmt_stack_;
Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, Optional<Stmt>> buffer_lca_;
Map<Buffer, Buffer> buffer_remap_;
// Remember each block's alloc list so the loop can see buffers defined in
// parents.
std::unordered_map<const BlockNode *, Array<Buffer>> block_alloc_buffers_;
};
using namespace tir::transform;
......
......@@ -2,10 +2,12 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
#include <unordered_map>
#include <utility>
#include "../target/utils.h"
......@@ -204,10 +206,20 @@ private:
void VisitExpr_(const CallNode *op) final {
auto args = op->args;
if (op->op.same_as(builtin::address_of())) {
const BufferLoad load = Downcast<BufferLoad>(op->args[0]);
const BufferRegion buffer_region = BufferRegion::FullRegion(load->buffer);
BufferRegion buffer_region;
if (const auto *load = op->args[0].as<BufferLoadNode>()) {
buffer_region = BufferRegion::FullRegion(load->buffer);
} else if (const auto *var_node = op->args[0].as<VarNode>()) {
Var data_var = GetRef<Var>(var_node);
auto it = buffer_data_to_buffer_.find(data_var);
if (it != buffer_data_to_buffer_.end()) {
buffer_region = BufferRegion::FullRegion((*it).second);
}
}
if (buffer_region.defined()) {
// because we only care about the buffer itself instead of indices
reads_.push_back(buffer_region);
}
} else if (op->op.same_as(builtin::tvm_access_ptr())) {
const VarNode *buffer_var = op->args[1].as<VarNode>();
ICHECK(buffer_var);
......@@ -398,38 +410,49 @@ private:
if (!num_stages_anno)
return StmtExprMutator::VisitStmt_(loop);
int num_stages = num_stages_anno->as<IntImmNode>()->value;
Stmt pipeline_body{nullptr};
Stmt pipeline_body_root{nullptr};
if (const auto *realize = loop->body.as<BlockRealizeNode>()) {
const auto &block = realize->block;
for (const auto &buffer : block->alloc_buffers) {
ICHECK(buffer->IsInstance<BufferNode>());
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
if (const auto *seq_stmt = block->body.as<SeqStmtNode>()) {
pipeline_body = block->body;
} else if (const auto *if_then_else = block->body.as<IfThenElseNode>()) {
// should assert else case is nullptr
pipeline_body_root = block->body;
} else {
pipeline_body_root = loop->body;
}
const SeqStmtNode *pipeline_body_seq = nullptr;
{
Stmt current = pipeline_body_root;
while (true) {
if (const auto *seq_stmt = current.as<SeqStmtNode>()) {
pipeline_body_seq = seq_stmt;
break;
}
if (const auto *if_then_else = current.as<IfThenElseNode>()) {
ICHECK(!if_then_else->else_case.defined())
<< "Pipeline_Planning: Can't handle the body of the loop because "
"it is not a SeqStmt";
pipeline_body = if_then_else->then_case;
} else {
"the IfThenElse node has an else branch";
current = if_then_else->then_case;
continue;
}
if (const auto *let_stmt = current.as<LetStmtNode>()) {
current = let_stmt->body;
continue;
}
LOG(FATAL) << "Pipeline_Planning: Can't handle the body of the loop "
"because it is not a SeqStmt or IfThenElse";
<< "because it is not a SeqStmt, IfThenElse without else, "
<< "or LetStmt wrapping them, but got "
<< current->GetTypeKey();
}
} else {
pipeline_body = loop->body;
}
const SeqStmtNode *pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
CHECK(pipeline_body_seq)
<< "ValueError: The body of the software pipeline "
"should be SeqStmt, got "
<< pipeline_body->GetTypeKey() << " " << pipeline_body;
ICHECK(pipeline_body_seq != nullptr);
CHECK(num_stages >= 1);
CHECK(loop->kind == ForKind::kSerial);
AsyncDependencyChainBuilder chain_builder(buffer_data_to_buffer_);
chain_builder(pipeline_body);
chain_builder(pipeline_body_root);
std::vector<PipelineStageInfo> pipeline_stage_infos;
for (size_t i = 0; i < pipeline_body_seq->size(); i++) {
......
......@@ -5,12 +5,14 @@
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include <optional>
#include <utility>
#include "arith/ir_mutator_with_analyzer.h"
......@@ -327,31 +329,63 @@ private:
Stmt VisitStmt_(const LetStmtNode *op) override {
PrimExpr value = this->VisitExpr(op->value);
bool remove_buffer_alias = false;
// TileLang emits aliases like `X_shared = buffer[0:128, 0:32]` to annotate
// fragment types. TVM currently reinterprets vectorized/shared accesses as
// Let-bound BufferLoad/BufferRegion nodes. If these bindings survive, later
// passes (Layout rewrite, FlattenBuffer) substitute them with vector lanes
// that our layout can't handle. Force-inline (by dropping the let) whenever
// the alias spans more than 2 dims or carries vector lanes.
auto get_ranges = [&](const PrimExpr &expr) -> Array<Range> {
Array<Range> ranges;
if (const auto *load = expr.as<BufferLoadNode>()) {
for (const PrimExpr &index : load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
ranges.push_back(Range::FromMinExtent(ramp->base, ramp->lanes));
} else {
ranges.push_back(Range::FromMinExtent(index, Integer(1)));
}
}
} else if (const auto *region = expr.as<BufferRegionNode>()) {
for (const Range &range : region->region) {
ranges.push_back(range);
}
}
return ranges;
};
Array<Range> ranges = get_ranges(value);
if (!ranges.empty()) {
int non_unit_dims = 0;
for (const Range &range : ranges) {
PrimExpr extent = analyzer_->Simplify(range->extent);
if (is_const_int(extent, 1) || analyzer_->CanProveEqual(extent, 1)) {
continue;
}
++non_unit_dims;
if (non_unit_dims > 1) {
remove_buffer_alias = true;
break;
}
}
}
if (remove_buffer_alias) {
Stmt body = this->VisitStmt(op->body);
bool used = UsesVar(
body, [&](const VarNode *var) { return var == op->var.get(); });
ICHECK(!used) << "Let binding of BufferLoad is expected to be unused "
"before removal "
<< op->var << " : " << op->value << " .";
return body;
}
bool can_inline = CanInlineLetStmt(op);
if (can_inline) {
// It is usually fine to discard the let binding because the
// call to simplify will always inline the var.
//
// The exception is when the variable is used in a Buffer's
// definition, as these are not updated by the simplification.
// After DeclBuffer is required prior to use of a buffer,
// simplifying can update the buffer definition as well. The
// buffer can only be updated at its point of definition,
// because the points of use may occur within contexts that
// allow for additional simplifications (e.g. a buffer of shape
// [i,j] whose first use occurs within "if i==1" should not have
// its shape simplified to [1,j]).
analyzer_->Bind(op->var, value);
} else if (SideEffect(op->value) <= CallEffectKind::kPure) {
// Even if we aren't replacing all occurrences, they may be
// necessary for proving conditional statements.
non_inlined_bindings_.Set(op->var, value);
}
Stmt body = this->VisitStmt(op->body);
// TODO(Lunderberg): Update the Buffer object as part of
// DeclBuffer updates, which will first require
// https://github.com/apache/tvm/pull/14778.
bool used_in_buffer_def = used_in_buffer_def_.count(op->var.get());
if (can_inline && !used_in_buffer_def) {
......
import tilelang
import tilelang.testing
import tilelang.language as T
import torch
@tilelang.jit
def _tmp_var_kernel(N, block_N, dtype="float"):
@T.prim_func
def kernel(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:
for i in T.Parallel(block_N):
idx = bx * block_N + i
tmp = T.max(A[idx], 1)
B[idx] = tmp / 2
A[idx] = tmp * 2
return kernel
def run_tmp_var_test(N=1024, block_N=128):
kernel = _tmp_var_kernel(N, block_N)
a = torch.randn(N, device="cuda", dtype=torch.float)
b = torch.empty(N, device="cuda", dtype=torch.float)
a_ref = a.clone()
kernel(a, b)
# Reference computation
tmp_ref = torch.maximum(a_ref, torch.tensor(1.0, dtype=torch.float, device="cuda"))
b_ref = tmp_ref / 2
a_ref = tmp_ref * 2
# Validate correctness
tilelang.testing.torch_assert_close(a, a_ref, rtol=1e-2, atol=1e-2)
tilelang.testing.torch_assert_close(b, b_ref, rtol=1e-2, atol=1e-2)
def test_issue_814():
"""Test that temporary variables are correctly handled and not over-inlined"""
run_tmp_var_test(N=1024, block_N=128)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -105,5 +105,34 @@ def test_multi_version_buffer():
_check(before, after)
def test_multi_version_buffer_with_let():
@T.prim_func
def before(scales: T.Tensor((4,), "float32")):
with T.block("root"):
shared = T.alloc_buffer((8,), "float32", scope="shared.dyn")
accum = T.alloc_buffer((8,), "float32", scope="local")
for k in T.serial(4, annotations={"num_stages": T.int32(2)}):
value: T.float32 = scales[k]
for i in T.serial(8):
shared[i] = value
for i in T.serial(8):
accum[i] = accum[i] + shared[i]
@T.prim_func
def after(scales: T.Tensor((4,), "float32")):
with T.block("root"):
shared = T.alloc_buffer((2, 8), "float32", scope="shared.dyn")
accum = T.alloc_buffer((8,), "float32", scope="local")
for k in T.serial(4, annotations={"num_stages": T.int32(2)}):
value: T.float32 = scales[k]
for i in T.serial(8):
shared[k % 2, i] = value
for i in T.serial(8):
accum[i] = accum[i] + shared[k % 2, i]
_check(before, after)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -61,6 +61,12 @@ def should_enable_aggressive_merge(pass_ctx: Optional[PassContext] = None,
return enable_aggressive_merge
def should_force_let_inline(pass_ctx: Optional[PassContext] = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False))
def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# Bind the target device information to the module
"""
......@@ -85,14 +91,15 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
"""
mod = tir.transform.BindTarget(target)(mod)
# Inline let expressions and statements
if should_force_let_inline():
# Force-let inline whenever the pass config requests it.
mod = tilelang.transform.LetInline()(mod)
# Add wrapper for single buf store
mod = tilelang.transform.AddWrapperForSingleBufStore()(mod)
# Inject assumes to speedup tvm prover
mod = tilelang.transform.InjectAssumes()(mod)
# Simplify the IR expressions
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.Simplify()(mod)
# Set layouts for reducers
mod = tilelang.transform.LayoutReducer()(mod)
# Infer memory layouts for fragments and shared memory
......
......@@ -66,6 +66,9 @@ class PassConfigKey(str, Enum):
optimization in cases where manual synchronization is preferred or when
synchronization is not needed. Default: False"""
TL_FORCE_LET_INLINE = "tl.force_let_inline"
"""Force TileLang to inline let bindings during simplification. Default: False"""
# TIR related configs
TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir"
"""Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment