Commit 732079ea authored by zhanghj2's avatar zhanghj2
Browse files

更新性能测试方式,仅测试flash_fwd_splitkv_mla_qkvfp8_kernel的性能

parent a4fdef4c
...@@ -4,6 +4,7 @@ import random ...@@ -4,6 +4,7 @@ import random
import torch import torch
import triton import triton
import kernelkit as kk
from flash_mla import flash_mla_with_kvcache_qkvfp8, get_mla_metadata from flash_mla import flash_mla_with_kvcache_qkvfp8, get_mla_metadata
torch.set_printoptions(precision=4, profile="default", sci_mode=False) torch.set_printoptions(precision=4, profile="default", sci_mode=False)
...@@ -163,12 +164,31 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa ...@@ -163,12 +164,31 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
cal_diff(out_flash, out_torch, "out", use_fp8) cal_diff(out_flash, out_torch, "out", use_fp8)
cal_diff(lse_flash, lse_torch, "lse") cal_diff(lse_flash, lse_torch, "lse")
if is_prof: return if is_prof: return
t = triton.testing.do_bench(flash_mla) # t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 # FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) # bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
print( # print(
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" # f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
) # )
time_usage = kk.bench_kineto(flash_mla, 10).get_kernel_time("flash_fwd_splitkv_mla_qkvfp8_kernel")
mean_attended_seqlens = cache_seqlens.float().mean().item()
compute_volume_flop = b * h_q * s_q * sum([
2 * d * mean_attended_seqlens, # Q * K^T
2 * mean_attended_seqlens * dv, # attention * V
])
q_elem_size = 1
kv_token_size = d * 1
memory_volume_B = b * sum([
s_q * h_q * (d * q_elem_size), # Q
mean_attended_seqlens * h_kv * kv_token_size, # K/V
s_q * h_q * (dv * 2), # Output
])
achieved_tflops = compute_volume_flop / time_usage / 1e12
achieved_gBps = memory_volume_B / time_usage / 1e9
print(f"{time_usage * 1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment