Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
13d88d41
Unverified
Commit
13d88d41
authored
Sep 22, 2024
by
Isotr0py
Committed by
GitHub
Sep 22, 2024
Browse files
[Bugfix] Refactor composite weight loading logic (#8656)
parent
d66ac628
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
70 additions
and
61 deletions
+70
-61
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+6
-10
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+6
-10
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+7
-13
vllm/model_executor/models/llava_next_video.py
vllm/model_executor/models/llava_next_video.py
+6
-11
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+5
-9
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+5
-7
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+35
-1
No files found.
vllm/model_executor/models/internvl.py
View file @
13d88d41
...
...
@@ -4,7 +4,6 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import
itertools
import
re
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
...
...
@@ -33,8 +32,8 @@ from vllm.utils import is_list_of
from
.clip
import
(
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_clip_num_patches
)
from
.interfaces
import
SupportsMultiModal
from
.utils
import
(
f
ilter_weights
,
flatten_bn
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
from
.utils
import
(
f
latten_bn
,
group_weights_with_prefix
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
IMG_START
=
'<img>'
IMG_END
=
'</img>'
...
...
@@ -518,21 +517,18 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators for components
vit_
weights
,
mlp_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
3
)
weights
_group
=
group_weights_with_prefix
(
weights
)
# load vision encoder
vit_weights
=
filter_weights
(
vit_weights
,
"vision_model"
)
self
.
vision_model
.
load_weights
(
vit_weights
)
self
.
vision_model
.
load_weights
(
weights_group
[
"vision_model"
])
# load mlp projector
mlp_weights
=
filter_weights
(
mlp_weights
,
"mlp1"
)
mlp_params_dict
=
dict
(
self
.
mlp1
.
named_parameters
())
for
name
,
loaded_weight
in
mlp_
weights
:
for
name
,
loaded_weight
in
weights
_group
[
"mlp1"
]
:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
llm_weights
)
self
.
language_model
.
load_weights
(
weights_group
[
"language_model"
])
vllm/model_executor/models/llava.py
View file @
13d88d41
import
itertools
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
...
...
@@ -26,8 +25,8 @@ from .interfaces import SupportsMultiModal
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
,
input_processor_for_siglip
)
from
.utils
import
(
f
ilter_weights
,
flatten_bn
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
from
.utils
import
(
f
latten_bn
,
group_weights_with_prefix
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
class
LlavaImagePixelInputs
(
TypedDict
):
...
...
@@ -393,21 +392,18 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators for components
vit_
weights
,
mlp_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
3
)
weights
_group
=
group_weights_with_prefix
(
weights
)
# load vision encoder
vit_weights
=
filter_weights
(
vit_weights
,
"vision_tower"
)
self
.
vision_tower
.
load_weights
(
vit_weights
)
self
.
vision_tower
.
load_weights
(
weights_group
[
"vision_tower"
])
# load mlp projector
mlp_weights
=
filter_weights
(
mlp_weights
,
"multi_modal_projector"
)
mlp_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
mlp_
weights
:
for
name
,
loaded_weight
in
weights
_group
[
"multi_modal_projector"
]
:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
llm_weights
)
self
.
language_model
.
load_weights
(
weights_group
[
"language_model"
])
vllm/model_executor/models/llava_next.py
View file @
13d88d41
import
itertools
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
...
...
@@ -30,8 +29,8 @@ from .llava import LlavaMultiModalProjector
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_siglip_image_feature_size
,
get_siglip_patch_grid_length
,
input_processor_for_siglip
)
from
.utils
import
(
f
ilter_weights
,
flatten_bn
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
from
.utils
import
(
f
latten_bn
,
group_weights_with_prefix
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
logger
=
init_logger
(
__name__
)
...
...
@@ -637,25 +636,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators for components
vit_weights
,
mlp_weights
,
newline_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
4
)
weights_group
=
group_weights_with_prefix
(
weights
)
# load vision encoder
vit_weights
=
filter_weights
(
vit_weights
,
"vision_tower"
)
self
.
vision_tower
.
load_weights
(
vit_weights
)
self
.
vision_tower
.
load_weights
(
weights_group
[
"vision_tower"
])
# load mlp projector
mlp_weights
=
filter_weights
(
mlp_weights
,
"multi_modal_projector"
)
mlp_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
mlp_
weights
:
for
name
,
loaded_weight
in
weights
_group
[
"multi_modal_projector"
]
:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load newline
newline_weights
=
filter_weights
(
newline_weights
,
"image_newline"
)
for
name
,
loaded_weight
in
newline_weights
:
for
name
,
loaded_weight
in
weights_group
[
"image_newline"
]:
assert
name
==
""
param
=
self
.
image_newline
weight_loader
=
getattr
(
param
,
"weight_loader"
,
...
...
@@ -663,5 +658,4 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
llm_weights
)
self
.
language_model
.
load_weights
(
weights_group
[
"language_model"
])
vllm/model_executor/models/llava_next_video.py
View file @
13d88d41
import
itertools
import
math
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
...
...
@@ -30,7 +29,7 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from
.interfaces
import
SupportsMultiModal
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
)
from
.utils
import
(
filter
_weights
,
init_vllm_registered_model
,
from
.utils
import
(
group
_weights
_with_prefix
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
logger
=
init_logger
(
__name__
)
...
...
@@ -449,23 +448,19 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators
vit_weights
,
mlp_weights
,
newline_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
4
)
# prepare weight iterators for components
weights_group
=
group_weights_with_prefix
(
weights
)
# load vision encoder
vit_weights
=
filter_weights
(
vit_weights
,
"vision_tower"
)
self
.
vision_tower
.
load_weights
(
vit_weights
)
self
.
vision_tower
.
load_weights
(
weights_group
[
"vision_tower"
])
# load mlp projector
mlp_weights
=
filter_weights
(
mlp_weights
,
"multi_modal_projector"
)
mlp_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
mlp_
weights
:
for
name
,
loaded_weight
in
weights
_group
[
"multi_modal_projector"
]
:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
llm_weights
)
self
.
language_model
.
load_weights
(
weights_group
[
"language_model"
])
vllm/model_executor/models/paligemma.py
View file @
13d88d41
import
itertools
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
...
...
@@ -23,7 +22,7 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsMultiModal
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
)
from
.utils
import
filter
_weights
,
merge_multimodal_embeddings
from
.utils
import
group
_weights
_with_prefix
,
merge_multimodal_embeddings
logger
=
init_logger
(
__name__
)
...
...
@@ -286,21 +285,18 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators for components
vit_
weights
,
mlp_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
3
)
weights
_group
=
group_weights_with_prefix
(
weights
)
# load vision tower
vit_weights
=
filter_weights
(
vit_weights
,
"vision_tower"
)
self
.
vision_tower
.
load_weights
(
vit_weights
)
self
.
vision_tower
.
load_weights
(
weights_group
[
"vision_tower"
])
# load mlp projector
mlp_weights
=
filter_weights
(
mlp_weights
,
"multi_modal_projector"
)
mlp_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
mlp_
weights
:
for
name
,
loaded_weight
in
weights
_group
[
"multi_modal_projector"
]
:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
llm_weights
)
self
.
language_model
.
load_weights
(
weights_group
[
"language_model"
])
vllm/model_executor/models/ultravox.py
View file @
13d88d41
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
import
itertools
import
math
from
array
import
array
from
functools
import
lru_cache
...
...
@@ -29,7 +28,8 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.utils
import
(
filter_weights
,
flatten_bn
,
from
vllm.model_executor.models.utils
import
(
flatten_bn
,
group_weights_with_prefix
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
...
@@ -467,11 +467,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators for components
projector_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
2
)
weights_group
=
group_weights_with_prefix
(
weights
)
# load projector weights
projector_weights
=
filter_weights
(
projector_weights
,
"multi_modal_projector"
)
projector_weights
=
weights_group
[
"multi_modal_projector"
]
projector_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
projector_weights
:
...
...
@@ -481,5 +480,4 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
llm_weights
)
self
.
language_model
.
load_weights
(
weights_group
[
"language_model"
])
vllm/model_executor/models/utils.py
View file @
13d88d41
import
itertools
from
collections
import
UserDict
from
typing
import
(
Dict
,
Iterable
,
List
,
Literal
,
Optional
,
Protocol
,
Tuple
,
Union
,
overload
)
...
...
@@ -16,7 +18,23 @@ from vllm.sequence import IntermediateTensors
from
vllm.utils
import
is_pin_memory_available
def
filter_weights
(
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
prefix
:
str
):
class
WeightsGroup
(
UserDict
):
"""
Wraps grouped weights dictionary for a more informative error message
when attempting to access a weight component that does not exist.
"""
def
__getitem__
(
self
,
key
:
str
)
->
int
:
try
:
return
super
().
__getitem__
(
key
)
except
KeyError
as
exc
:
msg
=
(
f
"There is no weights named with the prefix:
{
key
}
. "
f
"Available prefix:
{
set
(
self
.
keys
())
}
"
)
raise
KeyError
(
msg
)
from
exc
def
filter_weights
(
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
prefix
:
str
)
->
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]:
"""
Helper function to load weights for inner vLLM models.
...
...
@@ -30,6 +48,22 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
yield
name
,
loaded_weight
def
group_weights_with_prefix
(
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]
)
->
Dict
[
str
,
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]]:
"""
Helper function to group weights with prefix
"""
init_weights
,
repeated_weights
=
itertools
.
tee
(
weights
,
2
)
weights_prefix
=
{
name
.
split
(
"."
)[
0
]
for
name
,
_
in
init_weights
}
repeated_weights
=
itertools
.
tee
(
repeated_weights
,
len
(
weights_prefix
))
return
WeightsGroup
({
prefix
:
filter_weights
(
component
,
prefix
)
for
component
,
prefix
in
zip
(
repeated_weights
,
weights_prefix
)
})
def
init_vllm_registered_model
(
hf_config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment