decorators.py 15.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import inspect
6
from typing import Callable, Optional, TypeVar, Union, overload
7
from unittest.mock import patch
8
9

import torch
10
import torch.nn as nn
11
from packaging import version
12
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
13

14
from vllm.compilation.counter import compilation_counter
15
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
16
from vllm.config import CompilationLevel, VllmConfig
17
from vllm.logger import init_logger
18
from vllm.sequence import IntermediateTensors
19
from vllm.utils import resolve_obj_by_qualname, supports_dynamo
20

21
22
from .monitor import start_monitoring_torch_compile

23
logger = init_logger(__name__)
24

25
26
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"

27
28
29
_T = TypeVar("_T", bound=type[nn.Module])


30
31
32
33
34
35
36
def ignore_torch_compile(cls: _T) -> _T:
    """
    A decorator to ignore support_torch_compile decorator
    on the class. This is useful when a parent class has
    a support_torch_compile decorator, but we don't want to
    compile the class `cls` that inherits the parent class.
    This only ignores compiling the forward of the class the
37
    decorator is applied to.
38
39
40

    If the parent has ignore_torch_compile but the child has
    support_torch_compile, the child will still be compiled.
41

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    If the class has one or more submodules
    that have support_torch_compile decorator applied, compile will
    not be ignored for those submodules.
    """
    setattr(cls, IGNORE_COMPILE_KEY, True)
    return cls


def _should_ignore_torch_compile(cls) -> bool:
    """
    Check if the class should be ignored for torch.compile.
    """
    return getattr(cls, IGNORE_COMPILE_KEY, False)


57
58
59
60
@overload
def support_torch_compile(
    *,
    enable_if: Optional[Callable[[VllmConfig], bool]] = None,
61
) -> Callable[[_T], _T]: ...
62
63


64
65
66
@overload
def support_torch_compile(
    *,
67
    dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]],
68
) -> Callable[[_T], _T]: ...
69
70
71


@overload
72
def support_torch_compile(cls: _T) -> _T: ...
73

74
75

def support_torch_compile(
76
77
    cls: Optional[_T] = None,
    *,
78
    dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
79
    enable_if: Optional[Callable[[VllmConfig], bool]] = None,
80
) -> Union[Callable[[_T], _T], _T]:
81
82
83
    """
    A decorator to add support for compiling the forward method of a class.

84
85
86
87
88
    Usage 1: use directly as a decorator without arguments:

    ```python
    @support_torch_compile
    class MyModel(nn.Module):
89
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
90
91
92
93
94
95
96
    ```

    Usage 2: use as a decorator with arguments:

    ```python
    @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
    class MyModel(nn.Module):
97
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
98
99
    ```

100
101
102
103
    `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
    dimensions of the argument. The dynamic dimensions can be either a single
    integer or a list of integers.

104
105
106
107
108
109
110
111
112
113
114
115
    if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
    of the `forward` method, based on the following default rules:

    - if the argument is annotated as `torch.Tensor` or
        `Optional[torch.Tensor]`, the first dimension will be
        marked as dynamic.
    - if the argument is annotated as `IntermediateTensors`, the first
        dimension of all the tensors in the intermediate tensors
        will be marked as dynamic.

    During runtime, when we actually mark dimensions of tensors,
     it depends on the value of arguments:
116

117
    - if it is a single integer (can be negative), the corresponding dimension
118
        of the argument will be marked as dynamic.
119
120
121
122
123
124
125
126
    - if it is `None`, ignored.
    - if it is `IntermediateTensors`, all the tensors in the intermediate
        tensors will be marked as dynamic.
    - otherwise, it will raise an error.

    NOTE: if an argument is `None`, it should always be passed as `None` during
    the lifetime of the model, otherwise, it cannot be captured as a single
    computation graph.
127
128
129
130
131

    `enable_if` is a function that takes a `VllmConfig` object as input and
    returns a boolean value indicating whether to compile the model or not.
    This is useful if you want to compile the model only when certain
    conditions are met.
132
133
    """

134
    def cls_decorator_helper(cls: _T) -> _T:
135
136
        # helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
        # to avoid too much indentation for `_support_torch_compile``
137
        if not hasattr(cls, "forward"):
138
            raise TypeError("decorated class should have a forward method.")
139
        sig = inspect.signature(cls.forward)
140
141
142
143
144
        inferred_dynamic_arg_dims = dynamic_arg_dims
        if inferred_dynamic_arg_dims is None:
            inferred_dynamic_arg_dims = {}
            for k, v in sig.parameters.items():
                if v.annotation in [
145
146
147
148
                    torch.Tensor,
                    Optional[torch.Tensor],
                    IntermediateTensors,
                    Optional[IntermediateTensors],
149
150
151
                ]:
                    inferred_dynamic_arg_dims[k] = 0

152
153
154
155
156
            logger.debug(
                ("Inferred dynamic dimensions for forward method of %s: %s"),
                cls,
                list(inferred_dynamic_arg_dims.keys()),
            )
157
158
159
160

        if len(inferred_dynamic_arg_dims) == 0:
            raise ValueError(
                "No dynamic dimensions found in the forward method of "
161
162
                f"{cls}. Please provide dynamic_arg_dims explicitly."
            )
163
164

        for k in inferred_dynamic_arg_dims:
165
166
            if k not in sig.parameters:
                raise ValueError(
167
168
169
                    f"Argument {k} not found in the forward method of {cls}"
                )
        return _support_torch_compile(cls, inferred_dynamic_arg_dims, enable_if)
170
171
172
173
174

    if cls is not None:
        # use `support_torch_compile` as a decorator without arguments
        assert isinstance(cls, type)
        return cls_decorator_helper(cls)
175
176
177
178

    return cls_decorator_helper


179
180
def _support_torch_compile(
    cls: _T,
181
    dynamic_arg_dims: dict[str, Union[int, list[int]]],
182
    enable_if: Optional[Callable[[VllmConfig], bool]] = None,
183
) -> _T:
184
185
186
    """
    A decorator to add support for compiling the forward method of a class.
    """
187
188
    if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
        # support decorating multiple times
189
190
191
192
193
        return cls

    # take care of method resolution order
    # make sure super().__init__ is called on the base class
    #  other than TorchCompileWrapperWithCustomDispatcher
194
    cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher,)
195

196
    old_init = cls.__init__
197

198
199
    setattr(cls, IGNORE_COMPILE_KEY, False)

200
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
201
        old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
202
        self.vllm_config = vllm_config
203
        enable_compile = enable_if is None or enable_if(vllm_config)
204
205
        # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
        # will handle the compilation, so we don't need to do anything here.
206
207
208
209
210
211
212
        self.do_not_compile = (
            vllm_config.compilation_config.level
            in [CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS]
            or not supports_dynamo()
            or _should_ignore_torch_compile(self.__class__)
            or not enable_compile
        )
213
214
        if self.do_not_compile:
            return
215

216
        compilation_counter.num_models_seen += 1
217
        TorchCompileWrapperWithCustomDispatcher.__init__(
218
219
            self, compilation_level=vllm_config.compilation_config.level
        )
220

221
    cls.__init__ = __init__
222

223
    def __call__(self, *args, **kwargs):
224
225
226
        # torch.compiler.is_compiling() means we are inside the compilation
        # e.g. TPU has the compilation logic in model runner, so we don't
        # need to compile the model inside.
227
        if self.do_not_compile or torch.compiler.is_compiling():
228
            return self.forward(*args, **kwargs)
229
230
231

        # the first compilation needs to have dynamic shapes marked
        if len(self.compiled_codes) < 1:
232
233
234
235
236
237
            sig = inspect.signature(self.__class__.forward)
            bound_args = sig.bind(self, *args, **kwargs)
            bound_args.apply_defaults()
            for k, dims in dynamic_arg_dims.items():
                arg = bound_args.arguments.get(k)
                if arg is not None:
238
                    dims = [dims] if isinstance(dims, int) else dims
239
                    if isinstance(arg, torch.Tensor):
240
                        # In case dims is specified with negative indexing
241
                        dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
242
243
244
                        torch._dynamo.mark_dynamic(arg, dims)
                    elif isinstance(arg, IntermediateTensors):
                        for tensor in arg.tensors.values():
245
246
                            # In case dims is specified with negative indexing
                            dims = [
247
                                tensor.ndim + dim if dim < 0 else dim for dim in dims
248
                            ]
249
250
251
252
                            torch._dynamo.mark_dynamic(tensor, dims)
                    else:
                        raise ValueError(
                            "Unsupported dynamic dimensions"
253
254
                            f" {dims} for argument {k} with type {type(arg)}."
                        )
255
            # here, it is the starting point of the `torch.compile` process
256
            start_monitoring_torch_compile(self.vllm_config)
257
            logger.debug("Start compiling function %s", self.original_code_object)
258
259
260
261
262

        # if we don't use custom dispatcher, we can directly call the
        # compiled function and let torch.compile handle the dispatching,
        # with the overhead of guard evaluation and recompilation.
        if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
263
264
265
            # it seems Dynamo reuse the compilation across instances,
            # while we need to make sure the compiled code is not reused.
            # we need to control all the compilation of the model.
266
            torch._dynamo.eval_frame.remove_from_cache(self.original_code_object)
267
268
269
270
271
272
273

            # collect all relevant files traced by Dynamo,
            # so that the compilation cache can trigger re-compilation
            # properly when any of these files change.

            # 1. the file containing the top-level forward function
            self.vllm_config.compilation_config.traced_files.add(
274
275
                self.original_code_object.co_filename
            )
276
277
278
279
280
281
282
283
284

            # 2. every time Dynamo sees a function call, it will inline
            # the function by calling InliningInstructionTranslator.inline_call
            # we hijack this function to know all the functions called
            # during Dynamo tracing, and their corresponding files
            inline_call = InliningInstructionTranslator.inline_call

            def patched_inline_call(parent, func, args, kwargs):
                code = func.get_code()
285
                self.vllm_config.compilation_config.traced_files.add(code.co_filename)
286
287
                return inline_call(parent, func, args, kwargs)

288
289
290
291
292
293
294
            # Disable the C++ compilation of symbolic shape guards. C++-fication
            # of symbolic shape guards can improve guard overhead. But, since
            # vllm skip guards anyways, setting this flag to False can improve
            # compile time.
            dynamo_config_patches = {}
            try:
                _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
295
                dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
296
297
298
            except AttributeError:
                # Note: this config is not available in torch 2.6, we can skip
                # if the config doesn't exist
299
300
301
302
303
304
305
306
307
308
                logger.debug("enable_cpp_symbolic_shape_guards config not available")

            with (
                patch.object(
                    InliningInstructionTranslator, "inline_call", patched_inline_call
                ),
                torch._dynamo.config.patch(**dynamo_config_patches),
                maybe_use_cudagraph_partition_wrapper(self.vllm_config),
                _torch27_patch_tensor_subclasses(),
            ):
309
310
                output = self.compiled_callable(*args, **kwargs)
            return output
311
312
313
314
315

        # usually, capturing the model once is enough, and then we can
        # dispatch to the compiled code directly, without going through
        # the Dynamo guard mechanism.
        with self.dispatch_to_code(0):
316
            model_output = self.forward(*args, **kwargs)
317
318
            return model_output

319
    cls.__call__ = __call__
320
    return cls
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337


@contextlib.contextmanager
def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
    """
    Context manager to set/unset customized cudagraph partition wrappers.

    If we're using Inductor-based graph partitioning, we currently have the
    whole `fx.Graph` before Inductor lowering and and the piecewise
    splitting happens after all graph passes and fusions. Here, we add
    a custom hook for Inductor to wrap each partition with our static
    graph wrapper class to maintain more control over static graph
    capture and replay.
    """
    from vllm.config import CUDAGraphMode

    compilation_config = vllm_config.compilation_config
338
339
340
341
    if (
        compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
        and compilation_config.use_inductor_graph_partition
    ):
342
343
344
345
346
347
        from torch._inductor.utils import CUDAGraphWrapperMetadata

        from vllm.compilation.cuda_graph import CUDAGraphOptions
        from vllm.platforms import current_platform

        static_graph_wrapper_class = resolve_obj_by_qualname(
348
349
            current_platform.get_static_graph_wrapper_cls()
        )
350

351
        def customized_cudagraph_wrapper(f, metadata: CUDAGraphWrapperMetadata):
352
353
354
355
356
357
358
359
360
361
            partition_id = metadata.partition_index
            num_partitions = metadata.num_partitions
            return static_graph_wrapper_class(
                runnable=f,
                vllm_config=vllm_config,
                runtime_mode=CUDAGraphMode.PIECEWISE,
                cudagraph_options=CUDAGraphOptions(
                    debug_log_enable=partition_id == 0,
                    gc_disable=partition_id != 0,
                    weak_ref_output=partition_id == num_partitions - 1,
362
363
                ),
            )
364
365

        torch._inductor.utils.set_customized_partition_wrappers(
366
367
            customized_cudagraph_wrapper
        )
368
369
370

    yield

371
372
373
374
    if (
        compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
        and compilation_config.use_inductor_graph_partition
    ):
375
        torch._inductor.utils.set_customized_partition_wrappers(None)
376
377
378
379
380
381
382
383
384
385


@contextlib.contextmanager
def _torch27_patch_tensor_subclasses():
    """
    Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
    using torch 2.7.0. This enables using weight_loader_v2 and the use of
    `BasevLLMParameters` without having to replace them with regular tensors
    before `torch.compile`-time.
    """
386
387
388
389
390
391
    from vllm.model_executor.parameter import (
        BasevLLMParameter,
        ModelWeightParameter,
        RowvLLMParameter,
        _ColumnvLLMParameter,
    )
392
393
394
395

    def return_false(*args, **kwargs):
        return False

396
    if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"):
397
398
399
        yield
        return

400
401
402
403
404
405
406
407
408
409
410
411
412
413
    with (
        torch._dynamo.config.patch(
            "traceable_tensor_subclasses",
            [
                BasevLLMParameter,
                ModelWeightParameter,
                _ColumnvLLMParameter,
                RowvLLMParameter,
            ],
        ),
        patch(
            "torch._dynamo.variables.torch.can_dispatch_torch_function", return_false
        ),
    ):
414
        yield