Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
de98252f
Unverified
Commit
de98252f
authored
Aug 05, 2025
by
Woosuk Kwon
Committed by
GitHub
Aug 05, 2025
Browse files
Add GPT-OSS model code and config [1/N] (#22327)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
796bae07
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
503 additions
and
0 deletions
+503
-0
tests/models/registry.py
tests/models/registry.py
+1
-0
vllm/model_executor/models/config.py
vllm/model_executor/models/config.py
+29
-0
vllm/model_executor/models/gpt_oss.py
vllm/model_executor/models/gpt_oss.py
+472
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
No files found.
tests/models/registry.py
View file @
de98252f
...
@@ -197,6 +197,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -197,6 +197,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
{
"6b"
:
"EleutherAI/gpt-j-6b"
}),
{
"6b"
:
"EleutherAI/gpt-j-6b"
}),
"GPTNeoXForCausalLM"
:
_HfExamplesInfo
(
"EleutherAI/pythia-70m"
,
"GPTNeoXForCausalLM"
:
_HfExamplesInfo
(
"EleutherAI/pythia-70m"
,
{
"1b"
:
"EleutherAI/pythia-1.4b"
}),
{
"1b"
:
"EleutherAI/pythia-1.4b"
}),
"GptOssForCausalLM"
:
_HfExamplesInfo
(
"openai/gpt-oss-20b"
),
"GraniteForCausalLM"
:
_HfExamplesInfo
(
"ibm/PowerLM-3b"
),
"GraniteForCausalLM"
:
_HfExamplesInfo
(
"ibm/PowerLM-3b"
),
"GraniteMoeForCausalLM"
:
_HfExamplesInfo
(
"ibm/PowerMoE-3b"
),
"GraniteMoeForCausalLM"
:
_HfExamplesInfo
(
"ibm/PowerMoE-3b"
),
"GraniteMoeHybridForCausalLM"
:
_HfExamplesInfo
(
"ibm-granite/granite-4.0-tiny-preview"
),
# noqa: E501
"GraniteMoeHybridForCausalLM"
:
_HfExamplesInfo
(
"ibm-granite/granite-4.0-tiny-preview"
),
# noqa: E501
...
...
vllm/model_executor/models/config.py
View file @
de98252f
...
@@ -247,6 +247,34 @@ class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig):
...
@@ -247,6 +247,34 @@ class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig):
config
.
max_model_len
)
config
.
max_model_len
)
class
GptOssConfig
(
VerifyAndUpdateConfig
):
@
staticmethod
def
verify_and_update_config
(
vllm_config
:
"VllmConfig"
)
->
None
:
decoding_config
=
vllm_config
.
decoding_config
if
decoding_config
.
reasoning_backend
==
""
:
decoding_config
.
reasoning_backend
=
"openai"
# Increase the max capture size from 512 to 1024 for performance.
# NOTE(woosuk): This will increase the number of CUDA graphs
# from 67 to 83.
scheduler_config
=
vllm_config
.
scheduler_config
if
len
(
scheduler_config
.
cuda_graph_sizes
)
==
1
:
max_capture_size
=
scheduler_config
.
cuda_graph_sizes
[
0
]
# FIXME(woosuk): When using full cuda graph with FA3, the max
# supported size is 992.
if
max_capture_size
<
1024
:
cuda_graph_sizes
=
[
1
,
2
,
4
]
# Step size 8 for small batch sizes
cuda_graph_sizes
+=
[
i
for
i
in
range
(
8
,
256
,
8
)]
# Step size 16 for larger batch sizes
cuda_graph_sizes
+=
[
i
for
i
in
range
(
256
,
1025
,
16
)]
scheduler_config
.
cuda_graph_sizes
=
cuda_graph_sizes
logger
.
info
(
"Overriding max cuda graph capture size to "
"%d for performance."
,
1024
)
class
HybridAttentionMambaModelConfig
(
VerifyAndUpdateConfig
):
class
HybridAttentionMambaModelConfig
(
VerifyAndUpdateConfig
):
@
classmethod
@
classmethod
...
@@ -345,4 +373,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
...
@@ -345,4 +373,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"JinaVLForRanking"
:
JinaVLForSequenceClassificationConfig
,
"JinaVLForRanking"
:
JinaVLForSequenceClassificationConfig
,
"JambaForSequenceClassification"
:
JambaForSequenceClassificationConfig
,
"JambaForSequenceClassification"
:
JambaForSequenceClassificationConfig
,
"GraniteMoeHybridForCausalLM"
:
GraniteMoeHybridModelConfig
,
"GraniteMoeHybridForCausalLM"
:
GraniteMoeHybridModelConfig
,
"GptOssForCausalLM"
:
GptOssConfig
,
}
}
vllm/model_executor/models/gpt_oss.py
0 → 100644
View file @
de98252f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
typing
import
Optional
import
torch
import
torch.distributed
as
dist
from
torch
import
nn
from
transformers
import
GptOssConfig
from
vllm
import
envs
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
(
get_ep_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
cdiv
from
.utils
import
extract_layer_index
,
maybe_prefix
class
OAIAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GptOssConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
layer_idx
=
extract_layer_index
(
prefix
)
self
.
head_dim
=
config
.
head_dim
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
num_key_value_heads
=
config
.
num_key_value_heads
self
.
hidden_size
=
config
.
hidden_size
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
config
.
max_position_embeddings
,
base
=
config
.
rope_theta
,
dtype
=
torch
.
float32
,
rope_scaling
=
{
"rope_type"
:
"yarn"
,
"factor"
:
config
.
rope_scaling
[
"factor"
],
"original_max_position_embeddings"
:
config
.
rope_scaling
[
"original_max_position_embeddings"
],
"beta_fast"
:
config
.
rope_ntk_beta
,
"beta_slow"
:
config
.
rope_ntk_alpha
,
},
is_neox_style
=
True
,
)
tp_size
=
get_tensor_model_parallel_world_size
()
attention_sink_dtype
=
(
torch
.
float32
if
envs
.
VLLM_USE_TRTLLM_CONTEXT_ATTENTION
or
envs
.
VLLM_USE_TRTLLM_DECODE_ATTENTION
else
torch
.
bfloat16
)
self
.
sinks
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
config
.
num_attention_heads
//
tp_size
,
dtype
=
attention_sink_dtype
,
requires_grad
=
False
))
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
self
.
q_size
=
self
.
num_attention_heads
*
self
.
head_dim
//
tp_size
self
.
kv_size
=
self
.
num_key_value_heads
*
self
.
head_dim
//
tp_size
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
config
.
rope_theta
self
.
qkv
=
QKVParallelLinear
(
hidden_size
=
self
.
hidden_size
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
num_attention_heads
,
total_num_kv_heads
=
self
.
num_key_value_heads
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
input_size
=
self
.
num_attention_heads
*
self
.
head_dim
,
output_size
=
self
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
self
.
num_local_attention_heads
=
config
.
num_attention_heads
//
tp_size
self
.
num_local_key_value_heads
=
config
.
num_key_value_heads
//
tp_size
# Only apply sliding window to every other layer
sliding_window
=
(
config
.
sliding_window
if
self
.
layer_idx
%
2
==
0
else
None
)
self
.
attn
=
Attention
(
self
.
num_local_attention_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_local_key_value_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
per_layer_sliding_window
=
sliding_window
,
attn_type
=
AttentionType
.
DECODER
,
prefix
=
f
"
{
prefix
}
.attn"
,
sinks
=
self
.
sinks
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
t
=
self
.
norm
(
hidden_states
)
qkv
,
_
=
self
.
qkv
(
t
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
v
=
v
.
contiguous
()
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
+
hidden_states
class
MLPBlock
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
:
GptOssConfig
,
layer_idx
:
int
,
quant_config
:
QuantizationConfig
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
num_experts
=
config
.
num_local_experts
self
.
experts_per_token
=
config
.
num_experts_per_tok
self
.
world_size
=
dist
.
get_world_size
()
if
dist
.
is_initialized
()
else
1
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
1e-5
)
self
.
router
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_local_experts
,
dtype
=
torch
.
bfloat16
)
assert
config
.
intermediate_size
%
self
.
world_size
==
0
self
.
experts
=
FusedMoE
(
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_token
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
reduce_results
=
True
,
renormalize
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.experts"
,
apply_router_weight_on_input
=
False
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
t
=
self
.
norm
(
x
)
g
=
self
.
router
(
t
)
t
=
self
.
experts
(
hidden_states
=
t
,
router_logits
=
g
)
return
x
+
t
class
TransformerBlock
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
:
GptOssConfig
,
quant_config
:
QuantizationConfig
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
layer_idx
=
extract_layer_index
(
prefix
)
self
.
attn
=
OAIAttention
(
config
,
prefix
=
f
"
{
prefix
}
.attn"
)
self
.
mlp
=
MLPBlock
(
config
,
self
.
layer_idx
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
attn_output
=
self
.
attn
(
hidden_states
,
positions
)
output
=
self
.
mlp
(
attn_output
)
return
output
@
support_torch_compile
class
GptOssModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
quant_config
=
vllm_config
.
quant_config
self
.
config
.
hidden_size
=
self
.
config
.
hidden_size
self
.
embedding
=
VocabParallelEmbedding
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
,
)
self
.
layers
=
torch
.
nn
.
ModuleList
([
TransformerBlock
(
self
.
config
,
quant_config
=
self
.
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
f
"block.
{
layer_idx
}
"
),
)
for
layer_idx
in
range
(
self
.
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
self
.
config
.
hidden_size
,
eps
=
1e-5
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
embedding
(
input_ids
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
,
positions
)
x
=
self
.
norm
(
x
)
return
x
class
GptOssForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
.
hf_config
self
.
model
=
GptOssModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
),
)
self
.
lm_head
=
ParallelLMHead
(
self
.
model_config
.
vocab_size
,
self
.
model_config
.
hidden_size
,
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
model_config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
intermediate_tensors
is
None
assert
inputs_embeds
is
None
return
self
.
model
(
input_ids
,
positions
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
rename_mapping
=
{
"self_attn"
:
"attn"
,
"input_layernorm.weight"
:
"attn.norm.weight"
,
"post_attention_layernorm.weight"
:
"mlp.norm.weight"
,
"embed_tokens"
:
"embedding"
,
}
def
maybe_rename
(
name
:
str
)
->
str
:
for
remap_name
,
new_name
in
rename_mapping
.
items
():
if
remap_name
in
name
:
return
name
.
replace
(
remap_name
,
new_name
)
return
name
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
mxfp4_block
=
32
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
intermediate_size
=
self
.
model_config
.
intermediate_size
intermediate_size_block
=
intermediate_size
//
mxfp4_block
per_rank_intermediate_size_block
=
cdiv
(
intermediate_size_block
,
tp_size
)
per_rank_intermediate_size
=
(
per_rank_intermediate_size_block
*
mxfp4_block
)
# Calculate common slicing bounds for current rank
tp_rank_start
=
tp_rank
*
per_rank_intermediate_size
tp_rank_end
=
min
((
tp_rank
+
1
)
*
per_rank_intermediate_size
,
intermediate_size
)
# Attention heads per rank
heads_per_rank
=
self
.
model_config
.
num_attention_heads
//
tp_size
head_start
=
tp_rank
*
heads_per_rank
use_ep
=
self
.
vllm_config
.
parallel_config
.
enable_expert_parallel
ep_size
=
get_ep_group
().
world_size
ep_rank
=
get_ep_group
().
rank
num_experts
=
self
.
model_config
.
num_local_experts
experts_per_rank
=
num_experts
//
ep_size
ep_rank_start
=
ep_rank
*
experts_per_rank
ep_rank_end
=
(
ep_rank
+
1
)
*
experts_per_rank
for
name
,
weight
in
weights
:
# FIXME(woosuk): Remove this after testing.
weight
=
weight
.
cuda
()
if
"gate_up_proj_blocks"
in
name
:
# Handle MLP gate and up projection weights
new_name
=
name
.
replace
(
"gate_up_proj_blocks"
,
"w13_weight"
)
# flat weight from (E, 2 * N, block_size, entry_per_block)
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
weight
=
weight
.
view
(
num_experts
,
2
*
intermediate_size
,
-
1
).
contiguous
()
# Extract gate and up projection parts
# since the weight is shuffled, we can slice directly
if
use_ep
:
narrow_weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
else
:
narrow_weight
=
weight
[:,
2
*
tp_rank_start
:
2
*
tp_rank_end
,
...]
param
=
params_dict
[
new_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
narrow_weight
,
weight_name
=
new_name
,
shard_id
=
None
,
expert_id
=
None
)
loaded_params
.
add
(
new_name
)
elif
"down_proj_blocks"
in
name
:
# Handle MLP down projection weights
new_name
=
name
.
replace
(
"down_proj_blocks"
,
"w2_weight"
)
# same flatten here, but since 2 mx4 value are packed in 1
# uint8, divide by 2
weight
=
weight
.
view
(
num_experts
,
-
1
,
intermediate_size
//
2
).
contiguous
()
if
use_ep
:
narrow_weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
else
:
narrow_weight
=
weight
[...,
tp_rank_start
//
2
:
tp_rank_end
//
2
]
param
=
params_dict
[
new_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
narrow_weight
,
weight_name
=
new_name
,
shard_id
=
None
,
expert_id
=
None
)
loaded_params
.
add
(
new_name
)
elif
"gate_up_proj_scales"
in
name
:
# Handle MLP gate and up projection weights scale
new_name
=
name
.
replace
(
"gate_up_proj_scales"
,
"w13_weight_scale"
)
if
use_ep
:
narrow_weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
else
:
narrow_weight
=
weight
[:,
2
*
tp_rank_start
:
2
*
tp_rank_end
,
...]
param
=
params_dict
[
new_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
narrow_weight
,
weight_name
=
new_name
,
shard_id
=
None
,
expert_id
=
None
)
loaded_params
.
add
(
new_name
)
elif
"down_proj_scales"
in
name
:
# Handle MLP down projection weights
new_name
=
name
.
replace
(
"down_proj_scales"
,
"w2_weight_scale"
)
if
use_ep
:
narrow_weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
else
:
narrow_weight
=
weight
[...,
tp_rank_start
//
mxfp4_block
:
tp_rank_end
//
mxfp4_block
]
param
=
params_dict
[
new_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
narrow_weight
,
weight_name
=
new_name
,
shard_id
=
None
,
expert_id
=
None
)
loaded_params
.
add
(
new_name
)
elif
"gate_up_proj_bias"
in
name
:
# Handle MLP gate and up projection biases
new_name
=
name
.
replace
(
"gate_up_proj_bias"
,
"w13_bias"
)
# Extract gate and up projection bias parts
if
use_ep
:
narrow_weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
else
:
narrow_weight
=
weight
[:,
2
*
tp_rank_start
:
2
*
tp_rank_end
]
param
=
params_dict
[
new_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
narrow_weight
,
weight_name
=
new_name
,
shard_id
=
None
,
expert_id
=
None
)
loaded_params
.
add
(
new_name
)
elif
"down_proj_bias"
in
name
:
# Handle MLP down projection bias
new_name
=
name
.
replace
(
"down_proj_bias"
,
"w2_bias"
)
param
=
params_dict
[
new_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
if
use_ep
:
weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
else
:
# (only load on rank 0 to avoid duplication)
if
tp_rank
!=
0
:
weight
.
zero_
()
weight_loader
(
param
,
weight
,
weight_name
=
new_name
,
shard_id
=
None
,
expert_id
=
None
)
loaded_params
.
add
(
new_name
)
elif
"sinks"
in
name
:
# Handle attention sinks (distributed across ranks)
name
=
name
.
replace
(
"self_attn"
,
"attn"
)
param
=
params_dict
[
name
]
narrow_weight
=
weight
.
narrow
(
0
,
head_start
,
heads_per_rank
)
param
.
data
.
copy_
(
narrow_weight
)
loaded_params
.
add
(
name
)
elif
"q_proj"
in
name
or
"k_proj"
in
name
or
"v_proj"
in
name
:
shard_id
=
(
"q"
if
"q_proj"
in
name
else
"k"
if
"k_proj"
in
name
else
"v"
)
name
=
name
.
replace
(
"self_attn"
,
"attn"
)
param_name
=
name
.
replace
(
f
"
{
shard_id
}
_proj"
,
"qkv"
)
param
=
params_dict
[
param_name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
weight
,
loaded_shard_id
=
shard_id
)
loaded_params
.
add
(
param_name
)
else
:
# Handle all other weights with potential renaming
renamed_name
=
maybe_rename
(
name
)
if
renamed_name
not
in
params_dict
:
continue
param
=
params_dict
[
renamed_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
weight
)
loaded_params
.
add
(
renamed_name
)
return
loaded_params
vllm/model_executor/models/registry.py
View file @
de98252f
...
@@ -74,6 +74,7 @@ _TEXT_GENERATION_MODELS = {
...
@@ -74,6 +74,7 @@ _TEXT_GENERATION_MODELS = {
"GlmForCausalLM"
:
(
"glm"
,
"GlmForCausalLM"
),
"GlmForCausalLM"
:
(
"glm"
,
"GlmForCausalLM"
),
"Glm4ForCausalLM"
:
(
"glm4"
,
"Glm4ForCausalLM"
),
"Glm4ForCausalLM"
:
(
"glm4"
,
"Glm4ForCausalLM"
),
"Glm4MoeForCausalLM"
:
(
"glm4_moe"
,
"Glm4MoeForCausalLM"
),
"Glm4MoeForCausalLM"
:
(
"glm4_moe"
,
"Glm4MoeForCausalLM"
),
"GptOssForCausalLM"
:
(
"gpt_oss"
,
"GptOssForCausalLM"
),
"GPT2LMHeadModel"
:
(
"gpt2"
,
"GPT2LMHeadModel"
),
"GPT2LMHeadModel"
:
(
"gpt2"
,
"GPT2LMHeadModel"
),
"GPTBigCodeForCausalLM"
:
(
"gpt_bigcode"
,
"GPTBigCodeForCausalLM"
),
"GPTBigCodeForCausalLM"
:
(
"gpt_bigcode"
,
"GPTBigCodeForCausalLM"
),
"GPTJForCausalLM"
:
(
"gpt_j"
,
"GPTJForCausalLM"
),
"GPTJForCausalLM"
:
(
"gpt_j"
,
"GPTJForCausalLM"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment