Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bf7e3c51
Unverified
Commit
bf7e3c51
authored
Apr 05, 2025
by
Jonghyun Choe
Committed by
GitHub
Apr 04, 2025
Browse files
[Model] use AutoWeightsLoader for baichuan, gpt-neox, mpt (#15939)
Signed-off-by:
Jonghyun Choe
<
andy.choe729@gmail.com
>
parent
a35a8a83
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
119 additions
and
100 deletions
+119
-100
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+57
-48
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+42
-37
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+20
-15
No files found.
vllm/model_executor/models/baichuan.py
View file @
bf7e3c51
...
...
@@ -47,7 +47,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
,
SupportsQuant
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
...
...
@@ -321,6 +321,45 @@ class BaiChuanModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
BaiChuanBaseForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
,
SupportsQuant
):
...
...
@@ -353,6 +392,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
lm_head
.
weight
.
weight_loader
=
self
.
lm_head_weight_loader
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
...
...
@@ -393,17 +433,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
name
==
"lm_head.weight"
:
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
def
lm_head_weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# Unlike Baichuan, Baichuan2 normalizes the head weights.
# Refer to:
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
...
...
@@ -412,34 +446,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
is_baichuan2
=
self
.
config
.
vocab_size
==
125696
if
is_baichuan2
:
loaded_weight
=
torch
.
nn
.
functional
.
normalize
(
loaded_weight
)
loaded_weight
=
torch
.
nn
.
functional
.
normalize
(
loaded_weight
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
default_weight_loader
(
param
,
loaded_weight
)
class
BaichuanForCausalLM
(
BaiChuanBaseForCausalLM
):
...
...
vllm/model_executor/models/gpt_neox.py
View file @
bf7e3c51
...
...
@@ -42,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -241,6 +241,45 @@ class GPTNeoXModel(nn.Module):
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
(
"attention.bias"
in
name
or
"attention.masked_bias"
in
name
or
"rotary_emb.inv_freq"
in
name
):
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using OpenRLHF may include
# these tensors in the checkpoint. Skip them.
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
if
"query_key_value"
in
name
:
# NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
num_heads
=
self
.
config
.
num_attention_heads
if
output_dim
is
not
None
:
loaded_weight_shape
=
loaded_weight
.
shape
loaded_weight
=
loaded_weight
.
view
(
loaded_weight_shape
[:
output_dim
]
+
(
num_heads
,
3
,
-
1
)
+
loaded_weight_shape
[
output_dim
+
1
:])
loaded_weight
=
loaded_weight
.
transpose
(
output_dim
,
output_dim
+
1
)
loaded_weight
=
loaded_weight
.
reshape
(
loaded_weight_shape
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
GPTNeoXForCausalLM
(
nn
.
Module
,
SupportsPP
):
...
...
@@ -297,39 +336,5 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
(
"attention.bias"
in
name
or
"attention.masked_bias"
in
name
or
"rotary_emb.inv_freq"
in
name
):
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using OpenRLHF may include
# these tensors in the checkpoint. Skip them.
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
if
"query_key_value"
in
name
:
# NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
num_heads
=
self
.
config
.
num_attention_heads
if
output_dim
is
not
None
:
loaded_weight_shape
=
loaded_weight
.
shape
loaded_weight
=
loaded_weight
.
view
(
loaded_weight_shape
[:
output_dim
]
+
(
num_heads
,
3
,
-
1
)
+
loaded_weight_shape
[
output_dim
+
1
:])
loaded_weight
=
loaded_weight
.
transpose
(
output_dim
,
output_dim
+
1
)
loaded_weight
=
loaded_weight
.
reshape
(
loaded_weight_shape
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/mpt.py
View file @
bf7e3c51
...
...
@@ -27,7 +27,7 @@ from vllm.sequence import IntermediateTensors
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -266,6 +266,23 @@ class MPTModel(nn.Module):
hidden_states
=
self
.
norm_f
(
hidden_states
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
MPTForCausalLM
(
nn
.
Module
,
SupportsPP
):
...
...
@@ -318,17 +335,5 @@ class MPTForCausalLM(nn.Module, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment