Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0ff6d1fc
Unverified
Commit
0ff6d1fc
authored
Aug 14, 2025
by
Ke Bao
Committed by
GitHub
Aug 13, 2025
Browse files
Support FA3 backend for gpt-oss (#9028)
parent
4a16a71c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
6 deletions
+24
-6
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+18
-0
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-4
No files found.
python/pyproject.toml
View file @
0ff6d1fc
...
@@ -58,7 +58,7 @@ runtime_common = [
...
@@ -58,7 +58,7 @@ runtime_common = [
srt
=
[
srt
=
[
"sglang[runtime_common]"
,
"sglang[runtime_common]"
,
"sgl-kernel==0.3.4"
,
"sgl-kernel==0.3.4
.post1
"
,
"torch==2.8.0"
,
"torch==2.8.0"
,
"torchaudio==2.8.0"
,
"torchaudio==2.8.0"
,
"torchvision"
,
"torchvision"
,
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
0ff6d1fc
...
@@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend):
# For multi-head latent attention
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
...
@@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch
.
forward_mode
.
is_target_verify
()
and
self
.
topk
>
1
forward_batch
.
forward_mode
.
is_target_verify
()
and
self
.
topk
>
1
)
)
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs
=
{}
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
# Get the appropriate page table based on whether we're using local attention
# Get the appropriate page table based on whether we're using local attention
if
use_local_attn
:
if
use_local_attn
:
local_metadata
=
metadata
.
local_attn_metadata
local_metadata
=
metadata
.
local_attn_metadata
...
@@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
return_softmax_lse
=
use_cascade_attn
,
**
kwargs
,
)
)
if
use_cascade_attn
:
if
use_cascade_attn
:
...
@@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
**
kwargs
,
)
)
o
,
_
=
merge_state_v2_wrapper
(
o
,
_
=
merge_state_v2_wrapper
(
o
,
o
,
...
@@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend):
# For multi-head latent attention
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
...
@@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend):
)
)
causal
=
not
layer
.
is_cross_attention
causal
=
not
layer
.
is_cross_attention
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs
=
{}
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
k_descale
,
v_descale
=
None
,
None
k_descale
,
v_descale
=
None
,
None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# has corresponding quantization method so that layer.k_scale is not None,
...
@@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap
=
layer
.
logit_cap
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
**
kwargs
,
)
)
elif
use_local_attn
:
elif
use_local_attn
:
# Use chunked (local) attention batching for self-attention
# Use chunked (local) attention batching for self-attention
...
@@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap
=
layer
.
logit_cap
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
**
kwargs
,
)
)
else
:
else
:
page_table
=
metadata
.
page_table
page_table
=
metadata
.
page_table
...
@@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
return_softmax_lse
=
use_cascade_attn
,
**
kwargs
,
)
)
if
use_cascade_attn
:
if
use_cascade_attn
:
o
,
softmax_lse
,
*
rest
=
result
o
,
softmax_lse
,
*
rest
=
result
...
@@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
**
kwargs
,
)
)
)
)
o
,
_
=
merge_state_v2
(
o
,
_
=
merge_state_v2
(
...
...
python/sglang/srt/models/gpt_oss.py
View file @
0ff6d1fc
...
@@ -294,7 +294,7 @@ class GptOssAttention(nn.Module):
...
@@ -294,7 +294,7 @@ class GptOssAttention(nn.Module):
)
)
self
.
sinks
=
nn
.
Parameter
(
self
.
sinks
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_heads
,
dtype
=
torch
.
float
32
),
requires_grad
=
False
torch
.
empty
(
self
.
num_heads
,
dtype
=
torch
.
b
float
16
),
requires_grad
=
False
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
...
...
python/sglang/srt/server_args.py
View file @
0ff6d1fc
...
@@ -2106,10 +2106,10 @@ class ServerArgs:
...
@@ -2106,10 +2106,10 @@ class ServerArgs:
if
model_arch
in
[
"GptOssForCausalLM"
]:
if
model_arch
in
[
"GptOssForCausalLM"
]:
if
self
.
attention_backend
is
None
:
if
self
.
attention_backend
is
None
:
self
.
attention_backend
=
"triton"
self
.
attention_backend
=
"triton"
assert
self
.
attention_backend
in
[
supported_backends
=
[
"triton"
,
"trtllm_mha"
,
"fa3"
]
"triton"
,
assert
(
"trtllm_mha"
,
self
.
attention_backend
in
supported_backends
]
,
f
"GptOssForCausalLM requires
'trit
on
'
o
r 'trtllm_mha'
attention backend, but got
{
self
.
attention_backend
}
"
)
,
f
"GptOssForCausalLM requires on
e
o
f
{
supported_backends
}
attention backend, but got
'
{
self
.
attention_backend
}
'
"
quantization_config
=
getattr
(
hf_config
,
"quantization_config"
,
None
)
quantization_config
=
getattr
(
hf_config
,
"quantization_config"
,
None
)
is_mxfp4_quant_format
=
(
is_mxfp4_quant_format
=
(
quantization_config
is
not
None
quantization_config
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment