decorators.py 28 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import contextlib
5
import hashlib
6
import inspect
7
8
import os
import sys
9
from collections.abc import Callable, Generator
10
from typing import TYPE_CHECKING, Any, TypeVar, overload
11
from unittest.mock import patch
12
13

import torch
14
import torch.nn as nn
15
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
16

17
import vllm.envs as envs
18
from vllm.compilation.counter import compilation_counter
19
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
20
21
22
23
24
25
from vllm.config import (
    CompilationMode,
    VllmConfig,
    get_current_vllm_config,
    set_current_vllm_config,
)
26
from vllm.config.compilation import DynamicShapesType
27
from vllm.forward_context import get_forward_context, is_forward_context_available
28
from vllm.logger import init_logger
29
from vllm.sequence import IntermediateTensors
30
from vllm.utils.import_utils import resolve_obj_by_qualname
31
from vllm.utils.torch_utils import is_torch_equal_or_newer
32

33
from .monitor import monitor_profiling_run, monitor_torch_compile
34

35
36
37
38
39
40
41
42
if TYPE_CHECKING:
    # Only added on nightly/2.10 so wrap
    try:
        from torch._dynamo.package import SourceInfo
    except ImportError:
        # Fallback for old versions not supporting
        SourceInfo = Any

43
logger = init_logger(__name__)
44

45
46
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"

47
_T = TypeVar("_T", bound=nn.Module)
48
49


50
51
52
53
54
def should_torch_compile_mm_encoder(vllm_config: VllmConfig) -> bool:
    """Callable to be passed to `@support_torch_compile`'s `enable_if` argument."""
    return vllm_config.compilation_config.compile_mm_encoder


55
def ignore_torch_compile(cls: type[_T]) -> type[_T]:
56
57
58
59
60
61
    """
    A decorator to ignore support_torch_compile decorator
    on the class. This is useful when a parent class has
    a support_torch_compile decorator, but we don't want to
    compile the class `cls` that inherits the parent class.
    This only ignores compiling the forward of the class the
62
    decorator is applied to.
63
64
65

    If the parent has ignore_torch_compile but the child has
    support_torch_compile, the child will still be compiled.
66

67
68
69
70
71
72
73
74
    If the class has one or more submodules
    that have support_torch_compile decorator applied, compile will
    not be ignored for those submodules.
    """
    setattr(cls, IGNORE_COMPILE_KEY, True)
    return cls


75
def _should_ignore_torch_compile(cls: type[_T]) -> bool:
76
77
78
79
80
81
    """
    Check if the class should be ignored for torch.compile.
    """
    return getattr(cls, IGNORE_COMPILE_KEY, False)


82
83
84
@overload
def support_torch_compile(
    *,
85
    enable_if: Callable[[VllmConfig], bool] | None = None,
86
) -> Callable[[type[_T]], type[_T]]: ...
87
88


89
90
91
@overload
def support_torch_compile(
    *,
92
    dynamic_arg_dims: dict[str, int | list[int]] | None,
93
) -> Callable[[type[_T]], type[_T]]: ...
94
95


96
97
98
99
@overload
def support_torch_compile(
    *,
    mark_unbacked_dims: dict[str, int | list[int]] | None,
100
) -> Callable[[type[_T]], type[_T]]: ...
101
102
103
104
105
106
107


@overload
def support_torch_compile(
    *,
    dynamic_arg_dims: dict[str, int | list[int]] | None,
    mark_unbacked_dims: dict[str, int | list[int]] | None,
108
) -> Callable[[type[_T]], type[_T]]: ...
109
110


111
@overload
112
def support_torch_compile(cls: type[_T]) -> type[_T]: ...
113

114
115

def support_torch_compile(
116
    cls: type[_T] | None = None,
117
    *,
118
    dynamic_arg_dims: dict[str, int | list[int]] | None = None,
119
    mark_unbacked_dims: dict[str, int | list[int]] | None = None,
120
    enable_if: Callable[[VllmConfig], bool] | None = None,
121
    is_encoder: bool = False,
122
    shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
123
) -> Callable[[type[_T]], type[_T]] | type[_T]:
124
125
126
    """
    A decorator to add support for compiling the forward method of a class.

127
128
129
130
131
    Usage 1: use directly as a decorator without arguments:

    ```python
    @support_torch_compile
    class MyModel(nn.Module):
132
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
133
134
135
136
137
138
139
    ```

    Usage 2: use as a decorator with arguments:

    ```python
    @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
    class MyModel(nn.Module):
140
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
141
142
    ```

143
144
145
146
    `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
    dimensions of the argument. The dynamic dimensions can be either a single
    integer or a list of integers.

147
148
149
150
151
152
153
154
155
156
157
158
    if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
    of the `forward` method, based on the following default rules:

    - if the argument is annotated as `torch.Tensor` or
        `Optional[torch.Tensor]`, the first dimension will be
        marked as dynamic.
    - if the argument is annotated as `IntermediateTensors`, the first
        dimension of all the tensors in the intermediate tensors
        will be marked as dynamic.

    During runtime, when we actually mark dimensions of tensors,
     it depends on the value of arguments:
159

160
    - if it is a single integer (can be negative), the corresponding dimension
161
        of the argument will be marked as dynamic.
162
163
164
165
166
167
168
169
    - if it is `None`, ignored.
    - if it is `IntermediateTensors`, all the tensors in the intermediate
        tensors will be marked as dynamic.
    - otherwise, it will raise an error.

    NOTE: if an argument is `None`, it should always be passed as `None` during
    the lifetime of the model, otherwise, it cannot be captured as a single
    computation graph.
170
171
172
173
174

    `enable_if` is a function that takes a `VllmConfig` object as input and
    returns a boolean value indicating whether to compile the model or not.
    This is useful if you want to compile the model only when certain
    conditions are met.
175
176
177

    `mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
    dim to be decorated with `mark_unbacked`.  This is useful if we would like to
178
    enforce that dynamo does not specialize on 0/1 values in the case of dummy input
179
    such as for vision model compilation
180

181
182
183
184
185
    `is_encoder` marks this module as a portion of an multimodal encoder.
    When True, the compile range upper bound is set to MAX_INT32 instead of
    max_num_batched_tokens, since encoder input shapes are unpredictable.
    This is typically used for vision encoder sub-modules in multimodal models.

186
187
188
189
190
191
192
    `shape_invariants` is a function that gets compiled right before forward.
    The function should have the torch._check calls that are needed to set
    the relationships between different input sizes. For example:
            torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
    This enforces constraints on the symbolic shapes without hardcoding
    specific values. It is needed for some models to avoid data dependent
    errors.
193
194
    """

195
    def cls_decorator_helper(cls: type[_T]) -> type[_T]:
196
197
        # helper to pass `dynamic_arg_dims` to `_support_torch_compile`
        # to avoid too much indentation for `_support_torch_compile`
198
        if not hasattr(cls, "forward"):
199
            raise TypeError("decorated class should have a forward method.")
200
        sig = inspect.signature(cls.forward)
201
202
203
204
205
        inferred_dynamic_arg_dims = dynamic_arg_dims
        if inferred_dynamic_arg_dims is None:
            inferred_dynamic_arg_dims = {}
            for k, v in sig.parameters.items():
                if v.annotation in [
206
                    torch.Tensor,
207
                    torch.Tensor | None,
208
                    IntermediateTensors,
209
                    IntermediateTensors | None,
210
211
212
                ]:
                    inferred_dynamic_arg_dims[k] = 0

213
214
215
216
217
            logger.debug(
                ("Inferred dynamic dimensions for forward method of %s: %s"),
                cls,
                list(inferred_dynamic_arg_dims.keys()),
            )
218
219
220
221

        if len(inferred_dynamic_arg_dims) == 0:
            raise ValueError(
                "No dynamic dimensions found in the forward method of "
222
223
                f"{cls}. Please provide dynamic_arg_dims explicitly."
            )
224
225

        for k in inferred_dynamic_arg_dims:
226
227
            if k not in sig.parameters:
                raise ValueError(
228
229
                    f"Argument {k} not found in the forward method of {cls}"
                )
230
        return _support_torch_compile(
231
232
233
234
            cls,
            inferred_dynamic_arg_dims,
            mark_unbacked_dims,
            enable_if,
235
            is_encoder,
236
            shape_invariants,
237
        )
238
239
240
241
242

    if cls is not None:
        # use `support_torch_compile` as a decorator without arguments
        assert isinstance(cls, type)
        return cls_decorator_helper(cls)
243
244
245
246

    return cls_decorator_helper


247
def _model_hash_key(fn: Callable[..., Any]) -> str:
248
249
250
251
252
253
254
255
256
    import vllm

    sha256_hash = hashlib.sha256()
    sha256_hash.update(vllm.__version__.encode())
    sha256_hash.update(fn.__qualname__.encode())
    sha256_hash.update(str(fn.__code__.co_firstlineno).encode())
    return sha256_hash.hexdigest()


257
258
259
def _verify_source_unchanged(
    source_info: "SourceInfo", vllm_config: VllmConfig
) -> None:
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    from .caching import _compute_code_hash, _compute_code_hash_with_content

    file_contents = {}
    for source in source_info.inlined_sources:
        module = sys.modules[source.module]
        file = inspect.getfile(module)
        vllm_config.compilation_config.traced_files.add(file)
        file_contents[file] = source.content
    expected_checksum = _compute_code_hash_with_content(file_contents)
    actual_checksum = _compute_code_hash(set(file_contents.keys()))
    if expected_checksum != actual_checksum:
        raise RuntimeError(
            "Source code has changed since the last compilation. Recompiling the model."
        )


276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def _try_load_aot_compiled_fn(
    model: Any,
    aot_compilation_path: str,
) -> Any | None:
    """Try to load an AOT-compiled function from disk.

    Returns the loaded callable on success, or None on failure.
    Re-raises on failure when ``VLLM_FORCE_AOT_LOAD`` is set.
    """
    try:
        with monitor_torch_compile(model.vllm_config):
            with (
                set_current_vllm_config(model.vllm_config),
                open(aot_compilation_path, "rb") as f,
            ):
                loaded_fn = torch.compiler.load_compiled_function(
                    f, f_globals=model.forward.__globals__
                )
            _verify_source_unchanged(loaded_fn.source_info(), model.vllm_config)
            ds_config = model.compilation_config.dynamic_shapes_config
            if not ds_config.evaluate_guards:
                loaded_fn.disable_guard_check()
            # Eagerly load compiled artifacts now that traced_files
            # is populated by _verify_source_unchanged.
            with maybe_use_cudagraph_partition_wrapper(model.vllm_config):
                loaded_fn._artifacts.compiled_fn.finalize_loading(model.vllm_config)
302
303
304
305
            compilation_counter.num_aot_artifacts_loaded += 1
            logger.info(
                "Directly load AOT compilation from path %s", aot_compilation_path
            )
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
        return loaded_fn
    except Exception as e:
        if os.path.exists(aot_compilation_path):
            if isinstance(e, EOFError):
                message = "Compile cache file corrupted."
            else:
                message = str(e)
            logger.warning(
                "Compiling model again due to a load failure from %s, reason: %s",
                aot_compilation_path,
                message,
            )
        if envs.VLLM_FORCE_AOT_LOAD:
            raise e
        return None


323
def _support_torch_compile(
324
    cls: type[_T],
325
    dynamic_arg_dims: dict[str, int | list[int]],
326
    mark_unbacked_dims: dict[str, int | list[int]] | None = None,
327
    enable_if: Callable[[VllmConfig], bool] | None = None,
328
    is_encoder: bool = False,
329
    shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
330
) -> type[_T]:
331
332
333
    """
    A decorator to add support for compiling the forward method of a class.
    """
334
    if TorchCompileWithNoGuardsWrapper in cls.__bases__:
335
        # support decorating multiple times
336
337
338
339
        return cls

    # take care of method resolution order
    # make sure super().__init__ is called on the base class
340
341
    #  other than TorchCompileWithNoGuardsWrapper
    cls.__bases__ = cls.__bases__ + (TorchCompileWithNoGuardsWrapper,)
342

343
    old_init = cls.__init__
344

345
346
    setattr(cls, IGNORE_COMPILE_KEY, False)

347
    def __init__(
348
349
350
351
352
353
        self: _T,
        *,
        vllm_config: VllmConfig | None = None,
        prefix: str = "",
        **kwargs: Any,
    ) -> None:
354
355
356
357
        if vllm_config is None:
            vllm_config = get_current_vllm_config()

        # NOTE: to support multimodal models (such as encoder),
358
        # we may not have vllm_config so we may need to patch it
359
360
361
362
363
364
365
        sig = inspect.signature(old_init)
        if "vllm_config" in sig.parameters:
            kwargs["vllm_config"] = vllm_config
        if "prefix" in sig.parameters:
            kwargs["prefix"] = prefix
        old_init(self, **kwargs)

366
        self.vllm_config = vllm_config
367
        self.compilation_config = self.vllm_config.compilation_config
368
        enable_compile = enable_if is None or enable_if(vllm_config)
369
        # for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
370
        # will handle the compilation, so we don't need to do anything here.
371
        self.do_not_compile = (
372
            self.compilation_config.mode
373
            in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
374
375
376
            or _should_ignore_torch_compile(self.__class__)
            or not enable_compile
        )
377
378
        if self.do_not_compile:
            return
379

380
        self._check_shape_invariants = shape_invariants
381
        self.was_aot_compile_fn_loaded_from_disk = False
382
        compilation_counter.num_models_seen += 1
383
        self.compiled = False
384
385

        # Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
386
387
388
389
390
        TorchCompileWithNoGuardsWrapper.__init__(
            self,
            compile_prefix=cls.__name__ if is_encoder else "",
            is_encoder=is_encoder,
        )
391

392
    cls.__init__ = __init__
393

394
    def _mark_dynamic_inputs(
395
        mod: type[_T], ds_type: DynamicShapesType, *args: Any, **kwargs: Any
396
397
398
    ) -> None:
        def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None:
            if ds_type == DynamicShapesType.UNBACKED:
399
                if is_torch_equal_or_newer("2.10.0"):
400
401
402
403
404
405
                    for dim in dims:
                        torch._dynamo.decorators.mark_unbacked(
                            arg, dim, hint_override=arg.size()[dim]
                        )
                else:
                    torch._dynamo.decorators.mark_unbacked(arg, dims)
406
407
408
            else:
                torch._dynamo.mark_dynamic(arg, dims)

409
        sig = inspect.signature(mod.__class__.forward)  # type: ignore[attr-defined]
410
411
412
413
        bound_args = sig.bind(mod, *args, **kwargs)
        bound_args.apply_defaults()
        for k, dims in dynamic_arg_dims.items():
            arg = bound_args.arguments.get(k)
414

415
416
417
418
419
            if arg is not None:
                dims = [dims] if isinstance(dims, int) else dims
                if isinstance(arg, torch.Tensor):
                    # In case dims is specified with negative indexing
                    dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
420
                    mark_dynamic(arg, dims)
421
422
423
424
                elif isinstance(arg, IntermediateTensors):
                    for tensor in arg.tensors.values():
                        # In case dims is specified with negative indexing
                        dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
425
                        mark_dynamic(tensor, dims)
426
427
428
429
430
431
432
433
434
435
436
437
438
                else:
                    raise ValueError(
                        "Unsupported dynamic dimensions"
                        f" {dims} for argument {k} with type {type(arg)}."
                    )
        if mark_unbacked_dims:
            for k, dims in mark_unbacked_dims.items():
                arg = bound_args.arguments.get(k)
                if arg is not None:
                    dims = [dims] if isinstance(dims, int) else dims
                    if isinstance(arg, torch.Tensor):
                        # In case dims is specified with negative indexing
                        dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
439
                        if is_torch_equal_or_newer("2.10.0"):
440
441
442
443
444
445
                            for dim in dims:
                                torch._dynamo.decorators.mark_unbacked(
                                    arg, dim, hint_override=arg.size()[dim]
                                )
                        else:
                            torch._dynamo.decorators.mark_unbacked(arg, dims)
446

447
    def __call__(self: type[_T], *args: Any, **kwargs: Any) -> Any:
448
449
450
        # torch.compiler.is_compiling() means we are inside the compilation
        # e.g. TPU has the compilation logic in model runner, so we don't
        # need to compile the model inside.
451
        if self.do_not_compile or torch.compiler.is_compiling():
452
            return self.forward(*args, **kwargs)
453

454
455
456
457
458
459
        # If skip_compiled is set, bypass compiled model call. This is used e.g. for
        # enc-dec models where tensor shapes/types vary across invocations, preventing
        # the capture of a single computational graph.
        if is_forward_context_available() and get_forward_context().skip_compiled:
            return self.forward(*args, **kwargs)

460
461
462
        # if aot_compiled_fn is set, call it with partition wrapper context.
        # The partition wrapper must be active at runtime for CUDA graph
        # capture to work correctly with inductor graph partitioning.
463
        if getattr(self, "aot_compiled_fn", None) is not None:
464
465
            with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
                return self.aot_compiled_fn(self, *args, **kwargs)
466

467
        ds_type = self.compilation_config.dynamic_shapes_config.type
468
469
470
471
472
        cache_dir = None
        aot_compilation_path = None
        if envs.VLLM_USE_AOT_COMPILE:
            """
            When using torch.compile in AOT mode, we store the cache artifacts
473
474
475
476
            under VLLM_CACHE_ROOT/torch_compile_cache/torch_aot_compile/{hash}
            The {hash} contains all of the factors except for the source files
            being traced through, because we don't actually know which source
            files to check at this point (before dynamo runs).
477
478
479
480
481
            On loading we will actually look at the source files being traced
            through. If any source file have changed (compared with the
            serialized backend artifacts), then we need to generate a new AOT
            compile artifact from scratch.
            """
482
            from .caching import aot_compile_hash_factors
483

484
            factors: list[str] = aot_compile_hash_factors(self.vllm_config)
485
486
487
488
489

            factors.append(_model_hash_key(self.forward))
            hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
            cache_dir = os.path.join(
                envs.VLLM_CACHE_ROOT,
490
                "torch_compile_cache",
491
492
493
494
495
                "torch_aot_compile",
                hash_key,
            )

            rank = self.vllm_config.parallel_config.rank
496
            dp_rank = self.vllm_config.parallel_config.data_parallel_index
497
498
            cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
            aot_compilation_path = os.path.join(cache_dir, "model")
499
500
501
502
503
            if not envs.VLLM_DISABLE_COMPILE_CACHE:
                loaded_fn = _try_load_aot_compiled_fn(self, aot_compilation_path)
                if loaded_fn is not None:
                    self.aot_compiled_fn = loaded_fn
                    self.was_aot_compile_fn_loaded_from_disk = True
504
                    with (
505
506
                        monitor_profiling_run(),
                        maybe_use_cudagraph_partition_wrapper(self.vllm_config),
507
                    ):
508
509
                        output = self.aot_compiled_fn(self, *args, **kwargs)
                    return output
510

511
        if self.compiled:
512
513
514
515
            assert (
                not envs.VLLM_USE_AOT_COMPILE
                or self.vllm_config.compilation_config.backend == "eager"
            )
516
            return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)  # type: ignore[arg-type]
517
518

        # This is the path for the first compilation.
519
        # the first compilation needs to have dynamic shapes marked
520
521
522
523
524
525
        _mark_dynamic_inputs(
            self,
            ds_type,
            *args,
            **kwargs,
        )
526

527
528
529
530
531
532
533
534
535
536
537
538
        original_code_object = self.original_code_object()
        logger.debug("Start compiling function %s", original_code_object)

        # we do not want tp delete the original code object entries since
        # we depend on them now to look up cached compiled functions.
        # torch._dynamo.eval_frame.remove_from_cache(original_code_object)

        # collect all relevant files traced by Dynamo,
        # so that the compilation cache can trigger re-compilation
        # properly when any of these files change.

        # 1. the file containing the top-level forward function
539
        self.compilation_config.traced_files.add(original_code_object.co_filename)
540
541
542
543
544
545
546

        # 2. every time Dynamo sees a function call, it will inline
        # the function by calling InliningInstructionTranslator.inline_call_
        # we hijack this function to know all the functions called
        # during Dynamo tracing, and their corresponding files
        inline_call = InliningInstructionTranslator.inline_call_

547
        def patched_inline_call(self_: Any) -> Any:
548
            code = self_.f_code
549
            self.compilation_config.traced_files.add(code.co_filename)
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
            return inline_call(self_)

        # Disable the C++ compilation of symbolic shape guards. C++-fication
        # of symbolic shape guards can improve guard overhead. But, since
        # vllm skip guards anyways, setting this flag to False can improve
        # compile time.
        dynamo_config_patches = {}
        try:
            _ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
            dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
        except AttributeError:
            # Note: this config is not available in torch 2.6, we can skip
            # if the config doesn't exist
            logger.debug("enable_cpp_symbolic_shape_guards config not available")

565
566
567
568
569
        # Prepare backed_size_oblivious config patch if needed
        fx_config_patches = {}
        if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
            fx_config_patches["backed_size_oblivious"] = True

570
        # Prepare inductor config patches
571
        # assume_32bit_indexing is only available in torch 2.10.0+
572
        inductor_config_patches = {}
573
        if is_torch_equal_or_newer("2.10.0"):
574
575
576
            inductor_config_patches["assume_32bit_indexing"] = (
                self.compilation_config.dynamic_shapes_config.assume_32_bit_indexing
            )
577

578
579
580
581
582
583
        with (
            patch.object(
                InliningInstructionTranslator, "inline_call_", patched_inline_call
            ),
            torch._dynamo.config.patch(**dynamo_config_patches),
            maybe_use_cudagraph_partition_wrapper(self.vllm_config),
584
            torch.fx.experimental._config.patch(**fx_config_patches),
585
            torch._inductor.config.patch(**inductor_config_patches),
586
        ):
587
588
589
590
591
            use_aot_compile = envs.VLLM_USE_AOT_COMPILE
            if self.vllm_config.compilation_config.backend == "eager":
                logger.warning("Detected eager backend, disabling AOT compile.")
                use_aot_compile = False
            if use_aot_compile:
592
593
594
                # store the path for saving after warmup
                self._aot_compilation_path = aot_compilation_path
                self._aot_cache_dir = cache_dir
595
596
                with monitor_torch_compile(self.vllm_config):
                    self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
597
                    compilation_counter.num_aot_compiles += 1
598
599
600
                    # All compilation is done at this point, save the
                    # AOT artifact.
                    self.save_aot_compiled_function()
601

602
603
604
605
606
607
608
609
610
611
612
613
614
                with monitor_profiling_run():
                    output = self.aot_compiled_fn(self, *args, **kwargs)
            else:
                with monitor_torch_compile(
                    self.vllm_config,
                    "torch.compile and initial profiling/warmup "
                    "run together took %.2f s in total",
                ):
                    output = TorchCompileWithNoGuardsWrapper.__call__(
                        self,  # type: ignore[arg-type]
                        *args,
                        **kwargs,
                    )
615

616
617
        self.compiled = True
        return output
618

619
    # triggers VllmSerializableFunction.serialize()
620
    def save_aot_compiled_function(self: type[_T]) -> None:
621
622
623
        if envs.VLLM_DISABLE_COMPILE_CACHE:
            return

624
625
626
627
628
629
630
631
632
633
        if self.was_aot_compile_fn_loaded_from_disk:
            logger.debug("AOT compiled function was loaded from cache, skipping save")
            return

        assert (
            self.aot_compiled_fn and self._aot_compilation_path and self._aot_cache_dir
        )

        try:
            os.makedirs(self._aot_cache_dir, exist_ok=True)
634
635
636
637
638
            # File saving should be atomic, so we will save to a temporary location
            # first. Should be upstreamed to PyTorch 2.12 as well.
            tmp_file = f"{self._aot_compilation_path}.{os.getpid()}.tmp"
            self.aot_compiled_fn.save_compiled_function(tmp_file)
            os.replace(tmp_file, self._aot_compilation_path)
639
            compilation_counter.num_aot_artifacts_saved += 1
640
641
642
643
644
            logger.info_once(
                "saved AOT compiled function to %s",
                self._aot_compilation_path,
                scope="local",
            )
645
646
647
648
649
650
651
        except Exception as e:
            logger.warning(
                "unable to save AOT compiled function to %s: %s",
                self._aot_compilation_path,
                e,
            )

652
    cls.__call__ = __call__
653
    cls.save_aot_compiled_function = save_aot_compiled_function
654
    return cls
655
656
657


@contextlib.contextmanager
658
659
660
def maybe_use_cudagraph_partition_wrapper(
    vllm_config: VllmConfig,
) -> Generator[None, None, None]:
661
662
663
664
    """
    Context manager to set/unset customized cudagraph partition wrappers.

    If we're using Inductor-based graph partitioning, we currently have the
665
    whole `fx.Graph` before Inductor lowering and the piecewise
666
667
668
669
670
671
672
673
    splitting happens after all graph passes and fusions. Here, we add
    a custom hook for Inductor to wrap each partition with our static
    graph wrapper class to maintain more control over static graph
    capture and replay.
    """
    from vllm.config import CUDAGraphMode

    compilation_config = vllm_config.compilation_config
674
675
676
677
    if (
        compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
        and compilation_config.use_inductor_graph_partition
    ):
678
679
680
681
682
683
        from torch._inductor.utils import CUDAGraphWrapperMetadata

        from vllm.compilation.cuda_graph import CUDAGraphOptions
        from vllm.platforms import current_platform

        static_graph_wrapper_class = resolve_obj_by_qualname(
684
685
            current_platform.get_static_graph_wrapper_cls()
        )
686

687
688
689
        def customized_cudagraph_wrapper(
            f: Callable[..., Any], metadata: CUDAGraphWrapperMetadata
        ) -> Any:
690
691
692
693
694
695
696
697
698
699
            partition_id = metadata.partition_index
            num_partitions = metadata.num_partitions
            return static_graph_wrapper_class(
                runnable=f,
                vllm_config=vllm_config,
                runtime_mode=CUDAGraphMode.PIECEWISE,
                cudagraph_options=CUDAGraphOptions(
                    debug_log_enable=partition_id == 0,
                    gc_disable=partition_id != 0,
                    weak_ref_output=partition_id == num_partitions - 1,
700
701
                ),
            )
702
703

        torch._inductor.utils.set_customized_partition_wrappers(
704
705
            customized_cudagraph_wrapper
        )
706
707
708

    yield

709
710
711
712
    if (
        compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
        and compilation_config.use_inductor_graph_partition
    ):
713
        torch._inductor.utils.set_customized_partition_wrappers(None)