decorators.py 12.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import inspect
5
from typing import Callable, Optional, TypeVar, Union, overload
6
from unittest.mock import patch
7
8

import torch
9
import torch.nn as nn
10
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
11

12
from vllm.compilation.counter import compilation_counter
13
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
14
from vllm.config import CompilationLevel, VllmConfig
15
from vllm.logger import init_logger
16
17
18
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo

19
20
from .monitor import start_monitoring_torch_compile

21
logger = init_logger(__name__)
22

23
24
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"

25
26
27
_T = TypeVar("_T", bound=type[nn.Module])


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def ignore_torch_compile(cls: _T) -> _T:
    """
    A decorator to ignore support_torch_compile decorator
    on the class. This is useful when a parent class has
    a support_torch_compile decorator, but we don't want to
    compile the class `cls` that inherits the parent class.
    This only ignores compiling the forward of the class the
    decorator is applied to. 

    If the parent has ignore_torch_compile but the child has
    support_torch_compile, the child will still be compiled.
    
    If the class has one or more submodules
    that have support_torch_compile decorator applied, compile will
    not be ignored for those submodules.
    """
    setattr(cls, IGNORE_COMPILE_KEY, True)
    return cls


def _should_ignore_torch_compile(cls) -> bool:
    """
    Check if the class should be ignored for torch.compile.
    """
    return getattr(cls, IGNORE_COMPILE_KEY, False)


55
56
57
58
59
60
61
62
@overload
def support_torch_compile(
    *,
    enable_if: Optional[Callable[[VllmConfig], bool]] = None,
) -> Callable[[_T], _T]:
    ...


63
64
65
@overload
def support_torch_compile(
    *,
66
    dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]],
67
68
69
70
71
72
73
74
) -> Callable[[_T], _T]:
    ...


@overload
def support_torch_compile(cls: _T) -> _T:
    ...

75
76

def support_torch_compile(
77
78
    cls: Optional[_T] = None,
    *,
79
    dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
80
    enable_if: Optional[Callable[[VllmConfig], bool]] = None,
81
) -> Union[Callable[[_T], _T], _T]:
82
83
84
    """
    A decorator to add support for compiling the forward method of a class.

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    Usage 1: use directly as a decorator without arguments:

    ```python
    @support_torch_compile
    class MyModel(nn.Module):
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
            ...
    ```

    Usage 2: use as a decorator with arguments:

    ```python
    @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
    class MyModel(nn.Module):
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
            ...
    ```

103
104
105
106
    `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
    dimensions of the argument. The dynamic dimensions can be either a single
    integer or a list of integers.

107
108
109
110
111
112
113
114
115
116
117
118
    if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
    of the `forward` method, based on the following default rules:

    - if the argument is annotated as `torch.Tensor` or
        `Optional[torch.Tensor]`, the first dimension will be
        marked as dynamic.
    - if the argument is annotated as `IntermediateTensors`, the first
        dimension of all the tensors in the intermediate tensors
        will be marked as dynamic.

    During runtime, when we actually mark dimensions of tensors,
     it depends on the value of arguments:
119

120
    - if it is a single integer (can be negative), the corresponding dimension
121
        of the argument will be marked as dynamic.
122
123
124
125
126
127
128
129
    - if it is `None`, ignored.
    - if it is `IntermediateTensors`, all the tensors in the intermediate
        tensors will be marked as dynamic.
    - otherwise, it will raise an error.

    NOTE: if an argument is `None`, it should always be passed as `None` during
    the lifetime of the model, otherwise, it cannot be captured as a single
    computation graph.
130
131
132
133
134

    `enable_if` is a function that takes a `VllmConfig` object as input and
    returns a boolean value indicating whether to compile the model or not.
    This is useful if you want to compile the model only when certain
    conditions are met.
135
136
    """

137
    def cls_decorator_helper(cls: _T) -> _T:
138
139
        # helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
        # to avoid too much indentation for `_support_torch_compile``
140
141
        if not hasattr(cls, 'forward'):
            raise TypeError("decorated class should have a forward method.")
142
        sig = inspect.signature(cls.forward)
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        inferred_dynamic_arg_dims = dynamic_arg_dims
        if inferred_dynamic_arg_dims is None:
            inferred_dynamic_arg_dims = {}
            for k, v in sig.parameters.items():
                if v.annotation in [
                        torch.Tensor, Optional[torch.Tensor],
                        IntermediateTensors, Optional[IntermediateTensors]
                ]:
                    inferred_dynamic_arg_dims[k] = 0

            logger.debug(("Inferred dynamic dimensions for "
                          "forward method of %s: %s"), cls,
                         list(inferred_dynamic_arg_dims.keys()))

        if len(inferred_dynamic_arg_dims) == 0:
            raise ValueError(
                "No dynamic dimensions found in the forward method of "
                f"{cls}. Please provide dynamic_arg_dims explicitly.")

        for k in inferred_dynamic_arg_dims:
163
164
165
            if k not in sig.parameters:
                raise ValueError(
                    f"Argument {k} not found in the forward method of {cls}")
166
167
        return _support_torch_compile(cls, inferred_dynamic_arg_dims,
                                      enable_if)
168
169
170
171
172

    if cls is not None:
        # use `support_torch_compile` as a decorator without arguments
        assert isinstance(cls, type)
        return cls_decorator_helper(cls)
173
174
175
176

    return cls_decorator_helper


177
178
def _support_torch_compile(
    cls: _T,
179
    dynamic_arg_dims: dict[str, Union[int, list[int]]],
180
    enable_if: Optional[Callable[[VllmConfig], bool]] = None,
181
) -> _T:
182
183
184
    """
    A decorator to add support for compiling the forward method of a class.
    """
185
186
    if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
        # support decorating multiple times
187
188
189
190
191
        return cls

    # take care of method resolution order
    # make sure super().__init__ is called on the base class
    #  other than TorchCompileWrapperWithCustomDispatcher
192
    cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
193

194
    old_init = cls.__init__
195

196
197
    setattr(cls, IGNORE_COMPILE_KEY, False)

198
199
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
        old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
200
        self.vllm_config = vllm_config
201
        enable_compile = enable_if is None or enable_if(vllm_config)
202
203
        # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
        # will handle the compilation, so we don't need to do anything here.
204
205
        self.do_not_compile = \
            vllm_config.compilation_config.level in [
206
            CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
207
        ] or not supports_dynamo() or _should_ignore_torch_compile(
208
            self.__class__) or not enable_compile
209
210
        if self.do_not_compile:
            return
211

212
        compilation_counter.num_models_seen += 1
213
214
        TorchCompileWrapperWithCustomDispatcher.__init__(
            self, compilation_level=vllm_config.compilation_config.level)
215

216
    cls.__init__ = __init__
217

218
    def __call__(self, *args, **kwargs):
219
220
221
        # torch.compiler.is_compiling() means we are inside the compilation
        # e.g. TPU has the compilation logic in model runner, so we don't
        # need to compile the model inside.
222
        if self.do_not_compile or torch.compiler.is_compiling():
223
            return self.forward(*args, **kwargs)
224
225
226

        # the first compilation needs to have dynamic shapes marked
        if len(self.compiled_codes) < 1:
227
228
229
230
231
232
            sig = inspect.signature(self.__class__.forward)
            bound_args = sig.bind(self, *args, **kwargs)
            bound_args.apply_defaults()
            for k, dims in dynamic_arg_dims.items():
                arg = bound_args.arguments.get(k)
                if arg is not None:
233
                    dims = [dims] if isinstance(dims, int) else dims
234
                    if isinstance(arg, torch.Tensor):
235
236
237
238
                        # In case dims is specified with negative indexing
                        dims = [
                            arg.ndim + dim if dim < 0 else dim for dim in dims
                        ]
239
240
241
                        torch._dynamo.mark_dynamic(arg, dims)
                    elif isinstance(arg, IntermediateTensors):
                        for tensor in arg.tensors.values():
242
243
244
245
246
                            # In case dims is specified with negative indexing
                            dims = [
                                tensor.ndim + dim if dim < 0 else dim
                                for dim in dims
                            ]
247
248
249
250
251
                            torch._dynamo.mark_dynamic(tensor, dims)
                    else:
                        raise ValueError(
                            "Unsupported dynamic dimensions"
                            f" {dims} for argument {k} with type {type(arg)}.")
252
            # here, it is the starting point of the `torch.compile` process
253
            start_monitoring_torch_compile(self.vllm_config)
254
255
            logger.debug("Start compiling function %s",
                         self.original_code_object)
256
257
258
259
260

        # if we don't use custom dispatcher, we can directly call the
        # compiled function and let torch.compile handle the dispatching,
        # with the overhead of guard evaluation and recompilation.
        if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
261
262
263
264
265
            # it seems Dynamo reuse the compilation across instances,
            # while we need to make sure the compiled code is not reused.
            # we need to control all the compilation of the model.
            torch._dynamo.eval_frame.remove_from_cache(
                self.original_code_object)
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286

            # collect all relevant files traced by Dynamo,
            # so that the compilation cache can trigger re-compilation
            # properly when any of these files change.

            # 1. the file containing the top-level forward function
            self.vllm_config.compilation_config.traced_files.add(
                self.original_code_object.co_filename)

            # 2. every time Dynamo sees a function call, it will inline
            # the function by calling InliningInstructionTranslator.inline_call
            # we hijack this function to know all the functions called
            # during Dynamo tracing, and their corresponding files
            inline_call = InliningInstructionTranslator.inline_call

            def patched_inline_call(parent, func, args, kwargs):
                code = func.get_code()
                self.vllm_config.compilation_config.traced_files.add(
                    code.co_filename)
                return inline_call(parent, func, args, kwargs)

287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
            # Disable the C++ compilation of symbolic shape guards. C++-fication
            # of symbolic shape guards can improve guard overhead. But, since
            # vllm skip guards anyways, setting this flag to False can improve
            # compile time.
            dynamo_config_patches = {}
            try:
                _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
                dynamo_config_patches[
                    "enable_cpp_symbolic_shape_guards"] = False
            except AttributeError:
                # Note: this config is not available in torch 2.6, we can skip
                # if the config doesn't exist
                logger.debug(
                    "enable_cpp_symbolic_shape_guards config not available")

302
            with patch.object(InliningInstructionTranslator, 'inline_call',
303
304
                              patched_inline_call), torch._dynamo.config.patch(
                                  **dynamo_config_patches):
305
306
                output = self.compiled_callable(*args, **kwargs)
            return output
307
308
309
310
311

        # usually, capturing the model once is enough, and then we can
        # dispatch to the compiled code directly, without going through
        # the Dynamo guard mechanism.
        with self.dispatch_to_code(0):
312
            model_output = self.forward(*args, **kwargs)
313
314
            return model_output

315
    cls.__call__ = __call__
316
    return cls