"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "0850ddd6a26c2f78250cfe115dcd4df2645673ba"
Commit 6cede73d authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Update warp specialization checking (#580)

* Fix L2 cache size calculation to handle symbolic expressions and ensure float conversion of hit ratios in annotation

* [Enhancement] Update warp specialization check in phase.py

* lint fix

* [Enhancement] Add ContainsSeqStmt method to improve statement handling in merge_shared_memory_allocations.cc

* [Refactor] Simplify memory copy operations in GEMM kernel tests

- Updated memory copy operations in `test_tilelang_kernel_gemm.py` to use shared memory allocations for both A and B matrices, improving clarity and performance.
- Adjusted the main execution block to include a new `run_gemm_rs` function call for testing, enhancing the test structure.

* revert memory reuse pass.

* revert the memory resue and thread sync pass/

* Update test_tilelang_kernel_gemm.py

* Update test_tilelang_kernel_mha_bwd.py
parent 44508e59
...@@ -150,7 +150,8 @@ public: ...@@ -150,7 +150,8 @@ public:
if (it != alloc_info_.end() && it->second.alloc) { if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size()); ICHECK_LT(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) { if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[scope_.size() - 1].touched.push_back(buf); // set into scope_.size() - 1 for aggressive memory reuse
scope_[it->second.level].touched.push_back(buf);
} }
} }
StmtEntry e = scope_.back(); StmtEntry e = scope_.back();
...@@ -184,7 +185,7 @@ public: ...@@ -184,7 +185,7 @@ public:
ICHECK_LT(it->second.level, scope_.size()) ICHECK_LT(it->second.level, scope_.size())
<< "Load memory in places other than store."; << "Load memory in places other than store.";
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) { if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[scope_.size() - 1].touched.push_back(buf); scope_[it->second.level].touched.push_back(buf);
} }
} }
} }
...@@ -195,7 +196,7 @@ public: ...@@ -195,7 +196,7 @@ public:
if (it != alloc_info_.end() && it->second.alloc) { if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size()); ICHECK_LT(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) { if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
scope_[scope_.size() - 1].touched.push_back(buf); scope_[it->second.level].touched.push_back(buf);
} }
} }
} }
...@@ -242,8 +243,20 @@ public: ...@@ -242,8 +243,20 @@ public:
void VisitStmt_(const IfThenElseNode *op) final { VisitNewScope(op); } void VisitStmt_(const IfThenElseNode *op) final { VisitNewScope(op); }
bool ContainsSeqStmt(const Stmt &stmt) {
if (stmt->IsInstance<SeqStmtNode>()) {
return true;
}
if (const auto *if_node = stmt.as<IfThenElseNode>()) {
return ContainsSeqStmt(if_node->then_case) ||
(if_node->else_case.defined() &&
ContainsSeqStmt(if_node->else_case.value()));
}
return false;
}
void VisitStmt_(const ForNode *op) final { void VisitStmt_(const ForNode *op) final {
if (op->body->IsInstance<SeqStmtNode>()) { if (ContainsSeqStmt(op->body)) {
scope_level_++; scope_level_++;
VisitNewScope(op); VisitNewScope(op);
scope_level_--; scope_level_--;
......
...@@ -337,10 +337,10 @@ def matmul_sr( ...@@ -337,10 +337,10 @@ def matmul_sr(
T.copy(A[by * block_M, k * block_K], A_shared) T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B: if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared) T.copy(B[bx * block_N, k * block_K], B_shared)
T.copy(B[bx * block_N, k * block_K], B_local) T.copy(B_shared, B_local)
else: else:
T.copy(B[k * block_K, bx * block_N], B_shared) T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(B[k * block_K, bx * block_N], B_local) T.copy(B_shared, B_local)
T.gemm(A_shared, B_local, C_local, trans_A, trans_B) T.gemm(A_shared, B_local, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N]) T.copy(C_local, C[by * block_M, bx * block_N])
...@@ -443,18 +443,18 @@ def matmul_rs( ...@@ -443,18 +443,18 @@ def matmul_rs(
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope="shared")
A_local = T.alloc_fragment(A_shared_shape, in_dtype) A_local = T.alloc_fragment(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope="shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local) T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A: if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared) T.copy(A[k * block_K, by * block_M], A_shared)
T.copy(A[k * block_K, by * block_M], A_local) T.copy(A_shared, A_local)
else: else:
T.copy(A[by * block_M, k * block_K], A_shared) T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(A[by * block_M, k * block_K], A_local) T.copy(A_shared, A_local)
if trans_B: if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared) T.copy(B[bx * block_N, k * block_K], B_shared)
else: else:
......
...@@ -134,6 +134,11 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim): ...@@ -134,6 +134,11 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
return flash_bwd_post return flash_bwd_post
@tilelang.jit(
out_idx=[7, 8],
pass_configs={
tilelang.PassConfigKey.TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS: True,
})
def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
sm_scale = (1.0 / dim)**0.5 sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
...@@ -254,9 +259,9 @@ class _attention(torch.autograd.Function): ...@@ -254,9 +259,9 @@ class _attention(torch.autograd.Function):
mod_prep = cached(flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD), [2]) mod_prep = cached(flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD), [2])
mod_post = cached(flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD), [1]) mod_post = cached(flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD), [1])
delta = mod_prep(o, do) delta = mod_prep(o, do)
mod = cached( mod = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N), [6, 7, 8]) dq = torch.zeros_like(q, dtype=torch.float32)
dq, dk, dv = mod(q, k, v, do, lse, delta) dk, dv = mod(q, k, v, do, lse, delta, dq)
dq = mod_post(dq) dq = mod_post(dq)
return dq, dk, dv, None return dq, dk, dv, None
...@@ -302,8 +307,8 @@ def assert_mha_equal(batch, h, n_ctx, d_head, causal): ...@@ -302,8 +307,8 @@ def assert_mha_equal(batch, h, n_ctx, d_head, causal):
def test_mha_bwd(): def test_mha_bwd():
assert_mha_equal(8, 32, 1024, 64, False) assert_mha_equal(8, 32, 256, 64, False)
assert_mha_equal(8, 32, 1024, 64, True) assert_mha_equal(8, 32, 256, 64, True)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -13,7 +13,8 @@ def allow_warp_specialized(pass_ctx: Optional[PassContext] = None, ...@@ -13,7 +13,8 @@ def allow_warp_specialized(pass_ctx: Optional[PassContext] = None,
if pass_ctx is None: if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context() pass_ctx = tilelang.transform.get_pass_context()
if not is_cuda_target(target): # Warp specialized pass is recommended for Hopper or later architectures
if not is_cuda_target(target) or not have_tma(target):
return False return False
disable_warp_specialized = pass_ctx.config.get("tl.disable_warp_specialized", False) disable_warp_specialized = pass_ctx.config.get("tl.disable_warp_specialized", False)
return not disable_warp_specialized return not disable_warp_specialized
...@@ -148,13 +149,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -148,13 +149,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.AnnotateDeviceRegions()(mod) mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod) mod = tir.transform.SplitHostDevice()(mod)
if allow_warp_specialized(pass_ctx=pass_ctx, target=target): mod = tilelang.transform.MergeSharedMemoryAllocations()(mod)
# This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass
# when warp specialization is enabled, as different warp threads may access different
# buffers, but the liveness analysis is hard because we need to do pipeline.
mod = tir.transform.MergeSharedMemoryAllocations()(mod)
else:
mod = tilelang.transform.MergeSharedMemoryAllocations()(mod)
mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment