Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e86c414d
Unverified
Commit
e86c414d
authored
Apr 02, 2025
by
rongfu.leng
Committed by
GitHub
Apr 02, 2025
Browse files
[Model] use AutoWeightsLoader in model load_weights (#15770)
Signed-off-by:
rongfu.leng
<
rongfu.leng@daocloud.io
>
parent
550b2801
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
189 additions
and
165 deletions
+189
-165
docs/source/models/supported_models.md
docs/source/models/supported_models.md
+5
-0
vllm/model_executor/models/bamba.py
vllm/model_executor/models/bamba.py
+55
-50
vllm/model_executor/models/exaone.py
vllm/model_executor/models/exaone.py
+74
-66
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+55
-49
No files found.
docs/source/models/supported_models.md
View file @
e86c414d
...
...
@@ -218,6 +218,11 @@ See [this page](#generative-models) for more information on how to use generativ
*
`baichuan-inc/Baichuan2-13B-Chat`
,
`baichuan-inc/Baichuan-7B`
, etc.
*
✅︎
*
✅︎
-
*
`BambaForCausalLM`
*
Bamba
*
`ibm-ai-platform/Bamba-9B-fp8`
,
`ibm-ai-platform/Bamba-9B`
*
*
-
*
`BloomForCausalLM`
*
BLOOM, BLOOMZ, BLOOMChat
*
`bigscience/bloom`
,
`bigscience/bloomz`
, etc.
...
...
vllm/model_executor/models/bamba.py
View file @
e86c414d
...
...
@@ -34,7 +34,7 @@ from vllm.utils import LayerBlockType
from
.interfaces
import
(
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
,
SupportsQuant
,
SupportsV0Only
)
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -363,6 +363,58 @@ class BambaModel(nn.Module):
hidden_states
,
_
=
self
.
final_layernorm
(
hidden_states
,
residual
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"A_log"
in
name
:
name
=
name
.
replace
(
"A_log"
,
"A"
)
if
".self_attn."
in
name
:
name
=
name
.
replace
(
".self_attn"
,
""
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
BambaForCausalLM
(
nn
.
Module
,
HasInnerState
,
SupportsLoRA
,
SupportsPP
,
IsHybrid
,
SupportsV0Only
,
SupportsQuant
):
...
...
@@ -502,52 +554,5 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"A_log"
in
name
:
name
=
name
.
replace
(
"A_log"
,
"A"
)
if
".self_attn."
in
name
:
name
=
name
.
replace
(
".self_attn"
,
""
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
\ No newline at end of file
vllm/model_executor/models/exaone.py
View file @
e86c414d
...
...
@@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors
from
vllm.transformers_utils.configs.exaone
import
ExaoneConfig
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -313,6 +313,7 @@ class ExaoneModel(nn.Module):
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
quant_config
=
quant_config
lora_vocab
=
((
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
)
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
...
...
@@ -384,6 +385,72 @@ class ExaoneModel(nn.Module):
hidden_states
,
_
=
self
.
ln_f
(
hidden_states
,
residual
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".c_fc_0"
,
0
),
(
".gate_up_proj"
,
".c_fc_1"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
# Loading kv cache quantization scales
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
(
loaded_weight
if
loaded_weight
.
dim
()
==
0
else
loaded_weight
[
0
])
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
ExaoneForCausalLM
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
packed_modules_mapping
=
{
...
...
@@ -481,71 +548,12 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".c_fc_0"
,
0
),
(
".gate_up_proj"
,
".c_fc_1"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
loader
=
AutoWeightsLoader
(
self
,
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
# Loading kv cache quantization scales
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
(
loaded_weight
if
loaded_weight
.
dim
()
==
0
else
loaded_weight
[
0
])
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/falcon.py
View file @
e86c414d
...
...
@@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors
from
vllm.transformers_utils.configs
import
RWConfig
from
.interfaces
import
SupportsPP
from
.utils
import
(
is_pp_missing_parameter
,
from
.utils
import
(
AutoWeightsLoader
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -395,6 +395,54 @@ class FalconModel(nn.Module):
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
total_num_heads
=
self
.
config
.
num_attention_heads
if
self
.
config
.
new_decoder_architecture
:
total_num_kv_heads
=
self
.
config
.
num_kv_heads
elif
self
.
config
.
multi_query
:
total_num_kv_heads
=
1
else
:
total_num_kv_heads
=
total_num_heads
num_query_heads_per_kv_head
=
total_num_heads
//
total_num_kv_heads
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
if
"query_key_value"
in
name
:
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
loaded_weight_shape
=
loaded_weight
.
shape
if
output_dim
is
not
None
:
loaded_weight
=
loaded_weight
.
view
(
loaded_weight_shape
[:
output_dim
]
+
(
total_num_kv_heads
,
num_query_heads_per_kv_head
+
2
,
-
1
)
+
loaded_weight_shape
[
output_dim
+
1
:])
wq
=
loaded_weight
.
narrow
(
output_dim
+
1
,
0
,
num_query_heads_per_kv_head
).
reshape
(
*
loaded_weight_shape
[:
output_dim
],
-
1
,
*
loaded_weight_shape
[
output_dim
+
1
:])
wk
=
loaded_weight
.
narrow
(
output_dim
+
1
,
num_query_heads_per_kv_head
,
1
).
reshape
(
*
loaded_weight_shape
[:
output_dim
],
-
1
,
*
loaded_weight_shape
[
output_dim
+
1
:])
wv
=
loaded_weight
.
narrow
(
output_dim
+
1
,
num_query_heads_per_kv_head
+
1
,
1
).
reshape
(
*
loaded_weight_shape
[:
output_dim
],
-
1
,
*
loaded_weight_shape
[
output_dim
+
1
:])
loaded_weight
=
torch
.
cat
([
wq
,
wk
,
wv
],
dim
=
output_dim
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
FalconForCausalLM
(
nn
.
Module
,
SupportsPP
):
packed_modules_mapping
=
{
...
...
@@ -462,51 +510,9 @@ class FalconForCausalLM(nn.Module, SupportsPP):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
total_num_heads
=
self
.
config
.
num_attention_heads
if
self
.
config
.
new_decoder_architecture
:
total_num_kv_heads
=
self
.
config
.
num_kv_heads
elif
self
.
config
.
multi_query
:
total_num_kv_heads
=
1
else
:
total_num_kv_heads
=
total_num_heads
num_query_heads_per_kv_head
=
total_num_heads
//
total_num_kv_heads
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
name
==
"lm_head.weight"
and
self
.
tie_word_embeddings
:
# Falcon uses tied embeddings except Falcon-11b.
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
if
"query_key_value"
in
name
:
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
loaded_weight_shape
=
loaded_weight
.
shape
if
output_dim
is
not
None
:
loaded_weight
=
loaded_weight
.
view
(
loaded_weight_shape
[:
output_dim
]
+
(
total_num_kv_heads
,
num_query_heads_per_kv_head
+
2
,
-
1
)
+
loaded_weight_shape
[
output_dim
+
1
:])
wq
=
loaded_weight
.
narrow
(
output_dim
+
1
,
0
,
num_query_heads_per_kv_head
).
reshape
(
*
loaded_weight_shape
[:
output_dim
],
-
1
,
*
loaded_weight_shape
[
output_dim
+
1
:])
wk
=
loaded_weight
.
narrow
(
output_dim
+
1
,
num_query_heads_per_kv_head
,
1
).
reshape
(
*
loaded_weight_shape
[:
output_dim
],
-
1
,
*
loaded_weight_shape
[
output_dim
+
1
:])
wv
=
loaded_weight
.
narrow
(
output_dim
+
1
,
num_query_heads_per_kv_head
+
1
,
1
).
reshape
(
*
loaded_weight_shape
[:
output_dim
],
-
1
,
*
loaded_weight_shape
[
output_dim
+
1
:])
loaded_weight
=
torch
.
cat
([
wq
,
wk
,
wv
],
dim
=
output_dim
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"lm_head."
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment