Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
822de7fb
Unverified
Commit
822de7fb
authored
May 07, 2025
by
Jee Jee Li
Committed by
GitHub
May 07, 2025
Browse files
[Misc] Split model loader (#17712)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
8d84d836
Changes
20
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1681 additions
and
22 deletions
+1681
-22
tests/runai_model_streamer_test/test_runai_model_streamer_loader.py
...i_model_streamer_test/test_runai_model_streamer_loader.py
+2
-3
tests/test_sharded_state_loader.py
tests/test_sharded_state_loader.py
+1
-1
tests/utils.py
tests/utils.py
+1
-1
vllm/config.py
vllm/config.py
+1
-1
vllm/model_executor/model_loader/__init__.py
vllm/model_executor/model_loader/__init__.py
+53
-5
vllm/model_executor/model_loader/base_loader.py
vllm/model_executor/model_loader/base_loader.py
+23
-0
vllm/model_executor/model_loader/bitsandbytes_loader.py
vllm/model_executor/model_loader/bitsandbytes_loader.py
+568
-0
vllm/model_executor/model_loader/default_loader.py
vllm/model_executor/model_loader/default_loader.py
+293
-0
vllm/model_executor/model_loader/dummy_loader.py
vllm/model_executor/model_loader/dummy_loader.py
+37
-0
vllm/model_executor/model_loader/gguf_loader.py
vllm/model_executor/model_loader/gguf_loader.py
+113
-0
vllm/model_executor/model_loader/runai_streamer_loader.py
vllm/model_executor/model_loader/runai_streamer_loader.py
+120
-0
vllm/model_executor/model_loader/sharded_state_loader.py
vllm/model_executor/model_loader/sharded_state_loader.py
+210
-0
vllm/model_executor/model_loader/tensorizer_loader.py
vllm/model_executor/model_loader/tensorizer_loader.py
+119
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+131
-2
vllm/model_executor/models/mllama4.py
vllm/model_executor/models/mllama4.py
+2
-2
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+1
-1
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+2
-2
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+1
-1
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+1
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+2
-2
No files found.
tests/runai_model_streamer_test/test_runai_model_streamer_loader.py
View file @
822de7fb
...
@@ -2,8 +2,7 @@
...
@@ -2,8 +2,7 @@
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.config
import
LoadConfig
,
LoadFormat
from
vllm.config
import
LoadConfig
,
LoadFormat
from
vllm.model_executor.model_loader.loader
import
(
RunaiModelStreamerLoader
,
from
vllm.model_executor.model_loader
import
get_model_loader
get_model_loader
)
test_model
=
"openai-community/gpt2"
test_model
=
"openai-community/gpt2"
...
@@ -24,7 +23,7 @@ def get_runai_model_loader():
...
@@ -24,7 +23,7 @@ def get_runai_model_loader():
def
test_get_model_loader_with_runai_flag
():
def
test_get_model_loader_with_runai_flag
():
model_loader
=
get_runai_model_loader
()
model_loader
=
get_runai_model_loader
()
assert
isinstance
(
model_loader
,
RunaiModelStreamerLoader
)
assert
model_loader
.
__class__
.
__name__
==
"
RunaiModelStreamerLoader
"
def
test_runai_model_loader_download_files
(
vllm_runner
):
def
test_runai_model_loader_download_files
(
vllm_runner
):
...
...
tests/test_sharded_state_loader.py
View file @
822de7fb
...
@@ -10,7 +10,7 @@ import torch
...
@@ -10,7 +10,7 @@ import torch
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.model_executor.model_loader
.loader
import
ShardedStateLoader
from
vllm.model_executor.model_loader
import
ShardedStateLoader
prompts
=
[
prompts
=
[
"Hello, my name is"
,
"Hello, my name is"
,
...
...
tests/utils.py
View file @
822de7fb
...
@@ -29,7 +29,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
...
@@ -29,7 +29,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.model_executor.model_loader
.loader
import
get_model_loader
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.utils
import
(
FlexibleArgumentParser
,
GB_bytes
,
from
vllm.utils
import
(
FlexibleArgumentParser
,
GB_bytes
,
...
...
vllm/config.py
View file @
822de7fb
...
@@ -54,7 +54,7 @@ if TYPE_CHECKING:
...
@@ -54,7 +54,7 @@ if TYPE_CHECKING:
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.model_loader
.loader
import
BaseModelLoader
from
vllm.model_executor.model_loader
import
BaseModelLoader
ConfigType
=
type
[
DataclassInstance
]
ConfigType
=
type
[
DataclassInstance
]
else
:
else
:
...
...
vllm/model_executor/model_loader/__init__.py
View file @
822de7fb
...
@@ -2,19 +2,67 @@
...
@@ -2,19 +2,67 @@
from
torch
import
nn
from
torch
import
nn
from
vllm.config
import
VllmConfig
from
vllm.config
import
LoadConfig
,
LoadFormat
,
VllmConfig
from
vllm.model_executor.model_loader.loader
import
(
BaseModelLoader
,
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
get_model_loader
)
from
vllm.model_executor.model_loader.bitsandbytes_loader
import
(
BitsAndBytesModelLoader
)
from
vllm.model_executor.model_loader.default_loader
import
DefaultModelLoader
from
vllm.model_executor.model_loader.dummy_loader
import
DummyModelLoader
from
vllm.model_executor.model_loader.gguf_loader
import
GGUFModelLoader
from
vllm.model_executor.model_loader.runai_streamer_loader
import
(
RunaiModelStreamerLoader
)
from
vllm.model_executor.model_loader.sharded_state_loader
import
(
ShardedStateLoader
)
from
vllm.model_executor.model_loader.tensorizer_loader
import
TensorizerLoader
from
vllm.model_executor.model_loader.utils
import
(
from
vllm.model_executor.model_loader.utils
import
(
get_architecture_class_name
,
get_model_architecture
)
get_architecture_class_name
,
get_model_architecture
)
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
"""Get a model loader based on the load format."""
if
isinstance
(
load_config
.
load_format
,
type
):
return
load_config
.
load_format
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
DUMMY
:
return
DummyModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
TENSORIZER
:
return
TensorizerLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
SHARDED_STATE
:
return
ShardedStateLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
BITSANDBYTES
:
return
BitsAndBytesModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
GGUF
:
return
GGUFModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
RUNAI_STREAMER
:
return
RunaiModelStreamerLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
RUNAI_STREAMER_SHARDED
:
return
ShardedStateLoader
(
load_config
,
runai_model_streamer
=
True
)
return
DefaultModelLoader
(
load_config
)
def
get_model
(
*
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
get_model
(
*
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
loader
=
get_model_loader
(
vllm_config
.
load_config
)
loader
=
get_model_loader
(
vllm_config
.
load_config
)
return
loader
.
load_model
(
vllm_config
=
vllm_config
)
return
loader
.
load_model
(
vllm_config
=
vllm_config
)
__all__
=
[
__all__
=
[
"get_model"
,
"get_model_loader"
,
"BaseModelLoader"
,
"get_model"
,
"get_architecture_class_name"
,
"get_model_architecture"
"get_model_loader"
,
"get_architecture_class_name"
,
"get_model_architecture"
,
"BaseModelLoader"
,
"BitsAndBytesModelLoader"
,
"GGUFModelLoader"
,
"DefaultModelLoader"
,
"DummyModelLoader"
,
"RunaiModelStreamerLoader"
,
"ShardedStateLoader"
,
"TensorizerLoader"
,
]
]
vllm/model_executor/model_loader/base_loader.py
0 → 100644
View file @
822de7fb
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
import
torch.nn
as
nn
from
vllm.config
import
LoadConfig
,
ModelConfig
,
VllmConfig
class
BaseModelLoader
(
ABC
):
"""Base class for model loaders."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
self
.
load_config
=
load_config
@
abstractmethod
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
"""Download a model so that it can be immediately loaded."""
raise
NotImplementedError
@
abstractmethod
def
load_model
(
self
,
*
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
"""Load a model with the given configurations."""
raise
NotImplementedError
vllm/model_executor/model_loader/loader.py
→
vllm/model_executor/model_loader/
bitsandbytes_
loader.py
View file @
822de7fb
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: SIM117
# ruff: noqa: SIM117
import
collections
import
copy
import
copy
import
dataclasses
import
fnmatch
import
fnmatch
import
glob
import
glob
import
inspect
import
itertools
import
itertools
import
math
import
math
import
os
import
os
import
time
from
typing
import
Any
,
Callable
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
import
warnings
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
typing
import
(
Any
,
Callable
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
)
import
gguf
import
huggingface_hub
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
huggingface_hub
import
HfApi
from
huggingface_hub
import
HfApi
from
torch
import
nn
from
torch
import
nn
from
transformers
import
AutoModelForCausalLM
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.attention
import
Attention
from
vllm.config
import
LoadConfig
,
ModelConfig
,
VllmConfig
from
vllm.config
import
(
LoadConfig
,
LoadFormat
,
ModelConfig
,
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
)
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.envs
import
VLLM_USE_MODELSCOPE
# yapf: enable
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVCrossParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
# yapf: enable
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizeMethodBase
)
from
vllm.model_executor.model_loader.tensorizer
import
(
TensorizerConfig
,
is_vllm_tensorized
,
load_with_tensorizer
,
serialize_vllm_model
,
tensorizer_weights_iterator
)
from
vllm.model_executor.model_loader.utils
import
(
ParamMapping
,
from
vllm.model_executor.model_loader.utils
import
(
ParamMapping
,
configure_quant_config
,
initialize_model
,
get_model_architecture
,
set_default_torch_dtype
)
set_default_torch_dtype
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
fastsafetensors_weights_iterator
,
filter_duplicate_safetensors_files
,
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
filter_files_not_needed_for_inference
,
get_gguf_extra_tensor_names
,
pt_weights_iterator
,
safetensors_weights_iterator
)
get_lock
,
gguf_quant_weights_iterator
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
runai_safetensors_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.s3_utils
import
glob
as
s3_glob
from
vllm.transformers_utils.utils
import
is_s3
from
vllm.utils
import
is_pin_memory_available
@
contextmanager
def
device_loading_context
(
module
:
torch
.
nn
.
Module
,
target_device
:
torch
.
device
):
if
target_device
.
type
==
"cpu"
:
# If target is CPU, no need to move anything
yield
module
return
original_device_states
:
Dict
[
str
,
torch
.
device
]
=
{}
# Store original device states and move parameters to GPU if they're on CPU
for
name
,
p
in
module
.
named_parameters
():
if
p
.
device
.
type
==
"cpu"
:
original_device_states
[
name
]
=
p
.
device
p
.
data
=
p
.
data
.
to
(
target_device
)
# Parameters already on target device are not touched
try
:
yield
module
finally
:
# Restore parameters to their original devices, ignoring new parameters
pin_memory
=
is_pin_memory_available
()
for
name
,
p
in
module
.
named_parameters
():
if
name
in
original_device_states
:
original_device
:
torch
.
device
=
original_device_states
[
name
]
if
original_device
.
type
==
"cpu"
:
# `torch.empty_like` does not support `pin_memory` argument
cpu_data
=
torch
.
empty_strided
(
size
=
p
.
data
.
size
(),
stride
=
p
.
data
.
stride
(),
dtype
=
p
.
data
.
dtype
,
layout
=
p
.
data
.
layout
,
device
=
"cpu"
,
pin_memory
=
pin_memory
,
)
cpu_data
.
copy_
(
p
.
data
)
p
.
data
=
cpu_data
else
:
p
.
data
=
p
.
data
.
to
(
original_device
)
# New parameters or parameters already on target device are untouched
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
_initialize_model
(
vllm_config
:
VllmConfig
,
*
,
prefix
:
str
=
""
,
model_class
:
Optional
[
type
[
nn
.
Module
]]
=
None
,
)
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
model_config
=
vllm_config
.
model_config
if
model_class
is
None
:
model_class
,
_
=
get_model_architecture
(
model_config
)
if
vllm_config
.
quant_config
is
not
None
:
configure_quant_config
(
vllm_config
.
quant_config
,
model_class
)
signatures
=
inspect
.
signature
(
model_class
.
__init__
)
all_params
=
[
param
.
name
for
param
in
signatures
.
parameters
.
values
()]
if
"vllm_config"
in
all_params
and
"prefix"
in
all_params
:
# new-style model class
with
set_current_vllm_config
(
vllm_config
,
check_compile
=
True
):
return
model_class
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
msg
=
(
"vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"for the design and update the model class accordingly."
)
warnings
.
warn
(
msg
,
DeprecationWarning
,
stacklevel
=
2
)
logger
.
warning
(
"Trying to guess the arguments for old-style model class %s"
,
model_class
,
)
# try to be compatible with old-style model class
kwargs
=
{}
if
"prefix"
in
all_params
:
kwargs
[
"prefix"
]
=
prefix
if
"config"
in
all_params
:
kwargs
[
"config"
]
=
model_config
.
hf_config
if
"cache_config"
in
all_params
:
kwargs
[
"cache_config"
]
=
vllm_config
.
cache_config
if
"quant_config"
in
all_params
:
kwargs
[
"quant_config"
]
=
vllm_config
.
quant_config
if
"lora_config"
in
all_params
:
kwargs
[
"lora_config"
]
=
vllm_config
.
lora_config
if
"scheduler_config"
in
all_params
:
kwargs
[
"scheduler_config"
]
=
vllm_config
.
scheduler_config
with
set_current_vllm_config
(
vllm_config
,
check_compile
=
True
):
return
model_class
(
**
kwargs
)
def
_process_weights_after_loading
(
model
:
nn
.
Module
,
model_config
:
ModelConfig
,
target_device
:
torch
.
device
)
->
None
:
for
_
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
QKVCrossParallelLinear
):
# NOTE(Isotr0py): special case for cross QKV layer because
# q and kv proj aren't registered as submodules intentionally
module
.
process_weights_after_loading
()
continue
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
isinstance
(
quant_method
,
QuantizeMethodBase
):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with
device_loading_context
(
module
,
target_device
):
quant_method
.
process_weights_after_loading
(
module
)
# Currently only used by MLA.
# NOTE: This intentionally happens after other modules so we can easily
# decompress the weights for MLA.
for
_
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
Attention
)
and
\
hasattr
(
module
,
"process_weights_after_loading"
):
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module
.
process_weights_after_loading
(
model_config
.
dtype
)
class
BaseModelLoader
(
ABC
):
"""Base class for model loaders."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
self
.
load_config
=
load_config
@
abstractmethod
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
"""Download a model so that it can be immediately loaded."""
raise
NotImplementedError
@
abstractmethod
def
load_model
(
self
,
*
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
"""Load a model with the given configurations."""
raise
NotImplementedError
class
DefaultModelLoader
(
BaseModelLoader
):
"""Model loader that can load different file types from disk."""
@
dataclasses
.
dataclass
class
Source
:
"""A source for weights."""
model_or_path
:
str
"""The model ID or path."""
revision
:
Optional
[
str
]
"""The optional model revision."""
prefix
:
str
=
""
"""A prefix to prepend to all weights."""
fall_back_to_pt
:
bool
=
True
"""Whether .pt weights can be used."""
allow_patterns_overrides
:
Optional
[
list
[
str
]]
=
None
"""If defined, weights will load exclusively using these patterns."""
counter_before_loading_weights
:
float
=
0.0
counter_after_loading_weights
:
float
=
0.0
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
_maybe_download_from_modelscope
(
self
,
model
:
str
,
revision
:
Optional
[
str
])
->
Optional
[
str
]:
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if
VLLM_USE_MODELSCOPE
:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from
modelscope.hub.snapshot_download
import
snapshot_download
if
not
os
.
path
.
exists
(
model
):
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model
,
self
.
load_config
.
download_dir
):
model_path
=
snapshot_download
(
model_id
=
model
,
cache_dir
=
self
.
load_config
.
download_dir
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
revision
=
revision
,
ignore_file_pattern
=
self
.
load_config
.
ignore_patterns
,
)
else
:
model_path
=
model
return
model_path
return
None
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
,
allow_patterns_overrides
:
Optional
[
list
[
str
]],
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
model_name_or_path
=
(
self
.
_maybe_download_from_modelscope
(
model_name_or_path
,
revision
)
or
model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
load_format
=
self
.
load_config
.
load_format
use_safetensors
=
False
index_file
=
SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
if
load_format
==
LoadFormat
.
AUTO
:
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
(
load_format
==
LoadFormat
.
SAFETENSORS
or
load_format
==
LoadFormat
.
FASTSAFETENSORS
):
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
LoadFormat
.
MISTRAL
:
use_safetensors
=
True
allow_patterns
=
[
"consolidated*.safetensors"
]
index_file
=
"consolidated.safetensors.index.json"
elif
load_format
==
LoadFormat
.
PT
:
allow_patterns
=
[
"*.pt"
]
elif
load_format
==
LoadFormat
.
NPCACHE
:
allow_patterns
=
[
"*.bin"
]
else
:
raise
ValueError
(
f
"Unknown load_format:
{
load_format
}
"
)
if
fall_back_to_pt
:
allow_patterns
+=
[
"*.pt"
]
if
allow_patterns_overrides
is
not
None
:
allow_patterns
=
allow_patterns_overrides
if
not
is_local
:
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
else
:
hf_folder
=
model_name_or_path
hf_weights_files
:
List
[
str
]
=
[]
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
if
len
(
hf_weights_files
)
>
0
:
if
pattern
==
"*.safetensors"
:
use_safetensors
=
True
break
if
use_safetensors
:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if
not
is_local
:
download_safetensors_index_file_from_hf
(
model_name_or_path
,
index_file
,
self
.
load_config
.
download_dir
,
revision
,
)
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
,
hf_folder
,
index_file
)
else
:
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_folder
,
hf_weights_files
,
use_safetensors
def
_get_weights_iterator
(
self
,
source
:
"Source"
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Get an iterator for the model weights based on the load format."""
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
source
.
model_or_path
,
source
.
revision
,
source
.
fall_back_to_pt
,
source
.
allow_patterns_overrides
)
if
self
.
load_config
.
load_format
==
LoadFormat
.
NPCACHE
:
# Currently np_cache only support *.bin checkpoints
assert
use_safetensors
is
False
weights_iterator
=
np_cache_weights_iterator
(
source
.
model_or_path
,
self
.
load_config
.
download_dir
,
hf_folder
,
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
)
elif
use_safetensors
:
if
self
.
load_config
.
load_format
==
LoadFormat
.
FASTSAFETENSORS
:
weights_iterator
=
fastsafetensors_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
)
else
:
weights_iterator
=
safetensors_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
)
else
:
weights_iterator
=
pt_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
self
.
load_config
.
pt_load_map_location
,
)
if
current_platform
.
is_tpu
():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
import
torch_xla.core.xla_model
as
xm
def
_xla_weights_iterator
(
iterator
:
Generator
):
for
weights
in
iterator
:
yield
weights
xm
.
mark_step
()
weights_iterator
=
_xla_weights_iterator
(
weights_iterator
)
elif
current_platform
.
is_hpu
():
import
habana_frameworks.torch.core
as
htcore
def
_hpu_weights_iterator
(
iterator
:
Generator
):
for
weights
in
iterator
:
yield
weights
htcore
.
mark_step
()
weights_iterator
=
_hpu_weights_iterator
(
weights_iterator
)
if
self
.
counter_before_loading_weights
==
0.0
:
self
.
counter_before_loading_weights
=
time
.
perf_counter
()
# Apply the prefix.
return
((
source
.
prefix
+
name
,
tensor
)
for
(
name
,
tensor
)
in
weights_iterator
)
def
get_all_weights
(
self
,
model_config
:
ModelConfig
,
model
:
nn
.
Module
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
primary_weights
=
DefaultModelLoader
.
Source
(
model_config
.
model
,
model_config
.
revision
,
prefix
=
""
,
fall_back_to_pt
=
getattr
(
model
,
"fall_back_to_pt_during_load"
,
True
),
allow_patterns_overrides
=
getattr
(
model
,
"allow_patterns_overrides"
,
None
),
)
yield
from
self
.
_get_weights_iterator
(
primary_weights
)
secondary_weights
=
cast
(
Iterable
[
DefaultModelLoader
.
Source
],
getattr
(
model
,
"secondary_weights"
,
()),
)
for
source
in
secondary_weights
:
yield
from
self
.
_get_weights_iterator
(
source
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
,
fall_back_to_pt
=
True
,
allow_patterns_overrides
=
None
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
loaded_weights
=
model
.
load_weights
(
self
.
get_all_weights
(
model_config
,
model
))
self
.
counter_after_loading_weights
=
time
.
perf_counter
()
logger
.
info
(
"Loading weights took %.2f seconds"
,
self
.
counter_after_loading_weights
-
self
.
counter_before_loading_weights
)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if
model_config
.
quantization
is
None
and
loaded_weights
is
not
None
:
weights_not_loaded
=
weights_to_load
-
loaded_weights
if
weights_not_loaded
:
raise
ValueError
(
"Following weights were not initialized from "
f
"checkpoint:
{
weights_not_loaded
}
"
)
_process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
.
eval
()
class
DummyModelLoader
(
BaseModelLoader
):
"""Model loader that will set model weights to random values."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
pass
# Nothing to download
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights
(
model
)
_process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
.
eval
()
class
TensorizerLoader
(
BaseModelLoader
):
"""Model loader using CoreWeave's tensorizer library."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
isinstance
(
load_config
.
model_loader_extra_config
,
TensorizerConfig
):
self
.
tensorizer_config
=
load_config
.
model_loader_extra_config
else
:
self
.
tensorizer_config
=
TensorizerConfig
(
**
load_config
.
model_loader_extra_config
)
def
_verify_config
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
):
self
.
tensorizer_config
.
verify_with_model_config
(
model_config
)
self
.
tensorizer_config
.
verify_with_parallel_config
(
parallel_config
)
def
_get_weights_iterator
(
self
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
tensorizer_args
=
self
.
tensorizer_config
.
_construct_tensorizer_args
()
return
tensorizer_weights_iterator
(
tensorizer_args
)
def
_load_model_serialized_cpu
(
self
,
vllm_config
:
VllmConfig
,
)
->
nn
.
Module
:
"""Load a serialized model with tensorizer to the CPU.
This is only necessary when the model isn't vLLM-tensorized (see
examples/other/tensorize_vllm_model.py) This should still
be faster than default HuggingFace loading, but will be slower than
loading a vLLM-tensorized model.
"""
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
model
.
load_weights
(
self
.
_get_weights_iterator
())
return
model
.
eval
()
def
_load_model_serialized
(
self
,
vllm_config
:
VllmConfig
,
)
->
nn
.
Module
:
"""Load a serialized model with tensorizer.
Expects a vLLM-tensorized model. See the
examples/other/tensorize_vllm_model.py example script
for serializing vLLM models."""
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model_class
=
get_model_architecture
(
model_config
)[
0
]
tensorizer_config
=
copy
.
copy
(
self
.
tensorizer_config
)
tensorizer_config
.
model_class
=
model_class
tensorizer_config
.
hf_config
=
model_config
.
hf_config
tensorizer_config
.
dtype
=
model_config
.
dtype
model
=
load_with_tensorizer
(
tensorizer_config
,
vllm_config
=
vllm_config
)
return
model
.
eval
()
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
tensorizer_config
.
verify_with_model_config
(
model_config
)
with
self
.
tensorizer_config
.
open_stream
():
pass
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
model_config
=
vllm_config
.
model_config
parallel_config
=
vllm_config
.
parallel_config
self
.
_verify_config
(
model_config
,
parallel_config
)
if
parallel_config
.
tensor_parallel_size
>
1
:
from
vllm.distributed
import
get_tensor_model_parallel_rank
self
.
tensorizer_config
.
tensorizer_uri
=
(
self
.
tensorizer_config
.
tensorizer_uri
%
get_tensor_model_parallel_rank
())
if
is_vllm_tensorized
(
self
.
tensorizer_config
):
return
self
.
_load_model_serialized
(
vllm_config
=
vllm_config
)
return
self
.
_load_model_serialized_cpu
(
vllm_config
=
vllm_config
)
@
staticmethod
def
save_model
(
model
:
torch
.
nn
.
Module
,
tensorizer_config
:
TensorizerConfig
,
)
->
None
:
serialize_vllm_model
(
model
=
model
,
tensorizer_config
=
tensorizer_config
,
)
class
ShardedStateLoader
(
BaseModelLoader
):
"""
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/offline_inference/save_sharded_state.py` for creating a sharded
checkpoint.
"""
DEFAULT_PATTERN
=
"model-rank-{rank}-part-{part}.safetensors"
def
__init__
(
self
,
load_config
:
LoadConfig
,
runai_model_streamer
:
bool
=
False
):
super
().
__init__
(
load_config
)
self
.
runai_model_streamer
=
runai_model_streamer
extra_config
=
({}
if
load_config
.
model_loader_extra_config
is
None
else
load_config
.
model_loader_extra_config
.
copy
())
self
.
pattern
=
extra_config
.
pop
(
"pattern"
,
self
.
DEFAULT_PATTERN
)
if
extra_config
:
raise
ValueError
(
f
"Unexpected extra config keys for load format "
f
"
{
load_config
.
load_format
}
: "
f
"
{
load_config
.
model_loader_extra_config
.
keys
()
}
"
)
@
staticmethod
def
_filter_subtensors
(
tensors
:
Dict
[
str
,
torch
.
Tensor
],
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""
Filter out all tensors that share the same memory or a subset of the
memory of another tensor.
"""
same_storage_groups
:
Dict
[
Any
,
List
[
Tuple
[
str
,
torch
.
Tensor
]]]
=
(
collections
.
defaultdict
(
list
))
for
key
,
tensor
in
tensors
.
items
():
if
tensor
.
numel
():
ptr
=
tensor
.
untyped_storage
().
data_ptr
()
same_storage_groups
[
tensor
.
device
,
ptr
].
append
((
key
,
tensor
))
def
get_end_ptr
(
tensor
:
torch
.
Tensor
)
->
int
:
return
tensor
.
view
(
-
1
)[
-
1
].
data_ptr
()
+
tensor
.
element_size
()
result
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
for
group
in
same_storage_groups
.
values
():
for
k
,
t
in
group
:
a
,
b
=
t
.
data_ptr
(),
get_end_ptr
(
t
)
for
k2
,
t2
in
group
:
if
not
t2
.
is_contiguous
():
continue
a2
,
b2
=
t2
.
data_ptr
(),
get_end_ptr
(
t2
)
if
a
<
a2
or
b2
<
b
:
continue
if
a2
<
a
or
b
<
b2
or
not
t
.
is_contiguous
():
break
# t2 covers strictly more memory than t.
if
k2
<
k
:
# Same tensors, keep the one with the smaller key.
break
else
:
result
[
k
]
=
t
return
result
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
]):
if
is_s3
(
model_name_or_path
)
or
os
.
path
.
isdir
(
model_name_or_path
):
return
model_name_or_path
else
:
allow_patterns
=
[
"*.safetensors"
]
return
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
from
vllm.distributed
import
get_tensor_model_parallel_rank
model_weights
=
model_config
.
model
if
hasattr
(
model_config
,
"model_weights"
):
model_weights
=
model_config
.
model_weights
local_model_path
=
model_weights
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
_process_weights_after_loading
(
model
,
model_config
,
target_device
)
rank
=
get_tensor_model_parallel_rank
()
pattern
=
os
.
path
.
join
(
local_model_path
,
self
.
pattern
.
format
(
rank
=
rank
,
part
=
"*"
),
)
filepaths
=
[]
if
is_s3
(
local_model_path
):
file_pattern
=
f
"*
{
self
.
pattern
.
format
(
rank
=
rank
,
part
=
' * '
)
}
"
filepaths
=
s3_glob
(
path
=
local_model_path
,
allow_pattern
=
[
file_pattern
])
else
:
filepaths
=
glob
.
glob
(
pattern
)
if
not
filepaths
:
# TODO: support un-sharded checkpoints too
raise
ValueError
(
f
"Could not find checkpoint files '
{
pattern
}
', only "
f
"pre-sharded checkpoints are currently supported!"
)
state_dict
=
self
.
_filter_subtensors
(
model
.
state_dict
())
for
key
,
tensor
in
self
.
iterate_over_files
(
filepaths
):
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data
=
state_dict
[
key
].
data
param_shape
=
state_dict
[
key
].
shape
for
dim
,
size
in
enumerate
(
tensor
.
shape
):
if
size
<
param_shape
[
dim
]:
param_data
=
param_data
.
narrow
(
dim
,
0
,
size
)
if
tensor
.
shape
!=
param_shape
:
logger
.
warning
(
"loading tensor of shape %s into "
"parameter '%s' of shape %s"
,
tensor
.
shape
,
key
,
param_shape
,
)
param_data
.
copy_
(
tensor
)
state_dict
.
pop
(
key
)
if
state_dict
:
raise
ValueError
(
f
"Missing keys
{
tuple
(
state_dict
)
}
in loaded state!"
)
return
model
.
eval
()
def
iterate_over_files
(
self
,
paths
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
if
self
.
runai_model_streamer
:
yield
from
runai_safetensors_weights_iterator
(
paths
,
True
)
else
:
from
safetensors.torch
import
safe_open
for
path
in
paths
:
with
safe_open
(
path
,
framework
=
"pt"
)
as
f
:
for
key
in
f
.
keys
():
# noqa: SIM118
tensor
=
f
.
get_tensor
(
key
)
yield
key
,
tensor
@
staticmethod
def
save_model
(
model
:
torch
.
nn
.
Module
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
from
safetensors.torch
import
save_file
from
vllm.distributed
import
get_tensor_model_parallel_rank
if
pattern
is
None
:
pattern
=
ShardedStateLoader
.
DEFAULT_PATTERN
rank
=
get_tensor_model_parallel_rank
()
part_idx
=
0
total_size
=
0
state_dict
=
ShardedStateLoader
.
_filter_subtensors
(
model
.
state_dict
())
state_dict_part
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
for
key
,
tensor
in
state_dict
.
items
():
param_size
=
tensor
.
nelement
()
*
tensor
.
element_size
()
if
max_size
is
not
None
and
total_size
+
param_size
>
max_size
:
filename
=
pattern
.
format
(
rank
=
rank
,
part
=
part_idx
)
save_file
(
state_dict_part
,
os
.
path
.
join
(
path
,
filename
),
)
part_idx
+=
1
total_size
=
0
state_dict_part
=
{}
state_dict_part
[
key
]
=
tensor
total_size
+=
param_size
if
len
(
state_dict_part
)
>
0
:
filename
=
pattern
.
format
(
rank
=
rank
,
part
=
part_idx
)
save_file
(
state_dict_part
,
os
.
path
.
join
(
path
,
filename
),
)
class
BitsAndBytesModelLoader
(
BaseModelLoader
):
class
BitsAndBytesModelLoader
(
BaseModelLoader
):
"""Model loader to load model weights with BitAndBytes quantization."""
"""Model loader to load model weights with BitAndBytes quantization."""
...
@@ -1307,238 +557,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -1307,238 +557,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
model_config
=
vllm_config
.
model_config
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
self
.
_load_weights
(
model_config
,
model
)
return
model
.
eval
()
class
GGUFModelLoader
(
BaseModelLoader
):
"""
Model loader that can load GGUF files. This is useful for loading models
that are quantized with GGUF and saved in the GGUF format. This loader
supports loading both full models and sharded models.
"""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
_prepare_weights
(
self
,
model_name_or_path
:
str
):
if
os
.
path
.
isfile
(
model_name_or_path
):
return
model_name_or_path
else
:
raise
ValueError
(
f
"
{
model_name_or_path
}
is not a file."
)
def
_get_gguf_weights_map
(
self
,
model_config
:
ModelConfig
):
"""
GGUF uses this naming convention for their tensors from HF checkpoint:
`blk.N.BB.weight` and `blk.N.BB.bias`
where N signifies the block number of a layer, and BB signifies the
attention/mlp layer components.
See "Standardized tensor names" in
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
"""
config
=
model_config
.
hf_config
model_type
=
config
.
model_type
gguf_to_hf_name_map
=
{}
# hack: ggufs have a different name than transformers
if
model_type
==
"cohere"
:
model_type
=
"command-r"
if
model_type
in
(
"deepseek_v3"
,
"deepseek_v2"
):
model_type
=
"deepseek2"
# GGUF layer map assumes that we will have a merged expert weights
# so we need to map them manually
for
idx
in
range
(
config
.
num_hidden_layers
):
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.exp_probs_b.bias"
]
=
\
f
"model.layers.
{
idx
}
.mlp.gate.e_score_correction_bias"
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.ffn_down_exps.weight"
]
=
\
f
"model.layers.
{
idx
}
.mlp.experts.0.down_proj.weight"
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.ffn_gate_exps.weight"
]
=
\
f
"model.layers.
{
idx
}
.mlp.experts.0.gate_proj.weight"
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.ffn_up_exps.weight"
]
=
\
f
"model.layers.
{
idx
}
.mlp.experts.0.up_proj.weight"
arch
=
None
for
key
,
value
in
gguf
.
MODEL_ARCH_NAMES
.
items
():
if
value
==
model_type
:
arch
=
key
break
if
arch
is
None
:
raise
RuntimeError
(
f
"Unknown gguf model_type:
{
model_type
}
"
)
num_layers
=
config
.
num_hidden_layers
name_map
=
gguf
.
get_tensor_name_map
(
arch
,
num_layers
)
with
torch
.
device
(
"meta"
):
dummy_model
=
AutoModelForCausalLM
.
from_config
(
config
,
trust_remote_code
=
model_config
.
trust_remote_code
)
state_dict
=
dummy_model
.
state_dict
()
for
hf_name
in
state_dict
:
name
,
suffix
=
hf_name
.
rsplit
(
"."
,
1
)
gguf_name
=
name_map
.
get_name
(
name
)
gguf_to_hf_name_map
[
f
"
{
gguf_name
}
.
{
suffix
}
"
]
=
hf_name
return
gguf_to_hf_name_map
def
_get_weights_iterator
(
self
,
model_name_or_path
:
str
,
gguf_to_hf_name_map
:
Dict
[
str
,
str
]
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
return
gguf_quant_weights_iterator
(
model_name_or_path
,
gguf_to_hf_name_map
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
local_model_path
=
self
.
_prepare_weights
(
model_config
.
model
)
gguf_weights_map
=
self
.
_get_gguf_weights_map
(
model_config
)
# we can only know if tie word embeddings after mapping weights
if
"lm_head.weight"
in
get_gguf_extra_tensor_names
(
local_model_path
,
gguf_weights_map
):
model_config
.
hf_config
.
update
({
"tie_word_embeddings"
:
True
})
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
model
.
load_weights
(
self
.
_get_weights_iterator
(
local_model_path
,
gguf_weights_map
))
_process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
class
RunaiModelStreamerLoader
(
BaseModelLoader
):
"""
Model loader that can load safetensors
files from local FS or S3 bucket.
"""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
extra_config
=
load_config
.
model_loader_extra_config
if
(
"concurrency"
in
extra_config
and
isinstance
(
extra_config
.
get
(
"concurrency"
),
int
)):
os
.
environ
[
"RUNAI_STREAMER_CONCURRENCY"
]
=
str
(
extra_config
.
get
(
"concurrency"
))
if
(
"memory_limit"
in
extra_config
and
isinstance
(
extra_config
.
get
(
"memory_limit"
),
int
)):
os
.
environ
[
"RUNAI_STREAMER_MEMORY_LIMIT"
]
=
str
(
extra_config
.
get
(
"memory_limit"
))
runai_streamer_s3_endpoint
=
os
.
getenv
(
'RUNAI_STREAMER_S3_ENDPOINT'
)
aws_endpoint_url
=
os
.
getenv
(
'AWS_ENDPOINT_URL'
)
if
(
runai_streamer_s3_endpoint
is
None
and
aws_endpoint_url
is
not
None
):
os
.
environ
[
"RUNAI_STREAMER_S3_ENDPOINT"
]
=
aws_endpoint_url
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
])
->
List
[
str
]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
is_s3_path
=
is_s3
(
model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
safetensors_pattern
=
"*.safetensors"
index_file
=
SAFE_WEIGHTS_INDEX_NAME
hf_folder
=
(
model_name_or_path
if
(
is_local
or
is_s3_path
)
else
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
[
safetensors_pattern
],
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
))
if
is_s3_path
:
hf_weights_files
=
s3_glob
(
path
=
hf_folder
,
allow_pattern
=
[
safetensors_pattern
])
else
:
hf_weights_files
=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
safetensors_pattern
))
if
not
is_local
and
not
is_s3_path
:
download_safetensors_index_file_from_hf
(
model_name_or_path
,
index_file
,
self
.
load_config
.
download_dir
,
revision
)
if
not
hf_weights_files
:
raise
RuntimeError
(
f
"Cannot find any safetensors model weights with "
f
"`
{
model_name_or_path
}
`"
)
return
hf_weights_files
def
_get_weights_iterator
(
self
,
model_or_path
:
str
,
revision
:
str
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Get an iterator for the model weights based on the load format."""
hf_weights_files
=
self
.
_prepare_weights
(
model_or_path
,
revision
)
return
runai_safetensors_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
"""Download model if necessary"""
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
model
=
initialize_model
(
vllm_config
=
vllm_config
)
"""Perform streaming of the model to destination"""
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
self
.
_load_weights
(
model_config
,
model
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
model_weights
=
model_config
.
model
if
hasattr
(
model_config
,
"model_weights"
):
model_weights
=
model_config
.
model_weights
model
.
load_weights
(
self
.
_get_weights_iterator
(
model_weights
,
model_config
.
revision
))
_process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
.
eval
()
return
model
.
eval
()
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
"""Get a model loader based on the load format."""
if
isinstance
(
load_config
.
load_format
,
type
):
return
load_config
.
load_format
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
DUMMY
:
return
DummyModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
TENSORIZER
:
return
TensorizerLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
SHARDED_STATE
:
return
ShardedStateLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
BITSANDBYTES
:
return
BitsAndBytesModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
GGUF
:
return
GGUFModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
RUNAI_STREAMER
:
return
RunaiModelStreamerLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
RUNAI_STREAMER_SHARDED
:
return
ShardedStateLoader
(
load_config
,
runai_model_streamer
=
True
)
return
DefaultModelLoader
(
load_config
)
vllm/model_executor/model_loader/default_loader.py
0 → 100644
View file @
822de7fb
# SPDX-License-Identifier: Apache-2.0
import
dataclasses
import
glob
import
os
import
time
from
typing
import
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
import
huggingface_hub
import
torch
from
torch
import
nn
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
LoadConfig
,
LoadFormat
,
ModelConfig
,
VllmConfig
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.utils
import
(
initialize_model
,
process_weights_after_loading
,
set_default_torch_dtype
)
from
vllm.model_executor.model_loader.weight_utils
import
(
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
fastsafetensors_weights_iterator
,
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
get_lock
,
np_cache_weights_iterator
,
pt_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
class
DefaultModelLoader
(
BaseModelLoader
):
"""Model loader that can load different file types from disk."""
@
dataclasses
.
dataclass
class
Source
:
"""A source for weights."""
model_or_path
:
str
"""The model ID or path."""
revision
:
Optional
[
str
]
"""The optional model revision."""
prefix
:
str
=
""
"""A prefix to prepend to all weights."""
fall_back_to_pt
:
bool
=
True
"""Whether .pt weights can be used."""
allow_patterns_overrides
:
Optional
[
list
[
str
]]
=
None
"""If defined, weights will load exclusively using these patterns."""
counter_before_loading_weights
:
float
=
0.0
counter_after_loading_weights
:
float
=
0.0
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
_maybe_download_from_modelscope
(
self
,
model
:
str
,
revision
:
Optional
[
str
])
->
Optional
[
str
]:
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if
VLLM_USE_MODELSCOPE
:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from
modelscope.hub.snapshot_download
import
snapshot_download
if
not
os
.
path
.
exists
(
model
):
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model
,
self
.
load_config
.
download_dir
):
model_path
=
snapshot_download
(
model_id
=
model
,
cache_dir
=
self
.
load_config
.
download_dir
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
revision
=
revision
,
ignore_file_pattern
=
self
.
load_config
.
ignore_patterns
,
)
else
:
model_path
=
model
return
model_path
return
None
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
,
allow_patterns_overrides
:
Optional
[
list
[
str
]],
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
model_name_or_path
=
(
self
.
_maybe_download_from_modelscope
(
model_name_or_path
,
revision
)
or
model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
load_format
=
self
.
load_config
.
load_format
use_safetensors
=
False
index_file
=
SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
if
load_format
==
LoadFormat
.
AUTO
:
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
(
load_format
==
LoadFormat
.
SAFETENSORS
or
load_format
==
LoadFormat
.
FASTSAFETENSORS
):
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
LoadFormat
.
MISTRAL
:
use_safetensors
=
True
allow_patterns
=
[
"consolidated*.safetensors"
]
index_file
=
"consolidated.safetensors.index.json"
elif
load_format
==
LoadFormat
.
PT
:
allow_patterns
=
[
"*.pt"
]
elif
load_format
==
LoadFormat
.
NPCACHE
:
allow_patterns
=
[
"*.bin"
]
else
:
raise
ValueError
(
f
"Unknown load_format:
{
load_format
}
"
)
if
fall_back_to_pt
:
allow_patterns
+=
[
"*.pt"
]
if
allow_patterns_overrides
is
not
None
:
allow_patterns
=
allow_patterns_overrides
if
not
is_local
:
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
else
:
hf_folder
=
model_name_or_path
hf_weights_files
:
List
[
str
]
=
[]
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
if
len
(
hf_weights_files
)
>
0
:
if
pattern
==
"*.safetensors"
:
use_safetensors
=
True
break
if
use_safetensors
:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if
not
is_local
:
download_safetensors_index_file_from_hf
(
model_name_or_path
,
index_file
,
self
.
load_config
.
download_dir
,
revision
,
)
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
,
hf_folder
,
index_file
)
else
:
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_folder
,
hf_weights_files
,
use_safetensors
def
_get_weights_iterator
(
self
,
source
:
"Source"
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Get an iterator for the model weights based on the load format."""
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
source
.
model_or_path
,
source
.
revision
,
source
.
fall_back_to_pt
,
source
.
allow_patterns_overrides
)
if
self
.
load_config
.
load_format
==
LoadFormat
.
NPCACHE
:
# Currently np_cache only support *.bin checkpoints
assert
use_safetensors
is
False
weights_iterator
=
np_cache_weights_iterator
(
source
.
model_or_path
,
self
.
load_config
.
download_dir
,
hf_folder
,
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
)
elif
use_safetensors
:
if
self
.
load_config
.
load_format
==
LoadFormat
.
FASTSAFETENSORS
:
weights_iterator
=
fastsafetensors_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
)
else
:
weights_iterator
=
safetensors_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
)
else
:
weights_iterator
=
pt_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
self
.
load_config
.
pt_load_map_location
,
)
if
current_platform
.
is_tpu
():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
import
torch_xla.core.xla_model
as
xm
def
_xla_weights_iterator
(
iterator
:
Generator
):
for
weights
in
iterator
:
yield
weights
xm
.
mark_step
()
weights_iterator
=
_xla_weights_iterator
(
weights_iterator
)
elif
current_platform
.
is_hpu
():
import
habana_frameworks.torch.core
as
htcore
def
_hpu_weights_iterator
(
iterator
:
Generator
):
for
weights
in
iterator
:
yield
weights
htcore
.
mark_step
()
weights_iterator
=
_hpu_weights_iterator
(
weights_iterator
)
if
self
.
counter_before_loading_weights
==
0.0
:
self
.
counter_before_loading_weights
=
time
.
perf_counter
()
# Apply the prefix.
return
((
source
.
prefix
+
name
,
tensor
)
for
(
name
,
tensor
)
in
weights_iterator
)
def
get_all_weights
(
self
,
model_config
:
ModelConfig
,
model
:
nn
.
Module
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
primary_weights
=
DefaultModelLoader
.
Source
(
model_config
.
model
,
model_config
.
revision
,
prefix
=
""
,
fall_back_to_pt
=
getattr
(
model
,
"fall_back_to_pt_during_load"
,
True
),
allow_patterns_overrides
=
getattr
(
model
,
"allow_patterns_overrides"
,
None
),
)
yield
from
self
.
_get_weights_iterator
(
primary_weights
)
secondary_weights
=
cast
(
Iterable
[
DefaultModelLoader
.
Source
],
getattr
(
model
,
"secondary_weights"
,
()),
)
for
source
in
secondary_weights
:
yield
from
self
.
_get_weights_iterator
(
source
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
,
fall_back_to_pt
=
True
,
allow_patterns_overrides
=
None
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
initialize_model
(
vllm_config
=
vllm_config
)
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
loaded_weights
=
model
.
load_weights
(
self
.
get_all_weights
(
model_config
,
model
))
self
.
counter_after_loading_weights
=
time
.
perf_counter
()
logger
.
info
(
"Loading weights took %.2f seconds"
,
self
.
counter_after_loading_weights
-
self
.
counter_before_loading_weights
)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if
model_config
.
quantization
is
None
and
loaded_weights
is
not
None
:
weights_not_loaded
=
weights_to_load
-
loaded_weights
if
weights_not_loaded
:
raise
ValueError
(
"Following weights were not initialized from "
f
"checkpoint:
{
weights_not_loaded
}
"
)
process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
.
eval
()
vllm/model_executor/model_loader/dummy_loader.py
0 → 100644
View file @
822de7fb
# SPDX-License-Identifier: Apache-2.0
import
torch
import
torch.nn
as
nn
from
vllm.config
import
LoadConfig
,
ModelConfig
,
VllmConfig
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.utils
import
(
initialize_model
,
process_weights_after_loading
,
set_default_torch_dtype
)
from
vllm.model_executor.model_loader.weight_utils
import
(
initialize_dummy_weights
)
class
DummyModelLoader
(
BaseModelLoader
):
"""Model loader that will set model weights to random values."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
pass
# Nothing to download
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
initialize_model
(
vllm_config
=
vllm_config
)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights
(
model
)
process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
.
eval
()
vllm/model_executor/model_loader/gguf_loader.py
0 → 100644
View file @
822de7fb
# SPDX-License-Identifier: Apache-2.0
import
os
from
typing
import
Dict
,
Generator
,
Tuple
import
gguf
import
torch
import
torch.nn
as
nn
from
transformers
import
AutoModelForCausalLM
from
vllm.config
import
LoadConfig
,
ModelConfig
,
VllmConfig
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.utils
import
(
initialize_model
,
process_weights_after_loading
,
set_default_torch_dtype
)
from
vllm.model_executor.model_loader.weight_utils
import
(
get_gguf_extra_tensor_names
,
gguf_quant_weights_iterator
)
class
GGUFModelLoader
(
BaseModelLoader
):
"""
Model loader that can load GGUF files. This is useful for loading models
that are quantized with GGUF and saved in the GGUF format. This loader
supports loading both full models and sharded models.
"""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
_prepare_weights
(
self
,
model_name_or_path
:
str
):
if
os
.
path
.
isfile
(
model_name_or_path
):
return
model_name_or_path
else
:
raise
ValueError
(
f
"
{
model_name_or_path
}
is not a file."
)
def
_get_gguf_weights_map
(
self
,
model_config
:
ModelConfig
):
"""
GGUF uses this naming convention for their tensors from HF checkpoint:
`blk.N.BB.weight` and `blk.N.BB.bias`
where N signifies the block number of a layer, and BB signifies the
attention/mlp layer components.
See "Standardized tensor names" in
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
"""
config
=
model_config
.
hf_config
model_type
=
config
.
model_type
gguf_to_hf_name_map
=
{}
# hack: ggufs have a different name than transformers
if
model_type
==
"cohere"
:
model_type
=
"command-r"
if
model_type
in
(
"deepseek_v3"
,
"deepseek_v2"
):
model_type
=
"deepseek2"
# GGUF layer map assumes that we will have a merged expert weights
# so we need to map them manually
for
idx
in
range
(
config
.
num_hidden_layers
):
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.exp_probs_b.bias"
]
=
\
f
"model.layers.
{
idx
}
.mlp.gate.e_score_correction_bias"
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.ffn_down_exps.weight"
]
=
\
f
"model.layers.
{
idx
}
.mlp.experts.0.down_proj.weight"
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.ffn_gate_exps.weight"
]
=
\
f
"model.layers.
{
idx
}
.mlp.experts.0.gate_proj.weight"
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.ffn_up_exps.weight"
]
=
\
f
"model.layers.
{
idx
}
.mlp.experts.0.up_proj.weight"
arch
=
None
for
key
,
value
in
gguf
.
MODEL_ARCH_NAMES
.
items
():
if
value
==
model_type
:
arch
=
key
break
if
arch
is
None
:
raise
RuntimeError
(
f
"Unknown gguf model_type:
{
model_type
}
"
)
num_layers
=
config
.
num_hidden_layers
name_map
=
gguf
.
get_tensor_name_map
(
arch
,
num_layers
)
with
torch
.
device
(
"meta"
):
dummy_model
=
AutoModelForCausalLM
.
from_config
(
config
,
trust_remote_code
=
model_config
.
trust_remote_code
)
state_dict
=
dummy_model
.
state_dict
()
for
hf_name
in
state_dict
:
name
,
suffix
=
hf_name
.
rsplit
(
"."
,
1
)
gguf_name
=
name_map
.
get_name
(
name
)
gguf_to_hf_name_map
[
f
"
{
gguf_name
}
.
{
suffix
}
"
]
=
hf_name
return
gguf_to_hf_name_map
def
_get_weights_iterator
(
self
,
model_name_or_path
:
str
,
gguf_to_hf_name_map
:
Dict
[
str
,
str
]
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
return
gguf_quant_weights_iterator
(
model_name_or_path
,
gguf_to_hf_name_map
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
local_model_path
=
self
.
_prepare_weights
(
model_config
.
model
)
gguf_weights_map
=
self
.
_get_gguf_weights_map
(
model_config
)
# we can only know if tie word embeddings after mapping weights
if
"lm_head.weight"
in
get_gguf_extra_tensor_names
(
local_model_path
,
gguf_weights_map
):
model_config
.
hf_config
.
update
({
"tie_word_embeddings"
:
True
})
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
initialize_model
(
vllm_config
=
vllm_config
)
model
.
load_weights
(
self
.
_get_weights_iterator
(
local_model_path
,
gguf_weights_map
))
process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
vllm/model_executor/model_loader/runai_streamer_loader.py
0 → 100644
View file @
822de7fb
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: SIM117
import
glob
import
os
from
typing
import
Generator
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
LoadConfig
,
ModelConfig
,
VllmConfig
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.utils
import
(
initialize_model
,
process_weights_after_loading
,
set_default_torch_dtype
)
from
vllm.model_executor.model_loader.weight_utils
import
(
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
runai_safetensors_weights_iterator
)
from
vllm.transformers_utils.s3_utils
import
glob
as
s3_glob
from
vllm.transformers_utils.utils
import
is_s3
class
RunaiModelStreamerLoader
(
BaseModelLoader
):
"""
Model loader that can load safetensors
files from local FS or S3 bucket.
"""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
extra_config
=
load_config
.
model_loader_extra_config
if
(
"concurrency"
in
extra_config
and
isinstance
(
extra_config
.
get
(
"concurrency"
),
int
)):
os
.
environ
[
"RUNAI_STREAMER_CONCURRENCY"
]
=
str
(
extra_config
.
get
(
"concurrency"
))
if
(
"memory_limit"
in
extra_config
and
isinstance
(
extra_config
.
get
(
"memory_limit"
),
int
)):
os
.
environ
[
"RUNAI_STREAMER_MEMORY_LIMIT"
]
=
str
(
extra_config
.
get
(
"memory_limit"
))
runai_streamer_s3_endpoint
=
os
.
getenv
(
'RUNAI_STREAMER_S3_ENDPOINT'
)
aws_endpoint_url
=
os
.
getenv
(
'AWS_ENDPOINT_URL'
)
if
(
runai_streamer_s3_endpoint
is
None
and
aws_endpoint_url
is
not
None
):
os
.
environ
[
"RUNAI_STREAMER_S3_ENDPOINT"
]
=
aws_endpoint_url
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
])
->
List
[
str
]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
is_s3_path
=
is_s3
(
model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
safetensors_pattern
=
"*.safetensors"
index_file
=
SAFE_WEIGHTS_INDEX_NAME
hf_folder
=
(
model_name_or_path
if
(
is_local
or
is_s3_path
)
else
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
[
safetensors_pattern
],
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
))
if
is_s3_path
:
hf_weights_files
=
s3_glob
(
path
=
hf_folder
,
allow_pattern
=
[
safetensors_pattern
])
else
:
hf_weights_files
=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
safetensors_pattern
))
if
not
is_local
and
not
is_s3_path
:
download_safetensors_index_file_from_hf
(
model_name_or_path
,
index_file
,
self
.
load_config
.
download_dir
,
revision
)
if
not
hf_weights_files
:
raise
RuntimeError
(
f
"Cannot find any safetensors model weights with "
f
"`
{
model_name_or_path
}
`"
)
return
hf_weights_files
def
_get_weights_iterator
(
self
,
model_or_path
:
str
,
revision
:
str
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Get an iterator for the model weights based on the load format."""
hf_weights_files
=
self
.
_prepare_weights
(
model_or_path
,
revision
)
return
runai_safetensors_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
"""Download model if necessary"""
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
"""Perform streaming of the model to destination"""
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
initialize_model
(
vllm_config
=
vllm_config
)
model_weights
=
model_config
.
model
if
hasattr
(
model_config
,
"model_weights"
):
model_weights
=
model_config
.
model_weights
model
.
load_weights
(
self
.
_get_weights_iterator
(
model_weights
,
model_config
.
revision
))
process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
.
eval
()
vllm/model_executor/model_loader/sharded_state_loader.py
0 → 100644
View file @
822de7fb
# SPDX-License-Identifier: Apache-2.0
import
collections
import
glob
import
os
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
vllm.config
import
LoadConfig
,
ModelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.utils
import
(
initialize_model
,
process_weights_after_loading
,
set_default_torch_dtype
)
from
vllm.model_executor.model_loader.weight_utils
import
(
download_weights_from_hf
,
runai_safetensors_weights_iterator
)
from
vllm.transformers_utils.s3_utils
import
glob
as
s3_glob
from
vllm.transformers_utils.utils
import
is_s3
logger
=
init_logger
(
__name__
)
class
ShardedStateLoader
(
BaseModelLoader
):
"""
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/offline_inference/save_sharded_state.py` for creating a sharded
checkpoint.
"""
DEFAULT_PATTERN
=
"model-rank-{rank}-part-{part}.safetensors"
def
__init__
(
self
,
load_config
:
LoadConfig
,
runai_model_streamer
:
bool
=
False
):
super
().
__init__
(
load_config
)
self
.
runai_model_streamer
=
runai_model_streamer
extra_config
=
({}
if
load_config
.
model_loader_extra_config
is
None
else
load_config
.
model_loader_extra_config
.
copy
())
self
.
pattern
=
extra_config
.
pop
(
"pattern"
,
self
.
DEFAULT_PATTERN
)
if
extra_config
:
raise
ValueError
(
f
"Unexpected extra config keys for load format "
f
"
{
load_config
.
load_format
}
: "
f
"
{
load_config
.
model_loader_extra_config
.
keys
()
}
"
)
@
staticmethod
def
_filter_subtensors
(
tensors
:
Dict
[
str
,
torch
.
Tensor
],
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""
Filter out all tensors that share the same memory or a subset of the
memory of another tensor.
"""
same_storage_groups
:
Dict
[
Any
,
List
[
Tuple
[
str
,
torch
.
Tensor
]]]
=
(
collections
.
defaultdict
(
list
))
for
key
,
tensor
in
tensors
.
items
():
if
tensor
.
numel
():
ptr
=
tensor
.
untyped_storage
().
data_ptr
()
same_storage_groups
[
tensor
.
device
,
ptr
].
append
((
key
,
tensor
))
def
get_end_ptr
(
tensor
:
torch
.
Tensor
)
->
int
:
return
tensor
.
view
(
-
1
)[
-
1
].
data_ptr
()
+
tensor
.
element_size
()
result
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
for
group
in
same_storage_groups
.
values
():
for
k
,
t
in
group
:
a
,
b
=
t
.
data_ptr
(),
get_end_ptr
(
t
)
for
k2
,
t2
in
group
:
if
not
t2
.
is_contiguous
():
continue
a2
,
b2
=
t2
.
data_ptr
(),
get_end_ptr
(
t2
)
if
a
<
a2
or
b2
<
b
:
continue
if
a2
<
a
or
b
<
b2
or
not
t
.
is_contiguous
():
break
# t2 covers strictly more memory than t.
if
k2
<
k
:
# Same tensors, keep the one with the smaller key.
break
else
:
result
[
k
]
=
t
return
result
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
]):
if
is_s3
(
model_name_or_path
)
or
os
.
path
.
isdir
(
model_name_or_path
):
return
model_name_or_path
else
:
allow_patterns
=
[
"*.safetensors"
]
return
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
from
vllm.distributed
import
get_tensor_model_parallel_rank
model_weights
=
model_config
.
model
if
hasattr
(
model_config
,
"model_weights"
):
model_weights
=
model_config
.
model_weights
local_model_path
=
model_weights
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
initialize_model
(
vllm_config
=
vllm_config
)
process_weights_after_loading
(
model
,
model_config
,
target_device
)
rank
=
get_tensor_model_parallel_rank
()
pattern
=
os
.
path
.
join
(
local_model_path
,
self
.
pattern
.
format
(
rank
=
rank
,
part
=
"*"
),
)
filepaths
=
[]
if
is_s3
(
local_model_path
):
file_pattern
=
f
"*
{
self
.
pattern
.
format
(
rank
=
rank
,
part
=
' * '
)
}
"
filepaths
=
s3_glob
(
path
=
local_model_path
,
allow_pattern
=
[
file_pattern
])
else
:
filepaths
=
glob
.
glob
(
pattern
)
if
not
filepaths
:
# TODO: support un-sharded checkpoints too
raise
ValueError
(
f
"Could not find checkpoint files '
{
pattern
}
', only "
f
"pre-sharded checkpoints are currently supported!"
)
state_dict
=
self
.
_filter_subtensors
(
model
.
state_dict
())
for
key
,
tensor
in
self
.
iterate_over_files
(
filepaths
):
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data
=
state_dict
[
key
].
data
param_shape
=
state_dict
[
key
].
shape
for
dim
,
size
in
enumerate
(
tensor
.
shape
):
if
size
<
param_shape
[
dim
]:
param_data
=
param_data
.
narrow
(
dim
,
0
,
size
)
if
tensor
.
shape
!=
param_shape
:
logger
.
warning
(
"loading tensor of shape %s into "
"parameter '%s' of shape %s"
,
tensor
.
shape
,
key
,
param_shape
,
)
param_data
.
copy_
(
tensor
)
state_dict
.
pop
(
key
)
if
state_dict
:
raise
ValueError
(
f
"Missing keys
{
tuple
(
state_dict
)
}
in loaded state!"
)
return
model
.
eval
()
def
iterate_over_files
(
self
,
paths
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
if
self
.
runai_model_streamer
:
yield
from
runai_safetensors_weights_iterator
(
paths
,
True
)
else
:
from
safetensors.torch
import
safe_open
for
path
in
paths
:
with
safe_open
(
path
,
framework
=
"pt"
)
as
f
:
for
key
in
f
.
keys
():
# noqa: SIM118
tensor
=
f
.
get_tensor
(
key
)
yield
key
,
tensor
@
staticmethod
def
save_model
(
model
:
torch
.
nn
.
Module
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
from
safetensors.torch
import
save_file
from
vllm.distributed
import
get_tensor_model_parallel_rank
if
pattern
is
None
:
pattern
=
ShardedStateLoader
.
DEFAULT_PATTERN
rank
=
get_tensor_model_parallel_rank
()
part_idx
=
0
total_size
=
0
state_dict
=
ShardedStateLoader
.
_filter_subtensors
(
model
.
state_dict
())
state_dict_part
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
for
key
,
tensor
in
state_dict
.
items
():
param_size
=
tensor
.
nelement
()
*
tensor
.
element_size
()
if
max_size
is
not
None
and
total_size
+
param_size
>
max_size
:
filename
=
pattern
.
format
(
rank
=
rank
,
part
=
part_idx
)
save_file
(
state_dict_part
,
os
.
path
.
join
(
path
,
filename
),
)
part_idx
+=
1
total_size
=
0
state_dict_part
=
{}
state_dict_part
[
key
]
=
tensor
total_size
+=
param_size
if
len
(
state_dict_part
)
>
0
:
filename
=
pattern
.
format
(
rank
=
rank
,
part
=
part_idx
)
save_file
(
state_dict_part
,
os
.
path
.
join
(
path
,
filename
),
)
vllm/model_executor/model_loader/tensorizer_loader.py
0 → 100644
View file @
822de7fb
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: SIM117
import
copy
from
typing
import
Generator
,
Tuple
import
torch
from
torch
import
nn
from
vllm.config
import
LoadConfig
,
ModelConfig
,
ParallelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.tensorizer
import
(
TensorizerConfig
,
is_vllm_tensorized
,
load_with_tensorizer
,
serialize_vllm_model
,
tensorizer_weights_iterator
)
from
vllm.model_executor.model_loader.utils
import
(
get_model_architecture
,
initialize_model
,
set_default_torch_dtype
)
logger
=
init_logger
(
__name__
)
class
TensorizerLoader
(
BaseModelLoader
):
"""Model loader using CoreWeave's tensorizer library."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
isinstance
(
load_config
.
model_loader_extra_config
,
TensorizerConfig
):
self
.
tensorizer_config
=
load_config
.
model_loader_extra_config
else
:
self
.
tensorizer_config
=
TensorizerConfig
(
**
load_config
.
model_loader_extra_config
)
def
_verify_config
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
):
self
.
tensorizer_config
.
verify_with_model_config
(
model_config
)
self
.
tensorizer_config
.
verify_with_parallel_config
(
parallel_config
)
def
_get_weights_iterator
(
self
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
tensorizer_args
=
self
.
tensorizer_config
.
_construct_tensorizer_args
()
return
tensorizer_weights_iterator
(
tensorizer_args
)
def
_load_model_serialized_cpu
(
self
,
vllm_config
:
VllmConfig
,
)
->
nn
.
Module
:
"""Load a serialized model with tensorizer to the CPU.
This is only necessary when the model isn't vLLM-tensorized (see
examples/other/tensorize_vllm_model.py) This should still
be faster than default HuggingFace loading, but will be slower than
loading a vLLM-tensorized model.
"""
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
initialize_model
(
vllm_config
=
vllm_config
)
model
.
load_weights
(
self
.
_get_weights_iterator
())
return
model
.
eval
()
def
_load_model_serialized
(
self
,
vllm_config
:
VllmConfig
,
)
->
nn
.
Module
:
"""Load a serialized model with tensorizer.
Expects a vLLM-tensorized model. See the
examples/other/tensorize_vllm_model.py example script
for serializing vLLM models."""
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model_class
=
get_model_architecture
(
model_config
)[
0
]
tensorizer_config
=
copy
.
copy
(
self
.
tensorizer_config
)
tensorizer_config
.
model_class
=
model_class
tensorizer_config
.
hf_config
=
model_config
.
hf_config
tensorizer_config
.
dtype
=
model_config
.
dtype
model
=
load_with_tensorizer
(
tensorizer_config
,
vllm_config
=
vllm_config
)
return
model
.
eval
()
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
tensorizer_config
.
verify_with_model_config
(
model_config
)
with
self
.
tensorizer_config
.
open_stream
():
pass
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
model_config
=
vllm_config
.
model_config
parallel_config
=
vllm_config
.
parallel_config
self
.
_verify_config
(
model_config
,
parallel_config
)
if
parallel_config
.
tensor_parallel_size
>
1
:
from
vllm.distributed
import
get_tensor_model_parallel_rank
self
.
tensorizer_config
.
tensorizer_uri
=
(
self
.
tensorizer_config
.
tensorizer_uri
%
get_tensor_model_parallel_rank
())
if
is_vllm_tensorized
(
self
.
tensorizer_config
):
return
self
.
_load_model_serialized
(
vllm_config
=
vllm_config
)
return
self
.
_load_model_serialized_cpu
(
vllm_config
=
vllm_config
)
@
staticmethod
def
save_model
(
model
:
torch
.
nn
.
Module
,
tensorizer_config
:
TensorizerConfig
,
)
->
None
:
serialize_vllm_model
(
model
=
model
,
tensorizer_config
=
tensorizer_config
,
)
vllm/model_executor/model_loader/utils.py
View file @
822de7fb
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""Utilities for selecting and loading models."""
"""Utilities for selecting and loading models."""
import
contextlib
import
contextlib
import
inspect
import
warnings
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
...
@@ -9,14 +12,18 @@ import transformers
...
@@ -9,14 +12,18 @@ import transformers
from
torch
import
nn
from
torch
import
nn
from
transformers.dynamic_module_utils
import
get_class_from_dynamic_module
from
transformers.dynamic_module_utils
import
get_class_from_dynamic_module
from
vllm.config
import
ModelConfig
,
ModelImpl
from
vllm.attention
import
Attention
from
vllm.config
import
(
ModelConfig
,
ModelImpl
,
VllmConfig
,
set_current_vllm_config
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
QKVCrossParallelLinear
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models.adapters
import
(
as_classification_model
,
from
vllm.model_executor.models.adapters
import
(
as_classification_model
,
as_embedding_model
,
as_embedding_model
,
as_reward_model
)
as_reward_model
)
from
vllm.utils
import
is_pin_memory_available
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -30,6 +37,128 @@ def set_default_torch_dtype(dtype: torch.dtype):
...
@@ -30,6 +37,128 @@ def set_default_torch_dtype(dtype: torch.dtype):
torch
.
set_default_dtype
(
old_dtype
)
torch
.
set_default_dtype
(
old_dtype
)
def
initialize_model
(
vllm_config
:
VllmConfig
,
*
,
prefix
:
str
=
""
,
model_class
:
Optional
[
type
[
nn
.
Module
]]
=
None
,
)
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
model_config
=
vllm_config
.
model_config
if
model_class
is
None
:
model_class
,
_
=
get_model_architecture
(
model_config
)
if
vllm_config
.
quant_config
is
not
None
:
configure_quant_config
(
vllm_config
.
quant_config
,
model_class
)
signatures
=
inspect
.
signature
(
model_class
.
__init__
)
all_params
=
[
param
.
name
for
param
in
signatures
.
parameters
.
values
()]
if
"vllm_config"
in
all_params
and
"prefix"
in
all_params
:
# new-style model class
with
set_current_vllm_config
(
vllm_config
,
check_compile
=
True
):
return
model_class
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
msg
=
(
"vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"for the design and update the model class accordingly."
)
warnings
.
warn
(
msg
,
DeprecationWarning
,
stacklevel
=
2
)
logger
.
warning
(
"Trying to guess the arguments for old-style model class %s"
,
model_class
,
)
# try to be compatible with old-style model class
kwargs
=
{}
if
"prefix"
in
all_params
:
kwargs
[
"prefix"
]
=
prefix
if
"config"
in
all_params
:
kwargs
[
"config"
]
=
model_config
.
hf_config
if
"cache_config"
in
all_params
:
kwargs
[
"cache_config"
]
=
vllm_config
.
cache_config
if
"quant_config"
in
all_params
:
kwargs
[
"quant_config"
]
=
vllm_config
.
quant_config
if
"lora_config"
in
all_params
:
kwargs
[
"lora_config"
]
=
vllm_config
.
lora_config
if
"scheduler_config"
in
all_params
:
kwargs
[
"scheduler_config"
]
=
vllm_config
.
scheduler_config
with
set_current_vllm_config
(
vllm_config
,
check_compile
=
True
):
return
model_class
(
**
kwargs
)
def
process_weights_after_loading
(
model
:
nn
.
Module
,
model_config
:
ModelConfig
,
target_device
:
torch
.
device
)
->
None
:
for
_
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
QKVCrossParallelLinear
):
# NOTE(Isotr0py): special case for cross QKV layer because
# q and kv proj aren't registered as submodules intentionally
module
.
process_weights_after_loading
()
continue
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
isinstance
(
quant_method
,
QuantizeMethodBase
):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with
device_loading_context
(
module
,
target_device
):
quant_method
.
process_weights_after_loading
(
module
)
# Currently only used by MLA.
# NOTE: This intentionally happens after other modules so we can easily
# decompress the weights for MLA.
for
_
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
Attention
)
and
\
hasattr
(
module
,
"process_weights_after_loading"
):
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module
.
process_weights_after_loading
(
model_config
.
dtype
)
@
contextmanager
def
device_loading_context
(
module
:
torch
.
nn
.
Module
,
target_device
:
torch
.
device
):
if
target_device
.
type
==
"cpu"
:
# If target is CPU, no need to move anything
yield
module
return
original_device_states
:
Dict
[
str
,
torch
.
device
]
=
{}
# Store original device states and move parameters to GPU if they're on CPU
for
name
,
p
in
module
.
named_parameters
():
if
p
.
device
.
type
==
"cpu"
:
original_device_states
[
name
]
=
p
.
device
p
.
data
=
p
.
data
.
to
(
target_device
)
# Parameters already on target device are not touched
try
:
yield
module
finally
:
# Restore parameters to their original devices, ignoring new parameters
pin_memory
=
is_pin_memory_available
()
for
name
,
p
in
module
.
named_parameters
():
if
name
in
original_device_states
:
original_device
:
torch
.
device
=
original_device_states
[
name
]
if
original_device
.
type
==
"cpu"
:
# `torch.empty_like` does not support `pin_memory` argument
cpu_data
=
torch
.
empty_strided
(
size
=
p
.
data
.
size
(),
stride
=
p
.
data
.
stride
(),
dtype
=
p
.
data
.
dtype
,
layout
=
p
.
data
.
layout
,
device
=
"cpu"
,
pin_memory
=
pin_memory
,
)
cpu_data
.
copy_
(
p
.
data
)
p
.
data
=
cpu_data
else
:
p
.
data
=
p
.
data
.
to
(
original_device
)
# New parameters or parameters already on target device are untouched
def
resolve_transformers_arch
(
model_config
:
ModelConfig
,
def
resolve_transformers_arch
(
model_config
:
ModelConfig
,
architectures
:
list
[
str
]):
architectures
:
list
[
str
]):
for
i
,
arch
in
enumerate
(
architectures
):
for
i
,
arch
in
enumerate
(
architectures
):
...
...
vllm/model_executor/models/mllama4.py
View file @
822de7fb
...
@@ -37,7 +37,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -37,7 +37,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.
loader
import
_
initialize_model
from
vllm.model_executor.model_loader.
utils
import
initialize_model
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
@@ -670,7 +670,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -670,7 +670,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
config
,
self
.
config
,
None
,
None
,
prefix
=
maybe_prefix
(
prefix
,
"multi_modal_projector"
))
prefix
=
maybe_prefix
(
prefix
,
"multi_modal_projector"
))
self
.
language_model
=
_
initialize_model
(
self
.
language_model
=
initialize_model
(
vllm_config
=
vllm_config
.
with_hf_config
(
config
.
text_config
,
vllm_config
=
vllm_config
.
with_hf_config
(
config
.
text_config
,
[
"LlamaForCausalLM"
]),
[
"LlamaForCausalLM"
]),
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
...
...
vllm/model_executor/models/ultravox.py
View file @
822de7fb
...
@@ -17,7 +17,7 @@ from vllm.config import VllmConfig
...
@@ -17,7 +17,7 @@ from vllm.config import VllmConfig
from
vllm.forward_context
import
get_forward_context
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.layers.activation
import
MulAndSilu
,
get_act_fn
from
vllm.model_executor.layers.activation
import
MulAndSilu
,
get_act_fn
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.model_loader
.loader
import
DefaultModelLoader
from
vllm.model_executor.model_loader
import
DefaultModelLoader
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
...
vllm/model_executor/models/utils.py
View file @
822de7fb
...
@@ -273,7 +273,7 @@ def init_vllm_registered_model(
...
@@ -273,7 +273,7 @@ def init_vllm_registered_model(
Helper function to initialize an inner model registered to vLLM,
Helper function to initialize an inner model registered to vLLM,
based on the arguments passed to the outer vLLM model.
based on the arguments passed to the outer vLLM model.
"""
"""
from
vllm.model_executor.model_loader.
loader
import
_
initialize_model
from
vllm.model_executor.model_loader.
utils
import
initialize_model
if
hf_config
is
None
and
architectures
is
not
None
:
if
hf_config
is
None
and
architectures
is
not
None
:
# So that the architectures field is overridden
# So that the architectures field is overridden
...
@@ -283,7 +283,7 @@ def init_vllm_registered_model(
...
@@ -283,7 +283,7 @@ def init_vllm_registered_model(
vllm_config
=
vllm_config
.
with_hf_config
(
hf_config
,
vllm_config
=
vllm_config
.
with_hf_config
(
hf_config
,
architectures
=
architectures
)
architectures
=
architectures
)
return
_
initialize_model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
return
initialize_model
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
@
overload
@
overload
...
...
vllm/v1/spec_decode/eagle.py
View file @
822de7fb
...
@@ -7,7 +7,7 @@ from vllm.config import (CompilationLevel, VllmConfig,
...
@@ -7,7 +7,7 @@ from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config
,
set_current_vllm_config
)
get_layers_from_vllm_config
,
set_current_vllm_config
)
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
.loader
import
get_model_loader
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
...
...
vllm/v1/worker/gpu_worker.py
View file @
822de7fb
...
@@ -318,7 +318,7 @@ class Worker(WorkerBase):
...
@@ -318,7 +318,7 @@ class Worker(WorkerBase):
pattern
:
Optional
[
str
]
=
None
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
from
vllm.model_executor.model_loader
.loader
import
ShardedStateLoader
from
vllm.model_executor.model_loader
import
ShardedStateLoader
ShardedStateLoader
.
save_model
(
ShardedStateLoader
.
save_model
(
self
.
model_runner
.
model
,
self
.
model_runner
.
model
,
path
,
path
,
...
...
vllm/worker/model_runner.py
View file @
822de7fb
...
@@ -1220,7 +1220,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1220,7 +1220,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
pattern
:
Optional
[
str
]
=
None
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
from
vllm.model_executor.model_loader
.loader
import
ShardedStateLoader
from
vllm.model_executor.model_loader
import
ShardedStateLoader
ShardedStateLoader
.
save_model
(
ShardedStateLoader
.
save_model
(
self
.
model
,
self
.
model
,
path
,
path
,
...
@@ -1232,7 +1232,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1232,7 +1232,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
,
self
,
tensorizer_config
:
TensorizerConfig
,
tensorizer_config
:
TensorizerConfig
,
)
->
None
:
)
->
None
:
from
vllm.model_executor.model_loader
.loader
import
TensorizerLoader
from
vllm.model_executor.model_loader
import
TensorizerLoader
TensorizerLoader
.
save_model
(
TensorizerLoader
.
save_model
(
self
.
model
,
self
.
model
,
tensorizer_config
=
tensorizer_config
,
tensorizer_config
=
tensorizer_config
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment