decorators.py 8.06 KB
Newer Older
1
import inspect
2
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
3
4

import torch
5
import torch.nn as nn
6

7
from vllm.compilation.counter import compilation_counter
8
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
9
from vllm.config import CompilationLevel, VllmConfig
10
from vllm.logger import init_logger
11
12
13
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo

14
15
from .monitor import start_monitoring_torch_compile

16
logger = init_logger(__name__)
17

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
_T = TypeVar("_T", bound=type[nn.Module])


@overload
def support_torch_compile(
    *,
    dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]],
) -> Callable[[_T], _T]:
    ...


@overload
def support_torch_compile(cls: _T) -> _T:
    ...

33
34

def support_torch_compile(
35
36
37
38
    cls: Optional[_T] = None,
    *,
    dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None,
) -> Union[Callable[[_T], _T], _T]:
39
40
41
    """
    A decorator to add support for compiling the forward method of a class.

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    Usage 1: use directly as a decorator without arguments:

    ```python
    @support_torch_compile
    class MyModel(nn.Module):
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
            ...
    ```

    Usage 2: use as a decorator with arguments:

    ```python
    @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
    class MyModel(nn.Module):
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
            ...
    ```

60
61
62
63
    `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
    dimensions of the argument. The dynamic dimensions can be either a single
    integer or a list of integers.

64
65
66
67
68
69
70
71
72
73
74
75
    if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
    of the `forward` method, based on the following default rules:

    - if the argument is annotated as `torch.Tensor` or
        `Optional[torch.Tensor]`, the first dimension will be
        marked as dynamic.
    - if the argument is annotated as `IntermediateTensors`, the first
        dimension of all the tensors in the intermediate tensors
        will be marked as dynamic.

    During runtime, when we actually mark dimensions of tensors,
     it depends on the value of arguments:
76
77
78
79
80
81
82
83
84
85
86
87
88

    - if it is a single integer, the corresponding dimension of the argument
        will be marked as dynamic.
    - if it is `None`, ignored.
    - if it is `IntermediateTensors`, all the tensors in the intermediate
        tensors will be marked as dynamic.
    - otherwise, it will raise an error.

    NOTE: if an argument is `None`, it should always be passed as `None` during
    the lifetime of the model, otherwise, it cannot be captured as a single
    computation graph.
    """

89
    def cls_decorator_helper(cls: _T) -> _T:
90
91
        # helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
        # to avoid too much indentation for `_support_torch_compile``
92
93
        if not hasattr(cls, 'forward'):
            raise TypeError("decorated class should have a forward method.")
94
        sig = inspect.signature(cls.forward)
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        inferred_dynamic_arg_dims = dynamic_arg_dims
        if inferred_dynamic_arg_dims is None:
            inferred_dynamic_arg_dims = {}
            for k, v in sig.parameters.items():
                if v.annotation in [
                        torch.Tensor, Optional[torch.Tensor],
                        IntermediateTensors, Optional[IntermediateTensors]
                ]:
                    inferred_dynamic_arg_dims[k] = 0

            logger.debug(("Inferred dynamic dimensions for "
                          "forward method of %s: %s"), cls,
                         list(inferred_dynamic_arg_dims.keys()))

        if len(inferred_dynamic_arg_dims) == 0:
            raise ValueError(
                "No dynamic dimensions found in the forward method of "
                f"{cls}. Please provide dynamic_arg_dims explicitly.")

        for k in inferred_dynamic_arg_dims:
115
116
117
            if k not in sig.parameters:
                raise ValueError(
                    f"Argument {k} not found in the forward method of {cls}")
118
119
120
121
122
123
        return _support_torch_compile(cls, inferred_dynamic_arg_dims)

    if cls is not None:
        # use `support_torch_compile` as a decorator without arguments
        assert isinstance(cls, type)
        return cls_decorator_helper(cls)
124
125
126
127

    return cls_decorator_helper


128
129
130
131
def _support_torch_compile(
    cls: _T,
    dynamic_arg_dims: Dict[str, Union[int, List[int]]],
) -> _T:
132
133
134
    """
    A decorator to add support for compiling the forward method of a class.
    """
135
136
    if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
        # support decorating multiple times
137
138
139
140
141
        return cls

    # take care of method resolution order
    # make sure super().__init__ is called on the base class
    #  other than TorchCompileWrapperWithCustomDispatcher
142
    cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
143

144
    old_init = cls.__init__
145

146
147
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
        old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
148
        self.vllm_config = vllm_config
149
150
        # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
        # will handle the compilation, so we don't need to do anything here.
151
152
        self.do_not_compile = \
            vllm_config.compilation_config.level in [
153
154
155
156
            CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
        ] or not supports_dynamo()
        if self.do_not_compile:
            return
157
        compilation_counter.num_models_seen += 1
158
159
        TorchCompileWrapperWithCustomDispatcher.__init__(
            self, compilation_level=vllm_config.compilation_config.level)
160

161
    cls.__init__ = __init__
162

163
    def __call__(self, *args, **kwargs):
164
165
166
        # torch.compiler.is_compiling() means we are inside the compilation
        # e.g. TPU has the compilation logic in model runner, so we don't
        # need to compile the model inside.
167
        if self.do_not_compile or torch.compiler.is_compiling():
168
            return self.forward(*args, **kwargs)
169
170
171

        # the first compilation needs to have dynamic shapes marked
        if len(self.compiled_codes) < 1:
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
            sig = inspect.signature(self.__class__.forward)
            bound_args = sig.bind(self, *args, **kwargs)
            bound_args.apply_defaults()
            for k, dims in dynamic_arg_dims.items():
                arg = bound_args.arguments.get(k)
                if arg is not None:
                    if isinstance(arg, torch.Tensor):
                        torch._dynamo.mark_dynamic(arg, dims)
                    elif isinstance(arg, IntermediateTensors):
                        for tensor in arg.tensors.values():
                            torch._dynamo.mark_dynamic(tensor, dims)
                    else:
                        raise ValueError(
                            "Unsupported dynamic dimensions"
                            f" {dims} for argument {k} with type {type(arg)}.")
187
            # here, it is the starting point of the `torch.compile` process
188
            start_monitoring_torch_compile(self.vllm_config)
189
190
191
192
193

        # if we don't use custom dispatcher, we can directly call the
        # compiled function and let torch.compile handle the dispatching,
        # with the overhead of guard evaluation and recompilation.
        if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
194
195
196
197
198
            # it seems Dynamo reuse the compilation across instances,
            # while we need to make sure the compiled code is not reused.
            # we need to control all the compilation of the model.
            torch._dynamo.eval_frame.remove_from_cache(
                self.original_code_object)
199
            return self.compiled_callable(*args, **kwargs)
200
201
202
203
204

        # usually, capturing the model once is enough, and then we can
        # dispatch to the compiled code directly, without going through
        # the Dynamo guard mechanism.
        with self.dispatch_to_code(0):
205
            model_output = self.forward(*args, **kwargs)
206
207
            return model_output

208
    cls.__call__ = __call__
209
    return cls