Unverified Commit 873480d1 authored by Lucas Kabela's avatar Lucas Kabela Committed by GitHub
Browse files

[Misc][BE] Type coverage for vllm/compilation [1/3] (#31554)


Signed-off-by: default avatarLucas Kabela <lucaskabela@meta.com>
parent 6f351548
...@@ -9,7 +9,7 @@ import operator ...@@ -9,7 +9,7 @@ import operator
import os import os
import pprint import pprint
import time import time
from collections.abc import Callable, Sequence from collections.abc import Callable, Generator, Sequence
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
...@@ -90,7 +90,7 @@ class CompilerManager: ...@@ -90,7 +90,7 @@ class CompilerManager:
support int as key. support int as key.
""" """
def __init__(self, compilation_config: CompilationConfig): def __init__(self, compilation_config: CompilationConfig) -> None:
self.cache: dict[tuple[Range, int, str], Any] = dict() self.cache: dict[tuple[Range, int, str], Any] = dict()
self.is_cache_updated = False self.is_cache_updated = False
self.compilation_config = compilation_config self.compilation_config = compilation_config
...@@ -100,7 +100,7 @@ class CompilerManager: ...@@ -100,7 +100,7 @@ class CompilerManager:
return self.compiler.compute_hash(vllm_config) return self.compiler.compute_hash(vllm_config)
@contextmanager @contextmanager
def compile_context(self, compile_range: Range): def compile_context(self, compile_range: Range) -> Generator[None, None, None]:
"""Provide compilation context for the duration of compilation to set """Provide compilation context for the duration of compilation to set
any torch global properties we want to scope to a single Inductor any torch global properties we want to scope to a single Inductor
compilation (e.g. partition rules, pass context).""" compilation (e.g. partition rules, pass context)."""
...@@ -115,7 +115,7 @@ class CompilerManager: ...@@ -115,7 +115,7 @@ class CompilerManager:
def initialize_cache( def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = "" self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
): ) -> None:
""" """
Initialize the cache directory for the compiler. Initialize the cache directory for the compiler.
...@@ -143,7 +143,7 @@ class CompilerManager: ...@@ -143,7 +143,7 @@ class CompilerManager:
# do not use eval(), it is unsafe. # do not use eval(), it is unsafe.
cache = ast.literal_eval(f.read()) cache = ast.literal_eval(f.read())
def check_type(value, ty): def check_type(value: Any, ty: type) -> None:
if not isinstance(value, ty): if not isinstance(value, ty):
raise TypeError(f"Expected {ty} but got {type(value)} for {value}") raise TypeError(f"Expected {ty} but got {type(value)} for {value}")
...@@ -165,7 +165,7 @@ class CompilerManager: ...@@ -165,7 +165,7 @@ class CompilerManager:
cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
) )
def save_to_file(self): def save_to_file(self) -> None:
if self.disable_cache or not self.is_cache_updated: if self.disable_cache or not self.is_cache_updated:
return return
printer = pprint.PrettyPrinter(indent=4) printer = pprint.PrettyPrinter(indent=4)
...@@ -198,7 +198,7 @@ class CompilerManager: ...@@ -198,7 +198,7 @@ class CompilerManager:
def compile( def compile(
self, self,
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs, example_inputs: list[Any],
additional_inductor_config, additional_inductor_config,
compilation_config: CompilationConfig, compilation_config: CompilationConfig,
compile_range: Range, compile_range: Range,
...@@ -373,7 +373,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): ...@@ -373,7 +373,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
compile_submod_names: list[str], compile_submod_names: list[str],
vllm_config: VllmConfig, vllm_config: VllmConfig,
vllm_backend: "VllmBackend", vllm_backend: "VllmBackend",
): ) -> None:
super().__init__(module) super().__init__(module)
from torch._guards import detect_fake_mode from torch._guards import detect_fake_mode
...@@ -385,7 +385,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): ...@@ -385,7 +385,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
# When True, it annoyingly dumps the torch.fx.Graph on errors. # When True, it annoyingly dumps the torch.fx.Graph on errors.
self.extra_traceback = False self.extra_traceback = False
def run(self, *args): def run(self, *args: Any) -> Any:
# maybe instead just assert inputs are fake? # maybe instead just assert inputs are fake?
fake_args = [ fake_args = [
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
...@@ -467,7 +467,7 @@ model_is_encoder: bool = False ...@@ -467,7 +467,7 @@ model_is_encoder: bool = False
@contextmanager @contextmanager
def set_model_tag(tag: str, is_encoder: bool = False): def set_model_tag(tag: str, is_encoder: bool = False) -> Generator[None, None, None]:
"""Context manager to set the model tag.""" """Context manager to set the model tag."""
global model_tag global model_tag
global model_is_encoder global model_is_encoder
...@@ -521,7 +521,7 @@ class VllmBackend: ...@@ -521,7 +521,7 @@ class VllmBackend:
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
is_encoder: bool = False, is_encoder: bool = False,
): ) -> None:
# if the model is initialized with a non-empty prefix, # if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix, # then usually it's enough to use that prefix,
# e.g. language_model, vision_model, etc. # e.g. language_model, vision_model, etc.
...@@ -558,7 +558,7 @@ class VllmBackend: ...@@ -558,7 +558,7 @@ class VllmBackend:
# `torch.compile` is JIT compiled, so we don't need to # `torch.compile` is JIT compiled, so we don't need to
# do anything here # do anything here
def configure_post_pass(self): def configure_post_pass(self) -> None:
self.pass_manager.configure(self.vllm_config) self.pass_manager.configure(self.vllm_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass # Post-grad custom passes are run using the post_grad_custom_post_pass
...@@ -580,7 +580,7 @@ class VllmBackend: ...@@ -580,7 +580,7 @@ class VllmBackend:
self.inductor_config[self.pass_key] = self.pass_manager self.inductor_config[self.pass_key] = self.pass_manager
def __call__( def __call__(
self, graph: fx.GraphModule, example_inputs self, graph: fx.GraphModule, example_inputs: Sequence[Any]
) -> VllmSerializableFunction: ) -> VllmSerializableFunction:
vllm_config = self.vllm_config vllm_config = self.vllm_config
# Minimal hashing here with existing utilities, reused below. # Minimal hashing here with existing utilities, reused below.
......
...@@ -50,7 +50,7 @@ if hasattr(torch.ops._C, "scaled_fp4_quant"): ...@@ -50,7 +50,7 @@ if hasattr(torch.ops._C, "scaled_fp4_quant"):
class BasePattern: class BasePattern:
def __init__(self, dtype: torch.dtype, device: str): def __init__(self, dtype: torch.dtype, device: str | None) -> None:
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.tp = get_tp_group() self.tp = get_tp_group()
...@@ -637,7 +637,7 @@ class AllReduceRMSNormPattern(BasePattern): ...@@ -637,7 +637,7 @@ class AllReduceRMSNormPattern(BasePattern):
self, self,
epsilon: float, epsilon: float,
dtype: torch.dtype, dtype: torch.dtype,
device: str, device: str | None,
allreduce_params: FlashInferFusedAllReduceParams, allreduce_params: FlashInferFusedAllReduceParams,
): ):
super().__init__(dtype, device) super().__init__(dtype, device)
...@@ -692,7 +692,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern): ...@@ -692,7 +692,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
self, self,
epsilon: float, epsilon: float,
dtype: torch.dtype, dtype: torch.dtype,
device: str, device: str | None,
allreduce_params: FlashInferFusedAllReduceParams, allreduce_params: FlashInferFusedAllReduceParams,
): ):
super().__init__(dtype, device) super().__init__(dtype, device)
...@@ -759,7 +759,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): ...@@ -759,7 +759,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
self, self,
epsilon: float, epsilon: float,
dtype: torch.dtype, dtype: torch.dtype,
device: str, device: str | None,
allreduce_params: FlashInferFusedAllReduceParams, allreduce_params: FlashInferFusedAllReduceParams,
): ):
super().__init__(dtype, device) super().__init__(dtype, device)
...@@ -828,7 +828,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern): ...@@ -828,7 +828,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
self, self,
epsilon: float, epsilon: float,
dtype: torch.dtype, dtype: torch.dtype,
device: str, device: str | None,
allreduce_params: FlashInferFusedAllReduceParams, allreduce_params: FlashInferFusedAllReduceParams,
): ):
super().__init__(dtype, device) super().__init__(dtype, device)
...@@ -902,7 +902,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): ...@@ -902,7 +902,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
self, self,
epsilon: float, epsilon: float,
dtype: torch.dtype, dtype: torch.dtype,
device: str, device: str | None,
allreduce_params: FlashInferFusedAllReduceParams, allreduce_params: FlashInferFusedAllReduceParams,
): ):
super().__init__(dtype, device) super().__init__(dtype, device)
...@@ -988,7 +988,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): ...@@ -988,7 +988,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
self, self,
epsilon: float, epsilon: float,
dtype: torch.dtype, dtype: torch.dtype,
device: str, device: str | None,
allreduce_params: FlashInferFusedAllReduceParams, allreduce_params: FlashInferFusedAllReduceParams,
): ):
super().__init__(dtype, device) super().__init__(dtype, device)
......
...@@ -31,7 +31,7 @@ class CompilerInterface: ...@@ -31,7 +31,7 @@ class CompilerInterface:
def initialize_cache( def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = "" self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
): ) -> None:
""" """
when the vLLM process uses `cache_dir` as the cache directory, when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory, the compiler should initialize itself with the cache directory,
...@@ -66,7 +66,7 @@ class CompilerInterface: ...@@ -66,7 +66,7 @@ class CompilerInterface:
compiler_config: dict[str, Any], compiler_config: dict[str, Any],
compile_range: Range, compile_range: Range,
key: str | None = None, key: str | None = None,
) -> tuple[Callable | None, Any | None]: ) -> tuple[Callable[..., Any] | None, Any | None]:
""" """
Compile the graph with the given example inputs and compiler config, Compile the graph with the given example inputs and compiler config,
with a range. The `compile_range` specifies the range of the inputs, with a range. The `compile_range` specifies the range of the inputs,
...@@ -100,7 +100,7 @@ class CompilerInterface: ...@@ -100,7 +100,7 @@ class CompilerInterface:
example_inputs: list[Any], example_inputs: list[Any],
graph_index: int, graph_index: int,
compile_range: Range, compile_range: Range,
) -> Callable: ) -> Callable[..., Any]:
""" """
Load the compiled function from the handle. Load the compiled function from the handle.
Raises an error if the handle is invalid. Raises an error if the handle is invalid.
...@@ -138,13 +138,13 @@ class AlwaysHitShapeEnv: ...@@ -138,13 +138,13 @@ class AlwaysHitShapeEnv:
def __init__(self) -> None: def __init__(self) -> None:
self.guards: list[Any] = [] self.guards: list[Any] = []
def evaluate_guards_expression(self, *args, **kwargs): def evaluate_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[True]:
return True return True
def get_pruned_guards(self, *args, **kwargs): def get_pruned_guards(self, *args: Any, **kwargs: Any) -> list[Any]:
return [] return []
def produce_guards_expression(self, *args, **kwargs): def produce_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[""]:
return "" return ""
...@@ -193,7 +193,7 @@ class InductorStandaloneAdaptor(CompilerInterface): ...@@ -193,7 +193,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
name = "inductor_standalone" name = "inductor_standalone"
def __init__(self, save_format: Literal["binary", "unpacked"]): def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
self.save_format = save_format self.save_format = save_format
def compute_hash(self, vllm_config: VllmConfig) -> str: def compute_hash(self, vllm_config: VllmConfig) -> str:
...@@ -205,7 +205,7 @@ class InductorStandaloneAdaptor(CompilerInterface): ...@@ -205,7 +205,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
def initialize_cache( def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = "" self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
): ) -> None:
self.cache_dir = cache_dir self.cache_dir = cache_dir
def compile( def compile(
...@@ -215,7 +215,7 @@ class InductorStandaloneAdaptor(CompilerInterface): ...@@ -215,7 +215,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
compiler_config: dict[str, Any], compiler_config: dict[str, Any],
compile_range: Range, compile_range: Range,
key: str | None = None, key: str | None = None,
) -> tuple[Callable | None, Any | None]: ) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_inductor_compiles += 1 compilation_counter.num_inductor_compiles += 1
current_config = {} current_config = {}
if compiler_config is not None: if compiler_config is not None:
...@@ -252,7 +252,7 @@ class InductorStandaloneAdaptor(CompilerInterface): ...@@ -252,7 +252,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
example_inputs: list[Any], example_inputs: list[Any],
graph_index: int, graph_index: int,
compile_range: Range, compile_range: Range,
) -> Callable: ) -> Callable[..., Any]:
assert isinstance(handle, tuple) assert isinstance(handle, tuple)
assert isinstance(handle[0], str) assert isinstance(handle[0], str)
assert isinstance(handle[1], str) assert isinstance(handle[1], str)
...@@ -264,7 +264,7 @@ class InductorStandaloneAdaptor(CompilerInterface): ...@@ -264,7 +264,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
returns_tuple = graph_returns_tuple(graph) returns_tuple = graph_returns_tuple(graph)
def compiled_graph_wrapper(*args): def compiled_graph_wrapper(*args: Any) -> tuple[Any, ...] | Any:
graph_output = inductor_compiled_graph(*args) graph_output = inductor_compiled_graph(*args)
# unpack the tuple if needed # unpack the tuple if needed
# TODO(rzou): the implication is that we're not # TODO(rzou): the implication is that we're not
...@@ -293,7 +293,7 @@ class InductorAdaptor(CompilerInterface): ...@@ -293,7 +293,7 @@ class InductorAdaptor(CompilerInterface):
def initialize_cache( def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = "" self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
): ) -> None:
self.cache_dir = cache_dir self.cache_dir = cache_dir
self.prefix = prefix self.prefix = prefix
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
...@@ -317,7 +317,7 @@ class InductorAdaptor(CompilerInterface): ...@@ -317,7 +317,7 @@ class InductorAdaptor(CompilerInterface):
compiler_config: dict[str, Any], compiler_config: dict[str, Any],
compile_range: Range, compile_range: Range,
key: str | None = None, key: str | None = None,
) -> tuple[Callable | None, Any | None]: ) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_inductor_compiles += 1 compilation_counter.num_inductor_compiles += 1
from torch._inductor.compile_fx import compile_fx from torch._inductor.compile_fx import compile_fx
...@@ -348,7 +348,7 @@ class InductorAdaptor(CompilerInterface): ...@@ -348,7 +348,7 @@ class InductorAdaptor(CompilerInterface):
original_load = FxGraphCache.load original_load = FxGraphCache.load
original_load_name = "torch._inductor.codecache.FxGraphCache.load" original_load_name = "torch._inductor.codecache.FxGraphCache.load"
def hijack_load(*args, **kwargs): def hijack_load(*args: Any, **kwargs: Any) -> Any:
inductor_compiled_graph = original_load(*args, **kwargs) inductor_compiled_graph = original_load(*args, **kwargs)
nonlocal file_path nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable compiled_fn = inductor_compiled_graph.current_callable
...@@ -375,7 +375,7 @@ class InductorAdaptor(CompilerInterface): ...@@ -375,7 +375,7 @@ class InductorAdaptor(CompilerInterface):
# function renamed in 2.6 # function renamed in 2.6
original_load_name = None original_load_name = None
def hijacked_compile_fx_inner(*args, **kwargs): def hijacked_compile_fx_inner(*args: Any, **kwargs: Any) -> Any:
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs) output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
nonlocal hash_str nonlocal hash_str
inductor_compiled_graph = output inductor_compiled_graph = output
...@@ -401,13 +401,13 @@ class InductorAdaptor(CompilerInterface): ...@@ -401,13 +401,13 @@ class InductorAdaptor(CompilerInterface):
hash_str = inductor_compiled_graph._fx_graph_cache_key hash_str = inductor_compiled_graph._fx_graph_cache_key
return output return output
def hijack_compiled_fx_graph_hash(*args, **kwargs): def hijack_compiled_fx_graph_hash(*args: Any, **kwargs: Any) -> Any:
out = compiled_fx_graph_hash(*args, **kwargs) out = compiled_fx_graph_hash(*args, **kwargs)
nonlocal hash_str nonlocal hash_str
hash_str = out[0] hash_str = out[0]
return out return out
def _check_can_cache(*args, **kwargs): def _check_can_cache(*args: Any, **kwargs: Any) -> None:
# no error means it can be cached. # no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo # Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs # tracing context, and also disables caching for graphs
...@@ -513,7 +513,7 @@ class InductorAdaptor(CompilerInterface): ...@@ -513,7 +513,7 @@ class InductorAdaptor(CompilerInterface):
example_inputs: list[Any], example_inputs: list[Any],
graph_index: int, graph_index: int,
compile_range: Range, compile_range: Range,
) -> Callable: ) -> Callable[..., Any]:
assert isinstance(handle, tuple) assert isinstance(handle, tuple)
assert isinstance(handle[0], str) assert isinstance(handle[0], str)
assert isinstance(handle[1], str) assert isinstance(handle[1], str)
...@@ -572,7 +572,7 @@ class InductorAdaptor(CompilerInterface): ...@@ -572,7 +572,7 @@ class InductorAdaptor(CompilerInterface):
returns_tuple = graph_returns_tuple(graph) returns_tuple = graph_returns_tuple(graph)
# this is the callable we return to Dynamo to run # this is the callable we return to Dynamo to run
def compiled_graph(*args): def compiled_graph(*args: Any) -> tuple[Any, ...] | Any:
# convert args to list # convert args to list
list_args = list(args) list_args = list(args)
graph_output = inductor_compiled_graph(list_args) graph_output = inductor_compiled_graph(list_args)
...@@ -584,7 +584,7 @@ class InductorAdaptor(CompilerInterface): ...@@ -584,7 +584,7 @@ class InductorAdaptor(CompilerInterface):
return compiled_graph return compiled_graph
def metrics_context(self) -> contextlib.AbstractContextManager: def metrics_context(self) -> contextlib.AbstractContextManager[Any]:
""" """
This method returns the Dynamo metrics context (if it exists, This method returns the Dynamo metrics context (if it exists,
otherwise a null context). It is used by various compile components. otherwise a null context). It is used by various compile components.
...@@ -603,12 +603,12 @@ class InductorAdaptor(CompilerInterface): ...@@ -603,12 +603,12 @@ class InductorAdaptor(CompilerInterface):
if is_torch_equal_or_newer("2.6"): if is_torch_equal_or_newer("2.6"):
import torch._dynamo.utils import torch._dynamo.utils
return torch._dynamo.utils.get_metrics_context() return torch._dynamo.utils.get_metrics_context() # type: ignore[no-any-return]
else: else:
return contextlib.nullcontext() return contextlib.nullcontext()
def set_inductor_config(config, compile_range: Range): def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None:
if compile_range.is_single_size(): if compile_range.is_single_size():
# for a specific batch size, tuning triton kernel parameters # for a specific batch size, tuning triton kernel parameters
# can be beneficial # can be beneficial
...@@ -618,7 +618,7 @@ def set_inductor_config(config, compile_range: Range): ...@@ -618,7 +618,7 @@ def set_inductor_config(config, compile_range: Range):
) )
def set_functorch_config(): def set_functorch_config() -> None:
torch._functorch.config.bundled_autograd_cache = False torch._functorch.config.bundled_autograd_cache = False
...@@ -632,7 +632,7 @@ class EagerAdaptor(CompilerInterface): ...@@ -632,7 +632,7 @@ class EagerAdaptor(CompilerInterface):
compiler_config: dict[str, Any], compiler_config: dict[str, Any],
compile_range: Range, compile_range: Range,
key: str | None = None, key: str | None = None,
) -> tuple[Callable | None, Any | None]: ) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_eager_compiles += 1 compilation_counter.num_eager_compiles += 1
# we don't need to compile the graph, just return the graph itself. # we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle. # It does not support caching, return None for the handle.
......
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
import copy import copy
import dataclasses import dataclasses
from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any
@dataclasses.dataclass @dataclasses.dataclass
...@@ -34,7 +36,7 @@ class CompilationCounter: ...@@ -34,7 +36,7 @@ class CompilationCounter:
return copy.deepcopy(self) return copy.deepcopy(self)
@contextmanager @contextmanager
def expect(self, **kwargs): def expect(self, **kwargs: Any) -> Generator[None, None, None]:
old = self.clone() old = self.clone()
yield yield
for k, v in kwargs.items(): for k, v in kwargs.items():
......
...@@ -219,6 +219,7 @@ class CUDAGraphWrapper: ...@@ -219,6 +219,7 @@ class CUDAGraphWrapper:
# runtime modes. # runtime modes.
return self.runnable(*args, **kwargs) return self.runnable(*args, **kwargs)
assert batch_descriptor is not None
if batch_descriptor not in self.concrete_cudagraph_entries: if batch_descriptor not in self.concrete_cudagraph_entries:
# create a new entry for this batch descriptor # create a new entry for this batch descriptor
self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry( self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry(
......
...@@ -7,10 +7,11 @@ from collections.abc import Iterable, Iterator ...@@ -7,10 +7,11 @@ from collections.abc import Iterable, Iterator
from torch import fx from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._ops import OpOverload, OpOverloadPacket from torch._ops import OpOverload, OpOverloadPacket
from torch.fx.node import Target
def is_func(node: fx.Node, target) -> bool: def is_func(node: fx.Node, target: Target) -> bool:
return node.op == "call_function" and node.target == target return bool(node.op == "call_function" and node.target == target)
def is_auto_func(node: fx.Node, op: OpOverload) -> bool: def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
......
...@@ -8,9 +8,9 @@ import hashlib ...@@ -8,9 +8,9 @@ import hashlib
import inspect import inspect
import json import json
import types import types
from collections.abc import Callable from collections.abc import Callable, Generator
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
import torch import torch
from torch import fx from torch import fx
...@@ -30,6 +30,8 @@ else: ...@@ -30,6 +30,8 @@ else:
) )
_pass_context = None _pass_context = None
P = ParamSpec("P")
R = TypeVar("R")
class PassContext: class PassContext:
...@@ -44,7 +46,7 @@ def get_pass_context() -> PassContext: ...@@ -44,7 +46,7 @@ def get_pass_context() -> PassContext:
@contextmanager @contextmanager
def pass_context(compile_range: Range): def pass_context(compile_range: Range) -> Generator[None, None, None]:
"""A context manager that stores the current pass context, """A context manager that stores the current pass context,
usually it is a list of sizes to specialize. usually it is a list of sizes to specialize.
""" """
...@@ -57,7 +59,7 @@ def pass_context(compile_range: Range): ...@@ -57,7 +59,7 @@ def pass_context(compile_range: Range):
_pass_context = prev_context _pass_context = prev_context
class InductorPass(CustomGraphPass): class InductorPass(CustomGraphPass): # type: ignore[misc]
""" """
A custom graph pass that uses a hash of its source as the UUID. A custom graph pass that uses a hash of its source as the UUID.
This is defined as a convenience and should work in most cases. This is defined as a convenience and should work in most cases.
...@@ -73,7 +75,7 @@ class InductorPass(CustomGraphPass): ...@@ -73,7 +75,7 @@ class InductorPass(CustomGraphPass):
return InductorPass.hash_source(self) return InductorPass.hash_source(self)
@staticmethod @staticmethod
def hash_source(*srcs: str | Any): def hash_source(*srcs: str | Any) -> str:
""" """
Utility method to hash the sources of functions or objects. Utility method to hash the sources of functions or objects.
:param srcs: strings or objects to add to the hash. :param srcs: strings or objects to add to the hash.
...@@ -93,7 +95,7 @@ class InductorPass(CustomGraphPass): ...@@ -93,7 +95,7 @@ class InductorPass(CustomGraphPass):
return hasher.hexdigest() return hasher.hexdigest()
@staticmethod @staticmethod
def hash_dict(dict_: dict[Any, Any]): def hash_dict(dict_: dict[Any, Any]) -> str:
""" """
Utility method to hash a dictionary, can alternatively be used for uuid. Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary. :return: A sha256 hash of the json rep of the dictionary.
...@@ -101,7 +103,7 @@ class InductorPass(CustomGraphPass): ...@@ -101,7 +103,7 @@ class InductorPass(CustomGraphPass):
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest() return hashlib.sha256(encoded).hexdigest()
def is_applicable_for_range(self, compile_range: Range): def is_applicable_for_range(self, compile_range: Range) -> bool:
return True return True
...@@ -111,25 +113,27 @@ class CallableInductorPass(InductorPass): ...@@ -111,25 +113,27 @@ class CallableInductorPass(InductorPass):
implementation of the UUID. implementation of the UUID.
""" """
def __init__(self, callable: Callable[[fx.Graph], None], uuid: Any | None = None): def __init__(
self, callable: Callable[[fx.Graph], None], uuid: Any | None = None
) -> None:
self.callable = callable self.callable = callable
self._uuid = self.hash_source(callable) if uuid is None else uuid self._uuid = self.hash_source(callable) if uuid is None else uuid
def __call__(self, graph: torch.fx.Graph): def __call__(self, graph: torch.fx.Graph) -> None:
self.callable(graph) self.callable(graph)
def uuid(self) -> Any: def uuid(self) -> Any:
return self._uuid return self._uuid
def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]: def enable_fake_mode(fn: Callable[P, R]) -> Callable[P, R]:
""" """
Applies a FakeTensorMode context. This is useful when you don't want to Applies a FakeTensorMode context. This is useful when you don't want to
create or run things with real tensors. create or run things with real tensors.
""" """
@functools.wraps(fn) @functools.wraps(fn)
def fn_new(*args, **kwargs) -> Any: def fn_new(*args: P.args, **kwargs: P.kwargs) -> R:
with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode(): with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
result = fn(*args, **kwargs) result = fn(*args, **kwargs)
......
...@@ -12,7 +12,7 @@ context_manager = None ...@@ -12,7 +12,7 @@ context_manager = None
torch_compile_start_time: float = 0.0 torch_compile_start_time: float = 0.0
def start_monitoring_torch_compile(vllm_config: VllmConfig): def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
global torch_compile_start_time global torch_compile_start_time
torch_compile_start_time = time.time() torch_compile_start_time = time.time()
...@@ -28,7 +28,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig): ...@@ -28,7 +28,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
context_manager.__enter__() context_manager.__enter__()
def end_monitoring_torch_compile(vllm_config: VllmConfig): def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
compilation_config: CompilationConfig = vllm_config.compilation_config compilation_config: CompilationConfig = vllm_config.compilation_config
if compilation_config.mode == CompilationMode.VLLM_COMPILE: if compilation_config.mode == CompilationMode.VLLM_COMPILE:
logger.info_once( logger.info_once(
...@@ -45,7 +45,7 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig): ...@@ -45,7 +45,7 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig):
cudagraph_capturing_enabled: bool = True cudagraph_capturing_enabled: bool = True
def validate_cudagraph_capturing_enabled(): def validate_cudagraph_capturing_enabled() -> None:
# used to monitor whether a cudagraph capturing is legal at runtime. # used to monitor whether a cudagraph capturing is legal at runtime.
# should be called before any cudagraph capturing. # should be called before any cudagraph capturing.
# if an illegal cudagraph capturing happens, raise an error. # if an illegal cudagraph capturing happens, raise an error.
...@@ -57,6 +57,6 @@ def validate_cudagraph_capturing_enabled(): ...@@ -57,6 +57,6 @@ def validate_cudagraph_capturing_enabled():
) )
def set_cudagraph_capturing_enabled(enabled: bool): def set_cudagraph_capturing_enabled(enabled: bool) -> None:
global cudagraph_capturing_enabled global cudagraph_capturing_enabled
cudagraph_capturing_enabled = enabled cudagraph_capturing_enabled = enabled
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib import contextlib
from collections.abc import Generator
import torch import torch
...@@ -38,7 +39,9 @@ def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool: ...@@ -38,7 +39,9 @@ def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
@contextlib.contextmanager @contextlib.contextmanager
def inductor_partition_rule_context(splitting_ops: list[str]): def inductor_partition_rule_context(
splitting_ops: list[str] | None,
) -> Generator[None, None, None]:
"""Context manager to temporarily register Inductor partition rules. """Context manager to temporarily register Inductor partition rules.
Registers custom partition rules for specified operators, forcing the Registers custom partition rules for specified operators, forcing the
......
...@@ -41,8 +41,8 @@ class _SequenceParallelPatternHelper: ...@@ -41,8 +41,8 @@ class _SequenceParallelPatternHelper:
self, self,
epsilon: float, epsilon: float,
dtype: torch.dtype, dtype: torch.dtype,
device: str, device: str | None,
): ) -> None:
self.epsilon = epsilon self.epsilon = epsilon
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
...@@ -64,7 +64,7 @@ class _SequenceParallelPatternHelper: ...@@ -64,7 +64,7 @@ class _SequenceParallelPatternHelper:
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str): def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
super().__init__(epsilon, dtype, device) super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon) self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
...@@ -74,7 +74,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): ...@@ -74,7 +74,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
return [input, arg3_1] return [input, arg3_1]
def register(self, pm_pass: PatternMatcherPass): def register(self, pm_pass: PatternMatcherPass) -> None:
def pattern( def pattern(
input: torch.Tensor, input: torch.Tensor,
arg3_1: torch.Tensor, arg3_1: torch.Tensor,
...@@ -100,7 +100,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): ...@@ -100,7 +100,7 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str): def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None):
super().__init__(epsilon, dtype, device) super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
...@@ -162,7 +162,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): ...@@ -162,7 +162,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
self, self,
epsilon: float, epsilon: float,
dtype: torch.dtype, dtype: torch.dtype,
device: str, device: str | None,
): ):
super().__init__(epsilon, dtype, device) super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherRMSNorm(epsilon) self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
...@@ -203,7 +203,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): ...@@ -203,7 +203,7 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(self, epsilon: float, dtype: torch.dtype, device: str): def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None):
super().__init__(epsilon, dtype, device) super().__init__(epsilon, dtype, device)
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any, NoReturn
import torch import torch
...@@ -29,14 +29,14 @@ class Torch25CustomGraphPass(ABC): # noqa (redefinition) ...@@ -29,14 +29,14 @@ class Torch25CustomGraphPass(ABC): # noqa (redefinition)
Return None to skip inductor code caching entirely. Return None to skip inductor code caching entirely.
""" """
def __getstate__(self): def __getstate__(self) -> Any | None:
""" """
Pickling is used instead of uuid() in torch<2.6. Just return uuid() Pickling is used instead of uuid() in torch<2.6. Just return uuid()
to enable subclasses to only have to implement uuid. to enable subclasses to only have to implement uuid.
""" """
return self.uuid() return self.uuid()
def __setstate__(self, state): def __setstate__(self, state: Any) -> NoReturn:
raise ValueError( raise ValueError(
"Cannot unpickle CustomGraphPass because pickling" "Cannot unpickle CustomGraphPass because pickling"
" is used for cache key uuid. Use torch>=2.6 with" " is used for cache key uuid. Use torch>=2.6 with"
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import functools import functools
import operator import operator
import time import time
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar from typing import ClassVar
...@@ -43,13 +44,17 @@ class VllmInductorPass(InductorPass): ...@@ -43,13 +44,17 @@ class VllmInductorPass(InductorPass):
) )
self.pass_config = config.compilation_config.pass_config self.pass_config = config.compilation_config.pass_config
self.model_dtype = config.model_config.dtype if config.model_config else None self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None self.device: str | None = (
config.device_config.device if config.device_config else None
)
self.pass_name = self.__class__.__name__ self.pass_name = self.__class__.__name__
@staticmethod @staticmethod
def time_and_log(call_fn): def time_and_log(
call_fn: Callable[["VllmInductorPass", torch.fx.Graph], None],
) -> Callable[["VllmInductorPass", torch.fx.Graph], None]:
@functools.wraps(call_fn) @functools.wraps(call_fn)
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph): def wrapped(self: VllmInductorPass, graph: torch.fx.Graph) -> None:
self.begin() self.begin()
self.dump_graph(graph, "before") self.dump_graph(graph, "before")
call_fn(self, graph) call_fn(self, graph)
...@@ -58,17 +63,17 @@ class VllmInductorPass(InductorPass): ...@@ -58,17 +63,17 @@ class VllmInductorPass(InductorPass):
return wrapped return wrapped
def dump_graph(self, graph: torch.fx.Graph, stage: str): def dump_graph(self, graph: torch.fx.Graph, stage: str) -> None:
i = VllmInductorPass.dump_prefix i = VllmInductorPass.dump_prefix
i_str = "" if i is None else f".{i}" i_str = "" if i is None else f".{i}"
lazy_format_graph_code( lazy_format_graph_code(
f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module
) )
def begin(self): def begin(self) -> None:
self._start_time = time.perf_counter_ns() self._start_time = time.perf_counter_ns()
def end_and_log(self): def end_and_log(self) -> None:
self._end_time = time.perf_counter_ns() self._end_time = time.perf_counter_ns()
duration_ms = float(self._end_time - self._start_time) / 1.0e6 duration_ms = float(self._end_time - self._start_time) / 1.0e6
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
...@@ -92,12 +97,14 @@ class VllmPatternMatcherPass(VllmInductorPass): ...@@ -92,12 +97,14 @@ class VllmPatternMatcherPass(VllmInductorPass):
def _replace_op_overloads(self, string: str) -> str: def _replace_op_overloads(self, string: str) -> str:
"""Replace <OpOverload(..., ...)> with nicer formulations""" """Replace <OpOverload(..., ...)> with nicer formulations"""
return self._OP_OVERLOAD_PATTERN.sub( return str(
self._OP_OVERLOAD_PATTERN.sub(
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}", lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
string, string,
) )
)
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass): def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass) -> None:
""" """
If debug dumping is enabled, dump the Inductor pattern-matcher patterns If debug dumping is enabled, dump the Inductor pattern-matcher patterns
into the debug_dump_path folder next to the dumped fx graphs. into the debug_dump_path folder next to the dumped fx graphs.
...@@ -165,9 +172,9 @@ class VllmPatternMatcherPass(VllmInductorPass): ...@@ -165,9 +172,9 @@ class VllmPatternMatcherPass(VllmInductorPass):
class PrinterInductorPass(VllmInductorPass): class PrinterInductorPass(VllmInductorPass):
def __init__(self, name: str, config: VllmConfig): def __init__(self, name: str, config: VllmConfig) -> None:
super().__init__(config) super().__init__(config)
self.name = name self.name = name
def __call__(self, graph: torch.fx.Graph): def __call__(self, graph: torch.fx.Graph) -> None:
self.dump_graph(graph, self.name) self.dump_graph(graph, self.name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment