Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
81c89111
Unverified
Commit
81c89111
authored
Apr 17, 2025
by
Baizhou Zhang
Committed by
GitHub
Apr 17, 2025
Browse files
Add test for flash_attn_varlen_func kernel (#5484)
parent
92d1561b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
461 additions
and
0 deletions
+461
-0
sgl-kernel/tests/test_flash_attention.py
sgl-kernel/tests/test_flash_attention.py
+461
-0
No files found.
sgl-kernel/tests/test_flash_attention.py
View file @
81c89111
...
@@ -296,6 +296,152 @@ def attention_ref(
...
@@ -296,6 +296,152 @@ def attention_ref(
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
def
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
kvpacked
=
False
,
qkvpacked
=
False
,
add_unused_qkv
=
False
,
query_unused_mask
=
None
,
key_unused_mask
=
None
,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert
not
(
kvpacked
and
qkvpacked
)
batch_size
,
seqlen_q
,
nheads
,
d
=
q
.
shape
_
,
seqlen_k
,
nheads_k
,
_
=
k
.
shape
assert
k
.
shape
==
(
batch_size
,
seqlen_k
,
nheads_k
,
d
)
assert
v
.
shape
==
(
batch_size
,
seqlen_k
,
nheads_k
,
d
)
if
query_unused_mask
is
not
None
or
key_unused_mask
is
not
None
:
assert
not
kvpacked
assert
not
qkvpacked
if
query_padding_mask
is
not
None
:
q_unpad
,
indices_q
,
cu_seqlens_q
,
max_seqlen_q
,
seqused_q
=
unpad_input
(
q
,
query_padding_mask
,
query_unused_mask
,
)
output_pad_fn
=
lambda
output_unpad
:
pad_input
(
output_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
q_unpad
=
rearrange
(
q
,
"b s h d -> (b s) h d"
)
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q_unpad
.
device
,
)
seqused_q
=
None
max_seqlen_q
=
seqlen_q
output_pad_fn
=
lambda
output_unpad
:
rearrange
(
output_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
if
key_padding_mask
is
not
None
:
k_unpad
,
indices_k
,
cu_seqlens_k
,
max_seqlen_k
,
seqused_k
=
unpad_input
(
k
,
key_padding_mask
,
key_unused_mask
)
v_unpad
,
_
,
_
,
_
,
_
=
unpad_input
(
v
,
key_padding_mask
,
key_unused_mask
)
else
:
k_unpad
=
rearrange
(
k
,
"b s h d -> (b s) h d"
)
v_unpad
=
rearrange
(
v
,
"b s h d -> (b s) h d"
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
k_unpad
.
device
,
)
seqused_k
=
None
max_seqlen_k
=
seqlen_k
if
qkvpacked
:
assert
(
query_padding_mask
==
key_padding_mask
).
all
()
assert
nheads
==
nheads_k
qkv_unpad
=
torch
.
stack
([
q_unpad
,
k_unpad
,
v_unpad
],
dim
=
1
)
qkv
=
torch
.
stack
([
q
,
k
,
v
],
dim
=
2
)
if
query_padding_mask
is
not
None
:
dqkv_pad_fn
=
lambda
dqkv_unpad
:
pad_input
(
dqkv_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
dqkv_pad_fn
=
lambda
dqkv_unpad
:
rearrange
(
dqkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
qkv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
max_seqlen_q
,
qkv
.
detach
().
requires_grad_
(),
output_pad_fn
,
dqkv_pad_fn
,
)
elif
kvpacked
:
kv_unpad
=
torch
.
stack
([
k_unpad
,
v_unpad
],
dim
=
1
)
kv
=
torch
.
stack
([
k
,
v
],
dim
=
2
)
dq_pad_fn
=
output_pad_fn
if
key_padding_mask
is
not
None
:
dkv_pad_fn
=
lambda
dkv_unpad
:
pad_input
(
dkv_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
dkv_pad_fn
=
lambda
dkv_unpad
:
rearrange
(
dkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
kv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
.
detach
().
requires_grad_
(),
kv
.
detach
().
requires_grad_
(),
output_pad_fn
,
dq_pad_fn
,
dkv_pad_fn
,
)
else
:
dq_pad_fn
=
output_pad_fn
if
key_padding_mask
is
not
None
:
dk_pad_fn
=
lambda
dk_unpad
:
pad_input
(
dk_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
dk_pad_fn
=
lambda
dk_unpad
:
rearrange
(
dk_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
k_unpad
.
detach
().
requires_grad_
(),
v_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
cu_seqlens_k
,
seqused_q
,
seqused_k
,
max_seqlen_q
,
max_seqlen_k
,
q
.
detach
().
requires_grad_
(),
k
.
detach
().
requires_grad_
(),
v
.
detach
().
requires_grad_
(),
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
not
is_fa3_supported
(),
not
is_fa3_supported
(),
reason
=
"flash_attn at sgl-kernel is only supported on sm90 and above"
,
reason
=
"flash_attn at sgl-kernel is only supported on sm90 and above"
,
...
@@ -855,5 +1001,320 @@ def _generate_block_kvcache(
...
@@ -855,5 +1001,320 @@ def _generate_block_kvcache(
return
k_cache
,
v_cache
,
page_table
,
k_cache_paged
,
v_cache_paged
,
num_blocks
return
k_cache
,
v_cache
,
page_table
,
k_cache_paged
,
v_cache_paged
,
num_blocks
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
]
+
([
torch
.
float8_e4m3fn
]
if
not
DISABLE_FP8
else
[])
)
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
])
# @pytest.mark.parametrize("has_qv", [False, True])
@
pytest
.
mark
.
parametrize
(
"has_qv"
,
[
False
])
# @pytest.mark.parametrize("deterministic", [False, True])
@
pytest
.
mark
.
parametrize
(
"deterministic"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"softcap"
,
[
0.0
]
+
([
15.0
]
if
not
DISABLE_SOFTCAP
else
[]))
# @pytest.mark.parametrize("softcap", [0.0])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
])
# @pytest.mark.parametrize("local", [False])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [False])
@
pytest
.
mark
.
parametrize
(
"add_unused_qkv"
,
[
False
,
True
])
# @pytest.mark.parametrize("add_unused_qkv", [True])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
# @pytest.mark.parametrize("d", [64, 96, 128])
# @pytest.mark.parametrize("d", COMPILED_HDIMS)
@
pytest
.
mark
.
parametrize
(
"d"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
1
),
(
1
,
3
),
(
2
,
1
),
(
511
,
1
),
(
3
,
513
),
(
64
,
128
),
(
128
,
128
),
(
256
,
256
),
(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
307
,
256
),
(
640
,
128
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1024
,
1023
),
(
2048
,
2048
),
],
)
def
test_flash_attn_varlen_output
(
seqlen_q
,
seqlen_k
,
d
,
add_unused_qkv
,
causal
,
local
,
softcap
,
deterministic
,
has_qv
,
mha_type
,
dtype
,
):
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
seqlen_q
+
seqlen_k
+
d
+
int
(
causal
)
*
2
+
int
(
local
))
# batch_size = 40
# nheads = 16
batch_size
=
9
if
seqlen_q
<=
2048
else
2
nheads
=
6
# batch_size = 2
# nheads = 1
nheads_kv
=
nheads
if
mha_type
==
"mha"
else
(
2
if
mha_type
==
"gqa"
else
1
)
dtype_ref
=
torch
.
bfloat16
if
dtype
==
torch
.
float8_e4m3fn
else
dtype
dv_vals
=
[
128
,
d
]
if
d
>
128
and
d
<=
192
else
([
256
,
512
,
d
]
if
d
<=
64
else
[
d
])
if
dtype
==
torch
.
float8_e4m3fn
:
dv_vals
=
[
d
]
for
dv
in
dv_vals
:
q_ref
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype_ref
)
if
softcap
>
0.0
:
# Ensure the values of qk are at least within softcap range.
q_ref
=
(
q_ref
*
softcap
/
4
).
detach
().
requires_grad_
()
q_ref
=
q_ref
.
to
(
dtype
).
to
(
dtype_ref
).
requires_grad_
()
k_ref
=
(
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_kv
,
d
,
device
=
device
,
dtype
=
dtype_ref
)
.
to
(
dtype
)
.
to
(
dtype_ref
)
.
requires_grad_
()
)
v_ref
=
(
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_kv
,
dv
,
device
=
device
,
dtype
=
dtype_ref
)
.
to
(
dtype
)
.
to
(
dtype_ref
)
.
requires_grad_
()
)
if
has_qv
:
qv_ref
=
(
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
dv
,
device
=
device
,
dtype
=
dtype_ref
)
.
to
(
dtype
)
.
to
(
dtype_ref
)
)
else
:
qv_ref
=
None
# Put window_size after QKV randn so that window_size changes from test to test
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
if
dtype
==
torch
.
float8_e4m3fn
:
q_descale
,
k_descale
,
v_descale
=
[
torch
.
rand
(
batch_size
,
nheads_kv
,
device
=
device
,
dtype
=
torch
.
float32
)
*
2
for
_
in
range
(
3
)
]
else
:
q_descale
,
k_descale
,
v_descale
=
None
,
None
,
None
q
,
k
,
v
=
[
x
.
detach
().
requires_grad_
()
for
x
in
(
q_ref
,
k_ref
,
v_ref
)]
qv
=
qv_ref
.
detach
()
if
has_qv
else
None
query_padding_mask
=
generate_random_padding_mask
(
seqlen_q
,
batch_size
,
device
,
mode
=
"random"
,
zero_lengths
=
False
)
key_padding_mask
=
generate_random_padding_mask
(
seqlen_k
,
batch_size
,
device
,
mode
=
"random"
,
zero_lengths
=
True
)
def
_gen_unused_masks
(
padding_mask
,
add_unused
,
max_seq_len
,
bs
,
device
):
if
add_unused
:
another_mask
=
generate_random_padding_mask
(
max_seq_len
,
bs
,
device
)
attn_mask
=
torch
.
logical_and
(
padding_mask
,
another_mask
)
unused_mask
=
torch
.
logical_xor
(
torch
.
logical_or
(
padding_mask
,
another_mask
),
attn_mask
)
else
:
attn_mask
=
padding_mask
unused_mask
=
None
return
attn_mask
,
unused_mask
query_padding_mask
,
query_unused_mask
=
_gen_unused_masks
(
query_padding_mask
,
add_unused_qkv
,
seqlen_q
,
batch_size
,
q
.
device
)
key_padding_mask
,
key_unused_mask
=
_gen_unused_masks
(
key_padding_mask
,
add_unused_qkv
,
seqlen_k
,
batch_size
,
k
.
device
)
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
seqused_q
,
seqused_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
k
,
v
,
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
=
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
kvpacked
=
False
,
query_unused_mask
=
query_unused_mask
,
key_unused_mask
=
key_unused_mask
,
)
q_unpad
,
k_unpad
,
v_unpad
=
[
x
.
detach
().
to
(
dtype
).
requires_grad_
()
for
x
in
(
q_unpad
,
k_unpad
,
v_unpad
)
]
out_ref
,
attn_ref
=
attention_ref
(
q_ref
,
k_ref
,
v_ref
,
query_padding_mask
,
key_padding_mask
,
causal
=
causal
,
qv
=
qv_ref
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
window_size
=
window_size
,
softcap
=
softcap
,
)
out_pt
,
attn_pt
=
attention_ref
(
q_ref
,
k_ref
,
v_ref
,
query_padding_mask
,
key_padding_mask
,
causal
=
causal
,
qv
=
qv_ref
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
reorder_ops
=
True
,
intermediate_dtype
=
dtype
if
dtype
==
torch
.
float8_e4m3fn
else
None
,
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
if
query_unused_mask
is
not
None
:
q_zero_masking
=
rearrange
(
query_unused_mask
,
"b s -> b s 1 1"
)
# Numerical error if we just do any arithmetic on out_ref
fwd_atol
=
2
*
(
out_ref
+
0.3
-
0.3
-
out_ref
).
abs
().
max
().
item
()
rtol
=
2
if
softcap
==
0.0
else
3
pack_gqa_vals
=
[
False
,
True
]
if
not
DISABLE_PACKGQA
else
[
False
]
num_splits_vals
=
[
1
,
3
]
if
not
DISABLE_SPLIT
else
[
1
]
for
pack_gqa
,
num_splits
in
itertools
.
product
(
pack_gqa_vals
,
num_splits_vals
):
out_unpad
,
lse
,
*
rest
=
flash_attn_varlen_func
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
seqused_q
=
seqused_q
,
seqused_k
=
seqused_k
,
causal
=
causal
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
window_size
=
window_size
,
softcap
=
softcap
,
return_softmax_lse
=
True
,
)
out
=
output_pad_fn
(
out_unpad
)
if
query_unused_mask
is
not
None
:
out
.
masked_fill_
(
q_zero_masking
,
0.0
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most 3x the numerical error
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
rtol
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
fwd_atol
if
not
DISABLE_BACKWARD
and
dtype
!=
torch
.
float8_e4m3fn
and
not
has_qv
:
g_unpad
=
torch
.
randn_like
(
out_unpad
)
do_o
=
((
g_unpad
.
float
()
*
out_unpad
.
float
()).
sum
(
-
1
)).
transpose
(
-
1
,
-
2
)
dq_unpad
,
dk_unpad
,
dv_unpad
=
torch
.
autograd
.
grad
(
out_unpad
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g_unpad
)
dq
=
dq_pad_fn
(
dq_unpad
)
dk
=
dk_pad_fn
(
dk_unpad
)
dv
=
dk_pad_fn
(
dv_unpad
)
if
key_unused_mask
is
not
None
:
k_zero_masking
=
rearrange
(
key_unused_mask
,
"b s -> b s 1 1"
)
dk
.
masked_fill_
(
k_zero_masking
,
0.0
)
dv
.
masked_fill_
(
k_zero_masking
,
0.0
)
if
query_unused_mask
is
not
None
:
dq
.
masked_fill_
(
q_zero_masking
,
0.0
)
# print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}")
# assert (softmax_d - do_o).abs().max().item() <= 1e-5
# assert dq_accum.abs().max().item() == 0.0
g
=
output_pad_fn
(
g_unpad
)
# dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
dq_ref
,
dk_ref
,
dv_ref
=
torch
.
autograd
.
grad
(
out_ref
,
(
q_ref
,
k_ref
,
v_ref
),
g
)
dq_pt
,
dk_pt
,
dv_pt
=
torch
.
autograd
.
grad
(
out_pt
,
(
q_ref
,
k_ref
,
v_ref
),
g
)
print
(
f
"dQ max diff:
{
(
dq
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK max diff:
{
(
dk
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV max diff:
{
(
dv
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ mean diff:
{
(
dq
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK mean diff:
{
(
dk
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV mean diff:
{
(
dv
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dQ Pytorch max diff:
{
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK Pytorch max diff:
{
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV Pytorch max diff:
{
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ Pytorch mean diff:
{
(
dq_pt
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK Pytorch mean diff:
{
(
dk_pt
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV Pytorch mean diff:
{
(
dv_pt
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
if
not
DISABLE_BACKWARD
and
dtype
!=
torch
.
float8_e4m3fn
and
not
has_qv
:
dq_atol
=
2
*
(
dq_ref
+
0.3
-
0.3
-
dq_ref
).
abs
().
max
().
item
()
+
(
0
if
softcap
==
0
else
3e-4
)
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
rtol
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
+
dq_atol
dk_atol
=
2
*
(
dk_ref
+
0.3
-
0.3
-
dk_ref
).
abs
().
max
().
item
()
+
(
0
if
softcap
==
0
else
3e-4
)
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
rtol
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
+
dk_atol
dv_atol
=
2
*
(
dv_ref
+
0.3
-
0.3
-
dv_ref
).
abs
().
max
().
item
()
+
(
0
if
softcap
==
0
else
3e-4
)
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
rtol
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
dv_atol
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment