Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cea808f3
Unverified
Commit
cea808f3
authored
Nov 02, 2024
by
youkaichao
Committed by
GitHub
Nov 02, 2024
Browse files
[3/N] model runner pass the whole config to model (#9958)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
74b529ce
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
87 additions
and
140 deletions
+87
-140
tests/lora/conftest.py
tests/lora/conftest.py
+4
-5
vllm/model_executor/model_loader/__init__.py
vllm/model_executor/model_loader/__init__.py
+4
-16
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+54
-78
vllm/plugins/__init__.py
vllm/plugins/__init__.py
+20
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-7
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+1
-7
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+1
-7
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+1
-9
vllm/worker/xpu_model_runner.py
vllm/worker/xpu_model_runner.py
+1
-9
No files found.
tests/lora/conftest.py
View file @
cea808f3
...
@@ -248,11 +248,10 @@ def llama_2_7b_engine_extra_embeddings():
...
@@ -248,11 +248,10 @@ def llama_2_7b_engine_extra_embeddings():
cleanup_dist_env_and_memory
(
shutdown_ray
=
True
)
cleanup_dist_env_and_memory
(
shutdown_ray
=
True
)
get_model_old
=
get_model
get_model_old
=
get_model
def
get_model_patched
(
*
,
model_config
,
device_config
,
**
kwargs
):
def
get_model_patched
(
**
kwargs
):
kwargs
[
"lora_config"
]
=
LoRAConfig
(
max_loras
=
4
,
max_lora_rank
=
8
)
kwargs
[
"vllm_config"
].
lora_config
=
LoRAConfig
(
max_loras
=
4
,
return
get_model_old
(
model_config
=
model_config
,
max_lora_rank
=
8
)
device_config
=
device_config
,
return
get_model_old
(
**
kwargs
)
**
kwargs
)
with
patch
(
"vllm.worker.model_runner.get_model"
,
get_model_patched
):
with
patch
(
"vllm.worker.model_runner.get_model"
,
get_model_patched
):
engine
=
vllm
.
LLM
(
"meta-llama/Llama-2-7b-hf"
,
enable_lora
=
False
)
engine
=
vllm
.
LLM
(
"meta-llama/Llama-2-7b-hf"
,
enable_lora
=
False
)
...
...
vllm/model_executor/model_loader/__init__.py
View file @
cea808f3
from
typing
import
Optional
from
torch
import
nn
from
torch
import
nn
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
VllmConfig
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.model_executor.model_loader.loader
import
(
BaseModelLoader
,
from
vllm.model_executor.model_loader.loader
import
(
BaseModelLoader
,
get_model_loader
)
get_model_loader
)
from
vllm.model_executor.model_loader.utils
import
(
from
vllm.model_executor.model_loader.utils
import
(
get_architecture_class_name
,
get_model_architecture
)
get_architecture_class_name
,
get_model_architecture
)
def
get_model
(
*
,
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
def
get_model
(
*
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
:
DeviceConfig
,
parallel_config
:
ParallelConfig
,
loader
=
get_model_loader
(
vllm_config
.
load_config
)
scheduler_config
:
SchedulerConfig
,
return
loader
.
load_model
(
vllm_config
=
vllm_config
)
lora_config
:
Optional
[
LoRAConfig
],
cache_config
:
CacheConfig
)
->
nn
.
Module
:
loader
=
get_model_loader
(
load_config
)
return
loader
.
load_model
(
model_config
=
model_config
,
device_config
=
device_config
,
lora_config
=
lora_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
,
cache_config
=
cache_config
)
__all__
=
[
__all__
=
[
...
...
vllm/model_executor/model_loader/loader.py
View file @
cea808f3
...
@@ -21,9 +21,9 @@ from torch import nn
...
@@ -21,9 +21,9 @@ from torch import nn
from
transformers
import
AutoModelForCausalLM
,
PretrainedConfig
from
transformers
import
AutoModelForCausalLM
,
PretrainedConfig
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoadFormat
,
from
vllm.config
import
(
CacheConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
ParallelConfig
,
PoolerConfig
,
SchedulerConfig
)
PoolerConfig
,
SchedulerConfig
,
VllmConfig
)
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.envs
import
VLLM_USE_MODELSCOPE
...
@@ -150,6 +150,7 @@ def _get_model_initialization_kwargs(
...
@@ -150,6 +150,7 @@ def _get_model_initialization_kwargs(
def
build_model
(
model_class
:
Type
[
nn
.
Module
],
def
build_model
(
model_class
:
Type
[
nn
.
Module
],
vllm_config
:
VllmConfig
,
hf_config
:
PretrainedConfig
,
hf_config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
],
cache_config
:
Optional
[
CacheConfig
],
quant_config
:
Optional
[
QuantizationConfig
],
quant_config
:
Optional
[
QuantizationConfig
],
...
@@ -166,23 +167,29 @@ def build_model(model_class: Type[nn.Module],
...
@@ -166,23 +167,29 @@ def build_model(model_class: Type[nn.Module],
if
prefix
:
if
prefix
:
extra_kwargs
[
"prefix"
]
=
prefix
extra_kwargs
[
"prefix"
]
=
prefix
# TODO: unify all the module initialization code
# to only take the `VllmConfig` object as input
from
vllm.plugins
import
set_vllm_config
set_vllm_config
(
vllm_config
)
return
model_class
(
config
=
hf_config
,
return
model_class
(
config
=
hf_config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
**
extra_kwargs
)
**
extra_kwargs
)
def
_initialize_model
(
def
_initialize_model
(
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
cache_config
:
CacheConfig
,
scheduler_config
:
Optional
[
SchedulerConfig
]
=
None
)
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
"""Initialize a model with the given configurations."""
model_config
=
vllm_config
.
model_config
lora_config
=
vllm_config
.
lora_config
scheduler_config
=
vllm_config
.
scheduler_config
cache_config
=
vllm_config
.
cache_config
load_config
=
vllm_config
.
load_config
model_class
,
_
=
get_model_architecture
(
model_config
)
model_class
,
_
=
get_model_architecture
(
model_config
)
return
build_model
(
return
build_model
(
model_class
,
model_class
,
vllm_config
,
model_config
.
hf_config
,
model_config
.
hf_config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
_get_quantization_config
(
model_config
,
load_config
),
quant_config
=
_get_quantization_config
(
model_config
,
load_config
),
...
@@ -205,12 +212,7 @@ class BaseModelLoader(ABC):
...
@@ -205,12 +212,7 @@ class BaseModelLoader(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
self
,
*
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
"""Load a model with the given configurations."""
"""Load a model with the given configurations."""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -396,18 +398,14 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -396,18 +398,14 @@ class DefaultModelLoader(BaseModelLoader):
model_config
.
revision
,
model_config
.
revision
,
fall_back_to_pt
=
True
)
fall_back_to_pt
=
True
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
:
DeviceConfig
,
device_config
=
vllm_config
.
device_config
lora_config
:
Optional
[
LoRAConfig
],
model_config
=
vllm_config
.
model_config
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
target_device
=
torch
.
device
(
device_config
.
device
)
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
with
target_device
:
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
lora_config
,
cache_config
,
scheduler_config
)
model
.
load_weights
(
self
.
_get_all_weights
(
model_config
,
model
))
model
.
load_weights
(
self
.
_get_all_weights
(
model_config
,
model
))
...
@@ -436,17 +434,12 @@ class DummyModelLoader(BaseModelLoader):
...
@@ -436,17 +434,12 @@ class DummyModelLoader(BaseModelLoader):
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
pass
# Nothing to download
pass
# Nothing to download
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
:
DeviceConfig
,
device_config
=
vllm_config
.
device_config
lora_config
:
Optional
[
LoRAConfig
],
model_config
=
vllm_config
.
model_config
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
lora_config
,
cache_config
,
scheduler_config
)
# NOTE(woosuk): For accurate performance evaluation, we assign
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
# random values to the weights.
initialize_dummy_weights
(
model
)
initialize_dummy_weights
(
model
)
...
@@ -488,10 +481,7 @@ class TensorizerLoader(BaseModelLoader):
...
@@ -488,10 +481,7 @@ class TensorizerLoader(BaseModelLoader):
def
_load_model_serialized_cpu
(
def
_load_model_serialized_cpu
(
self
,
self
,
model_config
:
ModelConfig
,
vllm_config
:
VllmConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
cache_config
:
CacheConfig
,
)
->
nn
.
Module
:
)
->
nn
.
Module
:
"""Load a serialized model with tensorizer to the CPU.
"""Load a serialized model with tensorizer to the CPU.
...
@@ -500,26 +490,30 @@ class TensorizerLoader(BaseModelLoader):
...
@@ -500,26 +490,30 @@ class TensorizerLoader(BaseModelLoader):
default HuggingFace loading, but will be slower than loading a
default HuggingFace loading, but will be slower than loading a
vLLM-tensorized model.
vLLM-tensorized model.
"""
"""
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
lora_config
,
cache_config
)
model
.
load_weights
(
self
.
_get_weights_iterator
())
model
.
load_weights
(
self
.
_get_weights_iterator
())
return
model
.
eval
()
return
model
.
eval
()
def
_load_model_serialized
(
def
_load_model_serialized
(
self
,
self
,
model_config
:
ModelConfig
,
vllm_config
:
VllmConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
cache_config
:
CacheConfig
,
)
->
nn
.
Module
:
)
->
nn
.
Module
:
"""Load a serialized model with tensorizer.
"""Load a serialized model with tensorizer.
Expects a vLLM-tensorized model. See the
Expects a vLLM-tensorized model. See the
examples/tensorize_vllm_model.py example script
examples/tensorize_vllm_model.py example script
for serializing vLLM models."""
for serializing vLLM models."""
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
lora_config
=
vllm_config
.
lora_config
cache_config
=
vllm_config
.
cache_config
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
with
torch
.
device
(
device_config
.
device
):
model_class
=
get_model_architecture
(
model_config
)[
0
]
model_class
=
get_model_architecture
(
model_config
)[
0
]
...
@@ -544,12 +538,9 @@ class TensorizerLoader(BaseModelLoader):
...
@@ -544,12 +538,9 @@ class TensorizerLoader(BaseModelLoader):
with
self
.
tensorizer_config
.
open_stream
():
with
self
.
tensorizer_config
.
open_stream
():
pass
pass
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
:
DeviceConfig
,
model_config
=
vllm_config
.
model_config
lora_config
:
Optional
[
LoRAConfig
],
parallel_config
=
vllm_config
.
parallel_config
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
self
.
_verify_config
(
model_config
,
parallel_config
)
self
.
_verify_config
(
model_config
,
parallel_config
)
if
parallel_config
.
tensor_parallel_size
>
1
:
if
parallel_config
.
tensor_parallel_size
>
1
:
...
@@ -559,10 +550,8 @@ class TensorizerLoader(BaseModelLoader):
...
@@ -559,10 +550,8 @@ class TensorizerLoader(BaseModelLoader):
%
get_tensor_model_parallel_rank
()
%
get_tensor_model_parallel_rank
()
if
is_vllm_tensorized
(
self
.
tensorizer_config
):
if
is_vllm_tensorized
(
self
.
tensorizer_config
):
return
self
.
_load_model_serialized
(
model_config
,
device_config
,
return
self
.
_load_model_serialized
(
vllm_config
=
vllm_config
)
lora_config
,
cache_config
)
return
self
.
_load_model_serialized_cpu
(
vllm_config
=
vllm_config
)
return
self
.
_load_model_serialized_cpu
(
model_config
,
device_config
,
lora_config
,
cache_config
)
@
staticmethod
@
staticmethod
def
save_model
(
def
save_model
(
...
@@ -648,12 +637,9 @@ class ShardedStateLoader(BaseModelLoader):
...
@@ -648,12 +637,9 @@ class ShardedStateLoader(BaseModelLoader):
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
:
DeviceConfig
,
device_config
=
vllm_config
.
device_config
lora_config
:
Optional
[
LoRAConfig
],
model_config
=
vllm_config
.
model_config
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
from
safetensors.torch
import
safe_open
from
safetensors.torch
import
safe_open
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.distributed
import
get_tensor_model_parallel_rank
...
@@ -663,8 +649,7 @@ class ShardedStateLoader(BaseModelLoader):
...
@@ -663,8 +649,7 @@ class ShardedStateLoader(BaseModelLoader):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
lora_config
,
cache_config
)
for
_
,
module
in
model
.
named_modules
():
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
if
quant_method
is
not
None
:
...
@@ -1157,16 +1142,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -1157,16 +1142,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
:
DeviceConfig
,
device_config
=
vllm_config
.
device_config
lora_config
:
Optional
[
LoRAConfig
],
model_config
=
vllm_config
.
model_config
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
lora_config
,
cache_config
)
self
.
_load_weights
(
model_config
,
model
)
self
.
_load_weights
(
model_config
,
model
)
...
@@ -1235,13 +1216,9 @@ class GGUFModelLoader(BaseModelLoader):
...
@@ -1235,13 +1216,9 @@ class GGUFModelLoader(BaseModelLoader):
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
)
self
.
_prepare_weights
(
model_config
.
model
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
:
DeviceConfig
,
device_config
=
vllm_config
.
device_config
lora_config
:
Optional
[
LoRAConfig
],
model_config
=
vllm_config
.
model_config
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
local_model_path
=
self
.
_prepare_weights
(
model_config
.
model
)
local_model_path
=
self
.
_prepare_weights
(
model_config
.
model
)
gguf_weights_map
=
self
.
_get_gguf_weights_map
(
model_config
)
gguf_weights_map
=
self
.
_get_gguf_weights_map
(
model_config
)
# we can only know if tie word embeddings after mapping weights
# we can only know if tie word embeddings after mapping weights
...
@@ -1251,8 +1228,7 @@ class GGUFModelLoader(BaseModelLoader):
...
@@ -1251,8 +1228,7 @@ class GGUFModelLoader(BaseModelLoader):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
lora_config
,
cache_config
)
model
.
load_weights
(
model
.
load_weights
(
self
.
_get_weights_iterator
(
local_model_path
,
gguf_weights_map
))
self
.
_get_weights_iterator
(
local_model_path
,
gguf_weights_map
))
return
model
return
model
...
...
vllm/plugins/__init__.py
View file @
cea808f3
import
logging
import
logging
from
typing
import
Callable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.compilation.config
import
CompilationConfig
if
TYPE_CHECKING
:
from
vllm.compilation.config
import
CompilationConfig
from
vllm.config
import
VllmConfig
else
:
CompilationConfig
=
None
VllmConfig
=
None
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -55,3 +61,15 @@ def set_compilation_config(config: Optional[CompilationConfig]):
...
@@ -55,3 +61,15 @@ def set_compilation_config(config: Optional[CompilationConfig]):
def
get_compilation_config
()
->
Optional
[
CompilationConfig
]:
def
get_compilation_config
()
->
Optional
[
CompilationConfig
]:
return
_compilation_config
return
_compilation_config
_vllm_config
:
Optional
[
VllmConfig
]
=
None
def
set_vllm_config
(
config
:
Optional
[
VllmConfig
]):
global
_vllm_config
_vllm_config
=
config
def
get_vllm_config
()
->
Optional
[
VllmConfig
]:
return
_vllm_config
vllm/v1/worker/gpu_model_runner.py
View file @
cea808f3
...
@@ -369,13 +369,7 @@ class GPUModelRunner:
...
@@ -369,13 +369,7 @@ class GPUModelRunner:
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
with
DeviceMemoryProfiler
()
as
m
:
# noqa: SIM117
with
DeviceMemoryProfiler
()
as
m
:
# noqa: SIM117
with
patch
(
"vllm.model_executor.layers.sampler.Sampler"
,
Sampler
):
with
patch
(
"vllm.model_executor.layers.sampler.Sampler"
,
Sampler
):
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
device_config
=
self
.
device_config
,
load_config
=
self
.
load_config
,
lora_config
=
self
.
lora_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
,
cache_config
=
self
.
cache_config
)
self
.
model_memory_usage
=
m
.
consumed_memory
self
.
model_memory_usage
=
m
.
consumed_memory
logger
.
info
(
"Loading model weights took %.4f GB"
,
logger
.
info
(
"Loading model weights took %.4f GB"
,
...
...
vllm/worker/cpu_model_runner.py
View file @
cea808f3
...
@@ -453,13 +453,7 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
...
@@ -453,13 +453,7 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
return
uses_mrope
(
self
.
model_config
.
hf_config
)
return
uses_mrope
(
self
.
model_config
.
hf_config
)
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
load_config
=
self
.
load_config
,
device_config
=
self
.
device_config
,
lora_config
=
self
.
lora_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
,
cache_config
=
self
.
cache_config
)
def
make_model_input_from_broadcasted_tensor_dict
(
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
self
,
...
...
vllm/worker/model_runner.py
View file @
cea808f3
...
@@ -1051,13 +1051,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1051,13 +1051,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
with
DeviceMemoryProfiler
()
as
m
:
with
DeviceMemoryProfiler
()
as
m
:
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
device_config
=
self
.
device_config
,
load_config
=
self
.
load_config
,
lora_config
=
self
.
lora_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
,
cache_config
=
self
.
cache_config
)
self
.
model_memory_usage
=
m
.
consumed_memory
self
.
model_memory_usage
=
m
.
consumed_memory
logger
.
info
(
"Loading model weights took %.4f GB"
,
logger
.
info
(
"Loading model weights took %.4f GB"
,
...
...
vllm/worker/tpu_model_runner.py
View file @
cea808f3
...
@@ -137,15 +137,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -137,15 +137,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
"vllm.model_executor.layers.vocab_parallel_embedding."
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank"
,
"get_tensor_model_parallel_rank"
,
return_value
=
xm_tp_rank
):
return_value
=
xm_tp_rank
):
model
=
get_model
(
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
model_config
=
self
.
model_config
,
load_config
=
self
.
load_config
,
device_config
=
self
.
device_config
,
parallel_config
=
self
.
parallel_config
,
cache_config
=
self
.
cache_config
,
scheduler_config
=
self
.
scheduler_config
,
lora_config
=
None
,
)
model
=
model
.
eval
()
model
=
model
.
eval
()
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
self
.
model
=
ModelWrapper
(
model
)
self
.
model
=
ModelWrapper
(
model
)
...
...
vllm/worker/xpu_model_runner.py
View file @
cea808f3
...
@@ -405,15 +405,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
...
@@ -405,15 +405,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
with
DeviceMemoryProfiler
()
as
m
:
with
DeviceMemoryProfiler
()
as
m
:
self
.
model
=
get_model
(
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
model_config
=
self
.
model_config
,
device_config
=
self
.
device_config
,
load_config
=
self
.
load_config
,
lora_config
=
self
.
lora_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
,
cache_config
=
self
.
cache_config
,
)
self
.
model_memory_usage
=
m
.
consumed_memory
self
.
model_memory_usage
=
m
.
consumed_memory
logger
.
info
(
"Loading model weights took %.4f GB"
,
logger
.
info
(
"Loading model weights took %.4f GB"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment