Commit 6cede73d authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Update warp specialization checking (#580)

* Fix L2 cache size calculation to handle symbolic expressions and ensure float conversion of hit ratios in annotation

* [Enhancement] Update warp specialization check in phase.py

* lint fix

* [Enhancement] Add ContainsSeqStmt method to improve statement handling in merge_shared_memory_allocations.cc

* [Refactor] Simplify memory copy operations in GEMM kernel tests

- Updated memory copy operations in `test_tilelang_kernel_gemm.py` to use shared memory allocations for both A and B matrices, improving clarity and performance.
- Adjusted the main execution block to include a new `run_gemm_rs` function call for testing, enhancing the test structure.

* revert memory reuse pass.

* revert the memory resue and thread sync pass/

* Update test_tilelang_kernel_gemm.py

* Update test_tilelang_kernel_mha_bwd.py
parent 44508e59
......@@ -150,7 +150,8 @@ public:
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[scope_.size() - 1].touched.push_back(buf);
// set into scope_.size() - 1 for aggressive memory reuse
scope_[it->second.level].touched.push_back(buf);
}
}
StmtEntry e = scope_.back();
......@@ -184,7 +185,7 @@ public:
ICHECK_LT(it->second.level, scope_.size())
<< "Load memory in places other than store.";
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[scope_.size() - 1].touched.push_back(buf);
scope_[it->second.level].touched.push_back(buf);
}
}
}
......@@ -195,7 +196,7 @@ public:
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[scope_.size() - 1].touched.push_back(buf);
scope_[it->second.level].touched.push_back(buf);
}
}
}
......@@ -242,8 +243,20 @@ public:
void VisitStmt_(const IfThenElseNode *op) final { VisitNewScope(op); }
bool ContainsSeqStmt(const Stmt &stmt) {
if (stmt->IsInstance<SeqStmtNode>()) {
return true;
}
if (const auto *if_node = stmt.as<IfThenElseNode>()) {
return ContainsSeqStmt(if_node->then_case) ||
(if_node->else_case.defined() &&
ContainsSeqStmt(if_node->else_case.value()));
}
return false;
}
void VisitStmt_(const ForNode *op) final {
if (op->body->IsInstance<SeqStmtNode>()) {
if (ContainsSeqStmt(op->body)) {
scope_level_++;
VisitNewScope(op);
scope_level_--;
......
......@@ -337,10 +337,10 @@ def matmul_sr(
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
T.copy(B[bx * block_N, k * block_K], B_local)
T.copy(B_shared, B_local)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(B[k * block_K, bx * block_N], B_local)
T.copy(B_shared, B_local)
T.gemm(A_shared, B_local, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
......@@ -443,18 +443,18 @@ def matmul_rs(
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared")
A_local = T.alloc_fragment(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
T.copy(A[k * block_K, by * block_M], A_local)
T.copy(A_shared, A_local)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(A[by * block_M, k * block_K], A_local)
T.copy(A_shared, A_local)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
......
......@@ -134,6 +134,11 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
return flash_bwd_post
@tilelang.jit(
out_idx=[7, 8],
pass_configs={
tilelang.PassConfigKey.TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS: True,
})
def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
......@@ -254,9 +259,9 @@ class _attention(torch.autograd.Function):
mod_prep = cached(flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD), [2])
mod_post = cached(flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD), [1])
delta = mod_prep(o, do)
mod = cached(
flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N), [6, 7, 8])
dq, dk, dv = mod(q, k, v, do, lse, delta)
mod = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
dq = torch.zeros_like(q, dtype=torch.float32)
dk, dv = mod(q, k, v, do, lse, delta, dq)
dq = mod_post(dq)
return dq, dk, dv, None
......@@ -302,8 +307,8 @@ def assert_mha_equal(batch, h, n_ctx, d_head, causal):
def test_mha_bwd():
assert_mha_equal(8, 32, 1024, 64, False)
assert_mha_equal(8, 32, 1024, 64, True)
assert_mha_equal(8, 32, 256, 64, False)
assert_mha_equal(8, 32, 256, 64, True)
if __name__ == "__main__":
......
......@@ -13,7 +13,8 @@ def allow_warp_specialized(pass_ctx: Optional[PassContext] = None,
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
if not is_cuda_target(target):
# Warp specialized pass is recommended for Hopper or later architectures
if not is_cuda_target(target) or not have_tma(target):
return False
disable_warp_specialized = pass_ctx.config.get("tl.disable_warp_specialized", False)
return not disable_warp_specialized
......@@ -148,13 +149,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
if allow_warp_specialized(pass_ctx=pass_ctx, target=target):
# This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass
# when warp specialization is enabled, as different warp threads may access different
# buffers, but the liveness analysis is hard because we need to do pipeline.
mod = tir.transform.MergeSharedMemoryAllocations()(mod)
else:
mod = tilelang.transform.MergeSharedMemoryAllocations()(mod)
mod = tilelang.transform.MergeSharedMemoryAllocations()(mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment