Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
29283eaa
Unverified
Commit
29283eaa
authored
Apr 06, 2025
by
Jonghyun Choe
Committed by
GitHub
Apr 05, 2025
Browse files
[Model] use AutoWeightsLoader for phi, gemma, deepseek (#16088)
Signed-off-by:
Jonghyun Choe
<
andy.choe729@gmail.com
>
parent
2fa66ef7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
147 additions
and
131 deletions
+147
-131
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+54
-48
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+47
-42
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+46
-41
No files found.
vllm/model_executor/models/deepseek.py
View file @
29283eaa
...
@@ -51,7 +51,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -51,7 +51,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
.utils
import
(
extract_layer_index
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -385,6 +386,56 @@ class DeepseekModel(nn.Module):
...
@@ -385,6 +386,56 @@ class DeepseekModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
DeepseekForCausalLM
(
nn
.
Module
,
SupportsPP
):
class
DeepseekForCausalLM
(
nn
.
Module
,
SupportsPP
):
...
@@ -439,50 +490,5 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
...
@@ -439,50 +490,5 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
loader
=
AutoWeightsLoader
(
self
)
# (param_name, shard_name, shard_id)
return
loader
.
load_weights
(
weights
)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
\ No newline at end of file
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/gemma.py
View file @
29283eaa
...
@@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -319,6 +319,46 @@ class GemmaModel(nn.Module):
...
@@ -319,6 +319,46 @@ class GemmaModel(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
GemmaForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
class
GemmaForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
...
@@ -385,44 +425,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -385,44 +425,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
loader
=
AutoWeightsLoader
(
# (param_name, shard_name, shard_id)
self
,
(
"qkv_proj"
,
"q_proj"
,
"q"
),
skip_prefixes
=
([
"lm_head."
]
(
"qkv_proj"
,
"k_proj"
,
"k"
),
if
self
.
config
.
tie_word_embeddings
else
None
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
return
loader
.
load_weights
(
weights
)
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# lm_head is not used in vllm as it is tied with embed_token.
# To prevent errors, skip loading lm_head.weight.
if
"lm_head.weight"
in
name
:
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
vllm/model_executor/models/phi.py
View file @
29283eaa
...
@@ -61,7 +61,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -61,7 +61,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
maybe_prefix
)
...
@@ -249,6 +249,49 @@ class PhiModel(nn.Module):
...
@@ -249,6 +249,49 @@ class PhiModel(nn.Module):
return
hidden_states
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
)
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# pylint: disable=E1136
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
PhiForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
class
PhiForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
...
@@ -317,43 +360,5 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -317,43 +360,5 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
loader
=
AutoWeightsLoader
(
self
)
# (param_name, shard_name, shard_id)
return
loader
.
load_weights
(
weights
)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
)
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# pylint: disable=E1136
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment