Unverified Commit 4fd93750 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[2/N][torch.compile] make compilation cfg part of vllm cfg (#10383)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 661a34fd
...@@ -42,6 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -42,6 +42,7 @@ from vllm.model_executor.model_loader.weight_utils import (
safetensors_weights_iterator) safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.plugins import set_current_vllm_config
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
...@@ -97,6 +98,7 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: ...@@ -97,6 +98,7 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
all_params = [param.name for param in signatures.parameters.values()] all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params: if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class # new-style model class
with set_current_vllm_config(vllm_config):
return model_class(vllm_config=vllm_config, prefix=prefix) return model_class(vllm_config=vllm_config, prefix=prefix)
msg = ("vLLM model class should accept `vllm_config` and `prefix` as " msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class" "input arguments. Possibly you have an old-style model class"
...@@ -121,6 +123,7 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: ...@@ -121,6 +123,7 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
kwargs["lora_config"] = vllm_config.lora_config kwargs["lora_config"] = vllm_config.lora_config
if "scheduler_config" in all_params: if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config kwargs["scheduler_config"] = vllm_config.scheduler_config
with set_current_vllm_config(vllm_config):
return model_class(**kwargs) return model_class(**kwargs)
......
import enum import enum
import random import random
from typing import NamedTuple, Optional, Tuple, Union from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
class PlatformEnum(enum.Enum): class PlatformEnum(enum.Enum):
CUDA = enum.auto() CUDA = enum.auto()
...@@ -129,6 +134,19 @@ class Platform: ...@@ -129,6 +134,19 @@ class Platform:
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"""
Check and update the configuration for the current platform.
It can raise an exception if the configuration is not compatible with
the current platform, or it can update the configuration to make it
compatible with the current platform.
The config is passed by reference, so it can be modified in place.
"""
pass
class UnspecifiedPlatform(Platform): class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED _enum = PlatformEnum.UNSPECIFIED
import os import os
from typing import TYPE_CHECKING
import torch import torch
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.plugins import set_torch_compile_backend from vllm.plugins import set_torch_compile_backend
from .interface import Platform, PlatformEnum from .interface import Platform, PlatformEnum
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: if TYPE_CHECKING:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE) from vllm.config import VllmConfig
else:
assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE,\ VllmConfig = None
"TPU does not support Inductor."
set_torch_compile_backend("openxla") set_torch_compile_backend("openxla")
...@@ -31,3 +29,12 @@ class TpuPlatform(Platform): ...@@ -31,3 +29,12 @@ class TpuPlatform(Platform):
@classmethod @classmethod
def inference_mode(cls): def inference_mode(cls):
return torch.no_grad() return torch.no_grad()
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationLevel
compilation_config = vllm_config.compilation_config
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
compilation_config.level = CompilationLevel.DYNAMO_ONCE
assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."
import logging import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Optional, Union from typing import TYPE_CHECKING, Callable, Optional, Union
import vllm.envs as envs import vllm.envs as envs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.compilation.config import CompilationConfig from vllm.config import CompilationConfig, VllmConfig
from vllm.config import VllmConfig
else: else:
CompilationConfig = None CompilationConfig = None
VllmConfig = None VllmConfig = None
...@@ -72,3 +72,29 @@ def set_compilation_config(config: Optional[CompilationConfig]): ...@@ -72,3 +72,29 @@ def set_compilation_config(config: Optional[CompilationConfig]):
def get_compilation_config() -> Optional[CompilationConfig]: def get_compilation_config() -> Optional[CompilationConfig]:
return _compilation_config return _compilation_config
_current_vllm_config: Optional[VllmConfig] = None
@contextmanager
def set_current_vllm_config(vllm_config: VllmConfig):
"""
Temporarily set the current VLLM config.
Used during model initialization.
We save the current VLLM config in a global variable,
so that all modules can access it, e.g. custom ops
can access the VLLM config to determine how to dispatch.
"""
global _current_vllm_config
old_vllm_config = _current_vllm_config
try:
_current_vllm_config = vllm_config
yield
finally:
_current_vllm_config = old_vllm_config
def get_current_vllm_config() -> VllmConfig:
assert _current_vllm_config is not None, "Current VLLM config is not set."
return _current_vllm_config
import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
...@@ -8,11 +7,8 @@ import torch ...@@ -8,11 +7,8 @@ import torch
import torch.distributed import torch.distributed
import torch.nn as nn import torch.nn as nn
from vllm import envs
from vllm.compilation.compile_context import set_compile_context from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.config import CompilationConfig from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -99,7 +95,7 @@ class GPUModelRunner: ...@@ -99,7 +95,7 @@ class GPUModelRunner:
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
self.use_cuda_graph = (envs.VLLM_TORCH_COMPILE_LEVEL self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE == CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager) and not self.model_config.enforce_eager)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size. # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
...@@ -517,9 +513,9 @@ class GPUModelRunner: ...@@ -517,9 +513,9 @@ class GPUModelRunner:
# CUDA graphs do not work properly with the custom CUDA kernels. # CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time # FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor. # and avoid any potential issues with the inductor.
os.environ["VLLM_CUSTOM_OPS"] = "none"
set_compilation_config( set_compilation_config(
CompilationConfig( CompilationConfig(
custom_ops=["none"],
use_cudagraph=True, use_cudagraph=True,
non_cudagraph_ops=["vllm.unified_v1_flash_attention"], non_cudagraph_ops=["vllm.unified_v1_flash_attention"],
use_inductor=True, use_inductor=True,
......
...@@ -19,8 +19,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend ...@@ -19,8 +19,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.compilation.compile_context import set_compile_context from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.levels import CompilationLevel from vllm.config import CompilationLevel, VllmConfig
from vllm.config import VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
...@@ -1142,8 +1141,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1142,8 +1141,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"provided. Defaulting to scaling factors of 1.0. " "provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!") "This may lead to less accurate results!")
if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \ if self.vllm_config.compilation_config.level ==\
and supports_dynamo(): CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
from vllm.plugins import get_torch_compile_backend from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend() or "eager" backend = get_torch_compile_backend() or "eager"
self.model = torch.compile( self.model = torch.compile(
......
...@@ -140,7 +140,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -140,7 +140,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
model = get_model(vllm_config=self.vllm_config) model = get_model(vllm_config=self.vllm_config)
model = model.eval() model = model.eval()
xm.wait_device_ops() xm.wait_device_ops()
self.model = ModelWrapper(model) self.model = ModelWrapper(model, self.vllm_config)
def _dummy_run( def _dummy_run(
self, self,
...@@ -669,13 +669,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -669,13 +669,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
def __init__(self, model: nn.Module): def __init__(self, model: nn.Module, vllm_config: VllmConfig):
self.model = model self.model = model
compiled_callable = torch.compile(self.forward, compiled_callable = torch.compile(self.forward,
backend="openxla", backend="openxla",
fullgraph=True, fullgraph=True,
dynamic=False) dynamic=False)
super().__init__(compiled_callable) super().__init__(
compiled_callable,
compilation_level=vllm_config.compilation_config.level)
def __call__(self, *args, is_prompt: bool, **kwargs): def __call__(self, *args, is_prompt: bool, **kwargs):
if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment