Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3f06bae9
Unverified
Commit
3f06bae9
authored
Sep 24, 2024
by
Peter Salas
Committed by
GitHub
Sep 24, 2024
Browse files
[Core][Model] Support loading weights by ID within models (#7931)
parent
b8747e8a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
17 deletions
+73
-17
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+47
-13
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+26
-4
No files found.
vllm/model_executor/model_loader/loader.py
View file @
3f06bae9
# ruff: noqa: SIM117
# ruff: noqa: SIM117
import
collections
import
collections
import
copy
import
copy
import
dataclasses
import
fnmatch
import
fnmatch
import
glob
import
glob
import
json
import
json
...
@@ -8,7 +9,8 @@ import math
...
@@ -8,7 +9,8 @@ import math
import
os
import
os
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
(
Any
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
Type
,
cast
)
import
gguf
import
gguf
import
huggingface_hub
import
huggingface_hub
...
@@ -207,6 +209,22 @@ class BaseModelLoader(ABC):
...
@@ -207,6 +209,22 @@ class BaseModelLoader(ABC):
class
DefaultModelLoader
(
BaseModelLoader
):
class
DefaultModelLoader
(
BaseModelLoader
):
"""Model loader that can load different file types from disk."""
"""Model loader that can load different file types from disk."""
@
dataclasses
.
dataclass
class
Source
:
"""A source for weights."""
model_or_path
:
str
"""The model ID or path."""
revision
:
Optional
[
str
]
"""The optional model revision."""
prefix
:
str
=
""
"""A prefix to prepend to all weights."""
fall_back_to_pt
:
bool
=
True
"""Whether .pt weights can be used."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
if
load_config
.
model_loader_extra_config
:
...
@@ -313,17 +331,16 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -313,17 +331,16 @@ class DefaultModelLoader(BaseModelLoader):
return
hf_folder
,
hf_weights_files
,
use_safetensors
return
hf_folder
,
hf_weights_files
,
use_safetensors
def
_get_weights_iterator
(
def
_get_weights_iterator
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
self
,
source
:
"Source"
fall_back_to_pt
:
bool
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Get an iterator for the model weights based on the load format."""
"""Get an iterator for the model weights based on the load format."""
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
model_
name_
or_path
,
revision
,
fall_back_to_pt
)
source
.
model_or_path
,
source
.
revision
,
source
.
fall_back_to_pt
)
if
self
.
load_config
.
load_format
==
LoadFormat
.
NPCACHE
:
if
self
.
load_config
.
load_format
==
LoadFormat
.
NPCACHE
:
# Currently np_cache only support *.bin checkpoints
# Currently np_cache only support *.bin checkpoints
assert
use_safetensors
is
False
assert
use_safetensors
is
False
weights_iterator
=
np_cache_weights_iterator
(
weights_iterator
=
np_cache_weights_iterator
(
model
_name
_or_path
,
self
.
load_config
.
download_dir
,
hf_folder
,
source
.
model_or_path
,
self
.
load_config
.
download_dir
,
hf_folder
,
hf_weights_files
)
hf_weights_files
)
elif
use_safetensors
:
elif
use_safetensors
:
weights_iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
weights_iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
...
@@ -341,7 +358,29 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -341,7 +358,29 @@ class DefaultModelLoader(BaseModelLoader):
xm
.
mark_step
()
xm
.
mark_step
()
weights_iterator
=
_xla_weights_iterator
(
weights_iterator
)
weights_iterator
=
_xla_weights_iterator
(
weights_iterator
)
return
weights_iterator
# Apply the prefix.
return
((
source
.
prefix
+
name
,
tensor
)
for
(
name
,
tensor
)
in
weights_iterator
)
def
_get_all_weights
(
self
,
model_config
:
ModelConfig
,
model
:
nn
.
Module
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
primary_weights
=
DefaultModelLoader
.
Source
(
model_config
.
model
,
model_config
.
revision
,
prefix
=
""
,
fall_back_to_pt
=
getattr
(
model
,
"fall_back_to_pt_during_load"
,
True
))
yield
from
self
.
_get_weights_iterator
(
primary_weights
)
secondary_weights
=
cast
(
Iterable
[
DefaultModelLoader
.
Source
],
getattr
(
model
,
"secondary_weights"
,
()))
for
source
in
secondary_weights
:
yield
from
self
.
_get_weights_iterator
(
source
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
self
.
_prepare_weights
(
model_config
.
model
,
...
@@ -360,13 +399,8 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -360,13 +399,8 @@ class DefaultModelLoader(BaseModelLoader):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
cache_config
,
lora_config
,
cache_config
,
scheduler_config
)
scheduler_config
)
model
.
load_weights
(
self
.
_get_weights_iterator
(
model_config
.
model
,
model
.
load_weights
(
self
.
_get_all_weights
(
model_config
,
model
))
model_config
.
revision
,
fall_back_to_pt
=
getattr
(
model
,
"fall_back_to_pt_during_load"
,
True
)),
)
for
_
,
module
in
model
.
named_modules
():
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
...
...
vllm/model_executor/models/ultravox.py
View file @
3f06bae9
...
@@ -25,6 +25,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
...
@@ -25,6 +25,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.loader
import
DefaultModelLoader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.utils
import
(
flatten_bn
,
from
vllm.model_executor.models.utils
import
(
flatten_bn
,
...
@@ -334,14 +335,23 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
...
@@ -334,14 +335,23 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
self
.
multi_modal_config
=
multimodal_config
self
.
multi_modal_config
=
multimodal_config
assert
self
.
multi_modal_config
assert
self
.
multi_modal_config
self
.
secondary_weights
=
[]
self
.
audio_tower
=
ModifiedWhisperEncoder
(
config
.
audio_config
)
if
config
.
audio_model_id
is
not
None
:
if
config
.
audio_model_id
is
not
None
:
self
.
audio_tower
=
ModifiedWhisperEncoder
.
from_pretrained
(
self
.
secondary_weights
.
append
(
config
.
audio_model_id
)
DefaultModelLoader
.
Source
(
else
:
model_or_path
=
config
.
audio_model_id
,
self
.
audio_tower
=
ModifiedWhisperEncoder
(
config
.
audio_config
)
revision
=
None
,
prefix
=
"audio_tower."
,
))
self
.
multi_modal_projector
=
UltravoxProjector
(
config
)
self
.
multi_modal_projector
=
UltravoxProjector
(
config
)
self
.
language_model
=
init_vllm_registered_model
(
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
config
.
text_config
,
cache_config
,
quant_config
)
if
config
.
text_model_id
is
not
None
:
self
.
secondary_weights
.
append
(
DefaultModelLoader
.
Source
(
model_or_path
=
config
.
text_model_id
,
revision
=
None
,
prefix
=
"language_model."
))
def
_audio_features_to_embeddings
(
def
_audio_features_to_embeddings
(
self
,
input_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
self
,
input_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -466,6 +476,18 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
...
@@ -466,6 +476,18 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
# prepare weight iterators for components
# prepare weight iterators for components
weights_group
=
group_weights_with_prefix
(
weights
)
weights_group
=
group_weights_with_prefix
(
weights
)
# load audio tower weights
audio_tower_weights
=
weights_group
[
"audio_tower"
]
audio_tower_params_dict
=
dict
(
self
.
audio_tower
.
named_parameters
(
prefix
=
self
.
audio_tower
.
base_model_prefix
))
for
name
,
loaded_weight
in
audio_tower_weights
:
if
name
in
audio_tower_params_dict
:
param
=
audio_tower_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load projector weights
# load projector weights
projector_weights
=
weights_group
[
"multi_modal_projector"
]
projector_weights
=
weights_group
[
"multi_modal_projector"
]
projector_params_dict
=
dict
(
projector_params_dict
=
dict
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment