Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
610852a4
Unverified
Commit
610852a4
authored
Jul 24, 2025
by
22quinn
Committed by
GitHub
Jul 24, 2025
Browse files
[Core] Support model loader plugins (#21067)
Signed-off-by:
22quinn
<
33176974+22quinn@users.noreply.github.com
>
parent
f0f4de8f
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
159 additions
and
86 deletions
+159
-86
tests/fastsafetensors_loader/test_fastsafetensors_loader.py
tests/fastsafetensors_loader/test_fastsafetensors_loader.py
+1
-3
tests/model_executor/model_loader/__init__.py
tests/model_executor/model_loader/__init__.py
+0
-0
tests/model_executor/model_loader/test_registry.py
tests/model_executor/model_loader/test_registry.py
+37
-0
tests/runai_model_streamer_test/test_runai_model_streamer_loader.py
...i_model_streamer_test/test_runai_model_streamer_loader.py
+4
-3
vllm/config.py
vllm/config.py
+6
-24
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+13
-15
vllm/model_executor/model_loader/__init__.py
vllm/model_executor/model_loader/__init__.py
+87
-27
vllm/model_executor/model_loader/default_loader.py
vllm/model_executor/model_loader/default_loader.py
+9
-9
vllm/model_executor/model_loader/sharded_state_loader.py
vllm/model_executor/model_loader/sharded_state_loader.py
+2
-5
No files found.
tests/fastsafetensors_loader/test_fastsafetensors_loader.py
View file @
610852a4
...
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm
import
SamplingParams
from
vllm.config
import
LoadFormat
test_model
=
"openai-community/gpt2"
...
...
@@ -17,7 +16,6 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
def
test_model_loader_download_files
(
vllm_runner
):
with
vllm_runner
(
test_model
,
load_format
=
LoadFormat
.
FASTSAFETENSORS
)
as
llm
:
with
vllm_runner
(
test_model
,
load_format
=
"fastsafetensors"
)
as
llm
:
deserialized_outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
assert
deserialized_outputs
tests/model_executor/model_loader/__init__.py
0 → 100644
View file @
610852a4
tests/model_executor/model_loader/test_registry.py
0 → 100644
View file @
610852a4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
torch
import
nn
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.model_executor.model_loader
import
(
get_model_loader
,
register_model_loader
)
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
@
register_model_loader
(
"custom_load_format"
)
class
CustomModelLoader
(
BaseModelLoader
):
def
__init__
(
self
,
load_config
:
LoadConfig
)
->
None
:
super
().
__init__
(
load_config
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
pass
def
load_weights
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
pass
def
test_register_model_loader
():
load_config
=
LoadConfig
(
load_format
=
"custom_load_format"
)
assert
isinstance
(
get_model_loader
(
load_config
),
CustomModelLoader
)
def
test_invalid_model_loader
():
with
pytest
.
raises
(
ValueError
):
@
register_model_loader
(
"invalid_load_format"
)
class
InValidModelLoader
:
pass
tests/runai_model_streamer_test/test_runai_model_streamer_loader.py
View file @
610852a4
...
...
@@ -2,9 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm
import
SamplingParams
from
vllm.config
import
LoadConfig
,
LoadFormat
from
vllm.config
import
LoadConfig
from
vllm.model_executor.model_loader
import
get_model_loader
load_format
=
"runai_streamer"
test_model
=
"openai-community/gpt2"
prompts
=
[
...
...
@@ -18,7 +19,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
def
get_runai_model_loader
():
load_config
=
LoadConfig
(
load_format
=
L
oad
F
ormat
.
RUNAI_STREAMER
)
load_config
=
LoadConfig
(
load_format
=
l
oad
_f
ormat
)
return
get_model_loader
(
load_config
)
...
...
@@ -28,6 +29,6 @@ def test_get_model_loader_with_runai_flag():
def
test_runai_model_loader_download_files
(
vllm_runner
):
with
vllm_runner
(
test_model
,
load_format
=
L
oad
F
ormat
.
RUNAI_STREAMER
)
as
llm
:
with
vllm_runner
(
test_model
,
load_format
=
l
oad
_f
ormat
)
as
llm
:
deserialized_outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
assert
deserialized_outputs
vllm/config.py
View file @
610852a4
...
...
@@ -65,7 +65,7 @@ if TYPE_CHECKING:
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.model_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader
import
LoadFormats
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
ConfigType
=
type
[
DataclassInstance
]
...
...
@@ -78,6 +78,7 @@ else:
QuantizationConfig
=
Any
QuantizationMethods
=
Any
BaseModelLoader
=
Any
LoadFormats
=
Any
TensorizerConfig
=
Any
ConfigType
=
type
HfOverrides
=
Union
[
dict
[
str
,
Any
],
Callable
[[
type
],
type
]]
...
...
@@ -1773,29 +1774,12 @@ class CacheConfig:
logger
.
warning
(
"Possibly too large swap space. %s"
,
msg
)
class
LoadFormat
(
str
,
enum
.
Enum
):
AUTO
=
"auto"
PT
=
"pt"
SAFETENSORS
=
"safetensors"
NPCACHE
=
"npcache"
DUMMY
=
"dummy"
TENSORIZER
=
"tensorizer"
SHARDED_STATE
=
"sharded_state"
GGUF
=
"gguf"
BITSANDBYTES
=
"bitsandbytes"
MISTRAL
=
"mistral"
RUNAI_STREAMER
=
"runai_streamer"
RUNAI_STREAMER_SHARDED
=
"runai_streamer_sharded"
FASTSAFETENSORS
=
"fastsafetensors"
@
config
@
dataclass
class
LoadConfig
:
"""Configuration for loading the model weights."""
load_format
:
Union
[
str
,
LoadFormat
,
"BaseModelLoader"
]
=
LoadFormat
.
AUTO
.
value
load_format
:
Union
[
str
,
LoadFormats
]
=
"auto"
"""The format of the model weights to load:
\n
- "auto" will try to load the weights in the safetensors format and fall
back to the pytorch bin format if safetensors format is not available.
\n
...
...
@@ -1816,7 +1800,8 @@ class LoadConfig:
- "gguf" will load weights from GGUF format files (details specified in
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).
\n
- "mistral" will load weights from consolidated safetensors files used by
Mistral models."""
Mistral models.
- Other custom values can be supported via plugins."""
download_dir
:
Optional
[
str
]
=
None
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
...
...
@@ -1864,10 +1849,7 @@ class LoadConfig:
return
hash_str
def
__post_init__
(
self
):
if
isinstance
(
self
.
load_format
,
str
):
load_format
=
self
.
load_format
.
lower
()
self
.
load_format
=
LoadFormat
(
load_format
)
self
.
load_format
=
self
.
load_format
.
lower
()
if
self
.
ignore_patterns
is
not
None
and
len
(
self
.
ignore_patterns
)
>
0
:
logger
.
info
(
"Ignoring the following patterns when downloading weights: %s"
,
...
...
vllm/engine/arg_utils.py
View file @
610852a4
...
...
@@ -26,13 +26,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DetailedTraceModules
,
Device
,
DeviceConfig
,
DistributedExecutorBackend
,
GuidedDecodingBackend
,
GuidedDecodingBackendV1
,
HfOverrides
,
KVEventsConfig
,
KVTransferConfig
,
LoadConfig
,
LoadFormat
,
LogprobsMode
,
LoRAConfig
,
ModelConfig
,
ModelDType
,
ModelImpl
,
MultiModalConfig
,
ObservabilityConfig
,
ParallelConfig
,
PoolerConfig
,
PrefixCachingHashAlgo
,
SchedulerConfig
,
SchedulerPolicy
,
SpeculativeConfig
,
TaskOption
,
TokenizerMode
,
VllmConfig
,
get_attr_docs
,
get_field
)
KVTransferConfig
,
LoadConfig
,
LogprobsMode
,
LoRAConfig
,
ModelConfig
,
ModelDType
,
ModelImpl
,
MultiModalConfig
,
ObservabilityConfig
,
ParallelConfig
,
PoolerConfig
,
PrefixCachingHashAlgo
,
SchedulerConfig
,
SchedulerPolicy
,
SpeculativeConfig
,
TaskOption
,
TokenizerMode
,
VllmConfig
,
get_attr_docs
,
get_field
)
from
vllm.logger
import
init_logger
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.plugins
import
load_general_plugins
...
...
@@ -47,10 +46,12 @@ from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
if
TYPE_CHECKING
:
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.model_loader
import
LoadFormats
from
vllm.usage.usage_lib
import
UsageContext
else
:
ExecutorBase
=
Any
QuantizationMethods
=
Any
LoadFormats
=
Any
UsageContext
=
Any
logger
=
init_logger
(
__name__
)
...
...
@@ -276,7 +277,7 @@ class EngineArgs:
trust_remote_code
:
bool
=
ModelConfig
.
trust_remote_code
allowed_local_media_path
:
str
=
ModelConfig
.
allowed_local_media_path
download_dir
:
Optional
[
str
]
=
LoadConfig
.
download_dir
load_format
:
str
=
LoadConfig
.
load_format
load_format
:
Union
[
str
,
LoadFormats
]
=
LoadConfig
.
load_format
config_format
:
str
=
ModelConfig
.
config_format
dtype
:
ModelDType
=
ModelConfig
.
dtype
kv_cache_dtype
:
CacheDType
=
CacheConfig
.
cache_dtype
...
...
@@ -547,9 +548,7 @@ class EngineArgs:
title
=
"LoadConfig"
,
description
=
LoadConfig
.
__doc__
,
)
load_group
.
add_argument
(
"--load-format"
,
choices
=
[
f
.
value
for
f
in
LoadFormat
],
**
load_kwargs
[
"load_format"
])
load_group
.
add_argument
(
"--load-format"
,
**
load_kwargs
[
"load_format"
])
load_group
.
add_argument
(
"--download-dir"
,
**
load_kwargs
[
"download_dir"
])
load_group
.
add_argument
(
"--model-loader-extra-config"
,
...
...
@@ -864,10 +863,9 @@ class EngineArgs:
# NOTE: This is to allow model loading from S3 in CI
if
(
not
isinstance
(
self
,
AsyncEngineArgs
)
and
envs
.
VLLM_CI_USE_S3
and
self
.
model
in
MODELS_ON_S3
and
self
.
load_format
==
LoadFormat
.
AUTO
):
# noqa: E501
and
self
.
model
in
MODELS_ON_S3
and
self
.
load_format
==
"auto"
):
self
.
model
=
f
"
{
MODEL_WEIGHTS_S3_BUCKET
}
/
{
self
.
model
}
"
self
.
load_format
=
LoadFormat
.
RUNAI_STREAMER
self
.
load_format
=
"runai_streamer"
return
ModelConfig
(
model
=
self
.
model
,
...
...
@@ -1299,7 +1297,7 @@ class EngineArgs:
#############################################################
# Unsupported Feature Flags on V1.
if
self
.
load_format
==
LoadFormat
.
SHARDED_STATE
.
value
:
if
self
.
load_format
==
"sharded_state"
:
_raise_or_fallback
(
feature_name
=
f
"--load_format
{
self
.
load_format
}
"
,
recommend_to_remove
=
False
)
...
...
vllm/model_executor/model_loader/__init__.py
View file @
610852a4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
from
typing
import
Literal
,
Optional
from
torch
import
nn
from
vllm.config
import
LoadConfig
,
LoadFormat
,
ModelConfig
,
VllmConfig
from
vllm.config
import
LoadConfig
,
ModelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.bitsandbytes_loader
import
(
BitsAndBytesModelLoader
)
...
...
@@ -20,34 +21,92 @@ from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
from
vllm.model_executor.model_loader.utils
import
(
get_architecture_class_name
,
get_model_architecture
,
get_model_cls
)
logger
=
init_logger
(
__name__
)
# Reminder: Please update docstring in `LoadConfig`
# if a new load format is added here
LoadFormats
=
Literal
[
"auto"
,
"bitsandbytes"
,
"dummy"
,
"fastsafetensors"
,
"gguf"
,
"mistral"
,
"npcache"
,
"pt"
,
"runai_streamer"
,
"runai_streamer_sharded"
,
"safetensors"
,
"sharded_state"
,
"tensorizer"
,
]
_LOAD_FORMAT_TO_MODEL_LOADER
:
dict
[
str
,
type
[
BaseModelLoader
]]
=
{
"auto"
:
DefaultModelLoader
,
"bitsandbytes"
:
BitsAndBytesModelLoader
,
"dummy"
:
DummyModelLoader
,
"fastsafetensors"
:
DefaultModelLoader
,
"gguf"
:
GGUFModelLoader
,
"mistral"
:
DefaultModelLoader
,
"npcache"
:
DefaultModelLoader
,
"pt"
:
DefaultModelLoader
,
"runai_streamer"
:
RunaiModelStreamerLoader
,
"runai_streamer_sharded"
:
ShardedStateLoader
,
"safetensors"
:
DefaultModelLoader
,
"sharded_state"
:
ShardedStateLoader
,
"tensorizer"
:
TensorizerLoader
,
}
def
register_model_loader
(
load_format
:
str
):
"""Register a customized vllm model loader.
When a load format is not supported by vllm, you can register a customized
model loader to support it.
Args:
load_format (str): The model loader format name.
Examples:
>>> from vllm.config import LoadConfig
>>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
>>>
>>> @register_model_loader("my_loader")
... class MyModelLoader(BaseModelLoader):
... def download_model(self):
... pass
...
... def load_weights(self):
... pass
>>>
>>> load_config = LoadConfig(load_format="my_loader")
>>> type(get_model_loader(load_config))
<class 'MyModelLoader'>
"""
# noqa: E501
def
_wrapper
(
model_loader_cls
):
if
load_format
in
_LOAD_FORMAT_TO_MODEL_LOADER
:
logger
.
warning
(
"Load format `%s` is already registered, and will be "
"overwritten by the new loader class `%s`."
,
load_format
,
model_loader_cls
)
if
not
issubclass
(
model_loader_cls
,
BaseModelLoader
):
raise
ValueError
(
"The model loader must be a subclass of "
"`BaseModelLoader`."
)
_LOAD_FORMAT_TO_MODEL_LOADER
[
load_format
]
=
model_loader_cls
logger
.
info
(
"Registered model loader `%s` with load format `%s`"
,
model_loader_cls
,
load_format
)
return
model_loader_cls
return
_wrapper
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
"""Get a model loader based on the load format."""
if
isinstance
(
load_config
.
load_format
,
type
):
return
load_config
.
load_format
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
DUMMY
:
return
DummyModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
TENSORIZER
:
return
TensorizerLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
SHARDED_STATE
:
return
ShardedStateLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
BITSANDBYTES
:
return
BitsAndBytesModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
GGUF
:
return
GGUFModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
RUNAI_STREAMER
:
return
RunaiModelStreamerLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
RUNAI_STREAMER_SHARDED
:
return
ShardedStateLoader
(
load_config
,
runai_model_streamer
=
True
)
return
DefaultModelLoader
(
load_config
)
load_format
=
load_config
.
load_format
if
load_format
not
in
_LOAD_FORMAT_TO_MODEL_LOADER
:
raise
ValueError
(
f
"Load format `
{
load_format
}
` is not supported"
)
return
_LOAD_FORMAT_TO_MODEL_LOADER
[
load_format
](
load_config
)
def
get_model
(
*
,
...
...
@@ -66,6 +125,7 @@ __all__ = [
"get_architecture_class_name"
,
"get_model_architecture"
,
"get_model_cls"
,
"register_model_loader"
,
"BaseModelLoader"
,
"BitsAndBytesModelLoader"
,
"GGUFModelLoader"
,
...
...
vllm/model_executor/model_loader/default_loader.py
View file @
610852a4
...
...
@@ -13,7 +13,7 @@ from torch import nn
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm
import
envs
from
vllm.config
import
LoadConfig
,
LoadFormat
,
ModelConfig
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.base_loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.weight_utils
import
(
...
...
@@ -104,19 +104,19 @@ class DefaultModelLoader(BaseModelLoader):
use_safetensors
=
False
index_file
=
SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
if
load_format
==
LoadFormat
.
AUTO
:
if
load_format
==
"auto"
:
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
(
load_format
==
LoadFormat
.
SAFETENSORS
or
load_format
==
LoadFormat
.
FASTSAFETENSORS
):
elif
(
load_format
==
"safetensors"
or
load_format
==
"fastsafetensors"
):
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
LoadFormat
.
MISTRAL
:
elif
load_format
==
"mistral"
:
use_safetensors
=
True
allow_patterns
=
[
"consolidated*.safetensors"
]
index_file
=
"consolidated.safetensors.index.json"
elif
load_format
==
LoadFormat
.
PT
:
elif
load_format
==
"pt"
:
allow_patterns
=
[
"*.pt"
]
elif
load_format
==
LoadFormat
.
NPCACHE
:
elif
load_format
==
"npcache"
:
allow_patterns
=
[
"*.bin"
]
else
:
raise
ValueError
(
f
"Unknown load_format:
{
load_format
}
"
)
...
...
@@ -178,7 +178,7 @@ class DefaultModelLoader(BaseModelLoader):
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
source
.
model_or_path
,
source
.
revision
,
source
.
fall_back_to_pt
,
source
.
allow_patterns_overrides
)
if
self
.
load_config
.
load_format
==
LoadFormat
.
NPCACHE
:
if
self
.
load_config
.
load_format
==
"npcache"
:
# Currently np_cache only support *.bin checkpoints
assert
use_safetensors
is
False
weights_iterator
=
np_cache_weights_iterator
(
...
...
@@ -189,7 +189,7 @@ class DefaultModelLoader(BaseModelLoader):
self
.
load_config
.
use_tqdm_on_load
,
)
elif
use_safetensors
:
if
self
.
load_config
.
load_format
==
LoadFormat
.
FASTSAFETENSORS
:
if
self
.
load_config
.
load_format
==
"fastsafetensors"
:
weights_iterator
=
fastsafetensors_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
...
...
vllm/model_executor/model_loader/sharded_state_loader.py
View file @
610852a4
...
...
@@ -32,12 +32,9 @@ class ShardedStateLoader(BaseModelLoader):
DEFAULT_PATTERN
=
"model-rank-{rank}-part-{part}.safetensors"
def
__init__
(
self
,
load_config
:
LoadConfig
,
runai_model_streamer
:
bool
=
False
):
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
self
.
runai_model_streamer
=
runai_model_streamer
extra_config
=
({}
if
load_config
.
model_loader_extra_config
is
None
else
load_config
.
model_loader_extra_config
.
copy
())
self
.
pattern
=
extra_config
.
pop
(
"pattern"
,
self
.
DEFAULT_PATTERN
)
...
...
@@ -152,7 +149,7 @@ class ShardedStateLoader(BaseModelLoader):
def
iterate_over_files
(
self
,
paths
)
->
Generator
[
tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
if
self
.
runai_model_streamer
:
if
self
.
load_config
.
load_format
==
"runai_streamer_sharded"
:
yield
from
runai_safetensors_weights_iterator
(
paths
,
True
)
else
:
from
safetensors.torch
import
safe_open
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment