Unverified Commit 51bb12d1 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[4/N][torch.compile] clean up set_torch_compile_backend (#10401)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 47826cac
...@@ -2,15 +2,14 @@ import copy ...@@ -2,15 +2,14 @@ import copy
import dataclasses import dataclasses
import operator import operator
from contextlib import ExitStack from contextlib import ExitStack
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
Union)
from unittest.mock import patch from unittest.mock import patch
import torch import torch
import torch.fx as fx import torch.fx as fx
import vllm.envs as envs import vllm.envs as envs
from vllm.config import CompilationConfig, CompilationLevel from vllm.config import CompilationConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import combine_fx_passes, weak_ref_tensors from vllm.utils import combine_fx_passes, weak_ref_tensors
...@@ -684,14 +683,3 @@ class PiecewiseBackend: ...@@ -684,14 +683,3 @@ class PiecewiseBackend:
entry.cudagraph.replay() entry.cudagraph.replay()
return entry.output return entry.output
def select_default_backend(level: int) -> Union[str, Callable]:
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
backend_str = "eager"
return backend_str
assert level == CompilationLevel.PIECEWISE
from vllm.plugins import get_current_vllm_config
compilation_config = get_current_vllm_config().compilation_config
return VllmBackend(compilation_config)
...@@ -32,14 +32,9 @@ class TorchCompileWrapperWithCustomDispatcher: ...@@ -32,14 +32,9 @@ class TorchCompileWrapperWithCustomDispatcher:
# default compilation settings # default compilation settings
# compiling the forward method # compiling the forward method
# choose the compile backend from vllm.plugins import get_current_vllm_config
backend = get_current_vllm_config(
# if the user has set the backend, use it ).compilation_config.init_backend()
from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend()
if backend is None:
from vllm.compilation.backends import select_default_backend
backend = select_default_backend(compilation_level)
compiled_callable = torch.compile( compiled_callable = torch.compile(
self.forward, self.forward,
......
...@@ -22,7 +22,7 @@ from vllm.transformers_utils.config import ( ...@@ -22,7 +22,7 @@ from vllm.transformers_utils.config import (
get_hf_text_config, get_pooling_config, get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
identity, print_warning_once) identity, print_warning_once, resolve_obj_by_qualname)
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
...@@ -2072,6 +2072,13 @@ class CompilationConfig(BaseModel): ...@@ -2072,6 +2072,13 @@ class CompilationConfig(BaseModel):
- 1: dynamo as is. - 1: dynamo as is.
- 2: dynamo once. - 2: dynamo once.
- 3: piecewise compilation. - 3: piecewise compilation.
- backend: the backend for compilation. It needs to be a string.
- "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
- "full.module.name": a qualified name which can be used to import the backend function.
We use string to avoid serialization issues when using compilation in a distributed setting.
When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph).
When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph).
- custom_ops: fine-grained control over which custom ops to enable/disable. - custom_ops: fine-grained control over which custom ops to enable/disable.
Use 'all' to enable all, 'none' to disable all. Use 'all' to enable all, 'none' to disable all.
Also specify a list of custom op names to enable (prefixed with a '+'), Also specify a list of custom op names to enable (prefixed with a '+'),
...@@ -2139,6 +2146,7 @@ class CompilationConfig(BaseModel): ...@@ -2139,6 +2146,7 @@ class CompilationConfig(BaseModel):
certain small batchsizes, where inductor is good at optimizing. certain small batchsizes, where inductor is good at optimizing.
""" # noqa """ # noqa
level: int = 0 level: int = 0
backend: str = ""
custom_ops: List[str] = Field(default_factory=list) custom_ops: List[str] = Field(default_factory=list)
use_inductor: bool = True use_inductor: bool = True
...@@ -2182,6 +2190,27 @@ class CompilationConfig(BaseModel): ...@@ -2182,6 +2190,27 @@ class CompilationConfig(BaseModel):
func = __import__(module).__dict__[func_name] func = __import__(module).__dict__[func_name]
self.inductor_compile_config[k] = func self.inductor_compile_config[k] = func
def init_backend(self) -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
from torch._dynamo.backends.registry import list_backends
torch_backends = list_backends(exclude_tags=tuple())
if self.level in [
CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE
]:
if self.backend == "":
return "eager"
if self.backend in torch_backends:
return self.backend
return resolve_obj_by_qualname(self.backend)
# TODO: pass user-specified backend to piecewise compilation
# merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE
from vllm.compilation.backends import VllmBackend
return VllmBackend(self)
def init_during_runtime(self): def init_during_runtime(self):
"""To complete the initialization of config, """To complete the initialization of config,
we need to know the compile context, which is only available we need to know the compile context, which is only available
......
...@@ -3,8 +3,6 @@ from typing import TYPE_CHECKING ...@@ -3,8 +3,6 @@ from typing import TYPE_CHECKING
import torch import torch
from vllm.plugins import set_torch_compile_backend
from .interface import Platform, PlatformEnum from .interface import Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -12,8 +10,6 @@ if TYPE_CHECKING: ...@@ -12,8 +10,6 @@ if TYPE_CHECKING:
else: else:
VllmConfig = None VllmConfig = None
set_torch_compile_backend("openxla")
class TpuPlatform(Platform): class TpuPlatform(Platform):
_enum = PlatformEnum.TPU _enum = PlatformEnum.TPU
...@@ -38,3 +34,6 @@ class TpuPlatform(Platform): ...@@ -38,3 +34,6 @@ class TpuPlatform(Platform):
compilation_config.level = CompilationLevel.DYNAMO_ONCE compilation_config.level = CompilationLevel.DYNAMO_ONCE
assert compilation_config.level < CompilationLevel.PIECEWISE,\ assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor." "TPU does not support Inductor."
if compilation_config.backend == "":
compilation_config.backend = "openxla"
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Optional, Union from typing import TYPE_CHECKING, Optional
import vllm.envs as envs import vllm.envs as envs
...@@ -50,18 +50,6 @@ def load_general_plugins(): ...@@ -50,18 +50,6 @@ def load_general_plugins():
logger.exception("Failed to load plugin %s", plugin.name) logger.exception("Failed to load plugin %s", plugin.name)
_torch_compile_backend: Optional[Union[Callable, str]] = None
def set_torch_compile_backend(backend: Union[Callable, str]):
global _torch_compile_backend
_torch_compile_backend = backend
def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
return _torch_compile_backend
_compilation_config: Optional[CompilationConfig] = None _compilation_config: Optional[CompilationConfig] = None
......
...@@ -1600,3 +1600,12 @@ def direct_register_custom_op( ...@@ -1600,3 +1600,12 @@ def direct_register_custom_op(
my_lib.impl(op_name, op_func, "CUDA") my_lib.impl(op_name, op_func, "CUDA")
if fake_impl is not None: if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl) my_lib._register_fake(op_name, fake_impl)
def resolve_obj_by_qualname(qualname: str) -> Any:
"""
Resolve an object by its fully qualified name.
"""
module_name, obj_name = qualname.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, obj_name)
...@@ -1143,8 +1143,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1143,8 +1143,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if self.vllm_config.compilation_config.level ==\ if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
from vllm.plugins import get_torch_compile_backend backend = self.vllm_config.compilation_config.init_backend()
backend = get_torch_compile_backend() or "eager"
self.model = torch.compile( self.model = torch.compile(
self.model, self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment