decorators.py 27 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import hashlib
6
import inspect
7
8
import os
import sys
9
10
from collections.abc import Callable, Generator
from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload
11
from unittest.mock import patch
12
13

import torch
14
import torch.nn as nn
15
from packaging import version
16
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
17

18
import vllm.envs as envs
19
from vllm.compilation.counter import compilation_counter
20
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
21
22
23
24
25
26
from vllm.config import (
    CompilationMode,
    VllmConfig,
    get_current_vllm_config,
    set_current_vllm_config,
)
27
from vllm.config.compilation import DynamicShapesType
28
from vllm.forward_context import get_forward_context, is_forward_context_available
29
from vllm.logger import init_logger
30
from vllm.sequence import IntermediateTensors
31
from vllm.utils.import_utils import resolve_obj_by_qualname
32
from vllm.utils.torch_utils import is_torch_equal_or_newer
33

34
from .monitor import start_monitoring_torch_compile
zhuwenwen's avatar
zhuwenwen committed
35
from vllm.forward_context import get_profilling
36

37
38
39
40
41
42
43
44
if TYPE_CHECKING:
    # Only added on nightly/2.10 so wrap
    try:
        from torch._dynamo.package import SourceInfo
    except ImportError:
        # Fallback for old versions not supporting
        SourceInfo = Any

45
logger = init_logger(__name__)
46

47
48
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"

49
_T = TypeVar("_T", bound=nn.Module)
50
51


52
def ignore_torch_compile(cls: type[_T]) -> type[_T]:
53
54
55
56
57
58
    """
    A decorator to ignore support_torch_compile decorator
    on the class. This is useful when a parent class has
    a support_torch_compile decorator, but we don't want to
    compile the class `cls` that inherits the parent class.
    This only ignores compiling the forward of the class the
59
    decorator is applied to.
60
61
62

    If the parent has ignore_torch_compile but the child has
    support_torch_compile, the child will still be compiled.
63

64
65
66
67
68
69
70
71
    If the class has one or more submodules
    that have support_torch_compile decorator applied, compile will
    not be ignored for those submodules.
    """
    setattr(cls, IGNORE_COMPILE_KEY, True)
    return cls


72
def _should_ignore_torch_compile(cls: type[_T]) -> bool:
73
74
75
76
77
78
    """
    Check if the class should be ignored for torch.compile.
    """
    return getattr(cls, IGNORE_COMPILE_KEY, False)


79
80
81
@overload
def support_torch_compile(
    *,
82
    enable_if: Callable[[VllmConfig], bool] | None = None,
83
) -> Callable[[type[_T]], type[_T]]: ...
84
85


86
87
88
@overload
def support_torch_compile(
    *,
89
    dynamic_arg_dims: dict[str, int | list[int]] | None,
90
) -> Callable[[type[_T]], type[_T]]: ...
91
92


93
94
95
96
@overload
def support_torch_compile(
    *,
    mark_unbacked_dims: dict[str, int | list[int]] | None,
97
) -> Callable[[type[_T]], type[_T]]: ...
98
99


100
101
102
@overload
def support_torch_compile(
    *,
103
104
    dynamic_arg_dims: dict[str, int | list[int]] | None,
    mark_unbacked_dims: dict[str, int | list[int]] | None,
105
) -> Callable[[type[_T]], type[_T]]: ...
106
107
108


@overload
109
def support_torch_compile(cls: type[_T]) -> type[_T]: ...
110

111
112

def support_torch_compile(
113
    cls: type[_T] | None = None,
114
    *,
115
    dynamic_arg_dims: dict[str, int | list[int]] | None = None,
116
    mark_unbacked_dims: dict[str, int | list[int]] | None = None,
117
    enable_if: Callable[[VllmConfig], bool] | None = None,
118
    shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
119
) -> Callable[[type[_T]], type[_T]] | type[_T]:
120
121
122
    """
    A decorator to add support for compiling the forward method of a class.

123
124
125
126
127
    Usage 1: use directly as a decorator without arguments:

    ```python
    @support_torch_compile
    class MyModel(nn.Module):
128
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
129
130
131
132
133
134
135
    ```

    Usage 2: use as a decorator with arguments:

    ```python
    @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
    class MyModel(nn.Module):
136
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
137
138
    ```

139
140
141
142
    `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
    dimensions of the argument. The dynamic dimensions can be either a single
    integer or a list of integers.

143
144
145
146
147
148
149
150
151
152
153
154
    if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
    of the `forward` method, based on the following default rules:

    - if the argument is annotated as `torch.Tensor` or
        `Optional[torch.Tensor]`, the first dimension will be
        marked as dynamic.
    - if the argument is annotated as `IntermediateTensors`, the first
        dimension of all the tensors in the intermediate tensors
        will be marked as dynamic.

    During runtime, when we actually mark dimensions of tensors,
     it depends on the value of arguments:
155

156
    - if it is a single integer (can be negative), the corresponding dimension
157
        of the argument will be marked as dynamic.
158
159
160
161
162
163
164
165
    - if it is `None`, ignored.
    - if it is `IntermediateTensors`, all the tensors in the intermediate
        tensors will be marked as dynamic.
    - otherwise, it will raise an error.

    NOTE: if an argument is `None`, it should always be passed as `None` during
    the lifetime of the model, otherwise, it cannot be captured as a single
    computation graph.
166
167
168
169
170

    `enable_if` is a function that takes a `VllmConfig` object as input and
    returns a boolean value indicating whether to compile the model or not.
    This is useful if you want to compile the model only when certain
    conditions are met.
171
172
173

    `mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
    dim to be decorated with `mark_unbacked`.  This is useful if we would like to
174
    enforce that dynamo does not specialize on 0/1 values in the case of dummy input
175
    such as for vision model compilation
176
177
178
179
180
181
182
183

    `shape_invariants` is a function that gets compiled right before forward.
    The function should have the torch._check calls that are needed to set
    the relationships between different input sizes. For example:
            torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
    This enforces constraints on the symbolic shapes without hardcoding
    specific values. It is needed for some models to avoid data dependent
    errors.
184
185
    """

186
    def cls_decorator_helper(cls: type[_T]) -> type[_T]:
187
188
        # helper to pass `dynamic_arg_dims` to `_support_torch_compile`
        # to avoid too much indentation for `_support_torch_compile`
189
        if not hasattr(cls, "forward"):
190
            raise TypeError("decorated class should have a forward method.")
191
        sig = inspect.signature(cls.forward)
192
193
194
195
196
        inferred_dynamic_arg_dims = dynamic_arg_dims
        if inferred_dynamic_arg_dims is None:
            inferred_dynamic_arg_dims = {}
            for k, v in sig.parameters.items():
                if v.annotation in [
197
                    torch.Tensor,
198
                    torch.Tensor | None,
199
                    IntermediateTensors,
200
                    IntermediateTensors | None,
201
202
203
                ]:
                    inferred_dynamic_arg_dims[k] = 0

204
205
206
207
208
            logger.debug(
                ("Inferred dynamic dimensions for forward method of %s: %s"),
                cls,
                list(inferred_dynamic_arg_dims.keys()),
            )
209
210
211
212

        if len(inferred_dynamic_arg_dims) == 0:
            raise ValueError(
                "No dynamic dimensions found in the forward method of "
213
214
                f"{cls}. Please provide dynamic_arg_dims explicitly."
            )
215
216

        for k in inferred_dynamic_arg_dims:
217
218
            if k not in sig.parameters:
                raise ValueError(
219
220
                    f"Argument {k} not found in the forward method of {cls}"
                )
221
        return _support_torch_compile(
222
223
224
225
226
            cls,
            inferred_dynamic_arg_dims,
            mark_unbacked_dims,
            enable_if,
            shape_invariants,
227
        )
228
229
230
231
232

    if cls is not None:
        # use `support_torch_compile` as a decorator without arguments
        assert isinstance(cls, type)
        return cls_decorator_helper(cls)
233
234
235
236

    return cls_decorator_helper


237
def _model_hash_key(fn: Callable[..., Any]) -> str:
238
239
240
241
242
243
244
245
246
    import vllm

    sha256_hash = hashlib.sha256()
    sha256_hash.update(vllm.__version__.encode())
    sha256_hash.update(fn.__qualname__.encode())
    sha256_hash.update(str(fn.__code__.co_firstlineno).encode())
    return sha256_hash.hexdigest()


247
248
249
def _verify_source_unchanged(
    source_info: "SourceInfo", vllm_config: VllmConfig
) -> None:
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    from .caching import _compute_code_hash, _compute_code_hash_with_content

    file_contents = {}
    for source in source_info.inlined_sources:
        module = sys.modules[source.module]
        file = inspect.getfile(module)
        vllm_config.compilation_config.traced_files.add(file)
        file_contents[file] = source.content
    expected_checksum = _compute_code_hash_with_content(file_contents)
    actual_checksum = _compute_code_hash(set(file_contents.keys()))
    if expected_checksum != actual_checksum:
        raise RuntimeError(
            "Source code has changed since the last compilation. Recompiling the model."
        )


266
def _support_torch_compile(
267
    cls: type[_T],
268
    dynamic_arg_dims: dict[str, int | list[int]],
269
    mark_unbacked_dims: dict[str, int | list[int]] | None = None,
270
    enable_if: Callable[[VllmConfig], bool] | None = None,
271
    shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
272
) -> type[_T]:
273
274
275
    """
    A decorator to add support for compiling the forward method of a class.
    """
276
    if TorchCompileWithNoGuardsWrapper in cls.__bases__:
277
        # support decorating multiple times
278
279
280
281
        return cls

    # take care of method resolution order
    # make sure super().__init__ is called on the base class
282
283
    #  other than TorchCompileWithNoGuardsWrapper
    cls.__bases__ = cls.__bases__ + (TorchCompileWithNoGuardsWrapper,)
284

285
    old_init = cls.__init__
286

287
288
    setattr(cls, IGNORE_COMPILE_KEY, False)

289
    def __init__(
290
291
292
293
294
295
        self: _T,
        *,
        vllm_config: VllmConfig | None = None,
        prefix: str = "",
        **kwargs: Any,
    ) -> None:
296
297
298
299
300
301
302
303
304
305
306
307
308
        if vllm_config is None:
            vllm_config = get_current_vllm_config()

        # NOTE: to support multimodal models (such as encoder),
        # we may not have vllm_config so we may need to patch
        # it
        sig = inspect.signature(old_init)
        if "vllm_config" in sig.parameters:
            kwargs["vllm_config"] = vllm_config
        if "prefix" in sig.parameters:
            kwargs["prefix"] = prefix
        old_init(self, **kwargs)

309
        self.vllm_config = vllm_config
310
        self.compilation_config = self.vllm_config.compilation_config
311
        enable_compile = enable_if is None or enable_if(vllm_config)
312
        # for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
313
        # will handle the compilation, so we don't need to do anything here.
314
        self.do_not_compile = (
315
            self.compilation_config.mode
316
            in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
317
318
319
            or _should_ignore_torch_compile(self.__class__)
            or not enable_compile
        )
320
321
        if self.do_not_compile:
            return
322

323
        self._check_shape_invariants = shape_invariants
324
        self.was_aot_compile_fn_loaded_from_disk = False
325
        compilation_counter.num_models_seen += 1
326
        self.compiled = False
327
328

        # Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
329
        TorchCompileWithNoGuardsWrapper.__init__(self)
330

331
    cls.__init__ = __init__
332

333
    def _mark_dynamic_inputs(
334
        mod: type[_T], ds_type: DynamicShapesType, *args: Any, **kwargs: Any
335
336
337
    ) -> None:
        def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None:
            if ds_type == DynamicShapesType.UNBACKED:
338
339
340
341
342
343
344
                if is_torch_equal_or_newer("2.10.0.dev"):
                    for dim in dims:
                        torch._dynamo.decorators.mark_unbacked(
                            arg, dim, hint_override=arg.size()[dim]
                        )
                else:
                    torch._dynamo.decorators.mark_unbacked(arg, dims)
345
346
347
            else:
                torch._dynamo.mark_dynamic(arg, dims)

348
        sig = inspect.signature(mod.__class__.forward)  # type: ignore[attr-defined]
349
350
351
352
        bound_args = sig.bind(mod, *args, **kwargs)
        bound_args.apply_defaults()
        for k, dims in dynamic_arg_dims.items():
            arg = bound_args.arguments.get(k)
353

354
355
356
357
358
            if arg is not None:
                dims = [dims] if isinstance(dims, int) else dims
                if isinstance(arg, torch.Tensor):
                    # In case dims is specified with negative indexing
                    dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
359
                    mark_dynamic(arg, dims)
360
361
362
363
                elif isinstance(arg, IntermediateTensors):
                    for tensor in arg.tensors.values():
                        # In case dims is specified with negative indexing
                        dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
364
                        mark_dynamic(tensor, dims)
365
366
367
368
369
370
371
372
373
374
375
376
377
                else:
                    raise ValueError(
                        "Unsupported dynamic dimensions"
                        f" {dims} for argument {k} with type {type(arg)}."
                    )
        if mark_unbacked_dims:
            for k, dims in mark_unbacked_dims.items():
                arg = bound_args.arguments.get(k)
                if arg is not None:
                    dims = [dims] if isinstance(dims, int) else dims
                    if isinstance(arg, torch.Tensor):
                        # In case dims is specified with negative indexing
                        dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
378
379
380
381
382
383
384
                        if is_torch_equal_or_newer("2.10.0.dev"):
                            for dim in dims:
                                torch._dynamo.decorators.mark_unbacked(
                                    arg, dim, hint_override=arg.size()[dim]
                                )
                        else:
                            torch._dynamo.decorators.mark_unbacked(arg, dims)
385

386
    def __call__(self: type[_T], *args: Any, **kwargs: Any) -> Any:
387
388
389
        # torch.compiler.is_compiling() means we are inside the compilation
        # e.g. TPU has the compilation logic in model runner, so we don't
        # need to compile the model inside.
390
        if self.do_not_compile or torch.compiler.is_compiling() or get_profilling():
391
            return self.forward(*args, **kwargs)
392

393
394
395
396
397
398
        # If skip_compiled is set, bypass compiled model call. This is used e.g. for
        # enc-dec models where tensor shapes/types vary across invocations, preventing
        # the capture of a single computational graph.
        if is_forward_context_available() and get_forward_context().skip_compiled:
            return self.forward(*args, **kwargs)

399
400
401
        # if aot_compiled_fn is set, call it with partition wrapper context.
        # The partition wrapper must be active at runtime for CUDA graph
        # capture to work correctly with inductor graph partitioning.
402
        if getattr(self, "aot_compiled_fn", None) is not None:
403
404
            with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
                return self.aot_compiled_fn(self, *args, **kwargs)
405

406
        ds_type = self.compilation_config.dynamic_shapes_config.type
407
408
409
410
411
412
413
414
415
416
417
418
419
420
        cache_dir = None
        aot_compilation_path = None
        if envs.VLLM_USE_AOT_COMPILE:
            """
            When using torch.compile in AOT mode, we store the cache artifacts
            under VLLM_CACHE_ROOT/torch_aot_compile/{hash}/rank_i_j. The {hash}
            contains all of the factors except for the source files being
            traced through, because we don't actually know which source files
            to check at this point (before dynamo runs).
            On loading we will actually look at the source files being traced
            through. If any source file have changed (compared with the
            serialized backend artifacts), then we need to generate a new AOT
            compile artifact from scratch.
            """
421
            from .caching import aot_compile_hash_factors
422

423
            factors: list[str] = aot_compile_hash_factors(self.vllm_config)
424
425
426
427
428
429
430
431
432
433

            factors.append(_model_hash_key(self.forward))
            hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
            cache_dir = os.path.join(
                envs.VLLM_CACHE_ROOT,
                "torch_aot_compile",
                hash_key,
            )

            rank = self.vllm_config.parallel_config.rank
434
            dp_rank = self.vllm_config.parallel_config.data_parallel_index
435
436
            cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
            aot_compilation_path = os.path.join(cache_dir, "model")
437
            try:
438
439
440
441
442
                with (
                    set_current_vllm_config(self.vllm_config),
                    open(aot_compilation_path, "rb") as f,
                ):
                    start_monitoring_torch_compile(self.vllm_config)
443
444
445
                    loaded_fn = torch.compiler.load_compiled_function(
                        f, f_globals=self.forward.__globals__
                    )
446
                _verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
447
448
                if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
                    loaded_fn.disable_guard_check()
449
                self.aot_compiled_fn = loaded_fn
450
                self.was_aot_compile_fn_loaded_from_disk = True
451
452
453
454
455
456
457
458
459
460
461
462
463
            except Exception as e:
                if os.path.exists(aot_compilation_path):
                    logger.warning(
                        "Cannot load aot compilation from path %s, error: %s",
                        aot_compilation_path,
                        str(e),
                    )
                if envs.VLLM_FORCE_AOT_LOAD:
                    raise e
            if getattr(self, "aot_compiled_fn", None) is not None:
                logger.info(
                    "Directly load AOT compilation from path %s", aot_compilation_path
                )
464
465
466
                # Apply partition wrapper context for proper CUDA graph capture
                with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
                    return self.aot_compiled_fn(self, *args, **kwargs)
467

468
        if self.compiled:
469
470
471
472
            assert (
                not envs.VLLM_USE_AOT_COMPILE
                or self.vllm_config.compilation_config.backend == "eager"
            )
473
            return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)  # type: ignore[arg-type]
474
475

        # This is the path for the first compilation.
476
        # the first compilation needs to have dynamic shapes marked
477
478
479
480
481
482
        _mark_dynamic_inputs(
            self,
            ds_type,
            *args,
            **kwargs,
        )
483

484
485
486
487
488
489
490
491
492
493
494
495
496
497
        # here, it is the starting point of the `torch.compile` process
        start_monitoring_torch_compile(self.vllm_config)
        original_code_object = self.original_code_object()
        logger.debug("Start compiling function %s", original_code_object)

        # we do not want tp delete the original code object entries since
        # we depend on them now to look up cached compiled functions.
        # torch._dynamo.eval_frame.remove_from_cache(original_code_object)

        # collect all relevant files traced by Dynamo,
        # so that the compilation cache can trigger re-compilation
        # properly when any of these files change.

        # 1. the file containing the top-level forward function
498
        self.compilation_config.traced_files.add(original_code_object.co_filename)
499
500
501
502
503
504
505

        # 2. every time Dynamo sees a function call, it will inline
        # the function by calling InliningInstructionTranslator.inline_call_
        # we hijack this function to know all the functions called
        # during Dynamo tracing, and their corresponding files
        inline_call = InliningInstructionTranslator.inline_call_

506
        def patched_inline_call(self_: Any) -> Any:
507
            code = self_.f_code
508
            self.compilation_config.traced_files.add(code.co_filename)
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
            return inline_call(self_)

        # Disable the C++ compilation of symbolic shape guards. C++-fication
        # of symbolic shape guards can improve guard overhead. But, since
        # vllm skip guards anyways, setting this flag to False can improve
        # compile time.
        dynamo_config_patches = {}
        try:
            _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
            dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
        except AttributeError:
            # Note: this config is not available in torch 2.6, we can skip
            # if the config doesn't exist
            logger.debug("enable_cpp_symbolic_shape_guards config not available")

524
525
526
527
528
        # Prepare backed_size_oblivious config patch if needed
        fx_config_patches = {}
        if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
            fx_config_patches["backed_size_oblivious"] = True

529
530
531
532
        # Prepare inductor config patches
        # assume_32bit_indexing is only available in torch 2.10.0.dev+
        inductor_config_patches = {}
        if is_torch_equal_or_newer("2.10.0.dev"):
533
534
535
            inductor_config_patches["assume_32bit_indexing"] = (
                self.compilation_config.dynamic_shapes_config.assume_32_bit_indexing
            )
536

537
538
539
540
541
542
        with (
            patch.object(
                InliningInstructionTranslator, "inline_call_", patched_inline_call
            ),
            torch._dynamo.config.patch(**dynamo_config_patches),
            maybe_use_cudagraph_partition_wrapper(self.vllm_config),
543
            torch.fx.experimental._config.patch(**fx_config_patches),
544
            _torch27_patch_tensor_subclasses(),
545
            torch._inductor.config.patch(**inductor_config_patches),
546
        ):
547
548
549
550
551
            use_aot_compile = envs.VLLM_USE_AOT_COMPILE
            if self.vllm_config.compilation_config.backend == "eager":
                logger.warning("Detected eager backend, disabling AOT compile.")
                use_aot_compile = False
            if use_aot_compile:
552
553
554
555
556
557
558
559
560
                from vllm.compilation.backends import set_on_compilation_complete

                # store the path for saving after warmup
                self._aot_compilation_path = aot_compilation_path
                self._aot_cache_dir = cache_dir
                # set callback in context so it's available when compilation completes
                with set_on_compilation_complete(self.save_aot_compiled_function):
                    self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
                    output = self.aot_compiled_fn(self, *args, **kwargs)
561
            else:
562
                output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)  # type: ignore[arg-type]
563
564
565

        self.compiled = True
        return output
566

567
    # triggers VllmSerializableFunction.serialize()
568
    def save_aot_compiled_function(self: type[_T]) -> None:
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
        if self.was_aot_compile_fn_loaded_from_disk:
            logger.debug("AOT compiled function was loaded from cache, skipping save")
            return

        assert (
            self.aot_compiled_fn and self._aot_compilation_path and self._aot_cache_dir
        )

        logger.info("saving AOT compiled function to %s", self._aot_compilation_path)
        try:
            os.makedirs(self._aot_cache_dir, exist_ok=True)
            self.aot_compiled_fn.save_compiled_function(self._aot_compilation_path)
            logger.info("saved AOT compiled function to %s", self._aot_compilation_path)
        except Exception as e:
            logger.warning(
                "unable to save AOT compiled function to %s: %s",
                self._aot_compilation_path,
                e,
            )

589
    cls.__call__ = __call__
590
    cls.save_aot_compiled_function = save_aot_compiled_function
591
    return cls
592
593
594


@contextlib.contextmanager
595
596
597
def maybe_use_cudagraph_partition_wrapper(
    vllm_config: VllmConfig,
) -> Generator[None, None, None]:
598
599
600
601
    """
    Context manager to set/unset customized cudagraph partition wrappers.

    If we're using Inductor-based graph partitioning, we currently have the
602
    whole `fx.Graph` before Inductor lowering and the piecewise
603
604
605
606
607
608
609
610
    splitting happens after all graph passes and fusions. Here, we add
    a custom hook for Inductor to wrap each partition with our static
    graph wrapper class to maintain more control over static graph
    capture and replay.
    """
    from vllm.config import CUDAGraphMode

    compilation_config = vllm_config.compilation_config
611
612
613
614
    if (
        compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
        and compilation_config.use_inductor_graph_partition
    ):
615
616
617
618
619
620
        from torch._inductor.utils import CUDAGraphWrapperMetadata

        from vllm.compilation.cuda_graph import CUDAGraphOptions
        from vllm.platforms import current_platform

        static_graph_wrapper_class = resolve_obj_by_qualname(
621
622
            current_platform.get_static_graph_wrapper_cls()
        )
623

624
625
626
        def customized_cudagraph_wrapper(
            f: Callable[..., Any], metadata: CUDAGraphWrapperMetadata
        ) -> Any:
627
628
629
630
631
632
633
634
635
636
            partition_id = metadata.partition_index
            num_partitions = metadata.num_partitions
            return static_graph_wrapper_class(
                runnable=f,
                vllm_config=vllm_config,
                runtime_mode=CUDAGraphMode.PIECEWISE,
                cudagraph_options=CUDAGraphOptions(
                    debug_log_enable=partition_id == 0,
                    gc_disable=partition_id != 0,
                    weak_ref_output=partition_id == num_partitions - 1,
637
638
                ),
            )
639
640

        torch._inductor.utils.set_customized_partition_wrappers(
641
642
            customized_cudagraph_wrapper
        )
643
644
645

    yield

646
647
648
649
    if (
        compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
        and compilation_config.use_inductor_graph_partition
    ):
650
        torch._inductor.utils.set_customized_partition_wrappers(None)
651
652
653


@contextlib.contextmanager
654
def _torch27_patch_tensor_subclasses() -> Generator[None, None, None]:
655
656
657
658
659
660
    """
    Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
    using torch 2.7.0. This enables using weight_loader_v2 and the use of
    `BasevLLMParameters` without having to replace them with regular tensors
    before `torch.compile`-time.
    """
661
662
663
664
665
666
    from vllm.model_executor.parameter import (
        BasevLLMParameter,
        ModelWeightParameter,
        RowvLLMParameter,
        _ColumnvLLMParameter,
    )
667

668
    def return_false(*args: Any, **kwargs: Any) -> Literal[False]:
669
670
        return False

671
    if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"):
672
673
674
        yield
        return

675
676
677
678
679
680
681
682
683
684
685
686
687
688
    with (
        torch._dynamo.config.patch(
            "traceable_tensor_subclasses",
            [
                BasevLLMParameter,
                ModelWeightParameter,
                _ColumnvLLMParameter,
                RowvLLMParameter,
            ],
        ),
        patch(
            "torch._dynamo.variables.torch.can_dispatch_torch_function", return_false
        ),
    ):
689
        yield