Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
1858932a
Commit
1858932a
authored
Sep 30, 2025
by
Jiashi Li
Browse files
Code format
parent
7f55c715
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
66 additions
and
70 deletions
+66
-70
tests/quant.py
tests/quant.py
+12
-14
tests/test_flash_mla_decoding.py
tests/test_flash_mla_decoding.py
+32
-32
tests/test_flash_mla_prefill.py
tests/test_flash_mla_prefill.py
+13
-14
tests/test_fmha_sm100.py
tests/test_fmha_sm100.py
+9
-10
No files found.
tests/quant.py
View file @
1858932a
import
enum
import
torch
def
quantize_k_cache
(
...
...
@@ -19,19 +17,19 @@ def quantize_k_cache(
input_k_cache
=
input_k_cache
.
squeeze
(
2
)
# [num_blocks, block_size, d]
input_elem_size
=
input_k_cache
.
element_size
()
result
=
torch
.
empty
((
num_blocks
,
block_size
,
dv
+
num_tiles
*
4
+
input_elem_size
*
(
d
-
dv
)),
dtype
=
torch
.
float8_e4m3fn
,
device
=
input_k_cache
.
device
)
result
=
torch
.
empty
((
num_blocks
,
block_size
,
dv
+
num_tiles
*
4
+
input_elem_size
*
(
d
-
dv
)),
dtype
=
torch
.
float8_e4m3fn
,
device
=
input_k_cache
.
device
)
result_k_nope_part
=
result
[...,
:
dv
]
result_k_scale_factor
=
result
[...,
dv
:
dv
+
num_tiles
*
4
].
view
(
torch
.
float32
)
result_k_rope_part
=
result
[...,
dv
+
num_tiles
*
4
:].
view
(
input_k_cache
.
dtype
)
result_k_scale_factor
=
result
[...,
dv
:
dv
+
num_tiles
*
4
].
view
(
torch
.
float32
)
result_k_rope_part
=
result
[...,
dv
+
num_tiles
*
4
:].
view
(
input_k_cache
.
dtype
)
result_k_rope_part
[:]
=
input_k_cache
[...,
dv
:]
for
tile_idx
in
range
(
0
,
num_tiles
):
cur_scale_factors_inv
=
torch
.
abs
(
input_k_cache
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
]).
max
(
dim
=-
1
).
values
/
448.0
# [num_blocks, block_size]
cur_scale_factors_inv
=
torch
.
abs
(
input_k_cache
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
]).
max
(
dim
=-
1
).
values
/
448.0
# [num_blocks, block_size]
result_k_scale_factor
[:,
:,
tile_idx
]
=
cur_scale_factors_inv
cur_scale_factors_inv
.
unsqueeze_
(
-
1
)
# [num_blocks, block_size, 1]
cur_quantized_nope
=
(
input_k_cache
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
].
float
()
/
cur_scale_factors_inv
.
float
()).
to
(
torch
.
float8_e4m3fn
)
result_k_nope_part
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
]
=
cur_quantized_nope
cur_quantized_nope
=
(
input_k_cache
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
].
float
()
/
cur_scale_factors_inv
.
float
()).
to
(
torch
.
float8_e4m3fn
)
result_k_nope_part
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
]
=
cur_quantized_nope
result
=
result
.
view
(
num_blocks
,
block_size
,
1
,
-
1
)
return
result
...
...
@@ -55,14 +53,14 @@ def dequantize_k_cache(
quant_k_cache
=
quant_k_cache
.
view
(
num_blocks
,
block_size
,
-
1
)
input_nope
=
quant_k_cache
[...,
:
dv
]
input_scale
=
quant_k_cache
[...,
dv
:
dv
+
num_tiles
*
4
].
view
(
torch
.
float32
)
input_rope
=
quant_k_cache
[...,
dv
+
num_tiles
*
4
:].
view
(
torch
.
bfloat16
)
input_scale
=
quant_k_cache
[...,
dv
:
dv
+
num_tiles
*
4
].
view
(
torch
.
float32
)
input_rope
=
quant_k_cache
[...,
dv
+
num_tiles
*
4
:].
view
(
torch
.
bfloat16
)
result
[...,
dv
:]
=
input_rope
for
tile_idx
in
range
(
0
,
num_tiles
):
cur_nope
=
input_nope
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
].
to
(
torch
.
float32
)
cur_nope
=
input_nope
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
].
to
(
torch
.
float32
)
cur_scales
=
input_scale
[...,
tile_idx
].
unsqueeze
(
-
1
)
result
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
]
=
cur_nope
*
cur_scales
result
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
]
=
cur_nope
*
cur_scales
result
=
result
.
view
(
num_blocks
,
block_size
,
1
,
d
)
return
result
tests/test_flash_mla_decoding.py
View file @
1858932a
...
...
@@ -2,13 +2,13 @@ import argparse
import
math
import
random
import
dataclasses
from
typing
import
Optional
,
Tuple
,
List
from
typing
import
Optional
,
Tuple
import
torch
import
triton
import
quant
import
flash_mla
import
quant
from
lib
import
cdiv
,
check_is_allclose
@
dataclasses
.
dataclass
...
...
@@ -71,7 +71,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
cur_num_blocks
=
cdiv
(
cur_len
,
t
.
block_size
)
blocked_k
[
block_table
[
i
][
cur_num_blocks
:]]
=
float
(
"nan"
)
if
cur_len
%
t
.
block_size
!=
0
:
blocked_k
[
block_table
[
i
][
cur_num_blocks
-
1
]][
cur_len
%
t
.
block_size
:]
=
float
(
"nan"
)
blocked_k
[
block_table
[
i
][
cur_num_blocks
-
1
]][
cur_len
%
t
.
block_size
:]
=
float
(
"nan"
)
block_table
[
i
][
cur_num_blocks
:]
=
2147480000
return
cache_seqlens
,
q
,
block_table
,
blocked_k
,
None
,
None
else
:
...
...
@@ -82,7 +82,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
# Generate indices
for
j
in
range
(
t
.
s_q
):
cur_abs_indices
=
torch
.
randperm
(
int
(
cache_seqlens_cpu
[
i
].
item
()),
device
=
"cpu"
)[:
t
.
topk
]
cur_blocked_indices
=
block_table_cpu
[
i
,
cur_abs_indices
//
t
.
block_size
]
*
t
.
block_size
+
(
cur_abs_indices
%
t
.
block_size
)
cur_blocked_indices
=
block_table_cpu
[
i
,
cur_abs_indices
//
t
.
block_size
]
*
t
.
block_size
+
(
cur_abs_indices
%
t
.
block_size
)
if
len
(
cur_abs_indices
)
<
t
.
topk
:
pad_len
=
t
.
topk
-
len
(
cur_abs_indices
)
cur_abs_indices
=
torch
.
cat
([
cur_abs_indices
,
torch
.
full
((
pad_len
,),
-
1
,
device
=
'cpu'
)])
...
...
@@ -109,7 +109,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
all_indices
=
torch
.
tensor
(
all_indices
,
dtype
=
torch
.
int32
,
device
=
'cpu'
)
blocked_k
=
blocked_k
.
view
(
-
1
,
t
.
h_kv
,
t
.
d
)
nonused_indices_mask
=
torch
.
ones
(
blocked_k
.
size
(
0
)
*
blocked_k
.
size
(
1
),
dtype
=
torch
.
bool
,
device
=
'cpu'
)
nonused_indices_mask
=
torch
.
ones
(
blocked_k
.
size
(
0
)
*
blocked_k
.
size
(
1
),
dtype
=
torch
.
bool
,
device
=
'cpu'
)
nonused_indices_mask
[
all_indices
]
=
False
blocked_k
[
nonused_indices_mask
,
:,
:]
=
float
(
"nan"
)
blocked_k
=
blocked_k
.
view
(
-
1
,
t
.
block_size
,
t
.
h_kv
,
t
.
d
)
...
...
@@ -248,27 +248,27 @@ def test_flash_mla(t: TestParam):
out_ans
,
lse_ans
=
run_flash_mla
()
out_ref
,
lse_ref
=
reference_torch
(
cache_seqlens
,
block_table
,
q
,
blocked_k
,
t
.
dv
,
t
.
is_causal
,
abs_indices
)
assert
check_is_allclose
(
"out"
,
out_ans
,
out_ref
,
abs_tol
=
8e-4
,
rel_tol
=
2.01
/
128
,
cos_diff_tol
=
5e-6
)
assert
check_is_allclose
(
"lse"
,
lse_ans
,
lse_ref
,
abs_tol
=
1e-6
,
rel_tol
=
8.01
/
65536
)
assert
check_is_allclose
(
"out"
,
out_ans
,
out_ref
,
abs_tol
=
8e-4
,
rel_tol
=
2.01
/
128
,
cos_diff_tol
=
5e-6
)
assert
check_is_allclose
(
"lse"
,
lse_ans
,
lse_ref
,
abs_tol
=
1e-6
,
rel_tol
=
8.01
/
65536
)
if
t
.
test_performance
:
time_usage
:
float
=
triton
.
testing
.
do_bench
(
run_flash_mla
)
/
1000
# type: ignore
time_usage
:
float
=
triton
.
testing
.
do_bench
(
run_flash_mla
)
/
1000
# type: ignore
mean_attended_seqlens
=
cache_seqlens
.
float
().
mean
().
item
()
if
t
.
topk
is
None
else
t
.
topk
compute_volume_flop
=
t
.
b
*
t
.
h_q
*
t
.
s_q
*
sum
([
2
*
t
.
d
*
mean_attended_seqlens
,
# Q * K^T
2
*
mean_attended_seqlens
*
t
.
dv
,
# attention * V
compute_volume_flop
=
t
.
b
*
t
.
h_q
*
t
.
s_q
*
sum
([
2
*
t
.
d
*
mean_attended_seqlens
,
# Q * K^T
2
*
mean_attended_seqlens
*
t
.
dv
,
# attention * V
])
q_elem_size
=
torch
.
bfloat16
.
itemsize
kv_token_size
=
656
if
t
.
is_fp8
else
t
.
d
*
torch
.
bfloat16
.
itemsize
memory_volume_B
=
t
.
b
*
sum
([
t
.
s_q
*
t
.
h_q
*
(
t
.
d
*
q_elem_size
),
# Q
(
t
.
s_q
if
t
.
topk
is
not
None
else
1
)
*
mean_attended_seqlens
*
t
.
h_kv
*
kv_token_size
,
# K/V
t
.
s_q
*
t
.
h_q
*
(
t
.
dv
*
q_elem_size
),
# Output
kv_token_size
=
656
if
t
.
is_fp8
else
t
.
d
*
torch
.
bfloat16
.
itemsize
memory_volume_B
=
t
.
b
*
sum
([
t
.
s_q
*
t
.
h_q
*
(
t
.
d
*
q_elem_size
),
# Q
(
t
.
s_q
if
t
.
topk
is
not
None
else
1
)
*
mean_attended_seqlens
*
t
.
h_kv
*
kv_token_size
,
# K/V
t
.
s_q
*
t
.
h_q
*
(
t
.
dv
*
q_elem_size
),
# Output
])
achieved_tflops
=
compute_volume_flop
/
time_usage
/
1e12
achieved_gBps
=
memory_volume_B
/
time_usage
/
1e9
print
(
f
"
{
time_usage
*
1000
:.
3
f
}
ms,
{
achieved_tflops
:.
0
f
}
TFLOPS,
{
achieved_gBps
:.
0
f
}
GB/s"
)
print
(
f
"
{
time_usage
*
1000
:.
3
f
}
ms,
{
achieved_tflops
:.
0
f
}
TFLOPS,
{
achieved_gBps
:.
0
f
}
GB/s"
)
def
main
(
torch_dtype
):
...
...
tests/test_flash_mla_prefill.py
View file @
1858932a
...
...
@@ -35,8 +35,8 @@ def generate_testcase(t: TestParam) -> Testcase:
torch
.
manual_seed
(
t
.
seed
)
torch
.
cuda
.
manual_seed
(
t
.
seed
)
random
.
seed
(
t
.
seed
)
q
=
torch
.
randn
((
t
.
b
,
t
.
s_q
,
t
.
h_q
,
t
.
d_qk
),
dtype
=
torch
.
bfloat16
)
/
10
kv
=
torch
.
randn
((
t
.
b
,
t
.
s_kv
,
t
.
h_kv
,
t
.
d_qk
),
dtype
=
torch
.
bfloat16
)
/
10
q
=
torch
.
randn
((
t
.
b
,
t
.
s_q
,
t
.
h_q
,
t
.
d_qk
),
dtype
=
torch
.
bfloat16
)
/
10
kv
=
torch
.
randn
((
t
.
b
,
t
.
s_kv
,
t
.
h_kv
,
t
.
d_qk
),
dtype
=
torch
.
bfloat16
)
/
10
q
.
clamp_
(
-
10
,
10
)
kv
.
clamp_
(
-
10
,
10
)
...
...
@@ -48,7 +48,7 @@ def generate_testcase(t: TestParam) -> Testcase:
# NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention
near_mask
=
torch
.
randint
(
0
,
32
,
(
min
(
t
.
topk
,
t
.
s_kv
),))
<
31
cur_indices
=
torch
.
randperm
(
t
.
s_kv
)[:
t
.
topk
]
cur_indices
[
near_mask
]
=
torch
.
randint
(
max
(
0
,
t
.
s_kv
-
20000
),
t
.
s_kv
-
1
,
(
near_mask
.
sum
().
item
(),))
cur_indices
[
near_mask
]
=
torch
.
randint
(
max
(
0
,
t
.
s_kv
-
20000
),
t
.
s_kv
-
1
,
(
near_mask
.
sum
().
item
(),))
if
len
(
cur_indices
)
<
t
.
topk
:
cur_indices
=
torch
.
cat
([
cur_indices
,
torch
.
full
((
t
.
topk
-
len
(
cur_indices
),),
2147480000
)])
cur_indices
=
cur_indices
[
torch
.
randperm
(
t
.
topk
)]
...
...
@@ -110,9 +110,9 @@ def run_test(p: TestParam) -> bool:
if
p
.
benchmark
:
flop
=
get_flop
(
p
)
prefill_ans_time
:
float
=
triton
.
testing
.
do_bench
(
run_ans
,
warmup
=
10
,
rep
=
20
)
/
1000
# type: ignore
prefill_flops
=
flop
/
prefill_ans_time
/
1e12
print
(
f
"Prefill:
{
prefill_ans_time
*
1e6
:
4.0
f
}
us,
{
prefill_flops
:.
3
f
}
TFlops"
)
prefill_ans_time
:
float
=
triton
.
testing
.
do_bench
(
run_ans
,
warmup
=
10
,
rep
=
20
)
/
1000
# type: ignore
prefill_flops
=
flop
/
prefill_ans_time
/
1e12
print
(
f
"Prefill:
{
prefill_ans_time
*
1e6
:
4.0
f
}
us,
{
prefill_flops
:.
3
f
}
TFlops"
)
if
p
.
check_correctness
:
torch
.
cuda
.
synchronize
()
...
...
@@ -120,9 +120,9 @@ def run_test(p: TestParam) -> bool:
torch
.
cuda
.
synchronize
()
is_correct
=
True
is_correct
&=
check_is_allclose
(
"out"
,
ans_out
,
ref_out
,
abs_tol
=
8e-4
,
rel_tol
=
2.01
/
128
,
cos_diff_tol
=
7e-6
)
is_correct
&=
check_is_allclose
(
"max_logits"
,
ans_max_logits
,
ref_max_logits
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
is_correct
&=
check_is_allclose
(
"lse"
,
ans_lse
,
ref_lse
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
is_correct
&=
check_is_allclose
(
"out"
,
ans_out
,
ref_out
,
abs_tol
=
8e-4
,
rel_tol
=
2.01
/
128
,
cos_diff_tol
=
7e-6
)
is_correct
&=
check_is_allclose
(
"max_logits"
,
ans_max_logits
,
ref_max_logits
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
is_correct
&=
check_is_allclose
(
"lse"
,
ans_lse
,
ref_lse
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
return
is_correct
else
:
...
...
@@ -194,4 +194,3 @@ if __name__ == '__main__':
print
(
f
"
{
case
}
"
)
else
:
print
(
f
"
\033
[32m
\033
[1mAll
{
len
(
testcases
)
}
cases passed!
\033
[0m"
)
tests/test_fmha_sm100.py
View file @
1858932a
...
...
@@ -5,7 +5,6 @@ from torch.utils.checkpoint import checkpoint
import
triton
from
flash_mla
import
flash_attn_varlen_func
from
lib
import
check_is_allclose
def
get_window_size
(
causal
,
window
):
...
...
@@ -71,10 +70,10 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
causal
,
window
)
==
0
).
sum
().
item
()
for
i
in
range
(
b
)])
# print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}")
q
=
torch
.
randn
(
total_q
,
h
,
d
)
/
10
k
=
torch
.
randn
(
total_k
,
h_k
,
d
)
/
10
v
=
torch
.
randn
(
total_k
,
h_k
,
dv
)
/
10
grad_out
=
torch
.
randn
(
total_q
,
h
,
dv
)
/
10
q
=
torch
.
randn
(
total_q
,
h
,
d
)
/
10
k
=
torch
.
randn
(
total_k
,
h_k
,
d
)
/
10
v
=
torch
.
randn
(
total_k
,
h_k
,
dv
)
/
10
grad_out
=
torch
.
randn
(
total_q
,
h
,
dv
)
/
10
softmax_scale
=
(
d
+
100
)
**
(
-
0.5
)
q1
=
q
.
clone
().
requires_grad_
()
...
...
@@ -123,14 +122,14 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
if
check_correctness
:
out_torch
,
lse_torch
=
torch_attn
()
assert
check_is_allclose
(
"out"
,
out_flash
,
out_torch
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"lse"
,
lse_flash
,
lse_torch
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
assert
check_is_allclose
(
"out"
,
out_flash
,
out_torch
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"lse"
,
lse_flash
,
lse_torch
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
if
has_bwd
:
out_torch
.
backward
(
grad_out
,
retain_graph
=
True
)
assert
check_is_allclose
(
"dq"
,
q1
.
grad
,
q2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"dk"
,
k1
.
grad
,
k2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"dv"
,
v1
.
grad
,
v2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"dq"
,
q1
.
grad
,
q2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"dk"
,
k1
.
grad
,
k2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"dv"
,
v1
.
grad
,
v2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
def
forward
():
return
flash_attn
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment