decorators.py 21.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import hashlib
6
import inspect
7
8
import os
import sys
9
10
from collections.abc import Callable
from typing import TypeVar, overload
11
from unittest.mock import patch
12
13

import torch
14
import torch.nn as nn
15
from packaging import version
16
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
17

18
import vllm.envs as envs
19
from vllm.compilation.counter import compilation_counter
20
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
21
22
23
24
25
26
from vllm.config import (
    CompilationMode,
    VllmConfig,
    get_current_vllm_config,
    set_current_vllm_config,
)
27
from vllm.logger import init_logger
28
from vllm.sequence import IntermediateTensors
29
from vllm.utils.import_utils import resolve_obj_by_qualname
30
from vllm.utils.torch_utils import supports_dynamo
31

32
33
from .monitor import start_monitoring_torch_compile

34
logger = init_logger(__name__)
35

36
37
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"

38
39
40
_T = TypeVar("_T", bound=type[nn.Module])


41
42
43
44
45
46
47
def ignore_torch_compile(cls: _T) -> _T:
    """
    A decorator to ignore support_torch_compile decorator
    on the class. This is useful when a parent class has
    a support_torch_compile decorator, but we don't want to
    compile the class `cls` that inherits the parent class.
    This only ignores compiling the forward of the class the
48
    decorator is applied to.
49
50
51

    If the parent has ignore_torch_compile but the child has
    support_torch_compile, the child will still be compiled.
52

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    If the class has one or more submodules
    that have support_torch_compile decorator applied, compile will
    not be ignored for those submodules.
    """
    setattr(cls, IGNORE_COMPILE_KEY, True)
    return cls


def _should_ignore_torch_compile(cls) -> bool:
    """
    Check if the class should be ignored for torch.compile.
    """
    return getattr(cls, IGNORE_COMPILE_KEY, False)


68
69
70
@overload
def support_torch_compile(
    *,
71
    enable_if: Callable[[VllmConfig], bool] | None = None,
72
) -> Callable[[_T], _T]: ...
73
74


75
76
77
@overload
def support_torch_compile(
    *,
78
    dynamic_arg_dims: dict[str, int | list[int]] | None,
79
) -> Callable[[_T], _T]: ...
80
81


82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
@overload
def support_torch_compile(
    *,
    mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[_T], _T]: ...


@overload
def support_torch_compile(
    *,
    dynamic_arg_dims: dict[str, int | list[int]] | None,
    mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[_T], _T]: ...


97
@overload
98
def support_torch_compile(cls: _T) -> _T: ...
99

100
101

def support_torch_compile(
102
    cls: _T | None = None,
103
    *,
104
    dynamic_arg_dims: dict[str, int | list[int]] | None = None,
105
    mark_unbacked_dims: dict[str, int | list[int]] | None = None,
106
107
    enable_if: Callable[[VllmConfig], bool] | None = None,
) -> Callable[[_T], _T] | _T:
108
109
110
    """
    A decorator to add support for compiling the forward method of a class.

111
112
113
114
115
    Usage 1: use directly as a decorator without arguments:

    ```python
    @support_torch_compile
    class MyModel(nn.Module):
116
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
117
118
119
120
121
122
123
    ```

    Usage 2: use as a decorator with arguments:

    ```python
    @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
    class MyModel(nn.Module):
124
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
125
126
    ```

127
128
129
130
    `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
    dimensions of the argument. The dynamic dimensions can be either a single
    integer or a list of integers.

131
132
133
134
135
136
137
138
139
140
141
142
    if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
    of the `forward` method, based on the following default rules:

    - if the argument is annotated as `torch.Tensor` or
        `Optional[torch.Tensor]`, the first dimension will be
        marked as dynamic.
    - if the argument is annotated as `IntermediateTensors`, the first
        dimension of all the tensors in the intermediate tensors
        will be marked as dynamic.

    During runtime, when we actually mark dimensions of tensors,
     it depends on the value of arguments:
143

144
    - if it is a single integer (can be negative), the corresponding dimension
145
        of the argument will be marked as dynamic.
146
147
148
149
150
151
152
153
    - if it is `None`, ignored.
    - if it is `IntermediateTensors`, all the tensors in the intermediate
        tensors will be marked as dynamic.
    - otherwise, it will raise an error.

    NOTE: if an argument is `None`, it should always be passed as `None` during
    the lifetime of the model, otherwise, it cannot be captured as a single
    computation graph.
154
155
156
157
158

    `enable_if` is a function that takes a `VllmConfig` object as input and
    returns a boolean value indicating whether to compile the model or not.
    This is useful if you want to compile the model only when certain
    conditions are met.
159
160
161
162
163

    `mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
    dim to be decorated with `mark_unbacked`.  This is useful if we would like to
    enforce that dynamo do not specialize on 0/1 values in the case of dummy input
    such as for vision model compilation
164
165
    """

166
    def cls_decorator_helper(cls: _T) -> _T:
167
168
        # helper to pass `dynamic_arg_dims` to `_support_torch_compile`
        # to avoid too much indentation for `_support_torch_compile`
169
        if not hasattr(cls, "forward"):
170
            raise TypeError("decorated class should have a forward method.")
171
        sig = inspect.signature(cls.forward)
172
173
174
175
176
        inferred_dynamic_arg_dims = dynamic_arg_dims
        if inferred_dynamic_arg_dims is None:
            inferred_dynamic_arg_dims = {}
            for k, v in sig.parameters.items():
                if v.annotation in [
177
                    torch.Tensor,
178
                    torch.Tensor | None,
179
                    IntermediateTensors,
180
                    IntermediateTensors | None,
181
182
183
                ]:
                    inferred_dynamic_arg_dims[k] = 0

184
185
186
187
188
            logger.debug(
                ("Inferred dynamic dimensions for forward method of %s: %s"),
                cls,
                list(inferred_dynamic_arg_dims.keys()),
            )
189
190
191
192

        if len(inferred_dynamic_arg_dims) == 0:
            raise ValueError(
                "No dynamic dimensions found in the forward method of "
193
194
                f"{cls}. Please provide dynamic_arg_dims explicitly."
            )
195
196

        for k in inferred_dynamic_arg_dims:
197
198
            if k not in sig.parameters:
                raise ValueError(
199
200
                    f"Argument {k} not found in the forward method of {cls}"
                )
201
202
203
        return _support_torch_compile(
            cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if
        )
204
205
206
207
208

    if cls is not None:
        # use `support_torch_compile` as a decorator without arguments
        assert isinstance(cls, type)
        return cls_decorator_helper(cls)
209
210
211
212

    return cls_decorator_helper


213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def _model_hash_key(fn) -> str:
    import vllm

    sha256_hash = hashlib.sha256()
    sha256_hash.update(vllm.__version__.encode())
    sha256_hash.update(fn.__qualname__.encode())
    sha256_hash.update(str(fn.__code__.co_firstlineno).encode())
    return sha256_hash.hexdigest()


def _verify_source_unchanged(source_info, vllm_config) -> None:
    from .caching import _compute_code_hash, _compute_code_hash_with_content

    file_contents = {}
    for source in source_info.inlined_sources:
        module = sys.modules[source.module]
        file = inspect.getfile(module)
        vllm_config.compilation_config.traced_files.add(file)
        file_contents[file] = source.content
    expected_checksum = _compute_code_hash_with_content(file_contents)
    actual_checksum = _compute_code_hash(set(file_contents.keys()))
    if expected_checksum != actual_checksum:
        raise RuntimeError(
            "Source code has changed since the last compilation. Recompiling the model."
        )


240
241
def _support_torch_compile(
    cls: _T,
242
    dynamic_arg_dims: dict[str, int | list[int]],
243
    mark_unbacked_dims: dict[str, int | list[int]] | None = None,
244
    enable_if: Callable[[VllmConfig], bool] | None = None,
245
) -> _T:
246
247
248
    """
    A decorator to add support for compiling the forward method of a class.
    """
249
    if TorchCompileWithNoGuardsWrapper in cls.__bases__:
250
        # support decorating multiple times
251
252
253
254
        return cls

    # take care of method resolution order
    # make sure super().__init__ is called on the base class
255
256
    #  other than TorchCompileWithNoGuardsWrapper
    cls.__bases__ = cls.__bases__ + (TorchCompileWithNoGuardsWrapper,)
257

258
    old_init = cls.__init__
259

260
261
    setattr(cls, IGNORE_COMPILE_KEY, False)

262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    def __init__(
        self, *, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs
    ):
        if vllm_config is None:
            vllm_config = get_current_vllm_config()

        # NOTE: to support multimodal models (such as encoder),
        # we may not have vllm_config so we may need to patch
        # it
        sig = inspect.signature(old_init)
        if "vllm_config" in sig.parameters:
            kwargs["vllm_config"] = vllm_config
        if "prefix" in sig.parameters:
            kwargs["prefix"] = prefix
        old_init(self, **kwargs)

278
        self.vllm_config = vllm_config
279
        enable_compile = enable_if is None or enable_if(vllm_config)
280
        # for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
281
        # will handle the compilation, so we don't need to do anything here.
282
        self.do_not_compile = (
283
284
            vllm_config.compilation_config.mode
            in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
285
286
287
288
            or not supports_dynamo()
            or _should_ignore_torch_compile(self.__class__)
            or not enable_compile
        )
289
290
        if self.do_not_compile:
            return
291

292
        compilation_counter.num_models_seen += 1
293
294
        self.compiled = False
        TorchCompileWithNoGuardsWrapper.__init__(self)
295

296
    cls.__init__ = __init__
297

298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    def _mark_dynamic_inputs(mod, *args, **kwargs):
        sig = inspect.signature(mod.__class__.forward)
        bound_args = sig.bind(mod, *args, **kwargs)
        bound_args.apply_defaults()
        for k, dims in dynamic_arg_dims.items():
            arg = bound_args.arguments.get(k)
            if arg is not None:
                dims = [dims] if isinstance(dims, int) else dims
                if isinstance(arg, torch.Tensor):
                    # In case dims is specified with negative indexing
                    dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
                    torch._dynamo.mark_dynamic(arg, dims)
                elif isinstance(arg, IntermediateTensors):
                    for tensor in arg.tensors.values():
                        # In case dims is specified with negative indexing
                        dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
                        torch._dynamo.mark_dynamic(tensor, dims)
                else:
                    raise ValueError(
                        "Unsupported dynamic dimensions"
                        f" {dims} for argument {k} with type {type(arg)}."
                    )
        if mark_unbacked_dims:
            for k, dims in mark_unbacked_dims.items():
                arg = bound_args.arguments.get(k)
                if arg is not None:
                    dims = [dims] if isinstance(dims, int) else dims
                    if isinstance(arg, torch.Tensor):
                        # In case dims is specified with negative indexing
                        dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
                        torch._dynamo.decorators.mark_unbacked(arg, dims)

330
    def __call__(self, *args, **kwargs):
331
332
333
        # torch.compiler.is_compiling() means we are inside the compilation
        # e.g. TPU has the compilation logic in model runner, so we don't
        # need to compile the model inside.
334
        if self.do_not_compile or torch.compiler.is_compiling():
335
            return self.forward(*args, **kwargs)
336

337
        # if aot_compiled_fn is set, just call it.
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
        if getattr(self, "aot_compiled_fn", None) is not None:
            return self.aot_compiled_fn(self, *args, **kwargs)

        cache_dir = None
        aot_compilation_path = None
        if envs.VLLM_USE_AOT_COMPILE:
            """
            When using torch.compile in AOT mode, we store the cache artifacts
            under VLLM_CACHE_ROOT/torch_aot_compile/{hash}/rank_i_j. The {hash}
            contains all of the factors except for the source files being
            traced through, because we don't actually know which source files
            to check at this point (before dynamo runs).
            On loading we will actually look at the source files being traced
            through. If any source file have changed (compared with the
            serialized backend artifacts), then we need to generate a new AOT
            compile artifact from scratch.
            """
            from .caching import compilation_config_hash_factors

            factors: list[str] = compilation_config_hash_factors(self.vllm_config)

            factors.append(_model_hash_key(self.forward))
            hash_key = hashlib.sha256(str(factors).encode()).hexdigest()

            cache_dir = os.path.join(
                envs.VLLM_CACHE_ROOT,
                "torch_aot_compile",
                hash_key,
            )

            rank = self.vllm_config.parallel_config.rank
            dp_rank = self.vllm_config.parallel_config.data_parallel_rank
            cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
            aot_compilation_path = os.path.join(cache_dir, "model")
            try:
                with (
                    set_current_vllm_config(self.vllm_config),
                    open(aot_compilation_path, "rb") as f,
                ):
                    start_monitoring_torch_compile(self.vllm_config)
                    loaded_fn = torch.compiler.load_compiled_function(f)
                _verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
380
                loaded_fn.disable_guard_check()
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
                self.aot_compiled_fn = loaded_fn
            except Exception as e:
                if os.path.exists(aot_compilation_path):
                    logger.warning(
                        "Cannot load aot compilation from path %s, error: %s",
                        aot_compilation_path,
                        str(e),
                    )
                if envs.VLLM_FORCE_AOT_LOAD:
                    raise e
            if getattr(self, "aot_compiled_fn", None) is not None:
                logger.info(
                    "Directly load AOT compilation from path %s", aot_compilation_path
                )
                return self.aot_compiled_fn(self, *args, **kwargs)

397
398
399
400
401
402
        if self.compiled:
            assert not envs.VLLM_USE_AOT_COMPILE
            return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)

        # This is the path for the first compilation.

403
        # the first compilation needs to have dynamic shapes marked
404
        _mark_dynamic_inputs(self, *args, **kwargs)
405

406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
        # here, it is the starting point of the `torch.compile` process
        start_monitoring_torch_compile(self.vllm_config)
        original_code_object = self.original_code_object()
        logger.debug("Start compiling function %s", original_code_object)

        # we do not want tp delete the original code object entries since
        # we depend on them now to look up cached compiled functions.
        # torch._dynamo.eval_frame.remove_from_cache(original_code_object)

        # collect all relevant files traced by Dynamo,
        # so that the compilation cache can trigger re-compilation
        # properly when any of these files change.

        # 1. the file containing the top-level forward function
        self.vllm_config.compilation_config.traced_files.add(
            original_code_object.co_filename
        )

        # 2. every time Dynamo sees a function call, it will inline
        # the function by calling InliningInstructionTranslator.inline_call_
        # we hijack this function to know all the functions called
        # during Dynamo tracing, and their corresponding files
        inline_call = InliningInstructionTranslator.inline_call_

        def patched_inline_call(self_):
            code = self_.f_code
            self.vllm_config.compilation_config.traced_files.add(code.co_filename)
            return inline_call(self_)

        # Disable the C++ compilation of symbolic shape guards. C++-fication
        # of symbolic shape guards can improve guard overhead. But, since
        # vllm skip guards anyways, setting this flag to False can improve
        # compile time.
        dynamo_config_patches = {}
        try:
            _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
            dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
        except AttributeError:
            # Note: this config is not available in torch 2.6, we can skip
            # if the config doesn't exist
            logger.debug("enable_cpp_symbolic_shape_guards config not available")

        with (
            patch.object(
                InliningInstructionTranslator, "inline_call_", patched_inline_call
            ),
            torch._dynamo.config.patch(**dynamo_config_patches),
            maybe_use_cudagraph_partition_wrapper(self.vllm_config),
            _torch27_patch_tensor_subclasses(),
        ):
            if envs.VLLM_USE_AOT_COMPILE:
                self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
                output = self.aot_compiled_fn(self, *args, **kwargs)
                assert aot_compilation_path is not None
                assert cache_dir is not None
                try:
                    os.makedirs(cache_dir, exist_ok=True)
                    self.aot_compiled_fn.save_compiled_function(aot_compilation_path)
                except Exception as e:
                    logger.warning(
                        "Cannot save aot compilation to path %s, error: %s",
                        aot_compilation_path,
                        str(e),
                    )
            else:
                output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)

        self.compiled = True
        return output
475

476
    cls.__call__ = __call__
477
    return cls
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494


@contextlib.contextmanager
def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
    """
    Context manager to set/unset customized cudagraph partition wrappers.

    If we're using Inductor-based graph partitioning, we currently have the
    whole `fx.Graph` before Inductor lowering and and the piecewise
    splitting happens after all graph passes and fusions. Here, we add
    a custom hook for Inductor to wrap each partition with our static
    graph wrapper class to maintain more control over static graph
    capture and replay.
    """
    from vllm.config import CUDAGraphMode

    compilation_config = vllm_config.compilation_config
495
496
497
498
    if (
        compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
        and compilation_config.use_inductor_graph_partition
    ):
499
500
501
502
503
504
        from torch._inductor.utils import CUDAGraphWrapperMetadata

        from vllm.compilation.cuda_graph import CUDAGraphOptions
        from vllm.platforms import current_platform

        static_graph_wrapper_class = resolve_obj_by_qualname(
505
506
            current_platform.get_static_graph_wrapper_cls()
        )
507

508
        def customized_cudagraph_wrapper(f, metadata: CUDAGraphWrapperMetadata):
509
510
511
512
513
514
515
516
517
518
            partition_id = metadata.partition_index
            num_partitions = metadata.num_partitions
            return static_graph_wrapper_class(
                runnable=f,
                vllm_config=vllm_config,
                runtime_mode=CUDAGraphMode.PIECEWISE,
                cudagraph_options=CUDAGraphOptions(
                    debug_log_enable=partition_id == 0,
                    gc_disable=partition_id != 0,
                    weak_ref_output=partition_id == num_partitions - 1,
519
520
                ),
            )
521
522

        torch._inductor.utils.set_customized_partition_wrappers(
523
524
            customized_cudagraph_wrapper
        )
525
526
527

    yield

528
529
530
531
    if (
        compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
        and compilation_config.use_inductor_graph_partition
    ):
532
        torch._inductor.utils.set_customized_partition_wrappers(None)
533
534
535
536
537
538
539
540
541
542


@contextlib.contextmanager
def _torch27_patch_tensor_subclasses():
    """
    Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
    using torch 2.7.0. This enables using weight_loader_v2 and the use of
    `BasevLLMParameters` without having to replace them with regular tensors
    before `torch.compile`-time.
    """
543
544
545
546
547
548
    from vllm.model_executor.parameter import (
        BasevLLMParameter,
        ModelWeightParameter,
        RowvLLMParameter,
        _ColumnvLLMParameter,
    )
549
550
551
552

    def return_false(*args, **kwargs):
        return False

553
    if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"):
554
555
556
        yield
        return

557
558
559
560
561
562
563
564
565
566
567
568
569
570
    with (
        torch._dynamo.config.patch(
            "traceable_tensor_subclasses",
            [
                BasevLLMParameter,
                ModelWeightParameter,
                _ColumnvLLMParameter,
                RowvLLMParameter,
            ],
        ),
        patch(
            "torch._dynamo.variables.torch.can_dispatch_torch_function", return_false
        ),
    ):
571
        yield