Unverified Commit 4fd93750 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[2/N][torch.compile] make compilation cfg part of vllm cfg (#10383)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 661a34fd
...@@ -11,8 +11,8 @@ from torch.library import Library ...@@ -11,8 +11,8 @@ from torch.library import Library
from vllm.compilation.compile_context import set_compile_context from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel from vllm.config import CompilationLevel, VllmConfig
from vllm.config import VllmConfig from vllm.plugins import set_current_vllm_config
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
global_counter = 0 global_counter = 0
...@@ -82,7 +82,9 @@ def test_simple_piecewise_compile(): ...@@ -82,7 +82,9 @@ def test_simple_piecewise_compile():
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
model = SillyModel(vllm_config=VllmConfig(), prefix='') vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')
inputs = torch.randn(100).cuda() inputs = torch.randn(100).cuda()
......
...@@ -15,12 +15,10 @@ from torch import nn ...@@ -15,12 +15,10 @@ from torch import nn
from torch.library import Library from torch.library import Library
from vllm.compilation.compile_context import set_compile_context from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.config import CompilationConfig
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.config import VllmConfig from vllm.plugins import set_compilation_config, set_current_vllm_config
from vllm.plugins import set_compilation_config
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
# create a library to hold the custom op # create a library to hold the custom op
...@@ -272,9 +270,11 @@ def run_model(llama_config, ...@@ -272,9 +270,11 @@ def run_model(llama_config,
CompilationLevel.NO_COMPILATION) CompilationLevel.NO_COMPILATION)
set_compilation_config(None) set_compilation_config(None)
model = LlamaModel(config=llama_config, vllm_config = VllmConfig()
vllm_config=VllmConfig(), with set_current_vllm_config(vllm_config):
prefix="").eval().cuda() model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda()
B = 16 # max batch size B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
...@@ -395,9 +395,11 @@ def benchmark(): ...@@ -395,9 +395,11 @@ def benchmark():
else: else:
set_compilation_config(None) set_compilation_config(None)
model = LlamaModel(config=llama_config, vllm_config = VllmConfig()
vllm_config=VllmConfig(), with set_current_vllm_config(vllm_config):
prefix="").eval().cuda().to(torch.bfloat16) model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda().to(torch.bfloat16)
B = 256 # max batch size B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
......
...@@ -3,7 +3,7 @@ from typing import Dict, List, Optional ...@@ -3,7 +3,7 @@ from typing import Dict, List, Optional
import pytest import pytest
from vllm.compilation.levels import CompilationLevel from vllm.config import CompilationLevel
from vllm.utils import cuda_device_count_stateless from vllm.utils import cuda_device_count_stateless
from ..utils import compare_all_settings from ..utils import compare_all_settings
......
import pytest import pytest
from vllm.compilation.levels import CompilationLevel from vllm.config import CompilationLevel
from ..utils import fork_new_process_for_each_test from ..utils import fork_new_process_for_each_test
from .utils import TEST_MODELS, check_full_graph_support from .utils import TEST_MODELS, check_full_graph_support
......
...@@ -3,10 +3,10 @@ import torch ...@@ -3,10 +3,10 @@ import torch
from compressed_tensors.quantization import FP8_DTYPE from compressed_tensors.quantization import FP8_DTYPE
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.config import CompilationConfig
from vllm.compilation.fusion import (FusionPass, find_auto_fn, from vllm.compilation.fusion import (FusionPass, find_auto_fn,
find_auto_fn_maybe) find_auto_fn_maybe)
from vllm.compilation.reshapes import RedundantReshapesPass from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.config import CompilationConfig
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear) apply_fp8_linear)
......
...@@ -3,6 +3,7 @@ from typing import Optional ...@@ -3,6 +3,7 @@ from typing import Optional
import torch import torch
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationLevel
class MyMod(torch.nn.Module): class MyMod(torch.nn.Module):
...@@ -18,7 +19,8 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher): ...@@ -18,7 +19,8 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
def __init__(self, model): def __init__(self, model):
self.model = model self.model = model
compiled_callable = torch.compile(self.forward, backend="eager") compiled_callable = torch.compile(self.forward, backend="eager")
super().__init__(compiled_callable) super().__init__(compiled_callable,
compilation_level=CompilationLevel.DYNAMO_ONCE)
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
# this is the function to be compiled # this is the function to be compiled
......
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.compilation.levels import CompilationLevel from vllm.config import CompilationLevel
from vllm.platforms import current_platform from vllm.platforms import current_platform
TEST_MODELS = [ TEST_MODELS = [
......
...@@ -3,11 +3,13 @@ from typing import List ...@@ -3,11 +3,13 @@ from typing import List
import pytest import pytest
from vllm.config import CompilationConfig, VllmConfig
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul, from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation, ReLUSquaredActivation,
SiluAndMul) SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.plugins import set_current_vllm_config
# Registered subclass for test # Registered subclass for test
...@@ -51,42 +53,40 @@ class Relu3(ReLUSquaredActivation): ...@@ -51,42 +53,40 @@ class Relu3(ReLUSquaredActivation):
]) ])
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int], def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
default_on: bool): default_on: bool):
os.environ["VLLM_CUSTOM_OPS"] = env
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level) os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on
# Reset default_on (computed once): ops_enabled = [bool(x) for x in ops_enabled]
CustomOp.default_on.cache_clear()
assert CustomOp.default_on() == default_on assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0]
ops_enabled = [bool(x) for x in ops_enabled] assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1]
assert RMSNorm(1024).enabled() == ops_enabled[0] assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
assert SiluAndMul().enabled() == ops_enabled[1] # If registered, subclasses should follow their own name
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
assert GeluAndMul().enabled() == ops_enabled[2] # Unregistered subclass
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] class SiluAndMul2(SiluAndMul):
pass
# If registered, subclasses should follow their own name # Subclasses should not require registration
assert Relu3().enabled() == ops_enabled[3] assert SiluAndMul2().enabled() == SiluAndMul().enabled()
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3]
# Unregistered subclass
class SiluAndMul2(SiluAndMul):
pass
# Subclasses should not require registration
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"]) "env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
def test_enabled_ops_invalid(env: str): def test_enabled_ops_invalid(env: str):
os.environ["VLLM_CUSTOM_OPS"] = env with pytest.raises(Exception): # noqa
CustomOp.default_on.cache_clear() vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
with pytest.raises(AssertionError): with set_current_vllm_config(vllm_config):
RMSNorm(1024).enabled() RMSNorm(1024).enabled()
...@@ -5,7 +5,7 @@ import tempfile ...@@ -5,7 +5,7 @@ import tempfile
import depyf import depyf
from vllm.compilation.levels import CompilationLevel from vllm.config import CompilationLevel
# disable custom dispatcher, let Dynamo takes over # disable custom dispatcher, let Dynamo takes over
# all the control # all the control
......
import os import os
from vllm.compilation.levels import CompilationLevel from vllm.config import CompilationLevel
from ..utils import compare_two_settings from ..utils import compare_two_settings
......
...@@ -10,13 +10,12 @@ import torch ...@@ -10,13 +10,12 @@ import torch
import torch.fx as fx import torch.fx as fx
import vllm.envs as envs import vllm.envs as envs
from vllm.config import CompilationConfig, CompilationLevel
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import combine_fx_passes, weak_ref_tensors from vllm.utils import combine_fx_passes, weak_ref_tensors
from .config import CompilationConfig
from .counter import compilation_counter from .counter import compilation_counter
from .fusion import FusionPass from .fusion import FusionPass
from .levels import CompilationLevel
from .reshapes import RedundantReshapesPass from .reshapes import RedundantReshapesPass
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -392,7 +391,10 @@ class VllmBackend: ...@@ -392,7 +391,10 @@ class VllmBackend:
sym_tensor_indices: List[int] sym_tensor_indices: List[int]
input_buffers: List[torch.Tensor] input_buffers: List[torch.Tensor]
def __init__(self, post_grad_passes: Sequence[Callable] = ()): def __init__(
self,
compilation_configs: CompilationConfig,
):
global global_graph_pool global global_graph_pool
if global_graph_pool is None: if global_graph_pool is None:
global_graph_pool = torch.cuda.graph_pool_handle() global_graph_pool = torch.cuda.graph_pool_handle()
...@@ -401,11 +403,13 @@ class VllmBackend: ...@@ -401,11 +403,13 @@ class VllmBackend:
# streams, it might not be safe to share a global pool. # streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams # only investigate this when we use multiple streams
self.graph_pool = global_graph_pool self.graph_pool = global_graph_pool
self.post_grad_passes = post_grad_passes self.post_grad_passes = []
self.sym_tensor_indices = [] self.sym_tensor_indices = []
self.input_buffers = [] self.input_buffers = []
self.compilation_configs = compilation_configs
# `torch.compile` is JIT compiled, so we don't need to # `torch.compile` is JIT compiled, so we don't need to
# do anything here # do anything here
...@@ -437,10 +441,10 @@ class VllmBackend: ...@@ -437,10 +441,10 @@ class VllmBackend:
assert not self._called, "VllmBackend can only be called once" assert not self._called, "VllmBackend can only be called once"
self.graph = graph self.graph = graph
# config is read now, because only here can # config is updated now, because only here can
# we get the sizes to capture for cudagraph # we get the sizes to capture for cudagraph
# from compilation context # from compilation context
self.compilation_configs = CompilationConfig.select_and_init_config() self.compilation_configs.init_during_runtime()
self.add_passes_to_config() self.add_passes_to_config()
self.split_gm, self.piecewise_graphs = split_graph( self.split_gm, self.piecewise_graphs = split_graph(
...@@ -688,4 +692,6 @@ def select_default_backend(level: int) -> Union[str, Callable]: ...@@ -688,4 +692,6 @@ def select_default_backend(level: int) -> Union[str, Callable]:
return backend_str return backend_str
assert level == CompilationLevel.PIECEWISE assert level == CompilationLevel.PIECEWISE
return VllmBackend() from vllm.plugins import get_current_vllm_config
compilation_config = get_current_vllm_config().compilation_config
return VllmBackend(compilation_config)
import copy
from pathlib import Path
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, PrivateAttr
import vllm.envs as envs
from vllm.logger import init_logger
from .compile_context import get_compile_context
logger = init_logger(__name__)
class CompilationConfig(BaseModel):
"""
Configuration for compilation.
It has two parts:
- CudaGraph capture:
- use_cudagraph: whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses.
Note that this is orthogonal to the cudagraph capture out
side of compilation.
TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future.
- cudagraph_capture_sizes: sizes to capture cudagraph.
- None: capture sizes are inferred from compilation context.
- List[int]: capture sizes are specified.
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
Only after that, the execution will be recorded, and the recorded
cudagraph will be used for subsequent runs.
- cudagraph_copy_inputs: whether to copy input tensors for
cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.
- Inductor compilation:
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for different sizes specified
in inductor_compile_sizes, using configurations
in inductor_compile_config.
- inductor_compile_sizes: sizes to compile for inductor.
- inductor_specialize_for_cudagraph_no_more_than: an optional integer
to specialize inductor for cudagraph sizes no more than the
specified size. It is useful when we want to specialize inductor
with a subset of cudagraph sizes.
- inductor_compile_config: additional configurations for inductor.
- None: use default configurations.
- inductor_passes: additional passes for inductor. It is a dictionary
from pass name to pass function qualified name. We use function
name because the config uses json format. If we pass the config
from Python, functions can also be passed directly via Python object
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
- Custom inductor passes:
- dump_graph_stages: list of stages for which we want to dump the graph.
Each pass defines its own stages (before, after, maybe in-between).
- dump_graph_dir: directory to dump the graph. Default is .
- enable_fusion: whether to enable the custom fusion pass.
TODO better pass enabling system.
Why we have different sizes for cudagraph and inductor:
- cudagraph: a cudagraph captured for a specific size can only be used
for the same size. We need to capture all the sizes we want to use.
- inductor: a graph compiled by inductor for a general shape can be used
for different sizes. Inductor can also compile for specific sizes,
where it can have more information to optimize the graph with fully
static shapes. However, we find the general shape compilation is
sufficient for most cases. It might be beneficial to compile for
certain small batchsizes, where inductor is good at optimizing.
"""
use_inductor: bool = True
inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None
inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict)
inductor_compile_config: Dict = Field(default_factory=dict)
inductor_passes: Dict[str, str] = Field(default_factory=dict)
use_cudagraph: bool = False
non_cudagraph_ops: List[str] = Field(default_factory=list)
cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[List[int]] = None
cudagraph_copy_inputs: bool = False
dump_graph_stages: List[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path("."))
enable_fusion: bool = True
# not configurable, computed after init
compile_sizes: List[int] = PrivateAttr
capture_sizes: List[int] = PrivateAttr
def model_post_init(self, __context: Any) -> None:
for k, v in self.inductor_passes.items():
if not isinstance(v, str):
assert callable(v), (
f"pass {k} should be a function or a qualified name")
self.inductor_compile_config[k] = v
continue
# resolve function from qualified name
names = v.split(".")
module = ".".join(names[:-1])
func_name = names[-1]
func = __import__(module).__dict__[func_name]
self.inductor_compile_config[k] = func
def init_during_runtime(self):
"""To complete the initialization of config,
we need to know the compile context, which is only available
during the first run of the model.
"""
context = get_compile_context()
context = copy.deepcopy(context) if context is not None else []
sizes_to_specialize: List[int] = context
if self.cudagraph_capture_sizes is None:
self.capture_sizes = sizes_to_specialize
else:
self.capture_sizes = self.cudagraph_capture_sizes
logger.info(("cudagraph sizes specified by model runner"
" %s is overridden by config %s"),
sizes_to_specialize, self.cudagraph_capture_sizes)
if self.inductor_specialize_for_cudagraph_no_more_than is not None:
assert self.inductor_compile_sizes is None, (
"inductor_compile_sizes should be None when "
"inductor_specialize_for_cudagraph_no_more_than is not None")
self.compile_sizes = [
x for x in self.capture_sizes
if x <= self.inductor_specialize_for_cudagraph_no_more_than
]
else:
assert self.inductor_compile_sizes is not None, (
"inductor_compile_sizes should not be None when "
"inductor_specialize_for_cudagraph_no_more_than is None")
self.compile_sizes = self.inductor_compile_sizes
@staticmethod
def select_and_init_config() -> "CompilationConfig":
"""The order of selecting config is:
1. Use the config specified in environment variable.
2. Use the config specified in plugins.
3. Use the default config.
"""
config_path = envs.VLLM_TORCH_COMPILE_CONFIG
if config_path is not None:
with open(config_path) as json_file:
config = CompilationConfig.model_validate_json(
json_file.read())
else:
from vllm.plugins import get_compilation_config
predefined_config = get_compilation_config()
config = predefined_config if predefined_config is not None else (
CompilationConfig())
config.init_during_runtime()
return config
...@@ -3,10 +3,8 @@ from typing import Dict, List, Optional, Union ...@@ -3,10 +3,8 @@ from typing import Dict, List, Optional, Union
import torch import torch
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo from vllm.utils import supports_dynamo
...@@ -126,12 +124,14 @@ def _support_torch_compile(cls: type, ...@@ -126,12 +124,14 @@ def _support_torch_compile(cls: type,
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here. # will handle the compilation, so we don't need to do anything here.
self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [ self.do_not_compile = \
vllm_config.compilation_config.level in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo() ] or not supports_dynamo()
if self.do_not_compile: if self.do_not_compile:
return return
TorchCompileWrapperWithCustomDispatcher.__init__(self) TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level)
cls.__init__ = __init__ # type: ignore cls.__init__ = __init__ # type: ignore
......
...@@ -6,8 +6,8 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized ...@@ -6,8 +6,8 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, from torch._inductor.pattern_matcher import (Match, PatternMatcherPass,
fwd_only, register_replacement) fwd_only, register_replacement)
from vllm.compilation.config import CompilationConfig
from vllm.compilation.inductor_pass import InductorPass from vllm.compilation.inductor_pass import InductorPass
from vllm.config import CompilationConfig
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod ...@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
import torch import torch
from vllm.compilation.config import CompilationConfig from vllm.config import CompilationConfig
# yapf: disable # yapf: disable
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
from vllm.distributed import ( from vllm.distributed import (
......
# constants for the levels of the compilation process
class CompilationLevel:
NO_COMPILATION = 0
DYNAMO_AS_IS = 1
DYNAMO_ONCE = 2
PIECEWISE = 3
...@@ -8,8 +8,7 @@ from typing import Callable, List, Optional ...@@ -8,8 +8,7 @@ from typing import Callable, List, Optional
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import CompilationLevel
from .levels import CompilationLevel
class TorchCompileWrapperWithCustomDispatcher: class TorchCompileWrapperWithCustomDispatcher:
...@@ -25,7 +24,9 @@ class TorchCompileWrapperWithCustomDispatcher: ...@@ -25,7 +24,9 @@ class TorchCompileWrapperWithCustomDispatcher:
`torch.compile` over the forward method. `torch.compile` over the forward method.
""" """
def __init__(self, compiled_callable: Optional[Callable] = None): def __init__(self,
compiled_callable: Optional[Callable] = None,
compilation_level: int = 0):
if compiled_callable is None: if compiled_callable is None:
# default compilation settings # default compilation settings
...@@ -38,7 +39,7 @@ class TorchCompileWrapperWithCustomDispatcher: ...@@ -38,7 +39,7 @@ class TorchCompileWrapperWithCustomDispatcher:
backend = get_torch_compile_backend() backend = get_torch_compile_backend()
if backend is None: if backend is None:
from vllm.compilation.backends import select_default_backend from vllm.compilation.backends import select_default_backend
backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL) backend = select_default_backend(compilation_level)
compiled_callable = torch.compile( compiled_callable = torch.compile(
self.forward, self.forward,
...@@ -54,7 +55,7 @@ class TorchCompileWrapperWithCustomDispatcher: ...@@ -54,7 +55,7 @@ class TorchCompileWrapperWithCustomDispatcher:
# subclasses can use this to switch between the custom dispatcher # subclasses can use this to switch between the custom dispatcher
# and the default Dynamo guard mechanism. # and the default Dynamo guard mechanism.
self.use_custom_dispatcher: bool = \ self.use_custom_dispatcher: bool = \
envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE compilation_level >= CompilationLevel.DYNAMO_ONCE
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
"""Implement the dispatch logic here, beyond the torch.compile level. """Implement the dispatch logic here, beyond the torch.compile level.
......
...@@ -3,10 +3,12 @@ import enum ...@@ -3,10 +3,12 @@ import enum
import json import json
import warnings import warnings
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List,
Literal, Mapping, Optional, Set, Tuple, Type, Union) Literal, Mapping, Optional, Set, Tuple, Type, Union)
import torch import torch
from pydantic import BaseModel, Field, PrivateAttr
from transformers import PretrainedConfig from transformers import PretrainedConfig
import vllm.envs as envs import vllm.envs as envs
...@@ -2052,6 +2054,185 @@ class ObservabilityConfig: ...@@ -2052,6 +2054,185 @@ class ObservabilityConfig:
f"installed. Original error:\n{otel_import_error_traceback}") f"installed. Original error:\n{otel_import_error_traceback}")
class CompilationLevel:
# constants for the levels of the compilation process
NO_COMPILATION = 0
DYNAMO_AS_IS = 1
DYNAMO_ONCE = 2
PIECEWISE = 3
class CompilationConfig(BaseModel):
"""
Configuration for compilation.
It has three parts:
- Top-level Compilation control:
- level: the level of compilation.
- 0: no compilation.
- 1: dynamo as is.
- 2: dynamo once.
- 3: piecewise compilation.
- custom_ops: fine-grained control over which custom ops to enable/disable.
Use 'all' to enable all, 'none' to disable all.
Also specify a list of custom op names to enable (prefixed with a '+'),
or disable (prefixed with a '-').
Examples:
- 'all,-op1' to enable all except op1
- 'none,+op1,+op2' to enable only op1 and op2
By default, all custom ops are enabled when running without Inductor
and disabled when running with Inductor (compile_level >= Inductor).
- CudaGraph capture:
- use_cudagraph: whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses.
Note that this is orthogonal to the cudagraph capture out
side of compilation.
TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future.
- cudagraph_capture_sizes: sizes to capture cudagraph.
- None: capture sizes are inferred from compilation context.
- List[int]: capture sizes are specified.
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
Only after that, the execution will be recorded, and the recorded
cudagraph will be used for subsequent runs.
- cudagraph_copy_inputs: whether to copy input tensors for
cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.
- Inductor compilation:
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for different sizes specified
in inductor_compile_sizes, using configurations
in inductor_compile_config.
- inductor_compile_sizes: sizes to compile for inductor.
- inductor_specialize_for_cudagraph_no_more_than: an optional integer
to specialize inductor for cudagraph sizes no more than the
specified size. It is useful when we want to specialize inductor
with a subset of cudagraph sizes.
- inductor_compile_config: additional configurations for inductor.
- None: use default configurations.
- inductor_passes: additional passes for inductor. It is a dictionary
from pass name to pass function qualified name. We use function
name because the config uses json format. If we pass the config
from Python, functions can also be passed directly via Python object
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
- custom inductor passes:
- dump_graph_stages: list of stages for which we want to dump the graph.
Each pass defines its own stages (before, after, maybe in-between).
- dump_graph_dir: directory to dump the graph. Default is .
- enable_fusion: whether to enable the custom fusion pass.
TODO better pass enabling system.
Why we have different sizes for cudagraph and inductor:
- cudagraph: a cudagraph captured for a specific size can only be used
for the same size. We need to capture all the sizes we want to use.
- inductor: a graph compiled by inductor for a general shape can be used
for different sizes. Inductor can also compile for specific sizes,
where it can have more information to optimize the graph with fully
static shapes. However, we find the general shape compilation is
sufficient for most cases. It might be beneficial to compile for
certain small batchsizes, where inductor is good at optimizing.
""" # noqa
level: int = 0
custom_ops: List[str] = Field(default_factory=list)
use_inductor: bool = True
inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None
inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict)
inductor_compile_config: Dict = Field(default_factory=dict)
inductor_passes: Dict[str, str] = Field(default_factory=dict)
use_cudagraph: bool = False
non_cudagraph_ops: List[str] = Field(default_factory=list)
cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[List[int]] = None
cudagraph_copy_inputs: bool = False
dump_graph_stages: List[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path("."))
enable_fusion: bool = True
# not configurable, computed after init
compile_sizes: List[int] = PrivateAttr
capture_sizes: List[int] = PrivateAttr
def model_post_init(self, __context: Any) -> None:
self.level = envs.VLLM_TORCH_COMPILE_LEVEL
count_none = self.custom_ops.count("none")
count_all = self.custom_ops.count("all")
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
for k, v in self.inductor_passes.items():
if not isinstance(v, str):
assert callable(v), (
f"pass {k} should be a function or a qualified name")
self.inductor_compile_config[k] = v
continue
# resolve function from qualified name
names = v.split(".")
module = ".".join(names[:-1])
func_name = names[-1]
func = __import__(module).__dict__[func_name]
self.inductor_compile_config[k] = func
def init_during_runtime(self):
"""To complete the initialization of config,
we need to know the compile context, which is only available
during the first run of the model.
"""
from vllm.compilation.compile_context import get_compile_context
context = get_compile_context()
context = copy.deepcopy(context) if context is not None else []
sizes_to_specialize: List[int] = context
if self.cudagraph_capture_sizes is None:
self.capture_sizes = sizes_to_specialize
else:
self.capture_sizes = self.cudagraph_capture_sizes
logger.info(("cudagraph sizes specified by model runner"
" %s is overridden by config %s"),
sizes_to_specialize, self.cudagraph_capture_sizes)
if self.inductor_specialize_for_cudagraph_no_more_than is not None:
assert self.inductor_compile_sizes is None, (
"inductor_compile_sizes should be None when "
"inductor_specialize_for_cudagraph_no_more_than is not None")
self.compile_sizes = [
x for x in self.capture_sizes
if x <= self.inductor_specialize_for_cudagraph_no_more_than
]
else:
assert self.inductor_compile_sizes is not None, (
"inductor_compile_sizes should not be None when "
"inductor_specialize_for_cudagraph_no_more_than is None")
self.compile_sizes = self.inductor_compile_sizes
@staticmethod
def select_and_init_config() -> "CompilationConfig":
"""The order of selecting config is:
1. Use the config specified in environment variable.
2. Use the config specified in plugins.
3. Use the default config.
"""
config_path = envs.VLLM_TORCH_COMPILE_CONFIG
if config_path is not None:
with open(config_path) as json_file:
config = CompilationConfig.model_validate_json(
json_file.read())
else:
from vllm.plugins import get_compilation_config
predefined_config = get_compilation_config()
config = predefined_config if predefined_config is not None else (
CompilationConfig())
return config
@dataclass @dataclass
class VllmConfig: class VllmConfig:
"""Dataclass which contains all vllm-related configuration. This """Dataclass which contains all vllm-related configuration. This
...@@ -2073,6 +2254,8 @@ class VllmConfig: ...@@ -2073,6 +2254,8 @@ class VllmConfig:
observability_config: Optional[ObservabilityConfig] = None observability_config: Optional[ObservabilityConfig] = None
prompt_adapter_config: Optional[PromptAdapterConfig] = None prompt_adapter_config: Optional[PromptAdapterConfig] = None
quant_config: Optional[QuantizationConfig] = None quant_config: Optional[QuantizationConfig] = None
compilation_config: CompilationConfig = field(default=None,
init=True) # type: ignore
@staticmethod @staticmethod
def _get_quantization_config( def _get_quantization_config(
...@@ -2133,6 +2316,12 @@ class VllmConfig: ...@@ -2133,6 +2316,12 @@ class VllmConfig:
self.quant_config = VllmConfig._get_quantization_config( self.quant_config = VllmConfig._get_quantization_config(
self.model_config, self.load_config) self.model_config, self.load_config)
if self.compilation_config is None:
self.compilation_config = CompilationConfig.select_and_init_config(
)
current_platform.check_and_update_config(self)
def __str__(self): def __str__(self):
return ("model=%r, speculative_config=%r, tokenizer=%r, " return ("model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
......
...@@ -69,7 +69,6 @@ if TYPE_CHECKING: ...@@ -69,7 +69,6 @@ if TYPE_CHECKING:
VLLM_SKIP_P2P_CHECK: bool = False VLLM_SKIP_P2P_CHECK: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0 VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None
VLLM_CUSTOM_OPS: List[str] = []
VLLM_DISABLED_KERNELS: List[str] = [] VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False VLLM_USE_V1: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
...@@ -217,18 +216,6 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -217,18 +216,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_TORCH_COMPILE_CONFIG": "VLLM_TORCH_COMPILE_CONFIG":
lambda: os.environ.get("VLLM_TORCH_COMPILE_CONFIG", None), lambda: os.environ.get("VLLM_TORCH_COMPILE_CONFIG", None),
# Fine-grained control over which custom ops to enable/disable.
# Use 'all' to enable all, 'none' to disable all.
# Also specify a list of custom op names to enable (prefixed with a '+'),
# or disable (prefixed with a '-').
# Examples:
# - 'all,-op1' to enable all except op1
# - 'none,+op1,+op2' to enable only op1 and op2
# By default, all custom ops are enabled when running without Inductor
# and disabled when running with Inductor (compile_level >= Inductor).
"VLLM_CUSTOM_OPS":
lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
# the GPU device id # the GPU device id
"LOCAL_RANK": "LOCAL_RANK":
......
from functools import lru_cache
from typing import Dict, Type from typing import Dict, Type
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.plugins import get_current_vllm_config
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -87,6 +85,8 @@ class CustomOp(nn.Module): ...@@ -87,6 +85,8 @@ class CustomOp(nn.Module):
@classmethod @classmethod
def enabled(cls) -> bool: def enabled(cls) -> bool:
# if no name, then it was not registered # if no name, then it was not registered
compilation_config = get_current_vllm_config().compilation_config
custom_ops = compilation_config.custom_ops
if not hasattr(cls, "name"): if not hasattr(cls, "name"):
print_warning_once( print_warning_once(
f"Custom op {cls.__name__} was not registered, " f"Custom op {cls.__name__} was not registered, "
...@@ -94,22 +94,25 @@ class CustomOp(nn.Module): ...@@ -94,22 +94,25 @@ class CustomOp(nn.Module):
f"It will be enabled/disabled based on the global settings.") f"It will be enabled/disabled based on the global settings.")
return CustomOp.default_on() return CustomOp.default_on()
enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS enabled = f"+{cls.name}" in custom_ops
disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS disabled = f"-{cls.name}" in custom_ops
assert not (enabled assert not (enabled
and disabled), f"Cannot enable and disable {cls.name}" and disabled), f"Cannot enable and disable {cls.name}"
return (CustomOp.default_on() or enabled) and not disabled return (CustomOp.default_on() or enabled) and not disabled
# On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE
# Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
@staticmethod @staticmethod
@lru_cache
def default_on() -> bool: def default_on() -> bool:
count_none = envs.VLLM_CUSTOM_OPS.count("none") """
count_all = envs.VLLM_CUSTOM_OPS.count("all") On by default if level < CompilationLevel.PIECEWISE
assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" Specifying 'all' or 'none' in custom_op takes precedence.
return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE and \ """
from vllm.config import CompilationLevel
compilation_config = get_current_vllm_config().compilation_config
custom_ops = compilation_config.custom_ops
count_none = custom_ops.count("none")
count_all = custom_ops.count("all")
return compilation_config.level < CompilationLevel.PIECEWISE and \
not count_none > 0 or count_all > 0 not count_none > 0 or count_all > 0
# Dictionary of all custom ops (classes, indexed by registered name). # Dictionary of all custom ops (classes, indexed by registered name).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment