Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4fd93750
Unverified
Commit
4fd93750
authored
Nov 16, 2024
by
youkaichao
Committed by
GitHub
Nov 16, 2024
Browse files
[2/N][torch.compile] make compilation cfg part of vllm cfg (#10383)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
661a34fd
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
77 additions
and
26 deletions
+77
-26
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+5
-2
vllm/platforms/interface.py
vllm/platforms/interface.py
+19
-1
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+14
-7
vllm/plugins/__init__.py
vllm/plugins/__init__.py
+28
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-7
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+3
-4
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+5
-3
No files found.
vllm/model_executor/model_loader/loader.py
View file @
4fd93750
...
@@ -42,6 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -42,6 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import (
safetensors_weights_iterator
)
safetensors_weights_iterator
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.plugins
import
set_current_vllm_config
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
...
@@ -97,7 +98,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
...
@@ -97,7 +98,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
all_params
=
[
param
.
name
for
param
in
signatures
.
parameters
.
values
()]
all_params
=
[
param
.
name
for
param
in
signatures
.
parameters
.
values
()]
if
"vllm_config"
in
all_params
and
"prefix"
in
all_params
:
if
"vllm_config"
in
all_params
and
"prefix"
in
all_params
:
# new-style model class
# new-style model class
return
model_class
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
with
set_current_vllm_config
(
vllm_config
):
return
model_class
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
msg
=
(
"vLLM model class should accept `vllm_config` and `prefix` as "
msg
=
(
"vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
" registered from out of tree and it is used for new vLLM version. "
...
@@ -121,7 +123,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
...
@@ -121,7 +123,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
kwargs
[
"lora_config"
]
=
vllm_config
.
lora_config
kwargs
[
"lora_config"
]
=
vllm_config
.
lora_config
if
"scheduler_config"
in
all_params
:
if
"scheduler_config"
in
all_params
:
kwargs
[
"scheduler_config"
]
=
vllm_config
.
scheduler_config
kwargs
[
"scheduler_config"
]
=
vllm_config
.
scheduler_config
return
model_class
(
**
kwargs
)
with
set_current_vllm_config
(
vllm_config
):
return
model_class
(
**
kwargs
)
class
BaseModelLoader
(
ABC
):
class
BaseModelLoader
(
ABC
):
...
...
vllm/platforms/interface.py
View file @
4fd93750
import
enum
import
enum
import
random
import
random
from
typing
import
NamedTuple
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
NamedTuple
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
else
:
VllmConfig
=
None
class
PlatformEnum
(
enum
.
Enum
):
class
PlatformEnum
(
enum
.
Enum
):
CUDA
=
enum
.
auto
()
CUDA
=
enum
.
auto
()
...
@@ -129,6 +134,19 @@ class Platform:
...
@@ -129,6 +134,19 @@ class Platform:
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
@
classmethod
def
check_and_update_config
(
cls
,
vllm_config
:
VllmConfig
)
->
None
:
"""
Check and update the configuration for the current platform.
It can raise an exception if the configuration is not compatible with
the current platform, or it can update the configuration to make it
compatible with the current platform.
The config is passed by reference, so it can be modified in place.
"""
pass
class
UnspecifiedPlatform
(
Platform
):
class
UnspecifiedPlatform
(
Platform
):
_enum
=
PlatformEnum
.
UNSPECIFIED
_enum
=
PlatformEnum
.
UNSPECIFIED
vllm/platforms/tpu.py
View file @
4fd93750
import
os
import
os
from
typing
import
TYPE_CHECKING
import
torch
import
torch
import
vllm.envs
as
envs
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.plugins
import
set_torch_compile_backend
from
vllm.plugins
import
set_torch_compile_backend
from
.interface
import
Platform
,
PlatformEnum
from
.interface
import
Platform
,
PlatformEnum
if
"VLLM_TORCH_COMPILE_LEVEL"
not
in
os
.
environ
:
if
TYPE_CHECKING
:
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
DYNAMO_ONCE
)
from
vllm.config
import
VllmConfig
else
:
assert
envs
.
VLLM_TORCH_COMPILE_LEVEL
<
CompilationLevel
.
PIECEWISE
,
\
VllmConfig
=
None
"TPU does not support Inductor."
set_torch_compile_backend
(
"openxla"
)
set_torch_compile_backend
(
"openxla"
)
...
@@ -31,3 +29,12 @@ class TpuPlatform(Platform):
...
@@ -31,3 +29,12 @@ class TpuPlatform(Platform):
@
classmethod
@
classmethod
def
inference_mode
(
cls
):
def
inference_mode
(
cls
):
return
torch
.
no_grad
()
return
torch
.
no_grad
()
@
classmethod
def
check_and_update_config
(
cls
,
vllm_config
:
VllmConfig
)
->
None
:
from
vllm.config
import
CompilationLevel
compilation_config
=
vllm_config
.
compilation_config
if
"VLLM_TORCH_COMPILE_LEVEL"
not
in
os
.
environ
:
compilation_config
.
level
=
CompilationLevel
.
DYNAMO_ONCE
assert
compilation_config
.
level
<
CompilationLevel
.
PIECEWISE
,
\
"TPU does not support Inductor."
vllm/plugins/__init__.py
View file @
4fd93750
import
logging
import
logging
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
import
vllm.envs
as
envs
import
vllm.envs
as
envs
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.compilation.config
import
CompilationConfig
from
vllm.config
import
CompilationConfig
,
VllmConfig
from
vllm.config
import
VllmConfig
else
:
else
:
CompilationConfig
=
None
CompilationConfig
=
None
VllmConfig
=
None
VllmConfig
=
None
...
@@ -72,3 +72,29 @@ def set_compilation_config(config: Optional[CompilationConfig]):
...
@@ -72,3 +72,29 @@ def set_compilation_config(config: Optional[CompilationConfig]):
def
get_compilation_config
()
->
Optional
[
CompilationConfig
]:
def
get_compilation_config
()
->
Optional
[
CompilationConfig
]:
return
_compilation_config
return
_compilation_config
_current_vllm_config
:
Optional
[
VllmConfig
]
=
None
@
contextmanager
def
set_current_vllm_config
(
vllm_config
:
VllmConfig
):
"""
Temporarily set the current VLLM config.
Used during model initialization.
We save the current VLLM config in a global variable,
so that all modules can access it, e.g. custom ops
can access the VLLM config to determine how to dispatch.
"""
global
_current_vllm_config
old_vllm_config
=
_current_vllm_config
try
:
_current_vllm_config
=
vllm_config
yield
finally
:
_current_vllm_config
=
old_vllm_config
def
get_current_vllm_config
()
->
VllmConfig
:
assert
_current_vllm_config
is
not
None
,
"Current VLLM config is not set."
return
_current_vllm_config
vllm/v1/worker/gpu_model_runner.py
View file @
4fd93750
import
os
import
time
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Set
,
Tuple
...
@@ -8,11 +7,8 @@ import torch
...
@@ -8,11 +7,8 @@ import torch
import
torch.distributed
import
torch.distributed
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm
import
envs
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.config
import
CompilationConfig
from
vllm.config
import
CompilationConfig
,
CompilationLevel
,
VllmConfig
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -99,7 +95,7 @@ class GPUModelRunner:
...
@@ -99,7 +95,7 @@ class GPUModelRunner:
pin_memory
=
self
.
pin_memory
,
pin_memory
=
self
.
pin_memory
,
)
)
self
.
use_cuda_graph
=
(
envs
.
VLLM_TORCH_COMPILE_LEVEL
self
.
use_cuda_graph
=
(
self
.
vllm_config
.
compilation_config
.
level
==
CompilationLevel
.
PIECEWISE
==
CompilationLevel
.
PIECEWISE
and
not
self
.
model_config
.
enforce_eager
)
and
not
self
.
model_config
.
enforce_eager
)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
...
@@ -517,9 +513,9 @@ class GPUModelRunner:
...
@@ -517,9 +513,9 @@ class GPUModelRunner:
# CUDA graphs do not work properly with the custom CUDA kernels.
# CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time
# FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor.
# and avoid any potential issues with the inductor.
os
.
environ
[
"VLLM_CUSTOM_OPS"
]
=
"none"
set_compilation_config
(
set_compilation_config
(
CompilationConfig
(
CompilationConfig
(
custom_ops
=
[
"none"
],
use_cudagraph
=
True
,
use_cudagraph
=
True
,
non_cudagraph_ops
=
[
"vllm.unified_v1_flash_attention"
],
non_cudagraph_ops
=
[
"vllm.unified_v1_flash_attention"
],
use_inductor
=
True
,
use_inductor
=
True
,
...
...
vllm/worker/model_runner.py
View file @
4fd93750
...
@@ -19,8 +19,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend
...
@@ -19,8 +19,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend
from
vllm.attention.backends.abstract
import
AttentionState
from
vllm.attention.backends.abstract
import
AttentionState
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.distributed
import
get_pp_group
from
vllm.distributed
import
get_pp_group
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.distributed.parallel_state
import
graph_capture
...
@@ -1142,8 +1141,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1142,8 +1141,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"provided. Defaulting to scaling factors of 1.0. "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!"
)
"This may lead to less accurate results!"
)
if
envs
.
VLLM_TORCH_COMPILE_LEVEL
==
CompilationLevel
.
DYNAMO_AS_IS
\
if
self
.
vllm_config
.
compilation_config
.
level
==
\
and
supports_dynamo
():
CompilationLevel
.
DYNAMO_AS_IS
and
supports_dynamo
():
from
vllm.plugins
import
get_torch_compile_backend
from
vllm.plugins
import
get_torch_compile_backend
backend
=
get_torch_compile_backend
()
or
"eager"
backend
=
get_torch_compile_backend
()
or
"eager"
self
.
model
=
torch
.
compile
(
self
.
model
=
torch
.
compile
(
...
...
vllm/worker/tpu_model_runner.py
View file @
4fd93750
...
@@ -140,7 +140,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -140,7 +140,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
model
=
model
.
eval
()
model
=
model
.
eval
()
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
self
.
model
=
ModelWrapper
(
model
)
self
.
model
=
ModelWrapper
(
model
,
self
.
vllm_config
)
def
_dummy_run
(
def
_dummy_run
(
self
,
self
,
...
@@ -669,13 +669,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -669,13 +669,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
class
ModelWrapper
(
TorchCompileWrapperWithCustomDispatcher
):
class
ModelWrapper
(
TorchCompileWrapperWithCustomDispatcher
):
def
__init__
(
self
,
model
:
nn
.
Module
):
def
__init__
(
self
,
model
:
nn
.
Module
,
vllm_config
:
VllmConfig
):
self
.
model
=
model
self
.
model
=
model
compiled_callable
=
torch
.
compile
(
self
.
forward
,
compiled_callable
=
torch
.
compile
(
self
.
forward
,
backend
=
"openxla"
,
backend
=
"openxla"
,
fullgraph
=
True
,
fullgraph
=
True
,
dynamic
=
False
)
dynamic
=
False
)
super
().
__init__
(
compiled_callable
)
super
().
__init__
(
compiled_callable
,
compilation_level
=
vllm_config
.
compilation_config
.
level
)
def
__call__
(
self
,
*
args
,
is_prompt
:
bool
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
is_prompt
:
bool
,
**
kwargs
):
if
len
(
self
.
compiled_codes
)
<
3
or
not
self
.
use_custom_dispatcher
:
if
len
(
self
.
compiled_codes
)
<
3
or
not
self
.
use_custom_dispatcher
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment