Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
732079ea
Commit
732079ea
authored
Feb 24, 2026
by
zhanghj2
Browse files
更新性能测试方式,仅测试flash_fwd_splitkv_mla_qkvfp8_kernel的性能
parent
a4fdef4c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
6 deletions
+26
-6
tests/test_flash_mla_qkvfp8.py
tests/test_flash_mla_qkvfp8.py
+26
-6
No files found.
tests/test_flash_mla_qkvfp8.py
View file @
732079ea
...
...
@@ -4,6 +4,7 @@ import random
import
torch
import
triton
import
kernelkit
as
kk
from
flash_mla
import
flash_mla_with_kvcache_qkvfp8
,
get_mla_metadata
torch
.
set_printoptions
(
precision
=
4
,
profile
=
"default"
,
sci_mode
=
False
)
...
...
@@ -163,12 +164,31 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
cal_diff
(
out_flash
,
out_torch
,
"out"
,
use_fp8
)
cal_diff
(
lse_flash
,
lse_torch
,
"lse"
)
if
is_prof
:
return
t
=
triton
.
testing
.
do_bench
(
flash_mla
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
)
*
(
torch
.
finfo
(
torch_dtype
).
bits
//
8
)
+
(
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
init_dtype
).
bits
//
8
)
print
(
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
# t = triton.testing.do_bench(flash_mla)
# FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
# bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
# print(
# f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
# )
time_usage
=
kk
.
bench_kineto
(
flash_mla
,
10
).
get_kernel_time
(
"flash_fwd_splitkv_mla_qkvfp8_kernel"
)
mean_attended_seqlens
=
cache_seqlens
.
float
().
mean
().
item
()
compute_volume_flop
=
b
*
h_q
*
s_q
*
sum
([
2
*
d
*
mean_attended_seqlens
,
# Q * K^T
2
*
mean_attended_seqlens
*
dv
,
# attention * V
])
q_elem_size
=
1
kv_token_size
=
d
*
1
memory_volume_B
=
b
*
sum
([
s_q
*
h_q
*
(
d
*
q_elem_size
),
# Q
mean_attended_seqlens
*
h_kv
*
kv_token_size
,
# K/V
s_q
*
h_q
*
(
dv
*
2
),
# Output
])
achieved_tflops
=
compute_volume_flop
/
time_usage
/
1e12
achieved_gBps
=
memory_volume_B
/
time_usage
/
1e9
print
(
f
"
{
time_usage
*
1000
:.
3
f
}
ms,
{
achieved_tflops
:.
0
f
}
TFLOPS,
{
achieved_gBps
:.
0
f
}
GB/s"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment