Unverified Commit aaf4b70a authored by Lucas Kabela's avatar Lucas Kabela Committed by GitHub
Browse files

[Misc][BE] Type coverage for vllm/compilation [2/3] (#31744)

parent 3adffd5b
...@@ -179,7 +179,7 @@ class CompilerManager: ...@@ -179,7 +179,7 @@ class CompilerManager:
example_inputs: list[Any], example_inputs: list[Any],
graph_index: int, graph_index: int,
compile_range: Range, compile_range: Range,
) -> Callable | None: ) -> Callable[..., Any] | None:
if (compile_range, graph_index, self.compiler.name) not in self.cache: if (compile_range, graph_index, self.compiler.name) not in self.cache:
return None return None
handle = self.cache[(compile_range, graph_index, self.compiler.name)] handle = self.cache[(compile_range, graph_index, self.compiler.name)]
...@@ -199,7 +199,7 @@ class CompilerManager: ...@@ -199,7 +199,7 @@ class CompilerManager:
self, self,
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: list[Any], example_inputs: list[Any],
additional_inductor_config, additional_inductor_config: dict[str, Any],
compilation_config: CompilationConfig, compilation_config: CompilationConfig,
compile_range: Range, compile_range: Range,
graph_index: int = 0, graph_index: int = 0,
...@@ -355,7 +355,7 @@ def split_graph( ...@@ -355,7 +355,7 @@ def split_graph(
compilation_start_time = 0.0 compilation_start_time = 0.0
class PiecewiseCompileInterpreter(torch.fx.Interpreter): class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given graph with fake inputs, and compile some It runs the given graph with fake inputs, and compile some
submodules specified by `compile_submod_names` with the given submodules specified by `compile_submod_names` with the given
...@@ -506,9 +506,9 @@ class VllmBackend: ...@@ -506,9 +506,9 @@ class VllmBackend:
# the stiching graph module for all the piecewise graphs # the stiching graph module for all the piecewise graphs
split_gm: fx.GraphModule split_gm: fx.GraphModule
piecewise_graphs: list[SplitItem] piecewise_graphs: list[SplitItem]
returned_callable: Callable returned_callable: Callable[..., Any]
# Inductor passes to run on the graph pre-defunctionalization # Inductor passes to run on the graph pre-defunctionalization
post_grad_passes: Sequence[Callable] post_grad_passes: Sequence[Callable[..., Any]]
sym_tensor_indices: list[int] sym_tensor_indices: list[int]
input_buffers: list[torch.Tensor] input_buffers: list[torch.Tensor]
compiler_manager: CompilerManager compiler_manager: CompilerManager
...@@ -821,7 +821,7 @@ class VllmBackend: ...@@ -821,7 +821,7 @@ class VllmBackend:
] ]
# this is the callable we return to Dynamo to run # this is the callable we return to Dynamo to run
def copy_and_call(*args): def copy_and_call(*args: Any) -> Any:
list_args = list(args) list_args = list(args)
for i, index in enumerate(self.sym_tensor_indices): for i, index in enumerate(self.sym_tensor_indices):
runtime_tensor = list_args[index] runtime_tensor = list_args[index]
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
import inspect import inspect
import os import os
import pickle import pickle
from collections.abc import Callable, Sequence
from typing import Any, Literal
from unittest.mock import patch from unittest.mock import patch
import torch import torch
...@@ -25,7 +27,7 @@ assert isinstance(SerializableCallable, type) ...@@ -25,7 +27,7 @@ assert isinstance(SerializableCallable, type)
logger = init_logger(__name__) logger = init_logger(__name__)
class VllmSerializableFunction(SerializableCallable): class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
""" """
A wrapper around a compiled function by vllm. It will forward the tensor A wrapper around a compiled function by vllm. It will forward the tensor
inputs to the compiled function and return the result. inputs to the compiled function and return the result.
...@@ -38,8 +40,13 @@ class VllmSerializableFunction(SerializableCallable): ...@@ -38,8 +40,13 @@ class VllmSerializableFunction(SerializableCallable):
""" """
def __init__( def __init__(
self, graph_module, example_inputs, prefix, optimized_call, is_encoder=False self,
): graph_module: torch.fx.GraphModule,
example_inputs: Sequence[Any],
prefix: str,
optimized_call: Callable[..., Any],
is_encoder: bool = False,
) -> None:
assert isinstance(graph_module, torch.fx.GraphModule) assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module self.graph_module = graph_module
self.example_inputs = example_inputs self.example_inputs = example_inputs
...@@ -53,7 +60,7 @@ class VllmSerializableFunction(SerializableCallable): ...@@ -53,7 +60,7 @@ class VllmSerializableFunction(SerializableCallable):
if sym_input is not None: if sym_input is not None:
self.shape_env = sym_input.node.shape_env self.shape_env = sym_input.node.shape_env
def __call__(self, *args, **kwargs): def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.optimized_call(*args, **kwargs) return self.optimized_call(*args, **kwargs)
@classmethod @classmethod
...@@ -73,7 +80,9 @@ class VllmSerializableFunction(SerializableCallable): ...@@ -73,7 +80,9 @@ class VllmSerializableFunction(SerializableCallable):
graph_reducer_override = GraphPickler.reducer_override graph_reducer_override = GraphPickler.reducer_override
def _graph_reducer_override(self, obj): def _graph_reducer_override(
self: GraphPickler, obj: Any
) -> tuple[Callable[..., Any], tuple[Any, ...]] | Any:
if ( if (
inspect.isclass(obj) inspect.isclass(obj)
and issubclass(obj, sympy.Function) and issubclass(obj, sympy.Function)
...@@ -114,7 +123,7 @@ class VllmSerializableFunction(SerializableCallable): ...@@ -114,7 +123,7 @@ class VllmSerializableFunction(SerializableCallable):
get_current_vllm_config(), state["prefix"], is_encoder get_current_vllm_config(), state["prefix"], is_encoder
) )
def optimized_call(*example_inputs): def optimized_call(*example_inputs: Any) -> Any:
""" """
On the first run of the optimized call, we rerun the compiler On the first run of the optimized call, we rerun the compiler
backend which should result in a cache hit. After the backend backend which should result in a cache hit. After the backend
...@@ -136,7 +145,7 @@ class VllmSerializableFunction(SerializableCallable): ...@@ -136,7 +145,7 @@ class VllmSerializableFunction(SerializableCallable):
return fn return fn
@property @property
def co_name(self): def co_name(self) -> Literal["VllmSerializableFunction"]:
""" """
Used for depyf debugging. Used for depyf debugging.
""" """
......
...@@ -42,7 +42,9 @@ class CUDAGraphLogging: ...@@ -42,7 +42,9 @@ class CUDAGraphLogging:
"Count", "Count",
] ]
def __init__(self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None): def __init__(
self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None
) -> None:
self.reset() self.reset()
self.cg_mode = str(cg_mode) self.cg_mode = str(cg_mode)
self.cg_capture_sizes = str(cg_capture_sizes or []) self.cg_capture_sizes = str(cg_capture_sizes or [])
...@@ -54,10 +56,10 @@ class CUDAGraphLogging: ...@@ -54,10 +56,10 @@ class CUDAGraphLogging:
"**CUDAGraph Stats:**\n\n" "**CUDAGraph Stats:**\n\n"
) )
def reset(self): def reset(self) -> None:
self.stats = [] self.stats: list[CUDAGraphStat] = []
def observe(self, cudagraph_stat: CUDAGraphStat): def observe(self, cudagraph_stat: CUDAGraphStat) -> None:
self.stats.append(cudagraph_stat) self.stats.append(cudagraph_stat)
def generate_metric_table(self) -> str: def generate_metric_table(self) -> str:
...@@ -109,7 +111,7 @@ class CUDAGraphLogging: ...@@ -109,7 +111,7 @@ class CUDAGraphLogging:
+ "\n" + "\n"
) )
def log(self, log_fn=logger.info): def log(self, log_fn: Callable[..., Any] = logger.info) -> None:
if not self.stats: if not self.stats:
return return
log_fn(self.generate_metric_table()) log_fn(self.generate_metric_table())
...@@ -161,11 +163,11 @@ class CUDAGraphWrapper: ...@@ -161,11 +163,11 @@ class CUDAGraphWrapper:
def __init__( def __init__(
self, self,
runnable: Callable, runnable: Callable[..., Any],
vllm_config: VllmConfig, vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode, runtime_mode: CUDAGraphMode,
cudagraph_options: CUDAGraphOptions | None = None, cudagraph_options: CUDAGraphOptions | None = None,
): ) -> None:
self.runnable = runnable self.runnable = runnable
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.runtime_mode = runtime_mode self.runtime_mode = runtime_mode
...@@ -189,7 +191,7 @@ class CUDAGraphWrapper: ...@@ -189,7 +191,7 @@ class CUDAGraphWrapper:
# cudagraphs for. # cudagraphs for.
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {} self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {}
def __getattr__(self, key: str): def __getattr__(self, key: str) -> Any:
# allow accessing the attributes of the runnable. # allow accessing the attributes of the runnable.
if hasattr(self.runnable, key): if hasattr(self.runnable, key):
return getattr(self.runnable, key) return getattr(self.runnable, key)
...@@ -198,11 +200,11 @@ class CUDAGraphWrapper: ...@@ -198,11 +200,11 @@ class CUDAGraphWrapper:
f"cudagraph wrapper: {self.runnable}" f"cudagraph wrapper: {self.runnable}"
) )
def unwrap(self) -> Callable: def unwrap(self) -> Callable[..., Any]:
# in case we need to access the original runnable. # in case we need to access the original runnable.
return self.runnable return self.runnable
def __call__(self, *args, **kwargs): def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
forward_context = get_forward_context() forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor batch_descriptor = forward_context.batch_descriptor
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
......
...@@ -6,8 +6,8 @@ import hashlib ...@@ -6,8 +6,8 @@ import hashlib
import inspect import inspect
import os import os
import sys import sys
from collections.abc import Callable from collections.abc import Callable, Generator
from typing import TypeVar, overload from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload
from unittest.mock import patch from unittest.mock import patch
import torch import torch
...@@ -32,6 +32,14 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo ...@@ -32,6 +32,14 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo
from .monitor import start_monitoring_torch_compile from .monitor import start_monitoring_torch_compile
if TYPE_CHECKING:
# Only added on nightly/2.10 so wrap
try:
from torch._dynamo.package import SourceInfo
except ImportError:
# Fallback for old versions not supporting
SourceInfo = Any
logger = init_logger(__name__) logger = init_logger(__name__)
IGNORE_COMPILE_KEY = "_ignore_compile_vllm" IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
...@@ -59,7 +67,7 @@ def ignore_torch_compile(cls: _T) -> _T: ...@@ -59,7 +67,7 @@ def ignore_torch_compile(cls: _T) -> _T:
return cls return cls
def _should_ignore_torch_compile(cls) -> bool: def _should_ignore_torch_compile(cls: _T) -> bool:
""" """
Check if the class should be ignored for torch.compile. Check if the class should be ignored for torch.compile.
""" """
...@@ -224,7 +232,7 @@ def support_torch_compile( ...@@ -224,7 +232,7 @@ def support_torch_compile(
return cls_decorator_helper return cls_decorator_helper
def _model_hash_key(fn) -> str: def _model_hash_key(fn: Callable[..., Any]) -> str:
import vllm import vllm
sha256_hash = hashlib.sha256() sha256_hash = hashlib.sha256()
...@@ -234,7 +242,9 @@ def _model_hash_key(fn) -> str: ...@@ -234,7 +242,9 @@ def _model_hash_key(fn) -> str:
return sha256_hash.hexdigest() return sha256_hash.hexdigest()
def _verify_source_unchanged(source_info, vllm_config) -> None: def _verify_source_unchanged(
source_info: "SourceInfo", vllm_config: VllmConfig
) -> None:
from .caching import _compute_code_hash, _compute_code_hash_with_content from .caching import _compute_code_hash, _compute_code_hash_with_content
file_contents = {} file_contents = {}
...@@ -275,8 +285,12 @@ def _support_torch_compile( ...@@ -275,8 +285,12 @@ def _support_torch_compile(
setattr(cls, IGNORE_COMPILE_KEY, False) setattr(cls, IGNORE_COMPILE_KEY, False)
def __init__( def __init__(
self, *, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs self: _T,
): *,
vllm_config: VllmConfig | None = None,
prefix: str = "",
**kwargs: Any,
) -> None:
if vllm_config is None: if vllm_config is None:
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
...@@ -309,13 +323,17 @@ def _support_torch_compile( ...@@ -309,13 +323,17 @@ def _support_torch_compile(
compilation_counter.num_models_seen += 1 compilation_counter.num_models_seen += 1
self.compiled = False self.compiled = False
TorchCompileWithNoGuardsWrapper.__init__(self)
# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
TorchCompileWithNoGuardsWrapper.__init__(self) # type: ignore[arg-type]
cls.__init__ = __init__ cls.__init__ = __init__
def _mark_dynamic_inputs(mod, type, *args, **kwargs): def _mark_dynamic_inputs(
def mark_dynamic(arg, dims): mod: _T, ds_type: DynamicShapesType, *args: Any, **kwargs: Any
if type == DynamicShapesType.UNBACKED: ) -> None:
def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None:
if ds_type == DynamicShapesType.UNBACKED:
if is_torch_equal_or_newer("2.10.0.dev"): if is_torch_equal_or_newer("2.10.0.dev"):
for dim in dims: for dim in dims:
torch._dynamo.decorators.mark_unbacked( torch._dynamo.decorators.mark_unbacked(
...@@ -326,7 +344,7 @@ def _support_torch_compile( ...@@ -326,7 +344,7 @@ def _support_torch_compile(
else: else:
torch._dynamo.mark_dynamic(arg, dims) torch._dynamo.mark_dynamic(arg, dims)
sig = inspect.signature(mod.__class__.forward) sig = inspect.signature(mod.__class__.forward) # type: ignore[attr-defined]
bound_args = sig.bind(mod, *args, **kwargs) bound_args = sig.bind(mod, *args, **kwargs)
bound_args.apply_defaults() bound_args.apply_defaults()
for k, dims in dynamic_arg_dims.items(): for k, dims in dynamic_arg_dims.items():
...@@ -364,7 +382,7 @@ def _support_torch_compile( ...@@ -364,7 +382,7 @@ def _support_torch_compile(
else: else:
torch._dynamo.decorators.mark_unbacked(arg, dims) torch._dynamo.decorators.mark_unbacked(arg, dims)
def __call__(self, *args, **kwargs): def __call__(self: _T, *args: Any, **kwargs: Any) -> Any:
# torch.compiler.is_compiling() means we are inside the compilation # torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't # e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside. # need to compile the model inside.
...@@ -444,7 +462,7 @@ def _support_torch_compile( ...@@ -444,7 +462,7 @@ def _support_torch_compile(
not envs.VLLM_USE_AOT_COMPILE not envs.VLLM_USE_AOT_COMPILE
or self.vllm_config.compilation_config.backend == "eager" or self.vllm_config.compilation_config.backend == "eager"
) )
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
# This is the path for the first compilation. # This is the path for the first compilation.
# the first compilation needs to have dynamic shapes marked # the first compilation needs to have dynamic shapes marked
...@@ -477,7 +495,7 @@ def _support_torch_compile( ...@@ -477,7 +495,7 @@ def _support_torch_compile(
# during Dynamo tracing, and their corresponding files # during Dynamo tracing, and their corresponding files
inline_call = InliningInstructionTranslator.inline_call_ inline_call = InliningInstructionTranslator.inline_call_
def patched_inline_call(self_): def patched_inline_call(self_: Any) -> Any:
code = self_.f_code code = self_.f_code
self.compilation_config.traced_files.add(code.co_filename) self.compilation_config.traced_files.add(code.co_filename)
return inline_call(self_) return inline_call(self_)
...@@ -535,7 +553,7 @@ def _support_torch_compile( ...@@ -535,7 +553,7 @@ def _support_torch_compile(
str(e), str(e),
) )
else: else:
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
self.compiled = True self.compiled = True
return output return output
...@@ -545,7 +563,9 @@ def _support_torch_compile( ...@@ -545,7 +563,9 @@ def _support_torch_compile(
@contextlib.contextmanager @contextlib.contextmanager
def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig): def maybe_use_cudagraph_partition_wrapper(
vllm_config: VllmConfig,
) -> Generator[None, None, None]:
""" """
Context manager to set/unset customized cudagraph partition wrappers. Context manager to set/unset customized cudagraph partition wrappers.
...@@ -572,7 +592,9 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig): ...@@ -572,7 +592,9 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
current_platform.get_static_graph_wrapper_cls() current_platform.get_static_graph_wrapper_cls()
) )
def customized_cudagraph_wrapper(f, metadata: CUDAGraphWrapperMetadata): def customized_cudagraph_wrapper(
f: Callable[..., Any], metadata: CUDAGraphWrapperMetadata
) -> Any:
partition_id = metadata.partition_index partition_id = metadata.partition_index
num_partitions = metadata.num_partitions num_partitions = metadata.num_partitions
return static_graph_wrapper_class( return static_graph_wrapper_class(
...@@ -600,7 +622,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig): ...@@ -600,7 +622,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
@contextlib.contextmanager @contextlib.contextmanager
def _torch27_patch_tensor_subclasses(): def _torch27_patch_tensor_subclasses() -> Generator[None, None, None]:
""" """
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
using torch 2.7.0. This enables using weight_loader_v2 and the use of using torch 2.7.0. This enables using weight_loader_v2 and the use of
...@@ -614,7 +636,7 @@ def _torch27_patch_tensor_subclasses(): ...@@ -614,7 +636,7 @@ def _torch27_patch_tensor_subclasses():
_ColumnvLLMParameter, _ColumnvLLMParameter,
) )
def return_false(*args, **kwargs): def return_false(*args: Any, **kwargs: Any) -> Literal[False]:
return False return False
if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"): if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"):
......
...@@ -26,7 +26,7 @@ class FixFunctionalizationPass(VllmInductorPass): ...@@ -26,7 +26,7 @@ class FixFunctionalizationPass(VllmInductorPass):
""" """
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph): def __call__(self, graph: torch.fx.Graph) -> None:
# XPU does not support auto-functionalization yet. # XPU does not support auto-functionalization yet.
# Will enable this when switch to vllm-xpu-kernels. # Will enable this when switch to vllm-xpu-kernels.
if current_platform.is_xpu(): if current_platform.is_xpu():
...@@ -179,7 +179,7 @@ class FixFunctionalizationPass(VllmInductorPass): ...@@ -179,7 +179,7 @@ class FixFunctionalizationPass(VllmInductorPass):
) )
self.nodes_to_remove.clear() self.nodes_to_remove.clear()
def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]): def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]) -> None:
""" """
Stage a node (or nodes) for removal at the end of the pass. Stage a node (or nodes) for removal at the end of the pass.
""" """
...@@ -194,7 +194,7 @@ class FixFunctionalizationPass(VllmInductorPass): ...@@ -194,7 +194,7 @@ class FixFunctionalizationPass(VllmInductorPass):
node: torch.fx.Node, node: torch.fx.Node,
mutated_args: dict[int, torch.fx.Node | str], mutated_args: dict[int, torch.fx.Node | str],
args: tuple[torch.fx.Node | str, ...] | None = None, args: tuple[torch.fx.Node | str, ...] | None = None,
): ) -> None:
""" """
De-functionalize a node by replacing it with a call to the original. De-functionalize a node by replacing it with a call to the original.
It also replaces the getitem users with the mutated arguments. It also replaces the getitem users with the mutated arguments.
...@@ -206,7 +206,7 @@ class FixFunctionalizationPass(VllmInductorPass): ...@@ -206,7 +206,7 @@ class FixFunctionalizationPass(VllmInductorPass):
def replace_users_with_mutated_args( def replace_users_with_mutated_args(
self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str] self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str]
): ) -> None:
""" """
Replace all getitem users of the auto-functionalized node with the Replace all getitem users of the auto-functionalized node with the
mutated arguments. mutated arguments.
...@@ -237,7 +237,7 @@ class FixFunctionalizationPass(VllmInductorPass): ...@@ -237,7 +237,7 @@ class FixFunctionalizationPass(VllmInductorPass):
graph: torch.fx.Graph, graph: torch.fx.Graph,
node: torch.fx.Node, node: torch.fx.Node,
args: tuple[torch.fx.Node | str, ...] | None = None, args: tuple[torch.fx.Node | str, ...] | None = None,
): ) -> None:
""" """
Insert a new defunctionalized node into the graph before node. Insert a new defunctionalized node into the graph before node.
If one of the kwargs is 'out', provide args directly, If one of the kwargs is 'out', provide args directly,
......
...@@ -29,6 +29,9 @@ else: ...@@ -29,6 +29,9 @@ else:
Torch25CustomGraphPass as CustomGraphPass, Torch25CustomGraphPass as CustomGraphPass,
) )
# Re-export CustomGraphPass for external usage
__all__ = ["CustomGraphPass"]
_pass_context = None _pass_context = None
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
......
...@@ -65,7 +65,7 @@ class NoOpEliminationPass(VllmInductorPass): ...@@ -65,7 +65,7 @@ class NoOpEliminationPass(VllmInductorPass):
""" """
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
def __call__(self, graph: torch.fx.Graph): def __call__(self, graph: torch.fx.Graph) -> None:
count = 0 count = 0
# Remove no-op reshapes/views: # Remove no-op reshapes/views:
for node in graph.nodes: for node in graph.nodes:
...@@ -117,7 +117,7 @@ class NoOpEliminationPass(VllmInductorPass): ...@@ -117,7 +117,7 @@ class NoOpEliminationPass(VllmInductorPass):
2. The dimensions both correspond to the same SymInt 2. The dimensions both correspond to the same SymInt
""" """
# Case 1 # Case 1
return statically_known_true(dim == i_dim) return statically_known_true(dim == i_dim) # type: ignore[no-any-return]
def all_dims_equivalent( def all_dims_equivalent(
self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt] self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools import functools
from collections.abc import Callable
from typing import Any, ParamSpec, TypeVar
from torch import fx as fx from torch import fx as fx
...@@ -40,8 +42,11 @@ from .noop_elimination import NoOpEliminationPass ...@@ -40,8 +42,11 @@ from .noop_elimination import NoOpEliminationPass
logger = init_logger(__name__) logger = init_logger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
def with_pattern_match_debug(fn):
def with_pattern_match_debug(fn: Callable[P, R]) -> Callable[P, R]:
""" """
Function decorator that turns on inductor pattern match debug Function decorator that turns on inductor pattern match debug
for the duration of the call. for the duration of the call.
...@@ -49,7 +54,7 @@ def with_pattern_match_debug(fn): ...@@ -49,7 +54,7 @@ def with_pattern_match_debug(fn):
""" """
@functools.wraps(fn) @functools.wraps(fn)
def wrapper(*args, **kwargs): def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None: if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
# optionally check rank here # optionally check rank here
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val): with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
...@@ -59,7 +64,7 @@ def with_pattern_match_debug(fn): ...@@ -59,7 +64,7 @@ def with_pattern_match_debug(fn):
return wrapper return wrapper
class PostGradPassManager(CustomGraphPass): class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
""" """
The pass manager for post-grad passes. The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes. It handles configuration, adding custom passes, and running passes.
...@@ -74,11 +79,11 @@ class PostGradPassManager(CustomGraphPass): ...@@ -74,11 +79,11 @@ class PostGradPassManager(CustomGraphPass):
This way, all passes operate on a functionalized graph. This way, all passes operate on a functionalized graph.
""" """
def __init__(self): def __init__(self) -> None:
self.passes: list[InductorPass] = [] self.passes: list[InductorPass] = []
@with_pattern_match_debug @with_pattern_match_debug
def __call__(self, graph: fx.Graph): def __call__(self, graph: fx.Graph) -> None:
VllmInductorPass.dump_prefix = 0 # reset dump index VllmInductorPass.dump_prefix = 0 # reset dump index
compile_range = get_pass_context().compile_range compile_range = get_pass_context().compile_range
...@@ -98,7 +103,7 @@ class PostGradPassManager(CustomGraphPass): ...@@ -98,7 +103,7 @@ class PostGradPassManager(CustomGraphPass):
self.fix_functionalization(graph) self.fix_functionalization(graph)
VllmInductorPass.dump_prefix = None # Cleanup index VllmInductorPass.dump_prefix = None # Cleanup index
def configure(self, config: VllmConfig): def configure(self, config: VllmConfig) -> None:
self.pass_config = config.compilation_config.pass_config self.pass_config = config.compilation_config.pass_config
# Set the current vllm config to allow tracing CustomOp instances # Set the current vllm config to allow tracing CustomOp instances
...@@ -135,23 +140,25 @@ class PostGradPassManager(CustomGraphPass): ...@@ -135,23 +140,25 @@ class PostGradPassManager(CustomGraphPass):
self.post_cleanup = PostCleanupPass(config) self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config) self.fix_functionalization = FixFunctionalizationPass(config)
def add(self, pass_: InductorPass): def add(self, pass_: InductorPass) -> None:
assert isinstance(pass_, InductorPass) assert isinstance(pass_, InductorPass)
self.passes.append(pass_) self.passes.append(pass_)
def uuid(self): def uuid(self) -> str:
""" """
The PostGradPassManager is set as a custom pass in the Inductor and The PostGradPassManager is set as a custom pass in the Inductor and
affects compilation caching. Its uuid depends on the UUIDs of all affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info. dependent passes and the pass config. See InductorPass for more info.
""" """
state = {"pass_config": self.pass_config.compute_hash(), "passes": []} passes = []
state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
for pass_ in self.passes: for pass_ in self.passes:
state["passes"].append(pass_.uuid()) passes.append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid()) passes.append(self.fix_functionalization.uuid())
# Include the compile range in the uuid to ensure that inductor # Include the compile range in the uuid to ensure that inductor
# recompiles the graph for the new dynamic compile range. # recompiles the graph for the new dynamic compile range.
state["compile_range"] = str(get_pass_context().compile_range) state["compile_range"] = str(get_pass_context().compile_range)
state["passes"] = passes
return InductorPass.hash_dict(state) return InductorPass.hash_dict(state)
...@@ -86,7 +86,16 @@ class PiecewiseBackend: ...@@ -86,7 +86,16 @@ class PiecewiseBackend:
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges) self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
# We only keep compilation management inside this class directly. # We only keep compilation management inside this class directly.
if self.compile_sizes is not None:
for size in self.compile_sizes: for size in self.compile_sizes:
if isinstance(size, str):
assert size == "cudagraph_capture_sizes"
raise NotImplementedError(
"cudagraph_capture_sizes not supported in compile_sizes."
"This should be handled in `post_init_cudagraph_sizes`."
)
else:
assert isinstance(size, int)
range = Range(start=size, end=size) range = Range(start=size, end=size)
if range not in self.compile_ranges: if range not in self.compile_ranges:
self.range_entries[range] = RangeEntry( self.range_entries[range] = RangeEntry(
...@@ -99,14 +108,14 @@ class PiecewiseBackend: ...@@ -99,14 +108,14 @@ class PiecewiseBackend:
compile_range=range, compile_range=range,
) )
def check_for_ending_compilation(self): def check_for_ending_compilation(self) -> None:
if self.is_last_graph and not self.to_be_compiled_ranges: if self.is_last_graph and not self.to_be_compiled_ranges:
# no specific sizes to compile # no specific sizes to compile
# save the hash of the inductor graph for the next run # save the hash of the inductor graph for the next run
self.vllm_backend.compiler_manager.save_to_file() self.vllm_backend.compiler_manager.save_to_file()
end_monitoring_torch_compile(self.vllm_config) end_monitoring_torch_compile(self.vllm_config)
def _fakify_args(self, args: list[Any]) -> list[Any]: def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]:
# We need to pass fake example_inputs, otherwise torch.compile # We need to pass fake example_inputs, otherwise torch.compile
# will fakify the example_inputs potentially causing some non dynamic # will fakify the example_inputs potentially causing some non dynamic
# dimension to be be duck shaped to other existing shapes that have hints # dimension to be be duck shaped to other existing shapes that have hints
...@@ -127,7 +136,9 @@ class PiecewiseBackend: ...@@ -127,7 +136,9 @@ class PiecewiseBackend:
assert len(fake_example_inputs) == len(args) assert len(fake_example_inputs) == len(args)
return fake_example_inputs return fake_example_inputs
def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: def _maybe_compile_for_range_entry(
self, range_entry: RangeEntry, args: tuple[Any, ...]
) -> Any:
if not range_entry.compiled: if not range_entry.compiled:
range_entry.compiled = True range_entry.compiled = True
self.to_be_compiled_ranges.remove(range_entry.compile_range) self.to_be_compiled_ranges.remove(range_entry.compile_range)
...@@ -136,14 +147,14 @@ class PiecewiseBackend: ...@@ -136,14 +147,14 @@ class PiecewiseBackend:
# fakify for range, real args for concrete size. # fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in # For concrete size, we clear the shape env in
# compiler_manager.compile() so no need to fakify. # compiler_manager.compile() so no need to fakify.
args = ( args_list = (
self._fakify_args(args) self._fakify_args(args)
if not range_entry.compile_range.is_single_size() if not range_entry.compile_range.is_single_size()
else args else list(args)
) )
range_entry.runnable = self.vllm_backend.compiler_manager.compile( range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph, self.graph,
args, args_list,
self.vllm_backend.inductor_config, self.vllm_backend.inductor_config,
self.compilation_config, self.compilation_config,
compile_range=range_entry.compile_range, compile_range=range_entry.compile_range,
...@@ -153,10 +164,13 @@ class PiecewiseBackend: ...@@ -153,10 +164,13 @@ class PiecewiseBackend:
self.check_for_ending_compilation() self.check_for_ending_compilation()
def _find_range_for_shape(self, runtime_shape: int) -> Range | None: def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
# First we try to find the range entry for the concrete compile size # First we try to find the range entry for the concrete compile size
# If not found, we search for the range entry # If not found, we search for the range entry
# that contains the runtime shape. # that contains the runtime shape.
if self.compile_sizes is None:
return None
if runtime_shape in self.compile_sizes: if runtime_shape in self.compile_sizes:
return self.range_entries[Range(start=runtime_shape, end=runtime_shape)] return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
else: else:
...@@ -165,7 +179,7 @@ class PiecewiseBackend: ...@@ -165,7 +179,7 @@ class PiecewiseBackend:
return self.range_entries[range] return self.range_entries[range]
return None return None
def __call__(self, *args) -> Any: def __call__(self, *args: Any) -> Any:
runtime_shape = args[self.sym_shape_indices[0]] runtime_shape = args[self.sym_shape_indices[0]]
range_entry = self._find_range_for_shape(runtime_shape) range_entry = self._find_range_for_shape(runtime_shape)
......
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
import os import os
import sys import sys
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Callable, Generator
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from types import CodeType from types import CodeType
from typing import Any from typing import Any, ParamSpec, TypeVar
import torch import torch
import torch._C._dynamo.guards import torch._C._dynamo.guards
...@@ -19,19 +20,26 @@ from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context ...@@ -19,19 +20,26 @@ from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
logger = init_logger(__name__) logger = init_logger(__name__)
R = TypeVar("R")
P = ParamSpec("P")
def _noop_add_global_state_guard(self, *args, **kwargs):
def _noop_add_global_state_guard(
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
) -> None:
"""No-op to skip the GLOBAL_STATE guard entirely""" """No-op to skip the GLOBAL_STATE guard entirely"""
pass pass
def _noop_add_torch_function_mode_stack_guard(self, *args, **kwargs): def _noop_add_torch_function_mode_stack_guard(
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
) -> None:
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely""" """No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
pass pass
@contextmanager @contextmanager
def _compilation_context(): def _compilation_context() -> Generator[None, None, None]:
"""Context manager for compilation settings and patches. """Context manager for compilation settings and patches.
This manager: This manager:
...@@ -88,13 +96,15 @@ class TorchCompileWithNoGuardsWrapper: ...@@ -88,13 +96,15 @@ class TorchCompileWithNoGuardsWrapper:
since we drop all guards. since we drop all guards.
""" """
def check_invariants_and_forward(self, *args, **kwargs): def check_invariants_and_forward(self, *args: Any, **kwargs: Any) -> Any:
assert hasattr(self, "_check_shape_invariants") assert hasattr(self, "_check_shape_invariants")
self._check_shape_invariants(*args, **kwargs) self._check_shape_invariants(*args, **kwargs)
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
def _call_with_optional_nvtx_range(self, callable_fn, *args, **kwargs): def _call_with_optional_nvtx_range(
self, callable_fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> Any:
if self.layerwise_nvtx_tracing_enabled: if self.layerwise_nvtx_tracing_enabled:
args_list = list(args) args_list = list(args)
kwargs_dict = dict(kwargs) kwargs_dict = dict(kwargs)
...@@ -108,7 +118,7 @@ class TorchCompileWithNoGuardsWrapper: ...@@ -108,7 +118,7 @@ class TorchCompileWithNoGuardsWrapper:
return ctx.result return ctx.result
return callable_fn(*args, **kwargs) return callable_fn(*args, **kwargs)
def __init__(self): def __init__(self) -> None:
self.compiled = False self.compiled = False
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
...@@ -192,9 +202,9 @@ class TorchCompileWithNoGuardsWrapper: ...@@ -192,9 +202,9 @@ class TorchCompileWithNoGuardsWrapper:
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE: if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
self._compiled_bytecode = None self._compiled_bytecode: CodeType | None = None
def aot_compile(self, *args, **kwargs): def aot_compile(self, *args: Any, **kwargs: Any) -> Any:
if not hasattr(self._compiled_callable, "aot_compile"): if not hasattr(self._compiled_callable, "aot_compile"):
raise RuntimeError( raise RuntimeError(
"aot_compile is not supported by the current configuration. " "aot_compile is not supported by the current configuration. "
...@@ -203,7 +213,7 @@ class TorchCompileWithNoGuardsWrapper: ...@@ -203,7 +213,7 @@ class TorchCompileWithNoGuardsWrapper:
) )
return self._compiled_callable.aot_compile((args, kwargs)) return self._compiled_callable.aot_compile((args, kwargs))
def __call__(self, *args, **kwargs): def __call__(self, *args: Any, **kwargs: Any) -> Any:
if envs.VLLM_USE_BYTECODE_HOOK: if envs.VLLM_USE_BYTECODE_HOOK:
if ( if (
self.vllm_config.compilation_config.mode self.vllm_config.compilation_config.mode
...@@ -236,13 +246,13 @@ class TorchCompileWithNoGuardsWrapper: ...@@ -236,13 +246,13 @@ class TorchCompileWithNoGuardsWrapper:
) )
@abstractmethod @abstractmethod
def forward(self, *args, **kwargs): ... def forward(self, *args: Any, **kwargs: Any) -> Any: ...
def original_code_object(self) -> CodeType: def original_code_object(self) -> CodeType:
"""Return the original code object of the forward method.""" """Return the original code object of the forward method."""
return self.__class__.forward.__code__ return self.__class__.forward.__code__
def bytecode_hook(self, old_code: CodeType, new_code: CodeType): def bytecode_hook(self, old_code: CodeType, new_code: CodeType) -> None:
"""Hook to save the compiled bytecode for direct execution.""" """Hook to save the compiled bytecode for direct execution."""
if old_code is not self.original_code_object(): if old_code is not self.original_code_object():
return return
...@@ -299,7 +309,7 @@ class TorchCompileWithNoGuardsWrapper: ...@@ -299,7 +309,7 @@ class TorchCompileWithNoGuardsWrapper:
raise RuntimeError(msg) raise RuntimeError(msg)
@contextmanager @contextmanager
def _dispatch_to_compiled_code(self): def _dispatch_to_compiled_code(self) -> Generator[None, None, None]:
# noqa: E501 # noqa: E501
""" """
Context manager to dispatch to internally compiled code for torch<2.8. Context manager to dispatch to internally compiled code for torch<2.8.
......
...@@ -32,6 +32,9 @@ else: ...@@ -32,6 +32,9 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
# Explicitly exports Range
__all__ = ["Range"]
class CompilationMode(enum.IntEnum): class CompilationMode(enum.IntEnum):
"""The compilation approach used for torch.compile-based compilation of the """The compilation approach used for torch.compile-based compilation of the
......
...@@ -60,7 +60,7 @@ def is_symmetric_memory_tensor(tensor: torch.Tensor): ...@@ -60,7 +60,7 @@ def is_symmetric_memory_tensor(tensor: torch.Tensor):
return False return False
def set_graph_pool_id(graph_pool_id): def set_graph_pool_id(graph_pool_id: Any) -> None:
global _graph_pool_id global _graph_pool_id
_graph_pool_id = graph_pool_id _graph_pool_id = graph_pool_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment