Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
610852a4
Unverified
Commit
610852a4
authored
Jul 24, 2025
by
22quinn
Committed by
GitHub
Jul 24, 2025
Browse files
[Core] Support model loader plugins (#21067)
Signed-off-by:
22quinn
<
33176974+22quinn@users.noreply.github.com
>
parent
f0f4de8f
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
159 additions
and
86 deletions
+159
-86
tests/fastsafetensors_loader/test_fastsafetensors_loader.py
tests/fastsafetensors_loader/test_fastsafetensors_loader.py
+1
-3
tests/model_executor/model_loader/__init__.py
tests/model_executor/model_loader/__init__.py
+0
-0
tests/model_executor/model_loader/test_registry.py
tests/model_executor/model_loader/test_registry.py
+37
-0
tests/runai_model_streamer_test/test_runai_model_streamer_loader.py
...i_model_streamer_test/test_runai_model_streamer_loader.py
+4
-3
vllm/config.py
vllm/config.py
+6
-24
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+13
-15
vllm/model_executor/model_loader/__init__.py
vllm/model_executor/model_loader/__init__.py
+87
-27
vllm/model_executor/model_loader/default_loader.py
vllm/model_executor/model_loader/default_loader.py
+9
-9
vllm/model_executor/model_loader/sharded_state_loader.py
vllm/model_executor/model_loader/sharded_state_loader.py
+2
-5
No files found.
tests/fastsafetensors_loader/test_fastsafetensors_loader.py
View file @
610852a4
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.config
import
LoadFormat
test_model
=
"openai-community/gpt2"
test_model
=
"openai-community/gpt2"
...
@@ -17,7 +16,6 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
...
@@ -17,7 +16,6 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
def
test_model_loader_download_files
(
vllm_runner
):
def
test_model_loader_download_files
(
vllm_runner
):
with
vllm_runner
(
test_model
,
with
vllm_runner
(
test_model
,
load_format
=
"fastsafetensors"
)
as
llm
:
load_format
=
LoadFormat
.
FASTSAFETENSORS
)
as
llm
:
deserialized_outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
deserialized_outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
assert
deserialized_outputs
assert
deserialized_outputs
tests/model_executor/model_loader/__init__.py
0 → 100644
View file @
610852a4
tests/model_executor/model_loader/test_registry.py
0 → 100644
View file @
610852a4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
torch
import
nn
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.model_executor.model_loader
import
(
get_model_loader
,
register_model_loader
)
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
@
register_model_loader
(
"custom_load_format"
)
class
CustomModelLoader
(
BaseModelLoader
):
def
__init__
(
self
,
load_config
:
LoadConfig
)
->
None
:
super
().
__init__
(
load_config
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
pass
def
load_weights
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
pass
def
test_register_model_loader
():
load_config
=
LoadConfig
(
load_format
=
"custom_load_format"
)
assert
isinstance
(
get_model_loader
(
load_config
),
CustomModelLoader
)
def
test_invalid_model_loader
():
with
pytest
.
raises
(
ValueError
):
@
register_model_loader
(
"invalid_load_format"
)
class
InValidModelLoader
:
pass
tests/runai_model_streamer_test/test_runai_model_streamer_loader.py
View file @
610852a4
...
@@ -2,9 +2,10 @@
...
@@ -2,9 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.config
import
LoadConfig
,
LoadFormat
from
vllm.config
import
LoadConfig
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader
import
get_model_loader
load_format
=
"runai_streamer"
test_model
=
"openai-community/gpt2"
test_model
=
"openai-community/gpt2"
prompts
=
[
prompts
=
[
...
@@ -18,7 +19,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
...
@@ -18,7 +19,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
def
get_runai_model_loader
():
def
get_runai_model_loader
():
load_config
=
LoadConfig
(
load_format
=
L
oad
F
ormat
.
RUNAI_STREAMER
)
load_config
=
LoadConfig
(
load_format
=
l
oad
_f
ormat
)
return
get_model_loader
(
load_config
)
return
get_model_loader
(
load_config
)
...
@@ -28,6 +29,6 @@ def test_get_model_loader_with_runai_flag():
...
@@ -28,6 +29,6 @@ def test_get_model_loader_with_runai_flag():
def
test_runai_model_loader_download_files
(
vllm_runner
):
def
test_runai_model_loader_download_files
(
vllm_runner
):
with
vllm_runner
(
test_model
,
load_format
=
L
oad
F
ormat
.
RUNAI_STREAMER
)
as
llm
:
with
vllm_runner
(
test_model
,
load_format
=
l
oad
_f
ormat
)
as
llm
:
deserialized_outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
deserialized_outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
assert
deserialized_outputs
assert
deserialized_outputs
vllm/config.py
View file @
610852a4
...
@@ -65,7 +65,7 @@ if TYPE_CHECKING:
...
@@ -65,7 +65,7 @@ if TYPE_CHECKING:
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.model_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader
import
LoadFormats
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
ConfigType
=
type
[
DataclassInstance
]
ConfigType
=
type
[
DataclassInstance
]
...
@@ -78,6 +78,7 @@ else:
...
@@ -78,6 +78,7 @@ else:
QuantizationConfig
=
Any
QuantizationConfig
=
Any
QuantizationMethods
=
Any
QuantizationMethods
=
Any
BaseModelLoader
=
Any
BaseModelLoader
=
Any
LoadFormats
=
Any
TensorizerConfig
=
Any
TensorizerConfig
=
Any
ConfigType
=
type
ConfigType
=
type
HfOverrides
=
Union
[
dict
[
str
,
Any
],
Callable
[[
type
],
type
]]
HfOverrides
=
Union
[
dict
[
str
,
Any
],
Callable
[[
type
],
type
]]
...
@@ -1773,29 +1774,12 @@ class CacheConfig:
...
@@ -1773,29 +1774,12 @@ class CacheConfig:
logger
.
warning
(
"Possibly too large swap space. %s"
,
msg
)
logger
.
warning
(
"Possibly too large swap space. %s"
,
msg
)
class
LoadFormat
(
str
,
enum
.
Enum
):
AUTO
=
"auto"
PT
=
"pt"
SAFETENSORS
=
"safetensors"
NPCACHE
=
"npcache"
DUMMY
=
"dummy"
TENSORIZER
=
"tensorizer"
SHARDED_STATE
=
"sharded_state"
GGUF
=
"gguf"
BITSANDBYTES
=
"bitsandbytes"
MISTRAL
=
"mistral"
RUNAI_STREAMER
=
"runai_streamer"
RUNAI_STREAMER_SHARDED
=
"runai_streamer_sharded"
FASTSAFETENSORS
=
"fastsafetensors"
@
config
@
config
@
dataclass
@
dataclass
class
LoadConfig
:
class
LoadConfig
:
"""Configuration for loading the model weights."""
"""Configuration for loading the model weights."""
load_format
:
Union
[
str
,
LoadFormat
,
load_format
:
Union
[
str
,
LoadFormats
]
=
"auto"
"BaseModelLoader"
]
=
LoadFormat
.
AUTO
.
value
"""The format of the model weights to load:
\n
"""The format of the model weights to load:
\n
- "auto" will try to load the weights in the safetensors format and fall
- "auto" will try to load the weights in the safetensors format and fall
back to the pytorch bin format if safetensors format is not available.
\n
back to the pytorch bin format if safetensors format is not available.
\n
...
@@ -1816,7 +1800,8 @@ class LoadConfig:
...
@@ -1816,7 +1800,8 @@ class LoadConfig:
- "gguf" will load weights from GGUF format files (details specified in
- "gguf" will load weights from GGUF format files (details specified in
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).
\n
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).
\n
- "mistral" will load weights from consolidated safetensors files used by
- "mistral" will load weights from consolidated safetensors files used by
Mistral models."""
Mistral models.
- Other custom values can be supported via plugins."""
download_dir
:
Optional
[
str
]
=
None
download_dir
:
Optional
[
str
]
=
None
"""Directory to download and load the weights, default to the default
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
cache directory of Hugging Face."""
...
@@ -1864,10 +1849,7 @@ class LoadConfig:
...
@@ -1864,10 +1849,7 @@ class LoadConfig:
return
hash_str
return
hash_str
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
isinstance
(
self
.
load_format
,
str
):
self
.
load_format
=
self
.
load_format
.
lower
()
load_format
=
self
.
load_format
.
lower
()
self
.
load_format
=
LoadFormat
(
load_format
)
if
self
.
ignore_patterns
is
not
None
and
len
(
self
.
ignore_patterns
)
>
0
:
if
self
.
ignore_patterns
is
not
None
and
len
(
self
.
ignore_patterns
)
>
0
:
logger
.
info
(
logger
.
info
(
"Ignoring the following patterns when downloading weights: %s"
,
"Ignoring the following patterns when downloading weights: %s"
,
...
...
vllm/engine/arg_utils.py
View file @
610852a4
...
@@ -26,13 +26,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
...
@@ -26,13 +26,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DetailedTraceModules
,
Device
,
DeviceConfig
,
DetailedTraceModules
,
Device
,
DeviceConfig
,
DistributedExecutorBackend
,
GuidedDecodingBackend
,
DistributedExecutorBackend
,
GuidedDecodingBackend
,
GuidedDecodingBackendV1
,
HfOverrides
,
KVEventsConfig
,
GuidedDecodingBackendV1
,
HfOverrides
,
KVEventsConfig
,
KVTransferConfig
,
LoadConfig
,
LoadFormat
,
KVTransferConfig
,
LoadConfig
,
LogprobsMode
,
LogprobsMode
,
LoRAConfig
,
ModelConfig
,
ModelDType
,
LoRAConfig
,
ModelConfig
,
ModelDType
,
ModelImpl
,
ModelImpl
,
MultiModalConfig
,
ObservabilityConfig
,
MultiModalConfig
,
ObservabilityConfig
,
ParallelConfig
,
ParallelConfig
,
PoolerConfig
,
PrefixCachingHashAlgo
,
PoolerConfig
,
PrefixCachingHashAlgo
,
SchedulerConfig
,
SchedulerConfig
,
SchedulerPolicy
,
SpeculativeConfig
,
SchedulerPolicy
,
SpeculativeConfig
,
TaskOption
,
TaskOption
,
TokenizerMode
,
VllmConfig
,
get_attr_docs
,
TokenizerMode
,
VllmConfig
,
get_attr_docs
,
get_field
)
get_field
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.plugins
import
load_general_plugins
from
vllm.plugins
import
load_general_plugins
...
@@ -47,10 +46,12 @@ from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
...
@@ -47,10 +46,12 @@ from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.model_loader
import
LoadFormats
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
else
:
else
:
ExecutorBase
=
Any
ExecutorBase
=
Any
QuantizationMethods
=
Any
QuantizationMethods
=
Any
LoadFormats
=
Any
UsageContext
=
Any
UsageContext
=
Any
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -276,7 +277,7 @@ class EngineArgs:
...
@@ -276,7 +277,7 @@ class EngineArgs:
trust_remote_code
:
bool
=
ModelConfig
.
trust_remote_code
trust_remote_code
:
bool
=
ModelConfig
.
trust_remote_code
allowed_local_media_path
:
str
=
ModelConfig
.
allowed_local_media_path
allowed_local_media_path
:
str
=
ModelConfig
.
allowed_local_media_path
download_dir
:
Optional
[
str
]
=
LoadConfig
.
download_dir
download_dir
:
Optional
[
str
]
=
LoadConfig
.
download_dir
load_format
:
str
=
LoadConfig
.
load_format
load_format
:
Union
[
str
,
LoadFormats
]
=
LoadConfig
.
load_format
config_format
:
str
=
ModelConfig
.
config_format
config_format
:
str
=
ModelConfig
.
config_format
dtype
:
ModelDType
=
ModelConfig
.
dtype
dtype
:
ModelDType
=
ModelConfig
.
dtype
kv_cache_dtype
:
CacheDType
=
CacheConfig
.
cache_dtype
kv_cache_dtype
:
CacheDType
=
CacheConfig
.
cache_dtype
...
@@ -547,9 +548,7 @@ class EngineArgs:
...
@@ -547,9 +548,7 @@ class EngineArgs:
title
=
"LoadConfig"
,
title
=
"LoadConfig"
,
description
=
LoadConfig
.
__doc__
,
description
=
LoadConfig
.
__doc__
,
)
)
load_group
.
add_argument
(
"--load-format"
,
load_group
.
add_argument
(
"--load-format"
,
**
load_kwargs
[
"load_format"
])
choices
=
[
f
.
value
for
f
in
LoadFormat
],
**
load_kwargs
[
"load_format"
])
load_group
.
add_argument
(
"--download-dir"
,
load_group
.
add_argument
(
"--download-dir"
,
**
load_kwargs
[
"download_dir"
])
**
load_kwargs
[
"download_dir"
])
load_group
.
add_argument
(
"--model-loader-extra-config"
,
load_group
.
add_argument
(
"--model-loader-extra-config"
,
...
@@ -864,10 +863,9 @@ class EngineArgs:
...
@@ -864,10 +863,9 @@ class EngineArgs:
# NOTE: This is to allow model loading from S3 in CI
# NOTE: This is to allow model loading from S3 in CI
if
(
not
isinstance
(
self
,
AsyncEngineArgs
)
and
envs
.
VLLM_CI_USE_S3
if
(
not
isinstance
(
self
,
AsyncEngineArgs
)
and
envs
.
VLLM_CI_USE_S3
and
self
.
model
in
MODELS_ON_S3
and
self
.
model
in
MODELS_ON_S3
and
self
.
load_format
==
"auto"
):
and
self
.
load_format
==
LoadFormat
.
AUTO
):
# noqa: E501
self
.
model
=
f
"
{
MODEL_WEIGHTS_S3_BUCKET
}
/
{
self
.
model
}
"
self
.
model
=
f
"
{
MODEL_WEIGHTS_S3_BUCKET
}
/
{
self
.
model
}
"
self
.
load_format
=
LoadFormat
.
RUNAI_STREAMER
self
.
load_format
=
"runai_streamer"
return
ModelConfig
(
return
ModelConfig
(
model
=
self
.
model
,
model
=
self
.
model
,
...
@@ -1299,7 +1297,7 @@ class EngineArgs:
...
@@ -1299,7 +1297,7 @@ class EngineArgs:
#############################################################
#############################################################
# Unsupported Feature Flags on V1.
# Unsupported Feature Flags on V1.
if
self
.
load_format
==
LoadFormat
.
SHARDED_STATE
.
value
:
if
self
.
load_format
==
"sharded_state"
:
_raise_or_fallback
(
_raise_or_fallback
(
feature_name
=
f
"--load_format
{
self
.
load_format
}
"
,
feature_name
=
f
"--load_format
{
self
.
load_format
}
"
,
recommend_to_remove
=
False
)
recommend_to_remove
=
False
)
...
...
vllm/model_executor/model_loader/__init__.py
View file @
610852a4
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
from
typing
import
Literal
,
Optional
from
torch
import
nn
from
torch
import
nn
from
vllm.config
import
LoadConfig
,
LoadFormat
,
ModelConfig
,
VllmConfig
from
vllm.config
import
LoadConfig
,
ModelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.bitsandbytes_loader
import
(
from
vllm.model_executor.model_loader.bitsandbytes_loader
import
(
BitsAndBytesModelLoader
)
BitsAndBytesModelLoader
)
...
@@ -20,34 +21,92 @@ from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
...
@@ -20,34 +21,92 @@ from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
from
vllm.model_executor.model_loader.utils
import
(
from
vllm.model_executor.model_loader.utils
import
(
get_architecture_class_name
,
get_model_architecture
,
get_model_cls
)
get_architecture_class_name
,
get_model_architecture
,
get_model_cls
)
logger
=
init_logger
(
__name__
)
# Reminder: Please update docstring in `LoadConfig`
# if a new load format is added here
LoadFormats
=
Literal
[
"auto"
,
"bitsandbytes"
,
"dummy"
,
"fastsafetensors"
,
"gguf"
,
"mistral"
,
"npcache"
,
"pt"
,
"runai_streamer"
,
"runai_streamer_sharded"
,
"safetensors"
,
"sharded_state"
,
"tensorizer"
,
]
_LOAD_FORMAT_TO_MODEL_LOADER
:
dict
[
str
,
type
[
BaseModelLoader
]]
=
{
"auto"
:
DefaultModelLoader
,
"bitsandbytes"
:
BitsAndBytesModelLoader
,
"dummy"
:
DummyModelLoader
,
"fastsafetensors"
:
DefaultModelLoader
,
"gguf"
:
GGUFModelLoader
,
"mistral"
:
DefaultModelLoader
,
"npcache"
:
DefaultModelLoader
,
"pt"
:
DefaultModelLoader
,
"runai_streamer"
:
RunaiModelStreamerLoader
,
"runai_streamer_sharded"
:
ShardedStateLoader
,
"safetensors"
:
DefaultModelLoader
,
"sharded_state"
:
ShardedStateLoader
,
"tensorizer"
:
TensorizerLoader
,
}
def
register_model_loader
(
load_format
:
str
):
"""Register a customized vllm model loader.
When a load format is not supported by vllm, you can register a customized
model loader to support it.
Args:
load_format (str): The model loader format name.
Examples:
>>> from vllm.config import LoadConfig
>>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
>>>
>>> @register_model_loader("my_loader")
... class MyModelLoader(BaseModelLoader):
... def download_model(self):
... pass
...
... def load_weights(self):
... pass
>>>
>>> load_config = LoadConfig(load_format="my_loader")
>>> type(get_model_loader(load_config))
<class 'MyModelLoader'>
"""
# noqa: E501
def
_wrapper
(
model_loader_cls
):
if
load_format
in
_LOAD_FORMAT_TO_MODEL_LOADER
:
logger
.
warning
(
"Load format `%s` is already registered, and will be "
"overwritten by the new loader class `%s`."
,
load_format
,
model_loader_cls
)
if
not
issubclass
(
model_loader_cls
,
BaseModelLoader
):
raise
ValueError
(
"The model loader must be a subclass of "
"`BaseModelLoader`."
)
_LOAD_FORMAT_TO_MODEL_LOADER
[
load_format
]
=
model_loader_cls
logger
.
info
(
"Registered model loader `%s` with load format `%s`"
,
model_loader_cls
,
load_format
)
return
model_loader_cls
return
_wrapper
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
"""Get a model loader based on the load format."""
"""Get a model loader based on the load format."""
if
isinstance
(
load_config
.
load_format
,
type
):
load_format
=
load_config
.
load_format
return
load_config
.
load_format
(
load_config
)
if
load_format
not
in
_LOAD_FORMAT_TO_MODEL_LOADER
:
raise
ValueError
(
f
"Load format `
{
load_format
}
` is not supported"
)
if
load_config
.
load_format
==
LoadFormat
.
DUMMY
:
return
_LOAD_FORMAT_TO_MODEL_LOADER
[
load_format
](
load_config
)
return
DummyModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
TENSORIZER
:
return
TensorizerLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
SHARDED_STATE
:
return
ShardedStateLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
BITSANDBYTES
:
return
BitsAndBytesModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
GGUF
:
return
GGUFModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
RUNAI_STREAMER
:
return
RunaiModelStreamerLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
RUNAI_STREAMER_SHARDED
:
return
ShardedStateLoader
(
load_config
,
runai_model_streamer
=
True
)
return
DefaultModelLoader
(
load_config
)
def
get_model
(
*
,
def
get_model
(
*
,
...
@@ -66,6 +125,7 @@ __all__ = [
...
@@ -66,6 +125,7 @@ __all__ = [
"get_architecture_class_name"
,
"get_architecture_class_name"
,
"get_model_architecture"
,
"get_model_architecture"
,
"get_model_cls"
,
"get_model_cls"
,
"register_model_loader"
,
"BaseModelLoader"
,
"BaseModelLoader"
,
"BitsAndBytesModelLoader"
,
"BitsAndBytesModelLoader"
,
"GGUFModelLoader"
,
"GGUFModelLoader"
,
...
...
vllm/model_executor/model_loader/default_loader.py
View file @
610852a4
...
@@ -13,7 +13,7 @@ from torch import nn
...
@@ -13,7 +13,7 @@ from torch import nn
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm
import
envs
from
vllm
import
envs
from
vllm.config
import
LoadConfig
,
LoadFormat
,
ModelConfig
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
...
@@ -104,19 +104,19 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -104,19 +104,19 @@ class DefaultModelLoader(BaseModelLoader):
use_safetensors
=
False
use_safetensors
=
False
index_file
=
SAFE_WEIGHTS_INDEX_NAME
index_file
=
SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
# Some quantized models use .pt files for storing the weights.
if
load_format
==
LoadFormat
.
AUTO
:
if
load_format
==
"auto"
:
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
(
load_format
==
LoadFormat
.
SAFETENSORS
elif
(
load_format
==
"safetensors"
or
load_format
==
LoadFormat
.
FASTSAFETENSORS
):
or
load_format
==
"fastsafetensors"
):
use_safetensors
=
True
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
LoadFormat
.
MISTRAL
:
elif
load_format
==
"mistral"
:
use_safetensors
=
True
use_safetensors
=
True
allow_patterns
=
[
"consolidated*.safetensors"
]
allow_patterns
=
[
"consolidated*.safetensors"
]
index_file
=
"consolidated.safetensors.index.json"
index_file
=
"consolidated.safetensors.index.json"
elif
load_format
==
LoadFormat
.
PT
:
elif
load_format
==
"pt"
:
allow_patterns
=
[
"*.pt"
]
allow_patterns
=
[
"*.pt"
]
elif
load_format
==
LoadFormat
.
NPCACHE
:
elif
load_format
==
"npcache"
:
allow_patterns
=
[
"*.bin"
]
allow_patterns
=
[
"*.bin"
]
else
:
else
:
raise
ValueError
(
f
"Unknown load_format:
{
load_format
}
"
)
raise
ValueError
(
f
"Unknown load_format:
{
load_format
}
"
)
...
@@ -178,7 +178,7 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -178,7 +178,7 @@ class DefaultModelLoader(BaseModelLoader):
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
source
.
model_or_path
,
source
.
revision
,
source
.
fall_back_to_pt
,
source
.
model_or_path
,
source
.
revision
,
source
.
fall_back_to_pt
,
source
.
allow_patterns_overrides
)
source
.
allow_patterns_overrides
)
if
self
.
load_config
.
load_format
==
LoadFormat
.
NPCACHE
:
if
self
.
load_config
.
load_format
==
"npcache"
:
# Currently np_cache only support *.bin checkpoints
# Currently np_cache only support *.bin checkpoints
assert
use_safetensors
is
False
assert
use_safetensors
is
False
weights_iterator
=
np_cache_weights_iterator
(
weights_iterator
=
np_cache_weights_iterator
(
...
@@ -189,7 +189,7 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -189,7 +189,7 @@ class DefaultModelLoader(BaseModelLoader):
self
.
load_config
.
use_tqdm_on_load
,
self
.
load_config
.
use_tqdm_on_load
,
)
)
elif
use_safetensors
:
elif
use_safetensors
:
if
self
.
load_config
.
load_format
==
LoadFormat
.
FASTSAFETENSORS
:
if
self
.
load_config
.
load_format
==
"fastsafetensors"
:
weights_iterator
=
fastsafetensors_weights_iterator
(
weights_iterator
=
fastsafetensors_weights_iterator
(
hf_weights_files
,
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
self
.
load_config
.
use_tqdm_on_load
,
...
...
vllm/model_executor/model_loader/sharded_state_loader.py
View file @
610852a4
...
@@ -32,12 +32,9 @@ class ShardedStateLoader(BaseModelLoader):
...
@@ -32,12 +32,9 @@ class ShardedStateLoader(BaseModelLoader):
DEFAULT_PATTERN
=
"model-rank-{rank}-part-{part}.safetensors"
DEFAULT_PATTERN
=
"model-rank-{rank}-part-{part}.safetensors"
def
__init__
(
self
,
def
__init__
(
self
,
load_config
:
LoadConfig
):
load_config
:
LoadConfig
,
runai_model_streamer
:
bool
=
False
):
super
().
__init__
(
load_config
)
super
().
__init__
(
load_config
)
self
.
runai_model_streamer
=
runai_model_streamer
extra_config
=
({}
if
load_config
.
model_loader_extra_config
is
None
extra_config
=
({}
if
load_config
.
model_loader_extra_config
is
None
else
load_config
.
model_loader_extra_config
.
copy
())
else
load_config
.
model_loader_extra_config
.
copy
())
self
.
pattern
=
extra_config
.
pop
(
"pattern"
,
self
.
DEFAULT_PATTERN
)
self
.
pattern
=
extra_config
.
pop
(
"pattern"
,
self
.
DEFAULT_PATTERN
)
...
@@ -152,7 +149,7 @@ class ShardedStateLoader(BaseModelLoader):
...
@@ -152,7 +149,7 @@ class ShardedStateLoader(BaseModelLoader):
def
iterate_over_files
(
def
iterate_over_files
(
self
,
paths
)
->
Generator
[
tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
self
,
paths
)
->
Generator
[
tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
if
self
.
runai_model_streamer
:
if
self
.
load_config
.
load_format
==
"runai_streamer_sharded"
:
yield
from
runai_safetensors_weights_iterator
(
paths
,
True
)
yield
from
runai_safetensors_weights_iterator
(
paths
,
True
)
else
:
else
:
from
safetensors.torch
import
safe_open
from
safetensors.torch
import
safe_open
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment