Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
51bb12d1
Unverified
Commit
51bb12d1
authored
Nov 17, 2024
by
youkaichao
Committed by
GitHub
Nov 17, 2024
Browse files
[4/N][torch.compile] clean up set_torch_compile_backend (#10401)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
47826cac
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
49 additions
and
42 deletions
+49
-42
vllm/compilation/backends.py
vllm/compilation/backends.py
+2
-14
vllm/compilation/wrapper.py
vllm/compilation/wrapper.py
+3
-8
vllm/config.py
vllm/config.py
+30
-1
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+3
-4
vllm/plugins/__init__.py
vllm/plugins/__init__.py
+1
-13
vllm/utils.py
vllm/utils.py
+9
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+1
-2
No files found.
vllm/compilation/backends.py
View file @
51bb12d1
...
@@ -2,15 +2,14 @@ import copy
...
@@ -2,15 +2,14 @@ import copy
import
dataclasses
import
dataclasses
import
operator
import
operator
from
contextlib
import
ExitStack
from
contextlib
import
ExitStack
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Set
,
Tuple
,
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Set
,
Tuple
Union
)
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
torch
import
torch
import
torch.fx
as
fx
import
torch.fx
as
fx
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
CompilationConfig
,
CompilationLevel
from
vllm.config
import
CompilationConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
combine_fx_passes
,
weak_ref_tensors
from
vllm.utils
import
combine_fx_passes
,
weak_ref_tensors
...
@@ -684,14 +683,3 @@ class PiecewiseBackend:
...
@@ -684,14 +683,3 @@ class PiecewiseBackend:
entry
.
cudagraph
.
replay
()
entry
.
cudagraph
.
replay
()
return
entry
.
output
return
entry
.
output
def
select_default_backend
(
level
:
int
)
->
Union
[
str
,
Callable
]:
if
level
in
[
CompilationLevel
.
DYNAMO_AS_IS
,
CompilationLevel
.
DYNAMO_ONCE
]:
backend_str
=
"eager"
return
backend_str
assert
level
==
CompilationLevel
.
PIECEWISE
from
vllm.plugins
import
get_current_vllm_config
compilation_config
=
get_current_vllm_config
().
compilation_config
return
VllmBackend
(
compilation_config
)
vllm/compilation/wrapper.py
View file @
51bb12d1
...
@@ -32,14 +32,9 @@ class TorchCompileWrapperWithCustomDispatcher:
...
@@ -32,14 +32,9 @@ class TorchCompileWrapperWithCustomDispatcher:
# default compilation settings
# default compilation settings
# compiling the forward method
# compiling the forward method
# choose the compile backend
from
vllm.plugins
import
get_current_vllm_config
backend
=
get_current_vllm_config
(
# if the user has set the backend, use it
).
compilation_config
.
init_backend
()
from
vllm.plugins
import
get_torch_compile_backend
backend
=
get_torch_compile_backend
()
if
backend
is
None
:
from
vllm.compilation.backends
import
select_default_backend
backend
=
select_default_backend
(
compilation_level
)
compiled_callable
=
torch
.
compile
(
compiled_callable
=
torch
.
compile
(
self
.
forward
,
self
.
forward
,
...
...
vllm/config.py
View file @
51bb12d1
...
@@ -22,7 +22,7 @@ from vllm.transformers_utils.config import (
...
@@ -22,7 +22,7 @@ from vllm.transformers_utils.config import (
get_hf_text_config
,
get_pooling_config
,
get_hf_text_config
,
get_pooling_config
,
get_sentence_transformer_tokenizer_config
,
is_encoder_decoder
,
uses_mrope
)
get_sentence_transformer_tokenizer_config
,
is_encoder_decoder
,
uses_mrope
)
from
vllm.utils
import
(
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
from
vllm.utils
import
(
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
identity
,
print_warning_once
)
identity
,
print_warning_once
,
resolve_obj_by_qualname
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
from
ray.util.placement_group
import
PlacementGroup
...
@@ -2072,6 +2072,13 @@ class CompilationConfig(BaseModel):
...
@@ -2072,6 +2072,13 @@ class CompilationConfig(BaseModel):
- 1: dynamo as is.
- 1: dynamo as is.
- 2: dynamo once.
- 2: dynamo once.
- 3: piecewise compilation.
- 3: piecewise compilation.
- backend: the backend for compilation. It needs to be a string.
- "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
- "full.module.name": a qualified name which can be used to import the backend function.
We use string to avoid serialization issues when using compilation in a distributed setting.
When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph).
When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph).
- custom_ops: fine-grained control over which custom ops to enable/disable.
- custom_ops: fine-grained control over which custom ops to enable/disable.
Use 'all' to enable all, 'none' to disable all.
Use 'all' to enable all, 'none' to disable all.
Also specify a list of custom op names to enable (prefixed with a '+'),
Also specify a list of custom op names to enable (prefixed with a '+'),
...
@@ -2139,6 +2146,7 @@ class CompilationConfig(BaseModel):
...
@@ -2139,6 +2146,7 @@ class CompilationConfig(BaseModel):
certain small batchsizes, where inductor is good at optimizing.
certain small batchsizes, where inductor is good at optimizing.
"""
# noqa
"""
# noqa
level
:
int
=
0
level
:
int
=
0
backend
:
str
=
""
custom_ops
:
List
[
str
]
=
Field
(
default_factory
=
list
)
custom_ops
:
List
[
str
]
=
Field
(
default_factory
=
list
)
use_inductor
:
bool
=
True
use_inductor
:
bool
=
True
...
@@ -2182,6 +2190,27 @@ class CompilationConfig(BaseModel):
...
@@ -2182,6 +2190,27 @@ class CompilationConfig(BaseModel):
func
=
__import__
(
module
).
__dict__
[
func_name
]
func
=
__import__
(
module
).
__dict__
[
func_name
]
self
.
inductor_compile_config
[
k
]
=
func
self
.
inductor_compile_config
[
k
]
=
func
def
init_backend
(
self
)
->
Union
[
str
,
Callable
]:
if
self
.
level
==
CompilationLevel
.
NO_COMPILATION
:
raise
ValueError
(
"No compilation level is set."
)
from
torch._dynamo.backends.registry
import
list_backends
torch_backends
=
list_backends
(
exclude_tags
=
tuple
())
if
self
.
level
in
[
CompilationLevel
.
DYNAMO_AS_IS
,
CompilationLevel
.
DYNAMO_ONCE
]:
if
self
.
backend
==
""
:
return
"eager"
if
self
.
backend
in
torch_backends
:
return
self
.
backend
return
resolve_obj_by_qualname
(
self
.
backend
)
# TODO: pass user-specified backend to piecewise compilation
# merge with the config use_inductor
assert
self
.
level
==
CompilationLevel
.
PIECEWISE
from
vllm.compilation.backends
import
VllmBackend
return
VllmBackend
(
self
)
def
init_during_runtime
(
self
):
def
init_during_runtime
(
self
):
"""To complete the initialization of config,
"""To complete the initialization of config,
we need to know the compile context, which is only available
we need to know the compile context, which is only available
...
...
vllm/platforms/tpu.py
View file @
51bb12d1
...
@@ -3,8 +3,6 @@ from typing import TYPE_CHECKING
...
@@ -3,8 +3,6 @@ from typing import TYPE_CHECKING
import
torch
import
torch
from
vllm.plugins
import
set_torch_compile_backend
from
.interface
import
Platform
,
PlatformEnum
from
.interface
import
Platform
,
PlatformEnum
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -12,8 +10,6 @@ if TYPE_CHECKING:
...
@@ -12,8 +10,6 @@ if TYPE_CHECKING:
else
:
else
:
VllmConfig
=
None
VllmConfig
=
None
set_torch_compile_backend
(
"openxla"
)
class
TpuPlatform
(
Platform
):
class
TpuPlatform
(
Platform
):
_enum
=
PlatformEnum
.
TPU
_enum
=
PlatformEnum
.
TPU
...
@@ -38,3 +34,6 @@ class TpuPlatform(Platform):
...
@@ -38,3 +34,6 @@ class TpuPlatform(Platform):
compilation_config
.
level
=
CompilationLevel
.
DYNAMO_ONCE
compilation_config
.
level
=
CompilationLevel
.
DYNAMO_ONCE
assert
compilation_config
.
level
<
CompilationLevel
.
PIECEWISE
,
\
assert
compilation_config
.
level
<
CompilationLevel
.
PIECEWISE
,
\
"TPU does not support Inductor."
"TPU does not support Inductor."
if
compilation_config
.
backend
==
""
:
compilation_config
.
backend
=
"openxla"
vllm/plugins/__init__.py
View file @
51bb12d1
import
logging
import
logging
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -50,18 +50,6 @@ def load_general_plugins():
...
@@ -50,18 +50,6 @@ def load_general_plugins():
logger
.
exception
(
"Failed to load plugin %s"
,
plugin
.
name
)
logger
.
exception
(
"Failed to load plugin %s"
,
plugin
.
name
)
_torch_compile_backend
:
Optional
[
Union
[
Callable
,
str
]]
=
None
def
set_torch_compile_backend
(
backend
:
Union
[
Callable
,
str
]):
global
_torch_compile_backend
_torch_compile_backend
=
backend
def
get_torch_compile_backend
()
->
Optional
[
Union
[
Callable
,
str
]]:
return
_torch_compile_backend
_compilation_config
:
Optional
[
CompilationConfig
]
=
None
_compilation_config
:
Optional
[
CompilationConfig
]
=
None
...
...
vllm/utils.py
View file @
51bb12d1
...
@@ -1600,3 +1600,12 @@ def direct_register_custom_op(
...
@@ -1600,3 +1600,12 @@ def direct_register_custom_op(
my_lib
.
impl
(
op_name
,
op_func
,
"CUDA"
)
my_lib
.
impl
(
op_name
,
op_func
,
"CUDA"
)
if
fake_impl
is
not
None
:
if
fake_impl
is
not
None
:
my_lib
.
_register_fake
(
op_name
,
fake_impl
)
my_lib
.
_register_fake
(
op_name
,
fake_impl
)
def
resolve_obj_by_qualname
(
qualname
:
str
)
->
Any
:
"""
Resolve an object by its fully qualified name.
"""
module_name
,
obj_name
=
qualname
.
rsplit
(
"."
,
1
)
module
=
importlib
.
import_module
(
module_name
)
return
getattr
(
module
,
obj_name
)
vllm/worker/model_runner.py
View file @
51bb12d1
...
@@ -1143,8 +1143,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1143,8 +1143,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if
self
.
vllm_config
.
compilation_config
.
level
==
\
if
self
.
vllm_config
.
compilation_config
.
level
==
\
CompilationLevel
.
DYNAMO_AS_IS
and
supports_dynamo
():
CompilationLevel
.
DYNAMO_AS_IS
and
supports_dynamo
():
from
vllm.plugins
import
get_torch_compile_backend
backend
=
self
.
vllm_config
.
compilation_config
.
init_backend
()
backend
=
get_torch_compile_backend
()
or
"eager"
self
.
model
=
torch
.
compile
(
self
.
model
=
torch
.
compile
(
self
.
model
,
self
.
model
,
fullgraph
=
envs
.
VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE
,
fullgraph
=
envs
.
VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment