Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0ff6d1fc
Unverified
Commit
0ff6d1fc
authored
Aug 14, 2025
by
Ke Bao
Committed by
GitHub
Aug 13, 2025
Browse files
Support FA3 backend for gpt-oss (#9028)
parent
4a16a71c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
6 deletions
+24
-6
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+18
-0
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-4
No files found.
python/pyproject.toml
View file @
0ff6d1fc
...
...
@@ -58,7 +58,7 @@ runtime_common = [
srt
=
[
"sglang[runtime_common]"
,
"sgl-kernel==0.3.4"
,
"sgl-kernel==0.3.4
.post1
"
,
"torch==2.8.0"
,
"torchaudio==2.8.0"
,
"torchvision"
,
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
0ff6d1fc
...
...
@@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend):
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
k
is
not
None
:
assert
v
is
not
None
...
...
@@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch
.
forward_mode
.
is_target_verify
()
and
self
.
topk
>
1
)
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs
=
{}
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
# Get the appropriate page table based on whether we're using local attention
if
use_local_attn
:
local_metadata
=
metadata
.
local_attn_metadata
...
...
@@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
**
kwargs
,
)
if
use_cascade_attn
:
...
...
@@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
**
kwargs
,
)
o
,
_
=
merge_state_v2_wrapper
(
o
,
...
...
@@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend):
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
k
is
not
None
:
assert
v
is
not
None
...
...
@@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend):
)
causal
=
not
layer
.
is_cross_attention
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs
=
{}
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
k_descale
,
v_descale
=
None
,
None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
...
...
@@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
**
kwargs
,
)
elif
use_local_attn
:
# Use chunked (local) attention batching for self-attention
...
...
@@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
**
kwargs
,
)
else
:
page_table
=
metadata
.
page_table
...
...
@@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
**
kwargs
,
)
if
use_cascade_attn
:
o
,
softmax_lse
,
*
rest
=
result
...
...
@@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
**
kwargs
,
)
)
o
,
_
=
merge_state_v2
(
...
...
python/sglang/srt/models/gpt_oss.py
View file @
0ff6d1fc
...
...
@@ -294,7 +294,7 @@ class GptOssAttention(nn.Module):
)
self
.
sinks
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_heads
,
dtype
=
torch
.
float
32
),
requires_grad
=
False
torch
.
empty
(
self
.
num_heads
,
dtype
=
torch
.
b
float
16
),
requires_grad
=
False
)
self
.
o_proj
=
RowParallelLinear
(
...
...
python/sglang/srt/server_args.py
View file @
0ff6d1fc
...
...
@@ -2106,10 +2106,10 @@ class ServerArgs:
if
model_arch
in
[
"GptOssForCausalLM"
]:
if
self
.
attention_backend
is
None
:
self
.
attention_backend
=
"triton"
assert
self
.
attention_backend
in
[
"triton"
,
"trtllm_mha"
,
]
,
f
"GptOssForCausalLM requires
'trit
on
'
o
r 'trtllm_mha'
attention backend, but got
{
self
.
attention_backend
}
"
supported_backends
=
[
"triton"
,
"trtllm_mha"
,
"fa3"
]
assert
(
self
.
attention_backend
in
supported_backends
)
,
f
"GptOssForCausalLM requires on
e
o
f
{
supported_backends
}
attention backend, but got
'
{
self
.
attention_backend
}
'
"
quantization_config
=
getattr
(
hf_config
,
"quantization_config"
,
None
)
is_mxfp4_quant_format
=
(
quantization_config
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment