Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
c28eca99
Commit
c28eca99
authored
Sep 24, 2025
by
Shengyu Liu
Browse files
Reorganize files and add sparse prefill/decoding kernels on hopper
parent
261330bb
Changes
61
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
39 deletions
+34
-39
tests/test_fmha_sm100.py
tests/test_fmha_sm100.py
+34
-39
No files found.
tests/test_fmha_sm100.py
View file @
c28eca99
...
@@ -6,6 +6,7 @@ import triton
...
@@ -6,6 +6,7 @@ import triton
from
flash_mla
import
flash_attn_varlen_func
from
flash_mla
import
flash_attn_varlen_func
from
lib
import
check_is_allclose
def
get_window_size
(
causal
,
window
):
def
get_window_size
(
causal
,
window
):
if
window
>
0
:
if
window
>
0
:
...
@@ -28,21 +29,15 @@ def get_attn_bias(s_q, s_k, causal, window):
...
@@ -28,21 +29,15 @@ def get_attn_bias(s_q, s_k, causal, window):
return
attn_bias
return
attn_bias
def
assert_close
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
)
->
None
:
x
,
y
=
x
.
double
(),
y
.
double
()
RMSE
=
((
x
-
y
)
*
(
x
-
y
)).
mean
().
sqrt
().
item
()
cos_diff
=
1
-
2
*
(
x
*
y
).
sum
().
item
()
/
max
((
x
*
x
+
y
*
y
).
sum
().
item
(),
1e-12
)
amax_diff
=
(
x
-
y
).
abs
().
max
().
item
()
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
assert
cos_diff
<
1e-5
,
f
"
{
name
}
:
{
cos_diff
=
}
,
{
RMSE
=
}
,
{
amax_diff
=
}
"
def
sdpa
(
query
,
key
,
value
,
attn_bias
,
softmax_scale
=
None
):
def
sdpa
(
query
,
key
,
value
,
attn_bias
,
softmax_scale
=
None
):
query
=
query
.
float
().
transpose
(
-
3
,
-
2
)
key
=
key
.
float
().
transpose
(
-
3
,
-
2
)
value
=
value
.
float
().
transpose
(
-
3
,
-
2
)
key
=
key
.
repeat_interleave
(
h
//
h_k
,
dim
=-
3
)
key
=
key
.
repeat_interleave
(
h
//
h_k
,
dim
=-
3
)
value
=
value
.
repeat_interleave
(
h
//
h_k
,
dim
=-
3
)
value
=
value
.
repeat_interleave
(
h
//
h_k
,
dim
=-
3
)
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
query
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
query
.
shape
[
-
1
]
**
(
-
0.5
)
attn_weight
=
query
@
key
.
transpose
(
-
2
,
-
1
)
*
softmax_scale
attn_weight
=
(
query
@
key
.
transpose
(
-
2
,
-
1
)
)
*
softmax_scale
attn_weight
+=
attn_bias
attn_weight
+=
attn_bias
lse
=
attn_weight
.
logsumexp
(
dim
=-
1
)
lse
=
attn_weight
.
logsumexp
(
dim
=-
1
)
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
,
dtype
=
torch
.
float32
)
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
,
dtype
=
torch
.
float32
)
...
@@ -53,8 +48,8 @@ def sdpa_checkpoint(*args, **kwargs):
...
@@ -53,8 +48,8 @@ def sdpa_checkpoint(*args, **kwargs):
return
checkpoint
(
sdpa
,
*
args
,
use_reentrant
=
False
,
**
kwargs
)
return
checkpoint
(
sdpa
,
*
args
,
use_reentrant
=
False
,
**
kwargs
)
def
test_flash_attention
(
b
,
mean_sq
,
mean_sk
,
varlen
,
h
,
h_k
,
d
,
dv
,
causal
,
window
,
has_bwd
):
def
test_flash_attention
(
b
,
mean_sq
,
mean_sk
,
varlen
,
h
,
h_k
,
d
,
dv
,
causal
,
window
,
has_bwd
,
check_correctness
:
bool
=
True
):
print
(
f
"
{
b
=
}
,
{
mean_sq
=
}
,
{
mean_sk
=
}
,
{
varlen
=
}
,
{
h
=
}
,
{
h_k
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
"
)
print
(
f
"
{
b
=
}
,
{
mean_sq
=
}
,
{
mean_sk
=
}
,
{
varlen
=
}
,
{
h
=
}
,
{
h_k
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
has_bwd
=
}
,
{
check_correctness
=
}
"
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
random
.
seed
(
0
)
random
.
seed
(
0
)
...
@@ -76,19 +71,20 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
...
@@ -76,19 +71,20 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
causal
,
window
)
==
0
).
sum
().
item
()
for
i
in
range
(
b
)])
causal
,
window
)
==
0
).
sum
().
item
()
for
i
in
range
(
b
)])
# print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}")
# print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}")
q
=
torch
.
randn
(
total_q
,
h
,
d
)
q
=
torch
.
randn
(
total_q
,
h
,
d
)
/
10
k
=
torch
.
randn
(
total_k
,
h_k
,
d
)
k
=
torch
.
randn
(
total_k
,
h_k
,
d
)
/
10
v
=
torch
.
randn
(
total_k
,
h_k
,
dv
)
v
=
torch
.
randn
(
total_k
,
h_k
,
dv
)
/
10
grad_out
=
torch
.
randn
(
total_q
,
h
,
dv
)
grad_out
=
torch
.
randn
(
total_q
,
h
,
dv
)
/
10
softmax_scale
=
(
d
+
100
)
**
(
-
0.5
)
softmax_scale
=
(
d
+
100
)
**
(
-
0.5
)
q1
=
q
.
clone
().
requires_grad_
()
q1
=
q
.
clone
().
requires_grad_
()
k1
=
k
.
clone
().
requires_grad_
()
k1
=
k
.
clone
().
requires_grad_
()
v1
=
v
.
clone
().
requires_grad_
()
v1
=
v
.
clone
().
requires_grad_
()
q2
=
q
.
clone
().
requires_grad_
()
if
check_correctness
:
k2
=
k
.
clone
().
requires_grad_
()
q2
=
q
.
clone
().
requires_grad_
()
v2
=
v
.
clone
().
requires_grad_
()
k2
=
k
.
clone
().
requires_grad_
()
v2
=
v
.
clone
().
requires_grad_
()
def
flash_attn
():
def
flash_attn
():
q1
.
grad
=
k1
.
grad
=
v1
.
grad
=
None
q1
.
grad
=
k1
.
grad
=
v1
.
grad
=
None
...
@@ -106,9 +102,9 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
...
@@ -106,9 +102,9 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
lse
=
[]
lse
=
[]
for
i
in
range
(
b
):
for
i
in
range
(
b
):
OUT
,
LSE
=
sdpa_checkpoint
(
OUT
,
LSE
=
sdpa_checkpoint
(
q2
[
cu_seqlens_q
[
i
].
item
():
cu_seqlens_q
[
i
+
1
].
item
()]
.
float
().
transpose
(
-
3
,
-
2
)
,
q2
[
cu_seqlens_q
[
i
].
item
():
cu_seqlens_q
[
i
+
1
].
item
()],
k2
[
cu_seqlens_k
[
i
].
item
():
cu_seqlens_k
[
i
+
1
].
item
()]
.
float
().
transpose
(
-
3
,
-
2
)
,
k2
[
cu_seqlens_k
[
i
].
item
():
cu_seqlens_k
[
i
+
1
].
item
()],
v2
[
cu_seqlens_k
[
i
].
item
():
cu_seqlens_k
[
i
+
1
].
item
()]
.
float
().
transpose
(
-
3
,
-
2
)
,
v2
[
cu_seqlens_k
[
i
].
item
():
cu_seqlens_k
[
i
+
1
].
item
()],
attn_bias
=
get_attn_bias
(
seqlens_q
[
i
].
item
(),
seqlens_k
[
i
].
item
(),
causal
,
window
),
attn_bias
=
get_attn_bias
(
seqlens_q
[
i
].
item
(),
seqlens_k
[
i
].
item
(),
causal
,
window
),
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
)
)
...
@@ -119,20 +115,23 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
...
@@ -119,20 +115,23 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
return
out
,
lse
return
out
,
lse
out_flash
,
lse_flash
=
flash_attn
()
out_flash
,
lse_flash
=
flash_attn
()
out_torch
,
lse_torch
=
torch_attn
()
assert_close
(
out_flash
,
out_torch
,
"out"
)
assert_close
(
lse_flash
,
lse_torch
,
"lse"
)
if
has_bwd
:
if
has_bwd
:
out_flash
.
backward
(
grad_out
,
retain_graph
=
True
)
out_flash
.
backward
(
grad_out
,
retain_graph
=
True
)
out_torch
.
backward
(
grad_out
,
retain_graph
=
True
)
assert_close
(
q1
.
grad
,
q2
.
grad
,
"dq"
)
assert_close
(
k1
.
grad
,
k2
.
grad
,
"dk"
)
assert_close
(
v1
.
grad
,
v2
.
grad
,
"dv"
)
dq1
=
q1
.
grad
.
clone
()
dq1
=
q1
.
grad
.
clone
()
dk1
=
k1
.
grad
.
clone
()
dk1
=
k1
.
grad
.
clone
()
dv1
=
v1
.
grad
.
clone
()
dv1
=
v1
.
grad
.
clone
()
if
check_correctness
:
out_torch
,
lse_torch
=
torch_attn
()
assert
check_is_allclose
(
"out"
,
out_flash
,
out_torch
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"lse"
,
lse_flash
,
lse_torch
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
if
has_bwd
:
out_torch
.
backward
(
grad_out
,
retain_graph
=
True
)
assert
check_is_allclose
(
"dq"
,
q1
.
grad
,
q2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"dk"
,
k1
.
grad
,
k2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"dv"
,
v1
.
grad
,
v2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
def
forward
():
def
forward
():
return
flash_attn
()
return
flash_attn
()
...
@@ -150,12 +149,6 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
...
@@ -150,12 +149,6 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
assert
torch
.
equal
(
k1
.
grad
,
dk1
),
"dk deterministic check failed!"
assert
torch
.
equal
(
k1
.
grad
,
dk1
),
"dk deterministic check failed!"
assert
torch
.
equal
(
v1
.
grad
,
dv1
),
"dv deterministic check failed!"
assert
torch
.
equal
(
v1
.
grad
,
dv1
),
"dv deterministic check failed!"
# with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
# forward()
# if has_bwd:
# backward()
# print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=120))
def
timer
(
func
,
name
):
def
timer
(
func
,
name
):
t
=
triton
.
testing
.
do_bench
(
func
,
warmup
=
2
,
rep
=
3
)
t
=
triton
.
testing
.
do_bench
(
func
,
warmup
=
2
,
rep
=
3
)
FLOPS
=
total_attn_compute
*
h
*
2
*
((
d
+
dv
)
if
name
==
"fwd"
else
((
d
*
3
+
dv
*
2
)))
FLOPS
=
total_attn_compute
*
h
*
2
*
((
d
+
dv
)
if
name
==
"fwd"
else
((
d
*
3
+
dv
*
2
)))
...
@@ -173,18 +166,20 @@ if __name__ == "__main__":
...
@@ -173,18 +166,20 @@ if __name__ == "__main__":
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_float32_matmul_precision
(
"high"
)
b
=
4
b
=
2
window
=
0
window
=
0
has_bwd
=
False
has_bwd
=
False
for
(
mean_sq
,
mean_sk
)
in
[(
4096
,
4096
),
(
8192
,
8192
)]:
for
(
mean_sq
,
mean_sk
)
in
[(
4096
,
4096
),
(
8192
,
8192
)]:
for
varlen
in
[
False
,
True
]:
for
varlen
in
[
False
,
True
]:
for
(
h
,
h_k
)
in
[(
32
,
32
),
(
32
,
4
)]:
for
(
h
,
h_k
)
in
[(
128
,
128
),
(
32
,
4
)]:
if
h
!=
h_k
:
if
h
!=
h_k
:
has_bwd
=
False
has_bwd
=
False
else
:
else
:
has_bwd
=
True
has_bwd
=
True
for
(
d
,
dv
)
in
[(
128
,
128
),
(
192
,
128
)]:
for
(
d
,
dv
)
in
[(
128
,
128
),
(
192
,
128
)]:
for
causal
in
[
False
,
True
]:
for
causal
in
[
False
,
True
]:
test_flash_attention
(
b
,
mean_sq
,
mean_sk
,
varlen
,
h
,
h_k
,
d
,
dv
,
causal
,
window
,
has_bwd
)
skip_correctness_check
=
mean_sq
==
8192
and
mean_sk
==
8192
and
h
==
128
test_flash_attention
(
b
,
mean_sq
,
mean_sk
,
varlen
,
h
,
h_k
,
d
,
dv
,
causal
,
window
,
has_bwd
,
not
skip_correctness_check
)
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment