Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
c28eca99
"vscode:/vscode.git/clone" did not exist on "b8a42c41bca0866ecbb3ee43f063a644d05c4938"
Commit
c28eca99
authored
Sep 24, 2025
by
Shengyu Liu
Browse files
Reorganize files and add sparse prefill/decoding kernels on hopper
parent
261330bb
Changes
61
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
39 deletions
+34
-39
tests/test_fmha_sm100.py
tests/test_fmha_sm100.py
+34
-39
No files found.
tests/test_fmha_sm100.py
View file @
c28eca99
...
@@ -6,6 +6,7 @@ import triton
...
@@ -6,6 +6,7 @@ import triton
from
flash_mla
import
flash_attn_varlen_func
from
flash_mla
import
flash_attn_varlen_func
from
lib
import
check_is_allclose
def
get_window_size
(
causal
,
window
):
def
get_window_size
(
causal
,
window
):
if
window
>
0
:
if
window
>
0
:
...
@@ -28,21 +29,15 @@ def get_attn_bias(s_q, s_k, causal, window):
...
@@ -28,21 +29,15 @@ def get_attn_bias(s_q, s_k, causal, window):
return
attn_bias
return
attn_bias
def
assert_close
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
)
->
None
:
x
,
y
=
x
.
double
(),
y
.
double
()
RMSE
=
((
x
-
y
)
*
(
x
-
y
)).
mean
().
sqrt
().
item
()
cos_diff
=
1
-
2
*
(
x
*
y
).
sum
().
item
()
/
max
((
x
*
x
+
y
*
y
).
sum
().
item
(),
1e-12
)
amax_diff
=
(
x
-
y
).
abs
().
max
().
item
()
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
assert
cos_diff
<
1e-5
,
f
"
{
name
}
:
{
cos_diff
=
}
,
{
RMSE
=
}
,
{
amax_diff
=
}
"
def
sdpa
(
query
,
key
,
value
,
attn_bias
,
softmax_scale
=
None
):
def
sdpa
(
query
,
key
,
value
,
attn_bias
,
softmax_scale
=
None
):
query
=
query
.
float
().
transpose
(
-
3
,
-
2
)
key
=
key
.
float
().
transpose
(
-
3
,
-
2
)
value
=
value
.
float
().
transpose
(
-
3
,
-
2
)
key
=
key
.
repeat_interleave
(
h
//
h_k
,
dim
=-
3
)
key
=
key
.
repeat_interleave
(
h
//
h_k
,
dim
=-
3
)
value
=
value
.
repeat_interleave
(
h
//
h_k
,
dim
=-
3
)
value
=
value
.
repeat_interleave
(
h
//
h_k
,
dim
=-
3
)
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
query
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
query
.
shape
[
-
1
]
**
(
-
0.5
)
attn_weight
=
query
@
key
.
transpose
(
-
2
,
-
1
)
*
softmax_scale
attn_weight
=
(
query
@
key
.
transpose
(
-
2
,
-
1
)
)
*
softmax_scale
attn_weight
+=
attn_bias
attn_weight
+=
attn_bias
lse
=
attn_weight
.
logsumexp
(
dim
=-
1
)
lse
=
attn_weight
.
logsumexp
(
dim
=-
1
)
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
,
dtype
=
torch
.
float32
)
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
,
dtype
=
torch
.
float32
)
...
@@ -53,8 +48,8 @@ def sdpa_checkpoint(*args, **kwargs):
...
@@ -53,8 +48,8 @@ def sdpa_checkpoint(*args, **kwargs):
return
checkpoint
(
sdpa
,
*
args
,
use_reentrant
=
False
,
**
kwargs
)
return
checkpoint
(
sdpa
,
*
args
,
use_reentrant
=
False
,
**
kwargs
)
def
test_flash_attention
(
b
,
mean_sq
,
mean_sk
,
varlen
,
h
,
h_k
,
d
,
dv
,
causal
,
window
,
has_bwd
):
def
test_flash_attention
(
b
,
mean_sq
,
mean_sk
,
varlen
,
h
,
h_k
,
d
,
dv
,
causal
,
window
,
has_bwd
,
check_correctness
:
bool
=
True
):
print
(
f
"
{
b
=
}
,
{
mean_sq
=
}
,
{
mean_sk
=
}
,
{
varlen
=
}
,
{
h
=
}
,
{
h_k
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
"
)
print
(
f
"
{
b
=
}
,
{
mean_sq
=
}
,
{
mean_sk
=
}
,
{
varlen
=
}
,
{
h
=
}
,
{
h_k
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
has_bwd
=
}
,
{
check_correctness
=
}
"
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
random
.
seed
(
0
)
random
.
seed
(
0
)
...
@@ -76,19 +71,20 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
...
@@ -76,19 +71,20 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
causal
,
window
)
==
0
).
sum
().
item
()
for
i
in
range
(
b
)])
causal
,
window
)
==
0
).
sum
().
item
()
for
i
in
range
(
b
)])
# print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}")
# print(f"{total_q=}, {max_seqlen_q=}, {total_k=}, {max_seqlen_k=}, {total_attn_compute=}, {cu_seqlens_q.tolist()}, {cu_seqlens_k.tolist()}")
q
=
torch
.
randn
(
total_q
,
h
,
d
)
q
=
torch
.
randn
(
total_q
,
h
,
d
)
/
10
k
=
torch
.
randn
(
total_k
,
h_k
,
d
)
k
=
torch
.
randn
(
total_k
,
h_k
,
d
)
/
10
v
=
torch
.
randn
(
total_k
,
h_k
,
dv
)
v
=
torch
.
randn
(
total_k
,
h_k
,
dv
)
/
10
grad_out
=
torch
.
randn
(
total_q
,
h
,
dv
)
grad_out
=
torch
.
randn
(
total_q
,
h
,
dv
)
/
10
softmax_scale
=
(
d
+
100
)
**
(
-
0.5
)
softmax_scale
=
(
d
+
100
)
**
(
-
0.5
)
q1
=
q
.
clone
().
requires_grad_
()
q1
=
q
.
clone
().
requires_grad_
()
k1
=
k
.
clone
().
requires_grad_
()
k1
=
k
.
clone
().
requires_grad_
()
v1
=
v
.
clone
().
requires_grad_
()
v1
=
v
.
clone
().
requires_grad_
()
q2
=
q
.
clone
().
requires_grad_
()
if
check_correctness
:
k2
=
k
.
clone
().
requires_grad_
()
q2
=
q
.
clone
().
requires_grad_
()
v2
=
v
.
clone
().
requires_grad_
()
k2
=
k
.
clone
().
requires_grad_
()
v2
=
v
.
clone
().
requires_grad_
()
def
flash_attn
():
def
flash_attn
():
q1
.
grad
=
k1
.
grad
=
v1
.
grad
=
None
q1
.
grad
=
k1
.
grad
=
v1
.
grad
=
None
...
@@ -106,9 +102,9 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
...
@@ -106,9 +102,9 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
lse
=
[]
lse
=
[]
for
i
in
range
(
b
):
for
i
in
range
(
b
):
OUT
,
LSE
=
sdpa_checkpoint
(
OUT
,
LSE
=
sdpa_checkpoint
(
q2
[
cu_seqlens_q
[
i
].
item
():
cu_seqlens_q
[
i
+
1
].
item
()]
.
float
().
transpose
(
-
3
,
-
2
)
,
q2
[
cu_seqlens_q
[
i
].
item
():
cu_seqlens_q
[
i
+
1
].
item
()],
k2
[
cu_seqlens_k
[
i
].
item
():
cu_seqlens_k
[
i
+
1
].
item
()]
.
float
().
transpose
(
-
3
,
-
2
)
,
k2
[
cu_seqlens_k
[
i
].
item
():
cu_seqlens_k
[
i
+
1
].
item
()],
v2
[
cu_seqlens_k
[
i
].
item
():
cu_seqlens_k
[
i
+
1
].
item
()]
.
float
().
transpose
(
-
3
,
-
2
)
,
v2
[
cu_seqlens_k
[
i
].
item
():
cu_seqlens_k
[
i
+
1
].
item
()],
attn_bias
=
get_attn_bias
(
seqlens_q
[
i
].
item
(),
seqlens_k
[
i
].
item
(),
causal
,
window
),
attn_bias
=
get_attn_bias
(
seqlens_q
[
i
].
item
(),
seqlens_k
[
i
].
item
(),
causal
,
window
),
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
)
)
...
@@ -119,20 +115,23 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
...
@@ -119,20 +115,23 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
return
out
,
lse
return
out
,
lse
out_flash
,
lse_flash
=
flash_attn
()
out_flash
,
lse_flash
=
flash_attn
()
out_torch
,
lse_torch
=
torch_attn
()
assert_close
(
out_flash
,
out_torch
,
"out"
)
assert_close
(
lse_flash
,
lse_torch
,
"lse"
)
if
has_bwd
:
if
has_bwd
:
out_flash
.
backward
(
grad_out
,
retain_graph
=
True
)
out_flash
.
backward
(
grad_out
,
retain_graph
=
True
)
out_torch
.
backward
(
grad_out
,
retain_graph
=
True
)
assert_close
(
q1
.
grad
,
q2
.
grad
,
"dq"
)
assert_close
(
k1
.
grad
,
k2
.
grad
,
"dk"
)
assert_close
(
v1
.
grad
,
v2
.
grad
,
"dv"
)
dq1
=
q1
.
grad
.
clone
()
dq1
=
q1
.
grad
.
clone
()
dk1
=
k1
.
grad
.
clone
()
dk1
=
k1
.
grad
.
clone
()
dv1
=
v1
.
grad
.
clone
()
dv1
=
v1
.
grad
.
clone
()
if
check_correctness
:
out_torch
,
lse_torch
=
torch_attn
()
assert
check_is_allclose
(
"out"
,
out_flash
,
out_torch
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"lse"
,
lse_flash
,
lse_torch
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
if
has_bwd
:
out_torch
.
backward
(
grad_out
,
retain_graph
=
True
)
assert
check_is_allclose
(
"dq"
,
q1
.
grad
,
q2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"dk"
,
k1
.
grad
,
k2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
assert
check_is_allclose
(
"dv"
,
v1
.
grad
,
v2
.
grad
,
abs_tol
=
1e-3
,
rel_tol
=
8.01
/
128
,
cos_diff_tol
=
7e-6
)
def
forward
():
def
forward
():
return
flash_attn
()
return
flash_attn
()
...
@@ -150,12 +149,6 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
...
@@ -150,12 +149,6 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
assert
torch
.
equal
(
k1
.
grad
,
dk1
),
"dk deterministic check failed!"
assert
torch
.
equal
(
k1
.
grad
,
dk1
),
"dk deterministic check failed!"
assert
torch
.
equal
(
v1
.
grad
,
dv1
),
"dv deterministic check failed!"
assert
torch
.
equal
(
v1
.
grad
,
dv1
),
"dv deterministic check failed!"
# with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
# forward()
# if has_bwd:
# backward()
# print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=120))
def
timer
(
func
,
name
):
def
timer
(
func
,
name
):
t
=
triton
.
testing
.
do_bench
(
func
,
warmup
=
2
,
rep
=
3
)
t
=
triton
.
testing
.
do_bench
(
func
,
warmup
=
2
,
rep
=
3
)
FLOPS
=
total_attn_compute
*
h
*
2
*
((
d
+
dv
)
if
name
==
"fwd"
else
((
d
*
3
+
dv
*
2
)))
FLOPS
=
total_attn_compute
*
h
*
2
*
((
d
+
dv
)
if
name
==
"fwd"
else
((
d
*
3
+
dv
*
2
)))
...
@@ -173,18 +166,20 @@ if __name__ == "__main__":
...
@@ -173,18 +166,20 @@ if __name__ == "__main__":
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
set_float32_matmul_precision
(
"high"
)
b
=
4
b
=
2
window
=
0
window
=
0
has_bwd
=
False
has_bwd
=
False
for
(
mean_sq
,
mean_sk
)
in
[(
4096
,
4096
),
(
8192
,
8192
)]:
for
(
mean_sq
,
mean_sk
)
in
[(
4096
,
4096
),
(
8192
,
8192
)]:
for
varlen
in
[
False
,
True
]:
for
varlen
in
[
False
,
True
]:
for
(
h
,
h_k
)
in
[(
32
,
32
),
(
32
,
4
)]:
for
(
h
,
h_k
)
in
[(
128
,
128
),
(
32
,
4
)]:
if
h
!=
h_k
:
if
h
!=
h_k
:
has_bwd
=
False
has_bwd
=
False
else
:
else
:
has_bwd
=
True
has_bwd
=
True
for
(
d
,
dv
)
in
[(
128
,
128
),
(
192
,
128
)]:
for
(
d
,
dv
)
in
[(
128
,
128
),
(
192
,
128
)]:
for
causal
in
[
False
,
True
]:
for
causal
in
[
False
,
True
]:
test_flash_attention
(
b
,
mean_sq
,
mean_sk
,
varlen
,
h
,
h_k
,
d
,
dv
,
causal
,
window
,
has_bwd
)
skip_correctness_check
=
mean_sq
==
8192
and
mean_sk
==
8192
and
h
==
128
test_flash_attention
(
b
,
mean_sq
,
mean_sk
,
varlen
,
h
,
h_k
,
d
,
dv
,
causal
,
window
,
has_bwd
,
not
skip_correctness_check
)
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment