decorators.py 15.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import inspect
6
from typing import Callable, Optional, TypeVar, Union, overload
7
from unittest.mock import patch
8
9

import torch
10
import torch.nn as nn
11
from packaging import version
12
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
13

14
from vllm.compilation.counter import compilation_counter
15
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
16
from vllm.config import CompilationLevel, VllmConfig
17
from vllm.logger import init_logger
18
from vllm.sequence import IntermediateTensors
19
from vllm.utils import resolve_obj_by_qualname, supports_dynamo
20

21
22
from .monitor import start_monitoring_torch_compile

23
logger = init_logger(__name__)
24

25
26
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"

27
28
29
_T = TypeVar("_T", bound=type[nn.Module])


30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def ignore_torch_compile(cls: _T) -> _T:
    """
    A decorator to ignore support_torch_compile decorator
    on the class. This is useful when a parent class has
    a support_torch_compile decorator, but we don't want to
    compile the class `cls` that inherits the parent class.
    This only ignores compiling the forward of the class the
    decorator is applied to. 

    If the parent has ignore_torch_compile but the child has
    support_torch_compile, the child will still be compiled.
    
    If the class has one or more submodules
    that have support_torch_compile decorator applied, compile will
    not be ignored for those submodules.
    """
    setattr(cls, IGNORE_COMPILE_KEY, True)
    return cls


def _should_ignore_torch_compile(cls) -> bool:
    """
    Check if the class should be ignored for torch.compile.
    """
    return getattr(cls, IGNORE_COMPILE_KEY, False)


57
58
59
60
61
62
63
64
@overload
def support_torch_compile(
    *,
    enable_if: Optional[Callable[[VllmConfig], bool]] = None,
) -> Callable[[_T], _T]:
    ...


65
66
67
@overload
def support_torch_compile(
    *,
68
    dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]],
69
70
71
72
73
74
75
76
) -> Callable[[_T], _T]:
    ...


@overload
def support_torch_compile(cls: _T) -> _T:
    ...

77
78

def support_torch_compile(
79
80
    cls: Optional[_T] = None,
    *,
81
    dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
82
    enable_if: Optional[Callable[[VllmConfig], bool]] = None,
83
) -> Union[Callable[[_T], _T], _T]:
84
85
86
    """
    A decorator to add support for compiling the forward method of a class.

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    Usage 1: use directly as a decorator without arguments:

    ```python
    @support_torch_compile
    class MyModel(nn.Module):
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
            ...
    ```

    Usage 2: use as a decorator with arguments:

    ```python
    @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
    class MyModel(nn.Module):
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
            ...
    ```

105
106
107
108
    `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
    dimensions of the argument. The dynamic dimensions can be either a single
    integer or a list of integers.

109
110
111
112
113
114
115
116
117
118
119
120
    if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
    of the `forward` method, based on the following default rules:

    - if the argument is annotated as `torch.Tensor` or
        `Optional[torch.Tensor]`, the first dimension will be
        marked as dynamic.
    - if the argument is annotated as `IntermediateTensors`, the first
        dimension of all the tensors in the intermediate tensors
        will be marked as dynamic.

    During runtime, when we actually mark dimensions of tensors,
     it depends on the value of arguments:
121

122
    - if it is a single integer (can be negative), the corresponding dimension
123
        of the argument will be marked as dynamic.
124
125
126
127
128
129
130
131
    - if it is `None`, ignored.
    - if it is `IntermediateTensors`, all the tensors in the intermediate
        tensors will be marked as dynamic.
    - otherwise, it will raise an error.

    NOTE: if an argument is `None`, it should always be passed as `None` during
    the lifetime of the model, otherwise, it cannot be captured as a single
    computation graph.
132
133
134
135
136

    `enable_if` is a function that takes a `VllmConfig` object as input and
    returns a boolean value indicating whether to compile the model or not.
    This is useful if you want to compile the model only when certain
    conditions are met.
137
138
    """

139
    def cls_decorator_helper(cls: _T) -> _T:
140
141
        # helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
        # to avoid too much indentation for `_support_torch_compile``
142
143
        if not hasattr(cls, 'forward'):
            raise TypeError("decorated class should have a forward method.")
144
        sig = inspect.signature(cls.forward)
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        inferred_dynamic_arg_dims = dynamic_arg_dims
        if inferred_dynamic_arg_dims is None:
            inferred_dynamic_arg_dims = {}
            for k, v in sig.parameters.items():
                if v.annotation in [
                        torch.Tensor, Optional[torch.Tensor],
                        IntermediateTensors, Optional[IntermediateTensors]
                ]:
                    inferred_dynamic_arg_dims[k] = 0

            logger.debug(("Inferred dynamic dimensions for "
                          "forward method of %s: %s"), cls,
                         list(inferred_dynamic_arg_dims.keys()))

        if len(inferred_dynamic_arg_dims) == 0:
            raise ValueError(
                "No dynamic dimensions found in the forward method of "
                f"{cls}. Please provide dynamic_arg_dims explicitly.")

        for k in inferred_dynamic_arg_dims:
165
166
167
            if k not in sig.parameters:
                raise ValueError(
                    f"Argument {k} not found in the forward method of {cls}")
168
169
        return _support_torch_compile(cls, inferred_dynamic_arg_dims,
                                      enable_if)
170
171
172
173
174

    if cls is not None:
        # use `support_torch_compile` as a decorator without arguments
        assert isinstance(cls, type)
        return cls_decorator_helper(cls)
175
176
177
178

    return cls_decorator_helper


179
180
def _support_torch_compile(
    cls: _T,
181
    dynamic_arg_dims: dict[str, Union[int, list[int]]],
182
    enable_if: Optional[Callable[[VllmConfig], bool]] = None,
183
) -> _T:
184
185
186
    """
    A decorator to add support for compiling the forward method of a class.
    """
187
188
    if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
        # support decorating multiple times
189
190
191
192
193
        return cls

    # take care of method resolution order
    # make sure super().__init__ is called on the base class
    #  other than TorchCompileWrapperWithCustomDispatcher
194
    cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
195

196
    old_init = cls.__init__
197

198
199
    setattr(cls, IGNORE_COMPILE_KEY, False)

200
201
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
        old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
202
        self.vllm_config = vllm_config
203
        enable_compile = enable_if is None or enable_if(vllm_config)
204
205
        # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
        # will handle the compilation, so we don't need to do anything here.
206
207
        self.do_not_compile = \
            vllm_config.compilation_config.level in [
208
            CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
209
        ] or not supports_dynamo() or _should_ignore_torch_compile(
210
            self.__class__) or not enable_compile
211
212
        if self.do_not_compile:
            return
213

214
        compilation_counter.num_models_seen += 1
215
216
        TorchCompileWrapperWithCustomDispatcher.__init__(
            self, compilation_level=vllm_config.compilation_config.level)
217

218
    cls.__init__ = __init__
219

220
    def __call__(self, *args, **kwargs):
221
222
223
        # torch.compiler.is_compiling() means we are inside the compilation
        # e.g. TPU has the compilation logic in model runner, so we don't
        # need to compile the model inside.
224
        if self.do_not_compile or torch.compiler.is_compiling():
225
            return self.forward(*args, **kwargs)
226
227
228

        # the first compilation needs to have dynamic shapes marked
        if len(self.compiled_codes) < 1:
229
230
231
232
233
234
            sig = inspect.signature(self.__class__.forward)
            bound_args = sig.bind(self, *args, **kwargs)
            bound_args.apply_defaults()
            for k, dims in dynamic_arg_dims.items():
                arg = bound_args.arguments.get(k)
                if arg is not None:
235
                    dims = [dims] if isinstance(dims, int) else dims
236
                    if isinstance(arg, torch.Tensor):
237
238
239
240
                        # In case dims is specified with negative indexing
                        dims = [
                            arg.ndim + dim if dim < 0 else dim for dim in dims
                        ]
241
242
243
                        torch._dynamo.mark_dynamic(arg, dims)
                    elif isinstance(arg, IntermediateTensors):
                        for tensor in arg.tensors.values():
244
245
246
247
248
                            # In case dims is specified with negative indexing
                            dims = [
                                tensor.ndim + dim if dim < 0 else dim
                                for dim in dims
                            ]
249
250
251
252
253
                            torch._dynamo.mark_dynamic(tensor, dims)
                    else:
                        raise ValueError(
                            "Unsupported dynamic dimensions"
                            f" {dims} for argument {k} with type {type(arg)}.")
254
            # here, it is the starting point of the `torch.compile` process
255
            start_monitoring_torch_compile(self.vllm_config)
256
257
            logger.debug("Start compiling function %s",
                         self.original_code_object)
258
259
260
261
262

        # if we don't use custom dispatcher, we can directly call the
        # compiled function and let torch.compile handle the dispatching,
        # with the overhead of guard evaluation and recompilation.
        if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
263
264
265
266
267
            # it seems Dynamo reuse the compilation across instances,
            # while we need to make sure the compiled code is not reused.
            # we need to control all the compilation of the model.
            torch._dynamo.eval_frame.remove_from_cache(
                self.original_code_object)
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288

            # collect all relevant files traced by Dynamo,
            # so that the compilation cache can trigger re-compilation
            # properly when any of these files change.

            # 1. the file containing the top-level forward function
            self.vllm_config.compilation_config.traced_files.add(
                self.original_code_object.co_filename)

            # 2. every time Dynamo sees a function call, it will inline
            # the function by calling InliningInstructionTranslator.inline_call
            # we hijack this function to know all the functions called
            # during Dynamo tracing, and their corresponding files
            inline_call = InliningInstructionTranslator.inline_call

            def patched_inline_call(parent, func, args, kwargs):
                code = func.get_code()
                self.vllm_config.compilation_config.traced_files.add(
                    code.co_filename)
                return inline_call(parent, func, args, kwargs)

289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
            # Disable the C++ compilation of symbolic shape guards. C++-fication
            # of symbolic shape guards can improve guard overhead. But, since
            # vllm skip guards anyways, setting this flag to False can improve
            # compile time.
            dynamo_config_patches = {}
            try:
                _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
                dynamo_config_patches[
                    "enable_cpp_symbolic_shape_guards"] = False
            except AttributeError:
                # Note: this config is not available in torch 2.6, we can skip
                # if the config doesn't exist
                logger.debug(
                    "enable_cpp_symbolic_shape_guards config not available")

304
305
306
307
308
309
            with patch.object(
                    InliningInstructionTranslator, "inline_call",
                    patched_inline_call), torch._dynamo.config.patch(
                        **dynamo_config_patches
                    ), maybe_use_cudagraph_partition_wrapper(
                        self.vllm_config), _torch27_patch_tensor_subclasses():
310
311
                output = self.compiled_callable(*args, **kwargs)
            return output
312
313
314
315
316

        # usually, capturing the model once is enough, and then we can
        # dispatch to the compiled code directly, without going through
        # the Dynamo guard mechanism.
        with self.dispatch_to_code(0):
317
            model_output = self.forward(*args, **kwargs)
318
319
            return model_output

320
    cls.__call__ = __call__
321
    return cls
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370


@contextlib.contextmanager
def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
    """
    Context manager to set/unset customized cudagraph partition wrappers.

    If we're using Inductor-based graph partitioning, we currently have the
    whole `fx.Graph` before Inductor lowering and and the piecewise
    splitting happens after all graph passes and fusions. Here, we add
    a custom hook for Inductor to wrap each partition with our static
    graph wrapper class to maintain more control over static graph
    capture and replay.
    """
    from vllm.config import CUDAGraphMode

    compilation_config = vllm_config.compilation_config
    if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
            and compilation_config.use_inductor_graph_partition):
        from torch._inductor.utils import CUDAGraphWrapperMetadata

        from vllm.compilation.cuda_graph import CUDAGraphOptions
        from vllm.platforms import current_platform

        static_graph_wrapper_class = resolve_obj_by_qualname(
            current_platform.get_static_graph_wrapper_cls())

        def customized_cudagraph_wrapper(f,
                                         metadata: CUDAGraphWrapperMetadata):
            partition_id = metadata.partition_index
            num_partitions = metadata.num_partitions
            return static_graph_wrapper_class(
                runnable=f,
                vllm_config=vllm_config,
                runtime_mode=CUDAGraphMode.PIECEWISE,
                cudagraph_options=CUDAGraphOptions(
                    debug_log_enable=partition_id == 0,
                    gc_disable=partition_id != 0,
                    weak_ref_output=partition_id == num_partitions - 1,
                ))

        torch._inductor.utils.set_customized_partition_wrappers(
            customized_cudagraph_wrapper)

    yield

    if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
            and compilation_config.use_inductor_graph_partition):
        torch._inductor.utils.set_customized_partition_wrappers(None)
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400


@contextlib.contextmanager
def _torch27_patch_tensor_subclasses():
    """
    Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
    using torch 2.7.0. This enables using weight_loader_v2 and the use of
    `BasevLLMParameters` without having to replace them with regular tensors
    before `torch.compile`-time.
    """
    from vllm.model_executor.parameter import (BasevLLMParameter,
                                               ModelWeightParameter,
                                               RowvLLMParameter,
                                               _ColumnvLLMParameter)

    def return_false(*args, **kwargs):
        return False

    if version.parse("2.7") <= version.parse(
            torch.__version__) < version.parse("2.8"):
        yield
        return

    with (torch._dynamo.config.patch("traceable_tensor_subclasses", [
            BasevLLMParameter, ModelWeightParameter, _ColumnvLLMParameter,
            RowvLLMParameter
    ]),
          patch("torch._dynamo.variables.torch.can_dispatch_torch_function",
                return_false)):
        yield