Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0c25435d
Unverified
Commit
0c25435d
authored
Aug 03, 2024
by
Isotr0py
Committed by
GitHub
Aug 02, 2024
Browse files
[Model] Refactor and decouple weight loading logic for InternVL2 model (#7067)
parent
a0d16456
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
55 deletions
+38
-55
vllm/model_executor/models/intern_vit.py
vllm/model_executor/models/intern_vit.py
+10
-1
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+28
-54
No files found.
vllm/model_executor/models/intern_vit.py
View file @
0c25435d
...
...
@@ -4,7 +4,7 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from
typing
import
Optional
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
...
...
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
NORM2FN
=
{
'rms_norm'
:
RMSNorm
,
...
...
@@ -268,3 +269,11 @@ class InternVisionModel(nn.Module):
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
)
return
encoder_outputs
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/internvl.py
View file @
0c25435d
...
...
@@ -4,6 +4,7 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import
itertools
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
import
torch
...
...
@@ -414,58 +415,31 @@ class InternVLChatModel(nn.Module, SupportsVision):
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
(
".gate_up_proj"
,
".w1"
,
0
),
(
".gate_up_proj"
,
".w3"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
def
_filter_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
prefix
:
str
):
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
self
.
config
.
text_config
.
tie_word_embeddings
\
and
"lm_head.weight"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
# We only do sharding for language model
# and not vision model for now.
if
"vision_embed_tokens"
in
name
and
self
.
vision_embed_tokens
:
continue
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
if
"wqkv"
in
name
:
config
=
self
.
config
.
text_config
kv_groups
=
(
config
.
num_attention_heads
//
config
.
num_key_value_heads
)
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
loaded_weight
=
loaded_weight
.
view
(
-
1
,
2
+
kv_groups
,
head_dim
,
loaded_weight
.
shape
[
-
1
])
wq
,
wk
,
wv
=
torch
.
split
(
loaded_weight
,
[
kv_groups
,
1
,
1
],
dim
=
1
)
wq
=
wq
.
reshape
(
-
1
,
wq
.
shape
[
-
1
])
wk
=
wk
.
reshape
(
-
1
,
wk
.
shape
[
-
1
])
wv
=
wv
.
reshape
(
-
1
,
wv
.
shape
[
-
1
])
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
wq
,
'q'
)
weight_loader
(
param
,
wk
,
'k'
)
weight_loader
(
param
,
wv
,
'v'
)
continue
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
name
=
name
.
split
(
"."
)
if
prefix
==
name
.
pop
(
0
):
name
=
"."
.
join
(
name
)
yield
name
,
loaded_weight
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators for components
vit_weights
,
mlp_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
3
)
# load vision encoder
vit_weights
=
self
.
_filter_weights
(
vit_weights
,
"vision_model"
)
self
.
vision_model
.
load_weights
(
vit_weights
)
# load mlp projector
mlp_weights
=
self
.
_filter_weights
(
mlp_weights
,
"mlp1"
)
mlp_params_dict
=
dict
(
self
.
mlp1
.
named_parameters
())
for
name
,
loaded_weight
in
mlp_weights
:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
llm_weights
=
self
.
_filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
llm_weights
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment