Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
81c89111
"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "187888bbd3e5ec336a6e100062dc45a2b40d5896"
Unverified
Commit
81c89111
authored
Apr 17, 2025
by
Baizhou Zhang
Committed by
GitHub
Apr 17, 2025
Browse files
Add test for flash_attn_varlen_func kernel (#5484)
parent
92d1561b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
461 additions
and
0 deletions
+461
-0
sgl-kernel/tests/test_flash_attention.py
sgl-kernel/tests/test_flash_attention.py
+461
-0
No files found.
sgl-kernel/tests/test_flash_attention.py
View file @
81c89111
...
@@ -296,6 +296,152 @@ def attention_ref(
...
@@ -296,6 +296,152 @@ def attention_ref(
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
return
output
.
to
(
dtype
=
dtype_og
),
attention
.
to
(
dtype
=
dtype_og
)
def
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
=
None
,
key_padding_mask
=
None
,
kvpacked
=
False
,
qkvpacked
=
False
,
add_unused_qkv
=
False
,
query_unused_mask
=
None
,
key_unused_mask
=
None
,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert
not
(
kvpacked
and
qkvpacked
)
batch_size
,
seqlen_q
,
nheads
,
d
=
q
.
shape
_
,
seqlen_k
,
nheads_k
,
_
=
k
.
shape
assert
k
.
shape
==
(
batch_size
,
seqlen_k
,
nheads_k
,
d
)
assert
v
.
shape
==
(
batch_size
,
seqlen_k
,
nheads_k
,
d
)
if
query_unused_mask
is
not
None
or
key_unused_mask
is
not
None
:
assert
not
kvpacked
assert
not
qkvpacked
if
query_padding_mask
is
not
None
:
q_unpad
,
indices_q
,
cu_seqlens_q
,
max_seqlen_q
,
seqused_q
=
unpad_input
(
q
,
query_padding_mask
,
query_unused_mask
,
)
output_pad_fn
=
lambda
output_unpad
:
pad_input
(
output_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
q_unpad
=
rearrange
(
q
,
"b s h d -> (b s) h d"
)
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q_unpad
.
device
,
)
seqused_q
=
None
max_seqlen_q
=
seqlen_q
output_pad_fn
=
lambda
output_unpad
:
rearrange
(
output_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
if
key_padding_mask
is
not
None
:
k_unpad
,
indices_k
,
cu_seqlens_k
,
max_seqlen_k
,
seqused_k
=
unpad_input
(
k
,
key_padding_mask
,
key_unused_mask
)
v_unpad
,
_
,
_
,
_
,
_
=
unpad_input
(
v
,
key_padding_mask
,
key_unused_mask
)
else
:
k_unpad
=
rearrange
(
k
,
"b s h d -> (b s) h d"
)
v_unpad
=
rearrange
(
v
,
"b s h d -> (b s) h d"
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
k_unpad
.
device
,
)
seqused_k
=
None
max_seqlen_k
=
seqlen_k
if
qkvpacked
:
assert
(
query_padding_mask
==
key_padding_mask
).
all
()
assert
nheads
==
nheads_k
qkv_unpad
=
torch
.
stack
([
q_unpad
,
k_unpad
,
v_unpad
],
dim
=
1
)
qkv
=
torch
.
stack
([
q
,
k
,
v
],
dim
=
2
)
if
query_padding_mask
is
not
None
:
dqkv_pad_fn
=
lambda
dqkv_unpad
:
pad_input
(
dqkv_unpad
,
indices_q
,
batch_size
,
seqlen_q
)
else
:
dqkv_pad_fn
=
lambda
dqkv_unpad
:
rearrange
(
dqkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
qkv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
max_seqlen_q
,
qkv
.
detach
().
requires_grad_
(),
output_pad_fn
,
dqkv_pad_fn
,
)
elif
kvpacked
:
kv_unpad
=
torch
.
stack
([
k_unpad
,
v_unpad
],
dim
=
1
)
kv
=
torch
.
stack
([
k
,
v
],
dim
=
2
)
dq_pad_fn
=
output_pad_fn
if
key_padding_mask
is
not
None
:
dkv_pad_fn
=
lambda
dkv_unpad
:
pad_input
(
dkv_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
dkv_pad_fn
=
lambda
dkv_unpad
:
rearrange
(
dkv_unpad
,
"(b s) t h d -> b s t h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
kv_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
.
detach
().
requires_grad_
(),
kv
.
detach
().
requires_grad_
(),
output_pad_fn
,
dq_pad_fn
,
dkv_pad_fn
,
)
else
:
dq_pad_fn
=
output_pad_fn
if
key_padding_mask
is
not
None
:
dk_pad_fn
=
lambda
dk_unpad
:
pad_input
(
dk_unpad
,
indices_k
,
batch_size
,
seqlen_k
)
else
:
dk_pad_fn
=
lambda
dk_unpad
:
rearrange
(
dk_unpad
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
return
(
q_unpad
.
detach
().
requires_grad_
(),
k_unpad
.
detach
().
requires_grad_
(),
v_unpad
.
detach
().
requires_grad_
(),
cu_seqlens_q
,
cu_seqlens_k
,
seqused_q
,
seqused_k
,
max_seqlen_q
,
max_seqlen_k
,
q
.
detach
().
requires_grad_
(),
k
.
detach
().
requires_grad_
(),
v
.
detach
().
requires_grad_
(),
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
not
is_fa3_supported
(),
not
is_fa3_supported
(),
reason
=
"flash_attn at sgl-kernel is only supported on sm90 and above"
,
reason
=
"flash_attn at sgl-kernel is only supported on sm90 and above"
,
...
@@ -855,5 +1001,320 @@ def _generate_block_kvcache(
...
@@ -855,5 +1001,320 @@ def _generate_block_kvcache(
return
k_cache
,
v_cache
,
page_table
,
k_cache_paged
,
v_cache_paged
,
num_blocks
return
k_cache
,
v_cache
,
page_table
,
k_cache_paged
,
v_cache_paged
,
num_blocks
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
]
+
([
torch
.
float8_e4m3fn
]
if
not
DISABLE_FP8
else
[])
)
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
])
# @pytest.mark.parametrize("has_qv", [False, True])
@
pytest
.
mark
.
parametrize
(
"has_qv"
,
[
False
])
# @pytest.mark.parametrize("deterministic", [False, True])
@
pytest
.
mark
.
parametrize
(
"deterministic"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"softcap"
,
[
0.0
]
+
([
15.0
]
if
not
DISABLE_SOFTCAP
else
[]))
# @pytest.mark.parametrize("softcap", [0.0])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
])
# @pytest.mark.parametrize("local", [False])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [False])
@
pytest
.
mark
.
parametrize
(
"add_unused_qkv"
,
[
False
,
True
])
# @pytest.mark.parametrize("add_unused_qkv", [True])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
# @pytest.mark.parametrize("d", [64, 96, 128])
# @pytest.mark.parametrize("d", COMPILED_HDIMS)
@
pytest
.
mark
.
parametrize
(
"d"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
1
,
1
),
(
1
,
3
),
(
2
,
1
),
(
511
,
1
),
(
3
,
513
),
(
64
,
128
),
(
128
,
128
),
(
256
,
256
),
(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
307
,
256
),
(
640
,
128
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1024
,
1023
),
(
2048
,
2048
),
],
)
def
test_flash_attn_varlen_output
(
seqlen_q
,
seqlen_k
,
d
,
add_unused_qkv
,
causal
,
local
,
softcap
,
deterministic
,
has_qv
,
mha_type
,
dtype
,
):
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
seqlen_q
+
seqlen_k
+
d
+
int
(
causal
)
*
2
+
int
(
local
))
# batch_size = 40
# nheads = 16
batch_size
=
9
if
seqlen_q
<=
2048
else
2
nheads
=
6
# batch_size = 2
# nheads = 1
nheads_kv
=
nheads
if
mha_type
==
"mha"
else
(
2
if
mha_type
==
"gqa"
else
1
)
dtype_ref
=
torch
.
bfloat16
if
dtype
==
torch
.
float8_e4m3fn
else
dtype
dv_vals
=
[
128
,
d
]
if
d
>
128
and
d
<=
192
else
([
256
,
512
,
d
]
if
d
<=
64
else
[
d
])
if
dtype
==
torch
.
float8_e4m3fn
:
dv_vals
=
[
d
]
for
dv
in
dv_vals
:
q_ref
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype_ref
)
if
softcap
>
0.0
:
# Ensure the values of qk are at least within softcap range.
q_ref
=
(
q_ref
*
softcap
/
4
).
detach
().
requires_grad_
()
q_ref
=
q_ref
.
to
(
dtype
).
to
(
dtype_ref
).
requires_grad_
()
k_ref
=
(
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_kv
,
d
,
device
=
device
,
dtype
=
dtype_ref
)
.
to
(
dtype
)
.
to
(
dtype_ref
)
.
requires_grad_
()
)
v_ref
=
(
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_kv
,
dv
,
device
=
device
,
dtype
=
dtype_ref
)
.
to
(
dtype
)
.
to
(
dtype_ref
)
.
requires_grad_
()
)
if
has_qv
:
qv_ref
=
(
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
dv
,
device
=
device
,
dtype
=
dtype_ref
)
.
to
(
dtype
)
.
to
(
dtype_ref
)
)
else
:
qv_ref
=
None
# Put window_size after QKV randn so that window_size changes from test to test
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
if
dtype
==
torch
.
float8_e4m3fn
:
q_descale
,
k_descale
,
v_descale
=
[
torch
.
rand
(
batch_size
,
nheads_kv
,
device
=
device
,
dtype
=
torch
.
float32
)
*
2
for
_
in
range
(
3
)
]
else
:
q_descale
,
k_descale
,
v_descale
=
None
,
None
,
None
q
,
k
,
v
=
[
x
.
detach
().
requires_grad_
()
for
x
in
(
q_ref
,
k_ref
,
v_ref
)]
qv
=
qv_ref
.
detach
()
if
has_qv
else
None
query_padding_mask
=
generate_random_padding_mask
(
seqlen_q
,
batch_size
,
device
,
mode
=
"random"
,
zero_lengths
=
False
)
key_padding_mask
=
generate_random_padding_mask
(
seqlen_k
,
batch_size
,
device
,
mode
=
"random"
,
zero_lengths
=
True
)
def
_gen_unused_masks
(
padding_mask
,
add_unused
,
max_seq_len
,
bs
,
device
):
if
add_unused
:
another_mask
=
generate_random_padding_mask
(
max_seq_len
,
bs
,
device
)
attn_mask
=
torch
.
logical_and
(
padding_mask
,
another_mask
)
unused_mask
=
torch
.
logical_xor
(
torch
.
logical_or
(
padding_mask
,
another_mask
),
attn_mask
)
else
:
attn_mask
=
padding_mask
unused_mask
=
None
return
attn_mask
,
unused_mask
query_padding_mask
,
query_unused_mask
=
_gen_unused_masks
(
query_padding_mask
,
add_unused_qkv
,
seqlen_q
,
batch_size
,
q
.
device
)
key_padding_mask
,
key_unused_mask
=
_gen_unused_masks
(
key_padding_mask
,
add_unused_qkv
,
seqlen_k
,
batch_size
,
k
.
device
)
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
seqused_q
,
seqused_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
k
,
v
,
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
=
generate_qkv
(
q
,
k
,
v
,
query_padding_mask
,
key_padding_mask
,
kvpacked
=
False
,
query_unused_mask
=
query_unused_mask
,
key_unused_mask
=
key_unused_mask
,
)
q_unpad
,
k_unpad
,
v_unpad
=
[
x
.
detach
().
to
(
dtype
).
requires_grad_
()
for
x
in
(
q_unpad
,
k_unpad
,
v_unpad
)
]
out_ref
,
attn_ref
=
attention_ref
(
q_ref
,
k_ref
,
v_ref
,
query_padding_mask
,
key_padding_mask
,
causal
=
causal
,
qv
=
qv_ref
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
window_size
=
window_size
,
softcap
=
softcap
,
)
out_pt
,
attn_pt
=
attention_ref
(
q_ref
,
k_ref
,
v_ref
,
query_padding_mask
,
key_padding_mask
,
causal
=
causal
,
qv
=
qv_ref
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
window_size
=
window_size
,
softcap
=
softcap
,
upcast
=
False
,
reorder_ops
=
True
,
intermediate_dtype
=
dtype
if
dtype
==
torch
.
float8_e4m3fn
else
None
,
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
if
query_unused_mask
is
not
None
:
q_zero_masking
=
rearrange
(
query_unused_mask
,
"b s -> b s 1 1"
)
# Numerical error if we just do any arithmetic on out_ref
fwd_atol
=
2
*
(
out_ref
+
0.3
-
0.3
-
out_ref
).
abs
().
max
().
item
()
rtol
=
2
if
softcap
==
0.0
else
3
pack_gqa_vals
=
[
False
,
True
]
if
not
DISABLE_PACKGQA
else
[
False
]
num_splits_vals
=
[
1
,
3
]
if
not
DISABLE_SPLIT
else
[
1
]
for
pack_gqa
,
num_splits
in
itertools
.
product
(
pack_gqa_vals
,
num_splits_vals
):
out_unpad
,
lse
,
*
rest
=
flash_attn_varlen_func
(
q_unpad
,
k_unpad
,
v_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
seqused_q
=
seqused_q
,
seqused_k
=
seqused_k
,
causal
=
causal
,
q_descale
=
q_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
window_size
=
window_size
,
softcap
=
softcap
,
return_softmax_lse
=
True
,
)
out
=
output_pad_fn
(
out_unpad
)
if
query_unused_mask
is
not
None
:
out
.
masked_fill_
(
q_zero_masking
,
0.0
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
# Check that FlashAttention's numerical error is at most 3x the numerical error
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
rtol
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
+
fwd_atol
if
not
DISABLE_BACKWARD
and
dtype
!=
torch
.
float8_e4m3fn
and
not
has_qv
:
g_unpad
=
torch
.
randn_like
(
out_unpad
)
do_o
=
((
g_unpad
.
float
()
*
out_unpad
.
float
()).
sum
(
-
1
)).
transpose
(
-
1
,
-
2
)
dq_unpad
,
dk_unpad
,
dv_unpad
=
torch
.
autograd
.
grad
(
out_unpad
,
(
q_unpad
,
k_unpad
,
v_unpad
),
g_unpad
)
dq
=
dq_pad_fn
(
dq_unpad
)
dk
=
dk_pad_fn
(
dk_unpad
)
dv
=
dk_pad_fn
(
dv_unpad
)
if
key_unused_mask
is
not
None
:
k_zero_masking
=
rearrange
(
key_unused_mask
,
"b s -> b s 1 1"
)
dk
.
masked_fill_
(
k_zero_masking
,
0.0
)
dv
.
masked_fill_
(
k_zero_masking
,
0.0
)
if
query_unused_mask
is
not
None
:
dq
.
masked_fill_
(
q_zero_masking
,
0.0
)
# print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}")
# assert (softmax_d - do_o).abs().max().item() <= 1e-5
# assert dq_accum.abs().max().item() == 0.0
g
=
output_pad_fn
(
g_unpad
)
# dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
dq_ref
,
dk_ref
,
dv_ref
=
torch
.
autograd
.
grad
(
out_ref
,
(
q_ref
,
k_ref
,
v_ref
),
g
)
dq_pt
,
dk_pt
,
dv_pt
=
torch
.
autograd
.
grad
(
out_pt
,
(
q_ref
,
k_ref
,
v_ref
),
g
)
print
(
f
"dQ max diff:
{
(
dq
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK max diff:
{
(
dk
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV max diff:
{
(
dv
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ mean diff:
{
(
dq
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK mean diff:
{
(
dk
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV mean diff:
{
(
dv
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dQ Pytorch max diff:
{
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dK Pytorch max diff:
{
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dV Pytorch max diff:
{
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"dQ Pytorch mean diff:
{
(
dq_pt
-
dq_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dK Pytorch mean diff:
{
(
dk_pt
-
dk_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"dV Pytorch mean diff:
{
(
dv_pt
-
dv_ref
).
abs
().
mean
().
item
()
}
"
)
if
not
DISABLE_BACKWARD
and
dtype
!=
torch
.
float8_e4m3fn
and
not
has_qv
:
dq_atol
=
2
*
(
dq_ref
+
0.3
-
0.3
-
dq_ref
).
abs
().
max
().
item
()
+
(
0
if
softcap
==
0
else
3e-4
)
assert
(
dq
-
dq_ref
).
abs
().
max
().
item
()
<=
rtol
*
(
dq_pt
-
dq_ref
).
abs
().
max
().
item
()
+
dq_atol
dk_atol
=
2
*
(
dk_ref
+
0.3
-
0.3
-
dk_ref
).
abs
().
max
().
item
()
+
(
0
if
softcap
==
0
else
3e-4
)
assert
(
dk
-
dk_ref
).
abs
().
max
().
item
()
<=
rtol
*
(
dk_pt
-
dk_ref
).
abs
().
max
().
item
()
+
dk_atol
dv_atol
=
2
*
(
dv_ref
+
0.3
-
0.3
-
dv_ref
).
abs
().
max
().
item
()
+
(
0
if
softcap
==
0
else
3e-4
)
assert
(
dv
-
dv_ref
).
abs
().
max
().
item
()
<=
rtol
*
(
dv_pt
-
dv_ref
).
abs
().
max
().
item
()
+
dv_atol
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment