Commit cfcbcf1e authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

Refactor MLA decode kernel: Replace T.If with native Python if statement (#162)

Simplify the control flow in the MLA decode kernel by replacing TileLang's T.If construct with a standard Python if statement. This change improves code readability and maintains the existing logic for handling sequence length constraints during block-wise computation.
parent 18be9e07
...@@ -73,7 +73,7 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc ...@@ -73,7 +73,7 @@ def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, bloc
policy=T.GemmWarpPolicy.FullCol) policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
with T.If(kr == 0), T.Then(): if kr == 0:
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j]) acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j])
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment