Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
17944550
Unverified
Commit
17944550
authored
Sep 29, 2025
by
Shengyu Liu
Committed by
GitHub
Sep 29, 2025
Browse files
Merge pull request #98 from deepseek-ai/open-source-h
Add Sparse Attention Kernels on Hopper
parents
ebf30641
3969f20b
Changes
61
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
42 deletions
+34
-42
tests/test_fmha_sm100.py
tests/test_fmha_sm100.py
+34
-42
No files found.
tests/test_fmha_sm100.py
View file @
17944550
...
...
@@ -6,6 +6,7 @@ import triton
from
flash_mla
import
flash_attn_varlen_func
from
lib
import
check_is_allclose
def
get_window_size
(
causal
,
window
):
if
window
>
0
:
...
...
@@ -28,24 +29,15 @@ def get_attn_bias(s_q, s_k, causal, window):
return
attn_bias
def
assert_close
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
)
->
None
:
close_tensor
=
torch
.
isclose
(
x
.
to
(
torch
.
float32
),
y
.
to
(
torch
.
float32
),
rtol
=
1e-5
,
atol
=
1e-5
)
if
close_tensor
.
all
():
return
x
,
y
=
x
.
double
(),
y
.
double
()
RMSE
=
((
x
-
y
)
*
(
x
-
y
)).
mean
().
sqrt
().
item
()
cos_diff
=
1
-
2
*
(
x
*
y
).
sum
().
item
()
/
max
((
x
*
x
+
y
*
y
).
sum
().
item
(),
1e-12
)
amax_diff
=
(
x
-
y
).
abs
().
max
().
item
()
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
assert
cos_diff
<
1e-5
,
f
"
{
name
}
:
{
cos_diff
=
}
,
{
RMSE
=
}
,
{
amax_diff
=
}
"
def
sdpa
(
query
,
key
,
value
,
attn_bias
,
softmax_scale
=
None
):
query
=
query
.
float
().
transpose
(
-
3
,
-
2
)
key
=
key
.
float
().
transpose
(
-
3
,
-
2
)
value
=
value
.
float
().
transpose
(
-
3
,
-
2
)
key
=
key
.
repeat_interleave
(
h
//
h_k
,
dim
=-
3
)
value
=
value
.
repeat_interleave
(
h
//
h_k
,
dim
=-
3
)
if
softmax_scale
is
None
:
softmax_scale
=
query
.
shape
[
-
1
]
**
(
-
0.5
)
attn_weight
=
query
@
key
.
transpose
(
-
2
,
-
1
)
*
softmax_scale
attn_weight
=
(
query
@
key
.
transpose
(
-
2
,
-
1
)
)
*
softmax_scale
attn_weight
+=
attn_bias
lse
=
attn_weight
.
logsumexp
(
dim
=-
1
)
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
,
dtype
=
torch
.
float32
)
...
...
@@ -56,8 +48,8 @@ def sdpa_checkpoint(*args, **kwargs):
return
checkpoint
(
sdpa
,
*
args
,
use_reentrant
=
False
,
**
kwargs
)
def
test_flash_attention
(
b
,
mean_sq
,
mean_sk
,
varlen
,
h
,
h_k
,
d
,
dv
,
causal
,
window
,
has_bwd
):
print
(
f
"
{
b
=
}
,
{
mean_sq
=
}
,
{
mean_sk
=
}
,
{
varlen
=
}
,
{
h
=
}
,
{
h_k
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
"
)
def
test_flash_attention
(
b
,
mean_sq
,
mean_sk
,
varlen
,
h
,
h_k
,
d
,
dv
,
causal
,
window
,
has_bwd
,
check_correctness
:
bool
=
True
):
print
(
f
"
{
b
=
}
,
{
mean_sq
=
}
,
{
mean_sk
=
}
,
{
varlen
=
}
,
{
h
=
}
,
{
h_k
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
has_bwd
=
}
,
{
check_correctness
=
}
"
)
torch
.
manual_seed
(
0
)
random
.
seed
(
0
)
...
...
@@ -79,19 +71,20 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
causal
,
window
)
==
0
).
sum
().
item
()
for
i
in
range
(
b
)])
# print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}")
q
=
torch
.
randn
(
total_q
,
h
,
d
)
k
=
torch
.
randn
(
total_k
,
h_k
,
d
)
v
=
torch
.
randn
(
total_k
,
h_k
,
dv
)
grad_out
=
torch
.
randn
(
total_q
,
h
,
dv
)
q
=
torch
.
randn
(
total_q
,
h
,
d
)
/
10
k
=
torch
.
randn
(
total_k
,
h_k
,
d
)
/
10
v
=
torch
.
randn
(
total_k
,
h_k
,
dv
)
/
10
grad_out
=
torch
.
randn
(
total_q
,
h
,
dv
)
/
10
softmax_scale
=
(
d
+
100
)
**
(
-
0.5
)
q1
=
q
.
clone
().
requires_grad_
()
k1
=
k
.
clone
().
requires_grad_
()
v1
=
v
.
clone
().
requires_grad_
()
q2
=
q
.
clone
().
requires_grad_
()
k2
=
k
.
clone
().
requires_grad_
()
v2
=
v
.
clone
().
requires_grad_
()
if
check_correctness
:
q2
=
q
.
clone
().
requires_grad_
()
k2
=
k
.
clone
().
requires_grad_
()
v2
=
v
.
clone
().
requires_grad_
()
def
flash_attn
():
q1
.
grad
=
k1
.
grad
=
v1
.
grad
=
None
...
...
@@ -109,9 +102,9 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
lse
=
[]
for
i
in
range
(
b
):
OUT
,
LSE
=
sdpa_checkpoint
(
q2
[
cu_seqlens_q
[
i
].
item
():
cu_seqlens_q
[
i
+
1
].
item
()]
.
float
().
transpose
(
-
3
,
-
2
)
,
k2
[
cu_seqlens_k
[
i
].
item
():
cu_seqlens_k
[
i
+
1
].
item
()]
.
float
().
transpose
(
-
3
,
-
2
)
,
v2
[
cu_seqlens_k
[
i
].
item
():
cu_seqlens_k
[
i
+
1
].
item
()]
.
float
().
transpose
(
-
3
,
-
2
)
,
q2
[
cu_seqlens_q
[
i
].
item
():
cu_seqlens_q
[
i
+
1
].
item
()],
k2
[
cu_seqlens_k
[
i
].
item
():
cu_seqlens_k
[
i
+
1
].
item
()],
v2
[
cu_seqlens_k
[
i
].
item
():
cu_seqlens_k
[
i
+
1
].
item
()],
attn_bias
=
get_attn_bias
(
seqlens_q
[
i
].
item
(),
seqlens_k
[
i
].
item
(),
causal
,
window
),
softmax_scale
=
softmax_scale
,
)
...
...
@@ -122,20 +115,23 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
return
out
,
lse
out_flash
,
lse_flash
=
flash_attn
()
out_torch
,
lse_torch
=
torch_attn
()
assert_close
(
out_flash
,
out_torch
,
"out"
)
assert_close
(
lse_flash
,
lse_torch
,
"lse"
)
if
has_bwd
:
out_flash
.
backward
(
grad_out
,
retain_graph
=
True
)
out_torch
.
backward
(
grad_out
,
retain_graph
=
True
)
assert_close
(
q1
.
grad
,
q2
.
grad
,
"dq"
)
assert_close
(
k1
.
grad
,
k2
.
grad
,
"dk"
)
assert_close
(
v1
.
grad
,
v2
.
grad
,
"dv"
)
dq1
=
q1
.
grad
.
clone
()
dk1
=
k1
.
grad
.
clone
()
dv1
=
v1
.
grad
.
clone
()
if
check_correctness
:
out_torch
,
lse_torch
=
torch_attn
()
assert
check_is_allclose
(
"out"
,
out_flash
,
out_torch
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"lse"
,
lse_flash
,
lse_torch
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
if
has_bwd
:
out_torch
.
backward
(
grad_out
,
retain_graph
=
True
)
assert
check_is_allclose
(
"dq"
,
q1
.
grad
,
q2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"dk"
,
k1
.
grad
,
k2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"dv"
,
v1
.
grad
,
v2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
def
forward
():
return
flash_attn
()
...
...
@@ -153,12 +149,6 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
assert
torch
.
equal
(
k1
.
grad
,
dk1
),
"dk deterministic check failed!"
assert
torch
.
equal
(
v1
.
grad
,
dv1
),
"dv deterministic check failed!"
# with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
# forward()
# if has_bwd:
# backward()
# print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=120))
def
timer
(
func
,
name
):
t
=
triton
.
testing
.
do_bench
(
func
,
warmup
=
2
,
rep
=
3
)
FLOPS
=
total_attn_compute
*
h
*
2
*
((
d
+
dv
)
if
name
==
"fwd"
else
((
d
*
3
+
dv
*
2
)))
...
...
@@ -176,18 +166,20 @@ if __name__ == "__main__":
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_float32_matmul_precision
(
"high"
)
b
=
4
b
=
2
window
=
0
has_bwd
=
False
for
(
mean_sq
,
mean_sk
)
in
[(
4096
,
4096
),
(
8192
,
8192
)]:
for
varlen
in
[
False
,
True
]:
for
(
h
,
h_k
)
in
[(
32
,
32
),
(
32
,
4
)]:
for
(
h
,
h_k
)
in
[(
128
,
128
),
(
32
,
4
)]:
if
h
!=
h_k
:
has_bwd
=
False
else
:
has_bwd
=
True
for
(
d
,
dv
)
in
[(
128
,
128
),
(
192
,
128
)]:
for
causal
in
[
False
,
True
]:
test_flash_attention
(
b
,
mean_sq
,
mean_sk
,
varlen
,
h
,
h_k
,
d
,
dv
,
causal
,
window
,
has_bwd
)
skip_correctness_check
=
mean_sq
==
8192
and
mean_sk
==
8192
and
h
==
128
test_flash_attention
(
b
,
mean_sq
,
mean_sk
,
varlen
,
h
,
h_k
,
d
,
dv
,
causal
,
window
,
has_bwd
,
not
skip_correctness_check
)
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment