Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
242a637a
Unverified
Commit
242a637a
authored
Apr 06, 2025
by
rongfu.leng
Committed by
GitHub
Apr 06, 2025
Browse files
[Model] use AutoWeightsLoader for stablelm,starcoder2,zamba2 (#16103)
Signed-off-by:
rongfu.leng
<
rongfu.leng@daocloud.io
>
parent
c2a96715
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
135 additions
and
121 deletions
+135
-121
vllm/model_executor/models/stablelm.py
vllm/model_executor/models/stablelm.py
+50
-44
vllm/model_executor/models/starcoder2.py
vllm/model_executor/models/starcoder2.py
+45
-39
vllm/model_executor/models/zamba2.py
vllm/model_executor/models/zamba2.py
+40
-38
No files found.
vllm/model_executor/models/stablelm.py
View file @
242a637a
...
...
@@ -44,7 +44,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -253,6 +253,45 @@ class StableLMEpochModel(nn.Module):
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
StablelmForCausalLM
(
nn
.
Module
,
SupportsPP
):
...
...
@@ -308,46 +347,13 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
loader
=
AutoWeightsLoader
(
self
,
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
skip_prefixes
=
[
"rotary_emb.inv_freq"
,
"rotary_emb.cos_cached"
,
"rotary_emb.sin_cached"
],
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/starcoder2.py
View file @
242a637a
...
...
@@ -45,7 +45,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -256,6 +256,41 @@ class Starcoder2Model(nn.Module):
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
Starcoder2ForCausalLM
(
nn
.
Module
,
SupportsPP
):
...
...
@@ -319,41 +354,12 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
loader
=
AutoWeightsLoader
(
self
,
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
skip_prefixes
=
([
"rotary_emb.inv_freq"
,
"lm_head.weight"
]
if
self
.
config
.
tie_word_embeddings
else
[
"rotary_emb.inv_freq"
]),
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/zamba2.py
View file @
242a637a
...
...
@@ -39,7 +39,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
HasInnerState
,
IsHybrid
,
SupportsV0Only
from
.utils
import
maybe_prefix
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
,
maybe_prefix
class
Zamba2LoRA
(
nn
.
Module
):
...
...
@@ -777,6 +777,37 @@ class Zamba2Model(nn.Module):
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
chkpt_weight_name
,
loaded_weight
in
weights
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
chkpt_weight_name
:
continue
chkpt_weight_name
=
chkpt_weight_name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
chkpt_weight_name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
if
chkpt_weight_name
not
in
params_dict
:
continue
param
=
params_dict
[
chkpt_weight_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
chkpt_weight_name
)
return
loaded_params
class
Zamba2ForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsHybrid
,
SupportsV0Only
):
"""Zamba2 model with causal language modeling head.
...
...
@@ -787,6 +818,12 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
- Support for model parallelism and quantization
- Sampling capabilities for text generation
"""
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
"A_log"
:
"A"
,
"0.weight"
:
"A.weight"
,
"1.weight"
:
"B.weight"
,
})
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
"""Initialize the Zamba2 model for causal language modeling.
...
...
@@ -992,40 +1029,5 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
weights_dict
=
{}
for
key
,
loaded_weight
in
weights
:
if
"A_log"
in
key
:
key
=
key
.
replace
(
"A_log"
,
"A"
)
elif
"adapter_list"
in
key
:
key
=
key
.
replace
(
"0.weight"
,
"A.weight"
)
key
=
key
.
replace
(
"1.weight"
,
"B.weight"
)
weights_dict
[
key
]
=
loaded_weight
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
chkpt_weight_name
,
loaded_weight
in
weights_dict
.
items
():
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
chkpt_weight_name
:
continue
chkpt_weight_name
=
chkpt_weight_name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
chkpt_weight_name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
if
chkpt_weight_name
not
in
params_dict
:
continue
param
=
params_dict
[
chkpt_weight_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
chkpt_weight_name
)
return
loaded_params
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment