Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
51ba8395
Unverified
Commit
51ba8395
authored
Jul 20, 2025
by
Calvin Chen
Committed by
GitHub
Jul 20, 2025
Browse files
[Model] use AutoWeightsLoader for bart (#18299)
Signed-off-by:
calvin chen
<
120380290@qq.com
>
parent
d1fb65bd
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
71 additions
and
101 deletions
+71
-101
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+71
-101
No files found.
vllm/model_executor/models/bart.py
View file @
51ba8395
...
...
@@ -46,7 +46,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsQuant
,
SupportsV0Only
from
.utils
import
maybe_prefix
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
,
maybe_prefix
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -700,7 +700,8 @@ class BartDecoder(nn.Module):
class
BartModel
(
nn
.
Module
,
SupportsQuant
):
_tied_weights_keys
=
[
"encoder.embed_tokens.weight"
,
"decoder.embed_tokens.weight"
"encoder.embed_tokens.weight"
,
"decoder.embed_tokens.weight"
,
]
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
@@ -763,10 +764,54 @@ class BartModel(nn.Module, SupportsQuant):
return
decoder_outputs
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
other_weights
=
[]
loaded_stacked_params
=
[]
model_params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
name
not
in
model_params_dict
:
continue
param
=
model_params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_stacked_params
.
append
(
name
)
break
else
:
if
name
in
model_params_dict
:
other_weights
.
append
((
name
,
loaded_weight
))
loader
=
AutoWeightsLoader
(
self
)
loaded_params
=
loader
.
load_weights
(
other_weights
)
loaded_params
.
update
(
loaded_stacked_params
)
return
loaded_params
class
BartForConditionalGeneration
(
nn
.
Module
,
SupportsV0Only
,
SupportsQuant
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
]}
base_model_prefix
=
"model"
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"decoder."
:
"model.decoder."
,
"encoder."
:
"model.encoder."
,
"shared."
:
"model.shared."
},
orig_to_new_substr
=
{
"beta"
:
"bias"
,
"gamma"
:
"weight"
,
"LayerNorm"
:
"layernorm"
,
},
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
@@ -789,7 +834,6 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
self
.
lm_head
=
BartParallelLMHead
(
config
.
vocab_size
,
config
.
d_model
,
embed_scale
=
embed_scale
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
...
...
@@ -828,61 +872,12 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
sampling_metadata
)
return
logits
stacked_params_mapping
=
{
"q_proj"
:
{
"param_name"
:
"qkv_proj"
,
"shard_id"
:
"q"
,
},
"k_proj"
:
{
"param_name"
:
"qkv_proj"
,
"shard_id"
:
"k"
,
},
"v_proj"
:
{
"param_name"
:
"qkv_proj"
,
"shard_id"
:
"v"
,
},
}
params_mapping
=
{
"beta"
:
"bias"
,
"gamma"
:
"weight"
,
"LayerNorm"
:
"layernorm"
,
}
def
_rename_key
(
self
,
key
:
str
):
prefix
=
f
"
{
self
.
base_model_prefix
}
."
key
=
key
[
len
(
prefix
):]
if
key
.
startswith
(
prefix
)
else
key
for
src
,
dst
in
self
.
params_mapping
.
items
():
key
=
key
.
replace
(
src
,
dst
)
return
key
def
_rename_stacked_param
(
self
,
name
:
str
,
)
->
tuple
[
str
,
Optional
[
str
]]:
for
key
,
mapping
in
self
.
stacked_params_mapping
.
items
():
if
key
in
name
:
name
=
name
.
replace
(
key
,
mapping
[
"param_name"
])
return
name
,
mapping
[
"shard_id"
]
return
name
,
None
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
model_params_dict
=
dict
(
self
.
model
.
named_parameters
())
top_params_dict
=
dict
(
self
.
named_parameters
())
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
weights_tuple_list
=
list
(
weights
)
shared_embedding_weight
=
None
shared_embedding_shard_id
=
None
for
name
,
loaded_weight
in
weights_tuple_list
:
name
=
self
.
_rename_key
(
name
)
name
,
shard_id
=
self
.
_rename_stacked_param
(
name
)
if
(
'shared.weight'
in
name
or
'encoder.embed_tokens.weight'
in
name
or
'decoder.embed_tokens.weight'
in
name
...
...
@@ -890,49 +885,24 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
assert
shared_embedding_weight
is
None
,
(
"Conflicting embedding weights."
)
shared_embedding_weight
=
loaded_weight
shared_embedding_shard_id
=
shard_id
else
:
# Skip the specific downstream task weight.
if
name
.
startswith
(
'cls.'
):
continue
# use Pooler instead.
if
name
.
startswith
(
'pooler.'
):
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
model_params_dict
:
continue
param
=
model_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
if
shard_id
:
weight_loader
(
param
,
loaded_weight
,
shard_id
)
else
:
weight_loader
(
param
,
loaded_weight
)
# Assign shared weight values
encoder_in_param
=
model_params_dict
[
'encoder.embed_tokens.weight'
]
encoder_in_weight_loader
=
getattr
(
encoder_in_param
,
"weight_loader"
,
default_weight_loader
)
decoder_in_param
=
model_params_dict
[
'decoder.embed_tokens.weight'
]
decoder_in_weight_loader
=
getattr
(
decoder_in_param
,
"weight_loader"
,
default_weight_loader
)
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"cls."
,
"pooler."
]),
)
loaded_params
=
loader
.
load_weights
(
weights_tuple_list
,
mapper
=
self
.
hf_to_vllm_mapper
)
lm_head_in_param
=
top_params_dict
[
'lm_head.weight'
]
lm_head_in_
weight_loader
=
getattr
(
lm_head
_in_param
,
"weight_loader"
,
if
shared_embedding_weight
is
not
None
:
weight_loader
=
getattr
(
self
.
lm_head
.
weight
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
self
.
lm_head
.
weight
,
shared_embedding_weight
)
assert
shared_embedding_weight
is
not
None
self
.
model
.
encoder
.
embed_tokens
.
weight
=
self
.
lm_head
.
weight
self
.
model
.
decoder
.
embed_tokens
.
weight
=
self
.
lm_head
.
weight
loaded_params
.
update
({
'model.encoder.embed_tokens.weight'
,
'lm_head.weight'
,
'model.decoder.embed_tokens.weight'
})
if
shared_embedding_shard_id
:
encoder_in_weight_loader
(
encoder_in_param
,
shared_embedding_weight
,
shared_embedding_shard_id
)
decoder_in_weight_loader
(
decoder_in_param
,
shared_embedding_weight
,
shared_embedding_shard_id
)
lm_head_in_weight_loader
(
lm_head_in_param
,
shared_embedding_weight
,
shared_embedding_shard_id
)
else
:
encoder_in_weight_loader
(
encoder_in_param
,
shared_embedding_weight
)
decoder_in_weight_loader
(
decoder_in_param
,
shared_embedding_weight
)
lm_head_in_weight_loader
(
lm_head_in_param
,
shared_embedding_weight
)
return
loaded_params
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment