Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
103f1ec8
Unverified
Commit
103f1ec8
authored
Aug 20, 2025
by
Calvin Chen
Committed by
GitHub
Aug 20, 2025
Browse files
[Model] use autoWeightsLoader for gptoss (#22446)
Signed-off-by:
calvin chen
<
wen.chen@dynamia.ai
>
parent
d983769c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
224 additions
and
208 deletions
+224
-208
vllm/model_executor/models/gpt_oss.py
vllm/model_executor/models/gpt_oss.py
+224
-208
No files found.
vllm/model_executor/models/gpt_oss.py
View file @
103f1ec8
...
...
@@ -27,7 +27,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
cdiv
from
.utils
import
extract_layer_index
,
maybe_prefix
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
extract_layer_index
,
maybe_prefix
)
class
OAIAttention
(
nn
.
Module
):
...
...
@@ -203,6 +204,7 @@ class GptOssModel(nn.Module):
super
().
__init__
()
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
quant_config
=
vllm_config
.
quant_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
config
.
hidden_size
=
self
.
config
.
hidden_size
self
.
embedding
=
VocabParallelEmbedding
(
self
.
config
.
vocab_size
,
...
...
@@ -225,64 +227,26 @@ class GptOssModel(nn.Module):
x
=
self
.
norm
(
x
)
return
x
class
GptOssForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
.
hf_config
self
.
model
=
GptOssModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
),
)
self
.
lm_head
=
ParallelLMHead
(
self
.
model_config
.
vocab_size
,
self
.
model_config
.
hidden_size
,
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
model_config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
intermediate_tensors
is
None
assert
inputs_embeds
is
None
return
self
.
model
(
input_ids
,
positions
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
_load_weights_mxfp4
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
rename_mapping
=
{
"self_attn"
:
"attn"
,
"input_layernorm.weight"
:
"attn.norm.weight"
,
"post_attention_layernorm.weight"
:
"mlp.norm.weight"
,
"embed_tokens"
:
"embedding"
,
}
def
maybe_rename
(
name
:
str
)
->
str
:
for
remap_name
,
new_name
in
rename_mapping
.
items
():
if
remap_name
in
name
:
return
name
.
replace
(
remap_name
,
new_name
)
return
name
self
,
ep_rank_end
:
int
,
ep_rank_start
:
int
,
heads_per_rank
:
int
,
head_start
:
int
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]],
stacked_params_mapping
:
list
[
tuple
[
str
,
...]],
)
->
set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
mxfp4_block
=
32
use_ep
=
self
.
parallel_config
.
enable_expert_parallel
num_experts
=
self
.
config
.
num_local_experts
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
intermediate_size
=
self
.
model_config
.
intermediate_size
intermediate_size
=
self
.
config
.
intermediate_size
intermediate_size_block
=
intermediate_size
//
mxfp4_block
per_rank_intermediate_size_block
=
cdiv
(
intermediate_size_block
,
tp_size
)
...
...
@@ -294,33 +258,12 @@ class GptOssForCausalLM(nn.Module):
tp_rank_end
=
min
((
tp_rank
+
1
)
*
per_rank_intermediate_size
,
intermediate_size
)
# Attention heads per rank
heads_per_rank
=
self
.
model_config
.
num_attention_heads
//
tp_size
head_start
=
tp_rank
*
heads_per_rank
use_ep
=
self
.
vllm_config
.
parallel_config
.
enable_expert_parallel
ep_size
=
get_ep_group
().
world_size
ep_rank
=
get_ep_group
().
rank
num_experts
=
self
.
model_config
.
num_local_experts
experts_per_rank
=
num_experts
//
ep_size
ep_rank_start
=
ep_rank
*
experts_per_rank
ep_rank_end
=
(
ep_rank
+
1
)
*
experts_per_rank
for
name
,
weight
in
weights
:
# FIXME(woosuk): Remove this after testing.
weight
=
weight
.
cuda
()
if
"gate_up_proj_blocks"
in
name
:
# Handle MLP gate and up projection weights
new_name
=
name
.
replace
(
"gate_up_proj_blocks"
,
"w13_weight"
)
# flat weight from (E, 2 * N, block_size, entry_per_block)
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
weight
=
weight
.
view
(
num_experts
,
2
*
intermediate_size
,
-
1
).
contiguous
()
# Extract gate and up projection parts
# since the weight is shuffled, we can slice directly
if
".w13_weight_scale"
in
name
:
# Handle MLP gate and up projection weights scale
if
use_ep
:
narrow_weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
else
:
...
...
@@ -328,43 +271,44 @@ class GptOssForCausalLM(nn.Module):
2
*
tp_rank_start
:
2
*
tp_rank_end
,
...]
param
=
params_dict
[
new_
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
narrow_weight
,
weight_name
=
new_
name
,
weight_name
=
name
,
shard_id
=
None
,
expert_id
=
None
)
loaded_params
.
add
(
new_
name
)
elif
"
down_proj_blocks
"
in
name
:
loaded_params
.
add
(
name
)
continue
elif
"
.w2_weight_scale
"
in
name
:
# Handle MLP down projection weights
new_name
=
name
.
replace
(
"down_proj_blocks"
,
"w2_weight"
)
# same flatten here, but since 2 mx4 value are packed in 1
# uint8, divide by 2
weight
=
weight
.
view
(
num_experts
,
-
1
,
intermediate_size
//
2
).
contiguous
()
if
use_ep
:
narrow_weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
else
:
narrow_weight
=
weight
[...,
tp_rank_start
//
2
:
tp_rank_end
//
2
]
narrow_weight
=
weight
[...,
tp_rank_start
//
mxfp4_block
:
tp_rank_end
//
mxfp4_block
]
param
=
params_dict
[
new_
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
narrow_weight
,
weight_name
=
new_
name
,
weight_name
=
name
,
shard_id
=
None
,
expert_id
=
None
)
loaded_params
.
add
(
new_name
)
loaded_params
.
add
(
name
)
continue
elif
".w13_weight"
in
name
:
# Handle MLP gate and up projection weights
# flat weight from (E, 2 * N, block_size, entry_per_block)
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
weight
=
weight
.
view
(
num_experts
,
2
*
intermediate_size
,
-
1
).
contiguous
()
elif
"gate_up_proj_scales"
in
name
:
# Handle MLP gate and up projection weights scale
new_name
=
name
.
replace
(
"gate_up_proj_scales"
,
"w13_weight_scale"
)
# Extract gate and up projection parts
# since the weight is shuffled, we can slice directly
if
use_ep
:
narrow_weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
else
:
...
...
@@ -372,39 +316,40 @@ class GptOssForCausalLM(nn.Module):
2
*
tp_rank_start
:
2
*
tp_rank_end
,
...]
param
=
params_dict
[
new_
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
narrow_weight
,
weight_name
=
new_
name
,
weight_name
=
name
,
shard_id
=
None
,
expert_id
=
None
)
loaded_params
.
add
(
new_
name
)
elif
"
down_proj_scales
"
in
name
:
loaded_params
.
add
(
name
)
continue
elif
"
.w2_weight
"
in
name
:
# Handle MLP down projection weights
new_name
=
name
.
replace
(
"down_proj_scales"
,
"w2_weight_scale"
)
# same flatten here, but since 2 mx4 value are packed in 1
# uint8, divide by 2
weight
=
weight
.
view
(
num_experts
,
-
1
,
intermediate_size
//
2
).
contiguous
()
if
use_ep
:
narrow_weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
else
:
narrow_weight
=
weight
[...,
tp_rank_start
//
mxfp4_block
:
tp_rank_end
//
mxfp4_block
]
narrow_weight
=
weight
[...,
tp_rank_start
//
2
:
tp_rank_end
//
2
]
param
=
params_dict
[
new_
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
narrow_weight
,
weight_name
=
new_
name
,
weight_name
=
name
,
shard_id
=
None
,
expert_id
=
None
)
loaded_params
.
add
(
new_name
)
elif
"gate_up_proj_bias"
in
name
:
loaded_params
.
add
(
name
)
continue
elif
".w13_bias"
in
name
:
# Handle MLP gate and up projection biases
new_name
=
name
.
replace
(
"gate_up_proj_bias"
,
"w13_bias"
)
# Extract gate and up projection bias parts
if
use_ep
:
narrow_weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
...
...
@@ -412,20 +357,19 @@ class GptOssForCausalLM(nn.Module):
narrow_weight
=
weight
[:,
2
*
tp_rank_start
:
2
*
tp_rank_end
]
param
=
params_dict
[
new_
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
narrow_weight
,
weight_name
=
new_
name
,
weight_name
=
name
,
shard_id
=
None
,
expert_id
=
None
)
loaded_params
.
add
(
new_
name
)
elif
"
down_proj
_bias"
in
name
:
loaded_params
.
add
(
name
)
continue
elif
"
.w2
_bias"
in
name
:
# Handle MLP down projection bias
new_name
=
name
.
replace
(
"down_proj_bias"
,
"w2_bias"
)
param
=
params_dict
[
new_name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
if
use_ep
:
...
...
@@ -436,87 +380,69 @@ class GptOssForCausalLM(nn.Module):
weight
.
zero_
()
weight_loader
(
param
,
weight
,
weight_name
=
new_
name
,
weight_name
=
name
,
shard_id
=
None
,
expert_id
=
None
)
loaded_params
.
add
(
new_name
)
loaded_params
.
add
(
name
)
continue
elif
"sinks"
in
name
:
# Handle attention sinks (distributed across ranks)
name
=
name
.
replace
(
"self_attn"
,
"attn"
)
param
=
params_dict
[
name
]
narrow_weight
=
weight
.
narrow
(
0
,
head_start
,
heads_per_rank
)
param
.
data
.
copy_
(
narrow_weight
)
loaded_params
.
add
(
name
)
elif
"q_proj"
in
name
or
"k_proj"
in
name
or
"v_proj"
in
name
:
shard_id
=
(
"q"
if
"q_proj"
in
name
else
"k"
if
"k_proj"
in
name
else
"v"
)
name
=
name
.
replace
(
"self_attn"
,
"attn"
)
param_name
=
name
.
replace
(
f
"
{
shard_id
}
_proj"
,
"qkv"
)
param
=
params_dict
[
param_name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
weight
,
loaded_shard_id
=
shard_id
)
loaded_params
.
add
(
param_name
)
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
if
weight_loader
==
default_weight_loader
:
weight_loader
(
param
,
weight
)
else
:
weight_loader
(
param
,
weight
,
shard_id
)
break
else
:
# Handle all other weights with potential renaming
renamed_name
=
maybe_rename
(
name
)
if
renamed_name
not
in
params_dict
:
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
renamed_
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
weight
)
loaded_params
.
add
(
renamed_name
)
loaded_params
.
add
(
name
)
return
loaded_params
def
_load_weights_other
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
rename_mapping
=
{
"self_attn"
:
"attn"
,
"input_layernorm.weight"
:
"attn.norm.weight"
,
"post_attention_layernorm.weight"
:
"mlp.norm.weight"
,
"embed_tokens"
:
"embedding"
,
}
def
maybe_rename
(
name
:
str
)
->
str
:
for
remap_name
,
new_name
in
rename_mapping
.
items
():
if
remap_name
in
name
:
return
name
.
replace
(
remap_name
,
new_name
)
return
name
self
,
ep_rank_start
:
int
,
ep_rank_end
:
int
,
heads_per_rank
:
int
,
head_start
:
int
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]],
stacked_params_mapping
:
list
[
tuple
[
str
,
...]],
)
->
set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
use_ep
=
self
.
parallel_config
.
enable_expert_parallel
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
intermediate_size
=
self
.
model_config
.
intermediate_size
intermediate_size
=
self
.
config
.
intermediate_size
per_rank_intermediate_size
=
cdiv
(
intermediate_size
,
tp_size
)
# Calculate common slicing bounds for current rank
tp_rank_start
=
tp_rank
*
per_rank_intermediate_size
tp_rank_end
=
min
((
tp_rank
+
1
)
*
per_rank_intermediate_size
,
intermediate_size
)
# Attention heads per rank
heads_per_rank
=
self
.
model_config
.
num_attention_heads
//
tp_size
head_start
=
tp_rank
*
heads_per_rank
use_ep
=
self
.
vllm_config
.
parallel_config
.
enable_expert_parallel
ep_size
=
get_ep_group
().
world_size
ep_rank
=
get_ep_group
().
rank
num_experts
=
self
.
model_config
.
num_local_experts
experts_per_rank
=
num_experts
//
ep_size
ep_rank_start
=
ep_rank
*
experts_per_rank
ep_rank_end
=
(
ep_rank
+
1
)
*
experts_per_rank
for
name
,
weight
in
weights
:
if
".
experts.gate_up_proj"
in
name
and
"bias"
not
in
name
:
if
".
w13_weight"
in
name
:
# Handle MLP gate and up projection weights
new_name
=
name
.
replace
(
".experts.gate_up_proj"
,
".experts.w13_weight"
)
# Extract gate and up projection parts
# since the weight is shuffled, we can slice directly
if
use_ep
:
narrow_weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
else
:
...
...
@@ -524,30 +450,25 @@ class GptOssForCausalLM(nn.Module):
2
*
tp_rank_start
:
2
*
tp_rank_end
]
narrow_weight
=
narrow_weight
.
permute
(
0
,
2
,
1
).
contiguous
()
param
=
params_dict
[
new_
name
]
param
=
params_dict
[
name
]
param
.
copy_
(
narrow_weight
)
loaded_params
.
add
(
new_
name
)
elif
".
experts.down_proj"
in
name
and
"bias"
not
in
name
:
loaded_params
.
add
(
name
)
continue
elif
".
w2_weight"
in
name
:
# Handle MLP down projection weights
new_name
=
name
.
replace
(
".experts.down_proj"
,
".experts.w2_weight"
)
if
use_ep
:
narrow_weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
else
:
narrow_weight
=
weight
[:,
tp_rank_start
:
tp_rank_end
,
:]
narrow_weight
=
narrow_weight
.
permute
(
0
,
2
,
1
).
contiguous
()
param
=
params_dict
[
new_
name
]
param
=
params_dict
[
name
]
param
.
copy_
(
narrow_weight
)
loaded_params
.
add
(
new_
name
)
elif
"
gate_up_proj
_bias"
in
name
:
loaded_params
.
add
(
name
)
continue
elif
"
.w13
_bias"
in
name
:
# Handle MLP gate and up projection biases
new_name
=
name
.
replace
(
"gate_up_proj_bias"
,
"w13_bias"
)
# Extract gate and up projection bias parts
if
use_ep
:
narrow_weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
...
...
@@ -555,60 +476,155 @@ class GptOssForCausalLM(nn.Module):
narrow_weight
=
weight
[:,
2
*
tp_rank_start
:
2
*
tp_rank_end
]
param
=
params_dict
[
new_name
]
param
=
params_dict
[
name
]
param
.
copy_
(
narrow_weight
)
loaded_params
.
add
(
new_
name
)
elif
"
down_proj
_bias"
in
name
:
loaded_params
.
add
(
name
)
continue
elif
"
.w2
_bias"
in
name
:
# Handle MLP down projection bias
new_name
=
name
.
replace
(
"down_proj_bias"
,
"w2_bias"
)
if
use_ep
:
weight
=
weight
[
ep_rank_start
:
ep_rank_end
,
...]
else
:
# (only load on rank 0 to avoid duplication)
if
tp_rank
!=
0
:
weight
.
zero_
()
param
=
params_dict
[
new_
name
]
param
=
params_dict
[
name
]
param
.
copy_
(
weight
)
loaded_params
.
add
(
new_name
)
loaded_params
.
add
(
name
)
continue
elif
"sinks"
in
name
:
# Handle attention sinks (distributed across ranks)
name
=
name
.
replace
(
"self_attn"
,
"attn"
)
param
=
params_dict
[
name
]
narrow_weight
=
weight
.
narrow
(
0
,
head_start
,
heads_per_rank
)
param
.
data
.
copy_
(
narrow_weight
)
loaded_params
.
add
(
name
)
elif
"q_proj"
in
name
or
"k_proj"
in
name
or
"v_proj"
in
name
:
shard_id
=
(
"q"
if
"q_proj"
in
name
else
"k"
if
"k_proj"
in
name
else
"v"
)
name
=
name
.
replace
(
"self_attn"
,
"attn"
)
param_name
=
name
.
replace
(
f
"
{
shard_id
}
_proj"
,
"qkv"
)
param
=
params_dict
[
param_name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
weight
,
loaded_shard_id
=
shard_id
)
loaded_params
.
add
(
param_name
)
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
if
weight_loader
==
default_weight_loader
:
weight_loader
(
param
,
weight
)
else
:
weight_loader
(
param
,
weight
,
shard_id
)
break
else
:
# Handle all other weights with potential renaming
renamed_name
=
maybe_rename
(
name
)
if
renamed_name
not
in
params_dict
:
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
renamed_
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
weight
)
loaded_params
.
add
(
renamed_name
)
loaded_params
.
add
(
name
)
return
loaded_params
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
quant_method
=
(
self
.
model_config
.
quantization_config
[
'quant_method'
]
if
hasattr
(
self
.
model_config
,
"quantization_config"
)
else
None
)
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv"
,
".q_proj"
,
"q"
),
(
".qkv"
,
".k_proj"
,
"k"
),
(
".qkv"
,
".v_proj"
,
"v"
),
]
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
# Attention heads per rank
heads_per_rank
=
self
.
config
.
num_attention_heads
//
tp_size
head_start
=
tp_rank
*
heads_per_rank
ep_size
=
get_ep_group
().
world_size
ep_rank
=
get_ep_group
().
rank
num_experts
=
self
.
config
.
num_local_experts
experts_per_rank
=
num_experts
//
ep_size
ep_rank_start
=
ep_rank
*
experts_per_rank
ep_rank_end
=
(
ep_rank
+
1
)
*
experts_per_rank
quant_method
=
(
self
.
config
.
quantization_config
[
'quant_method'
]
if
hasattr
(
self
.
config
,
"quantization_config"
)
else
None
)
if
quant_method
==
"mxfp4"
:
return
self
.
_load_weights_mxfp4
(
weights
)
return
self
.
_load_weights_mxfp4
(
ep_rank_end
,
ep_rank_start
,
heads_per_rank
,
head_start
,
weights
,
stacked_params_mapping
)
else
:
return
self
.
_load_weights_other
(
weights
)
return
self
.
_load_weights_other
(
ep_rank_end
,
ep_rank_start
,
heads_per_rank
,
head_start
,
weights
,
stacked_params_mapping
)
class
GptOssForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"qkv"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
]}
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
".self_attn."
:
".attn."
,
".post_attention_layernorm."
:
".mlp.norm."
,
},
orig_to_new_suffix
=
{
".embed_tokens.weight"
:
".embedding.weight"
,
".input_layernorm.weight"
:
".attn.norm.weight"
,
".post_attention_layernorm.weight"
:
".mlp.norm.weight"
,
# MoE MXFP4 weights
".gate_up_proj_blocks"
:
".w13_weight"
,
".down_proj_blocks"
:
".w2_weight"
,
".gate_up_proj_scales"
:
".w13_weight_scale"
,
".down_proj_scales"
:
".w2_weight_scale"
,
# MoE other weights
".gate_up_proj"
:
".w13_weight"
,
".down_proj"
:
".w2_weight"
,
# MoE Bias
".gate_up_proj_bias"
:
".w13_bias"
,
".down_proj_bias"
:
".w2_bias"
,
},
)
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
vllm_config
=
vllm_config
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
model
=
GptOssModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
),
)
self
.
lm_head
=
ParallelLMHead
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
,
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
intermediate_tensors
is
None
assert
inputs_embeds
is
None
return
self
.
model
(
input_ids
,
positions
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment