decorators.py 10.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import inspect
5
from typing import Callable, Optional, TypeVar, Union, overload
6
from unittest.mock import patch
7
8

import torch
9
import torch.nn as nn
10
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
11

12
from vllm import envs
13
from vllm.compilation.counter import compilation_counter
14
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
15
from vllm.forward_context import get_forward_context, get_profilling
16
from vllm.config import CompilationLevel, VllmConfig
17
from vllm.logger import init_logger
18
19
20
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo

21
22
from .monitor import start_monitoring_torch_compile

23
logger = init_logger(__name__)
24

25
26
27
28
29
30
_T = TypeVar("_T", bound=type[nn.Module])


@overload
def support_torch_compile(
    *,
31
    dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]],
32
33
34
35
36
37
38
39
) -> Callable[[_T], _T]:
    ...


@overload
def support_torch_compile(cls: _T) -> _T:
    ...

40
41

def support_torch_compile(
42
43
    cls: Optional[_T] = None,
    *,
44
    dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
45
) -> Union[Callable[[_T], _T], _T]:
46
47
48
    """
    A decorator to add support for compiling the forward method of a class.

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    Usage 1: use directly as a decorator without arguments:

    ```python
    @support_torch_compile
    class MyModel(nn.Module):
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
            ...
    ```

    Usage 2: use as a decorator with arguments:

    ```python
    @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
    class MyModel(nn.Module):
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
            ...
    ```

67
68
69
70
    `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
    dimensions of the argument. The dynamic dimensions can be either a single
    integer or a list of integers.

71
72
73
74
75
76
77
78
79
80
81
82
    if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
    of the `forward` method, based on the following default rules:

    - if the argument is annotated as `torch.Tensor` or
        `Optional[torch.Tensor]`, the first dimension will be
        marked as dynamic.
    - if the argument is annotated as `IntermediateTensors`, the first
        dimension of all the tensors in the intermediate tensors
        will be marked as dynamic.

    During runtime, when we actually mark dimensions of tensors,
     it depends on the value of arguments:
83

84
85
    - if it is a single integer (can be negative), the corresponding dimension 
        of the argument will be marked as dynamic.
86
87
88
89
90
91
92
93
94
95
    - if it is `None`, ignored.
    - if it is `IntermediateTensors`, all the tensors in the intermediate
        tensors will be marked as dynamic.
    - otherwise, it will raise an error.

    NOTE: if an argument is `None`, it should always be passed as `None` during
    the lifetime of the model, otherwise, it cannot be captured as a single
    computation graph.
    """

96
    def cls_decorator_helper(cls: _T) -> _T:
97
98
        # helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
        # to avoid too much indentation for `_support_torch_compile``
99
100
        if not hasattr(cls, 'forward'):
            raise TypeError("decorated class should have a forward method.")
101
        sig = inspect.signature(cls.forward)
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        inferred_dynamic_arg_dims = dynamic_arg_dims
        if inferred_dynamic_arg_dims is None:
            inferred_dynamic_arg_dims = {}
            for k, v in sig.parameters.items():
                if v.annotation in [
                        torch.Tensor, Optional[torch.Tensor],
                        IntermediateTensors, Optional[IntermediateTensors]
                ]:
                    inferred_dynamic_arg_dims[k] = 0

            logger.debug(("Inferred dynamic dimensions for "
                          "forward method of %s: %s"), cls,
                         list(inferred_dynamic_arg_dims.keys()))

        if len(inferred_dynamic_arg_dims) == 0:
            raise ValueError(
                "No dynamic dimensions found in the forward method of "
                f"{cls}. Please provide dynamic_arg_dims explicitly.")

        for k in inferred_dynamic_arg_dims:
122
123
124
            if k not in sig.parameters:
                raise ValueError(
                    f"Argument {k} not found in the forward method of {cls}")
125
126
127
128
129
130
        return _support_torch_compile(cls, inferred_dynamic_arg_dims)

    if cls is not None:
        # use `support_torch_compile` as a decorator without arguments
        assert isinstance(cls, type)
        return cls_decorator_helper(cls)
131
132
133
134

    return cls_decorator_helper


135
136
def _support_torch_compile(
    cls: _T,
137
    dynamic_arg_dims: dict[str, Union[int, list[int]]],
138
) -> _T:
139
140
141
    """
    A decorator to add support for compiling the forward method of a class.
    """
142
143
    if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
        # support decorating multiple times
144
145
146
147
148
        return cls

    # take care of method resolution order
    # make sure super().__init__ is called on the base class
    #  other than TorchCompileWrapperWithCustomDispatcher
149
    cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
150

151
    old_init = cls.__init__
152

153
154
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
        old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
155
        self.vllm_config = vllm_config
156
157
        # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
        # will handle the compilation, so we don't need to do anything here.
158
159
        self.do_not_compile = \
            vllm_config.compilation_config.level in [
160
161
162
163
            CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
        ] or not supports_dynamo()
        if self.do_not_compile:
            return
164
        compilation_counter.num_models_seen += 1
165
166
        TorchCompileWrapperWithCustomDispatcher.__init__(
            self, compilation_level=vllm_config.compilation_config.level)
167

168
    cls.__init__ = __init__
169

170
    def __call__(self, *args, **kwargs):
171
172
173
        # torch.compiler.is_compiling() means we are inside the compilation
        # e.g. TPU has the compilation logic in model runner, so we don't
        # need to compile the model inside.
174
175
176
177
        skip_cuda_graphs = get_forward_context().skip_cuda_graphs
        if envs.VLLM_ENABLE_TBO and skip_cuda_graphs:
            return self.forward(*args, **kwargs)

178
        if self.do_not_compile or torch.compiler.is_compiling() or get_profilling():
179
            return self.forward(*args, **kwargs)
180
181
182

        # the first compilation needs to have dynamic shapes marked
        if len(self.compiled_codes) < 1:
183
184
185
186
187
188
            sig = inspect.signature(self.__class__.forward)
            bound_args = sig.bind(self, *args, **kwargs)
            bound_args.apply_defaults()
            for k, dims in dynamic_arg_dims.items():
                arg = bound_args.arguments.get(k)
                if arg is not None:
189
                    dims = [dims] if isinstance(dims, int) else dims
190
                    if isinstance(arg, torch.Tensor):
191
192
193
194
                        # In case dims is specified with negative indexing
                        dims = [
                            arg.ndim + dim if dim < 0 else dim for dim in dims
                        ]
195
196
197
                        torch._dynamo.mark_dynamic(arg, dims)
                    elif isinstance(arg, IntermediateTensors):
                        for tensor in arg.tensors.values():
198
199
200
201
202
                            # In case dims is specified with negative indexing
                            dims = [
                                tensor.ndim + dim if dim < 0 else dim
                                for dim in dims
                            ]
203
204
205
206
207
                            torch._dynamo.mark_dynamic(tensor, dims)
                    else:
                        raise ValueError(
                            "Unsupported dynamic dimensions"
                            f" {dims} for argument {k} with type {type(arg)}.")
208
            # here, it is the starting point of the `torch.compile` process
209
            start_monitoring_torch_compile(self.vllm_config)
210
211
            logger.debug("Start compiling function %s",
                         self.original_code_object)
212
213
214
215
216

        # if we don't use custom dispatcher, we can directly call the
        # compiled function and let torch.compile handle the dispatching,
        # with the overhead of guard evaluation and recompilation.
        if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
217
218
219
220
221
            # it seems Dynamo reuse the compilation across instances,
            # while we need to make sure the compiled code is not reused.
            # we need to control all the compilation of the model.
            torch._dynamo.eval_frame.remove_from_cache(
                self.original_code_object)
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

            # collect all relevant files traced by Dynamo,
            # so that the compilation cache can trigger re-compilation
            # properly when any of these files change.

            # 1. the file containing the top-level forward function
            self.vllm_config.compilation_config.traced_files.add(
                self.original_code_object.co_filename)

            # 2. every time Dynamo sees a function call, it will inline
            # the function by calling InliningInstructionTranslator.inline_call
            # we hijack this function to know all the functions called
            # during Dynamo tracing, and their corresponding files
            inline_call = InliningInstructionTranslator.inline_call

            def patched_inline_call(parent, func, args, kwargs):
                code = func.get_code()
                self.vllm_config.compilation_config.traced_files.add(
                    code.co_filename)
                return inline_call(parent, func, args, kwargs)

            with patch.object(InliningInstructionTranslator, 'inline_call',
                              patched_inline_call):
                output = self.compiled_callable(*args, **kwargs)
            return output
247
248
249
250
251

        # usually, capturing the model once is enough, and then we can
        # dispatch to the compiled code directly, without going through
        # the Dynamo guard mechanism.
        with self.dispatch_to_code(0):
252
            model_output = self.forward(*args, **kwargs)
253
254
            return model_output

255
    cls.__call__ = __call__
256
    return cls