Unverified Commit f877a7d1 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Improve type annotations for `support_torch_compile` (#10763)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 13370712
import inspect
from typing import Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
import torch
import torch.nn as nn
from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
......@@ -12,10 +13,27 @@ from vllm.utils import supports_dynamo
logger = init_logger(__name__)
_T = TypeVar("_T", bound=type[nn.Module])
@overload
def support_torch_compile(
*,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]],
) -> Callable[[_T], _T]:
...
@overload
def support_torch_compile(cls: _T) -> _T:
...
def support_torch_compile(
cls: Optional[type] = None,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None):
cls: Optional[_T] = None,
*,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None,
) -> Union[Callable[[_T], _T], _T]:
"""
A decorator to add support for compiling the forward method of a class.
......@@ -66,7 +84,7 @@ def support_torch_compile(
computation graph.
"""
def cls_decorator_helper(cls: type):
def cls_decorator_helper(cls: _T) -> _T:
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
# to avoid too much indentation for `_support_torch_compile``
if not hasattr(cls, 'forward'):
......@@ -105,8 +123,10 @@ def support_torch_compile(
return cls_decorator_helper
def _support_torch_compile(cls: type,
dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
def _support_torch_compile(
cls: _T,
dynamic_arg_dims: Dict[str, Union[int, List[int]]],
) -> _T:
"""
A decorator to add support for compiling the forward method of a class.
"""
......@@ -119,7 +139,7 @@ def _support_torch_compile(cls: type,
# other than TorchCompileWrapperWithCustomDispatcher
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
old_init = cls.__init__ # type: ignore
old_init = cls.__init__
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
......@@ -135,7 +155,7 @@ def _support_torch_compile(cls: type,
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level)
cls.__init__ = __init__ # type: ignore
cls.__init__ = __init__
def __call__(self, *args, **kwargs):
# torch.compiler.is_compiling() means we are inside the compilation
......@@ -180,5 +200,5 @@ def _support_torch_compile(cls: type,
model_output = self.forward(*args, **kwargs)
return model_output
cls.__call__ = __call__ # type: ignore
cls.__call__ = __call__
return cls
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment