Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0ff6d1fc
"examples/vscode:/vscode.git/clone" did not exist on "cbee4278390f2b009e8807022fff37ff9934a17c"
Unverified
Commit
0ff6d1fc
authored
Aug 14, 2025
by
Ke Bao
Committed by
GitHub
Aug 13, 2025
Browse files
Support FA3 backend for gpt-oss (#9028)
parent
4a16a71c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
6 deletions
+24
-6
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+18
-0
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-4
No files found.
python/pyproject.toml
View file @
0ff6d1fc
...
@@ -58,7 +58,7 @@ runtime_common = [
...
@@ -58,7 +58,7 @@ runtime_common = [
srt
=
[
srt
=
[
"sglang[runtime_common]"
,
"sglang[runtime_common]"
,
"sgl-kernel==0.3.4"
,
"sgl-kernel==0.3.4
.post1
"
,
"torch==2.8.0"
,
"torch==2.8.0"
,
"torchaudio==2.8.0"
,
"torchaudio==2.8.0"
,
"torchvision"
,
"torchvision"
,
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
0ff6d1fc
...
@@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend):
# For multi-head latent attention
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
...
@@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch
.
forward_mode
.
is_target_verify
()
and
self
.
topk
>
1
forward_batch
.
forward_mode
.
is_target_verify
()
and
self
.
topk
>
1
)
)
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs
=
{}
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
# Get the appropriate page table based on whether we're using local attention
# Get the appropriate page table based on whether we're using local attention
if
use_local_attn
:
if
use_local_attn
:
local_metadata
=
metadata
.
local_attn_metadata
local_metadata
=
metadata
.
local_attn_metadata
...
@@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
return_softmax_lse
=
use_cascade_attn
,
**
kwargs
,
)
)
if
use_cascade_attn
:
if
use_cascade_attn
:
...
@@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
**
kwargs
,
)
)
o
,
_
=
merge_state_v2_wrapper
(
o
,
_
=
merge_state_v2_wrapper
(
o
,
o
,
...
@@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend):
# For multi-head latent attention
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
...
@@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend):
)
)
causal
=
not
layer
.
is_cross_attention
causal
=
not
layer
.
is_cross_attention
# For fa3 interface version compatibility, we put new fields into conditional keyword args
kwargs
=
{}
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
k_descale
,
v_descale
=
None
,
None
k_descale
,
v_descale
=
None
,
None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# has corresponding quantization method so that layer.k_scale is not None,
...
@@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap
=
layer
.
logit_cap
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
**
kwargs
,
)
)
elif
use_local_attn
:
elif
use_local_attn
:
# Use chunked (local) attention batching for self-attention
# Use chunked (local) attention batching for self-attention
...
@@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend):
softcap
=
layer
.
logit_cap
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
**
kwargs
,
)
)
else
:
else
:
page_table
=
metadata
.
page_table
page_table
=
metadata
.
page_table
...
@@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
return_softmax_lse
=
use_cascade_attn
,
**
kwargs
,
)
)
if
use_cascade_attn
:
if
use_cascade_attn
:
o
,
softmax_lse
,
*
rest
=
result
o
,
softmax_lse
,
*
rest
=
result
...
@@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
**
kwargs
,
)
)
)
)
o
,
_
=
merge_state_v2
(
o
,
_
=
merge_state_v2
(
...
...
python/sglang/srt/models/gpt_oss.py
View file @
0ff6d1fc
...
@@ -294,7 +294,7 @@ class GptOssAttention(nn.Module):
...
@@ -294,7 +294,7 @@ class GptOssAttention(nn.Module):
)
)
self
.
sinks
=
nn
.
Parameter
(
self
.
sinks
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_heads
,
dtype
=
torch
.
float
32
),
requires_grad
=
False
torch
.
empty
(
self
.
num_heads
,
dtype
=
torch
.
b
float
16
),
requires_grad
=
False
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
...
...
python/sglang/srt/server_args.py
View file @
0ff6d1fc
...
@@ -2106,10 +2106,10 @@ class ServerArgs:
...
@@ -2106,10 +2106,10 @@ class ServerArgs:
if
model_arch
in
[
"GptOssForCausalLM"
]:
if
model_arch
in
[
"GptOssForCausalLM"
]:
if
self
.
attention_backend
is
None
:
if
self
.
attention_backend
is
None
:
self
.
attention_backend
=
"triton"
self
.
attention_backend
=
"triton"
assert
self
.
attention_backend
in
[
supported_backends
=
[
"triton"
,
"trtllm_mha"
,
"fa3"
]
"triton"
,
assert
(
"trtllm_mha"
,
self
.
attention_backend
in
supported_backends
]
,
f
"GptOssForCausalLM requires
'trit
on
'
o
r 'trtllm_mha'
attention backend, but got
{
self
.
attention_backend
}
"
)
,
f
"GptOssForCausalLM requires on
e
o
f
{
supported_backends
}
attention backend, but got
'
{
self
.
attention_backend
}
'
"
quantization_config
=
getattr
(
hf_config
,
"quantization_config"
,
None
)
quantization_config
=
getattr
(
hf_config
,
"quantization_config"
,
None
)
is_mxfp4_quant_format
=
(
is_mxfp4_quant_format
=
(
quantization_config
is
not
None
quantization_config
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment