Unverified Commit f877a7d1 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Improve type annotations for `support_torch_compile` (#10763)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 13370712
import inspect import inspect
from typing import Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
import torch import torch
import torch.nn as nn
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
...@@ -12,10 +13,27 @@ from vllm.utils import supports_dynamo ...@@ -12,10 +13,27 @@ from vllm.utils import supports_dynamo
logger = init_logger(__name__) logger = init_logger(__name__)
_T = TypeVar("_T", bound=type[nn.Module])
@overload
def support_torch_compile(
*,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]],
) -> Callable[[_T], _T]:
...
@overload
def support_torch_compile(cls: _T) -> _T:
...
def support_torch_compile( def support_torch_compile(
cls: Optional[type] = None, cls: Optional[_T] = None,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None): *,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None,
) -> Union[Callable[[_T], _T], _T]:
""" """
A decorator to add support for compiling the forward method of a class. A decorator to add support for compiling the forward method of a class.
...@@ -66,7 +84,7 @@ def support_torch_compile( ...@@ -66,7 +84,7 @@ def support_torch_compile(
computation graph. computation graph.
""" """
def cls_decorator_helper(cls: type): def cls_decorator_helper(cls: _T) -> _T:
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile`` # helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
# to avoid too much indentation for `_support_torch_compile`` # to avoid too much indentation for `_support_torch_compile``
if not hasattr(cls, 'forward'): if not hasattr(cls, 'forward'):
...@@ -105,8 +123,10 @@ def support_torch_compile( ...@@ -105,8 +123,10 @@ def support_torch_compile(
return cls_decorator_helper return cls_decorator_helper
def _support_torch_compile(cls: type, def _support_torch_compile(
dynamic_arg_dims: Dict[str, Union[int, List[int]]]): cls: _T,
dynamic_arg_dims: Dict[str, Union[int, List[int]]],
) -> _T:
""" """
A decorator to add support for compiling the forward method of a class. A decorator to add support for compiling the forward method of a class.
""" """
...@@ -119,7 +139,7 @@ def _support_torch_compile(cls: type, ...@@ -119,7 +139,7 @@ def _support_torch_compile(cls: type,
# other than TorchCompileWrapperWithCustomDispatcher # other than TorchCompileWrapperWithCustomDispatcher
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
old_init = cls.__init__ # type: ignore old_init = cls.__init__
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
...@@ -135,7 +155,7 @@ def _support_torch_compile(cls: type, ...@@ -135,7 +155,7 @@ def _support_torch_compile(cls: type,
TorchCompileWrapperWithCustomDispatcher.__init__( TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level) self, compilation_level=vllm_config.compilation_config.level)
cls.__init__ = __init__ # type: ignore cls.__init__ = __init__
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
# torch.compiler.is_compiling() means we are inside the compilation # torch.compiler.is_compiling() means we are inside the compilation
...@@ -180,5 +200,5 @@ def _support_torch_compile(cls: type, ...@@ -180,5 +200,5 @@ def _support_torch_compile(cls: type,
model_output = self.forward(*args, **kwargs) model_output = self.forward(*args, **kwargs)
return model_output return model_output
cls.__call__ = __call__ # type: ignore cls.__call__ = __call__
return cls return cls
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment