Unverified Commit 72111642 authored by Chaofan Lin's avatar Chaofan Lin Committed by GitHub
Browse files

[Refactor] Refactor Pass `LegalizeSafeMemoryAccess` to support recursive load/store rewrite (#1050)



* [Refactor] Refactor Pass  to support recursive load/store rewrite

* lint

* recursive collect conds for call_extern

* fix name

* [Lint]: [pre-commit.ci] auto fixes [...]

* lint

* [Lint]: [pre-commit.ci] auto fixes [...]

* lint

* [Lint]: [pre-commit.ci] auto fixes [...]

* address comment

* rename pad_value to safe_value

* lint

* add oob store test

* [Lint]: [pre-commit.ci] auto fixes [...]

* fix

* fix

---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 278c0fbf
......@@ -22,7 +22,7 @@ namespace tvm {
namespace tl {
namespace attr {
static constexpr const char *kPaddingMap = "padding_map";
static constexpr const char *kSafeValueMap = "safe_value_map";
static constexpr const char *kWarpSpecializationScope =
"kWarpSpecializationScope";
static constexpr const char *kCustomWarpSpecialization =
......
......@@ -50,8 +50,7 @@ private:
bool parent_has_child_for_ = false;
};
// We will create a visitor to check BufferLoad and BufferStore nodes
// within this loop body. This visitor will:
// GlobalMemChecker for a BufferLoad/BufferStore node:
// 1. Identify BufferLoad and BufferStore nodes.
// 2. Check if the buffer is in global scope.
// 3. For each index, compare against the buffer's shape.
......@@ -59,22 +58,30 @@ private:
// log a warning or handle accordingly.
struct GlobalMemChecker : public StmtExprVisitor {
GlobalMemChecker(arith::Analyzer *analyzer) : analyzer_(analyzer) {}
GlobalMemChecker(arith::Analyzer *analyzer, bool recursively_collect_conds)
: analyzer_(analyzer),
recursively_collect_conds_(recursively_collect_conds) {}
void VisitExpr_(const BufferLoadNode *op) final {
// Check if the buffer is in global scope
// This is because we are writing TilePrograms, where out of bounds
// accesses only happen in the global buffer.
if (IsGlobalBuffer(op->buffer)) {
CheckBufferIndices(op->buffer, op->indices, /*is_load=*/true);
}
if (recursively_collect_conds_) {
StmtExprVisitor::VisitExpr_(op);
}
}
void VisitStmt_(const BufferStoreNode *op) final {
// Check if the buffer is in global scope
if (IsGlobalBuffer(op->buffer)) {
CheckBufferIndices(op->buffer, op->indices, /*is_load=*/false);
}
if (recursively_collect_conds_) {
StmtExprVisitor::VisitStmt_(op);
}
}
// Helper function to determine if a buffer is global
bool IsGlobalBuffer(const Buffer &buffer) {
......@@ -109,6 +116,7 @@ struct GlobalMemChecker : public StmtExprVisitor {
}
});
if (!has_variable) {
// If index is a constant, we can skip the check
continue;
}
......@@ -134,23 +142,48 @@ struct GlobalMemChecker : public StmtExprVisitor {
private:
Array<PrimExpr> _conditions;
arith::Analyzer *analyzer_;
bool recursively_collect_conds_;
};
class SafeMemorysRewriter : public StmtExprMutator {
arith::Analyzer *analyzer_;
public:
explicit SafeMemorysRewriter(Map<Buffer, PrimExpr> annotated_padding_map,
explicit SafeMemorysRewriter(Map<Buffer, PrimExpr> annotated_safe_value_map,
arith::Analyzer *analyzer)
: annotated_padding_map_(std::move(annotated_padding_map)),
: annotated_safe_value_map_(std::move(annotated_safe_value_map)),
analyzer_(analyzer) {}
private:
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
// For Load/Store, we only check the current node, not its children.
// Since rewriter will recursively visit children.
GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false);
checker(load);
Array<PrimExpr> conditions = checker.GetConditions();
if (conditions.empty()) {
return load;
}
// For loading, we can always use safe value if the access is out of
// bounds
PrimExpr value = load;
for (auto cond : conditions) {
ICHECK(cond.dtype() == DataType::Bool(1))
<< "condition is not a boolean: " << cond;
value = if_then_else(cond, value, GetSafeValue(load->buffer));
}
return value;
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
// Check if the buffer is in global scope
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
GlobalMemChecker checker(analyzer_);
GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false);
checker(store);
Array<PrimExpr> conditions = checker.GetConditions();
......@@ -172,49 +205,36 @@ private:
return store;
}
auto value = store->value;
if (IsGlobalBuffer(store->buffer)) {
// If a store is out of bounds, we skip the corresponding stmt directly.
Stmt store_with_conditions = store;
for (auto cond : conditions) {
store_with_conditions = IfThenElse(cond, store_with_conditions);
}
return store_with_conditions;
} else if (isSharedBuffer(store->buffer)) {
PrimExpr value = store->value;
for (auto cond : conditions) {
ICHECK(cond.dtype() == DataType::Bool(1))
<< "condition is not a boolean: " << cond;
value = if_then_else(cond, value, GetPadding(store->buffer));
}
store.CopyOnWrite()->value = value;
return store;
} else if (IsLocalBuffer(store->buffer)) {
PrimExpr value = store->value;
for (auto cond : conditions) {
ICHECK(cond.dtype() == DataType::Bool(1))
<< "condition is not a boolean: " << cond;
value = if_then_else(cond, value, GetPadding(store->buffer));
}
store.CopyOnWrite()->value = value;
return store;
} else {
LOG(FATAL) << "Check store buffer: " << store->buffer
<< " is not a global or shared or local buffer";
}
return store;
}
// Handle Call Nodes
// Recursively check Load/Store in the call arguments.
// For example
// T.call_extern("handle", "atomicAddx2", T.address_of(C),
// T.address_of(C_shared))
// NOTE(chaofan): This is currently not the most rigorous solution.
// The check here is primarily intended to handle extern functions like
// atomicAdd, which may involve memory access. Due to their special nature,
// the BufferLoad in their parameters might be used for boundary checks of the
// current statement. The current solution adopts a simplified approach:
// directly applying the boundary constraints of all parameters to the
// statement. While not entirely precise, it addresses most common scenarios.
Stmt VisitStmt_(const EvaluateNode *op) final {
auto evaluate = Downcast<Evaluate>(StmtExprMutator::VisitStmt_(op));
auto evaluate = Downcast<Evaluate>(op);
if (const CallNode *call_op = op->value.as<CallNode>()) {
auto call = Downcast<Call>(evaluate->value);
auto call = Downcast<Call>(op->value);
if (call->op == builtin::call_extern()) {
GlobalMemChecker checker(analyzer_);
// For CallExtern, we recursively collect conditions from all children.
// Since we cannot rewrite any BufferLoad in its children (Rewrite will
// cause potential Nullptr exception).
GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/true);
checker(call);
Array<PrimExpr> conditions = checker.GetConditions();
......@@ -248,15 +268,15 @@ private:
String scope = buffer.scope();
return scope == "global";
}
// Get the padding of the buffer
PrimExpr GetPadding(const Buffer &buffer) {
if (annotated_padding_map_.count(buffer)) {
return annotated_padding_map_[buffer];
// Get the safe value of the buffer
PrimExpr GetSafeValue(const Buffer &buffer) {
if (annotated_safe_value_map_.count(buffer)) {
return annotated_safe_value_map_[buffer];
}
return make_zero(buffer->dtype);
}
Map<Buffer, PrimExpr> annotated_padding_map_;
Map<Buffer, PrimExpr> annotated_safe_value_map_;
};
// Class to legalize safe memory access by transforming them appropriately
......@@ -288,7 +308,7 @@ private:
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto has_inner_loop = HasInnerLoop(for_node->body);
if (!has_inner_loop) {
SafeMemorysRewriter rewriter(annotated_padding_map_, analyzer_);
SafeMemorysRewriter rewriter(annotated_safe_value_map_, analyzer_);
for_node.CopyOnWrite()->body = rewriter(for_node->body);
// // Detect Buffer Load Node in the loop body, collect the indices and
// buffer size
......@@ -316,16 +336,16 @@ private:
for (auto buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
if (op->annotations.count(attr::kPaddingMap)) {
auto map = op->annotations.Get(attr::kPaddingMap)
if (op->annotations.count(attr::kSafeValueMap)) {
auto map = op->annotations.Get(attr::kSafeValueMap)
->as<Map<Var, PrimExpr>>()
.value();
for (const auto &[var, padding] : map) {
for (const auto &[var, safe_value] : map) {
ICHECK(buffer_data_to_buffer_.count(var))
<< "buffer " << var << " is not found in the block "
<< buffer_data_to_buffer_;
auto buffer = buffer_data_to_buffer_[var];
annotated_padding_map_.Set(buffer, padding);
annotated_safe_value_map_.Set(buffer, safe_value);
}
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
......@@ -338,7 +358,7 @@ private:
}
Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, PrimExpr> annotated_padding_map_;
Map<Buffer, PrimExpr> annotated_safe_value_map_;
};
// Create a pass that legalizes vectorized loops in the IRModule
......
......@@ -179,7 +179,7 @@ private:
using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;
Stmt VisitStmt_(const BlockNode *op) final {
if (op->annotations.count(attr::kPaddingMap)) {
if (op->annotations.count(attr::kSafeValueMap)) {
return RewritePaddingMap(op);
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
......@@ -191,18 +191,18 @@ private:
* \return The rewritten block.
*/
Stmt RewritePaddingMap(const BlockNode *op) {
auto padding_map = op->annotations.Get(attr::kPaddingMap);
if (!padding_map) {
auto safe_value_map = op->annotations.Get(attr::kSafeValueMap);
if (!safe_value_map) {
LOG(FATAL) << "Padding map annotation is missing";
}
Map<Var, Var> var_remap = CreateVarRemap();
Map<Var, PrimExpr> new_padding_map = RemapPaddingMap(
Downcast<Map<Var, PrimExpr>>(padding_map.value()), var_remap);
Map<Var, PrimExpr> new_safe_value_map = RemapPaddingMap(
Downcast<Map<Var, PrimExpr>>(safe_value_map.value()), var_remap);
auto block = Downcast<Block>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto block_ptr = block.CopyOnWrite();
block_ptr->annotations.Set(attr::kPaddingMap, new_padding_map);
block_ptr->annotations.Set(attr::kSafeValueMap, new_safe_value_map);
return block;
}
......@@ -220,21 +220,21 @@ private:
/*!
* \brief Remap the padding map using the variable remapping.
* \param padding_map The original padding map.
* \param safe_value_map The original padding map.
* \param var_remap The variable remapping.
* \return The remapped padding map.
*/
Map<Var, PrimExpr> RemapPaddingMap(const Map<Var, PrimExpr> &padding_map,
Map<Var, PrimExpr> RemapPaddingMap(const Map<Var, PrimExpr> &safe_value_map,
const Map<Var, Var> &var_remap) const {
Map<Var, PrimExpr> new_padding_map;
for (const auto &[var, padding] : padding_map) {
Map<Var, PrimExpr> new_safe_value_map;
for (const auto &[var, padding] : safe_value_map) {
if (var_remap.count(var)) {
new_padding_map.Set(var_remap.at(var), padding);
new_safe_value_map.Set(var_remap.at(var), padding);
} else {
new_padding_map.Set(var, padding);
new_safe_value_map.Set(var, padding);
}
}
return new_padding_map;
return new_safe_value_map;
}
Map<Buffer, Buffer> buffer_remap_;
......
......@@ -17,7 +17,7 @@ def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.annotate_padding({A_shared: pad_value})
T.annotate_safe_value({A: pad_value})
for i, j in T.Parallel(block_M, block_N):
A_shared[i, j] = A[by * block_M + i - 10, bx * block_N + j]
......
......@@ -8,7 +8,7 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off
dtype = "float32"
@T.prim_func
def main(A: T.Tensor((M, N), dtype="float32"),):
def main(A: T.Tensor((M, N), dtype=dtype),):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N), dtype=dtype)
tid = T.get_thread_binding()
......@@ -16,7 +16,7 @@ def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_off
A_shared[tid, j] = A[tid + M_offset, j + N_offset]
@T.prim_func
def expected(A: T.Tensor((M, N), dtype="float32"),):
def expected(A: T.Tensor((M, N), dtype=dtype),):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N), dtype=dtype)
tid = T.get_thread_binding()
......@@ -38,9 +38,127 @@ def assert_vectorize_access(M: int = 64, N: int = 64):
tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
def issue_1013_buggy_kernel():
# NOTE: This kernel is mainly to test some corner cases in boundary check
num_tokens = T.symbolic('num_tokens')
num_threads = 128
@T.prim_func
def main(x: T.Tensor((num_tokens,), dtype="int64")):
with T.Kernel(1, threads=num_threads) as _:
count = T.alloc_var('int')
thread_idx = T.get_thread_binding()
for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
idx = thread_idx + i * num_threads
count += x[idx] == 2
# NOTE(chaofan): Ideally, the prover should be able to prove that the access is safe
# and the padding value is not used. However, the current prover cannot handle this case.
# So for now the expected kernel is a if-else statement to check the boundary.
@T.prim_func
def expected(x: T.Tensor((num_tokens,), dtype="int64")):
with T.Kernel(1, threads=num_threads) as _:
count = T.alloc_var('int')
thread_idx = T.get_thread_binding()
for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
idx = thread_idx + i * num_threads
count += T.Cast("int32",
T.if_then_else(idx < num_tokens, x[idx], T.int64(0)) == T.int64(2))
return main, expected
def vectorize_access_with_atmoic_add_legalize(M: int = 64,
N: int = 64,
M_offset: int = 2,
N_offset: int = 2):
dtype = "float32"
@T.prim_func
def main(A: T.Tensor((M, N), dtype=dtype),):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N), dtype=dtype)
tid = T.get_thread_binding()
for j in T.serial(N):
A_shared[tid, j] = A[tid + M_offset, j + N_offset]
T.atomic_add(A[tid + M_offset, j + N_offset], 1)
@T.prim_func
def expected(A: T.Tensor((M, N), dtype=dtype),):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N), dtype=dtype)
tid = T.get_thread_binding()
T.reads(A[tid + M_offset, N_offset:N + N_offset])
for j in T.serial(N):
A_shared[tid, j] = T.if_then_else(
j + N_offset < N,
T.if_then_else(tid + M_offset < M, A[tid + M_offset, j + N_offset],
T.float32(0)), T.float32(0))
# Nest if-then-else is expected, do not flatten it to pass structural equal check
if j + N_offset < N: # noqa: SIM102
if tid + M_offset < M:
T.call_extern("handle", "AtomicAdd", A[tid + M_offset, j + N_offset], 1)
return main, expected
def assert_vectorize_access_with_atmoic_add(M: int = 64, N: int = 64):
func, expected = vectorize_access_with_atmoic_add_legalize(M, N)
mod = tvm.IRModule({func.attrs["global_symbol"]: func})
transformed = tl.transform.LegalizeSafeMemoryAccess()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
def oob_store_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2):
dtype = "float32"
@T.prim_func
def main(A: T.Tensor((M, N), dtype=dtype),):
with T.Kernel(1, 1, threads=M) as (bx, by):
tid = T.get_thread_binding()
for j in T.serial(N):
A[tid + M_offset, j + N_offset] = 1
@T.prim_func
def expected(A: T.Tensor((M, N), dtype=dtype),):
with T.Kernel(1, 1, threads=M) as (bx, by):
tid = T.get_thread_binding()
T.writes(A[tid + M_offset, N_offset:N + N_offset])
for j in T.serial(N):
if j + N_offset < N: # noqa: SIM102
if tid + M_offset < M:
A[tid + M_offset, j + N_offset] = T.float32(1.0)
return main, expected
def assert_oob_store_legalize(M: int = 64, N: int = 64):
func, expected = oob_store_legalize(M, N)
mod = tvm.IRModule({func.attrs["global_symbol"]: func})
transformed = tl.transform.LegalizeSafeMemoryAccess()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
def test_vectorize_access():
assert_vectorize_access(64, 64)
def test_issue_1013():
func, expected = issue_1013_buggy_kernel()
mod = tvm.IRModule({func.attrs["global_symbol"]: func})
transformed = tl.transform.LegalizeSafeMemoryAccess()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
def test_vectorize_access_with_atmoic_add():
assert_vectorize_access_with_atmoic_add(64, 64)
def test_oob_store():
assert_oob_store_legalize(64, 64)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -146,11 +146,14 @@ def annotate_layout(layout_map: Dict):
return block_attr({"layout_map": _layout_map})
def annotate_padding(padding_map: Dict):
"""Annotate the padding of the buffer
def annotate_safe_value(safe_value_map: Dict):
"""Annotate the safe value of the buffer.
A safe value of a buffer is the value that will be used when the
buffer is accessed out of bounds.
Args:
padding_map (dict): a dictionary of buffer to padding value
safe_value_map (dict): a dictionary of buffer to safe value
Returns:
block_attr: a block attribute
......@@ -165,7 +168,7 @@ def annotate_padding(padding_map: Dict):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
T.annotate_padding({A_shared: pad_value})
T.annotate_safe_value({A: safe_value})
for i, j in T.Parallel(block_M, block_N):
A_shared[i, j] = A[by * block_M + i - 10, bx * block_N + j]
......@@ -174,13 +177,11 @@ def annotate_padding(padding_map: Dict):
return main
"""
# padding_map is a dictionary of buffer to padding value
_padding_map = {}
for buffer, padding_value in padding_map.items():
# assert not global
assert buffer.scope() != "global", "padding can not be applied to global buffers"
_padding_map[buffer.data] = padding_value
return block_attr({"padding_map": _padding_map})
# safe_value_map is a dictionary of buffer to safe value
_safe_value_map = {}
for buffer, safe_value in safe_value_map.items():
_safe_value_map[buffer.data] = safe_value
return block_attr({"safe_value_map": _safe_value_map})
def annotate_l2_hit_ratio(l2_hit_ratio_map: Dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment