Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
16ee07f2
Unverified
Commit
16ee07f2
authored
Nov 30, 2024
by
Isotr0py
Committed by
GitHub
Nov 30, 2024
Browse files
[Model] Refactor Molmo weights loading to use AutoWeightsLoader (#10771)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
40bc2425
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
111 additions
and
102 deletions
+111
-102
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+111
-102
No files found.
vllm/model_executor/models/molmo.py
View file @
16ee07f2
...
...
@@ -3,7 +3,7 @@ import re
from
array
import
array
from
dataclasses
import
dataclass
from
functools
import
lru_cache
,
partial
from
typing
import
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
TypedDict
from
typing
import
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
TypedDict
import
torch
from
einops
import
rearrange
...
...
@@ -44,7 +44,8 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
from
vllm.transformers_utils.processor
import
get_processor
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
get_vit_attn_backend
,
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
get_vit_attn_backend
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -720,6 +721,42 @@ class MolmoVisionBackbone(nn.Module):
# image_features: (batch_size, num_image, num_patch, d_model)
return
image_features
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
@
support_torch_compile
class
MolmoModel
(
nn
.
Module
):
...
...
@@ -804,6 +841,28 @@ class MolmoModel(nn.Module):
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"gate_up_proj"
in
name
:
up_proj
,
gate_proj
=
loaded_weight
.
chunk
(
2
,
dim
=
0
)
loaded_weight
=
torch
.
cat
([
gate_proj
,
up_proj
],
dim
=
0
)
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
cached_get_processor
=
lru_cache
(
get_processor
)
...
...
@@ -1200,103 +1259,53 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
# vision backbone mapping
"image_projector.w1."
:
"image_projector.gate_proj."
,
"image_projector.w3."
:
"image_projector.up_proj."
,
"image_projector.w2."
:
"image_projector.down_proj."
,
# language backbone mapping
"att_proj"
:
"self_attn.qkv_proj"
,
"attn_out"
:
"self_attn.o_proj"
,
"q_norm"
:
"self_attn.q_norm"
,
"k_norm"
:
"self_attn.k_norm"
,
"ff_proj"
:
"mlp.gate_up_proj"
,
"ff_out"
:
"mlp.down_proj"
,
"attn_norm"
:
"input_layernorm"
,
"ff_norm"
:
"post_attention_layernorm"
,
},
orig_to_new_prefix
=
{
# vision backbone mapping
"model.vision_backbone."
:
"vision_backbone."
,
# language backbone mapping
"model.transformer.blocks."
:
"model.layers."
,
"model.transformer.ln_f."
:
"model.norm."
,
# lm_head is renamed to model.transformer.mlp.down_proj firstly,
# we need to run a second renaming for it
"model.transformer.mlp.down_proj."
:
"lm_head."
,
},
)
loader
=
AutoWeightsLoader
(
self
)
weights
=
_get_weights_with_merged_embedding
(
weights
)
return
loader
.
load_weights
(
weights
,
mapper
=
hf_to_vllm_mapper
)
params_mapping
=
[
(
"model.transformer.ln_f.weight"
,
"model.norm.weight"
),
(
"attn_out"
,
"self_attn.o_proj"
),
(
"att_proj"
,
"self_attn.qkv_proj"
),
(
"q_norm"
,
"self_attn.q_norm"
),
(
"k_norm"
,
"self_attn.k_norm"
),
(
"attn_norm"
,
"input_layernorm"
),
(
"ff_norm"
,
"post_attention_layernorm"
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
embedding_weight
=
dict
()
projector_weight
=
dict
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
def
_get_weights_with_merged_embedding
(
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]
)
->
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]:
embedding_weights
=
{}
for
name
,
weight
in
weights
:
if
"wte.embedding"
in
name
:
embedding_weight
[
"embedding"
]
=
loaded_weight
continue
if
"wte.new_embedding"
in
name
:
embedding_weight
[
"new_embedding"
]
=
loaded_weight
continue
if
"vision_backbone"
in
name
:
if
name
.
startswith
(
"model"
):
name
=
name
[
len
(
"model."
):]
if
'image_projector'
in
name
:
if
'w1'
in
name
:
projector_weight
[
'gate_proj'
]
=
loaded_weight
elif
'w3'
in
name
:
projector_weight
[
'up_proj'
]
=
loaded_weight
elif
'w2'
in
name
:
projector_weight
[
'down_proj'
]
=
loaded_weight
else
:
raise
ValueError
(
f
"Unexpected projector weight:
{
name
}
"
)
continue
embedding_weights
[
"embedding"
]
=
weight
elif
"wte.new_embedding"
in
name
:
embedding_weights
[
"new_embedding"
]
=
weight
else
:
if
"transformer.blocks"
in
name
:
name
=
name
.
replace
(
"transformer.blocks"
,
"layers"
)
if
"ff_proj"
in
name
:
name
=
name
.
replace
(
"ff_proj"
,
"mlp.gate_up_proj"
)
assert
'weight'
in
name
up_weight
,
gate_weight
=
loaded_weight
.
chunk
(
2
,
dim
=
0
)
loaded_weight
=
torch
.
cat
([
gate_weight
,
up_weight
],
dim
=
0
)
elif
"ff_out"
in
name
:
if
"layers"
in
name
:
name
=
name
.
replace
(
"ff_out"
,
"mlp.down_proj"
)
else
:
# lm head
name
=
name
.
replace
(
"model.transformer.ff_out"
,
"lm_head"
)
else
:
for
(
param_name
,
weight_name
)
in
params_mapping
:
if
param_name
in
name
:
name
=
name
.
replace
(
param_name
,
weight_name
)
break
try
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
except
KeyError
:
raise
ValueError
(
f
"Unexpected weight:
{
name
}
"
)
from
None
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
gate_up_proj_weight
=
torch
.
cat
(
[
projector_weight
[
"gate_proj"
],
projector_weight
[
"up_proj"
]],
dim
=
0
)
name
=
"vision_backbone.image_projector.gate_up_proj.weight"
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
gate_up_proj_weight
)
down_proj_weight
=
projector_weight
[
"down_proj"
]
name
=
"vision_backbone.image_projector.down_proj.weight"
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
down_proj_weight
)
embedding_weight
=
torch
.
cat
(
[
embedding_weight
[
"embedding"
],
embedding_weight
[
"new_embedding"
]],
dim
=
0
)
name
=
"model.embed_tokens.weight"
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
embedding_weight
)
yield
(
name
,
weight
)
# this is compatible with most of quantization,
# because they won't quantize embed_tokens
embedding_weights
=
torch
.
cat
(
[
embedding_weights
[
"embedding"
],
embedding_weights
[
"new_embedding"
]],
dim
=
0
,
)
yield
(
"model.embed_tokens.weight"
,
embedding_weights
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment