decorators.py 25.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import hashlib
6
import inspect
7
8
import os
import sys
9
10
from collections.abc import Callable, Generator
from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload
11
from unittest.mock import patch
12
13

import torch
14
import torch.nn as nn
15
from packaging import version
16
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
17

18
import vllm.envs as envs
19
from vllm.compilation.counter import compilation_counter
20
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
21
22
23
24
25
26
from vllm.config import (
    CompilationMode,
    VllmConfig,
    get_current_vllm_config,
    set_current_vllm_config,
)
27
from vllm.config.compilation import DynamicShapesType
28
from vllm.logger import init_logger
29
from vllm.sequence import IntermediateTensors
30
from vllm.utils.import_utils import resolve_obj_by_qualname
31
from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo
32

33
34
from .monitor import start_monitoring_torch_compile

35
36
37
38
39
40
41
42
if TYPE_CHECKING:
    # Only added on nightly/2.10 so wrap
    try:
        from torch._dynamo.package import SourceInfo
    except ImportError:
        # Fallback for old versions not supporting
        SourceInfo = Any

43
logger = init_logger(__name__)
44

45
46
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"

47
48
49
_T = TypeVar("_T", bound=type[nn.Module])


50
51
52
53
54
55
56
def ignore_torch_compile(cls: _T) -> _T:
    """
    A decorator to ignore support_torch_compile decorator
    on the class. This is useful when a parent class has
    a support_torch_compile decorator, but we don't want to
    compile the class `cls` that inherits the parent class.
    This only ignores compiling the forward of the class the
57
    decorator is applied to.
58
59
60

    If the parent has ignore_torch_compile but the child has
    support_torch_compile, the child will still be compiled.
61

62
63
64
65
66
67
68
69
    If the class has one or more submodules
    that have support_torch_compile decorator applied, compile will
    not be ignored for those submodules.
    """
    setattr(cls, IGNORE_COMPILE_KEY, True)
    return cls


70
def _should_ignore_torch_compile(cls: _T) -> bool:
71
72
73
74
75
76
    """
    Check if the class should be ignored for torch.compile.
    """
    return getattr(cls, IGNORE_COMPILE_KEY, False)


77
78
79
@overload
def support_torch_compile(
    *,
80
    enable_if: Callable[[VllmConfig], bool] | None = None,
81
) -> Callable[[_T], _T]: ...
82
83


84
85
86
@overload
def support_torch_compile(
    *,
87
    dynamic_arg_dims: dict[str, int | list[int]] | None,
88
) -> Callable[[_T], _T]: ...
89
90


91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
@overload
def support_torch_compile(
    *,
    mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[_T], _T]: ...


@overload
def support_torch_compile(
    *,
    dynamic_arg_dims: dict[str, int | list[int]] | None,
    mark_unbacked_dims: dict[str, int | list[int]] | None,
) -> Callable[[_T], _T]: ...


106
@overload
107
def support_torch_compile(cls: _T) -> _T: ...
108

109
110

def support_torch_compile(
111
    cls: _T | None = None,
112
    *,
113
    dynamic_arg_dims: dict[str, int | list[int]] | None = None,
114
    mark_unbacked_dims: dict[str, int | list[int]] | None = None,
115
    enable_if: Callable[[VllmConfig], bool] | None = None,
116
    shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
117
) -> Callable[[_T], _T] | _T:
118
119
120
    """
    A decorator to add support for compiling the forward method of a class.

121
122
123
124
125
    Usage 1: use directly as a decorator without arguments:

    ```python
    @support_torch_compile
    class MyModel(nn.Module):
126
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
127
128
129
130
131
132
133
    ```

    Usage 2: use as a decorator with arguments:

    ```python
    @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
    class MyModel(nn.Module):
134
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
135
136
    ```

137
138
139
140
    `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
    dimensions of the argument. The dynamic dimensions can be either a single
    integer or a list of integers.

141
142
143
144
145
146
147
148
149
150
151
152
    if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
    of the `forward` method, based on the following default rules:

    - if the argument is annotated as `torch.Tensor` or
        `Optional[torch.Tensor]`, the first dimension will be
        marked as dynamic.
    - if the argument is annotated as `IntermediateTensors`, the first
        dimension of all the tensors in the intermediate tensors
        will be marked as dynamic.

    During runtime, when we actually mark dimensions of tensors,
     it depends on the value of arguments:
153

154
    - if it is a single integer (can be negative), the corresponding dimension
155
        of the argument will be marked as dynamic.
156
157
158
159
160
161
162
163
    - if it is `None`, ignored.
    - if it is `IntermediateTensors`, all the tensors in the intermediate
        tensors will be marked as dynamic.
    - otherwise, it will raise an error.

    NOTE: if an argument is `None`, it should always be passed as `None` during
    the lifetime of the model, otherwise, it cannot be captured as a single
    computation graph.
164
165
166
167
168

    `enable_if` is a function that takes a `VllmConfig` object as input and
    returns a boolean value indicating whether to compile the model or not.
    This is useful if you want to compile the model only when certain
    conditions are met.
169
170
171

    `mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
    dim to be decorated with `mark_unbacked`.  This is useful if we would like to
172
    enforce that dynamo does not specialize on 0/1 values in the case of dummy input
173
    such as for vision model compilation
174
175
176
177
178
179
180
181

    `shape_invariants` is a function that gets compiled right before forward.
    The function should have the torch._check calls that are needed to set
    the relationships between different input sizes. For example:
            torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
    This enforces constraints on the symbolic shapes without hardcoding
    specific values. It is needed for some models to avoid data dependent
    errors.
182
183
    """

184
    def cls_decorator_helper(cls: _T) -> _T:
185
186
        # helper to pass `dynamic_arg_dims` to `_support_torch_compile`
        # to avoid too much indentation for `_support_torch_compile`
187
        if not hasattr(cls, "forward"):
188
            raise TypeError("decorated class should have a forward method.")
189
        sig = inspect.signature(cls.forward)
190
191
192
193
194
        inferred_dynamic_arg_dims = dynamic_arg_dims
        if inferred_dynamic_arg_dims is None:
            inferred_dynamic_arg_dims = {}
            for k, v in sig.parameters.items():
                if v.annotation in [
195
                    torch.Tensor,
196
                    torch.Tensor | None,
197
                    IntermediateTensors,
198
                    IntermediateTensors | None,
199
200
201
                ]:
                    inferred_dynamic_arg_dims[k] = 0

202
203
204
205
206
            logger.debug(
                ("Inferred dynamic dimensions for forward method of %s: %s"),
                cls,
                list(inferred_dynamic_arg_dims.keys()),
            )
207
208
209
210

        if len(inferred_dynamic_arg_dims) == 0:
            raise ValueError(
                "No dynamic dimensions found in the forward method of "
211
212
                f"{cls}. Please provide dynamic_arg_dims explicitly."
            )
213
214

        for k in inferred_dynamic_arg_dims:
215
216
            if k not in sig.parameters:
                raise ValueError(
217
218
                    f"Argument {k} not found in the forward method of {cls}"
                )
219
        return _support_torch_compile(
220
221
222
223
224
            cls,
            inferred_dynamic_arg_dims,
            mark_unbacked_dims,
            enable_if,
            shape_invariants,
225
        )
226
227
228
229
230

    if cls is not None:
        # use `support_torch_compile` as a decorator without arguments
        assert isinstance(cls, type)
        return cls_decorator_helper(cls)
231
232
233
234

    return cls_decorator_helper


235
def _model_hash_key(fn: Callable[..., Any]) -> str:
236
237
238
239
240
241
242
243
244
    import vllm

    sha256_hash = hashlib.sha256()
    sha256_hash.update(vllm.__version__.encode())
    sha256_hash.update(fn.__qualname__.encode())
    sha256_hash.update(str(fn.__code__.co_firstlineno).encode())
    return sha256_hash.hexdigest()


245
246
247
def _verify_source_unchanged(
    source_info: "SourceInfo", vllm_config: VllmConfig
) -> None:
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    from .caching import _compute_code_hash, _compute_code_hash_with_content

    file_contents = {}
    for source in source_info.inlined_sources:
        module = sys.modules[source.module]
        file = inspect.getfile(module)
        vllm_config.compilation_config.traced_files.add(file)
        file_contents[file] = source.content
    expected_checksum = _compute_code_hash_with_content(file_contents)
    actual_checksum = _compute_code_hash(set(file_contents.keys()))
    if expected_checksum != actual_checksum:
        raise RuntimeError(
            "Source code has changed since the last compilation. Recompiling the model."
        )


264
265
def _support_torch_compile(
    cls: _T,
266
    dynamic_arg_dims: dict[str, int | list[int]],
267
    mark_unbacked_dims: dict[str, int | list[int]] | None = None,
268
    enable_if: Callable[[VllmConfig], bool] | None = None,
269
    shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
270
) -> _T:
271
272
273
    """
    A decorator to add support for compiling the forward method of a class.
    """
274
    if TorchCompileWithNoGuardsWrapper in cls.__bases__:
275
        # support decorating multiple times
276
277
278
279
        return cls

    # take care of method resolution order
    # make sure super().__init__ is called on the base class
280
281
    #  other than TorchCompileWithNoGuardsWrapper
    cls.__bases__ = cls.__bases__ + (TorchCompileWithNoGuardsWrapper,)
282

283
    old_init = cls.__init__
284

285
286
    setattr(cls, IGNORE_COMPILE_KEY, False)

287
    def __init__(
288
289
290
291
292
293
        self: _T,
        *,
        vllm_config: VllmConfig | None = None,
        prefix: str = "",
        **kwargs: Any,
    ) -> None:
294
295
296
297
298
299
300
301
302
303
304
305
306
        if vllm_config is None:
            vllm_config = get_current_vllm_config()

        # NOTE: to support multimodal models (such as encoder),
        # we may not have vllm_config so we may need to patch
        # it
        sig = inspect.signature(old_init)
        if "vllm_config" in sig.parameters:
            kwargs["vllm_config"] = vllm_config
        if "prefix" in sig.parameters:
            kwargs["prefix"] = prefix
        old_init(self, **kwargs)

307
        self.vllm_config = vllm_config
308
        self.compilation_config = self.vllm_config.compilation_config
309
        enable_compile = enable_if is None or enable_if(vllm_config)
310
        # for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
311
        # will handle the compilation, so we don't need to do anything here.
312
        self.do_not_compile = (
313
            self.compilation_config.mode
314
            in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
315
316
317
318
            or not supports_dynamo()
            or _should_ignore_torch_compile(self.__class__)
            or not enable_compile
        )
319
320
        if self.do_not_compile:
            return
321

322
323
        self._check_shape_invariants = shape_invariants

324
        compilation_counter.num_models_seen += 1
325
        self.compiled = False
326
327
328

        # Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
        TorchCompileWithNoGuardsWrapper.__init__(self)  # type: ignore[arg-type]
329

330
    cls.__init__ = __init__
331

332
333
334
335
336
    def _mark_dynamic_inputs(
        mod: _T, ds_type: DynamicShapesType, *args: Any, **kwargs: Any
    ) -> None:
        def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None:
            if ds_type == DynamicShapesType.UNBACKED:
337
338
339
340
341
342
343
                if is_torch_equal_or_newer("2.10.0.dev"):
                    for dim in dims:
                        torch._dynamo.decorators.mark_unbacked(
                            arg, dim, hint_override=arg.size()[dim]
                        )
                else:
                    torch._dynamo.decorators.mark_unbacked(arg, dims)
344
345
346
            else:
                torch._dynamo.mark_dynamic(arg, dims)

347
        sig = inspect.signature(mod.__class__.forward)  # type: ignore[attr-defined]
348
349
350
351
        bound_args = sig.bind(mod, *args, **kwargs)
        bound_args.apply_defaults()
        for k, dims in dynamic_arg_dims.items():
            arg = bound_args.arguments.get(k)
352

353
354
355
356
357
            if arg is not None:
                dims = [dims] if isinstance(dims, int) else dims
                if isinstance(arg, torch.Tensor):
                    # In case dims is specified with negative indexing
                    dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
358
                    mark_dynamic(arg, dims)
359
360
361
362
                elif isinstance(arg, IntermediateTensors):
                    for tensor in arg.tensors.values():
                        # In case dims is specified with negative indexing
                        dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
363
                        mark_dynamic(tensor, dims)
364
365
366
367
368
369
370
371
372
373
374
375
376
                else:
                    raise ValueError(
                        "Unsupported dynamic dimensions"
                        f" {dims} for argument {k} with type {type(arg)}."
                    )
        if mark_unbacked_dims:
            for k, dims in mark_unbacked_dims.items():
                arg = bound_args.arguments.get(k)
                if arg is not None:
                    dims = [dims] if isinstance(dims, int) else dims
                    if isinstance(arg, torch.Tensor):
                        # In case dims is specified with negative indexing
                        dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
377
378
379
380
381
382
383
                        if is_torch_equal_or_newer("2.10.0.dev"):
                            for dim in dims:
                                torch._dynamo.decorators.mark_unbacked(
                                    arg, dim, hint_override=arg.size()[dim]
                                )
                        else:
                            torch._dynamo.decorators.mark_unbacked(arg, dims)
384

385
    def __call__(self: _T, *args: Any, **kwargs: Any) -> Any:
386
387
388
        # torch.compiler.is_compiling() means we are inside the compilation
        # e.g. TPU has the compilation logic in model runner, so we don't
        # need to compile the model inside.
389
        if self.do_not_compile or torch.compiler.is_compiling():
390
            return self.forward(*args, **kwargs)
391

392
393
394
        # if aot_compiled_fn is set, call it with partition wrapper context.
        # The partition wrapper must be active at runtime for CUDA graph
        # capture to work correctly with inductor graph partitioning.
395
        if getattr(self, "aot_compiled_fn", None) is not None:
396
397
            with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
                return self.aot_compiled_fn(self, *args, **kwargs)
398

399
        ds_type = self.compilation_config.dynamic_shapes_config.type
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
        cache_dir = None
        aot_compilation_path = None
        if envs.VLLM_USE_AOT_COMPILE:
            """
            When using torch.compile in AOT mode, we store the cache artifacts
            under VLLM_CACHE_ROOT/torch_aot_compile/{hash}/rank_i_j. The {hash}
            contains all of the factors except for the source files being
            traced through, because we don't actually know which source files
            to check at this point (before dynamo runs).
            On loading we will actually look at the source files being traced
            through. If any source file have changed (compared with the
            serialized backend artifacts), then we need to generate a new AOT
            compile artifact from scratch.
            """
            from .caching import compilation_config_hash_factors

            factors: list[str] = compilation_config_hash_factors(self.vllm_config)

            factors.append(_model_hash_key(self.forward))
            hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
            cache_dir = os.path.join(
                envs.VLLM_CACHE_ROOT,
                "torch_aot_compile",
                hash_key,
            )

            rank = self.vllm_config.parallel_config.rank
427
            dp_rank = self.vllm_config.parallel_config.data_parallel_index
428
429
430
431
432
433
434
435
            cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
            aot_compilation_path = os.path.join(cache_dir, "model")
            try:
                with (
                    set_current_vllm_config(self.vllm_config),
                    open(aot_compilation_path, "rb") as f,
                ):
                    start_monitoring_torch_compile(self.vllm_config)
436
437
438
                    loaded_fn = torch.compiler.load_compiled_function(
                        f, f_globals=self.forward.__globals__
                    )
439
                _verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
440
441
                if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
                    loaded_fn.disable_guard_check()
442
443
444
445
446
447
448
449
450
451
452
453
454
455
                self.aot_compiled_fn = loaded_fn
            except Exception as e:
                if os.path.exists(aot_compilation_path):
                    logger.warning(
                        "Cannot load aot compilation from path %s, error: %s",
                        aot_compilation_path,
                        str(e),
                    )
                if envs.VLLM_FORCE_AOT_LOAD:
                    raise e
            if getattr(self, "aot_compiled_fn", None) is not None:
                logger.info(
                    "Directly load AOT compilation from path %s", aot_compilation_path
                )
456
457
458
                # Apply partition wrapper context for proper CUDA graph capture
                with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
                    return self.aot_compiled_fn(self, *args, **kwargs)
459

460
        if self.compiled:
461
462
463
464
            assert (
                not envs.VLLM_USE_AOT_COMPILE
                or self.vllm_config.compilation_config.backend == "eager"
            )
465
            return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)  # type: ignore[arg-type]
466
467

        # This is the path for the first compilation.
468
        # the first compilation needs to have dynamic shapes marked
469
470
471
472
473
474
        _mark_dynamic_inputs(
            self,
            ds_type,
            *args,
            **kwargs,
        )
475

476
477
478
479
480
481
482
483
484
485
486
487
488
489
        # here, it is the starting point of the `torch.compile` process
        start_monitoring_torch_compile(self.vllm_config)
        original_code_object = self.original_code_object()
        logger.debug("Start compiling function %s", original_code_object)

        # we do not want tp delete the original code object entries since
        # we depend on them now to look up cached compiled functions.
        # torch._dynamo.eval_frame.remove_from_cache(original_code_object)

        # collect all relevant files traced by Dynamo,
        # so that the compilation cache can trigger re-compilation
        # properly when any of these files change.

        # 1. the file containing the top-level forward function
490
        self.compilation_config.traced_files.add(original_code_object.co_filename)
491
492
493
494
495
496
497

        # 2. every time Dynamo sees a function call, it will inline
        # the function by calling InliningInstructionTranslator.inline_call_
        # we hijack this function to know all the functions called
        # during Dynamo tracing, and their corresponding files
        inline_call = InliningInstructionTranslator.inline_call_

498
        def patched_inline_call(self_: Any) -> Any:
499
            code = self_.f_code
500
            self.compilation_config.traced_files.add(code.co_filename)
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
            return inline_call(self_)

        # Disable the C++ compilation of symbolic shape guards. C++-fication
        # of symbolic shape guards can improve guard overhead. But, since
        # vllm skip guards anyways, setting this flag to False can improve
        # compile time.
        dynamo_config_patches = {}
        try:
            _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
            dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
        except AttributeError:
            # Note: this config is not available in torch 2.6, we can skip
            # if the config doesn't exist
            logger.debug("enable_cpp_symbolic_shape_guards config not available")

516
517
518
519
520
        # Prepare backed_size_oblivious config patch if needed
        fx_config_patches = {}
        if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
            fx_config_patches["backed_size_oblivious"] = True

521
522
523
524
        # Prepare inductor config patches
        # assume_32bit_indexing is only available in torch 2.10.0.dev+
        inductor_config_patches = {}
        if is_torch_equal_or_newer("2.10.0.dev"):
525
526
527
            inductor_config_patches["assume_32bit_indexing"] = (
                self.compilation_config.dynamic_shapes_config.assume_32_bit_indexing
            )
528

529
530
531
532
533
534
        with (
            patch.object(
                InliningInstructionTranslator, "inline_call_", patched_inline_call
            ),
            torch._dynamo.config.patch(**dynamo_config_patches),
            maybe_use_cudagraph_partition_wrapper(self.vllm_config),
535
            torch.fx.experimental._config.patch(**fx_config_patches),
536
            _torch27_patch_tensor_subclasses(),
537
            torch._inductor.config.patch(**inductor_config_patches),
538
        ):
539
540
541
542
543
            use_aot_compile = envs.VLLM_USE_AOT_COMPILE
            if self.vllm_config.compilation_config.backend == "eager":
                logger.warning("Detected eager backend, disabling AOT compile.")
                use_aot_compile = False
            if use_aot_compile:
544
545
546
547
548
549
550
551
552
553
554
555
556
557
                self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
                output = self.aot_compiled_fn(self, *args, **kwargs)
                assert aot_compilation_path is not None
                assert cache_dir is not None
                try:
                    os.makedirs(cache_dir, exist_ok=True)
                    self.aot_compiled_fn.save_compiled_function(aot_compilation_path)
                except Exception as e:
                    logger.warning(
                        "Cannot save aot compilation to path %s, error: %s",
                        aot_compilation_path,
                        str(e),
                    )
            else:
558
                output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)  # type: ignore[arg-type]
559
560
561

        self.compiled = True
        return output
562

563
    cls.__call__ = __call__
564
    return cls
565
566
567


@contextlib.contextmanager
568
569
570
def maybe_use_cudagraph_partition_wrapper(
    vllm_config: VllmConfig,
) -> Generator[None, None, None]:
571
572
573
574
    """
    Context manager to set/unset customized cudagraph partition wrappers.

    If we're using Inductor-based graph partitioning, we currently have the
575
    whole `fx.Graph` before Inductor lowering and the piecewise
576
577
578
579
580
581
582
583
    splitting happens after all graph passes and fusions. Here, we add
    a custom hook for Inductor to wrap each partition with our static
    graph wrapper class to maintain more control over static graph
    capture and replay.
    """
    from vllm.config import CUDAGraphMode

    compilation_config = vllm_config.compilation_config
584
585
586
587
    if (
        compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
        and compilation_config.use_inductor_graph_partition
    ):
588
589
590
591
592
593
        from torch._inductor.utils import CUDAGraphWrapperMetadata

        from vllm.compilation.cuda_graph import CUDAGraphOptions
        from vllm.platforms import current_platform

        static_graph_wrapper_class = resolve_obj_by_qualname(
594
595
            current_platform.get_static_graph_wrapper_cls()
        )
596

597
598
599
        def customized_cudagraph_wrapper(
            f: Callable[..., Any], metadata: CUDAGraphWrapperMetadata
        ) -> Any:
600
601
602
603
604
605
606
607
608
609
            partition_id = metadata.partition_index
            num_partitions = metadata.num_partitions
            return static_graph_wrapper_class(
                runnable=f,
                vllm_config=vllm_config,
                runtime_mode=CUDAGraphMode.PIECEWISE,
                cudagraph_options=CUDAGraphOptions(
                    debug_log_enable=partition_id == 0,
                    gc_disable=partition_id != 0,
                    weak_ref_output=partition_id == num_partitions - 1,
610
611
                ),
            )
612
613

        torch._inductor.utils.set_customized_partition_wrappers(
614
615
            customized_cudagraph_wrapper
        )
616
617
618

    yield

619
620
621
622
    if (
        compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
        and compilation_config.use_inductor_graph_partition
    ):
623
        torch._inductor.utils.set_customized_partition_wrappers(None)
624
625
626


@contextlib.contextmanager
627
def _torch27_patch_tensor_subclasses() -> Generator[None, None, None]:
628
629
630
631
632
633
    """
    Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
    using torch 2.7.0. This enables using weight_loader_v2 and the use of
    `BasevLLMParameters` without having to replace them with regular tensors
    before `torch.compile`-time.
    """
634
635
636
637
638
639
    from vllm.model_executor.parameter import (
        BasevLLMParameter,
        ModelWeightParameter,
        RowvLLMParameter,
        _ColumnvLLMParameter,
    )
640

641
    def return_false(*args: Any, **kwargs: Any) -> Literal[False]:
642
643
        return False

644
    if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"):
645
646
647
        yield
        return

648
649
650
651
652
653
654
655
656
657
658
659
660
661
    with (
        torch._dynamo.config.patch(
            "traceable_tensor_subclasses",
            [
                BasevLLMParameter,
                ModelWeightParameter,
                _ColumnvLLMParameter,
                RowvLLMParameter,
            ],
        ),
        patch(
            "torch._dynamo.variables.torch.can_dispatch_torch_function", return_false
        ),
    ):
662
        yield