Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
1858932a
Commit
1858932a
authored
Sep 30, 2025
by
Jiashi Li
Browse files
Code format
parent
7f55c715
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
66 additions
and
70 deletions
+66
-70
tests/quant.py
tests/quant.py
+12
-14
tests/test_flash_mla_decoding.py
tests/test_flash_mla_decoding.py
+32
-32
tests/test_flash_mla_prefill.py
tests/test_flash_mla_prefill.py
+13
-14
tests/test_fmha_sm100.py
tests/test_fmha_sm100.py
+9
-10
No files found.
tests/quant.py
View file @
1858932a
import
enum
import
torch
def
quantize_k_cache
(
...
...
@@ -19,20 +17,20 @@ def quantize_k_cache(
input_k_cache
=
input_k_cache
.
squeeze
(
2
)
# [num_blocks, block_size, d]
input_elem_size
=
input_k_cache
.
element_size
()
result
=
torch
.
empty
((
num_blocks
,
block_size
,
dv
+
num_tiles
*
4
+
input_elem_size
*
(
d
-
dv
)),
dtype
=
torch
.
float8_e4m3fn
,
device
=
input_k_cache
.
device
)
result
=
torch
.
empty
((
num_blocks
,
block_size
,
dv
+
num_tiles
*
4
+
input_elem_size
*
(
d
-
dv
)),
dtype
=
torch
.
float8_e4m3fn
,
device
=
input_k_cache
.
device
)
result_k_nope_part
=
result
[...,
:
dv
]
result_k_scale_factor
=
result
[...,
dv
:
dv
+
num_tiles
*
4
].
view
(
torch
.
float32
)
result_k_rope_part
=
result
[...,
dv
+
num_tiles
*
4
:].
view
(
input_k_cache
.
dtype
)
result_k_scale_factor
=
result
[...,
dv
:
dv
+
num_tiles
*
4
].
view
(
torch
.
float32
)
result_k_rope_part
=
result
[...,
dv
+
num_tiles
*
4
:].
view
(
input_k_cache
.
dtype
)
result_k_rope_part
[:]
=
input_k_cache
[...,
dv
:]
for
tile_idx
in
range
(
0
,
num_tiles
):
cur_scale_factors_inv
=
torch
.
abs
(
input_k_cache
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
]).
max
(
dim
=-
1
).
values
/
448.0
# [num_blocks, block_size]
cur_scale_factors_inv
=
torch
.
abs
(
input_k_cache
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
]).
max
(
dim
=-
1
).
values
/
448.0
# [num_blocks, block_size]
result_k_scale_factor
[:,
:,
tile_idx
]
=
cur_scale_factors_inv
cur_scale_factors_inv
.
unsqueeze_
(
-
1
)
# [num_blocks, block_size, 1]
cur_quantized_nope
=
(
input_k_cache
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
].
float
()
/
cur_scale_factors_inv
.
float
()).
to
(
torch
.
float8_e4m3fn
)
result_k_nope_part
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
]
=
cur_quantized_nope
cur_quantized_nope
=
(
input_k_cache
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
].
float
()
/
cur_scale_factors_inv
.
float
()).
to
(
torch
.
float8_e4m3fn
)
result_k_nope_part
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
]
=
cur_quantized_nope
result
=
result
.
view
(
num_blocks
,
block_size
,
1
,
-
1
)
return
result
...
...
@@ -55,14 +53,14 @@ def dequantize_k_cache(
quant_k_cache
=
quant_k_cache
.
view
(
num_blocks
,
block_size
,
-
1
)
input_nope
=
quant_k_cache
[...,
:
dv
]
input_scale
=
quant_k_cache
[...,
dv
:
dv
+
num_tiles
*
4
].
view
(
torch
.
float32
)
input_rope
=
quant_k_cache
[...,
dv
+
num_tiles
*
4
:].
view
(
torch
.
bfloat16
)
input_scale
=
quant_k_cache
[...,
dv
:
dv
+
num_tiles
*
4
].
view
(
torch
.
float32
)
input_rope
=
quant_k_cache
[...,
dv
+
num_tiles
*
4
:].
view
(
torch
.
bfloat16
)
result
[...,
dv
:]
=
input_rope
for
tile_idx
in
range
(
0
,
num_tiles
):
cur_nope
=
input_nope
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
].
to
(
torch
.
float32
)
cur_nope
=
input_nope
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
].
to
(
torch
.
float32
)
cur_scales
=
input_scale
[...,
tile_idx
].
unsqueeze
(
-
1
)
result
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
]
=
cur_nope
*
cur_scales
result
[...,
tile_idx
*
tile_size
:(
tile_idx
+
1
)
*
tile_size
]
=
cur_nope
*
cur_scales
result
=
result
.
view
(
num_blocks
,
block_size
,
1
,
d
)
return
result
tests/test_flash_mla_decoding.py
View file @
1858932a
...
...
@@ -2,20 +2,20 @@ import argparse
import
math
import
random
import
dataclasses
from
typing
import
Optional
,
Tuple
,
List
from
typing
import
Optional
,
Tuple
import
torch
import
triton
import
quant
import
flash_mla
import
quant
from
lib
import
cdiv
,
check_is_allclose
@
dataclasses
.
dataclass
class
TestParam
:
b
:
int
# Batch size
s_q
:
int
# Number of queries for one request
s_k
:
int
# Seq len, or mean seq len if varlen == True
b
:
int
# Batch size
s_q
:
int
# Number of queries for one request
s_k
:
int
# Seq len, or mean seq len if varlen == True
is_varlen
:
bool
is_causal
:
bool
is_fp8
:
bool
...
...
@@ -24,8 +24,8 @@ class TestParam:
is_all_indices_invalid
:
bool
=
False
have_zero_seqlen_k
:
bool
=
False
block_size
:
int
=
64
h_q
:
int
=
128
# Number of q heads
h_kv
:
int
=
1
# Number of kv heads
h_q
:
int
=
128
# Number of q heads
h_kv
:
int
=
1
# Number of kv heads
d
:
int
=
576
# Q/K head dim (= dv + RoPE dim)
dv
:
int
=
512
# V head dim
seed
:
int
=
0
...
...
@@ -71,7 +71,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
cur_num_blocks
=
cdiv
(
cur_len
,
t
.
block_size
)
blocked_k
[
block_table
[
i
][
cur_num_blocks
:]]
=
float
(
"nan"
)
if
cur_len
%
t
.
block_size
!=
0
:
blocked_k
[
block_table
[
i
][
cur_num_blocks
-
1
]][
cur_len
%
t
.
block_size
:]
=
float
(
"nan"
)
blocked_k
[
block_table
[
i
][
cur_num_blocks
-
1
]][
cur_len
%
t
.
block_size
:]
=
float
(
"nan"
)
block_table
[
i
][
cur_num_blocks
:]
=
2147480000
return
cache_seqlens
,
q
,
block_table
,
blocked_k
,
None
,
None
else
:
...
...
@@ -82,12 +82,12 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
# Generate indices
for
j
in
range
(
t
.
s_q
):
cur_abs_indices
=
torch
.
randperm
(
int
(
cache_seqlens_cpu
[
i
].
item
()),
device
=
"cpu"
)[:
t
.
topk
]
cur_blocked_indices
=
block_table_cpu
[
i
,
cur_abs_indices
//
t
.
block_size
]
*
t
.
block_size
+
(
cur_abs_indices
%
t
.
block_size
)
cur_blocked_indices
=
block_table_cpu
[
i
,
cur_abs_indices
//
t
.
block_size
]
*
t
.
block_size
+
(
cur_abs_indices
%
t
.
block_size
)
if
len
(
cur_abs_indices
)
<
t
.
topk
:
pad_len
=
t
.
topk
-
len
(
cur_abs_indices
)
cur_abs_indices
=
torch
.
cat
([
cur_abs_indices
,
torch
.
full
((
pad_len
,),
-
1
,
device
=
'cpu'
)])
cur_blocked_indices
=
torch
.
cat
([
cur_blocked_indices
,
torch
.
full
((
pad_len
,),
-
1
,
device
=
'cpu'
)])
# Mask KV
perm
=
torch
.
randperm
(
t
.
topk
,
device
=
'cpu'
)
cur_abs_indices
=
cur_abs_indices
[
perm
]
...
...
@@ -100,7 +100,7 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
abs_indices
[
i
,
j
,
:]
=
cur_abs_indices
indices_in_kvcache
[
i
,
j
,
:]
=
cur_blocked_indices
# Mask nonused KV as NaN
all_indices
=
indices_in_kvcache
.
flatten
().
tolist
()
all_indices
=
list
(
set
(
all_indices
))
...
...
@@ -109,11 +109,11 @@ def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.
all_indices
=
torch
.
tensor
(
all_indices
,
dtype
=
torch
.
int32
,
device
=
'cpu'
)
blocked_k
=
blocked_k
.
view
(
-
1
,
t
.
h_kv
,
t
.
d
)
nonused_indices_mask
=
torch
.
ones
(
blocked_k
.
size
(
0
)
*
blocked_k
.
size
(
1
),
dtype
=
torch
.
bool
,
device
=
'cpu'
)
nonused_indices_mask
=
torch
.
ones
(
blocked_k
.
size
(
0
)
*
blocked_k
.
size
(
1
),
dtype
=
torch
.
bool
,
device
=
'cpu'
)
nonused_indices_mask
[
all_indices
]
=
False
blocked_k
[
nonused_indices_mask
,
:,
:]
=
float
(
"nan"
)
blocked_k
=
blocked_k
.
view
(
-
1
,
t
.
block_size
,
t
.
h_kv
,
t
.
d
)
abs_indices
=
abs_indices
.
to
(
q
.
device
)
indices_in_kvcache
=
indices_in_kvcache
.
to
(
q
.
device
)
...
...
@@ -139,7 +139,7 @@ def reference_torch(
valid_indices
=
cur_indices
[
cur_indices
!=
-
1
]
mask
[
i
,
valid_indices
]
=
True
return
mask
def
scaled_dot_product_attention
(
batch_idx
:
int
,
query
:
torch
.
Tensor
,
# [h_q, s_q, d]
...
...
@@ -157,7 +157,7 @@ def reference_torch(
if
h_kv
!=
1
:
kv
=
kv
.
repeat_interleave
(
h_q
//
h_kv
,
dim
=
0
)
kv
[
kv
!=
kv
]
=
0.0
attn_weight
=
query
@
kv
.
transpose
(
-
2
,
-
1
)
# [h_q, s_q, s_k]
attn_weight
=
query
@
kv
.
transpose
(
-
2
,
-
1
)
# [h_q, s_q, s_k]
if
(
is_causal
and
query
.
size
(
1
)
>
1
)
or
indices
is
not
None
:
mask
=
torch
.
ones
(
s_q
,
s_k
,
dtype
=
torch
.
bool
)
if
is_causal
:
...
...
@@ -169,14 +169,14 @@ def reference_torch(
attn_bias
.
masked_fill_
(
mask
.
logical_not
(),
float
(
"-inf"
))
attn_weight
+=
attn_bias
.
to
(
q
.
dtype
)
attn_weight
/=
math
.
sqrt
(
query
.
size
(
-
1
))
lse
=
attn_weight
.
logsumexp
(
dim
=-
1
)
# [h_q, s_q]
lse
=
attn_weight
.
logsumexp
(
dim
=-
1
)
# [h_q, s_q]
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
,
dtype
=
torch
.
float32
)
output
=
attn_weight
@
kv
[...,
:
dv
]
# [h_q, s_q, dv]
# Correct for q tokens which has no attendable k
lonely_q_mask
=
(
lse
==
float
(
"-inf"
))
output
[
lonely_q_mask
.
unsqueeze
(
-
1
).
broadcast_to
(
h_q
,
s_q
,
dv
)]
=
0.0
lse
[
lonely_q_mask
]
=
float
(
"+inf"
)
return
output
,
lse
b
,
s_q
,
h_q
,
d
=
q
.
size
()
...
...
@@ -202,7 +202,7 @@ def reference_torch(
lse_ref
[
i
]
=
cur_lse
out_ref
=
out_ref
.
to
(
torch
.
bfloat16
)
return
out_ref
,
lse_ref
@
torch
.
inference_mode
()
def
test_flash_mla
(
t
:
TestParam
):
...
...
@@ -235,7 +235,7 @@ def test_flash_mla(t: TestParam):
def
run_flash_mla
():
return
flash_mla
.
flash_mla_with_kvcache
(
q
,
blocked_k
if
not
t
.
is_fp8
else
blocked_k_quantized
,
# type: ignore
blocked_k
if
not
t
.
is_fp8
else
blocked_k_quantized
,
# type: ignore
block_table
,
cache_seqlens
,
t
.
dv
,
...
...
@@ -248,27 +248,27 @@ def test_flash_mla(t: TestParam):
out_ans
,
lse_ans
=
run_flash_mla
()
out_ref
,
lse_ref
=
reference_torch
(
cache_seqlens
,
block_table
,
q
,
blocked_k
,
t
.
dv
,
t
.
is_causal
,
abs_indices
)
assert
check_is_allclose
(
"out"
,
out_ans
,
out_ref
,
abs_tol
=
8e-4
,
rel_tol
=
2.01
/
128
,
cos_diff_tol
=
5e-6
)
assert
check_is_allclose
(
"lse"
,
lse_ans
,
lse_ref
,
abs_tol
=
1e-6
,
rel_tol
=
8.01
/
65536
)
assert
check_is_allclose
(
"out"
,
out_ans
,
out_ref
,
abs_tol
=
8e-4
,
rel_tol
=
2.01
/
128
,
cos_diff_tol
=
5e-6
)
assert
check_is_allclose
(
"lse"
,
lse_ans
,
lse_ref
,
abs_tol
=
1e-6
,
rel_tol
=
8.01
/
65536
)
if
t
.
test_performance
:
time_usage
:
float
=
triton
.
testing
.
do_bench
(
run_flash_mla
)
/
1000
# type: ignore
time_usage
:
float
=
triton
.
testing
.
do_bench
(
run_flash_mla
)
/
1000
# type: ignore
mean_attended_seqlens
=
cache_seqlens
.
float
().
mean
().
item
()
if
t
.
topk
is
None
else
t
.
topk
compute_volume_flop
=
t
.
b
*
t
.
h_q
*
t
.
s_q
*
sum
([
2
*
t
.
d
*
mean_attended_seqlens
,
# Q * K^T
2
*
mean_attended_seqlens
*
t
.
dv
,
# attention * V
compute_volume_flop
=
t
.
b
*
t
.
h_q
*
t
.
s_q
*
sum
([
2
*
t
.
d
*
mean_attended_seqlens
,
# Q * K^T
2
*
mean_attended_seqlens
*
t
.
dv
,
# attention * V
])
q_elem_size
=
torch
.
bfloat16
.
itemsize
kv_token_size
=
656
if
t
.
is_fp8
else
t
.
d
*
torch
.
bfloat16
.
itemsize
memory_volume_B
=
t
.
b
*
sum
([
t
.
s_q
*
t
.
h_q
*
(
t
.
d
*
q_elem_size
),
# Q
(
t
.
s_q
if
t
.
topk
is
not
None
else
1
)
*
mean_attended_seqlens
*
t
.
h_kv
*
kv_token_size
,
# K/V
t
.
s_q
*
t
.
h_q
*
(
t
.
dv
*
q_elem_size
),
# Output
kv_token_size
=
656
if
t
.
is_fp8
else
t
.
d
*
torch
.
bfloat16
.
itemsize
memory_volume_B
=
t
.
b
*
sum
([
t
.
s_q
*
t
.
h_q
*
(
t
.
d
*
q_elem_size
),
# Q
(
t
.
s_q
if
t
.
topk
is
not
None
else
1
)
*
mean_attended_seqlens
*
t
.
h_kv
*
kv_token_size
,
# K/V
t
.
s_q
*
t
.
h_q
*
(
t
.
dv
*
q_elem_size
),
# Output
])
achieved_tflops
=
compute_volume_flop
/
time_usage
/
1e12
achieved_gBps
=
memory_volume_B
/
time_usage
/
1e9
print
(
f
"
{
time_usage
*
1000
:.
3
f
}
ms,
{
achieved_tflops
:.
0
f
}
TFLOPS,
{
achieved_gBps
:.
0
f
}
GB/s"
)
print
(
f
"
{
time_usage
*
1000
:.
3
f
}
ms,
{
achieved_tflops
:.
0
f
}
TFLOPS,
{
achieved_gBps
:.
0
f
}
GB/s"
)
def
main
(
torch_dtype
):
...
...
@@ -324,7 +324,7 @@ def main(torch_dtype):
cc_major
,
cc_minor
=
torch
.
cuda
.
get_device_capability
()
if
cc_major
==
10
:
testcases
=
[
t
for
t
in
testcases
if
(
t
.
is_fp8
and
t
.
topk
is
not
None
)]
for
testcase
in
testcases
:
test_flash_mla
(
testcase
)
...
...
tests/test_flash_mla_prefill.py
View file @
1858932a
...
...
@@ -35,8 +35,8 @@ def generate_testcase(t: TestParam) -> Testcase:
torch
.
manual_seed
(
t
.
seed
)
torch
.
cuda
.
manual_seed
(
t
.
seed
)
random
.
seed
(
t
.
seed
)
q
=
torch
.
randn
((
t
.
b
,
t
.
s_q
,
t
.
h_q
,
t
.
d_qk
),
dtype
=
torch
.
bfloat16
)
/
10
kv
=
torch
.
randn
((
t
.
b
,
t
.
s_kv
,
t
.
h_kv
,
t
.
d_qk
),
dtype
=
torch
.
bfloat16
)
/
10
q
=
torch
.
randn
((
t
.
b
,
t
.
s_q
,
t
.
h_q
,
t
.
d_qk
),
dtype
=
torch
.
bfloat16
)
/
10
kv
=
torch
.
randn
((
t
.
b
,
t
.
s_kv
,
t
.
h_kv
,
t
.
d_qk
),
dtype
=
torch
.
bfloat16
)
/
10
q
.
clamp_
(
-
10
,
10
)
kv
.
clamp_
(
-
10
,
10
)
...
...
@@ -48,7 +48,7 @@ def generate_testcase(t: TestParam) -> Testcase:
# NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention
near_mask
=
torch
.
randint
(
0
,
32
,
(
min
(
t
.
topk
,
t
.
s_kv
),))
<
31
cur_indices
=
torch
.
randperm
(
t
.
s_kv
)[:
t
.
topk
]
cur_indices
[
near_mask
]
=
torch
.
randint
(
max
(
0
,
t
.
s_kv
-
20000
),
t
.
s_kv
-
1
,
(
near_mask
.
sum
().
item
(),))
cur_indices
[
near_mask
]
=
torch
.
randint
(
max
(
0
,
t
.
s_kv
-
20000
),
t
.
s_kv
-
1
,
(
near_mask
.
sum
().
item
(),))
if
len
(
cur_indices
)
<
t
.
topk
:
cur_indices
=
torch
.
cat
([
cur_indices
,
torch
.
full
((
t
.
topk
-
len
(
cur_indices
),),
2147480000
)])
cur_indices
=
cur_indices
[
torch
.
randperm
(
t
.
topk
)]
...
...
@@ -72,9 +72,9 @@ def get_flop(p: TestParam) -> float:
def
reference_torch
(
p
:
TestParam
,
t
:
Testcase
,
sm_scale
:
float
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
def
log2sumexp2
(
a
:
torch
.
Tensor
,
dim
:
int
)
->
torch
.
Tensor
:
return
torch
.
logsumexp
(
a
*
math
.
log
(
2
),
dim
=
dim
)
*
math
.
log2
(
math
.
e
)
assert
p
.
b
==
1
indices
=
t
.
indices
[
0
,
:,
0
,
:]
# [s_q, topk]
indices
=
t
.
indices
[
0
,
:,
0
,
:]
# [s_q, topk]
invalid_indices_mask
=
(
indices
<
0
)
|
(
indices
>=
p
.
s_kv
)
qs
=
t
.
q
[
0
,
:,
:,
:].
float
()
# [s_q, h_q, d_qk]
kvs
=
t
.
kv
[
0
,
:,
0
,
:].
float
()
# [s_kv, d_qk]
...
...
@@ -104,15 +104,15 @@ def run_test(p: TestParam) -> bool:
return
flash_mla_sparse_fwd
(
t
.
q
.
squeeze
(
0
),
t
.
kv
.
squeeze
(
0
),
t
.
indices
.
squeeze
(
0
),
sm_scale
=
sm_scale
)
ans_out
,
ans_max_logits
,
ans_lse
=
run_ans
()
torch
.
cuda
.
synchronize
()
if
p
.
benchmark
:
flop
=
get_flop
(
p
)
prefill_ans_time
:
float
=
triton
.
testing
.
do_bench
(
run_ans
,
warmup
=
10
,
rep
=
20
)
/
1000
# type: ignore
prefill_flops
=
flop
/
prefill_ans_time
/
1e12
print
(
f
"Prefill:
{
prefill_ans_time
*
1e6
:
4.0
f
}
us,
{
prefill_flops
:.
3
f
}
TFlops"
)
prefill_ans_time
:
float
=
triton
.
testing
.
do_bench
(
run_ans
,
warmup
=
10
,
rep
=
20
)
/
1000
# type: ignore
prefill_flops
=
flop
/
prefill_ans_time
/
1e12
print
(
f
"Prefill:
{
prefill_ans_time
*
1e6
:
4.0
f
}
us,
{
prefill_flops
:.
3
f
}
TFlops"
)
if
p
.
check_correctness
:
torch
.
cuda
.
synchronize
()
...
...
@@ -120,9 +120,9 @@ def run_test(p: TestParam) -> bool:
torch
.
cuda
.
synchronize
()
is_correct
=
True
is_correct
&=
check_is_allclose
(
"out"
,
ans_out
,
ref_out
,
abs_tol
=
8e-4
,
rel_tol
=
2.01
/
128
,
cos_diff_tol
=
7e-6
)
is_correct
&=
check_is_allclose
(
"max_logits"
,
ans_max_logits
,
ref_max_logits
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
is_correct
&=
check_is_allclose
(
"lse"
,
ans_lse
,
ref_lse
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
is_correct
&=
check_is_allclose
(
"out"
,
ans_out
,
ref_out
,
abs_tol
=
8e-4
,
rel_tol
=
2.01
/
128
,
cos_diff_tol
=
7e-6
)
is_correct
&=
check_is_allclose
(
"max_logits"
,
ans_max_logits
,
ref_max_logits
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
is_correct
&=
check_is_allclose
(
"lse"
,
ans_lse
,
ref_lse
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
return
is_correct
else
:
...
...
@@ -187,11 +187,10 @@ if __name__ == '__main__':
is_correct
=
run_test
(
test
)
if
not
is_correct
:
failed_cases
.
append
(
test
)
if
len
(
failed_cases
)
>
0
:
print
(
f
"
\033
[31m
\033
[1m
{
len
(
failed_cases
)
}
/
{
len
(
testcases
)
}
cases failed:
\033
[0m"
)
for
case
in
failed_cases
:
print
(
f
"
{
case
}
"
)
else
:
print
(
f
"
\033
[32m
\033
[1mAll
{
len
(
testcases
)
}
cases passed!
\033
[0m"
)
tests/test_fmha_sm100.py
View file @
1858932a
...
...
@@ -5,7 +5,6 @@ from torch.utils.checkpoint import checkpoint
import
triton
from
flash_mla
import
flash_attn_varlen_func
from
lib
import
check_is_allclose
def
get_window_size
(
causal
,
window
):
...
...
@@ -71,10 +70,10 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
causal
,
window
)
==
0
).
sum
().
item
()
for
i
in
range
(
b
)])
# print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}")
q
=
torch
.
randn
(
total_q
,
h
,
d
)
/
10
k
=
torch
.
randn
(
total_k
,
h_k
,
d
)
/
10
v
=
torch
.
randn
(
total_k
,
h_k
,
dv
)
/
10
grad_out
=
torch
.
randn
(
total_q
,
h
,
dv
)
/
10
q
=
torch
.
randn
(
total_q
,
h
,
d
)
/
10
k
=
torch
.
randn
(
total_k
,
h_k
,
d
)
/
10
v
=
torch
.
randn
(
total_k
,
h_k
,
dv
)
/
10
grad_out
=
torch
.
randn
(
total_q
,
h
,
dv
)
/
10
softmax_scale
=
(
d
+
100
)
**
(
-
0.5
)
q1
=
q
.
clone
().
requires_grad_
()
...
...
@@ -123,14 +122,14 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
if
check_correctness
:
out_torch
,
lse_torch
=
torch_attn
()
assert
check_is_allclose
(
"out"
,
out_flash
,
out_torch
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"lse"
,
lse_flash
,
lse_torch
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
assert
check_is_allclose
(
"out"
,
out_flash
,
out_torch
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"lse"
,
lse_flash
,
lse_torch
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
if
has_bwd
:
out_torch
.
backward
(
grad_out
,
retain_graph
=
True
)
assert
check_is_allclose
(
"dq"
,
q1
.
grad
,
q2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"dk"
,
k1
.
grad
,
k2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"dv"
,
v1
.
grad
,
v2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"dq"
,
q1
.
grad
,
q2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"dk"
,
k1
.
grad
,
k2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"dv"
,
v1
.
grad
,
v2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
def
forward
():
return
flash_attn
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment