Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f792e3c5
Unverified
Commit
f792e3c5
authored
Oct 13, 2025
by
Yineng Zhang
Committed by
GitHub
Oct 13, 2025
Browse files
Revert "[NVIDIA] BUMP FA3 (#11444)" (#11582)
parent
28f80b12
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
66 additions
and
75 deletions
+66
-75
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+2
-2
sgl-kernel/csrc/flash_extension.cc
sgl-kernel/csrc/flash_extension.cc
+26
-29
sgl-kernel/include/sgl_flash_kernel_ops.h
sgl-kernel/include/sgl_flash_kernel_ops.h
+37
-36
sgl-kernel/python/sgl_kernel/flash_attn.py
sgl-kernel/python/sgl_kernel/flash_attn.py
+1
-8
No files found.
sgl-kernel/CMakeLists.txt
View file @
f792e3c5
...
@@ -90,7 +90,7 @@ FetchContent_Populate(repo-flashinfer)
...
@@ -90,7 +90,7 @@ FetchContent_Populate(repo-flashinfer)
FetchContent_Declare
(
FetchContent_Declare
(
repo-flash-attention
repo-flash-attention
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
GIT_REPOSITORY https://github.com/sgl-project/sgl-attn
GIT_TAG
36f9456cd48ec57c8d75d8d6b90933d4bedffb6b
GIT_TAG
f9af0c2a1d82ab1812e6987e9338363cc2bf0f8d
GIT_SHALLOW OFF
GIT_SHALLOW OFF
)
)
FetchContent_Populate
(
repo-flash-attention
)
FetchContent_Populate
(
repo-flash-attention
)
...
@@ -99,7 +99,7 @@ FetchContent_Populate(repo-flash-attention)
...
@@ -99,7 +99,7 @@ FetchContent_Populate(repo-flash-attention)
FetchContent_Declare
(
FetchContent_Declare
(
repo-flash-attention-origin
repo-flash-attention-origin
GIT_REPOSITORY https://github.com/Dao-AILab/flash-attention.git
GIT_REPOSITORY https://github.com/Dao-AILab/flash-attention.git
GIT_TAG
5a5a65b48dc99fc7483d2a7d5cfb1d8befa89389
GIT_TAG
203b9b3dba39d5d08dffb49c09aa622984dff07d
GIT_SHALLOW OFF
GIT_SHALLOW OFF
)
)
FetchContent_Populate
(
repo-flash-attention-origin
)
FetchContent_Populate
(
repo-flash-attention-origin
)
...
...
sgl-kernel/csrc/flash_extension.cc
View file @
f792e3c5
...
@@ -23,43 +23,40 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -23,43 +23,40 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
* From flash-attention
* From flash-attention
*/
*/
m
.
def
(
m
.
def
(
"fwd(Tensor
q,"
// (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
"fwd(Tensor
!
q,"
" Tensor k,"
// (b_k, s_k, h_k, d) or (total_k, h_k, d) or paged
" Tensor k,"
" Tensor v,"
// (b_k, s_k, h_k, dv) or (total_k, h_k, dv) or paged
" Tensor v,"
" Tensor? k_new,"
// (b, s_k_new, h_k, d) or (total_k_new, h_k, d)
" Tensor? k_new,"
" Tensor? v_new,"
// (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv)
" Tensor? v_new,"
" Tensor? q_v,"
// (b, s_q, h, dv) or (total_q_new, h, dv)
" Tensor? q_v,"
" Tensor?
out,"
// (b, s_q, h, dv) or (total_q, h, dv)
" Tensor
!
? out,"
" Tensor? cu_seqlens_q,"
// b+1
" Tensor? cu_seqlens_q,"
" Tensor? cu_seqlens_k,"
// b+1
" Tensor? cu_seqlens_k,"
" Tensor? cu_seqlens_k_new,"
// b+1
" Tensor? cu_seqlens_k_new,"
" Tensor? seqused_q,"
// b
" Tensor? seqused_q,"
" Tensor? seqused_k,"
// b
" Tensor? seqused_k,"
" int? max_seqlen_q,"
" int? max_seqlen_q,"
" int? max_seqlen_k,"
// TODO: check if needed
" int? max_seqlen_k,"
" Tensor? page_table,"
// (b_k, max_num_pages_per_seq)
" Tensor? page_table,"
" Tensor? kv_batch_idx,"
// b
" Tensor? kv_batch_idx,"
" Tensor? leftpad_k,"
// b
" Tensor? leftpad_k,"
" Tensor? rotary_cos,"
// seqlen_ro x (rotary_dim / 2)
" Tensor? rotary_cos,"
" Tensor? rotary_sin,"
// seqlen_ro x (rotary_dim / 2)
" Tensor? rotary_sin,"
" Tensor? seqlens_rotary,"
// b
" Tensor? seqlens_rotary,"
" Tensor? q_descale,"
// (b, h_k)
" Tensor? q_descale,"
" Tensor? k_descale,"
// (b, h_k)
" Tensor? k_descale,"
" Tensor? v_descale,"
// (b, h_k)
" Tensor? v_descale,"
" float
?
softmax_scale,"
// now optional
" float
softmax_scale,"
" bool is_causal,"
" bool is_causal,"
" int window_size_left,"
" int window_size_left,"
" int window_size_right,"
" int window_size_right,"
" int attention_chunk,"
// NEW
" float softcap,"
" float softcap,"
// promoted to double in C++; schema float is fine
" bool is_rotary_interleaved,"
" bool is_rotary_interleaved,"
" Tensor? scheduler_metadata,"
// (b + 1)
" Tensor? scheduler_metadata,"
" int num_splits,"
" int num_splits,"
" bool? pack_gqa,"
" bool? pack_gqa,"
" int sm_margin,"
" int sm_margin,"
" Tensor? sinks"
" Tensor? sinks) -> Tensor[]"
);
") -> (Tensor, Tensor, Tensor, Tensor)"
);
// NEW return type: tuple of 4 tensors
m
.
impl
(
"fwd"
,
torch
::
kCUDA
,
make_pytorch_shim
(
&
mha_fwd
));
m
.
impl
(
"fwd"
,
torch
::
kCUDA
,
make_pytorch_shim
(
&
mha_fwd
));
}
}
...
...
sgl-kernel/include/sgl_flash_kernel_ops.h
View file @
f792e3c5
...
@@ -42,44 +42,45 @@ limitations under the License.
...
@@ -42,44 +42,45 @@ limitations under the License.
/*
/*
* From flash-attention
* From flash-attention
*/
*/
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
mha_fwd
(
std
::
vector
<
at
::
Tensor
>
mha_fwd
(
at
::
Tensor
q
,
// (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
at
::
Tensor
&
q
,
// (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
at
::
Tensor
k
,
// (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
const
at
::
Tensor
&
k
,
// (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
// h_k, d) if there is page_table.
// h_k, d) if there is page_table.
at
::
Tensor
v
,
// (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
const
at
::
Tensor
&
v
,
// (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
// page_size, h_k, dv) if there is page_table.
// page_size, h_k, dv) if there is page_table.
std
::
optional
<
at
::
Tensor
>
k_new_
,
// (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
std
::
optional
<
const
at
::
Tensor
>&
std
::
optional
<
at
::
Tensor
>
v_new_
,
// (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
k_new_
,
// (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
std
::
optional
<
at
::
Tensor
>
q_v_
,
// (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std
::
optional
<
const
at
::
Tensor
>&
std
::
optional
<
at
::
Tensor
>
out_
,
// (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
v_new_
,
// (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
std
::
optional
<
at
::
Tensor
>
cu_seqlens_q_
,
// b+1
std
::
optional
<
const
at
::
Tensor
>&
q_v_
,
// (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
std
::
optional
<
at
::
Tensor
>
cu_seqlens_k_
,
// b+1
std
::
optional
<
at
::
Tensor
>&
out_
,
// (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
std
::
optional
<
at
::
Tensor
>
cu_seqlens_k_new_
,
// b+1
std
::
optional
<
const
at
::
Tensor
>&
cu_seqlens_q_
,
// b+1
std
::
optional
<
at
::
Tensor
>
std
::
optional
<
const
at
::
Tensor
>&
cu_seqlens_k_
,
// b+1
std
::
optional
<
const
at
::
Tensor
>&
cu_seqlens_k_new_
,
// b+1
std
::
optional
<
const
at
::
Tensor
>&
seqused_q_
,
// b. If given, only this many elements of each batch element's queries and outputs are used.
seqused_q_
,
// b. If given, only this many elements of each batch element's queries and outputs are used.
std
::
optional
<
at
::
Tensor
>
std
::
optional
<
const
at
::
Tensor
>
&
seqused_k_
,
// b. If given, only this many elements of each batch element's keys are used.
seqused_k_
,
// b. If given, only this many elements of each batch element's keys are used.
std
::
optional
<
int
64_t
>
max_seqlen_q_
,
std
::
optional
<
int
>
max_seqlen_q_
,
// TODO: check if we need max_seqlen_k
// TODO: check if we need max_seqlen_k
std
::
optional
<
int
64_t
>
max_seqlen_k_
,
std
::
optional
<
int
>
max_seqlen_k_
,
std
::
optional
<
at
::
Tensor
>
page_table_
,
// (b_k, max_num_pages_per_seq)
std
::
optional
<
const
at
::
Tensor
>
&
page_table_
,
// (b_k, max_num_pages_per_seq)
std
::
optional
<
at
::
Tensor
>
kv_batch_idx_
,
// b. indices to index into the KV cache
std
::
optional
<
const
at
::
Tensor
>
&
kv_batch_idx_
,
// b. indices to index into the KV cache
std
::
optional
<
at
::
Tensor
>
leftpad_k_
,
// b
std
::
optional
<
const
at
::
Tensor
>
&
leftpad_k_
,
// b
std
::
optional
<
at
::
Tensor
>
rotary_cos_
,
// seqlen_ro x (rotary_dim / 2)
std
::
optional
<
const
at
::
Tensor
>
&
rotary_cos_
,
// seqlen_ro x (rotary_dim / 2)
std
::
optional
<
at
::
Tensor
>
rotary_sin_
,
// seqlen_ro x (rotary_dim / 2)
std
::
optional
<
const
at
::
Tensor
>
&
rotary_sin_
,
// seqlen_ro x (rotary_dim / 2)
std
::
optional
<
at
::
Tensor
>
seqlens_rotary_
,
// b
std
::
optional
<
const
at
::
Tensor
>
&
seqlens_rotary_
,
// b
std
::
optional
<
at
::
Tensor
>
q_descale_
,
// (b, h_k), not (b, h)
std
::
optional
<
at
::
Tensor
>
&
q_descale_
,
// (b, h_k), not (b, h)
std
::
optional
<
at
::
Tensor
>
k_descale_
,
// (b, h_k)
std
::
optional
<
at
::
Tensor
>
&
k_descale_
,
// (b, h_k)
std
::
optional
<
at
::
Tensor
>
v_descale_
,
// (b, h_k)
std
::
optional
<
at
::
Tensor
>
&
v_descale_
,
// (b, h_k)
std
::
optional
<
double
>
softmax_scale
_
,
float
const
softmax_scale
,
bool
is_causal
,
bool
is_causal
,
int64_t
window_size_left
,
int
window_size_left
,
int64_t
window_size_right
,
int
window_size_right
,
int64_t
attention_chunk
,
float
const
softcap
,
double
softcap
,
bool
const
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
bool
is_rotary_interleaved
,
// if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
std
::
optional
<
at
::
Tensor
>&
scheduler_metadata_
,
// (b + 1)
std
::
optional
<
at
::
Tensor
>
scheduler_metadata_
,
// (b + 1)
int
num_splits
,
int64_t
num_splits
,
std
::
optional
<
bool
>
pack_gqa_
,
std
::
optional
<
bool
>
pack_gqa_
,
int
64_
t
sm_margin
,
int
cons
t
sm_margin
,
std
::
optional
<
const
at
::
Tensor
>&
sinks_
);
// (h)
std
::
optional
<
const
at
::
Tensor
>&
sinks_
);
sgl-kernel/python/sgl_kernel/flash_attn.py
View file @
f792e3c5
...
@@ -43,7 +43,7 @@ def flash_attn_with_kvcache(
...
@@ -43,7 +43,7 @@ def flash_attn_with_kvcache(
qv
=
None
,
qv
=
None
,
rotary_cos
=
None
,
rotary_cos
=
None
,
rotary_sin
=
None
,
rotary_sin
=
None
,
cache_seqlens
:
Optional
[
Union
[
int
,
torch
.
Tensor
]]
=
None
,
cache_seqlens
:
Optional
[
Union
[
(
int
,
torch
.
Tensor
)
]]
=
None
,
cache_batch_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_batch_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_leftpad
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_leftpad
:
Optional
[
torch
.
Tensor
]
=
None
,
page_table
:
Optional
[
torch
.
Tensor
]
=
None
,
page_table
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -57,7 +57,6 @@ def flash_attn_with_kvcache(
...
@@ -57,7 +57,6 @@ def flash_attn_with_kvcache(
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
attention_chunk
:
Optional
[
int
]
=
None
,
softcap
=
0.0
,
# 0.0 means deactivated
softcap
=
0.0
,
# 0.0 means deactivated
rotary_interleaved
=
True
,
rotary_interleaved
=
True
,
scheduler_metadata
=
None
,
scheduler_metadata
=
None
,
...
@@ -136,7 +135,6 @@ def flash_attn_with_kvcache(
...
@@ -136,7 +135,6 @@ def flash_attn_with_kvcache(
Default to 1 / sqrt(headdim).
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
attention_chunk: Optional[int]. If not None, splits the query into chunks of this size to save memory.
softcap: float. Anything > 0 activates softcapping attention.
softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
...
@@ -216,7 +214,6 @@ def flash_attn_with_kvcache(
...
@@ -216,7 +214,6 @@ def flash_attn_with_kvcache(
]
]
rotary_cos
,
rotary_sin
=
[
maybe_contiguous
(
x
)
for
x
in
(
rotary_cos
,
rotary_sin
)]
rotary_cos
,
rotary_sin
=
[
maybe_contiguous
(
x
)
for
x
in
(
rotary_cos
,
rotary_sin
)]
rotary_seqlens
=
maybe_contiguous
(
rotary_seqlens
)
rotary_seqlens
=
maybe_contiguous
(
rotary_seqlens
)
attention_chunk
=
0
if
attention_chunk
is
None
else
int
(
attention_chunk
)
out
,
softmax_lse
,
*
rest
=
torch
.
ops
.
sgl_kernel
.
fwd
.
default
(
out
,
softmax_lse
,
*
rest
=
torch
.
ops
.
sgl_kernel
.
fwd
.
default
(
q
,
q
,
...
@@ -246,7 +243,6 @@ def flash_attn_with_kvcache(
...
@@ -246,7 +243,6 @@ def flash_attn_with_kvcache(
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
attention_chunk
,
softcap
,
softcap
,
rotary_interleaved
,
rotary_interleaved
,
scheduler_metadata
,
scheduler_metadata
,
...
@@ -276,7 +272,6 @@ def flash_attn_varlen_func(
...
@@ -276,7 +272,6 @@ def flash_attn_varlen_func(
k_descale
=
None
,
k_descale
=
None
,
v_descale
=
None
,
v_descale
=
None
,
window_size
=
(
-
1
,
-
1
),
window_size
=
(
-
1
,
-
1
),
attention_chunk
:
Optional
[
int
]
=
None
,
softcap
=
0.0
,
softcap
=
0.0
,
num_splits
=
1
,
num_splits
=
1
,
pack_gqa
=
None
,
pack_gqa
=
None
,
...
@@ -326,7 +321,6 @@ def flash_attn_varlen_func(
...
@@ -326,7 +321,6 @@ def flash_attn_varlen_func(
softmax_scale
=
(
q
.
shape
[
-
1
]
+
(
qv
.
shape
[
-
1
]
if
qv
is
not
None
else
0
))
**
(
softmax_scale
=
(
q
.
shape
[
-
1
]
+
(
qv
.
shape
[
-
1
]
if
qv
is
not
None
else
0
))
**
(
-
0.5
-
0.5
)
)
attention_chunk
=
0
if
attention_chunk
is
None
else
int
(
attention_chunk
)
out
,
softmax_lse
,
*
rest
=
torch
.
ops
.
sgl_kernel
.
fwd
.
default
(
out
,
softmax_lse
,
*
rest
=
torch
.
ops
.
sgl_kernel
.
fwd
.
default
(
q
,
q
,
...
@@ -356,7 +350,6 @@ def flash_attn_varlen_func(
...
@@ -356,7 +350,6 @@ def flash_attn_varlen_func(
causal
,
causal
,
window_size
[
0
],
window_size
[
0
],
window_size
[
1
],
window_size
[
1
],
attention_chunk
,
softcap
,
softcap
,
is_rotary_interleaved
=
False
,
is_rotary_interleaved
=
False
,
scheduler_metadata
=
None
,
scheduler_metadata
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment