Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
94f44b88
"xcode/Samples/FrameworkSample/runtests.sh" did not exist on "172b233a0477f82a5bc2b6d15f0647db1cfe5a1e"
Unverified
Commit
94f44b88
authored
Aug 13, 2025
by
Ke Bao
Committed by
GitHub
Aug 13, 2025
Browse files
Update fa3 interface and add unit test (#9150)
parent
3b3b3baf
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
12 deletions
+54
-12
sgl-kernel/csrc/flash_extension.cc
sgl-kernel/csrc/flash_extension.cc
+2
-1
sgl-kernel/include/sgl_flash_kernel_ops.h
sgl-kernel/include/sgl_flash_kernel_ops.h
+2
-1
sgl-kernel/python/sgl_kernel/flash_attn.py
sgl-kernel/python/sgl_kernel/flash_attn.py
+4
-0
sgl-kernel/tests/test_flash_attention.py
sgl-kernel/tests/test_flash_attention.py
+46
-10
No files found.
sgl-kernel/csrc/flash_extension.cc
View file @
94f44b88
...
@@ -55,7 +55,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -55,7 +55,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" Tensor? scheduler_metadata,"
" Tensor? scheduler_metadata,"
" int num_splits,"
" int num_splits,"
" bool? pack_gqa,"
" bool? pack_gqa,"
" int sm_margin) -> Tensor[]"
);
" int sm_margin,"
" Tensor? sinks) -> Tensor[]"
);
m
.
impl
(
"fwd"
,
torch
::
kCUDA
,
make_pytorch_shim
(
&
mha_fwd
));
m
.
impl
(
"fwd"
,
torch
::
kCUDA
,
make_pytorch_shim
(
&
mha_fwd
));
}
}
...
...
sgl-kernel/include/sgl_flash_kernel_ops.h
View file @
94f44b88
...
@@ -82,4 +82,5 @@ std::vector<at::Tensor> mha_fwd(
...
@@ -82,4 +82,5 @@ std::vector<at::Tensor> mha_fwd(
std
::
optional
<
at
::
Tensor
>&
scheduler_metadata_
,
// (b + 1)
std
::
optional
<
at
::
Tensor
>&
scheduler_metadata_
,
// (b + 1)
int
num_splits
,
int
num_splits
,
std
::
optional
<
bool
>
pack_gqa_
,
std
::
optional
<
bool
>
pack_gqa_
,
int
const
sm_margin
);
int
const
sm_margin
,
std
::
optional
<
const
at
::
Tensor
>&
sinks_
);
sgl-kernel/python/sgl_kernel/flash_attn.py
View file @
94f44b88
...
@@ -58,6 +58,7 @@ def flash_attn_with_kvcache(
...
@@ -58,6 +58,7 @@ def flash_attn_with_kvcache(
pack_gqa
=
None
,
# Can be tuned for speed
pack_gqa
=
None
,
# Can be tuned for speed
sm_margin
=
0
,
# Can be tuned if some SMs are used for communication
sm_margin
=
0
,
# Can be tuned if some SMs are used for communication
return_softmax_lse
=
False
,
return_softmax_lse
=
False
,
sinks
=
None
,
):
):
"""
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
...
@@ -205,6 +206,7 @@ def flash_attn_with_kvcache(
...
@@ -205,6 +206,7 @@ def flash_attn_with_kvcache(
num_splits
,
num_splits
,
pack_gqa
,
pack_gqa
,
sm_margin
,
sm_margin
,
sinks
,
)
)
# return (out, softmax_lse) if return_softmax_lse else out
# return (out, softmax_lse) if return_softmax_lse else out
return
(
out
,
softmax_lse
,
*
rest
)
if
return_softmax_lse
else
out
return
(
out
,
softmax_lse
,
*
rest
)
if
return_softmax_lse
else
out
...
@@ -232,6 +234,7 @@ def flash_attn_varlen_func(
...
@@ -232,6 +234,7 @@ def flash_attn_varlen_func(
pack_gqa
=
None
,
pack_gqa
=
None
,
sm_margin
=
0
,
sm_margin
=
0
,
return_softmax_lse
=
False
,
return_softmax_lse
=
False
,
sinks
=
None
,
):
):
if
not
is_fa3_supported
():
if
not
is_fa3_supported
():
raise
NotImplementedError
(
raise
NotImplementedError
(
...
@@ -277,6 +280,7 @@ def flash_attn_varlen_func(
...
@@ -277,6 +280,7 @@ def flash_attn_varlen_func(
num_splits
=
num_splits
,
num_splits
=
num_splits
,
pack_gqa
=
pack_gqa
,
pack_gqa
=
pack_gqa
,
sm_margin
=
sm_margin
,
sm_margin
=
sm_margin
,
sinks
=
sinks
,
)
)
return
(
out
,
softmax_lse
,
*
rest
)
if
return_softmax_lse
else
out
return
(
out
,
softmax_lse
,
*
rest
)
if
return_softmax_lse
else
out
sgl-kernel/tests/test_flash_attention.py
View file @
94f44b88
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py
import
itertools
import
itertools
import
math
import
math
import
os
from
typing
import
Optional
import
pytest
import
pytest
import
torch
import
torch
...
@@ -45,12 +45,12 @@ DISABLE_BACKWARD = True
...
@@ -45,12 +45,12 @@ DISABLE_BACKWARD = True
# or torch.cuda.get_device_capability("cuda")[0] < 9
# or torch.cuda.get_device_capability("cuda")[0] < 9
# )
# )
DISABLE_SPLIT
=
Tru
e
DISABLE_SPLIT
=
Fals
e
DISABLE_PAGEDKV
=
True
DISABLE_PAGEDKV
=
True
DISABLE_APPENDKV
=
Tru
e
DISABLE_APPENDKV
=
Fals
e
DISABLE_LOCAL
=
Tru
e
DISABLE_LOCAL
=
Fals
e
DISABLE_SOFTCAP
=
True
DISABLE_SOFTCAP
=
True
DISABLE_PACKGQA
=
Tru
e
DISABLE_PACKGQA
=
Fals
e
DISABLE_FP16
=
True
DISABLE_FP16
=
True
DISABLE_FP8
=
True
DISABLE_FP8
=
True
...
@@ -199,6 +199,7 @@ def attention_ref(
...
@@ -199,6 +199,7 @@ def attention_ref(
v_descale
=
None
,
v_descale
=
None
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
window_size
=
(
-
1
,
-
1
),
# -1 means infinite window size
sink_token_length
=
0
,
sink_token_length
=
0
,
sinks
:
Optional
[
torch
.
Tensor
]
=
None
,
softcap
=
0.0
,
softcap
=
0.0
,
upcast
=
True
,
upcast
=
True
,
reorder_ops
=
False
,
reorder_ops
=
False
,
...
@@ -271,7 +272,18 @@ def attention_ref(
...
@@ -271,7 +272,18 @@ def attention_ref(
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
scores
.
masked_fill_
(
local_mask
,
float
(
"-inf"
))
if
attn_bias
is
not
None
:
if
attn_bias
is
not
None
:
scores
=
scores
+
attn_bias
scores
=
scores
+
attn_bias
if
sinks
is
None
:
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
v
.
dtype
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
).
to
(
v
.
dtype
)
else
:
scores_fp32
=
scores
.
to
(
torch
.
float32
)
logits_max
=
torch
.
amax
(
scores_fp32
,
dim
=-
1
,
keepdim
=
True
)
sinks
=
rearrange
(
sinks
,
"h -> h 1 1"
)
logits_or_sinks_max
=
torch
.
maximum
(
sinks
,
logits_max
)
unnormalized_scores
=
torch
.
exp
(
scores_fp32
-
logits_or_sinks_max
)
normalizer
=
unnormalized_scores
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
torch
.
exp
(
sinks
-
logits_or_sinks_max
)
attention
=
(
unnormalized_scores
/
normalizer
).
to
(
v
.
dtype
)
# We want to mask here so that the attention matrix doesn't have any NaNs
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
# Otherwise we'll get NaN in dV
if
query_padding_mask
is
not
None
:
if
query_padding_mask
is
not
None
:
...
@@ -459,8 +471,10 @@ def generate_qkv(
...
@@ -459,8 +471,10 @@ def generate_qkv(
)
)
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
])
# @pytest.mark.parametrize("mha_type", ["mha"])
@
pytest
.
mark
.
parametrize
(
"has_sink"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_sink", [False])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
]
+
([
True
]
if
not
DISABLE_APPENDKV
else
[]))
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
]
+
([
True
]
if
not
DISABLE_APPENDKV
else
[]))
# @pytest.mark.parametrize("new_kv", [True])
# @pytest.mark.parametrize("new_kv", [True])
# @pytest.mark.parametrize(
# @pytest.mark.parametrize(
...
@@ -540,6 +554,7 @@ def test_flash_attn_kvcache(
...
@@ -540,6 +554,7 @@ def test_flash_attn_kvcache(
new_kv
,
new_kv
,
mha_type
,
mha_type
,
dtype
,
dtype
,
has_sink
,
):
):
from
sgl_kernel.flash_attn
import
flash_attn_with_kvcache
from
sgl_kernel.flash_attn
import
flash_attn_with_kvcache
...
@@ -565,6 +580,12 @@ def test_flash_attn_kvcache(
...
@@ -565,6 +580,12 @@ def test_flash_attn_kvcache(
assert
nheads
%
nheads_k
==
0
assert
nheads
%
nheads_k
==
0
dtype_ref
=
torch
.
bfloat16
if
dtype
==
torch
.
float8_e4m3fn
else
dtype
dtype_ref
=
torch
.
bfloat16
if
dtype
==
torch
.
float8_e4m3fn
else
dtype
dv_vals
=
[
128
,
d
]
if
d
>
128
and
d
<=
192
else
([
256
,
512
,
d
]
if
d
<=
64
else
[
d
])
dv_vals
=
[
128
,
d
]
if
d
>
128
and
d
<=
192
else
([
256
,
512
,
d
]
if
d
<=
64
else
[
d
])
if
has_sink
:
sinks
=
torch
.
randn
(
nheads
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
else
:
sinks
=
None
if
dtype
==
torch
.
float8_e4m3fn
or
not
is_hopper
():
if
dtype
==
torch
.
float8_e4m3fn
or
not
is_hopper
():
# for fp8 and ampere arch, we not support v head dim != qk head dim
# for fp8 and ampere arch, we not support v head dim != qk head dim
dv_vals
=
[
d
]
dv_vals
=
[
d
]
...
@@ -820,6 +841,7 @@ def test_flash_attn_kvcache(
...
@@ -820,6 +841,7 @@ def test_flash_attn_kvcache(
qv
=
qv
,
qv
=
qv
,
window_size
=
window_size
,
window_size
=
window_size
,
key_leftpad
=
cache_leftpad
,
key_leftpad
=
cache_leftpad
,
sinks
=
sinks
,
)
)
out_pt
,
_
=
attention_ref
(
out_pt
,
_
=
attention_ref
(
q_ro
,
q_ro
,
...
@@ -834,6 +856,7 @@ def test_flash_attn_kvcache(
...
@@ -834,6 +856,7 @@ def test_flash_attn_kvcache(
reorder_ops
=
True
,
reorder_ops
=
True
,
key_leftpad
=
cache_leftpad
,
key_leftpad
=
cache_leftpad
,
intermediate_dtype
=
dtype
if
dtype
==
torch
.
float8_e4m3fn
else
None
,
intermediate_dtype
=
dtype
if
dtype
==
torch
.
float8_e4m3fn
else
None
,
sinks
=
sinks
,
)
)
q
=
q
.
to
(
dtype
)
q
=
q
.
to
(
dtype
)
q_unpad
=
q_unpad
.
to
(
dtype
)
if
varlen_q
else
None
q_unpad
=
q_unpad
.
to
(
dtype
)
if
varlen_q
else
None
...
@@ -888,6 +911,7 @@ def test_flash_attn_kvcache(
...
@@ -888,6 +911,7 @@ def test_flash_attn_kvcache(
scheduler_metadata
=
scheduler_metadata
,
scheduler_metadata
=
scheduler_metadata
,
num_splits
=
num_splits
,
num_splits
=
num_splits
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
sinks
=
sinks
,
)
)
if
varlen_q
:
if
varlen_q
:
out
=
output_pad_fn
(
out
)
out
=
output_pad_fn
(
out
)
...
@@ -1019,8 +1043,10 @@ def _generate_block_kvcache(
...
@@ -1019,8 +1043,10 @@ def _generate_block_kvcache(
)
)
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
])
# @pytest.mark.parametrize("mha_type", ["mha"])
@
pytest
.
mark
.
parametrize
(
"has_sink"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_sink", [False])
# @pytest.mark.parametrize("has_qv", [False, True])
# @pytest.mark.parametrize("has_qv", [False, True])
@
pytest
.
mark
.
parametrize
(
"has_qv"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"has_qv"
,
[
False
])
# @pytest.mark.parametrize("deterministic", [False, True])
# @pytest.mark.parametrize("deterministic", [False, True])
...
@@ -1078,6 +1104,7 @@ def test_flash_attn_varlen_output(
...
@@ -1078,6 +1104,7 @@ def test_flash_attn_varlen_output(
has_qv
,
has_qv
,
mha_type
,
mha_type
,
dtype
,
dtype
,
has_sink
,
):
):
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
...
@@ -1131,6 +1158,12 @@ def test_flash_attn_varlen_output(
...
@@ -1131,6 +1158,12 @@ def test_flash_attn_varlen_output(
qv_ref
=
None
qv_ref
=
None
# Put window_size after QKV randn so that window_size changes from test to test
# Put window_size after QKV randn so that window_size changes from test to test
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
window_size
=
(
-
1
,
-
1
)
if
not
local
else
torch
.
randint
(
0
,
seqlen_k
,
(
2
,))
if
has_sink
:
sinks
=
torch
.
randn
(
nheads
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
else
:
sinks
=
None
if
dtype
==
torch
.
float8_e4m3fn
:
if
dtype
==
torch
.
float8_e4m3fn
:
q_descale
,
k_descale
,
v_descale
=
[
q_descale
,
k_descale
,
v_descale
=
[
torch
.
rand
(
batch_size
,
nheads_kv
,
device
=
device
,
dtype
=
torch
.
float32
)
torch
.
rand
(
batch_size
,
nheads_kv
,
device
=
device
,
dtype
=
torch
.
float32
)
...
@@ -1209,6 +1242,7 @@ def test_flash_attn_varlen_output(
...
@@ -1209,6 +1242,7 @@ def test_flash_attn_varlen_output(
v_descale
=
v_descale
,
v_descale
=
v_descale
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
softcap
=
softcap
,
sinks
=
sinks
,
)
)
out_pt
,
attn_pt
=
attention_ref
(
out_pt
,
attn_pt
=
attention_ref
(
q_ref
,
q_ref
,
...
@@ -1226,6 +1260,7 @@ def test_flash_attn_varlen_output(
...
@@ -1226,6 +1260,7 @@ def test_flash_attn_varlen_output(
upcast
=
False
,
upcast
=
False
,
reorder_ops
=
True
,
reorder_ops
=
True
,
intermediate_dtype
=
dtype
if
dtype
==
torch
.
float8_e4m3fn
else
None
,
intermediate_dtype
=
dtype
if
dtype
==
torch
.
float8_e4m3fn
else
None
,
sinks
=
sinks
,
)
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
...
@@ -1258,6 +1293,7 @@ def test_flash_attn_varlen_output(
...
@@ -1258,6 +1293,7 @@ def test_flash_attn_varlen_output(
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
softcap
=
softcap
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
sinks
=
sinks
,
)
)
out
=
output_pad_fn
(
out_unpad
)
out
=
output_pad_fn
(
out_unpad
)
if
query_unused_mask
is
not
None
:
if
query_unused_mask
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment