Unverified Commit 16561159 authored by Wenxuan Tan's avatar Wenxuan Tan Committed by GitHub
Browse files

[Bugfix] Fix flops comp and softmax scale in mla (#900)

* fix flops comp and softmax scale

* format
parent 54fc6ba0
...@@ -87,8 +87,8 @@ def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, ...@@ -87,8 +87,8 @@ def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
@torch.inference_mode() @torch.inference_mode()
def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens,
h_q, h_kv, d, dv, causal, dtype): h_q, h_kv, d, dv, causal, dtype):
# pip install flashinfer-python # pip install flashinfer-python
import flashinfer import flashinfer
assert d > dv, "mla with rope dim should be larger than no rope dim" assert d > dv, "mla with rope dim should be larger than no rope dim"
...@@ -128,7 +128,7 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_ ...@@ -128,7 +128,7 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_
blocked_k.dtype, blocked_k.dtype,
) )
def flash_infer(): def flashinfer():
output, lse = mla_wrapper.run( output, lse = mla_wrapper.run(
q_nope.view(-1, h_q, dv), q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, d - dv), q_pe.view(-1, h_q, d - dv),
...@@ -137,8 +137,8 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_ ...@@ -137,8 +137,8 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_
return_lse=True) return_lse=True)
return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1) return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1)
out_flash, lse_flash = flash_infer() out_flash, lse_flash = flashinfer()
t = triton.testing.do_bench(flash_infer) t = triton.testing.do_bench(flashinfer)
return out_flash, lse_flash, t return out_flash, lse_flash, t
...@@ -459,7 +459,7 @@ FUNC_TABLE = { ...@@ -459,7 +459,7 @@ FUNC_TABLE = {
"torch": run_torch_mla, "torch": run_torch_mla,
"tilelang": run_flash_mla_tilelang, "tilelang": run_flash_mla_tilelang,
"flash_mla": run_flash_mla, "flash_mla": run_flash_mla,
"flash_infer": run_flash_infer, "flashinfer": run_flashinfer,
"flash_mla_triton": run_flash_mla_triton, "flash_mla_triton": run_flash_mla_triton,
} }
...@@ -496,9 +496,9 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal ...@@ -496,9 +496,9 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
if target not in ["flash_infer", "flash_mla_triton", "tilelang" if target not in ["flashinfer", "flash_mla_triton", "tilelang"
] and baseline not in ["flash_infer", "flash_mla_triton", "tilelang"]: ] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]:
# flash_infer has a different lse return value # flashinfer has a different lse return value
# flash_mla_triton and flash_mla_tilelang doesn't return lse # flash_mla_triton and flash_mla_tilelang doesn't return lse
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
...@@ -554,7 +554,7 @@ available_targets = [ ...@@ -554,7 +554,7 @@ available_targets = [
"torch", "torch",
"tilelang", "tilelang",
"flash_mla", "flash_mla",
"flash_infer", "flashinfer",
"flash_mla_triton", "flash_mla_triton",
] ]
......
...@@ -11,8 +11,19 @@ import math ...@@ -11,8 +11,19 @@ import math
out_idx=[8], pass_configs={ out_idx=[8], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}) })
def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, def mla_decode_tilelang(batch,
block_size, softmax_scale): h_q,
h_kv,
max_seqlen_pad,
dv,
dpe,
block_N,
block_H,
num_split,
block_size,
softmax_scale=None):
if softmax_scale is None:
softmax_scale = (dv + dpe)**-0.5
scale = float(softmax_scale * 1.44269504) # log2(e) scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
...@@ -322,7 +333,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s ...@@ -322,7 +333,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
num_kv_splits = 1 num_kv_splits = 1
BLOCK_N = 64 BLOCK_N = 64
BLOCK_H = min(64, h_q // h_kv) BLOCK_H = min(64, h_q // h_kv)
softmax_scale = (d + dv)**-0.5 softmax_scale = d**-0.5
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
...@@ -379,7 +390,7 @@ if __name__ == "__main__": ...@@ -379,7 +390,7 @@ if __name__ == "__main__":
max_seqlen = cache_seqlens.max().item() max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = math.ceil(max_seqlen / 256) * 256 max_seqlen_pad = math.ceil(max_seqlen / 256) * 256
total_flops = s_q * total_seqlens * h_q * (d + dv) * 2 total_flops = s_q * total_seqlens * h_q * d * 2
q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device) q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device)
block_table = torch.arange( block_table = torch.arange(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment