Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
51bb12d1
Unverified
Commit
51bb12d1
authored
Nov 17, 2024
by
youkaichao
Committed by
GitHub
Nov 17, 2024
Browse files
[4/N][torch.compile] clean up set_torch_compile_backend (#10401)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
47826cac
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
49 additions
and
42 deletions
+49
-42
vllm/compilation/backends.py
vllm/compilation/backends.py
+2
-14
vllm/compilation/wrapper.py
vllm/compilation/wrapper.py
+3
-8
vllm/config.py
vllm/config.py
+30
-1
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+3
-4
vllm/plugins/__init__.py
vllm/plugins/__init__.py
+1
-13
vllm/utils.py
vllm/utils.py
+9
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+1
-2
No files found.
vllm/compilation/backends.py
View file @
51bb12d1
...
...
@@ -2,15 +2,14 @@ import copy
import
dataclasses
import
operator
from
contextlib
import
ExitStack
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Set
,
Tuple
,
Union
)
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Set
,
Tuple
from
unittest.mock
import
patch
import
torch
import
torch.fx
as
fx
import
vllm.envs
as
envs
from
vllm.config
import
CompilationConfig
,
CompilationLevel
from
vllm.config
import
CompilationConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
combine_fx_passes
,
weak_ref_tensors
...
...
@@ -684,14 +683,3 @@ class PiecewiseBackend:
entry
.
cudagraph
.
replay
()
return
entry
.
output
def
select_default_backend
(
level
:
int
)
->
Union
[
str
,
Callable
]:
if
level
in
[
CompilationLevel
.
DYNAMO_AS_IS
,
CompilationLevel
.
DYNAMO_ONCE
]:
backend_str
=
"eager"
return
backend_str
assert
level
==
CompilationLevel
.
PIECEWISE
from
vllm.plugins
import
get_current_vllm_config
compilation_config
=
get_current_vllm_config
().
compilation_config
return
VllmBackend
(
compilation_config
)
vllm/compilation/wrapper.py
View file @
51bb12d1
...
...
@@ -32,14 +32,9 @@ class TorchCompileWrapperWithCustomDispatcher:
# default compilation settings
# compiling the forward method
# choose the compile backend
# if the user has set the backend, use it
from
vllm.plugins
import
get_torch_compile_backend
backend
=
get_torch_compile_backend
()
if
backend
is
None
:
from
vllm.compilation.backends
import
select_default_backend
backend
=
select_default_backend
(
compilation_level
)
from
vllm.plugins
import
get_current_vllm_config
backend
=
get_current_vllm_config
(
).
compilation_config
.
init_backend
()
compiled_callable
=
torch
.
compile
(
self
.
forward
,
...
...
vllm/config.py
View file @
51bb12d1
...
...
@@ -22,7 +22,7 @@ from vllm.transformers_utils.config import (
get_hf_text_config
,
get_pooling_config
,
get_sentence_transformer_tokenizer_config
,
is_encoder_decoder
,
uses_mrope
)
from
vllm.utils
import
(
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
identity
,
print_warning_once
)
identity
,
print_warning_once
,
resolve_obj_by_qualname
)
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
...
...
@@ -2072,6 +2072,13 @@ class CompilationConfig(BaseModel):
- 1: dynamo as is.
- 2: dynamo once.
- 3: piecewise compilation.
- backend: the backend for compilation. It needs to be a string.
- "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
- "full.module.name": a qualified name which can be used to import the backend function.
We use string to avoid serialization issues when using compilation in a distributed setting.
When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph).
When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph).
- custom_ops: fine-grained control over which custom ops to enable/disable.
Use 'all' to enable all, 'none' to disable all.
Also specify a list of custom op names to enable (prefixed with a '+'),
...
...
@@ -2139,6 +2146,7 @@ class CompilationConfig(BaseModel):
certain small batchsizes, where inductor is good at optimizing.
"""
# noqa
level
:
int
=
0
backend
:
str
=
""
custom_ops
:
List
[
str
]
=
Field
(
default_factory
=
list
)
use_inductor
:
bool
=
True
...
...
@@ -2182,6 +2190,27 @@ class CompilationConfig(BaseModel):
func
=
__import__
(
module
).
__dict__
[
func_name
]
self
.
inductor_compile_config
[
k
]
=
func
def
init_backend
(
self
)
->
Union
[
str
,
Callable
]:
if
self
.
level
==
CompilationLevel
.
NO_COMPILATION
:
raise
ValueError
(
"No compilation level is set."
)
from
torch._dynamo.backends.registry
import
list_backends
torch_backends
=
list_backends
(
exclude_tags
=
tuple
())
if
self
.
level
in
[
CompilationLevel
.
DYNAMO_AS_IS
,
CompilationLevel
.
DYNAMO_ONCE
]:
if
self
.
backend
==
""
:
return
"eager"
if
self
.
backend
in
torch_backends
:
return
self
.
backend
return
resolve_obj_by_qualname
(
self
.
backend
)
# TODO: pass user-specified backend to piecewise compilation
# merge with the config use_inductor
assert
self
.
level
==
CompilationLevel
.
PIECEWISE
from
vllm.compilation.backends
import
VllmBackend
return
VllmBackend
(
self
)
def
init_during_runtime
(
self
):
"""To complete the initialization of config,
we need to know the compile context, which is only available
...
...
vllm/platforms/tpu.py
View file @
51bb12d1
...
...
@@ -3,8 +3,6 @@ from typing import TYPE_CHECKING
import
torch
from
vllm.plugins
import
set_torch_compile_backend
from
.interface
import
Platform
,
PlatformEnum
if
TYPE_CHECKING
:
...
...
@@ -12,8 +10,6 @@ if TYPE_CHECKING:
else
:
VllmConfig
=
None
set_torch_compile_backend
(
"openxla"
)
class
TpuPlatform
(
Platform
):
_enum
=
PlatformEnum
.
TPU
...
...
@@ -38,3 +34,6 @@ class TpuPlatform(Platform):
compilation_config
.
level
=
CompilationLevel
.
DYNAMO_ONCE
assert
compilation_config
.
level
<
CompilationLevel
.
PIECEWISE
,
\
"TPU does not support Inductor."
if
compilation_config
.
backend
==
""
:
compilation_config
.
backend
=
"openxla"
vllm/plugins/__init__.py
View file @
51bb12d1
import
logging
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
import
vllm.envs
as
envs
...
...
@@ -50,18 +50,6 @@ def load_general_plugins():
logger
.
exception
(
"Failed to load plugin %s"
,
plugin
.
name
)
_torch_compile_backend
:
Optional
[
Union
[
Callable
,
str
]]
=
None
def
set_torch_compile_backend
(
backend
:
Union
[
Callable
,
str
]):
global
_torch_compile_backend
_torch_compile_backend
=
backend
def
get_torch_compile_backend
()
->
Optional
[
Union
[
Callable
,
str
]]:
return
_torch_compile_backend
_compilation_config
:
Optional
[
CompilationConfig
]
=
None
...
...
vllm/utils.py
View file @
51bb12d1
...
...
@@ -1600,3 +1600,12 @@ def direct_register_custom_op(
my_lib
.
impl
(
op_name
,
op_func
,
"CUDA"
)
if
fake_impl
is
not
None
:
my_lib
.
_register_fake
(
op_name
,
fake_impl
)
def
resolve_obj_by_qualname
(
qualname
:
str
)
->
Any
:
"""
Resolve an object by its fully qualified name.
"""
module_name
,
obj_name
=
qualname
.
rsplit
(
"."
,
1
)
module
=
importlib
.
import_module
(
module_name
)
return
getattr
(
module
,
obj_name
)
vllm/worker/model_runner.py
View file @
51bb12d1
...
...
@@ -1143,8 +1143,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if
self
.
vllm_config
.
compilation_config
.
level
==
\
CompilationLevel
.
DYNAMO_AS_IS
and
supports_dynamo
():
from
vllm.plugins
import
get_torch_compile_backend
backend
=
get_torch_compile_backend
()
or
"eager"
backend
=
self
.
vllm_config
.
compilation_config
.
init_backend
()
self
.
model
=
torch
.
compile
(
self
.
model
,
fullgraph
=
envs
.
VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment