Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
0ff6d1fc
"src/targets/vscode:/vscode.git/clone" did not exist on "428e4c94a3afb17deebec1aee6c21639d9780263"
Unverified
Commit
0ff6d1fc
authored
Aug 14, 2025
by
Ke Bao
Committed by
GitHub
Aug 13, 2025
Browse files
Support FA3 backend for gpt-oss (#9028)
parent
4a16a71c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
6 deletions
+24
-6
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+18
-0
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-4
No files found.
python/pyproject.toml
View file @
0ff6d1fc
...
...
@@ -58,7 +58,7 @@ runtime_common = [
srt
=
[
"sglang[runtime_common]"
,
"sgl-kernel==0.3.4"
,
"sgl-kernel==0.3.4
.post1
"
,
"torch==2.8.0"
,
"torchaudio==2.8.0"
,
"torchvision"
,
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
0ff6d1fc
...
...
@@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend):
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
k
is
not
None
:
assert
v
is
not
None
...
...
@@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch
.
forward_mode
.
is_target_verify
()
and
self
.
topk
>
1
)
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs
=
{}
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
# Get the appropriate page table based on whether we're using local attention
if
use_local_attn
:
local_metadata
=
metadata
.
local_attn_metadata
...
...
@@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
**
kwargs
,
)
if
use_cascade_attn
:
...
...
@@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
**
kwargs
,
)
o
,
_
=
merge_state_v2_wrapper
(
o
,
...
...
@@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend):
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
k
is
not
None
:
assert
v
is
not
None
...
...
@@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend):
)
causal
=
not
layer
.
is_cross_attention
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs
=
{}
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
k_descale
,
v_descale
=
None
,
None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
...
...
@@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
**
kwargs
,
)
elif
use_local_attn
:
# Use chunked (local) attention batching for self-attention
...
...
@@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
**
kwargs
,
)
else
:
page_table
=
metadata
.
page_table
...
...
@@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
**
kwargs
,
)
if
use_cascade_attn
:
o
,
softmax_lse
,
*
rest
=
result
...
...
@@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
**
kwargs
,
)
)
o
,
_
=
merge_state_v2
(
...
...
python/sglang/srt/models/gpt_oss.py
View file @
0ff6d1fc
...
...
@@ -294,7 +294,7 @@ class GptOssAttention(nn.Module):
)
self
.
sinks
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_heads
,
dtype
=
torch
.
float
32
),
requires_grad
=
False
torch
.
empty
(
self
.
num_heads
,
dtype
=
torch
.
b
float
16
),
requires_grad
=
False
)
self
.
o_proj
=
RowParallelLinear
(
...
...
python/sglang/srt/server_args.py
View file @
0ff6d1fc
...
...
@@ -2106,10 +2106,10 @@ class ServerArgs:
if
model_arch
in
[
"GptOssForCausalLM"
]:
if
self
.
attention_backend
is
None
:
self
.
attention_backend
=
"triton"
assert
self
.
attention_backend
in
[
"triton"
,
"trtllm_mha"
,
]
,
f
"GptOssForCausalLM requires
'trit
on
'
o
r 'trtllm_mha'
attention backend, but got
{
self
.
attention_backend
}
"
supported_backends
=
[
"triton"
,
"trtllm_mha"
,
"fa3"
]
assert
(
self
.
attention_backend
in
supported_backends
)
,
f
"GptOssForCausalLM requires on
e
o
f
{
supported_backends
}
attention backend, but got
'
{
self
.
attention_backend
}
'
"
quantization_config
=
getattr
(
hf_config
,
"quantization_config"
,
None
)
is_mxfp4_quant_format
=
(
quantization_config
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment