"vscode:/vscode.git/clone" did not exist on "f6f8db8142b1301c5f2bf1a0dda2e8eef03381a7"
Unverified Commit 15e302df authored by Lucas Kabela's avatar Lucas Kabela Committed by GitHub
Browse files

[Misc][BE] Turn on strict type coverage for vllm/compilation (#31756)


Signed-off-by: default avatarLucas Kabela <lucaskabela@meta.com>
parent d117a4d1
...@@ -100,6 +100,13 @@ ignore_missing_imports = true ...@@ -100,6 +100,13 @@ ignore_missing_imports = true
check_untyped_defs = true check_untyped_defs = true
follow_imports = "silent" follow_imports = "silent"
[[tool.mypy.overrides]]
module = "vllm.compilation.*"
disallow_untyped_defs = true
disallow_incomplete_defs = true
warn_return_any = true
follow_imports = "silent"
[tool.pytest.ini_options] [tool.pytest.ini_options]
markers = [ markers = [
"slow_test", "slow_test",
......
...@@ -28,7 +28,7 @@ def test_bad_callable(): ...@@ -28,7 +28,7 @@ def test_bad_callable():
pass_manager.configure(config) pass_manager.configure(config)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
pass_manager.add(simple_callable) pass_manager.add(simple_callable) # type: ignore[arg-type]
# Pass that inherits from InductorPass # Pass that inherits from InductorPass
......
...@@ -77,6 +77,11 @@ EXCLUDE = [ ...@@ -77,6 +77,11 @@ EXCLUDE = [
"vllm/v1/attention/ops", "vllm/v1/attention/ops",
] ]
# Directories that should be checked with --strict
STRICT_DIRS = [
"vllm/compilation",
]
def group_files(changed_files: list[str]) -> dict[str, list[str]]: def group_files(changed_files: list[str]) -> dict[str, list[str]]:
""" """
...@@ -108,11 +113,17 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]: ...@@ -108,11 +113,17 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]:
return file_groups return file_groups
def is_strict_file(filepath: str) -> bool:
"""Check if a file should be checked with strict mode."""
return any(filepath.startswith(strict_dir) for strict_dir in STRICT_DIRS)
def mypy( def mypy(
targets: list[str], targets: list[str],
python_version: str | None, python_version: str | None,
follow_imports: str | None, follow_imports: str | None,
file_group: str, file_group: str,
strict: bool = False,
) -> int: ) -> int:
""" """
Run mypy on the given targets. Run mypy on the given targets.
...@@ -124,6 +135,7 @@ def mypy( ...@@ -124,6 +135,7 @@ def mypy(
follow_imports: Value for the --follow-imports option or None to use follow_imports: Value for the --follow-imports option or None to use
the default mypy behavior. the default mypy behavior.
file_group: The file group name for logging purposes. file_group: The file group name for logging purposes.
strict: If True, run mypy with --strict flag.
Returns: Returns:
The return code from mypy. The return code from mypy.
...@@ -133,6 +145,8 @@ def mypy( ...@@ -133,6 +145,8 @@ def mypy(
args += ["--python-version", python_version] args += ["--python-version", python_version]
if follow_imports is not None: if follow_imports is not None:
args += ["--follow-imports", follow_imports] args += ["--follow-imports", follow_imports]
if strict:
args += ["--strict"]
print(f"$ {' '.join(args)} {file_group}") print(f"$ {' '.join(args)} {file_group}")
return subprocess.run(args + targets, check=False).returncode return subprocess.run(args + targets, check=False).returncode
...@@ -149,8 +163,28 @@ def main(): ...@@ -149,8 +163,28 @@ def main():
for file_group, changed_files in file_groups.items(): for file_group, changed_files in file_groups.items():
follow_imports = None if ci and file_group == "" else "skip" follow_imports = None if ci and file_group == "" else "skip"
if changed_files: if changed_files:
# Separate files into strict and non-strict groups
strict_files = [f for f in changed_files if is_strict_file(f)]
non_strict_files = [f for f in changed_files if not is_strict_file(f)]
# Run mypy on non-strict files
if non_strict_files:
returncode |= mypy(
non_strict_files,
python_version,
follow_imports,
file_group,
strict=False,
)
# Run mypy on strict files with --strict flag
if strict_files:
returncode |= mypy( returncode |= mypy(
changed_files, python_version, follow_imports, file_group strict_files,
python_version,
follow_imports,
f"{file_group} (strict)",
strict=True,
) )
return returncode return returncode
......
...@@ -68,7 +68,7 @@ def make_copy_and_call( ...@@ -68,7 +68,7 @@ def make_copy_and_call(
A wrapper function that copies inputs and calls the compiled function A wrapper function that copies inputs and calls the compiled function
""" """
def copy_and_call(*args): def copy_and_call(*args: Any) -> Any:
list_args = list(args) list_args = list(args)
for i, index in enumerate(sym_tensor_indices): for i, index in enumerate(sym_tensor_indices):
runtime_tensor = list_args[index] runtime_tensor = list_args[index]
......
...@@ -43,15 +43,15 @@ class StandaloneCompiledArtifacts: ...@@ -43,15 +43,15 @@ class StandaloneCompiledArtifacts:
split on attn) split on attn)
""" """
def __init__(self): def __init__(self) -> None:
# dict from submodule name to byte hash # dict from submodule name to byte hash
self.submodule_bytes = {} self.submodule_bytes: dict[str, str] = {}
# dict from byte hash to bytes # dict from byte hash to bytes
self.submodule_bytes_store = {} self.submodule_bytes_store: dict[str, bytes] = {}
# dict from byte hash to loaded module # dict from byte hash to loaded module
self.loaded_submodule_store = {} self.loaded_submodule_store: dict[str, Any] = {}
def insert(self, submod_name: str, shape: str, entry: bytes): def insert(self, submod_name: str, shape: str, entry: bytes) -> None:
hasher = hashlib.sha256() hasher = hashlib.sha256()
hasher.update(entry) hasher.update(entry)
hex_digest = hasher.hexdigest() hex_digest = hasher.hexdigest()
...@@ -86,7 +86,7 @@ class StandaloneCompiledArtifacts: ...@@ -86,7 +86,7 @@ class StandaloneCompiledArtifacts:
self.submodule_bytes[f"{submod_name}_{shape}"] self.submodule_bytes[f"{submod_name}_{shape}"]
] ]
def get_loaded(self, submod_name: str, shape: str): def get_loaded(self, submod_name: str, shape: str) -> Any:
logger.debug( logger.debug(
"getting artifact for submod %s with shape %s", "getting artifact for submod %s with shape %s",
submod_name, submod_name,
...@@ -119,7 +119,7 @@ class StandaloneCompiledArtifacts: ...@@ -119,7 +119,7 @@ class StandaloneCompiledArtifacts:
from torch._inductor.standalone_compile import AOTCompiledArtifact from torch._inductor.standalone_compile import AOTCompiledArtifact
def _load_entry(entry_bytes) -> AOTCompiledArtifact: def _load_entry(entry_bytes: bytes) -> AOTCompiledArtifact:
entry = pickle.loads(entry_bytes) entry = pickle.loads(entry_bytes)
return AOTCompiledArtifact.deserialize(entry) return AOTCompiledArtifact.deserialize(entry)
...@@ -132,13 +132,13 @@ class StandaloneCompiledArtifacts: ...@@ -132,13 +132,13 @@ class StandaloneCompiledArtifacts:
logger.debug("loaded all %s submodules", self.num_artifacts()) logger.debug("loaded all %s submodules", self.num_artifacts())
def __getstate__(self): def __getstate__(self) -> dict[str, dict[str, str] | dict[str, bytes]]:
return { return {
"submodule_bytes": self.submodule_bytes, "submodule_bytes": self.submodule_bytes,
"submodule_bytes_store": self.submodule_bytes_store, "submodule_bytes_store": self.submodule_bytes_store,
} }
def __setstate__(self, state): def __setstate__(self, state: dict[str, dict[str, Any]]) -> None:
self.submodule_bytes = state["submodule_bytes"] self.submodule_bytes = state["submodule_bytes"]
self.submodule_bytes_store = state["submodule_bytes_store"] self.submodule_bytes_store = state["submodule_bytes_store"]
self.loaded_submodule_store = {} self.loaded_submodule_store = {}
...@@ -387,7 +387,7 @@ def reconstruct_serializable_fn_from_mega_artifact( ...@@ -387,7 +387,7 @@ def reconstruct_serializable_fn_from_mega_artifact(
standalone_compile_artifacts.load_all() standalone_compile_artifacts.load_all()
submod_names = standalone_compile_artifacts.submodule_names() submod_names = standalone_compile_artifacts.submodule_names()
compiled_callables: dict[str, dict[str, Callable]] = {} compiled_callables: dict[str, dict[str, Callable[..., Any]]] = {}
for cache_key in standalone_compile_artifacts.submodule_bytes: for cache_key in standalone_compile_artifacts.submodule_bytes:
submod_name, shape_str = cache_key.rsplit("_", 1) submod_name, shape_str = cache_key.rsplit("_", 1)
...@@ -495,9 +495,10 @@ def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str: ...@@ -495,9 +495,10 @@ def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
# e.g. exec(). We can't actually check these. # e.g. exec(). We can't actually check these.
continue continue
hash_content.append(content) hash_content.append(content)
return safe_hash( result: str = safe_hash(
"\n".join(hash_content).encode(), usedforsecurity=False "\n".join(hash_content).encode(), usedforsecurity=False
).hexdigest() ).hexdigest()
return result
def _compute_code_hash(files: set[str]) -> str: def _compute_code_hash(files: set[str]) -> str:
......
...@@ -30,19 +30,15 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass ...@@ -30,19 +30,15 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
flashinfer_comm: ModuleType | None = None
if find_spec("flashinfer"): if find_spec("flashinfer"):
try: try:
import flashinfer.comm as flashinfer_comm import flashinfer.comm as _flashinfer_comm
flashinfer_comm: ModuleType | None = ( # type: ignore[no-redef] if hasattr(_flashinfer_comm, "trtllm_allreduce_fusion"):
flashinfer_comm flashinfer_comm = _flashinfer_comm
if hasattr(flashinfer_comm, "trtllm_allreduce_fusion")
else None
)
except ImportError: except ImportError:
flashinfer_comm = None # type: ignore[assignment] pass
else:
flashinfer_comm = None # type: ignore[assignment]
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -441,7 +437,7 @@ class AsyncTPPass(VllmPatternMatcherPass): ...@@ -441,7 +437,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
): ):
return True return True
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
return compile_range.is_single_size() and compile_range.end % tp_size == 0 return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None: def __call__(self, graph: fx.Graph) -> None:
...@@ -516,7 +512,7 @@ if flashinfer_comm is not None: ...@@ -516,7 +512,7 @@ if flashinfer_comm is not None:
# Get one shot input size limit for the current world size # Get one shot input size limit for the current world size
# for the current device capability # for the current device capability
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
device_capability, # type: ignore[arg-type] device_capability, # type: ignore[arg-type, unused-ignore]
{}, {},
).get(world_size, None) ).get(world_size, None)
# Use one shot if no max size is specified # Use one shot if no max size is specified
...@@ -666,6 +662,7 @@ class AllReduceRMSNormPattern(BasePattern): ...@@ -666,6 +662,7 @@ class AllReduceRMSNormPattern(BasePattern):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input) residual = torch.zeros_like(input)
rms_result = torch.empty_like(input) rms_result = torch.empty_like(input)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized( allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm, flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input, allreduce_in=input,
...@@ -722,6 +719,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern): ...@@ -722,6 +719,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
def replacement( def replacement(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized( allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm, flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input, allreduce_in=input,
...@@ -800,6 +798,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern): ...@@ -800,6 +798,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
residual = torch.zeros_like(input) residual = torch.zeros_like(input)
result_rms = torch.empty_like(input) result_rms = torch.empty_like(input)
result_quant = torch.empty_like(input, dtype=self.quant_dtype) result_quant = torch.empty_like(input, dtype=self.quant_dtype)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized( allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm, flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input, allreduce_in=input,
...@@ -875,6 +874,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern): ...@@ -875,6 +874,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
scale: torch.Tensor, scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
result_quant = torch.empty_like(input, dtype=self.quant_dtype) result_quant = torch.empty_like(input, dtype=self.quant_dtype)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized( allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm, flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input, allreduce_in=input,
...@@ -960,6 +960,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): ...@@ -960,6 +960,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input) residual = torch.zeros_like(input)
result_rms = torch.empty_like(input) result_rms = torch.empty_like(input)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized( allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm, flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input, allreduce_in=input,
...@@ -1055,6 +1056,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): ...@@ -1055,6 +1056,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
weight: torch.Tensor, weight: torch.Tensor,
input_global_scale: torch.Tensor, input_global_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized( allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm, flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input, allreduce_in=input,
...@@ -1131,7 +1133,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass): ...@@ -1131,7 +1133,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
) )
self.ipc_handles, workspace_tensor = ( self.ipc_handles, workspace_tensor = (
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( # type: ignore[misc] flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
tp_rank=rank, tp_rank=rank,
tp_size=self.tp_size, tp_size=self.tp_size,
max_token_num=self.max_token_num, max_token_num=self.max_token_num,
...@@ -1204,7 +1206,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass): ...@@ -1204,7 +1206,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
if self.disabled: if self.disabled:
logger.warning_once("AllReduce fusion pass is disabled.") logger.warning_once("AllReduce fusion pass is disabled.")
return False return False
return compile_range.end <= self.max_token_num return bool(compile_range.end <= self.max_token_num)
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None: def __call__(self, graph: fx.Graph) -> None:
......
...@@ -201,9 +201,9 @@ class InductorStandaloneAdaptor(CompilerInterface): ...@@ -201,9 +201,9 @@ class InductorStandaloneAdaptor(CompilerInterface):
def compute_hash(self, vllm_config: VllmConfig) -> str: def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors() factors = get_inductor_factors()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[ hash_str: str = safe_hash(
:10 str(factors).encode(), usedforsecurity=False
] ).hexdigest()[:10]
return hash_str return hash_str
def initialize_cache( def initialize_cache(
...@@ -319,9 +319,9 @@ class InductorAdaptor(CompilerInterface): ...@@ -319,9 +319,9 @@ class InductorAdaptor(CompilerInterface):
def compute_hash(self, vllm_config: VllmConfig) -> str: def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors() factors = get_inductor_factors()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[ hash_str: str = safe_hash(
:10 str(factors).encode(), usedforsecurity=False
] ).hexdigest()[:10]
return hash_str return hash_str
def initialize_cache( def initialize_cache(
......
...@@ -45,10 +45,10 @@ logger = init_logger(__name__) ...@@ -45,10 +45,10 @@ logger = init_logger(__name__)
IGNORE_COMPILE_KEY = "_ignore_compile_vllm" IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
_T = TypeVar("_T", bound=type[nn.Module]) _T = TypeVar("_T", bound=nn.Module)
def ignore_torch_compile(cls: _T) -> _T: def ignore_torch_compile(cls: type[_T]) -> type[_T]:
""" """
A decorator to ignore support_torch_compile decorator A decorator to ignore support_torch_compile decorator
on the class. This is useful when a parent class has on the class. This is useful when a parent class has
...@@ -68,7 +68,7 @@ def ignore_torch_compile(cls: _T) -> _T: ...@@ -68,7 +68,7 @@ def ignore_torch_compile(cls: _T) -> _T:
return cls return cls
def _should_ignore_torch_compile(cls: _T) -> bool: def _should_ignore_torch_compile(cls: type[_T]) -> bool:
""" """
Check if the class should be ignored for torch.compile. Check if the class should be ignored for torch.compile.
""" """
...@@ -79,21 +79,21 @@ def _should_ignore_torch_compile(cls: _T) -> bool: ...@@ -79,21 +79,21 @@ def _should_ignore_torch_compile(cls: _T) -> bool:
def support_torch_compile( def support_torch_compile(
*, *,
enable_if: Callable[[VllmConfig], bool] | None = None, enable_if: Callable[[VllmConfig], bool] | None = None,
) -> Callable[[_T], _T]: ... ) -> Callable[[type[_T]], type[_T]]: ...
@overload @overload
def support_torch_compile( def support_torch_compile(
*, *,
dynamic_arg_dims: dict[str, int | list[int]] | None, dynamic_arg_dims: dict[str, int | list[int]] | None,
) -> Callable[[_T], _T]: ... ) -> Callable[[type[_T]], type[_T]]: ...
@overload @overload
def support_torch_compile( def support_torch_compile(
*, *,
mark_unbacked_dims: dict[str, int | list[int]] | None, mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[_T], _T]: ... ) -> Callable[[type[_T]], type[_T]]: ...
@overload @overload
...@@ -101,21 +101,21 @@ def support_torch_compile( ...@@ -101,21 +101,21 @@ def support_torch_compile(
*, *,
dynamic_arg_dims: dict[str, int | list[int]] | None, dynamic_arg_dims: dict[str, int | list[int]] | None,
mark_unbacked_dims: dict[str, int | list[int]] | None, mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[_T], _T]: ... ) -> Callable[[type[_T]], type[_T]]: ...
@overload @overload
def support_torch_compile(cls: _T) -> _T: ... def support_torch_compile(cls: type[_T]) -> type[_T]: ...
def support_torch_compile( def support_torch_compile(
cls: _T | None = None, cls: type[_T] | None = None,
*, *,
dynamic_arg_dims: dict[str, int | list[int]] | None = None, dynamic_arg_dims: dict[str, int | list[int]] | None = None,
mark_unbacked_dims: dict[str, int | list[int]] | None = None, mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None, enable_if: Callable[[VllmConfig], bool] | None = None,
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None, shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> Callable[[_T], _T] | _T: ) -> Callable[[type[_T]], type[_T]] | type[_T]:
""" """
A decorator to add support for compiling the forward method of a class. A decorator to add support for compiling the forward method of a class.
...@@ -182,7 +182,7 @@ def support_torch_compile( ...@@ -182,7 +182,7 @@ def support_torch_compile(
errors. errors.
""" """
def cls_decorator_helper(cls: _T) -> _T: def cls_decorator_helper(cls: type[_T]) -> type[_T]:
# helper to pass `dynamic_arg_dims` to `_support_torch_compile` # helper to pass `dynamic_arg_dims` to `_support_torch_compile`
# to avoid too much indentation for `_support_torch_compile` # to avoid too much indentation for `_support_torch_compile`
if not hasattr(cls, "forward"): if not hasattr(cls, "forward"):
...@@ -263,12 +263,12 @@ def _verify_source_unchanged( ...@@ -263,12 +263,12 @@ def _verify_source_unchanged(
def _support_torch_compile( def _support_torch_compile(
cls: _T, cls: type[_T],
dynamic_arg_dims: dict[str, int | list[int]], dynamic_arg_dims: dict[str, int | list[int]],
mark_unbacked_dims: dict[str, int | list[int]] | None = None, mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None, enable_if: Callable[[VllmConfig], bool] | None = None,
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None, shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> _T: ) -> type[_T]:
""" """
A decorator to add support for compiling the forward method of a class. A decorator to add support for compiling the forward method of a class.
""" """
...@@ -325,12 +325,12 @@ def _support_torch_compile( ...@@ -325,12 +325,12 @@ def _support_torch_compile(
self.compiled = False self.compiled = False
# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class # Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
TorchCompileWithNoGuardsWrapper.__init__(self) # type: ignore[arg-type] TorchCompileWithNoGuardsWrapper.__init__(self)
cls.__init__ = __init__ cls.__init__ = __init__
def _mark_dynamic_inputs( def _mark_dynamic_inputs(
mod: _T, ds_type: DynamicShapesType, *args: Any, **kwargs: Any mod: type[_T], ds_type: DynamicShapesType, *args: Any, **kwargs: Any
) -> None: ) -> None:
def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None: def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None:
if ds_type == DynamicShapesType.UNBACKED: if ds_type == DynamicShapesType.UNBACKED:
...@@ -382,7 +382,7 @@ def _support_torch_compile( ...@@ -382,7 +382,7 @@ def _support_torch_compile(
else: else:
torch._dynamo.decorators.mark_unbacked(arg, dims) torch._dynamo.decorators.mark_unbacked(arg, dims)
def __call__(self: _T, *args: Any, **kwargs: Any) -> Any: def __call__(self: type[_T], *args: Any, **kwargs: Any) -> Any:
# torch.compiler.is_compiling() means we are inside the compilation # torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't # e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside. # need to compile the model inside.
...@@ -564,7 +564,7 @@ def _support_torch_compile( ...@@ -564,7 +564,7 @@ def _support_torch_compile(
return output return output
# triggers VllmSerializableFunction.serialize() # triggers VllmSerializableFunction.serialize()
def save_aot_compiled_function(self): def save_aot_compiled_function(self: type[_T]) -> None:
if self.was_aot_compile_fn_loaded_from_disk: if self.was_aot_compile_fn_loaded_from_disk:
logger.debug("AOT compiled function was loaded from cache, skipping save") logger.debug("AOT compiled function was loaded from cache, skipping save")
return return
......
...@@ -141,7 +141,8 @@ class MatcherRotaryEmbedding(MatcherCustomOp): ...@@ -141,7 +141,8 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
key: torch.Tensor | None, key: torch.Tensor | None,
cos_sin_cache: torch.Tensor, cos_sin_cache: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
return RotaryEmbedding.forward_static( result: tuple[torch.Tensor, torch.Tensor | None] = (
RotaryEmbedding.forward_static(
positions, positions,
query, query,
key, key,
...@@ -150,6 +151,8 @@ class MatcherRotaryEmbedding(MatcherCustomOp): ...@@ -150,6 +151,8 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
cos_sin_cache, cos_sin_cache,
self.is_neox, self.is_neox,
) )
)
return result
class MatcherRMSNorm(MatcherCustomOp): class MatcherRMSNorm(MatcherCustomOp):
...@@ -275,9 +278,10 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp): ...@@ -275,9 +278,10 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
weight: torch.Tensor, weight: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
return RMSNorm.forward_static( result: tuple[torch.Tensor, torch.Tensor] = RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
) )
return result
class MatcherQuantFP8(MatcherCustomOp): class MatcherQuantFP8(MatcherCustomOp):
......
...@@ -25,7 +25,7 @@ logger = init_logger(__name__) ...@@ -25,7 +25,7 @@ logger = init_logger(__name__)
class RangeEntry: class RangeEntry:
compile_range: Range compile_range: Range
compiled: bool = False compiled: bool = False
runnable: Callable = None # type: ignore runnable: Callable[..., Any] = None # type: ignore
class PiecewiseBackend: class PiecewiseBackend:
...@@ -38,7 +38,7 @@ class PiecewiseBackend: ...@@ -38,7 +38,7 @@ class PiecewiseBackend:
sym_shape_indices: list[int], sym_shape_indices: list[int],
vllm_backend: VllmBackend, vllm_backend: VllmBackend,
returns_tuple: bool, returns_tuple: bool,
compiled_runnables: dict[str, Callable] | None = None, compiled_runnables: dict[str, Callable[..., Any]] | None = None,
): ):
""" """
The backend for piecewise compilation. The backend for piecewise compilation.
...@@ -138,8 +138,10 @@ class PiecewiseBackend: ...@@ -138,8 +138,10 @@ class PiecewiseBackend:
self.on_compilation_complete = _on_compilation_complete_callback.get() self.on_compilation_complete = _on_compilation_complete_callback.get()
def get_compiled_graph_wrapper(self, compiled_graph): def get_compiled_graph_wrapper(
def compiled_graph_wrapper(*args): self, compiled_graph: Callable[..., Any]
) -> Callable[..., Any]:
def compiled_graph_wrapper(*args: Any) -> Any:
graph_output = compiled_graph(*args) graph_output = compiled_graph(*args)
# unpack the tuple if needed # unpack the tuple if needed
# TODO(rzou): the implication is that we're not # TODO(rzou): the implication is that we're not
...@@ -163,7 +165,7 @@ class PiecewiseBackend: ...@@ -163,7 +165,7 @@ class PiecewiseBackend:
def to_bytes(self) -> dict[str, bytes]: def to_bytes(self) -> dict[str, bytes]:
class StandaloneCompiledArtifactsPickler(Pickler): class StandaloneCompiledArtifactsPickler(Pickler):
def reducer_override(self, obj): def reducer_override(self, obj: object) -> Any:
if isinstance(obj, CachingAutotuner): if isinstance(obj, CachingAutotuner):
obj.prepare_for_pickle() obj.prepare_for_pickle()
return pickle.loads, ( return pickle.loads, (
...@@ -173,7 +175,7 @@ class PiecewiseBackend: ...@@ -173,7 +175,7 @@ class PiecewiseBackend:
) )
return NotImplemented return NotImplemented
def serialize(fn) -> bytes: def serialize(fn: Callable[..., Any]) -> bytes:
assert hasattr(fn, "serialize"), "fn must have serialize method" assert hasattr(fn, "serialize"), "fn must have serialize method"
with torch._functorch.config.patch("bundled_autograd_cache", True): with torch._functorch.config.patch("bundled_autograd_cache", True):
entry = fn.serialize() entry = fn.serialize()
......
...@@ -358,7 +358,10 @@ class SequenceParallelismPass(VllmPatternMatcherPass): ...@@ -358,7 +358,10 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
): ):
return True return True
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0) result: bool = (compile_range.is_single_size()) and (
compile_range.end % tp_size == 0
)
return result
@VllmInductorPass.time_and_log @VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None: def __call__(self, graph: fx.Graph) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment