decorators.py 11.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import inspect
5
from typing import Callable, Optional, TypeVar, Union, overload
6
from unittest.mock import patch
7
8

import torch
9
import torch.nn as nn
10
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
11

12
from vllm import envs
13
from vllm.compilation.counter import compilation_counter
14
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
15
from vllm.forward_context import get_forward_context, get_profilling
16
from vllm.config import CompilationLevel, VllmConfig
17
from vllm.logger import init_logger
18
19
20
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo

21
22
from .monitor import start_monitoring_torch_compile

23
logger = init_logger(__name__)
24

25
26
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"

27
28
29
_T = TypeVar("_T", bound=type[nn.Module])


30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def ignore_torch_compile(cls: _T) -> _T:
    """
    A decorator to ignore support_torch_compile decorator
    on the class. This is useful when a parent class has
    a support_torch_compile decorator, but we don't want to
    compile the class `cls` that inherits the parent class.
    This only ignores compiling the forward of the class the
    decorator is applied to. 

    If the parent has ignore_torch_compile but the child has
    support_torch_compile, the child will still be compiled.
    
    If the class has one or more submodules
    that have support_torch_compile decorator applied, compile will
    not be ignored for those submodules.
    """
    setattr(cls, IGNORE_COMPILE_KEY, True)
    return cls


def _should_ignore_torch_compile(cls) -> bool:
    """
    Check if the class should be ignored for torch.compile.
    """
    return getattr(cls, IGNORE_COMPILE_KEY, False)


57
58
59
@overload
def support_torch_compile(
    *,
60
    dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]],
61
62
63
64
65
66
67
68
) -> Callable[[_T], _T]:
    ...


@overload
def support_torch_compile(cls: _T) -> _T:
    ...

69
70

def support_torch_compile(
71
72
    cls: Optional[_T] = None,
    *,
73
    dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
74
) -> Union[Callable[[_T], _T], _T]:
75
76
77
    """
    A decorator to add support for compiling the forward method of a class.

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    Usage 1: use directly as a decorator without arguments:

    ```python
    @support_torch_compile
    class MyModel(nn.Module):
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
            ...
    ```

    Usage 2: use as a decorator with arguments:

    ```python
    @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
    class MyModel(nn.Module):
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
            ...
    ```

96
97
98
99
    `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
    dimensions of the argument. The dynamic dimensions can be either a single
    integer or a list of integers.

100
101
102
103
104
105
106
107
108
109
110
111
    if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
    of the `forward` method, based on the following default rules:

    - if the argument is annotated as `torch.Tensor` or
        `Optional[torch.Tensor]`, the first dimension will be
        marked as dynamic.
    - if the argument is annotated as `IntermediateTensors`, the first
        dimension of all the tensors in the intermediate tensors
        will be marked as dynamic.

    During runtime, when we actually mark dimensions of tensors,
     it depends on the value of arguments:
112

113
    - if it is a single integer (can be negative), the corresponding dimension
114
        of the argument will be marked as dynamic.
115
116
117
118
119
120
121
122
123
124
    - if it is `None`, ignored.
    - if it is `IntermediateTensors`, all the tensors in the intermediate
        tensors will be marked as dynamic.
    - otherwise, it will raise an error.

    NOTE: if an argument is `None`, it should always be passed as `None` during
    the lifetime of the model, otherwise, it cannot be captured as a single
    computation graph.
    """

125
    def cls_decorator_helper(cls: _T) -> _T:
126
127
        # helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
        # to avoid too much indentation for `_support_torch_compile``
128
129
        if not hasattr(cls, 'forward'):
            raise TypeError("decorated class should have a forward method.")
130
        sig = inspect.signature(cls.forward)
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        inferred_dynamic_arg_dims = dynamic_arg_dims
        if inferred_dynamic_arg_dims is None:
            inferred_dynamic_arg_dims = {}
            for k, v in sig.parameters.items():
                if v.annotation in [
                        torch.Tensor, Optional[torch.Tensor],
                        IntermediateTensors, Optional[IntermediateTensors]
                ]:
                    inferred_dynamic_arg_dims[k] = 0

            logger.debug(("Inferred dynamic dimensions for "
                          "forward method of %s: %s"), cls,
                         list(inferred_dynamic_arg_dims.keys()))

        if len(inferred_dynamic_arg_dims) == 0:
            raise ValueError(
                "No dynamic dimensions found in the forward method of "
                f"{cls}. Please provide dynamic_arg_dims explicitly.")

        for k in inferred_dynamic_arg_dims:
151
152
153
            if k not in sig.parameters:
                raise ValueError(
                    f"Argument {k} not found in the forward method of {cls}")
154
155
156
157
158
159
        return _support_torch_compile(cls, inferred_dynamic_arg_dims)

    if cls is not None:
        # use `support_torch_compile` as a decorator without arguments
        assert isinstance(cls, type)
        return cls_decorator_helper(cls)
160
161
162
163

    return cls_decorator_helper


164
165
def _support_torch_compile(
    cls: _T,
166
    dynamic_arg_dims: dict[str, Union[int, list[int]]],
167
) -> _T:
168
169
170
    """
    A decorator to add support for compiling the forward method of a class.
    """
171
172
    if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
        # support decorating multiple times
173
174
175
176
177
        return cls

    # take care of method resolution order
    # make sure super().__init__ is called on the base class
    #  other than TorchCompileWrapperWithCustomDispatcher
178
    cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
179

180
    old_init = cls.__init__
181

182
183
    setattr(cls, IGNORE_COMPILE_KEY, False)

184
185
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
        old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
186
        self.vllm_config = vllm_config
187
188
        # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
        # will handle the compilation, so we don't need to do anything here.
189
190
        self.do_not_compile = \
            vllm_config.compilation_config.level in [
191
            CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
192
193
        ] or not supports_dynamo() or _should_ignore_torch_compile(
            self.__class__)
194
195
        if self.do_not_compile:
            return
196

197
        compilation_counter.num_models_seen += 1
198
199
        TorchCompileWrapperWithCustomDispatcher.__init__(
            self, compilation_level=vllm_config.compilation_config.level)
200

201
    cls.__init__ = __init__
202

203
    def __call__(self, *args, **kwargs):
204
205
206
        # torch.compiler.is_compiling() means we are inside the compilation
        # e.g. TPU has the compilation logic in model runner, so we don't
        # need to compile the model inside.
207
208
209
210
        skip_cuda_graphs = get_forward_context().skip_cuda_graphs
        if envs.VLLM_ENABLE_TBO and skip_cuda_graphs:
            return self.forward(*args, **kwargs)

211
        if self.do_not_compile or torch.compiler.is_compiling() or get_profilling():
212
            return self.forward(*args, **kwargs)
213
214
215

        # the first compilation needs to have dynamic shapes marked
        if len(self.compiled_codes) < 1:
216
217
218
219
220
221
            sig = inspect.signature(self.__class__.forward)
            bound_args = sig.bind(self, *args, **kwargs)
            bound_args.apply_defaults()
            for k, dims in dynamic_arg_dims.items():
                arg = bound_args.arguments.get(k)
                if arg is not None:
222
                    dims = [dims] if isinstance(dims, int) else dims
223
                    if isinstance(arg, torch.Tensor):
224
225
226
227
                        # In case dims is specified with negative indexing
                        dims = [
                            arg.ndim + dim if dim < 0 else dim for dim in dims
                        ]
228
229
230
                        torch._dynamo.mark_dynamic(arg, dims)
                    elif isinstance(arg, IntermediateTensors):
                        for tensor in arg.tensors.values():
231
232
233
234
235
                            # In case dims is specified with negative indexing
                            dims = [
                                tensor.ndim + dim if dim < 0 else dim
                                for dim in dims
                            ]
236
237
238
239
240
                            torch._dynamo.mark_dynamic(tensor, dims)
                    else:
                        raise ValueError(
                            "Unsupported dynamic dimensions"
                            f" {dims} for argument {k} with type {type(arg)}.")
241
            # here, it is the starting point of the `torch.compile` process
242
            start_monitoring_torch_compile(self.vllm_config)
243
244
            logger.debug("Start compiling function %s",
                         self.original_code_object)
245
246
247
248
249

        # if we don't use custom dispatcher, we can directly call the
        # compiled function and let torch.compile handle the dispatching,
        # with the overhead of guard evaluation and recompilation.
        if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
250
251
252
253
254
            # it seems Dynamo reuse the compilation across instances,
            # while we need to make sure the compiled code is not reused.
            # we need to control all the compilation of the model.
            torch._dynamo.eval_frame.remove_from_cache(
                self.original_code_object)
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

            # collect all relevant files traced by Dynamo,
            # so that the compilation cache can trigger re-compilation
            # properly when any of these files change.

            # 1. the file containing the top-level forward function
            self.vllm_config.compilation_config.traced_files.add(
                self.original_code_object.co_filename)

            # 2. every time Dynamo sees a function call, it will inline
            # the function by calling InliningInstructionTranslator.inline_call
            # we hijack this function to know all the functions called
            # during Dynamo tracing, and their corresponding files
            inline_call = InliningInstructionTranslator.inline_call

            def patched_inline_call(parent, func, args, kwargs):
                code = func.get_code()
                self.vllm_config.compilation_config.traced_files.add(
                    code.co_filename)
                return inline_call(parent, func, args, kwargs)

            with patch.object(InliningInstructionTranslator, 'inline_call',
                              patched_inline_call):
                output = self.compiled_callable(*args, **kwargs)
            return output
280
281
282
283
284

        # usually, capturing the model once is enough, and then we can
        # dispatch to the compiled code directly, without going through
        # the Dynamo guard mechanism.
        with self.dispatch_to_code(0):
285
            model_output = self.forward(*args, **kwargs)
286
287
            return model_output

288
    cls.__call__ = __call__
289
    return cls