Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c5d004aa
Unverified
Commit
c5d004aa
authored
Aug 28, 2025
by
Isotr0py
Committed by
GitHub
Aug 28, 2025
Browse files
[Model] Add PP support and VLM backbone compatability for GPT-OSS (#23680)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
11a7fafa
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
34 deletions
+87
-34
docs/models/supported_models.md
docs/models/supported_models.md
+1
-1
vllm/model_executor/models/gpt_oss.py
vllm/model_executor/models/gpt_oss.py
+86
-33
No files found.
docs/models/supported_models.md
View file @
c5d004aa
...
...
@@ -358,7 +358,7 @@ th {
|
`GPTBigCodeForCausalLM`
| StarCoder, SantaCoder, WizardCoder |
`bigcode/starcoder`
,
`bigcode/gpt_bigcode-santacoder`
,
`WizardLM/WizardCoder-15B-V1.0`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`GPTJForCausalLM`
| GPT-J |
`EleutherAI/gpt-j-6b`
,
`nomic-ai/gpt4all-j`
, etc. | | ✅︎ | ✅︎ |
|
`GPTNeoXForCausalLM`
| GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM |
`EleutherAI/gpt-neox-20b`
,
`EleutherAI/pythia-12b`
,
`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`
,
`databricks/dolly-v2-12b`
,
`stabilityai/stablelm-tuned-alpha-7b`
, etc. | | ✅︎ | ✅︎ |
|
`GptOssForCausalLM`
| GPT-OSS |
`openai/gpt-oss-120b`
,
`openai/gpt-oss-20b`
| | | ✅︎ |
|
`GptOssForCausalLM`
| GPT-OSS |
`openai/gpt-oss-120b`
,
`openai/gpt-oss-20b`
| |
✅︎
| ✅︎ |
|
`GraniteForCausalLM`
| Granite 3.0, Granite 3.1, PowerLM |
`ibm-granite/granite-3.0-2b-base`
,
`ibm-granite/granite-3.1-8b-instruct`
,
`ibm/PowerLM-3b`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`GraniteMoeForCausalLM`
| Granite 3.0 MoE, PowerMoE |
`ibm-granite/granite-3.0-1b-a400m-base`
,
`ibm-granite/granite-3.0-3b-a800m-instruct`
,
`ibm/PowerMoE-3b`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`GraniteMoeHybridForCausalLM`
| Granite 4.0 MoE Hybrid |
`ibm-granite/granite-4.0-tiny-preview`
, etc. | ✅︎ | ✅︎ | ✅︎ |
...
...
vllm/model_executor/models/gpt_oss.py
View file @
c5d004aa
...
...
@@ -11,7 +11,8 @@ from transformers import GptOssConfig
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_ep_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_ep_group
,
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
...
@@ -27,7 +28,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
cdiv
from
.interfaces
import
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -75,8 +79,6 @@ class OAIAttention(nn.Module):
dtype
=
torch
.
bfloat16
,
requires_grad
=
False
))
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
self
.
q_size
=
self
.
num_attention_heads
*
self
.
head_dim
//
tp_size
self
.
kv_size
=
self
.
num_key_value_heads
*
self
.
head_dim
//
tp_size
self
.
scaling
=
self
.
head_dim
**-
0.5
...
...
@@ -119,16 +121,13 @@ class OAIAttention(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
t
=
self
.
norm
(
hidden_states
)
qkv
,
_
=
self
.
qkv
(
t
)
qkv
,
_
=
self
.
qkv
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
v
=
v
.
contiguous
()
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
+
hidden_states
return
output
class
MLPBlock
(
torch
.
nn
.
Module
):
...
...
@@ -145,7 +144,6 @@ class MLPBlock(torch.nn.Module):
self
.
num_experts
=
config
.
num_local_experts
self
.
experts_per_token
=
config
.
num_experts_per_tok
self
.
world_size
=
dist
.
get_world_size
()
if
dist
.
is_initialized
()
else
1
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
self
.
router
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_local_experts
,
dtype
=
torch
.
bfloat16
)
...
...
@@ -163,10 +161,9 @@ class MLPBlock(torch.nn.Module):
activation
=
"swigluoai"
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
t
=
self
.
norm
(
x
)
g
=
self
.
router
(
t
)
t
=
self
.
experts
(
hidden_states
=
t
,
router_logits
=
g
)
return
x
+
t
g
=
self
.
router
(
x
)
x
=
self
.
experts
(
hidden_states
=
x
,
router_logits
=
g
)
return
x
class
TransformerBlock
(
torch
.
nn
.
Module
):
...
...
@@ -187,12 +184,28 @@ class TransformerBlock(torch.nn.Module):
self
.
layer_idx
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
attn_output
=
self
.
attn
(
hidden_states
,
positions
)
output
=
self
.
mlp
(
attn_output
)
return
output
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
attn
(
hidden_states
,
positions
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
output
=
self
.
mlp
(
hidden_states
)
return
output
,
residual
@
support_torch_compile
...
...
@@ -214,22 +227,52 @@ class GptOssModel(nn.Module):
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
,
)
self
.
layers
=
torch
.
nn
.
ModuleList
([
TransformerBlock
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
config
.
num_hidden_layers
,
lambda
prefix
:
TransformerBlock
(
self
.
config
,
cache_config
=
self
.
cache_config
,
quant_config
=
self
.
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
f
"block.
{
layer_idx
}
"
),
)
for
layer_idx
in
range
(
self
.
config
.
num_hidden_layers
)
])
prefix
=
prefix
,
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
self
.
norm
=
RMSNorm
(
self
.
config
.
hidden_size
,
eps
=
1e-5
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
self
.
config
.
hidden_size
))
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
embedding
(
input_ids
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
,
positions
)
x
=
self
.
norm
(
x
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embedding
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
x
=
inputs_embeds
else
:
x
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
x
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
x
,
residual
=
layer
(
x
,
positions
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
x
,
"residual"
:
residual
})
x
,
_
=
self
.
norm
(
x
,
residual
)
return
x
def
_load_weights_mxfp4
(
...
...
@@ -264,6 +307,10 @@ class GptOssModel(nn.Module):
intermediate_size
)
for
name
,
weight
in
weights
:
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
# FIXME(woosuk): Remove this after testing.
weight
=
weight
.
cuda
()
...
...
@@ -445,6 +492,10 @@ class GptOssModel(nn.Module):
intermediate_size
)
for
name
,
weight
in
weights
:
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
if
".w13_weight"
in
name
:
# Handle MLP gate and up projection weights
# Extract gate and up projection parts
...
...
@@ -562,18 +613,15 @@ class GptOssModel(nn.Module):
weights
,
stacked_params_mapping
)
class
GptOssForCausalLM
(
nn
.
Module
):
class
GptOssForCausalLM
(
nn
.
Module
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
]}
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
".self_attn."
:
".attn."
,
".post_attention_layernorm."
:
".mlp.norm."
,
},
orig_to_new_suffix
=
{
".embed_tokens.weight"
:
".embedding.weight"
,
".input_layernorm.weight"
:
".attn.norm.weight"
,
".post_attention_layernorm.weight"
:
".mlp.norm.weight"
,
# MoE MXFP4 weights
".gate_up_proj_blocks"
:
".w13_weight"
,
...
...
@@ -609,6 +657,11 @@ class GptOssForCausalLM(nn.Module):
self
.
config
.
hidden_size
,
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment