Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
69e1d2fb
Unverified
Commit
69e1d2fb
authored
Apr 16, 2024
by
Antoni Baum
Committed by
GitHub
Apr 16, 2024
Browse files
[Core] Refactor model loading code (#4097)
parent
05434764
Changes
67
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
684 additions
and
446 deletions
+684
-446
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+5
-5
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+1
-1
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+3
-2
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+0
-128
vllm/model_executor/model_loader/__init__.py
vllm/model_executor/model_loader/__init__.py
+30
-0
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+354
-0
vllm/model_executor/model_loader/neuron.py
vllm/model_executor/model_loader/neuron.py
+0
-0
vllm/model_executor/model_loader/tensorizer.py
vllm/model_executor/model_loader/tensorizer.py
+82
-34
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+40
-0
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+127
-168
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+4
-10
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+4
-10
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+4
-10
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+4
-12
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+4
-12
vllm/model_executor/models/decilm.py
vllm/model_executor/models/decilm.py
+4
-10
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+6
-14
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+4
-10
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+4
-10
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+4
-10
No files found.
vllm/executor/executor_base.py
View file @
69e1d2fb
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
Lo
RA
Config
,
Model
Config
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
Lo
ad
Config
,
LoRA
Config
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
Tensorizer
Config
,
VisionLanguageConfig
)
Speculative
Config
,
VisionLanguageConfig
)
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
...
@@ -23,20 +23,20 @@ class ExecutorBase(ABC):
...
@@ -23,20 +23,20 @@ class ExecutorBase(ABC):
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
tensorizer_config
:
Optional
[
TensorizerConfig
],
)
->
None
:
)
->
None
:
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
load_config
=
load_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
device_config
=
device_config
self
.
vision_language_config
=
vision_language_config
self
.
vision_language_config
=
vision_language_config
self
.
speculative_config
=
speculative_config
self
.
speculative_config
=
speculative_config
self
.
tensorizer_config
=
tensorizer_config
self
.
_init_executor
()
self
.
_init_executor
()
...
...
vllm/executor/gpu_executor.py
View file @
69e1d2fb
...
@@ -35,12 +35,12 @@ class GPUExecutor(ExecutorBase):
...
@@ -35,12 +35,12 @@ class GPUExecutor(ExecutorBase):
scheduler_config
=
self
.
scheduler_config
,
scheduler_config
=
self
.
scheduler_config
,
device_config
=
self
.
device_config
,
device_config
=
self
.
device_config
,
cache_config
=
self
.
cache_config
,
cache_config
=
self
.
cache_config
,
load_config
=
self
.
load_config
,
local_rank
=
0
,
local_rank
=
0
,
rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
distributed_init_method
=
distributed_init_method
,
lora_config
=
self
.
lora_config
,
lora_config
=
self
.
lora_config
,
vision_language_config
=
self
.
vision_language_config
,
vision_language_config
=
self
.
vision_language_config
,
tensorizer_config
=
self
.
tensorizer_config
,
is_driver_worker
=
True
,
is_driver_worker
=
True
,
)
)
self
.
driver_worker
.
init_device
()
self
.
driver_worker
.
init_device
()
...
...
vllm/executor/ray_gpu_executor.py
View file @
69e1d2fb
...
@@ -147,6 +147,7 @@ class RayGPUExecutor(ExecutorBase):
...
@@ -147,6 +147,7 @@ class RayGPUExecutor(ExecutorBase):
model_config
=
copy
.
deepcopy
(
self
.
model_config
)
model_config
=
copy
.
deepcopy
(
self
.
model_config
)
parallel_config
=
copy
.
deepcopy
(
self
.
parallel_config
)
parallel_config
=
copy
.
deepcopy
(
self
.
parallel_config
)
scheduler_config
=
copy
.
deepcopy
(
self
.
scheduler_config
)
scheduler_config
=
copy
.
deepcopy
(
self
.
scheduler_config
)
load_config
=
copy
.
deepcopy
(
self
.
load_config
)
device_config
=
copy
.
deepcopy
(
self
.
device_config
)
device_config
=
copy
.
deepcopy
(
self
.
device_config
)
lora_config
=
copy
.
deepcopy
(
self
.
lora_config
)
lora_config
=
copy
.
deepcopy
(
self
.
lora_config
)
cache_config
=
copy
.
deepcopy
(
self
.
cache_config
)
cache_config
=
copy
.
deepcopy
(
self
.
cache_config
)
...
@@ -165,12 +166,12 @@ class RayGPUExecutor(ExecutorBase):
...
@@ -165,12 +166,12 @@ class RayGPUExecutor(ExecutorBase):
scheduler_config
=
scheduler_config
,
scheduler_config
=
scheduler_config
,
device_config
=
device_config
,
device_config
=
device_config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
load_config
=
load_config
,
local_rank
=
local_rank
,
local_rank
=
local_rank
,
rank
=
rank
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
distributed_init_method
=
distributed_init_method
,
lora_config
=
lora_config
,
lora_config
=
lora_config
,
vision_language_config
=
vision_language_config
,
vision_language_config
=
vision_language_config
,
tensorizer_config
=
self
.
tensorizer_config
,
))
))
# Initialize the driver worker with the Worker class.
# Initialize the driver worker with the Worker class.
...
@@ -187,7 +188,7 @@ class RayGPUExecutor(ExecutorBase):
...
@@ -187,7 +188,7 @@ class RayGPUExecutor(ExecutorBase):
distributed_init_method
=
distributed_init_method
,
distributed_init_method
=
distributed_init_method
,
lora_config
=
self
.
lora_config
,
lora_config
=
self
.
lora_config
,
vision_language_config
=
self
.
vision_language_config
,
vision_language_config
=
self
.
vision_language_config
,
tensorizer
_config
=
self
.
tensorizer
_config
,
load
_config
=
self
.
load
_config
,
is_driver_worker
=
True
,
is_driver_worker
=
True
,
)
)
...
...
vllm/model_executor/model_loader.py
deleted
100644 → 0
View file @
05434764
"""Utilities for selecting and loading models."""
import
contextlib
from
typing
import
Tuple
,
Type
import
torch
from
torch
import
nn
from
vllm.config
import
DeviceConfig
,
ModelConfig
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models.llava
import
LlavaForConditionalGeneration
from
vllm.model_executor.tensorizer_loader
import
(
ParameterizedLoadFormat
,
is_vllm_serialized_tensorizer
,
load_with_tensorizer
)
from
vllm.model_executor.weight_utils
import
(
get_quant_config
,
initialize_dummy_weights
)
_VISION_MODEL_CLASSES
=
[
LlavaForConditionalGeneration
,
]
@
contextlib
.
contextmanager
def
_set_default_torch_dtype
(
dtype
:
torch
.
dtype
):
"""Sets the default torch dtype to the given dtype."""
old_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
dtype
)
yield
torch
.
set_default_dtype
(
old_dtype
)
def
_get_model_architecture
(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if
(
model_config
.
quantization
is
not
None
and
"MixtralForCausalLM"
in
architectures
):
architectures
=
[
"QuantMixtralForCausalLM"
]
for
arch
in
architectures
:
model_cls
=
ModelRegistry
.
load_model_cls
(
arch
)
if
model_cls
is
not
None
:
return
(
model_cls
,
arch
)
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Supported architectures:
{
ModelRegistry
.
get_supported_archs
()
}
"
)
def
get_architecture_class_name
(
model_config
:
ModelConfig
)
->
str
:
return
_get_model_architecture
(
model_config
)[
1
]
def
get_model
(
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
**
kwargs
)
->
nn
.
Module
:
lora_config
=
kwargs
.
get
(
"lora_config"
,
None
)
vision_language_config
=
kwargs
.
get
(
"vision_language_config"
,
None
)
tensorizer_config
=
kwargs
.
get
(
"tensorizer_config"
,
None
)
model_class
=
_get_model_architecture
(
model_config
)[
0
]
# Get the (maybe quantized) linear method.
linear_method
=
None
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
)
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
<
quant_config
.
get_min_capability
():
raise
ValueError
(
f
"The quantization method
{
model_config
.
quantization
}
is not "
"supported for the current GPU. "
f
"Minimum capability:
{
quant_config
.
get_min_capability
()
}
. "
f
"Current capability:
{
capability
}
."
)
supported_dtypes
=
quant_config
.
get_supported_act_dtypes
()
if
model_config
.
dtype
not
in
supported_dtypes
:
raise
ValueError
(
f
"
{
model_config
.
dtype
}
is not supported for quantization "
f
"method
{
model_config
.
quantization
}
. Supported dtypes: "
f
"
{
supported_dtypes
}
"
)
linear_method
=
quant_config
.
get_linear_method
()
with
_set_default_torch_dtype
(
model_config
.
dtype
):
# Create a model instance.
# The weights will be initialized as empty tensors.
extra_kwargs
=
{}
if
hasattr
(
model_class
,
"supported_lora_modules"
):
extra_kwargs
[
"lora_config"
]
=
lora_config
elif
lora_config
:
raise
ValueError
(
f
"Model
{
model_class
.
__name__
}
does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github."
)
elif
model_class
in
_VISION_MODEL_CLASSES
:
extra_kwargs
[
"vision_language_config"
]
=
vision_language_config
with
torch
.
device
(
device_config
.
device
):
if
(
model_config
.
load_format
==
"tensorizer"
and
is_vllm_serialized_tensorizer
(
tensorizer_config
)):
extra_kwargs
[
"linear_method"
]
=
linear_method
tensorizer_config
.
model_class
=
model_class
tensorizer_config
.
hf_config
=
model_config
.
hf_config
tensorizer_config
.
dtype
=
model_config
.
dtype
model
=
load_with_tensorizer
(
tensorizer_config
,
**
extra_kwargs
)
return
model
.
eval
()
model
=
model_class
(
config
=
model_config
.
hf_config
,
linear_method
=
linear_method
,
**
extra_kwargs
)
if
model_config
.
load_format
==
"dummy"
:
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights
(
model
)
else
:
# Load the weights from the cached or downloaded files.
if
model_config
.
load_format
==
"tensorizer"
:
# Provide a dynamic load format for `model.load_weights`
# to retain tensorizer args from CLI.
model_config
.
load_format
=
ParameterizedLoadFormat
(
model_config
.
load_format
)
model_config
.
load_format
.
params
=
(
tensorizer_config
.
_construct_tensorizer_args
())
model
.
load_weights
(
model_config
.
model
,
model_config
.
download_dir
,
model_config
.
load_format
,
model_config
.
revision
,
)
return
model
.
eval
()
vllm/model_executor/model_loader/__init__.py
0 → 100644
View file @
69e1d2fb
from
typing
import
Optional
from
torch
import
nn
from
vllm.config
import
(
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.model_executor.model_loader.loader
import
(
BaseModelLoader
,
get_model_loader
)
from
vllm.model_executor.model_loader.utils
import
(
get_architecture_class_name
,
get_model_architecture
)
def
get_model
(
*
,
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
device_config
:
DeviceConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
])
->
nn
.
Module
:
loader
=
get_model_loader
(
load_config
)
return
loader
.
load_model
(
model_config
=
model_config
,
device_config
=
device_config
,
lora_config
=
lora_config
,
vision_language_config
=
vision_language_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
)
__all__
=
[
"get_model"
,
"get_model_loader"
,
"BaseModelLoader"
,
"get_architecture_class_name"
,
"get_model_architecture"
]
vllm/model_executor/model_loader/loader.py
0 → 100644
View file @
69e1d2fb
# ruff: noqa: SIM117
import
copy
import
glob
import
os
from
abc
import
ABC
,
abstractmethod
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Type
)
import
torch
from
torch
import
nn
from
vllm.config
import
(
VLLM_USE_MODELSCOPE
,
DeviceConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader.tensorizer
import
(
TensorizerConfig
,
is_vllm_serialized_tensorizer
,
load_with_tensorizer
,
tensorizer_weights_iterator
)
from
vllm.model_executor.model_loader.utils
import
(
get_model_architecture
,
set_default_torch_dtype
)
from
vllm.model_executor.model_loader.weight_utils
import
(
download_weights_from_hf
,
filter_files_not_needed_for_inference
,
get_quant_config
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.models.llava
import
LlavaForConditionalGeneration
if
TYPE_CHECKING
:
from
vllm.model_executor.layers.linear
import
LinearMethodBase
_VISION_MODEL_CLASSES
=
[
LlavaForConditionalGeneration
,
]
logger
=
init_logger
(
__name__
)
def
_get_linear_method
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
)
->
Optional
[
"LinearMethodBase"
]:
"""Get the (maybe quantized) linear method."""
linear_method
=
None
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
,
load_config
)
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
<
quant_config
.
get_min_capability
():
raise
ValueError
(
f
"The quantization method
{
model_config
.
quantization
}
is not "
"supported for the current GPU. "
f
"Minimum capability:
{
quant_config
.
get_min_capability
()
}
. "
f
"Current capability:
{
capability
}
."
)
supported_dtypes
=
quant_config
.
get_supported_act_dtypes
()
if
model_config
.
dtype
not
in
supported_dtypes
:
raise
ValueError
(
f
"
{
model_config
.
dtype
}
is not supported for quantization "
f
"method
{
model_config
.
quantization
}
. Supported dtypes: "
f
"
{
supported_dtypes
}
"
)
linear_method
=
quant_config
.
get_linear_method
()
return
linear_method
def
_get_model_initialization_kwargs
(
model_class
:
Type
[
nn
.
Module
],
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
]
)
->
Dict
[
str
,
Any
]:
"""Get extra kwargs for model initialization."""
extra_kwargs
=
{}
if
hasattr
(
model_class
,
"supported_lora_modules"
):
extra_kwargs
[
"lora_config"
]
=
lora_config
elif
lora_config
:
raise
ValueError
(
f
"Model
{
model_class
.
__name__
}
does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github."
)
elif
model_class
in
_VISION_MODEL_CLASSES
:
extra_kwargs
[
"vision_language_config"
]
=
vision_language_config
return
extra_kwargs
def
_initialize_model
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
])
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
model_class
=
get_model_architecture
(
model_config
)[
0
]
linear_method
=
_get_linear_method
(
model_config
,
load_config
)
return
model_class
(
config
=
model_config
.
hf_config
,
linear_method
=
linear_method
,
**
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
vision_language_config
))
class
BaseModelLoader
(
ABC
):
"""Base class for model loaders."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
self
.
load_config
=
load_config
@
abstractmethod
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
)
->
nn
.
Module
:
"""Load a model with the given configurations."""
...
class
DefaultModelLoader
(
BaseModelLoader
):
"""Model loader that can load different file types from disk."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
_maybe_download_from_modelscope
(
self
,
model
:
str
,
revision
:
Optional
[
str
])
->
Optional
[
str
]:
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if
VLLM_USE_MODELSCOPE
:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from
modelscope.hub.snapshot_download
import
snapshot_download
if
not
os
.
path
.
exists
(
model
):
model_path
=
snapshot_download
(
model_id
=
model
,
cache_dir
=
self
.
load_config
.
download_dir
,
revision
=
revision
)
else
:
model_path
=
model
return
model_path
return
None
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
model_name_or_path
=
self
.
_maybe_download_from_modelscope
(
model_name_or_path
,
revision
)
or
model_name_or_path
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
load_format
=
self
.
load_config
.
load_format
use_safetensors
=
False
# Some quantized models use .pt files for storing the weights.
if
load_format
==
LoadFormat
.
AUTO
:
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
load_format
==
LoadFormat
.
SAFETENSORS
:
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
LoadFormat
.
PT
:
allow_patterns
=
[
"*.pt"
]
elif
load_format
==
LoadFormat
.
NPCACHE
:
allow_patterns
=
[
"*.bin"
]
else
:
raise
ValueError
(
f
"Unknown load_format:
{
load_format
}
"
)
if
fall_back_to_pt
:
allow_patterns
+=
[
"*.pt"
]
if
not
is_local
:
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
)
else
:
hf_folder
=
model_name_or_path
hf_weights_files
:
List
[
str
]
=
[]
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
if
len
(
hf_weights_files
)
>
0
:
if
pattern
==
"*.safetensors"
:
use_safetensors
=
True
break
if
not
use_safetensors
:
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_folder
,
hf_weights_files
,
use_safetensors
def
_get_weights_iterator
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Get an iterator for the model weights based on the load format."""
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
model_name_or_path
,
revision
,
fall_back_to_pt
)
if
self
.
load_config
.
load_format
==
LoadFormat
.
NPCACHE
:
# Currently np_cache only support *.bin checkpoints
assert
use_safetensors
is
False
return
np_cache_weights_iterator
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
hf_folder
,
hf_weights_files
)
if
use_safetensors
:
return
safetensors_weights_iterator
(
hf_weights_files
)
return
pt_weights_iterator
(
hf_weights_files
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
vision_language_config
)
model
.
load_weights
(
self
.
_get_weights_iterator
(
model_config
.
model
,
model_config
.
revision
,
fall_back_to_pt
=
getattr
(
model
,
"fall_back_to_pt_during_load"
,
True
)),
)
return
model
.
eval
()
class
DummyModelLoader
(
BaseModelLoader
):
"""Model loader that will set model weights to random values."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
vision_language_config
)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights
(
model
)
return
model
.
eval
()
class
TensorizerLoader
(
BaseModelLoader
):
"""Model loader using CoreWeave's tensorizer library."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
isinstance
(
load_config
.
model_loader_extra_config
,
TensorizerConfig
):
self
.
tensorizer_config
=
load_config
.
model_loader_extra_config
else
:
self
.
tensorizer_config
=
TensorizerConfig
(
**
load_config
.
model_loader_extra_config
)
def
_verify_config
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
):
self
.
tensorizer_config
.
verify_with_model_config
(
model_config
)
self
.
tensorizer_config
.
verify_with_parallel_config
(
parallel_config
)
def
_get_weights_iterator
(
self
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
tensorizer_args
=
self
.
tensorizer_config
.
_construct_tensorizer_args
()
return
tensorizer_weights_iterator
(
tensorizer_args
)
def
_load_model_unserialized
(
self
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
]
)
->
nn
.
Module
:
"""Load an unserialized model with tensorizer.
Unserialized here means "not serialized with tensorizer". This
should still be faster than default HuggingFace loading, but will
be slower than loading a tensorizer-serialized model.
"""
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
vision_language_config
)
model
.
load_weights
(
self
.
_get_weights_iterator
())
return
model
.
eval
()
def
_load_model_serialized
(
self
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
]
)
->
nn
.
Module
:
"""Load a serialized model with tensorizer.
See the examples/tensorize_vllm_model.py example "
script for serializing vLLM models."""
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model_class
=
get_model_architecture
(
model_config
)[
0
]
linear_method
=
_get_linear_method
(
model_config
,
self
.
load_config
)
extra_kwargs
=
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
vision_language_config
)
extra_kwargs
[
"linear_method"
]
=
linear_method
tensorizer_config
=
copy
.
copy
(
self
.
tensorizer_config
)
tensorizer_config
.
model_class
=
model_class
tensorizer_config
.
hf_config
=
model_config
.
hf_config
tensorizer_config
.
dtype
=
model_config
.
dtype
model
=
load_with_tensorizer
(
tensorizer_config
,
**
extra_kwargs
)
return
model
.
eval
()
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
)
->
nn
.
Module
:
self
.
_verify_config
(
model_config
,
parallel_config
)
if
is_vllm_serialized_tensorizer
(
self
.
tensorizer_config
):
return
self
.
_load_model_serialized
(
model_config
,
device_config
,
lora_config
,
vision_language_config
)
return
self
.
_load_model_unserialized
(
model_config
,
device_config
,
lora_config
,
vision_language_config
)
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
"""Get a model loader based on the load format."""
if
isinstance
(
load_config
.
load_format
,
type
):
return
load_config
.
load_format
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
DUMMY
:
return
DummyModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
TENSORIZER
:
return
TensorizerLoader
(
load_config
)
return
DefaultModelLoader
(
load_config
)
vllm/model_executor/
neuron_
model_loader.py
→
vllm/model_executor/model_loader
/neuron
.py
View file @
69e1d2fb
File moved
vllm/model_executor/tensorizer
_loader
.py
→
vllm/model_executor/
model_loader/
tensorizer.py
View file @
69e1d2fb
...
@@ -4,20 +4,20 @@ import io
...
@@ -4,20 +4,20 @@ import io
import
os
import
os
import
time
import
time
import
typing
import
typing
import
warnings
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Union
from
typing
import
Generator
,
Optional
,
Tuple
,
Type
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
Tensorizer
Config
from
vllm.config
import
ModelConfig
,
Parallel
Config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
tensorizer_load_fail
=
Fals
e
tensorizer_load_fail
=
Non
e
try
:
try
:
from
tensorizer
import
(
DecryptionParams
,
EncryptionParams
,
from
tensorizer
import
(
DecryptionParams
,
EncryptionParams
,
...
@@ -25,51 +25,78 @@ try:
...
@@ -25,51 +25,78 @@ try:
from
tensorizer.stream_io
import
open_stream
from
tensorizer.stream_io
import
open_stream
from
tensorizer.utils
import
(
convert_bytes
,
get_mem_usage
,
from
tensorizer.utils
import
(
convert_bytes
,
get_mem_usage
,
no_init_or_tensor
)
no_init_or_tensor
)
except
ImportError
:
except
ImportError
as
e
:
tensorizer_load_fail
=
Tru
e
tensorizer_load_fail
=
e
__all__
=
[
__all__
=
[
'EncryptionParams'
,
'DecryptionParams'
,
'TensorDeserializer'
,
'EncryptionParams'
,
'DecryptionParams'
,
'TensorDeserializer'
,
'TensorSerializer'
,
'open_stream'
,
'convert_bytes'
,
'get_mem_usage'
,
'TensorSerializer'
,
'open_stream'
,
'convert_bytes'
,
'get_mem_usage'
,
'no_init_or_tensor'
'no_init_or_tensor'
,
'TensorizerConfig'
]
]
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
dataclass
class
TensorizerConfig
:
tensorizer_uri
:
Union
[
io
.
BufferedIOBase
,
io
.
RawIOBase
,
typing
.
BinaryIO
,
str
,
bytes
,
os
.
PathLike
,
int
]
vllm_tensorized
:
bool
verify_hash
:
Optional
[
bool
]
=
False
num_readers
:
Optional
[
int
]
=
1
encryption_keyfile
:
Optional
[
str
]
=
None
s3_access_key_id
:
Optional
[
str
]
=
None
s3_secret_access_key
:
Optional
[
str
]
=
None
s3_endpoint
:
Optional
[
str
]
=
None
model_class
:
Optional
[
Type
[
torch
.
nn
.
Module
]]
=
None
hf_config
:
Optional
[
PretrainedConfig
]
=
None
dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
def
_construct_tensorizer_args
(
self
)
->
"TensorizerArgs"
:
tensorizer_args
=
{
"tensorizer_uri"
:
self
.
tensorizer_uri
,
"vllm_tensorized"
:
self
.
vllm_tensorized
,
"verify_hash"
:
self
.
verify_hash
,
"num_readers"
:
self
.
num_readers
,
"encryption_keyfile"
:
self
.
encryption_keyfile
,
"s3_access_key_id"
:
self
.
s3_access_key_id
,
"s3_secret_access_key"
:
self
.
s3_secret_access_key
,
"s3_endpoint"
:
self
.
s3_endpoint
,
}
return
TensorizerArgs
(
**
tensorizer_args
)
def
verify_with_parallel_config
(
self
,
parallel_config
:
"ParallelConfig"
,
)
->
None
:
if
(
parallel_config
.
tensor_parallel_size
>
1
and
self
.
tensorizer_uri
is
not
None
):
raise
ValueError
(
"Loading to multiple GPUs is not currently supported with "
"vLLM-serialized models. Please set tensor_parallel_size=1."
" or use a non-vLLM-serialized model, such as a "
"serialized Hugging Face `PretrainedModel`."
)
def
verify_with_model_config
(
self
,
model_config
:
"ModelConfig"
)
->
None
:
if
(
model_config
.
quantization
is
not
None
and
self
.
tensorizer_uri
is
not
None
):
logger
.
warning
(
"Loading a model using Tensorizer with quantization on vLLM"
" is unstable and may lead to errors."
)
def
load_with_tensorizer
(
tensorizer_config
:
TensorizerConfig
,
def
load_with_tensorizer
(
tensorizer_config
:
TensorizerConfig
,
**
extra_kwargs
)
->
nn
.
Module
:
**
extra_kwargs
)
->
nn
.
Module
:
tensorizer
=
TensorizerAgent
(
tensorizer_config
,
**
extra_kwargs
)
tensorizer
=
TensorizerAgent
(
tensorizer_config
,
**
extra_kwargs
)
return
tensorizer
.
deserialize
()
return
tensorizer
.
deserialize
()
def
tensorizer_warning
(
message
:
str
):
return
warnings
.
warn
(
message
,
category
=
PerformanceWarning
,
stacklevel
=
2
)
def
is_vllm_serialized_tensorizer
(
tensorizer_config
:
TensorizerConfig
)
->
bool
:
def
is_vllm_serialized_tensorizer
(
tensorizer_config
:
TensorizerConfig
)
->
bool
:
if
tensorizer_config
is
None
:
if
tensorizer_config
is
None
:
return
False
return
False
return
tensorizer_config
.
vllm_tensorized
return
tensorizer_config
.
vllm_tensorized
class
ParameterizedLoadFormat
(
str
):
__slots__
=
"params"
class
PerformanceWarning
(
UserWarning
):
def
__str__
(
self
):
return
(
f
"
{
super
().
__str__
()
}
"
" (set the VLLM_SILENCE_PERFORMANCE_WARNINGS"
" environment variable to hide this)"
)
if
(
os
.
getenv
(
"VLLM_SILENCE_PERFORMANCE_WARNINGS"
,
""
).
lower
()
not
in
(
""
,
"0"
,
"n"
,
"no"
,
"off"
,
"disable"
)):
warnings
.
simplefilter
(
"ignore"
,
category
=
PerformanceWarning
)
@
dataclass
@
dataclass
class
TensorizerArgs
:
class
TensorizerArgs
:
tensorizer_uri
:
Union
[
io
.
BufferedIOBase
,
io
.
RawIOBase
,
typing
.
BinaryIO
,
tensorizer_uri
:
Union
[
io
.
BufferedIOBase
,
io
.
RawIOBase
,
typing
.
BinaryIO
,
...
@@ -219,11 +246,17 @@ class TensorizerAgent:
...
@@ -219,11 +246,17 @@ class TensorizerAgent:
behavior of the TensorDeserializer when loading tensors from a serialized
behavior of the TensorDeserializer when loading tensors from a serialized
model. For deserializations of HuggingFace models, TensorDeserializer is
model. For deserializations of HuggingFace models, TensorDeserializer is
instead used as an iterator directly in the func hf_model_weights_iterator
instead used as an iterator directly in the func hf_model_weights_iterator
in vllm/model_executor/weight_utils.py
in vllm/model_executor/
model_loader/
weight_utils.py
"""
"""
def
__init__
(
self
,
tensorizer_config
:
TensorizerConfig
,
def
__init__
(
self
,
tensorizer_config
:
TensorizerConfig
,
linear_method
:
LinearMethodBase
,
**
extra_kwargs
):
linear_method
:
LinearMethodBase
,
**
extra_kwargs
):
if
tensorizer_load_fail
is
not
None
:
raise
ImportError
(
"Tensorizer is not installed. Please install tensorizer "
"to use this feature with `pip install vllm[tensorizer]`."
)
from
tensorizer_load_fail
self
.
tensorizer_config
=
tensorizer_config
self
.
tensorizer_config
=
tensorizer_config
self
.
tensorizer_args
=
(
self
.
tensorizer_args
=
(
self
.
tensorizer_config
.
_construct_tensorizer_args
())
self
.
tensorizer_config
.
_construct_tensorizer_args
())
...
@@ -234,11 +267,6 @@ class TensorizerAgent:
...
@@ -234,11 +267,6 @@ class TensorizerAgent:
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
model
=
self
.
_init_model
()
self
.
model
=
self
.
_init_model
()
if
tensorizer_load_fail
:
raise
ImportError
(
"Tensorizer is not installed. Please install tensorizer "
"to use this feature with `pip install vllm[tensorizer]`."
)
def
_init_model
(
self
):
def
_init_model
(
self
):
model_args
=
self
.
tensorizer_config
.
hf_config
model_args
=
self
.
tensorizer_config
.
hf_config
model_args
.
torch_dtype
=
self
.
tensorizer_config
.
dtype
model_args
.
torch_dtype
=
self
.
tensorizer_config
.
dtype
...
@@ -313,3 +341,23 @@ class TensorizerAgent:
...
@@ -313,3 +341,23 @@ class TensorizerAgent:
self
.
_check_tensors_on_meta_device
()
self
.
_check_tensors_on_meta_device
()
self
.
_resize_lora_embeddings
()
self
.
_resize_lora_embeddings
()
return
self
.
model
.
eval
()
return
self
.
model
.
eval
()
def
tensorizer_weights_iterator
(
tensorizer_args
:
"TensorizerArgs"
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
logger
.
warning
(
"Deserializing HuggingFace models is not optimized for "
"loading on vLLM, as tensorizer is forced to load to CPU. "
"Consider deserializing a vLLM model instead for faster "
"load times. See the examples/tensorize_vllm_model.py example "
"script for serializing vLLM models."
)
deserializer_args
=
tensorizer_args
.
deserializer_params
stream_params
=
tensorizer_args
.
stream_params
stream
=
open_stream
(
tensorizer_args
.
tensorizer_uri
,
**
stream_params
)
with
TensorDeserializer
(
stream
,
**
deserializer_args
,
device
=
"cpu"
)
as
state
:
for
name
,
param
in
state
.
items
():
yield
name
,
param
del
state
vllm/model_executor/model_loader/utils.py
0 → 100644
View file @
69e1d2fb
"""Utilities for selecting and loading models."""
import
contextlib
from
typing
import
Tuple
,
Type
import
torch
from
torch
import
nn
from
vllm.config
import
ModelConfig
from
vllm.model_executor.models
import
ModelRegistry
@
contextlib
.
contextmanager
def
set_default_torch_dtype
(
dtype
:
torch
.
dtype
):
"""Sets the default torch dtype to the given dtype."""
old_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
dtype
)
yield
torch
.
set_default_dtype
(
old_dtype
)
def
get_model_architecture
(
model_config
:
ModelConfig
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if
(
model_config
.
quantization
is
not
None
and
"MixtralForCausalLM"
in
architectures
):
architectures
=
[
"QuantMixtralForCausalLM"
]
for
arch
in
architectures
:
model_cls
=
ModelRegistry
.
load_model_cls
(
arch
)
if
model_cls
is
not
None
:
return
(
model_cls
,
arch
)
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Supported architectures:
{
ModelRegistry
.
get_supported_archs
()
}
"
)
def
get_architecture_class_name
(
model_config
:
ModelConfig
)
->
str
:
return
get_model_architecture
(
model_config
)[
1
]
vllm/model_executor/weight_utils.py
→
vllm/model_executor/
model_loader/
weight_utils.py
View file @
69e1d2fb
...
@@ -4,8 +4,9 @@ import glob
...
@@ -4,8 +4,9 @@ import glob
import
hashlib
import
hashlib
import
json
import
json
import
os
import
os
import
tempfile
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Any
,
Iterable
,
Itera
tor
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Generator
,
Itera
ble
,
List
,
Optional
,
Tuple
import
filelock
import
filelock
import
huggingface_hub.constants
import
huggingface_hub.constants
...
@@ -15,7 +16,7 @@ from huggingface_hub import HfFileSystem, snapshot_download
...
@@ -15,7 +16,7 @@ from huggingface_hub import HfFileSystem, snapshot_download
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
from
vllm.config
import
ModelConfig
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
get_quantization_config
)
get_quantization_config
)
...
@@ -27,8 +28,7 @@ logger = init_logger(__name__)
...
@@ -27,8 +28,7 @@ logger = init_logger(__name__)
# can share the same lock without error.
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
# system reboots, so users will not complain about annoying lock files
temp_dir
=
os
.
environ
.
get
(
'TMPDIR'
)
or
os
.
environ
.
get
(
temp_dir
=
tempfile
.
gettempdir
()
'TEMP'
)
or
os
.
environ
.
get
(
'TMP'
)
or
"/tmp/"
def
enable_hf_transfer
():
def
enable_hf_transfer
():
...
@@ -46,7 +46,7 @@ def enable_hf_transfer():
...
@@ -46,7 +46,7 @@ def enable_hf_transfer():
enable_hf_transfer
()
enable_hf_transfer
()
class
Disabled
t
qdm
(
tqdm
):
class
Disabled
T
qdm
(
tqdm
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
,
disable
=
True
)
super
().
__init__
(
*
args
,
**
kwargs
,
disable
=
True
)
...
@@ -114,7 +114,8 @@ def convert_bin_to_safetensor_file(
...
@@ -114,7 +114,8 @@ def convert_bin_to_safetensor_file(
# TODO(woosuk): Move this to other place.
# TODO(woosuk): Move this to other place.
def
get_quant_config
(
model_config
:
ModelConfig
)
->
QuantizationConfig
:
def
get_quant_config
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
)
->
QuantizationConfig
:
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
# Read the quantization config from the HF model config, if available.
# Read the quantization config from the HF model config, if available.
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
...
@@ -125,12 +126,12 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
...
@@ -125,12 +126,12 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
if
not
is_local
:
if
not
is_local
:
# Download the config files.
# Download the config files.
with
get_lock
(
model_name_or_path
,
model
_config
.
download_dir
):
with
get_lock
(
model_name_or_path
,
load
_config
.
download_dir
):
hf_folder
=
snapshot_download
(
model_name_or_path
,
hf_folder
=
snapshot_download
(
model_name_or_path
,
revision
=
model_config
.
revision
,
revision
=
model_config
.
revision
,
allow_patterns
=
"*.json"
,
allow_patterns
=
"*.json"
,
cache_dir
=
model
_config
.
download_dir
,
cache_dir
=
load
_config
.
download_dir
,
tqdm_class
=
Disabled
t
qdm
)
tqdm_class
=
Disabled
T
qdm
)
else
:
else
:
hf_folder
=
model_name_or_path
hf_folder
=
model_name_or_path
config_files
=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
"*.json"
))
config_files
=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
"*.json"
))
...
@@ -153,36 +154,24 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
...
@@ -153,36 +154,24 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
return
quant_cls
.
from_config
(
config
)
return
quant_cls
.
from_config
(
config
)
def
prepare_hf_model_weights
(
def
download_weights_from_hf
(
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
],
cache_dir
:
Optional
[
str
]
=
None
,
allow_patterns
:
List
[
str
],
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
)
->
str
:
fall_back_to_pt
:
bool
=
True
,
"""Download model weights from Hugging Face Hub.
revision
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
# Download model weights from huggingface.
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
\
and
load_format
!=
"tensorizer"
use_safetensors
=
False
# Some quantized models use .pt files for storing the weights.
if
load_format
==
"auto"
:
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
load_format
==
"safetensors"
:
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
"pt"
:
allow_patterns
=
[
"*.pt"
]
elif
load_format
==
"npcache"
:
allow_patterns
=
[
"*.bin"
]
elif
load_format
==
"tensorizer"
:
allow_patterns
=
[
"*.tensors"
]
else
:
raise
ValueError
(
f
"Unknown load_format:
{
load_format
}
"
)
if
fall_back_to_pt
:
Args:
allow_patterns
+=
[
"*.pt"
]
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
allow_patterns (List[str]): The allowed patterns for the
weight files. Files matched by any of the patterns will be
downloaded.
revision (Optional[str]): The revision of the model.
if
not
is_local
and
load_format
!=
"tensorizer"
:
Returns:
str: The path to the downloaded model weights.
"""
# Before we download we look at that is available:
# Before we download we look at that is available:
fs
=
HfFileSystem
()
fs
=
HfFileSystem
()
file_list
=
fs
.
ls
(
model_name_or_path
,
detail
=
False
,
revision
=
revision
)
file_list
=
fs
.
ls
(
model_name_or_path
,
detail
=
False
,
revision
=
revision
)
...
@@ -201,20 +190,18 @@ def prepare_hf_model_weights(
...
@@ -201,20 +190,18 @@ def prepare_hf_model_weights(
hf_folder
=
snapshot_download
(
model_name_or_path
,
hf_folder
=
snapshot_download
(
model_name_or_path
,
allow_patterns
=
allow_patterns
,
allow_patterns
=
allow_patterns
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
tqdm_class
=
Disabled
t
qdm
,
tqdm_class
=
Disabled
T
qdm
,
revision
=
revision
)
revision
=
revision
)
else
:
return
hf_folder
hf_folder
=
model_name_or_path
hf_weights_files
:
List
[
str
]
=
[]
for
pattern
in
allow_patterns
:
def
filter_files_not_needed_for_inference
(
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
hf_weights_files
:
List
[
str
])
->
List
[
str
]:
if
len
(
hf_weights_files
)
>
0
:
"""
if
pattern
==
"*.safetensors"
:
Exclude files that are not needed for inference.
use_safetensors
=
True
break
See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
if
not
use_safetensors
:
"""
# Exclude files that are not needed for inference.
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
blacklist
=
[
blacklist
=
[
"training_args.bin"
,
"training_args.bin"
,
"optimizer.bin"
,
"optimizer.bin"
,
...
@@ -226,35 +213,17 @@ def prepare_hf_model_weights(
...
@@ -226,35 +213,17 @@ def prepare_hf_model_weights(
f
for
f
in
hf_weights_files
f
for
f
in
hf_weights_files
if
not
any
(
f
.
endswith
(
x
)
for
x
in
blacklist
)
if
not
any
(
f
.
endswith
(
x
)
for
x
in
blacklist
)
]
]
return
hf_weights_files
if
load_format
==
"tensorizer"
:
return
hf_folder
,
hf_weights_files
,
use_safetensors
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_folder
,
hf_weights_files
,
use_safetensors
def
np_cache_weights_iterator
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
],
hf_folder
:
str
,
hf_weights_files
:
List
[
str
]
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Iterate over the weights in the model np files.
def
hf_model_weights_iterator
(
Will dump the model weights to numpy files if they are not already dumped.
model_name_or_path
:
str
,
"""
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
Union
[
Tuple
,
str
]
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
fall_back_to_pt
:
Optional
[
bool
]
=
True
,
)
->
Iterator
[
Tuple
[
str
,
torch
.
Tensor
]]:
hf_folder
,
hf_weights_files
,
use_safetensors
=
prepare_hf_model_weights
(
model_name_or_path
,
cache_dir
=
cache_dir
,
load_format
=
load_format
,
fall_back_to_pt
=
fall_back_to_pt
,
revision
=
revision
)
if
load_format
==
"npcache"
:
# Currently np_cache only support *.bin checkpoints
assert
use_safetensors
is
False
# Convert the model weights from torch tensors to numpy arrays for
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
# faster loading.
np_folder
=
os
.
path
.
join
(
hf_folder
,
"np"
)
np_folder
=
os
.
path
.
join
(
hf_folder
,
"np"
)
...
@@ -283,33 +252,23 @@ def hf_model_weights_iterator(
...
@@ -283,33 +252,23 @@ def hf_model_weights_iterator(
with
open
(
param_path
,
"rb"
)
as
f
:
with
open
(
param_path
,
"rb"
)
as
f
:
param
=
np
.
load
(
f
)
param
=
np
.
load
(
f
)
yield
name
,
torch
.
from_numpy
(
param
)
yield
name
,
torch
.
from_numpy
(
param
)
elif
load_format
==
"tensorizer"
:
from
vllm.model_executor.tensorizer_loader
import
(
TensorDeserializer
,
open_stream
,
def
safetensors_weights_iterator
(
tensorizer_warning
)
hf_weights_files
:
List
[
str
]
tensorizer_args
=
load_format
.
params
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
tensorizer_warning
(
"""Iterate over the weights in the model safetensor files."""
"Deserializing HuggingFace models is not optimized for "
"loading on vLLM, as tensorizer is forced to load to CPU. "
"Consider deserializing a vLLM model instead for faster "
"load times. See the examples/tensorize_vllm_model.py example "
"script for serializing vLLM models."
)
deserializer_args
=
tensorizer_args
.
deserializer_params
stream_params
=
tensorizer_args
.
stream_params
stream
=
open_stream
(
tensorizer_args
.
tensorizer_uri
,
**
stream_params
)
with
TensorDeserializer
(
stream
,
**
deserializer_args
,
device
=
"cpu"
)
as
state
:
for
name
,
param
in
state
.
items
():
yield
name
,
param
del
state
elif
use_safetensors
:
for
st_file
in
hf_weights_files
:
for
st_file
in
hf_weights_files
:
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
for
name
in
f
.
keys
():
# noqa: SIM118
for
name
in
f
.
keys
():
# noqa: SIM118
param
=
f
.
get_tensor
(
name
)
param
=
f
.
get_tensor
(
name
)
yield
name
,
param
yield
name
,
param
else
:
def
pt_weights_iterator
(
hf_weights_files
:
List
[
str
]
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Iterate over the weights in the model bin/pt files."""
for
bin_file
in
hf_weights_files
:
for
bin_file
in
hf_weights_files
:
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
)
for
name
,
param
in
state
.
items
():
for
name
,
param
in
state
.
items
():
...
...
vllm/model_executor/models/baichuan.py
View file @
69e1d2fb
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
# limitations under the License.
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
import
math
import
math
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -40,9 +40,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -40,9 +40,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -340,19 +339,14 @@ class BaiChuanBaseForCausalLM(nn.Module):
...
@@ -340,19 +339,14 @@ class BaiChuanBaseForCausalLM(nn.Module):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
if
name
==
"lm_head.weight"
:
if
name
==
"lm_head.weight"
:
...
...
vllm/model_executor/models/bloom.py
View file @
69e1d2fb
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
# limitations under the License.
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights."""
"""Inference-only BLOOM model compatible with HuggingFace weights."""
import
math
import
math
from
typing
import
List
,
Optional
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -35,9 +35,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
...
@@ -35,9 +35,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -298,14 +297,9 @@ class BloomForCausalLM(nn.Module):
...
@@ -298,14 +297,9 @@ class BloomForCausalLM(nn.Module):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
name
==
"lm_head.weight"
:
if
name
==
"lm_head.weight"
:
continue
continue
if
not
name
.
startswith
(
"transformer."
):
if
not
name
.
startswith
(
"transformer."
):
...
...
vllm/model_executor/models/chatglm.py
View file @
69e1d2fb
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# Adapted from
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
"""Inference-only ChatGLM model compatible with THUDM weights."""
from
typing
import
List
,
Optional
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -22,9 +22,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -22,9 +22,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs
import
ChatGLMConfig
from
vllm.transformers_utils.configs
import
ChatGLMConfig
...
@@ -370,14 +369,9 @@ class ChatGLMForCausalLM(nn.Module):
...
@@ -370,14 +369,9 @@ class ChatGLMForCausalLM(nn.Module):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_pos_emb.inv_freq"
in
name
:
if
"rotary_pos_emb.inv_freq"
in
name
:
continue
continue
if
"word_embeddings"
in
name
:
if
"word_embeddings"
in
name
:
...
...
vllm/model_executor/models/commandr.py
View file @
69e1d2fb
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
# This file is based on the LLama model definition file in transformers
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
"""PyTorch Cohere model."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
...
@@ -41,10 +41,9 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -41,10 +41,9 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -335,13 +334,7 @@ class CohereForCausalLM(nn.Module):
...
@@ -335,13 +334,7 @@ class CohereForCausalLM(nn.Module):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -352,8 +345,7 @@ class CohereForCausalLM(nn.Module):
...
@@ -352,8 +345,7 @@ class CohereForCausalLM(nn.Module):
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
loaded_params
=
set
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
if
shard_name
not
in
name
:
continue
continue
...
...
vllm/model_executor/models/dbrx.py
View file @
69e1d2fb
# coding=utf-8
# coding=utf-8
from
typing
import
List
,
Optional
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -18,10 +18,9 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -18,10 +18,9 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
...
@@ -391,20 +390,13 @@ class DbrxForCausalLM(nn.Module):
...
@@ -391,20 +390,13 @@ class DbrxForCausalLM(nn.Module):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
expert_params_mapping
=
[(
expert_params_mapping
=
[(
"ws"
if
weight_name
in
[
"w1"
,
"v1"
]
else
"w2s"
,
"ws"
if
weight_name
in
[
"w1"
,
"v1"
]
else
"w2s"
,
f
"experts.mlp.
{
weight_name
}
"
,
f
"experts.mlp.
{
weight_name
}
"
,
)
for
weight_name
in
[
"w1"
,
"v1"
,
"w2"
]]
)
for
weight_name
in
[
"w1"
,
"v1"
,
"w2"
]]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
for
param_name
,
weight_name
in
expert_params_mapping
:
for
param_name
,
weight_name
in
expert_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
...
...
vllm/model_executor/models/decilm.py
View file @
69e1d2fb
...
@@ -23,16 +23,15 @@
...
@@ -23,16 +23,15 @@
# limitations under the License.
# limitations under the License.
"""Inference-only DeciLM model compatible with HuggingFace weights."""
"""Inference-only DeciLM model compatible with HuggingFace weights."""
from
typing
import
Optional
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
import
torch
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.config
import
LoRAConfig
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
class
DeciLMForCausalLM
(
LlamaForCausalLM
):
class
DeciLMForCausalLM
(
LlamaForCausalLM
):
...
@@ -65,11 +64,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
...
@@ -65,11 +64,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
linear_method
=
linear_method
,
linear_method
=
linear_method
,
lora_config
=
lora_config
)
lora_config
=
lora_config
)
def
load_weights
(
self
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -79,8 +74,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
...
@@ -79,8 +74,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
...
vllm/model_executor/models/deepseek.py
View file @
69e1d2fb
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only Deepseek model."""
"""Inference-only Deepseek model."""
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -44,9 +44,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -44,9 +44,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -316,6 +315,8 @@ class DeepseekDecoderLayer(nn.Module):
...
@@ -316,6 +315,8 @@ class DeepseekDecoderLayer(nn.Module):
class
DeepseekModel
(
nn
.
Module
):
class
DeepseekModel
(
nn
.
Module
):
fall_back_to_pt_during_load
=
False
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
...
@@ -395,11 +396,7 @@ class DeepseekForCausalLM(nn.Module):
...
@@ -395,11 +396,7 @@ class DeepseekForCausalLM(nn.Module):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -410,12 +407,7 @@ class DeepseekForCausalLM(nn.Module):
...
@@ -410,12 +407,7 @@ class DeepseekForCausalLM(nn.Module):
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
,
fall_back_to_pt
=
False
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
...
...
vllm/model_executor/models/falcon.py
View file @
69e1d2fb
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
"""PyTorch Falcon model."""
"""PyTorch Falcon model."""
import
math
import
math
from
typing
import
List
,
Optional
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -40,9 +40,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -40,9 +40,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs
import
RWConfig
from
vllm.transformers_utils.configs
import
RWConfig
...
@@ -399,11 +398,7 @@ class FalconForCausalLM(nn.Module):
...
@@ -399,11 +398,7 @@ class FalconForCausalLM(nn.Module):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
total_num_heads
=
self
.
config
.
num_attention_heads
total_num_heads
=
self
.
config
.
num_attention_heads
if
self
.
config
.
new_decoder_architecture
:
if
self
.
config
.
new_decoder_architecture
:
total_num_kv_heads
=
self
.
config
.
num_kv_heads
total_num_kv_heads
=
self
.
config
.
num_kv_heads
...
@@ -413,8 +408,7 @@ class FalconForCausalLM(nn.Module):
...
@@ -413,8 +408,7 @@ class FalconForCausalLM(nn.Module):
total_num_kv_heads
=
total_num_heads
total_num_kv_heads
=
total_num_heads
num_query_heads_per_kv_head
=
total_num_heads
//
total_num_kv_heads
num_query_heads_per_kv_head
=
total_num_heads
//
total_num_kv_heads
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
name
==
"lm_head.weight"
:
if
name
==
"lm_head.weight"
:
# Falcon uses tied embeddings.
# Falcon uses tied embeddings.
continue
continue
...
...
vllm/model_executor/models/gemma.py
View file @
69e1d2fb
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
# limitations under the License.
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
"""Inference-only Gemma model compatible with HuggingFace weights."""
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -36,9 +36,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -36,9 +36,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -346,11 +345,7 @@ class GemmaForCausalLM(nn.Module):
...
@@ -346,11 +345,7 @@ class GemmaForCausalLM(nn.Module):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"q_proj"
,
"q"
),
...
@@ -361,8 +356,7 @@ class GemmaForCausalLM(nn.Module):
...
@@ -361,8 +356,7 @@ class GemmaForCausalLM(nn.Module):
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
loaded_params
=
set
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
shard_name
,
shard_id
)
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
if
shard_name
not
in
name
:
continue
continue
...
...
vllm/model_executor/models/gpt2.py
View file @
69e1d2fb
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -34,9 +34,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
...
@@ -34,9 +34,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -239,14 +238,9 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -239,14 +238,9 @@ class GPT2LMHeadModel(nn.Module):
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
weights
:
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"lm_head.weight"
in
name
:
if
"lm_head.weight"
in
name
:
# GPT-2 ties the weights of the embedding layer and the final
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
# linear layer.
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment