"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "71214b48548b1dcb6ebd581dd36a9d0e60af6837"
Commit e937faa6 authored by Yu Cheng's avatar Yu Cheng Committed by LeiWang1999
Browse files

[Dev] Update benchmark and decoding scripts to refine condition checks and...

[Dev] Update benchmark and decoding scripts to refine condition checks and optimize tensor operations (#637)

- Enhanced the condition in `compare_ab` to ensure baseline checks align with target exclusions.
- Removed unnecessary tensor allocation in `mla_decode_tilelang`, optimizing memory usage and improving performance by directly using shared tensors in GEMM operations.
parent 02a0cf59
......@@ -496,7 +496,8 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
if target not in ["flash_infer", "flash_mla_triton", "flash_mla_tilelang"]:
if target not in ["flash_infer", "flash_mla_triton", "tilelang"
] and baseline not in ["flash_infer", "flash_mla_triton", "tilelang"]:
# flash_infer has a different lse return value
# flash_mla_triton and flash_mla_tilelang doesn't return lse
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
......
......@@ -36,7 +36,6 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
K_pe_shared = T.alloc_shared([block_N, dpe], dtype)
O_shared = T.alloc_shared([block_H, dv], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dv], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
......@@ -86,12 +85,11 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
T.copy(S_shared, acc_s_cast)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dv):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dv):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment