Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
252dc4e1
Unverified
Commit
252dc4e1
authored
Oct 20, 2025
by
Johnny
Committed by
GitHub
Oct 19, 2025
Browse files
[NVIDIA] FA3/FA4 Fix (#11606)
Co-authored-by:
Baizhou Zhang
<
sobereddiezhang@gmail.com
>
parent
cbb5fc2e
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
381 additions
and
218 deletions
+381
-218
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+10
-0
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+8
-0
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+2
-2
sgl-kernel/csrc/flash_extension.cc
sgl-kernel/csrc/flash_extension.cc
+29
-26
sgl-kernel/include/sgl_flash_kernel_ops.h
sgl-kernel/include/sgl_flash_kernel_ops.h
+36
-37
sgl-kernel/python/sgl_kernel/_fa4_interface.py
sgl-kernel/python/sgl_kernel/_fa4_interface.py
+150
-76
sgl-kernel/python/sgl_kernel/flash_attn.py
sgl-kernel/python/sgl_kernel/flash_attn.py
+19
-14
sgl-kernel/tests/test_flash_attention_4.py
sgl-kernel/tests/test_flash_attention_4.py
+70
-63
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_flash_attention_4.py
test/srt/test_flash_attention_4.py
+56
-0
No files found.
python/sglang/srt/server_args.py
View file @
252dc4e1
...
@@ -1071,6 +1071,16 @@ class ServerArgs:
...
@@ -1071,6 +1071,16 @@ class ServerArgs:
self
.
enable_mixed_chunk
=
False
self
.
enable_mixed_chunk
=
False
self
.
disable_radix_cache
=
True
self
.
disable_radix_cache
=
True
if
self
.
attention_backend
==
"fa4"
or
self
.
decode_attention_backend
==
"fa4"
:
raise
ValueError
(
"FA4 backend is only supported for prefill. Please use `--prefill-attention-backend fa4` instead."
)
if
self
.
prefill_attention_backend
==
"fa4"
:
logger
.
warning
(
f
"FA4 backend only supports page size 128, changing page_size from
{
self
.
page_size
}
to 128."
)
self
.
page_size
=
128
def
_handle_page_size
(
self
):
def
_handle_page_size
(
self
):
if
self
.
page_size
is
None
:
if
self
.
page_size
is
None
:
self
.
page_size
=
1
self
.
page_size
=
1
...
...
python/sglang/test/test_utils.py
View file @
252dc4e1
...
@@ -129,6 +129,11 @@ def is_in_amd_ci():
...
@@ -129,6 +129,11 @@ def is_in_amd_ci():
return
get_bool_env_var
(
"SGLANG_IS_IN_CI_AMD"
)
return
get_bool_env_var
(
"SGLANG_IS_IN_CI_AMD"
)
def
is_blackwell_system
():
"""Return whether it is running on a Blackwell (B200) system."""
return
get_bool_env_var
(
"IS_BLACKWELL"
)
def
_use_cached_default_models
(
model_repo
:
str
):
def
_use_cached_default_models
(
model_repo
:
str
):
cache_dir
=
os
.
getenv
(
"DEFAULT_MODEL_CACHE_DIR"
)
cache_dir
=
os
.
getenv
(
"DEFAULT_MODEL_CACHE_DIR"
)
if
cache_dir
and
model_repo
:
if
cache_dir
and
model_repo
:
...
@@ -151,6 +156,9 @@ DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 10
...
@@ -151,6 +156,9 @@ DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 10
if
is_in_amd_ci
():
if
is_in_amd_ci
():
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
=
3000
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
=
3000
if
is_blackwell_system
():
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
=
3000
def
call_generate_lightllm
(
prompt
,
temperature
,
max_tokens
,
stop
=
None
,
url
=
None
):
def
call_generate_lightllm
(
prompt
,
temperature
,
max_tokens
,
stop
=
None
,
url
=
None
):
assert
url
is
not
None
assert
url
is
not
None
...
...
sgl-kernel/CMakeLists.txt
View file @
252dc4e1
...
@@ -91,7 +91,7 @@ FetchContent_Populate(repo-flashinfer)
...
@@ -91,7 +91,7 @@ FetchContent_Populate(repo-flashinfer)
FetchContent_Declare
(
FetchContent_Declare
(
repo-flash-attention
repo-flash-attention
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
GIT_TAG f
9af0c2a1d82ab1812e6987e9338363cc2bf0f8d
GIT_TAG f
f87110aad048bb8c4e6effea4c563ddae88b0eb
GIT_SHALLOW OFF
GIT_SHALLOW OFF
)
)
FetchContent_Populate
(
repo-flash-attention
)
FetchContent_Populate
(
repo-flash-attention
)
...
@@ -100,7 +100,7 @@ FetchContent_Populate(repo-flash-attention)
...
@@ -100,7 +100,7 @@ FetchContent_Populate(repo-flash-attention)
FetchContent_Declare
(
FetchContent_Declare
(
repo-flash-attention-origin
repo-flash-attention-origin
GIT_REPOSITORY https://github.com/Dao-AILab/flash-attention.git
GIT_REPOSITORY https://github.com/Dao-AILab/flash-attention.git
GIT_TAG
203b9b3dba39d5d08dffb49c09aa622984dff07d
GIT_TAG
04adaf0e9028d4bec7073f69e4dfa3f6d3357189
GIT_SHALLOW OFF
GIT_SHALLOW OFF
)
)
FetchContent_Populate
(
repo-flash-attention-origin
)
FetchContent_Populate
(
repo-flash-attention-origin
)
...
...
sgl-kernel/csrc/flash_extension.cc
View file @
252dc4e1
...
@@ -23,40 +23,43 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -23,40 +23,43 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
* From flash-attention
* From flash-attention
*/
*/
m
.
def
(
m
.
def
(
"fwd(Tensor
!
q,"
"fwd(Tensor
q,"
// (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
" Tensor k,"
" Tensor k,"
// (b_k, s_k, h_k, d) or (total_k, h_k, d) or paged
" Tensor v,"
" Tensor v,"
// (b_k, s_k, h_k, dv) or (total_k, h_k, dv) or paged
" Tensor? k_new,"
" Tensor? k_new,"
// (b, s_k_new, h_k, d) or (total_k_new, h_k, d)
" Tensor? v_new,"
" Tensor? v_new,"
// (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv)
" Tensor? q_v,"
" Tensor? q_v,"
// (b, s_q, h, dv) or (total_q_new, h, dv)
" Tensor
!
? out,"
" Tensor?
out,"
// (b, s_q, h, dv) or (total_q, h, dv)
" Tensor? cu_seqlens_q,"
" Tensor? cu_seqlens_q,"
// b+1
" Tensor? cu_seqlens_k,"
" Tensor? cu_seqlens_k,"
// b+1
" Tensor? cu_seqlens_k_new,"
" Tensor? cu_seqlens_k_new,"
// b+1
" Tensor? seqused_q,"
" Tensor? seqused_q,"
// b
" Tensor? seqused_k,"
" Tensor? seqused_k,"
// b
" int? max_seqlen_q,"
" int? max_seqlen_q,"
" int? max_seqlen_k,"
" int? max_seqlen_k,"
// TODO: check if needed
" Tensor? page_table,"
" Tensor? page_table,"
// (b_k, max_num_pages_per_seq)
" Tensor? kv_batch_idx,"
" Tensor? kv_batch_idx,"
// b
" Tensor? leftpad_k,"
" Tensor? leftpad_k,"
// b
" Tensor? rotary_cos,"
" Tensor? rotary_cos,"
// seqlen_ro x (rotary_dim / 2)
" Tensor? rotary_sin,"
" Tensor? rotary_sin,"
// seqlen_ro x (rotary_dim / 2)
" Tensor? seqlens_rotary,"
" Tensor? seqlens_rotary,"
// b
" Tensor? q_descale,"
" Tensor? q_descale,"
// (b, h_k)
" Tensor? k_descale,"
" Tensor? k_descale,"
// (b, h_k)
" Tensor? v_descale,"
" Tensor? v_descale,"
// (b, h_k)
" float
softmax_scale,"
" float
?
softmax_scale,"
// now optional
" bool is_causal,"
" bool is_causal,"
" int window_size_left,"
" int window_size_left,"
" int window_size_right,"
" int window_size_right,"
" float softcap,"
" int attention_chunk,"
// NEW
" float softcap,"
// promoted to double in C++; schema float is fine
" bool is_rotary_interleaved,"
" bool is_rotary_interleaved,"
" Tensor? scheduler_metadata,"
" Tensor? scheduler_metadata,"
// (b + 1)
" int num_splits,"
" int num_splits,"
" bool? pack_gqa,"
" bool? pack_gqa,"
" int sm_margin,"
" int sm_margin,"
" Tensor? sinks) -> Tensor[]"
);
" Tensor? sinks"
") -> (Tensor, Tensor, Tensor, Tensor)"
);
// NEW return type: tuple of 4 tensors
m
.
impl
(
"fwd"
,
torch
::
kCUDA
,
make_pytorch_shim
(
&
mha_fwd
));
m
.
impl
(
"fwd"
,
torch
::
kCUDA
,
make_pytorch_shim
(
&
mha_fwd
));
}
}
...
...
sgl-kernel/include/sgl_flash_kernel_ops.h
View file @
252dc4e1
...
@@ -42,45 +42,44 @@ limitations under the License.
...
@@ -42,45 +42,44 @@ limitations under the License.
/*
/*
* From flash-attention
* From flash-attention
*/
*/
std
::
vector
<
at
::
Tensor
>
mha_fwd
(
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
mha_fwd
(
at
::
Tensor
&
q
,
// (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
at
::
Tensor
q
,
// (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
const
at
::
Tensor
&
k
,
// (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
at
::
Tensor
k
,
// (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
// h_k, d) if there is page_table.
// h_k, d) if there is page_table.
const
at
::
Tensor
&
v
,
// (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
at
::
Tensor
v
,
// (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
// page_size, h_k, dv) if there is page_table.
// page_size, h_k, dv) if there is page_table.
std
::
optional
<
const
at
::
Tensor
>&
std
::
optional
<
at
::
Tensor
>
k_new_
,
// (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
k_new_
,
// (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
std
::
optional
<
at
::
Tensor
>
v_new_
,
// (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std
::
optional
<
const
at
::
Tensor
>&
std
::
optional
<
at
::
Tensor
>
q_v_
,
// (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
v_new_
,
// (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std
::
optional
<
at
::
Tensor
>
out_
,
// (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std
::
optional
<
const
at
::
Tensor
>&
q_v_
,
// (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_
,
// b+1
std
::
optional
<
at
::
Tensor
>&
out_
,
// (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std
::
optional
<
at
::
Tensor
>
cu_seqlens_k_
,
// b+1
std
::
optional
<
const
at
::
Tensor
>&
cu_seqlens_q_
,
// b+1
std
::
optional
<
at
::
Tensor
>
cu_seqlens_k_new_
,
// b+1
std
::
optional
<
const
at
::
Tensor
>&
cu_seqlens_k_
,
// b+1
std
::
optional
<
at
::
Tensor
>
std
::
optional
<
const
at
::
Tensor
>&
cu_seqlens_k_new_
,
// b+1
std
::
optional
<
const
at
::
Tensor
>&
seqused_q_
,
// b. If given, only this many elements of each batch element's queries and outputs are used.
seqused_q_
,
// b. If given, only this many elements of each batch element's queries and outputs are used.
std
::
optional
<
const
at
::
Tensor
>
&
std
::
optional
<
at
::
Tensor
>
seqused_k_
,
// b. If given, only this many elements of each batch element's keys are used.
seqused_k_
,
// b. If given, only this many elements of each batch element's keys are used.
std
::
optional
<
int
>
max_seqlen_q_
,
std
::
optional
<
int
64_t
>
max_seqlen_q_
,
// TODO: check if we need max_seqlen_k
// TODO: check if we need max_seqlen_k
std
::
optional
<
int
>
max_seqlen_k_
,
std
::
optional
<
int
64_t
>
max_seqlen_k_
,
std
::
optional
<
const
at
::
Tensor
>
&
page_table_
,
// (b_k, max_num_pages_per_seq)
std
::
optional
<
at
::
Tensor
>
page_table_
,
// (b_k, max_num_pages_per_seq)
std
::
optional
<
const
at
::
Tensor
>
&
kv_batch_idx_
,
// b. indices to index into the KV cache
std
::
optional
<
at
::
Tensor
>
kv_batch_idx_
,
// b. indices to index into the KV cache
std
::
optional
<
const
at
::
Tensor
>
&
leftpad_k_
,
// b
std
::
optional
<
at
::
Tensor
>
leftpad_k_
,
// b
std
::
optional
<
const
at
::
Tensor
>
&
rotary_cos_
,
// seqlen_ro x (rotary_dim / 2)
std
::
optional
<
at
::
Tensor
>
rotary_cos_
,
// seqlen_ro x (rotary_dim / 2)
std
::
optional
<
const
at
::
Tensor
>
&
rotary_sin_
,
// seqlen_ro x (rotary_dim / 2)
std
::
optional
<
at
::
Tensor
>
rotary_sin_
,
// seqlen_ro x (rotary_dim / 2)
std
::
optional
<
const
at
::
Tensor
>
&
seqlens_rotary_
,
// b
std
::
optional
<
at
::
Tensor
>
seqlens_rotary_
,
// b
std
::
optional
<
at
::
Tensor
>
&
q_descale_
,
// (b, h_k), not (b, h)
std
::
optional
<
at
::
Tensor
>
q_descale_
,
// (b, h_k), not (b, h)
std
::
optional
<
at
::
Tensor
>
&
k_descale_
,
// (b, h_k)
std
::
optional
<
at
::
Tensor
>
k_descale_
,
// (b, h_k)
std
::
optional
<
at
::
Tensor
>
&
v_descale_
,
// (b, h_k)
std
::
optional
<
at
::
Tensor
>
v_descale_
,
// (b, h_k)
float
const
softmax_scale
,
std
::
optional
<
double
>
softmax_scale
_
,
bool
is_causal
,
bool
is_causal
,
int
window_size_left
,
int64_t
window_size_left
,
int
window_size_right
,
int64_t
window_size_right
,
float
const
softcap
,
int64_t
attention_chunk
,
bool
const
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
double
softcap
,
std
::
optional
<
at
::
Tensor
>&
scheduler_metadata_
,
// (b + 1)
bool
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int
num_splits
,
std
::
optional
<
at
::
Tensor
>
scheduler_metadata_
,
// (b + 1)
int64_t
num_splits
,
std
::
optional
<
bool
>
pack_gqa_
,
std
::
optional
<
bool
>
pack_gqa_
,
int
cons
t
sm_margin
,
int
64_
t
sm_margin
,
std
::
optional
<
const
at
::
Tensor
>&
sinks_
);
std
::
optional
<
const
at
::
Tensor
>&
sinks_
);
// (h)
sgl-kernel/python/sgl_kernel/_fa4_interface.py
View file @
252dc4e1
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/
203b9b3dba39d5d08dffb49c09aa622984dff07
d/flash_attn/cute/interface.py
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/
54d8aa6751fc9d5f0357854079261913d5df1f9
d/flash_attn/cute/interface.py
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
# [2025-
07-0
4] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.
1.0
.
# [2025-
10-1
4] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.
2.1
.
import
copy
import
copy
import
gc
import
gc
import
logging
import
logging
import
math
import
math
from
typing
import
Optional
,
Tuple
from
typing
import
Callable
,
Optional
,
Tuple
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -18,6 +18,7 @@ import cutlass
...
@@ -18,6 +18,7 @@ import cutlass
import
cutlass.cute
as
cute
import
cutlass.cute
as
cute
import
torch
import
torch
from
cutlass.cute.runtime
import
from_dlpack
from
cutlass.cute.runtime
import
from_dlpack
from
flash_attn.cute
import
utils
from
flash_attn.cute.flash_fwd
import
FlashAttentionForwardSm90
from
flash_attn.cute.flash_fwd
import
FlashAttentionForwardSm90
from
flash_attn.cute.flash_fwd_sm100
import
FlashAttentionForwardSm100
from
flash_attn.cute.flash_fwd_sm100
import
FlashAttentionForwardSm100
...
@@ -26,22 +27,6 @@ def maybe_contiguous(x):
...
@@ -26,22 +27,6 @@ def maybe_contiguous(x):
return
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
return
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
def
_reason_recompile
(
compile_key
,
jit_func
):
compile_cache
=
jit_func
.
compile_cache
compile_key_map
=
jit_func
.
compile_key_map
if
not
compile_cache
:
return
"not compiled yet"
for
k
,
v
in
compile_cache
.
items
():
if
k
==
compile_key
:
continue
if
len
(
k
)
!=
len
(
compile_key
):
continue
for
i
in
range
(
len
(
k
)):
if
k
[
i
]
!=
compile_key
[
i
]:
return
f
"diff at '
{
compile_key_map
[
i
]
}
':
{
k
[
i
]
}
vs
{
compile_key
[
i
]
}
"
return
"unknown reason"
torch2cute_dtype_map
=
{
torch2cute_dtype_map
=
{
torch
.
float16
:
cutlass
.
Float16
,
torch
.
float16
:
cutlass
.
Float16
,
torch
.
bfloat16
:
cutlass
.
BFloat16
,
torch
.
bfloat16
:
cutlass
.
BFloat16
,
...
@@ -72,7 +57,11 @@ def _flash_attn_fwd(
...
@@ -72,7 +57,11 @@ def _flash_attn_fwd(
num_threads
:
int
=
384
,
num_threads
:
int
=
384
,
pack_gqa
:
Optional
[
bool
]
=
None
,
pack_gqa
:
Optional
[
bool
]
=
None
,
_compute_capability
:
Optional
[
int
]
=
None
,
_compute_capability
:
Optional
[
int
]
=
None
,
return_softmax_lse
:
Optional
[
bool
]
=
False
,
score_mod
:
Callable
|
None
=
None
,
return_lse
:
bool
=
False
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
lse
:
Optional
[
torch
.
Tensor
]
=
None
,
buffers
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
q
,
k
,
v
=
[
maybe_contiguous
(
t
)
for
t
in
(
q
,
k
,
v
)]
q
,
k
,
v
=
[
maybe_contiguous
(
t
)
for
t
in
(
q
,
k
,
v
)]
num_head
,
head_dim
=
q
.
shape
[
-
2
:]
num_head
,
head_dim
=
q
.
shape
[
-
2
:]
...
@@ -169,6 +158,14 @@ def _flash_attn_fwd(
...
@@ -169,6 +158,14 @@ def _flash_attn_fwd(
q_batch_seqlen_shape
=
(
q_batch_seqlen_shape
=
(
(
batch_size
,
seqlen_q
)
if
cu_seqlens_q
is
None
else
(
total_q
,)
(
batch_size
,
seqlen_q
)
if
cu_seqlens_q
is
None
else
(
total_q
,)
)
)
lse_shape
=
(
(
batch_size
,
num_head
,
seqlen_q
)
if
cu_seqlens_q
is
None
else
(
num_head
,
total_q
)
)
requires_grad
=
q
.
requires_grad
or
k
.
requires_grad
or
v
.
requires_grad
if
out
is
None
:
out
=
torch
.
empty
(
out
=
torch
.
empty
(
*
q_batch_seqlen_shape
,
*
q_batch_seqlen_shape
,
num_head
,
num_head
,
...
@@ -176,16 +173,36 @@ def _flash_attn_fwd(
...
@@ -176,16 +173,36 @@ def _flash_attn_fwd(
dtype
=
out_torch_dtype
,
dtype
=
out_torch_dtype
,
device
=
device
,
device
=
device
,
)
)
lse_shape
=
(
else
:
(
batch_size
,
num_head
,
seqlen_q
)
expected_out_shape
=
(
*
q_batch_seqlen_shape
,
num_head
,
head_dim_v
)
if
cu_seqlens_q
is
None
assert
(
else
(
num_head
,
total_q
)
out
.
shape
==
expected_out_shape
)
),
f
"out tensor shape
{
out
.
shape
}
does not match expected shape
{
expected_out_shape
}
"
assert
(
out
.
dtype
==
out_torch_dtype
),
f
"out tensor dtype
{
out
.
dtype
}
does not match expected dtype
{
out_torch_dtype
}
"
assert
(
out
.
device
==
device
),
f
"out tensor device
{
out
.
device
}
does not match input device
{
device
}
"
assert
out
.
is_cuda
,
"out tensor must be on CUDA device"
if
lse
is
None
:
lse
=
(
lse
=
(
torch
.
empty
(
lse_shape
,
dtype
=
torch
.
float32
,
device
=
device
)
torch
.
empty
(
lse_shape
,
dtype
=
torch
.
float32
,
device
=
device
)
if
re
turn_softmax
_lse
if
re
quires_grad
or
return
_lse
else
None
else
None
)
)
elif
lse
is
not
None
:
assert
(
lse
.
shape
==
lse_shape
),
f
"lse tensor shape
{
lse
.
shape
}
does not match expected shape
{
lse_shape
}
"
assert
(
lse
.
dtype
==
torch
.
float32
),
f
"lse tensor dtype
{
lse
.
dtype
}
does not match expected dtype torch.float32"
assert
(
lse
.
device
==
device
),
f
"lse tensor device
{
lse
.
device
}
does not match input device
{
device
}
"
assert
lse
.
is_cuda
,
"lse tensor must be on CUDA device"
dtype
=
torch2cute_dtype_map
[
q
.
dtype
]
dtype
=
torch2cute_dtype_map
[
q
.
dtype
]
q_tensor
,
k_tensor
,
v_tensor
,
o_tensor
=
[
q_tensor
,
k_tensor
,
v_tensor
,
o_tensor
=
[
...
@@ -242,6 +259,7 @@ def _flash_attn_fwd(
...
@@ -242,6 +259,7 @@ def _flash_attn_fwd(
current_stream
=
cuda
.
CUstream
(
torch
.
cuda
.
current_stream
().
cuda_stream
)
current_stream
=
cuda
.
CUstream
(
torch
.
cuda
.
current_stream
().
cuda_stream
)
if
compute_capability
==
9
:
# TODO: tune block size according to hdim
if
compute_capability
==
9
:
# TODO: tune block size according to hdim
# Perf heuristic from upstream: hdim=128, noncausal, non-local benefits from larger n_block
if
head_dim
==
head_dim_v
==
128
and
not
causal
and
not
local
:
if
head_dim
==
head_dim_v
==
128
and
not
causal
and
not
local
:
n_block_size
=
192
n_block_size
=
192
if
compute_capability
==
10
:
if
compute_capability
==
10
:
...
@@ -253,13 +271,34 @@ def _flash_attn_fwd(
...
@@ -253,13 +271,34 @@ def _flash_attn_fwd(
):
):
pack_gqa
=
False
pack_gqa
=
False
if
softcap
is
not
None
:
assert
score_mod
is
None
,
"softcap and score_mod cannot be used together"
score_mod
=
utils
.
create_softcap_scoremod
(
softcap
)
if
score_mod
is
not
None
:
is_varlen
=
(
cu_seqlens_q
is
not
None
or
cu_seqlens_k
is
not
None
or
seqused_q
is
not
None
or
seqused_k
is
not
None
)
if
is_varlen
:
raise
NotImplementedError
(
"score_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR."
)
cute_buffers
=
None
if
buffers
is
not
None
:
cute_buffers
=
[
from_dlpack
(
buf
)
for
buf
in
buffers
]
compile_key
=
(
compile_key
=
(
dtype
,
dtype
,
head_dim
,
head_dim
,
head_dim_v
,
head_dim_v
,
qhead_per_kvhead
,
qhead_per_kvhead
,
causal
,
causal
,
softcap
is
not
None
,
utils
.
hash_callable
(
score_mod
)
if
score_mod
is
not
None
else
None
,
buffers
is
not
None
,
lse
is
None
,
lse
is
None
,
cu_seqlens_q
is
None
,
cu_seqlens_q
is
None
,
cu_seqlens_k
is
None
,
cu_seqlens_k
is
None
,
...
@@ -276,9 +315,6 @@ def _flash_attn_fwd(
...
@@ -276,9 +315,6 @@ def _flash_attn_fwd(
compute_capability
,
compute_capability
,
)
)
if
compile_key
not
in
_flash_attn_fwd
.
compile_cache
:
if
compile_key
not
in
_flash_attn_fwd
.
compile_cache
:
logger
.
info
(
f
"Compiling FA4 kernel with reason:
{
_reason_recompile
(
compile_key
,
_flash_attn_fwd
)
}
"
)
if
compute_capability
==
9
:
if
compute_capability
==
9
:
assert
page_table
is
None
,
"paged KV not supported on SM 9.0"
assert
page_table
is
None
,
"paged KV not supported on SM 9.0"
# fa_fwd = FlashAttentionForwardSm80(
# fa_fwd = FlashAttentionForwardSm80(
...
@@ -290,12 +326,14 @@ def _flash_attn_fwd(
...
@@ -290,12 +326,14 @@ def _flash_attn_fwd(
is_causal
=
causal
,
is_causal
=
causal
,
is_local
=
local
,
is_local
=
local
,
pack_gqa
=
pack_gqa
,
pack_gqa
=
pack_gqa
,
m_block_size
=
m_block_size
,
tile_m
=
m_block_size
,
n_block_size
=
n_block_size
,
tile_n
=
n_block_size
,
# num_stages=1,
# num_stages=1,
num_stages
=
2
,
num_stages
=
2
,
num_threads
=
num_threads
,
num_threads
=
num_threads
,
Q_in_regs
=
False
,
Q_in_regs
=
False
,
score_mod
=
score_mod
,
has_buffers
=
buffers
is
not
None
,
)
)
elif
compute_capability
==
10
:
elif
compute_capability
==
10
:
assert
page_size
in
[
assert
page_size
in
[
...
@@ -313,12 +351,15 @@ def _flash_attn_fwd(
...
@@ -313,12 +351,15 @@ def _flash_attn_fwd(
and
not
local
and
not
local
and
cu_seqlens_q
is
None
and
cu_seqlens_q
is
None
and
seqused_q
is
None
,
and
seqused_q
is
None
,
score_mod
=
score_mod
,
has_buffers
=
buffers
is
not
None
,
)
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Unsupported compute capability:
{
compute_capability
}
. Supported: 9.x, 10.x"
f
"Unsupported compute capability:
{
compute_capability
}
. Supported: 9.x, 10.x"
)
)
# TODO: check @can_implement
# TODO: check @can_implement
# TODO caching for buffers; cute_buffers
_flash_attn_fwd
.
compile_cache
[
compile_key
]
=
cute
.
compile
(
_flash_attn_fwd
.
compile_cache
[
compile_key
]
=
cute
.
compile
(
fa_fwd
,
fa_fwd
,
q_tensor
,
q_tensor
,
...
@@ -333,10 +374,10 @@ def _flash_attn_fwd(
...
@@ -333,10 +374,10 @@ def _flash_attn_fwd(
seqused_q_tensor
,
seqused_q_tensor
,
seqused_k_tensor
,
seqused_k_tensor
,
page_table_tensor
,
page_table_tensor
,
softcap
,
window_size_left
,
window_size_left
,
window_size_right
,
window_size_right
,
learnable_sink_tensor
,
learnable_sink_tensor
,
cute_buffers
,
)
)
_flash_attn_fwd
.
compile_cache
[
compile_key
](
_flash_attn_fwd
.
compile_cache
[
compile_key
](
q_tensor
,
q_tensor
,
...
@@ -351,46 +392,29 @@ def _flash_attn_fwd(
...
@@ -351,46 +392,29 @@ def _flash_attn_fwd(
seqused_q_tensor
,
seqused_q_tensor
,
seqused_k_tensor
,
seqused_k_tensor
,
page_table_tensor
,
page_table_tensor
,
softcap
,
window_size_left
,
window_size_left
,
window_size_right
,
window_size_right
,
learnable_sink_tensor
,
learnable_sink_tensor
,
cute_buffers
,
)
)
return
out
,
lse
return
out
,
lse
_flash_attn_fwd
.
compile_cache
=
{}
_flash_attn_fwd
.
compile_cache
=
{}
_flash_attn_fwd
.
compile_key_map
=
[
"dtype"
,
"head_dim"
,
"head_dim_v"
,
"qhead_per_kvhead"
,
"causal"
,
"softcap is not None"
,
"lse is None"
,
"cu_seqlens_q is None"
,
"cu_seqlens_k is None"
,
"seqused_q is None"
,
"seqused_k is None"
,
"page_table is not None"
,
"window_size_left is not None"
,
"window_size_right is not None"
,
"learnable_sink is not None"
,
"m_block_size"
,
"n_block_size"
,
"num_threads"
,
"pack_gqa"
,
"compute_capability"
,
]
def
warmup_flash_attn
(
f
):
def
warmup_flash_attn
(
f
):
"""
"""
Decorator for flash_attn_varlen_func:
Decorator for flash_attn_varlen_func:
- On the first call, run several warmup passes with different flag combinations
- On first call, run several warmup passes with different flag combinations:
- Warmups are executed sequentially to minimize peak GPU memory usage
* return_softmax_lse in {False, True}
- Does not modify user-provided tensors (clones data)
* global noncausal (window_size=(None,None))
- Easy to extend with more compile-key dimensions
* causal (window_size=(None,0))
* local sliding window (window_size=(64,64))
* optionally pack_gqa=True if qheads > kvheads and allowed
- No score_mod / softcap (not supported for varlen yet)
- Executes sequentially to minimize peak GPU mem
- Does not modify user tensors (clones)
"""
"""
done
=
False
done
=
False
...
@@ -399,30 +423,78 @@ def warmup_flash_attn(f):
...
@@ -399,30 +423,78 @@ def warmup_flash_attn(f):
def
maybe_clone
(
x
):
def
maybe_clone
(
x
):
if
isinstance
(
x
,
torch
.
Tensor
):
if
isinstance
(
x
,
torch
.
Tensor
):
return
x
.
clone
()
return
x
.
detach
().
clone
()
# detach to avoid autograd edges
return
copy
.
deepcopy
(
x
)
return
copy
.
deepcopy
(
x
)
return
tuple
(
maybe_clone
(
a
)
for
a
in
args
),
{
return
tuple
(
maybe_clone
(
a
)
for
a
in
args
),
{
k
:
maybe_clone
(
v
)
for
k
,
v
in
kwargs
.
items
()
k
:
maybe_clone
(
v
)
for
k
,
v
in
kwargs
.
items
()
}
}
def
_infer_heads
(
args
,
kwargs
):
"""Infer q and kv head counts from arguments."""
# Expect signature: (q, k, v, cu_seqlens_q, cu_seqlens_k, ...)
q
=
args
[
0
]
if
len
(
args
)
>
0
else
kwargs
.
get
(
"q"
)
k
=
args
[
1
]
if
len
(
args
)
>
1
else
kwargs
.
get
(
"k"
)
try
:
qh
=
int
(
q
.
shape
[
-
2
])
kvh
=
int
(
k
.
shape
[
-
2
])
return
qh
,
kvh
except
Exception
:
return
None
,
None
def
_run_warmups
(
args
,
kwargs
):
def
_run_warmups
(
args
,
kwargs
):
"""Run warmup calls sequentially and release memory after each."""
"""Run warmup calls sequentially and release memory after each."""
base_args
,
base_kwargs
=
_clone_args
(
args
,
kwargs
)
base_args
,
base_kwargs
=
_clone_args
(
args
,
kwargs
)
# Warmup combinations for return_softmax_lse and causal
qh
,
kvh
=
_infer_heads
(
base_args
,
base_kwargs
)
combos
=
[
can_pack_gqa
=
(
dict
(
return_softmax_lse
=
False
,
causal
=
False
),
qh
is
not
None
and
kvh
is
not
None
and
qh
%
kvh
==
0
and
qh
//
kvh
>
1
dict
(
return_softmax_lse
=
False
,
causal
=
True
),
)
dict
(
return_softmax_lse
=
True
,
causal
=
False
),
has_page_table
=
(
dict
(
return_softmax_lse
=
True
,
causal
=
True
),
"page_table"
in
base_kwargs
and
base_kwargs
[
"page_table"
]
is
not
None
)
# Window presets covering global, causal, and local
window_presets
=
[
(
None
,
None
),
# global noncausal
(
None
,
0
),
# causal
(
64
,
64
),
# local sliding window
]
]
lse_flags
=
[
False
,
True
]
# Base combo list
combos
=
[]
for
ws
in
window_presets
:
for
return_lse_flag
in
lse_flags
:
combos
.
append
(
dict
(
window_size
=
ws
,
return_softmax_lse
=
return_lse_flag
))
# Optionally add a pack_gqa=True variant (FA4 may disable it internally for some varlen shapes/SMs)
if
can_pack_gqa
:
for
ws
in
window_presets
:
combos
.
append
(
dict
(
window_size
=
ws
,
return_softmax_lse
=
False
,
pack_gqa
=
True
)
)
# If page_table is present, warm one combo with it (page_table in compile key for SM100)
if
has_page_table
:
combos
.
append
(
dict
(
window_size
=
(
None
,
None
),
return_softmax_lse
=
False
))
# Run sequentially
for
combo
in
combos
:
for
combo
in
combos
:
wa
,
wk
=
_clone_args
(
base_args
,
base_kwargs
)
wa
,
wk
=
_clone_args
(
base_args
,
base_kwargs
)
# Keep user-provided softcap/score_mod OUT (varlen+score_mod unsupported)
wk
.
pop
(
"score_mod"
,
None
)
if
"softcap"
in
wk
and
wk
[
"softcap"
]:
wk
[
"softcap"
]
=
0.0
# Apply combo
wk
.
update
(
combo
)
wk
.
update
(
combo
)
with
torch
.
cuda
.
stream
(
torch
.
cuda
.
current_stream
()):
with
torch
.
cuda
.
stream
(
torch
.
cuda
.
current_stream
()):
try
:
f
(
*
wa
,
**
wk
)
f
(
*
wa
,
**
wk
)
except
Exception
as
e
:
# Some combos can be invalid for specific head dims / arch. Ignore and continue.
logger
.
debug
(
"Warmup combo skipped: %s"
,
e
)
del
wa
,
wk
del
wa
,
wk
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
...
@@ -430,7 +502,9 @@ def warmup_flash_attn(f):
...
@@ -430,7 +502,9 @@ def warmup_flash_attn(f):
def
wrapper
(
*
args
,
**
kwargs
):
def
wrapper
(
*
args
,
**
kwargs
):
nonlocal
done
nonlocal
done
if
not
done
:
if
not
done
:
logger
.
info
(
"Running flash_attn_varlen_func warmup passes..."
)
logger
.
info
(
"Running FA4 warmup (global/causal/local, LSE on/off, optional GQA pack)..."
)
_run_warmups
(
args
,
kwargs
)
_run_warmups
(
args
,
kwargs
)
done
=
True
done
=
True
return
f
(
*
args
,
**
kwargs
)
return
f
(
*
args
,
**
kwargs
)
...
@@ -472,7 +546,7 @@ def flash_attn_varlen_func(
...
@@ -472,7 +546,7 @@ def flash_attn_varlen_func(
learnable_sink
=
learnable_sink
,
learnable_sink
=
learnable_sink
,
softcap
=
softcap
,
softcap
=
softcap
,
pack_gqa
=
pack_gqa
,
pack_gqa
=
pack_gqa
,
return_
softmax_
lse
=
return_softmax_lse
,
return_lse
=
return_softmax_lse
,
)
)
return
(
out
,
lse
)
if
return_softmax_lse
else
out
return
(
out
,
lse
)
if
return_softmax_lse
else
out
sgl-kernel/python/sgl_kernel/flash_attn.py
View file @
252dc4e1
...
@@ -45,7 +45,7 @@ def flash_attn_with_kvcache(
...
@@ -45,7 +45,7 @@ def flash_attn_with_kvcache(
qv
=
None
,
qv
=
None
,
rotary_cos
=
None
,
rotary_cos
=
None
,
rotary_sin
=
None
,
rotary_sin
=
None
,
cache_seqlens
:
Optional
[
Union
[
(
int
,
torch
.
Tensor
)
]]
=
None
,
cache_seqlens
:
Optional
[
Union
[
int
,
torch
.
Tensor
]]
=
None
,
cache_batch_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_batch_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_leftpad
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_leftpad
:
Optional
[
torch
.
Tensor
]
=
None
,
page_table
:
Optional
[
torch
.
Tensor
]
=
None
,
page_table
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -59,6 +59,7 @@ def flash_attn_with_kvcache(
...
@@ -59,6 +59,7 @@ def flash_attn_with_kvcache(
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
attention_chunk
:
Optional
[
int
]
=
None
,
softcap
=
0.0
,
# 0.0 means deactivated
softcap
=
0.0
,
# 0.0 means deactivated
rotary_interleaved
=
True
,
rotary_interleaved
=
True
,
scheduler_metadata
=
None
,
scheduler_metadata
=
None
,
...
@@ -137,6 +138,7 @@ def flash_attn_with_kvcache(
...
@@ -137,6 +138,7 @@ def flash_attn_with_kvcache(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
attention_chunk: Optional[int]. If not None, splits the query into chunks of this size to save memory.
softcap: float. Anything > 0 activates softcapping attention.
softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
...
@@ -216,6 +218,7 @@ def flash_attn_with_kvcache(
...
@@ -216,6 +218,7 @@ def flash_attn_with_kvcache(
]
]
rotary_cos
,
rotary_sin
=
[
maybe_contiguous
(
x
)
for
x
in
(
rotary_cos
,
rotary_sin
)]
rotary_cos
,
rotary_sin
=
[
maybe_contiguous
(
x
)
for
x
in
(
rotary_cos
,
rotary_sin
)]
rotary_seqlens
=
maybe_contiguous
(
rotary_seqlens
)
rotary_seqlens
=
maybe_contiguous
(
rotary_seqlens
)
attention_chunk
=
0
if
attention_chunk
is
None
else
int
(
attention_chunk
)
out
,
softmax_lse
,
*
rest
=
torch
.
ops
.
sgl_kernel
.
fwd
.
default
(
out
,
softmax_lse
,
*
rest
=
torch
.
ops
.
sgl_kernel
.
fwd
.
default
(
q
,
q
,
...
@@ -245,6 +248,7 @@ def flash_attn_with_kvcache(
...
@@ -245,6 +248,7 @@ def flash_attn_with_kvcache(
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
attention_chunk
,
softcap
,
softcap
,
rotary_interleaved
,
rotary_interleaved
,
scheduler_metadata
,
scheduler_metadata
,
...
@@ -263,10 +267,11 @@ def flash_attn_varlen_func(
...
@@ -263,10 +267,11 @@ def flash_attn_varlen_func(
v
,
v
,
cu_seqlens_q
,
cu_seqlens_q
,
cu_seqlens_k
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_q
=
None
,
max_seqlen_k
,
max_seqlen_k
=
None
,
seqused_q
=
None
,
seqused_q
=
None
,
seqused_k
=
None
,
seqused_k
=
None
,
page_table
=
None
,
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
qv
=
None
,
qv
=
None
,
...
@@ -274,6 +279,7 @@ def flash_attn_varlen_func(
...
@@ -274,6 +279,7 @@ def flash_attn_varlen_func(
k_descale
=
None
,
k_descale
=
None
,
v_descale
=
None
,
v_descale
=
None
,
window_size
=
(
-
1
,
-
1
),
window_size
=
(
-
1
,
-
1
),
attention_chunk
=
0
,
softcap
=
0.0
,
softcap
=
0.0
,
num_splits
=
1
,
num_splits
=
1
,
pack_gqa
=
None
,
pack_gqa
=
None
,
...
@@ -293,25 +299,18 @@ def flash_attn_varlen_func(
...
@@ -293,25 +299,18 @@ def flash_attn_varlen_func(
q
,
q
,
k
,
k
,
v
,
v
,
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
,
cu_seqlens_k
=
cu_seqlens_k
,
# max_seqlen_q,
# max_seqlen_k,
seqused_q
=
seqused_q
,
seqused_q
=
seqused_q
,
seqused_k
=
seqused_k
,
seqused_k
=
seqused_k
,
page_table
=
page_table
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
# qv=qv,
# q_descale=q_descale,
# k_descale=k_descale,
# v_descale=v_descale,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
softcap
=
softcap
,
# num_splits=num_splits,
pack_gqa
=
pack_gqa
,
pack_gqa
=
pack_gqa
,
# sm_margin=sm_margin,
return_softmax_lse
=
return_softmax_lse
,
learnable_sink
=
sinks
,
learnable_sink
=
sinks
,
return_softmax_lse
=
return_softmax_lse
,
)
)
if
not
is_fa3_supported
():
if
not
is_fa3_supported
():
...
@@ -319,10 +318,15 @@ def flash_attn_varlen_func(
...
@@ -319,10 +318,15 @@ def flash_attn_varlen_func(
"flash_attn at sgl-kernel is only supported on sm90 and above"
"flash_attn at sgl-kernel is only supported on sm90 and above"
)
)
# FA3 requires max_seqlen_q and max_seqlen_k
if
max_seqlen_q
is
None
or
max_seqlen_k
is
None
:
raise
ValueError
(
"max_seqlen_q and max_seqlen_k are required for FA3"
)
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
(
q
.
shape
[
-
1
]
+
(
qv
.
shape
[
-
1
]
if
qv
is
not
None
else
0
))
**
(
softmax_scale
=
(
q
.
shape
[
-
1
]
+
(
qv
.
shape
[
-
1
]
if
qv
is
not
None
else
0
))
**
(
-
0.5
-
0.5
)
)
attention_chunk
=
0
if
attention_chunk
is
None
else
int
(
attention_chunk
)
out
,
softmax_lse
,
*
rest
=
torch
.
ops
.
sgl_kernel
.
fwd
.
default
(
out
,
softmax_lse
,
*
rest
=
torch
.
ops
.
sgl_kernel
.
fwd
.
default
(
q
,
q
,
...
@@ -352,6 +356,7 @@ def flash_attn_varlen_func(
...
@@ -352,6 +356,7 @@ def flash_attn_varlen_func(
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
attention_chunk
,
softcap
,
softcap
,
is_rotary_interleaved
=
False
,
is_rotary_interleaved
=
False
,
scheduler_metadata
=
None
,
scheduler_metadata
=
None
,
...
...
sgl-kernel/tests/test_flash_attention_4.py
View file @
252dc4e1
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/
b31ae1e4cd22cf5f820a2995b74b7cd3bd54355a
/tests/cute/test_flash_attn.py
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/
8ecf128f683266735ba68e3c106ff67a2611886e
/tests/cute/test_flash_attn.py
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
...
@@ -10,12 +10,25 @@ import pytest
...
@@ -10,12 +10,25 @@ import pytest
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
try
:
from
flash_attn.layers.rotary
import
apply_rotary_emb
except
ImportError
:
apply_rotary_emb
=
None
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
utils
import
is_hopper
from
sgl_kernel.testing.rotary_embedding
import
_apply_rotary_emb
as
apply_rotary_emb
# from utils import is_hopper # Not used in this test
# Force sgl_kernel.flash_attn wrappers to use FA4 (Cute-DSL) implementations.
# The wrappers accept a superset of args; for FA4, extra args are ignored.
flash_attn_varlen_func
=
partial
(
flash_attn_varlen_func
,
ver
=
4
)
flash_attn_varlen_func
=
partial
(
flash_attn_varlen_func
,
ver
=
4
)
flash_attn_with_kvcache
=
partial
(
flash_attn_with_kvcache
,
ver
=
4
)
flash_attn_with_kvcache
=
partial
(
flash_attn_with_kvcache
,
ver
=
4
)
# Skip this test on Hopper machine
skip_condition
=
torch
.
cuda
.
get_device_capability
()
<
(
10
,
0
)
def
unpad_input
(
hidden_states
,
attention_mask
,
unused_mask
=
None
):
def
unpad_input
(
hidden_states
,
attention_mask
,
unused_mask
=
None
):
"""
"""
...
@@ -88,6 +101,11 @@ def generate_random_padding_mask(
...
@@ -88,6 +101,11 @@ def generate_random_padding_mask(
lengths
=
torch
.
randint
(
lengths
=
torch
.
randint
(
max_seqlen
//
3
,
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
max_seqlen
//
3
,
max_seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
)
else
:
# This should never happen due to the assertion above, but for linter
lengths
=
torch
.
full
(
(
batch_size
,
1
),
max_seqlen
,
device
=
device
,
dtype
=
torch
.
int32
)
if
zero_lengths
:
if
zero_lengths
:
# Generate zero-lengths every 5 batches and the last batch.
# Generate zero-lengths every 5 batches and the last batch.
...
@@ -482,8 +500,7 @@ def attention_ref(
...
@@ -482,8 +500,7 @@ def attention_ref(
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
is_hopper
(),
skip_condition
,
reason
=
"FA4 Requires compute capability of 10 or above."
reason
=
"skip on hopper"
,
)
)
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
...
@@ -497,8 +514,8 @@ def attention_ref(
...
@@ -497,8 +514,8 @@ def attention_ref(
@
pytest
.
mark
.
parametrize
(
"deterministic"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"deterministic"
,
[
False
])
# @pytest.mark.parametrize("softcap", [0.0, 15.0])
# @pytest.mark.parametrize("softcap", [0.0, 15.0])
@
pytest
.
mark
.
parametrize
(
"softcap"
,
[
0.0
])
@
pytest
.
mark
.
parametrize
(
"softcap"
,
[
0.0
])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
#
@pytest.mark.parametrize("local", [False, True])
#
@pytest.mark.parametrize("local", [False])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [False])
# @pytest.mark.parametrize("causal", [False])
# @pytest.mark.parametrize("add_unused_qkv", [False, True])
# @pytest.mark.parametrize("add_unused_qkv", [False, True])
...
@@ -522,11 +539,11 @@ def attention_ref(
...
@@ -522,11 +539,11 @@ def attention_ref(
(
64
,
128
),
(
64
,
128
),
(
128
,
128
),
(
128
,
128
),
(
256
,
256
),
(
256
,
256
),
(
113
,
203
),
#
(113, 203),
(
128
,
217
),
#
(128, 217),
(
113
,
211
),
#
(113, 211),
(
108
,
256
),
#
(108, 256),
(
256
,
512
),
#
(256, 512),
(
307
,
256
),
(
307
,
256
),
(
640
,
128
),
(
640
,
128
),
(
512
,
256
),
(
512
,
256
),
...
@@ -658,25 +675,7 @@ def test_flash_attn_varlen_output(
...
@@ -658,25 +675,7 @@ def test_flash_attn_varlen_output(
if
causal
or
local
:
if
causal
or
local
:
key_padding_mask
=
query_padding_mask
key_padding_mask
=
query_padding_mask
(
result
=
generate_qkv
(
q_unpad
,
k_unpad
,
v_unpad
,
qv_unpad
,
cu_seqlens_q
,
cu_seqlens_k
,
seqused_q
,
seqused_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
k
,
v
,
qv
,
output_pad_fn
,
dq_pad_fn
,
dk_pad_fn
,
)
=
generate_qkv
(
q
,
q
,
k
,
k
,
v
,
v
,
...
@@ -687,6 +686,25 @@ def test_flash_attn_varlen_output(
...
@@ -687,6 +686,25 @@ def test_flash_attn_varlen_output(
query_unused_mask
=
query_unused_mask
,
query_unused_mask
=
query_unused_mask
,
key_unused_mask
=
key_unused_mask
,
key_unused_mask
=
key_unused_mask
,
)
)
(
q_unpad
,
# 0
k_unpad
,
# 1
v_unpad
,
# 2
qv_unpad
,
# 3
cu_seqlens_q
,
# 4
cu_seqlens_k
,
# 5
seqused_q
,
# 6
seqused_k
,
# 7
max_seqlen_q
,
# 8
max_seqlen_k
,
# 9
q
,
# 10
k
,
# 11
v
,
# 12
qv
,
# 13
output_pad_fn
,
# 14
dq_pad_fn
,
# 15
dk_pad_fn
,
# 16
)
=
result
q_unpad
,
k_unpad
,
v_unpad
=
[
q_unpad
,
k_unpad
,
v_unpad
=
[
x
.
detach
().
to
(
dtype
).
requires_grad_
()
for
x
in
(
q_unpad
,
k_unpad
,
v_unpad
)
x
.
detach
().
to
(
dtype
).
requires_grad_
()
for
x
in
(
q_unpad
,
k_unpad
,
v_unpad
)
]
]
...
@@ -746,20 +764,16 @@ def test_flash_attn_varlen_output(
...
@@ -746,20 +764,16 @@ def test_flash_attn_varlen_output(
v_unpad
,
v_unpad
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
None
,
# max_seqlen_q and max_seqlen_k not needed for FA4
max_seqlen_k
=
None
,
seqused_q
=
seqused_q
,
# seqused_q=seqused_q,
seqused_k
=
seqused_k
,
# seqused_k=seqused_k,
causal
=
causal
,
causal
=
causal
,
# qv=qv_unpad,
# q_descale=q_descale,
# k_descale=k_descale, v_descale=v_descale,
window_size
=
window_size
,
window_size
=
window_size
,
# attention_chunk=attention_chunk,
sinks
=
learnable_sink
,
softcap
=
softcap
,
softcap
=
softcap
,
sinks
=
learnable_sink
,
# FA4 uses learnable_sink, not sinks
pack_gqa
=
pack_gqa
,
pack_gqa
=
pack_gqa
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
ver
=
4
,
# Use FA4
)
)
out
=
output_pad_fn
(
out_unpad
)
out
=
output_pad_fn
(
out_unpad
)
if
query_unused_mask
is
not
None
:
if
query_unused_mask
is
not
None
:
...
@@ -875,8 +889,7 @@ def test_flash_attn_varlen_output(
...
@@ -875,8 +889,7 @@ def test_flash_attn_varlen_output(
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
is_hopper
(),
skip_condition
,
reason
=
"FA4 Requires compute capability of 10 or above."
reason
=
"skip on hopper"
,
)
)
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
...
@@ -887,8 +900,8 @@ def test_flash_attn_varlen_output(
...
@@ -887,8 +900,8 @@ def test_flash_attn_varlen_output(
# @pytest.mark.parametrize("has_learnable_sink", [False])
# @pytest.mark.parametrize("has_learnable_sink", [False])
# @pytest.mark.parametrize("new_kv", [False, True])
# @pytest.mark.parametrize("new_kv", [False, True])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"new_kv"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
,
True
])
#
@pytest.mark.parametrize("local", [False, True])
#
@pytest.mark.parametrize("local", [False])
@
pytest
.
mark
.
parametrize
(
"local"
,
[
False
])
# @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [False, True])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
True
])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
...
@@ -900,8 +913,8 @@ def test_flash_attn_varlen_output(
...
@@ -900,8 +913,8 @@ def test_flash_attn_varlen_output(
# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
])
# @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128]))
# @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128]))
@
pytest
.
mark
.
parametrize
(
"page_size"
,
[
None
,
128
])
#
@pytest.mark.parametrize("page_size", [None, 128])
#
@pytest.mark.parametrize("page_size", [128])
@
pytest
.
mark
.
parametrize
(
"page_size"
,
[
128
])
# @pytest.mark.parametrize("has_leftpad", [False, True])
# @pytest.mark.parametrize("has_leftpad", [False, True])
@
pytest
.
mark
.
parametrize
(
"has_leftpad"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"has_leftpad"
,
[
False
])
# @pytest.mark.parametrize("has_batch_idx", [False, True])
# @pytest.mark.parametrize("has_batch_idx", [False, True])
...
@@ -1085,6 +1098,7 @@ def test_flash_attn_kvcache(
...
@@ -1085,6 +1098,7 @@ def test_flash_attn_kvcache(
.
to
(
dtype_ref
)
.
to
(
dtype_ref
)
)
)
page_table
=
None
page_table
=
None
num_blocks
=
None
else
:
else
:
(
(
k_cache
,
k_cache
,
...
@@ -1301,31 +1315,24 @@ def test_flash_attn_kvcache(
...
@@ -1301,31 +1315,24 @@ def test_flash_attn_kvcache(
else
:
else
:
k_cache_paged
.
copy_
(
k_cache_saved
)
k_cache_paged
.
copy_
(
k_cache_saved
)
v_cache_paged
.
copy_
(
v_cache_saved
)
v_cache_paged
.
copy_
(
v_cache_saved
)
# out, lse, *rest = flash_attn_with_kvcache(
# For FA4, use flash_attn_varlen_func directly instead of flash_attn_with_kvcache
out
,
lse
,
*
rest
=
flash_attn_with_kvcache
(
# This matches the pattern from the original FA4 test
out
,
lse
=
flash_attn_varlen_func
(
q
if
not
varlen_q
else
q_unpad
,
q
if
not
varlen_q
else
q_unpad
,
k_cache
if
page_size
is
None
else
k_cache_paged
,
k_cache
if
page_size
is
None
else
k_cache_paged
,
v_cache
if
page_size
is
None
else
v_cache_paged
,
v_cache
if
page_size
is
None
else
v_cache_paged
,
# k if not new_kv or not varlen_q else k_unpad,
# v if not new_kv or not varlen_q else v_unpad,
# qv=qv if not varlen_q else qv_unpad,
# rotary_cos=cos,
# rotary_sin=sin,
cache_seqlens
=
cache_seqlens
,
# cache_batch_idx=cache_batch_idx,
# cache_leftpad=cache_leftpad,
page_table
=
page_table
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
# cu_seqlens_k_new=cu_seqlens_k_new,
cu_seqlens_k
=
None
,
# FA4 doesn't use cu_seqlens_k for KV cache
# rotary_seqlens=rotary_seqlens,
# max_seqlen_q and max_seqlen_k not needed for FA4
seqused_k
=
cache_seqlens
,
# Use cache_seqlens as seqused_k
page_table
=
page_table
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
sinks
=
learnable_sink
,
sinks
=
learnable_sink
,
# FA4 uses learnable_sink, not sinks
# attention_chunk=attention_chunk,
softcap
=
0.0
,
# rotary_interleaved=rotary_interleaved,
pack_gqa
=
None
,
# scheduler_metadata=scheduler_metadata,
# num_splits=num_splits,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
ver
=
4
,
# Use FA4
)
)
if
varlen_q
:
if
varlen_q
:
out
=
output_pad_fn
(
out
)
out
=
output_pad_fn
(
out
)
...
...
test/srt/run_suite.py
View file @
252dc4e1
...
@@ -169,6 +169,7 @@ suites = {
...
@@ -169,6 +169,7 @@ suites = {
TestFile
(
"test_disaggregation_pp.py"
,
140
),
TestFile
(
"test_disaggregation_pp.py"
,
140
),
],
],
"per-commit-4-gpu-b200"
:
[
"per-commit-4-gpu-b200"
:
[
# TestFile("test_flash_attention_4.py"),
# TestFile("test_gpt_oss_4gpu.py", 600),
# TestFile("test_gpt_oss_4gpu.py", 600),
# TestFile("test_deepseek_v3_fp4_4gpu.py", 3600),
# TestFile("test_deepseek_v3_fp4_4gpu.py", 3600),
],
],
...
...
test/srt/test_flash_attention_4.py
0 → 100644
View file @
252dc4e1
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.environ
import
envs
from
sglang.srt.utils
import
get_device_sm
,
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
@
unittest
.
skipIf
(
get_device_sm
()
<
100
,
"Test requires CUDA SM 100 or higher"
)
class
TestFlashAttention4
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
other_args
=
[
"--trust-remote-code"
,
"--mem-fraction-static"
,
"0.8"
,
"--prefill-attention-backend"
,
"fa4"
,
]
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
4
,
data_path
=
None
,
num_questions
=
100
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
print
(
metrics
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.65
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment