"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "81d2d168e9fac509fba8518b978dbc5a4444d009"
Commit f23c4d30 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Introduce padding annotation and improve memory access validation (#511)

* Added a new attribute `kPaddingMap` in `builtin.h` for managing padding annotations.
* Enhanced `SafeMemorysRewriter` to utilize an annotated padding map for buffer stores, improving memory access safety.
* Implemented checks in `layout_inference.cc` to ensure buffers are correctly referenced during layout mapping.
* Introduced a new test file for validating the padding annotation functionality in TileLang.
parent dbe8689f
...@@ -12,6 +12,11 @@ ...@@ -12,6 +12,11 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
namespace attr {
static constexpr const char *kPaddingMap = "padding_map";
} // namespace attr
static constexpr const char *kDebugMergeSharedMemoryAllocations = static constexpr const char *kDebugMergeSharedMemoryAllocations =
"tl.debug_merge_shared_memory_allocations"; "tl.debug_merge_shared_memory_allocations";
static constexpr const char *kDisableTMALower = "tl.disable_tma_lower"; static constexpr const char *kDisableTMALower = "tl.disable_tma_lower";
......
...@@ -479,6 +479,8 @@ private: ...@@ -479,6 +479,8 @@ private:
auto map = auto map =
op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>().value(); op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>().value();
for (const auto &[var, layout] : map) { for (const auto &[var, layout] : map) {
ICHECK(buffer_data_to_buffer_.count(var))
<< "buffer " << var << " is not found in the block";
auto buffer = buffer_data_to_buffer_[var]; auto buffer = buffer_data_to_buffer_[var];
ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape)); ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
annotated_layout_map_.Set(buffer, layout); annotated_layout_map_.Set(buffer, layout);
......
...@@ -135,16 +135,33 @@ class SafeMemorysRewriter : public StmtExprMutator { ...@@ -135,16 +135,33 @@ class SafeMemorysRewriter : public StmtExprMutator {
arith::Analyzer *analyzer_; arith::Analyzer *analyzer_;
public: public:
explicit SafeMemorysRewriter(arith::Analyzer *analyzer) explicit SafeMemorysRewriter(Map<Buffer, PrimExpr> annotated_padding_map,
: analyzer_(analyzer) {} arith::Analyzer *analyzer)
: annotated_padding_map_(annotated_padding_map), analyzer_(analyzer) {}
private: private:
Stmt VisitStmt_(const BufferStoreNode *op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
// Check if the buffer is in global scope // Check if the buffer is in global scope
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
GlobalMemChecker checker(analyzer_); GlobalMemChecker checker(analyzer_);
checker(store); checker(store);
Array<PrimExpr> conditions = checker.GetConditions(); Array<PrimExpr> conditions = checker.GetConditions();
// Skip boundary check if the store value is an IfThenElse
if (const IfThenElseNode *if_node = store->value.as<IfThenElseNode>()) {
if (conditions.size() > 0) {
LOG(WARNING)
<< "Skipping boundary check for store with IfThenElse value: "
<< store->value
<< "\nAs manual boundary check detected, potential out-of-bounds "
"access may occur."
<< "\nAuto detect boundaries are " << conditions;
return store;
}
return store;
}
if (conditions.size() == 0) { if (conditions.size() == 0) {
return store; return store;
} }
...@@ -161,7 +178,7 @@ private: ...@@ -161,7 +178,7 @@ private:
for (auto cond : conditions) { for (auto cond : conditions) {
ICHECK(cond.dtype() == DataType::Bool(1)) ICHECK(cond.dtype() == DataType::Bool(1))
<< "condition is not a boolean: " << cond; << "condition is not a boolean: " << cond;
value = if_then_else(cond, value, make_zero(value->dtype)); value = if_then_else(cond, value, GetPadding(store->buffer));
} }
store.CopyOnWrite()->value = value; store.CopyOnWrite()->value = value;
return store; return store;
...@@ -170,7 +187,7 @@ private: ...@@ -170,7 +187,7 @@ private:
for (auto cond : conditions) { for (auto cond : conditions) {
ICHECK(cond.dtype() == DataType::Bool(1)) ICHECK(cond.dtype() == DataType::Bool(1))
<< "condition is not a boolean: " << cond; << "condition is not a boolean: " << cond;
value = if_then_else(cond, value, make_zero(value->dtype)); value = if_then_else(cond, value, GetPadding(store->buffer));
} }
store.CopyOnWrite()->value = value; store.CopyOnWrite()->value = value;
return store; return store;
...@@ -224,6 +241,15 @@ private: ...@@ -224,6 +241,15 @@ private:
String scope = buffer.scope(); String scope = buffer.scope();
return scope == "global"; return scope == "global";
} }
// Get the padding of the buffer
PrimExpr GetPadding(const Buffer &buffer) {
if (annotated_padding_map_.count(buffer)) {
return annotated_padding_map_[buffer];
}
return make_zero(buffer->dtype);
}
Map<Buffer, PrimExpr> annotated_padding_map_;
}; };
// Class to legalize safe memory access by transforming them appropriately // Class to legalize safe memory access by transforming them appropriately
...@@ -236,6 +262,9 @@ public: ...@@ -236,6 +262,9 @@ public:
SafeMemoryLegalizer substituter(&analyzer); SafeMemoryLegalizer substituter(&analyzer);
// Get a mutable copy of the function node // Get a mutable copy of the function node
PrimFuncNode *fptr = f.CopyOnWrite(); PrimFuncNode *fptr = f.CopyOnWrite();
for (const auto &[_, buffer] : f->buffer_map) {
substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
}
// Apply the legalizer to the function body // Apply the legalizer to the function body
fptr->body = substituter.VisitStmt(f->body); fptr->body = substituter.VisitStmt(f->body);
return f; return f;
...@@ -252,7 +281,7 @@ private: ...@@ -252,7 +281,7 @@ private:
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op)); For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto has_inner_loop = HasInnerLoop(for_node->body); auto has_inner_loop = HasInnerLoop(for_node->body);
if (!has_inner_loop) { if (!has_inner_loop) {
SafeMemorysRewriter rewriter(analyzer_); SafeMemorysRewriter rewriter(annotated_padding_map_, analyzer_);
for_node.CopyOnWrite()->body = rewriter(for_node->body); for_node.CopyOnWrite()->body = rewriter(for_node->body);
// // Detect Buffer Load Node in the loop body, collect the indices and // // Detect Buffer Load Node in the loop body, collect the indices and
// buffer size // buffer size
...@@ -276,11 +305,32 @@ private: ...@@ -276,11 +305,32 @@ private:
return IRMutatorWithAnalyzer::VisitStmt_(op); return IRMutatorWithAnalyzer::VisitStmt_(op);
} }
Stmt VisitStmt_(const BlockNode *op) final {
for (auto buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
if (op->annotations.count(attr::kPaddingMap)) {
auto map = op->annotations.Get(attr::kPaddingMap)
.as<Map<Var, PrimExpr>>()
.value();
for (const auto &[var, padding] : map) {
ICHECK(buffer_data_to_buffer_.count(var))
<< "buffer " << var << " is not found in the block";
auto buffer = buffer_data_to_buffer_[var];
annotated_padding_map_.Set(buffer, padding);
}
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
static bool HasInnerLoop(const Stmt &stmt) { static bool HasInnerLoop(const Stmt &stmt) {
LeafForFinder finder; LeafForFinder finder;
finder(stmt); finder(stmt);
return finder.leaf_for_nodes.size() > 0; return finder.leaf_for_nodes.size() > 0;
} }
Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, PrimExpr> annotated_padding_map_;
}; };
// Create a pass that legalizes vectorized loops in the IRModule // Create a pass that legalizes vectorized loops in the IRModule
......
import tilelang
import tilelang.language as T
import tilelang.testing
import torch
tilelang.disable_cache()
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.annotate_padding({A_shared: pad_value})
for i, j in T.Parallel(block_M, block_N):
A_shared[i, j] = A[by * block_M + i - 10, bx * block_N + j]
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = A_shared[i, j]
return main
def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16", pad_value=0):
program = tilelang_copy(M, N, block_M, block_N, dtype, pad_value=pad_value)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
})
print(kernel.get_kernel_source())
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
ref_b = torch.zeros_like(a)
for i in range(M):
if i >= 10:
ref_b[i, :] = a[i - 10, :]
else:
ref_b[i, :] = pad_value
torch.testing.assert_close(b, ref_b, rtol=1e-2, atol=1e-2)
def test_tilelang_copy():
run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, pad_value=10)
if __name__ == "__main__":
tilelang.testing.main()
...@@ -79,12 +79,76 @@ def use_swizzle(panel_size: int, order: str = "row", enable: bool = True): ...@@ -79,12 +79,76 @@ def use_swizzle(panel_size: int, order: str = "row", enable: bool = True):
f"tl::{device_func}<{panel_size}>") if enable else None f"tl::{device_func}<{panel_size}>") if enable else None
def annotate_layout(layout_map): def annotate_layout(layout_map: Dict):
"""Annotate the layout of the buffer
Args:
layout_map (Dict): a dictionary of buffer to layout
Returns:
block_attr: a block attribute
Example:
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.annotate_layout({A_shared: layout})
for i, j in T.Parallel(block_M, block_N):
A_shared[i, j] = A[by * block_M + i, bx * block_N + j]
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = A_shared[i, j]
return main
"""
# layout_map is a dictionary of buffer to layout # layout_map is a dictionary of buffer to layout
layout_map = {buffer.data: layout for buffer, layout in layout_map.items()} layout_map = {buffer.data: layout for buffer, layout in layout_map.items()}
return block_attr({"layout_map": layout_map}) return block_attr({"layout_map": layout_map})
def annotate_padding(padding_map: Dict):
"""Annotate the padding of the buffer
Args:
padding_map (dict): a dictionary of buffer to padding value
Returns:
block_attr: a block attribute
Example:
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.annotate_padding({A_shared: pad_value})
for i, j in T.Parallel(block_M, block_N):
A_shared[i, j] = A[by * block_M + i - 10, bx * block_N + j]
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = A_shared[i, j]
return main
"""
# padding_map is a dictionary of buffer to padding value
_padding_map = {}
for buffer, padding_value in padding_map.items():
# assert not global
assert buffer.scope() != "global", "padding can only be applied to global buffers"
_padding_map[buffer.data] = padding_value
return block_attr({"padding_map": _padding_map})
def import_source(source: Optional[str] = None): def import_source(source: Optional[str] = None):
# source is the source code to be imported # source is the source code to be imported
return block_attr({"pragma_import_c": source}) if source is not None else None return block_attr({"pragma_import_c": source}) if source is not None else None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment