Commit f23c4d30 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Introduce padding annotation and improve memory access validation (#511)

* Added a new attribute `kPaddingMap` in `builtin.h` for managing padding annotations.
* Enhanced `SafeMemorysRewriter` to utilize an annotated padding map for buffer stores, improving memory access safety.
* Implemented checks in `layout_inference.cc` to ensure buffers are correctly referenced during layout mapping.
* Introduced a new test file for validating the padding annotation functionality in TileLang.
parent dbe8689f
......@@ -12,6 +12,11 @@
namespace tvm {
namespace tl {
namespace attr {
static constexpr const char *kPaddingMap = "padding_map";
} // namespace attr
static constexpr const char *kDebugMergeSharedMemoryAllocations =
"tl.debug_merge_shared_memory_allocations";
static constexpr const char *kDisableTMALower = "tl.disable_tma_lower";
......
......@@ -479,6 +479,8 @@ private:
auto map =
op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>().value();
for (const auto &[var, layout] : map) {
ICHECK(buffer_data_to_buffer_.count(var))
<< "buffer " << var << " is not found in the block";
auto buffer = buffer_data_to_buffer_[var];
ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
annotated_layout_map_.Set(buffer, layout);
......
......@@ -135,16 +135,33 @@ class SafeMemorysRewriter : public StmtExprMutator {
arith::Analyzer *analyzer_;
public:
explicit SafeMemorysRewriter(arith::Analyzer *analyzer)
: analyzer_(analyzer) {}
explicit SafeMemorysRewriter(Map<Buffer, PrimExpr> annotated_padding_map,
arith::Analyzer *analyzer)
: annotated_padding_map_(annotated_padding_map), analyzer_(analyzer) {}
private:
Stmt VisitStmt_(const BufferStoreNode *op) final {
// Check if the buffer is in global scope
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
GlobalMemChecker checker(analyzer_);
checker(store);
Array<PrimExpr> conditions = checker.GetConditions();
// Skip boundary check if the store value is an IfThenElse
if (const IfThenElseNode *if_node = store->value.as<IfThenElseNode>()) {
if (conditions.size() > 0) {
LOG(WARNING)
<< "Skipping boundary check for store with IfThenElse value: "
<< store->value
<< "\nAs manual boundary check detected, potential out-of-bounds "
"access may occur."
<< "\nAuto detect boundaries are " << conditions;
return store;
}
return store;
}
if (conditions.size() == 0) {
return store;
}
......@@ -161,7 +178,7 @@ private:
for (auto cond : conditions) {
ICHECK(cond.dtype() == DataType::Bool(1))
<< "condition is not a boolean: " << cond;
value = if_then_else(cond, value, make_zero(value->dtype));
value = if_then_else(cond, value, GetPadding(store->buffer));
}
store.CopyOnWrite()->value = value;
return store;
......@@ -170,7 +187,7 @@ private:
for (auto cond : conditions) {
ICHECK(cond.dtype() == DataType::Bool(1))
<< "condition is not a boolean: " << cond;
value = if_then_else(cond, value, make_zero(value->dtype));
value = if_then_else(cond, value, GetPadding(store->buffer));
}
store.CopyOnWrite()->value = value;
return store;
......@@ -224,6 +241,15 @@ private:
String scope = buffer.scope();
return scope == "global";
}
// Get the padding of the buffer
PrimExpr GetPadding(const Buffer &buffer) {
if (annotated_padding_map_.count(buffer)) {
return annotated_padding_map_[buffer];
}
return make_zero(buffer->dtype);
}
Map<Buffer, PrimExpr> annotated_padding_map_;
};
// Class to legalize safe memory access by transforming them appropriately
......@@ -236,6 +262,9 @@ public:
SafeMemoryLegalizer substituter(&analyzer);
// Get a mutable copy of the function node
PrimFuncNode *fptr = f.CopyOnWrite();
for (const auto &[_, buffer] : f->buffer_map) {
substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
}
// Apply the legalizer to the function body
fptr->body = substituter.VisitStmt(f->body);
return f;
......@@ -252,7 +281,7 @@ private:
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto has_inner_loop = HasInnerLoop(for_node->body);
if (!has_inner_loop) {
SafeMemorysRewriter rewriter(analyzer_);
SafeMemorysRewriter rewriter(annotated_padding_map_, analyzer_);
for_node.CopyOnWrite()->body = rewriter(for_node->body);
// // Detect Buffer Load Node in the loop body, collect the indices and
// buffer size
......@@ -276,11 +305,32 @@ private:
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
Stmt VisitStmt_(const BlockNode *op) final {
for (auto buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
if (op->annotations.count(attr::kPaddingMap)) {
auto map = op->annotations.Get(attr::kPaddingMap)
.as<Map<Var, PrimExpr>>()
.value();
for (const auto &[var, padding] : map) {
ICHECK(buffer_data_to_buffer_.count(var))
<< "buffer " << var << " is not found in the block";
auto buffer = buffer_data_to_buffer_[var];
annotated_padding_map_.Set(buffer, padding);
}
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
static bool HasInnerLoop(const Stmt &stmt) {
LeafForFinder finder;
finder(stmt);
return finder.leaf_for_nodes.size() > 0;
}
Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, PrimExpr> annotated_padding_map_;
};
// Create a pass that legalizes vectorized loops in the IRModule
......
import tilelang
import tilelang.language as T
import tilelang.testing
import torch
tilelang.disable_cache()
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.annotate_padding({A_shared: pad_value})
for i, j in T.Parallel(block_M, block_N):
A_shared[i, j] = A[by * block_M + i - 10, bx * block_N + j]
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = A_shared[i, j]
return main
def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16", pad_value=0):
program = tilelang_copy(M, N, block_M, block_N, dtype, pad_value=pad_value)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
print(kernel.get_kernel_source())
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
ref_b = torch.zeros_like(a)
for i in range(M):
if i >= 10:
ref_b[i, :] = a[i - 10, :]
else:
ref_b[i, :] = pad_value
torch.testing.assert_close(b, ref_b, rtol=1e-2, atol=1e-2)
def test_tilelang_copy():
run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, pad_value=10)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -79,12 +79,76 @@ def use_swizzle(panel_size: int, order: str = "row", enable: bool = True):
f"tl::{device_func}<{panel_size}>") if enable else None
def annotate_layout(layout_map):
def annotate_layout(layout_map: Dict):
"""Annotate the layout of the buffer
Args:
layout_map (Dict): a dictionary of buffer to layout
Returns:
block_attr: a block attribute
Example:
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.annotate_layout({A_shared: layout})
for i, j in T.Parallel(block_M, block_N):
A_shared[i, j] = A[by * block_M + i, bx * block_N + j]
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = A_shared[i, j]
return main
"""
# layout_map is a dictionary of buffer to layout
layout_map = {buffer.data: layout for buffer, layout in layout_map.items()}
return block_attr({"layout_map": layout_map})
def annotate_padding(padding_map: Dict):
"""Annotate the padding of the buffer
Args:
padding_map (dict): a dictionary of buffer to padding value
Returns:
block_attr: a block attribute
Example:
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.annotate_padding({A_shared: pad_value})
for i, j in T.Parallel(block_M, block_N):
A_shared[i, j] = A[by * block_M + i - 10, bx * block_N + j]
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = A_shared[i, j]
return main
"""
# padding_map is a dictionary of buffer to padding value
_padding_map = {}
for buffer, padding_value in padding_map.items():
# assert not global
assert buffer.scope() != "global", "padding can only be applied to global buffers"
_padding_map[buffer.data] = padding_value
return block_attr({"padding_map": _padding_map})
def import_source(source: Optional[str] = None):
# source is the source code to be imported
return block_attr({"pragma_import_c": source}) if source is not None else None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment