Unverified Commit e11880de authored by Kebe's avatar Kebe Committed by GitHub
Browse files

[Bugfix] Remove triton do_bench fast_flush arg (#16256)


Signed-off-by: default avatarKebe <mail@kebe7jun.com>
parent 9351f91b
......@@ -124,7 +124,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
cal_diff(out_flash, out_torch, "out")
cal_diff(lse_flash, lse_torch, "lse")
t = triton.testing.do_bench(flash_mla, fast_flush=False)
t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d +
b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment