Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
c566af36
Commit
c566af36
authored
Feb 25, 2026
by
zhanghj2
Browse files
恢复测试方式
parent
732079ea
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
27 deletions
+6
-27
tests/test_flash_mla_qkvfp8.py
tests/test_flash_mla_qkvfp8.py
+6
-27
No files found.
tests/test_flash_mla_qkvfp8.py
View file @
c566af36
...
@@ -164,33 +164,12 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
...
@@ -164,33 +164,12 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
cal_diff
(
out_flash
,
out_torch
,
"out"
,
use_fp8
)
cal_diff
(
out_flash
,
out_torch
,
"out"
,
use_fp8
)
cal_diff
(
lse_flash
,
lse_torch
,
"lse"
)
cal_diff
(
lse_flash
,
lse_torch
,
"lse"
)
if
is_prof
:
return
if
is_prof
:
return
# t = triton.testing.do_bench(flash_mla)
t
=
triton
.
testing
.
do_bench
(
flash_mla
)
# FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
# bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
)
*
(
torch
.
finfo
(
torch_dtype
).
bits
//
8
)
+
(
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
init_dtype
).
bits
//
8
)
# print(
print
(
# f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
# )
)
time_usage
=
kk
.
bench_kineto
(
flash_mla
,
10
).
get_kernel_time
(
"flash_fwd_splitkv_mla_qkvfp8_kernel"
)
mean_attended_seqlens
=
cache_seqlens
.
float
().
mean
().
item
()
compute_volume_flop
=
b
*
h_q
*
s_q
*
sum
([
2
*
d
*
mean_attended_seqlens
,
# Q * K^T
2
*
mean_attended_seqlens
*
dv
,
# attention * V
])
q_elem_size
=
1
kv_token_size
=
d
*
1
memory_volume_B
=
b
*
sum
([
s_q
*
h_q
*
(
d
*
q_elem_size
),
# Q
mean_attended_seqlens
*
h_kv
*
kv_token_size
,
# K/V
s_q
*
h_q
*
(
dv
*
2
),
# Output
])
achieved_tflops
=
compute_volume_flop
/
time_usage
/
1e12
achieved_gBps
=
memory_volume_B
/
time_usage
/
1e9
print
(
f
"
{
time_usage
*
1000
:.
3
f
}
ms,
{
achieved_tflops
:.
0
f
}
TFLOPS,
{
achieved_gBps
:.
0
f
}
GB/s"
)
def
main
(
torch_dtype
,
is_prof
=
False
):
def
main
(
torch_dtype
,
is_prof
=
False
):
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment