backends.py 37.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import ast
5
import contextvars
6
import dataclasses
7
import hashlib
8
import json
9
import operator
10
11
import os
import pprint
12
import time
13
from collections.abc import Callable, Generator, Sequence
14
from contextlib import contextmanager
15
from copy import deepcopy
16
from functools import partial
17
from typing import Any
18
19
20

import torch
import torch.fx as fx
21
from torch._dispatch.python import enable_python_dispatcher
22

23
import vllm.envs as envs
24
25
26
from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.partition_rules import (
    inductor_partition_rule_context,
27
    should_split,
28
)
29
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
30
from vllm.config.compilation import DynamicShapesType
31
from vllm.config.utils import Range, hash_factors
32
from vllm.logger import init_logger
33
from vllm.logging_utils import lazy
34
from vllm.platforms import current_platform
35
from vllm.utils.import_utils import resolve_obj_by_qualname
36

37
38
39
40
41
from .compiler_interface import (
    CompilerInterface,
    EagerAdaptor,
    InductorAdaptor,
    InductorStandaloneAdaptor,
42
    is_compile_cache_enabled,
43
)
44
from .counter import compilation_counter
45
46
from .inductor_pass import InductorPass
from .pass_manager import PostGradPassManager
47
48
49

logger = init_logger(__name__)

50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def make_copy_and_call(
    sym_tensor_indices: list[int],
    input_buffers: list[torch.Tensor | None],
    callable_fn: Callable[..., Any],
) -> Callable[..., Any]:
    """Create a wrapper that copies inputs to static buffers before calling.

    This is used for cudagraph input copying where we need to copy dynamic
    tensors to static buffers before invoking the compiled graph.

    Args:
        sym_tensor_indices: Indices of tensors with symbolic shapes
        input_buffers: List of static buffers (can contain None for lazy init)
        callable_fn: The compiled function to call

    Returns:
        A wrapper function that copies inputs and calls the compiled function
    """

70
    def copy_and_call(*args: Any) -> Any:
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        list_args = list(args)
        for i, index in enumerate(sym_tensor_indices):
            runtime_tensor = list_args[index]
            runtime_shape = runtime_tensor.shape[0]

            # lazy initialization of buffer on first call
            if input_buffers[i] is None:
                input_buffers[i] = runtime_tensor.clone()

            static_tensor = input_buffers[i][:runtime_shape]  # type: ignore[index]
            static_tensor.copy_(runtime_tensor)
            list_args[index] = static_tensor
        return callable_fn(*list_args)

    return copy_and_call


88
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
89
90
91
92
    assert not envs.VLLM_USE_MEGA_AOT_ARTIFACT or envs.VLLM_USE_STANDALONE_COMPILE, (
        "VLLM_USE_MEGA_AOT_ARTIFACT=1 requires VLLM_USE_STANDALONE_COMPILE=1"
    )

93
    if compilation_config.backend == "inductor":
94
95
        # Use standalone compile only if requested, version is new enough,
        # and the symbol actually exists in this PyTorch build.
96
97
        if envs.VLLM_USE_STANDALONE_COMPILE and hasattr(
            torch._inductor, "standalone_compile"
98
        ):
99
            logger.debug("Using InductorStandaloneAdaptor")
100
101
102
            return InductorStandaloneAdaptor(
                compilation_config.compile_cache_save_format
            )
103
        else:
104
            logger.debug("Using InductorAdaptor")
105
            return InductorAdaptor()
106
    elif compilation_config.backend == "eager":
107
        logger.debug("Using EagerAdaptor")
108
        return EagerAdaptor()
109
110
111
112
113
    else:
        logger.debug("Using custom backend: %s", compilation_config.backend)
        compiler = resolve_obj_by_qualname(current_platform.get_compile_backend())()
        assert isinstance(compiler, CompilerInterface)
        return compiler
114
115


116
117
118
119
120
class CompilerManager:
    """
    A manager to manage the compilation process, including
    caching the compiled graph, loading the compiled graph,
    and compiling the graph.
121

122
123
124
    The cache is a dict mapping
    `(runtime_shape, graph_index, backend_name)`
    to `any_data` returned from the compiler.
125

126
127
128
    When serializing the cache, we save it to a Python file
    for readability. We don't use json here because json doesn't
    support int as key.
129
130
    """

131
    def __init__(self, compilation_config: CompilationConfig) -> None:
132
        self.cache: dict[tuple[Range, int, str], Any] = dict()
133
        self.is_cache_updated = False
134
135
        self.compilation_config = compilation_config
        self.compiler = make_compiler(compilation_config)
136

137
138
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        return self.compiler.compute_hash(vllm_config)
139

140
    @contextmanager
141
    def compile_context(self, compile_range: Range) -> Generator[None, None, None]:
142
143
144
        """Provide compilation context for the duration of compilation to set
        any torch global properties we want to scope to a single Inductor
        compilation (e.g. partition rules, pass context)."""
145
        with pass_context(compile_range):
146
            if self.compilation_config.use_inductor_graph_partition:
147
                with inductor_partition_rule_context(
148
                    self.compilation_config.splitting_ops
149
                ):
150
151
152
153
                    yield
            else:
                yield

154
155
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
156
    ) -> None:
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        """
        Initialize the cache directory for the compiler.

        The organization of the cache directory is as follows:
        cache_dir=/path/to/hash_str/rank_i_j/prefix/
        inside cache_dir, there will be:
        - vllm_compile_cache.py
        - computation_graph.py
        - transformed_code.py

        for multiple prefixes, they can share the same
        base cache dir of /path/to/hash_str/rank_i_j/ ,
        to store some common compilation artifacts.
        """

172
        self.disable_cache = disable_cache
173
        self.cache_dir = cache_dir
174
175
176
177
        self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")

        if not disable_cache and os.path.exists(self.cache_file_path):
            # load the cache from the file
178
            with open(self.cache_file_path) as f:
179
180
181
                # we use ast.literal_eval to parse the data
                # because it is a safe way to parse Python literals.
                # do not use eval(), it is unsafe.
182
183
                cache = ast.literal_eval(f.read())

184
            def check_type(value: Any, ty: type) -> None:
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
                if not isinstance(value, ty):
                    raise TypeError(f"Expected {ty} but got {type(value)} for {value}")

            def parse_key(key: Any) -> tuple[Range, int, str]:
                range_tuple, graph_index, compiler_name = key
                check_type(graph_index, int)
                check_type(compiler_name, str)
                if isinstance(range_tuple, tuple):
                    start, end = range_tuple
                    check_type(start, int)
                    check_type(end, int)
                    range_tuple = Range(start=start, end=end)
                check_type(range_tuple, Range)
                return range_tuple, graph_index, compiler_name

            self.cache = {parse_key(key): value for key, value in cache.items()}
201

202
203
204
        self.compiler.initialize_cache(
            cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
        )
205

206
    def save_to_file(self) -> None:
207
        if self.disable_cache or not self.is_cache_updated:
208
            return
209
210
        printer = pprint.PrettyPrinter(indent=4)
        data = printer.pformat(self.cache)
211
        with open(self.cache_file_path, "w") as f:
212
213
            f.write(data)

214
215
216
217
218
    def load(
        self,
        graph: fx.GraphModule,
        example_inputs: list[Any],
        graph_index: int,
219
        compile_range: Range,
220
    ) -> Callable[..., Any] | None:
221
        if (compile_range, graph_index, self.compiler.name) not in self.cache:
222
            return None
223
        handle = self.cache[(compile_range, graph_index, self.compiler.name)]
224
        compiled_graph = self.compiler.load(
225
226
227
228
229
230
231
232
            handle, graph, example_inputs, graph_index, compile_range
        )
        logger.debug(
            "Directly load the %s-th graph for compile range %sfrom %s via handle %s",
            graph_index,
            str(compile_range),
            self.compiler.name,
            handle,
233
        )
234
235
        return compiled_graph

236
237
238
    def compile(
        self,
        graph: fx.GraphModule,
239
        example_inputs: list[Any],
240
        additional_inductor_config: dict[str, Any],
241
        compilation_config: CompilationConfig,
242
        compile_range: Range,
243
244
245
        graph_index: int = 0,
        num_graphs: int = 1,
    ) -> Any:
246
        if graph_index == 0:
247
248
249
250
251
252
253
254
255
            # before compiling the first graph, record the start time
            global compilation_start_time
            compilation_start_time = time.time()

        compilation_counter.num_backend_compilations += 1

        compiled_graph = None

        # try to load from the cache
256
        compiled_graph = self.load(graph, example_inputs, graph_index, compile_range)
257
        if compiled_graph is not None:
258
259
260
261
262
            if graph_index == num_graphs - 1:
                # after loading the last graph for this shape, record the time.
                # there can be multiple graphs due to piecewise compilation.
                now = time.time()
                elapsed = now - compilation_start_time
263
                compilation_config.compilation_time += elapsed
264
265
266
267
268
269
                logger.info(
                    "Directly load the compiled graph(s) for compile range %s "
                    "from the cache, took %.3f s",
                    str(compile_range),
                    elapsed,
                )
270
271
272
273
            return compiled_graph

        # no compiler cached the graph, or the cache is disabled,
        # we need to compile it
274
275
276
277
        if isinstance(self.compiler, InductorAdaptor):
            # Let compile_fx generate a key for us
            maybe_key = None
        else:
278
279
280
281
            maybe_key = "artifact_compile_range_"
            maybe_key += f"{compile_range.start}_{compile_range.end}"
            maybe_key += f"_subgraph_{graph_index}"
        with self.compile_context(compile_range):
282
283
284
285
            compiled_graph, handle = self.compiler.compile(
                graph,
                example_inputs,
                additional_inductor_config,
286
                compile_range,
287
288
                maybe_key,
            )
289
290
291
292

        assert compiled_graph is not None, "Failed to compile the graph"

        # store the artifact in the cache
293
        if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
294
            self.cache[(compile_range, graph_index, self.compiler.name)] = handle
295
            compilation_counter.num_cache_entries_updated += 1
296
            self.is_cache_updated = True
297
298
            if graph_index == 0:
                # adds some info logging for the first graph
299
300
301
                logger.info_once(
                    "Cache the graph of compile range %s for later use",
                    str(compile_range),
302
                )
303
304
305
306
307
308
309
            logger.debug(
                "Store the %s-th graph for compile range%s from %s via handle %s",
                graph_index,
                str(compile_range),
                self.compiler.name,
                handle,
            )
310
311
312
313
314
315

        # after compiling the last graph, record the end time
        if graph_index == num_graphs - 1:
            now = time.time()
            elapsed = now - compilation_start_time
            compilation_config.compilation_time += elapsed
316
317
318
319
320
321
            logger.info_once(
                "Compiling a graph for compile range %s takes %.2f s",
                str(compile_range),
                elapsed,
                scope="local",
            )
322

323
        return compiled_graph
324
325


326
327
328
@dataclasses.dataclass
class SplitItem:
    submod_name: str
329
    graph_id: int
330
331
332
333
    is_splitting_graph: bool
    graph: fx.GraphModule


334
def split_graph(
335
    graph: fx.GraphModule, splitting_ops: list[str]
336
) -> tuple[fx.GraphModule, list[SplitItem]]:
337
338
    # split graph by ops
    subgraph_id = 0
339
340
    node_to_subgraph_id: dict[fx.Node, int] = {}
    split_op_graphs: list[int] = []
341
342
343
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
344

345
346
347
348
349
350
351
352
353
354
355
356
        # Check if this is a getitem operation on a node from an earlier subgraph.
        # If so, assign it to the same subgraph as its input to avoid passing entire
        # tuple as input to submodules, which is against standalone_compile and
        # AoTAutograd input requirement.
        if node.op == "call_function" and node.target == operator.getitem:
            # Assign this getitem to the same subgraph as its input
            input_node = node.args[0]
            if input_node.op != "placeholder":
                assert input_node in node_to_subgraph_id
                node_to_subgraph_id[node] = node_to_subgraph_id[input_node]
                continue

357
        if should_split(node, splitting_ops):
358
359
360
361
362
363
364
365
366
367
368
369
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
370
371
        graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
    )
372

373
    outputs = []
374

375
    names = [name for (name, module) in split_gm.named_modules()]
376

377
378
379
380
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
381

382
        module = getattr(split_gm, name)
383

384
        graph_id = int(name.replace("submod_", ""))
385
        outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
386

387
    # sort by integer graph_id, rather than string name
388
    outputs.sort(key=lambda x: x.graph_id)
389

390
    return split_gm, outputs
391
392


393
394
compilation_start_time = 0.0

395

396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
def wrap_with_cudagraph_if_needed(
    piecewise_backend: Any,
    vllm_config: VllmConfig,
    compilation_config: CompilationConfig,
    is_first_graph: bool,
    is_last_graph: bool,
) -> Any:
    """
    Wrap a piecewise backend with CUDA graph wrapper if needed.
    This function is shared between VllmBackend and
    construct_serializable_fn_from_inductor_cache.

    Args:
        piecewise_backend: The backend to wrap
        vllm_config: The vLLM configuration
        compilation_config: The compilation configuration
        is_first_graph: Whether this is the first graph in the sequence
        is_last_graph: Whether this is the last graph in the sequence

    Returns:
        The wrapped backend if CUDA graphs are enabled, otherwise the original backend
    """
    if (
        not compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
        or compilation_config.use_inductor_graph_partition
    ):
        return piecewise_backend

    # We're using Dynamo-based piecewise splitting, so we wrap
    # the whole subgraph with a static graph wrapper.
    from .cuda_graph import CUDAGraphOptions

    # resolve the static graph wrapper class (e.g. CUDAGraphWrapper
    # class) as platform dependent.
    static_graph_wrapper_class = resolve_obj_by_qualname(
        current_platform.get_static_graph_wrapper_cls()
    )

    # Always assign PIECEWISE runtime mode to the
    # CUDAGraphWrapper for piecewise_backend, to distinguish
    # it from the FULL cudagraph runtime mode, no matter it
    # is wrapped on a full or piecewise fx graph.
    return static_graph_wrapper_class(
        runnable=piecewise_backend,
        vllm_config=vllm_config,
        runtime_mode=CUDAGraphMode.PIECEWISE,
        cudagraph_options=CUDAGraphOptions(
            debug_log_enable=is_first_graph,
            gc_disable=not is_first_graph,
            weak_ref_output=is_last_graph,
        ),
    )


450
class PiecewiseCompileInterpreter(torch.fx.Interpreter):  # type: ignore[misc]
451
452
453
454
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
455
456
457
458
459

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
460
461
462
463
464
465
466
467
468
469
470
471

    Note: This class shares similar logic with
    reconstruct_serializable_fn_from_mega_artifact in caching.py.
    Both create PiecewiseBackend instances and wrap them with cudagraph.
    The key difference is:
    - reconstruct_serializable_fn_from_mega_artifact: PiecewiseBackend receives
      pre-compiled runnables (compiled_runnables is set, graph is None)
    - this class: PiecewiseBackend receives the FX graph to compile
      (graph is set, compiled_runnables is None)


    If modifying the backend creation/wrapping logic, consider updating both.
472
473
    """

474
475
476
477
478
479
    def __init__(
        self,
        module: torch.fx.GraphModule,
        compile_submod_names: list[str],
        vllm_config: VllmConfig,
        vllm_backend: "VllmBackend",
480
    ) -> None:
481
482
        super().__init__(module)
        from torch._guards import detect_fake_mode
483

484
485
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
486
487
        self.compilation_config = vllm_config.compilation_config
        self.vllm_config = vllm_config
488
        self.vllm_backend = vllm_backend
489
490
        # When True, it annoyingly dumps the torch.fx.Graph on errors.
        self.extra_traceback = False
491

492
    def run(self, *args: Any) -> Any:
493
        # maybe instead just assert inputs are fake?
494
495
496
497
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
498
        with self.fake_mode, enable_python_dispatcher():
499
            return super().run(*fake_args)
500

501
502
503
504
505
506
    def call_module(
        self,
        target: torch.fx.node.Target,
        args: tuple[torch.fx.node.Argument, ...],
        kwargs: dict[str, Any],
    ) -> Any:
507
        assert isinstance(target, str)
508

509
510
511
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
512
            index = self.compile_submod_names.index(target)
513
            submod = self.fetch_attr(target)
514

515
516
517
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
518

519
            # Lazy import here to avoid circular import
520
521
            from torch._inductor.compile_fx import graph_returns_tuple

522
            from .piecewise_backend import PiecewiseBackend
523

524
            piecewise_backend = PiecewiseBackend(
525
526
527
528
529
530
                submod,
                self.vllm_config,
                index,
                len(self.compile_submod_names),
                sym_shape_indices,
                self.vllm_backend,
531
                graph_returns_tuple(submod),
532
            )
533

534
535
536
537
538
539
540
            self.module.__dict__[target] = wrap_with_cudagraph_if_needed(
                piecewise_backend,
                self.vllm_config,
                self.compilation_config,
                piecewise_backend.is_first_graph,
                piecewise_backend.is_last_graph,
            )
541

542
543
544
545
546
            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


547
548
549
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"
550
model_is_encoder: bool = False
551

552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
_on_compilation_complete_callback: contextvars.ContextVar[Callable[[], None] | None] = (
    contextvars.ContextVar("on_compilation_complete_callback", default=None)
)


@contextmanager
def set_on_compilation_complete(
    callback: Callable[[], None],
) -> Generator[None, None, None]:
    token = _on_compilation_complete_callback.set(callback)
    try:
        yield
    finally:
        _on_compilation_complete_callback.reset(token)

567
568

@contextmanager
569
def set_model_tag(tag: str, is_encoder: bool = False) -> Generator[None, None, None]:
570
571
    """Context manager to set the model tag."""
    global model_tag
572
    global model_is_encoder
573
    assert tag != model_tag, (
574
        f"Model tag {tag} is the same as the current tag {model_tag}."
575
    )
576
    old_tag = model_tag
577
578
    old_is_encoder = model_is_encoder

579
    model_tag = tag
580
    model_is_encoder = is_encoder
581
582
583
584
    try:
        yield
    finally:
        model_tag = old_tag
585
        model_is_encoder = old_is_encoder
586
587


588
class VllmBackend:
589
    """The compilation backend for `torch.compile` with vLLM.
590
    It is used for compilation mode of `CompilationMode.VLLM_COMPILE`,
591
    where we customize the compilation.
592

593
594
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
595

596
597
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
598
    """
599

600
601
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
602
603
604
605
606
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
607
    piecewise_graphs: list[SplitItem]
608
    returned_callable: Callable[..., Any]
609
    # Inductor passes to run on the graph pre-defunctionalization
610
    post_grad_passes: Sequence[Callable[..., Any]]
611
    compiler_manager: CompilerManager
612
613
614
    # Copy of CompilationConfig.inductor_compile_config +
    # an entry for PostGradPassManager
    inductor_config: dict[str, Any]
615

616
617
    def __init__(
        self,
618
        vllm_config: VllmConfig,
619
        prefix: str = "",
620
        is_encoder: bool = False,
621
    ) -> None:
622
623
        # if the model is initialized with a non-empty prefix,
        # then usually it's enough to use that prefix,
624
        # e.g. language_model, vision_model, etc.
625
626
627
628
629
        # when multiple parts are initialized as independent
        # models, we need to use the model_tag to distinguish
        # them, e.g. backbone (default), eagle_head, etc.
        self.prefix = prefix or model_tag

630
        # Mark compilation for encoder.
631
        self.is_encoder = is_encoder or model_is_encoder
632

633
        # Passes to run on the graph post-grad.
634
635
636
637
        self.pass_manager = resolve_obj_by_qualname(
            current_platform.get_pass_manager_cls()
        )()
        self.pass_key = current_platform.pass_key
638

639
640
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
641

642
        self.compiler_manager: CompilerManager = CompilerManager(
643
644
            self.compilation_config
        )
645

646
647
648
649
650
651
        # Deepcopy the inductor config to detach the post-grad custom pass
        # from CompilationConfig.
        # We want to avoid PostGradPassManager in CompilationConfig because
        # in future we need PostGradPassManager.uuid() to be executed
        # only at compile time.
        self.inductor_config = deepcopy(self.compilation_config.inductor_compile_config)
652
653
654
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
    def collect_standalone_compile_artifacts(
        self,
    ) -> tuple[Any, dict[str, list[int]] | None, dict[str, bool] | None]:
        """Collect inductor cache artifacts from all piecewise backends.

        Returns:
            tuple: (standalone_compile_artifacts, sym_shape_indices_map,
                    returns_tuple_map)
                - standalone_compile_artifacts: StandaloneCompiledArtifacts
                  with compiled artifacts
                - sym_shape_indices_map: dict mapping submod_name to
                  sym_shape_indices
                - returns_tuple_map: dict mapping submod_name to
                  returns_tuple
        """

        if not envs.VLLM_USE_MEGA_AOT_ARTIFACT:
            return None, None, None

        from .caching import StandaloneCompiledArtifacts
        from .piecewise_backend import PiecewiseBackend

        standalone_compile_artifacts = StandaloneCompiledArtifacts()
        sym_shape_indices_map = {}
        returns_tuple_map = {}

        for name, _ in self.split_gm.named_children():
            # get the actual attribute (shadowed by PiecewiseBackend in __dict__)
            child = getattr(self.split_gm, name)
            # unwrap the static graph wrapper class if applicable
            piecewise_backend = child.runnable if hasattr(child, "runnable") else child

            if not isinstance(piecewise_backend, PiecewiseBackend):
                continue

            submod_name = name
            sym_shape_indices_map[submod_name] = piecewise_backend.sym_shape_indices
            returns_tuple_map[submod_name] = piecewise_backend.returns_tuple

            for shape_str, bytes_data in piecewise_backend.to_bytes().items():
                standalone_compile_artifacts.insert(submod_name, shape_str, bytes_data)
                logger.debug(
                    "collected artifact for %s shape %s (%d bytes)",
                    submod_name,
                    shape_str,
                    len(bytes_data),
                )

        logger.info(
            "collected artifacts: %d entries, %d artifacts, %d bytes total",
            standalone_compile_artifacts.num_entries(),
            standalone_compile_artifacts.num_artifacts(),
            standalone_compile_artifacts.size_bytes(),
        )

        logger.debug(
            "standalone compile artifact keys: %s",
            list(standalone_compile_artifacts.submodule_bytes.keys()),
        )

        return standalone_compile_artifacts, sym_shape_indices_map, returns_tuple_map

717
    def configure_post_pass(self) -> None:
718
        self.pass_manager.configure(self.vllm_config)
719

720
721
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
722
723
724
725
726
        if self.pass_key in self.inductor_config:
            if isinstance(self.inductor_config[self.pass_key], PostGradPassManager):
                raise ValueError(
                    "PostGradPassManager can not be kept in CompilationConfig."
                )
727
            else:
728
                # Config should automatically wrap all inductor passes
729
730
731
732
733
734
735
                assert isinstance(
                    self.compilation_config.inductor_compile_config[self.pass_key],
                    InductorPass,
                )
                self.pass_manager.add(
                    self.compilation_config.inductor_compile_config[self.pass_key]
                )
736
        self.inductor_config[self.pass_key] = self.pass_manager
737

738
739
740
741
742
    def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
        from .caching import (
            VllmSerializableFunction,
        )

743
        vllm_config = self.vllm_config
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
        # Minimal hashing here with existing utilities, reused below.

        env_factors = envs.compile_factors()
        env_hash = hash_factors(env_factors)
        # Compute config/compiler/code hashes once and reuse
        config_hash = vllm_config.compute_hash()
        compiler_hash = self.compiler_manager.compute_hash(vllm_config)
        forward_code_files = list(sorted(self.compilation_config.traced_files))

        logger.debug(
            "Traced files (to be considered for compilation cache):\n%s",
            lazy(lambda: "\n".join(forward_code_files)),
        )
        hash_content = []
        for filepath in forward_code_files:
            hash_content.append(filepath)
            if filepath == "<string>":
                # This means the function was dynamically generated, with
                # e.g. exec(). We can't actually check these.
                continue
            try:
                with open(filepath) as f:
                    hash_content.append(f.read())
767
            except (OSError, UnicodeDecodeError):
768
769
770
771
772
                logger.warning("Failed to read file %s", filepath)
                continue
        code_hash = hashlib.sha256("\n".join(hash_content).encode()).hexdigest()
        # Clear after consumption
        self.compilation_config.traced_files.clear()
773
774
775
776
777
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.
778
779
780
781
            factors = [env_hash, config_hash, code_hash, compiler_hash]
            # Use SHA-256 for cache key hashing to be consistent across
            # compute_hash functions. Truncate for a short cache dir name.
            hash_key = hashlib.sha256(str(factors).encode()).hexdigest()[:10]
782
            cache_dir = os.path.join(
783
                envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key
784
785
786
            )
            self.compilation_config.cache_dir = cache_dir

787
        cache_dir = self.compilation_config.cache_dir
788
        os.makedirs(cache_dir, exist_ok=True)
789
        self.compilation_config.cache_dir = cache_dir
790
        rank = vllm_config.parallel_config.rank
791
        dp_rank = vllm_config.parallel_config.data_parallel_index
792
        local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
793
        os.makedirs(local_cache_dir, exist_ok=True)
794
        self.compilation_config.local_cache_dir = local_cache_dir
795

796
        # Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
797
        disable_cache = not is_compile_cache_enabled(self.inductor_config)
798
799

        if disable_cache:
800
            logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
801
        else:
802
803
804
805
            logger.info_once(
                "Using cache directory: %s for vLLM's torch.compile",
                local_cache_dir,
                scope="local",
806
            )
807

808
809
810
        self.compiler_manager.initialize_cache(
            local_cache_dir, disable_cache, self.prefix
        )
811

812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
        # Reuses existing cache key

        logger.debug(
            "torch.compile cache factors: env=%s cfg=%s comp=%s code=%s dir=%s",
            env_hash,
            config_hash,
            compiler_hash,
            code_hash,
            local_cache_dir,
        )

        # Persist and log only hash-relevant factors together.
        try:
            logger.debug(
                "Compile env factors (raw):\n%s\nVllm config hash: %s",
                lazy(partial(pprint.pformat, env_factors, width=120)),
                config_hash,
            )
            meta_path = os.path.join(local_cache_dir, "cache_key_factors.json")
            if not os.path.exists(meta_path):
                with open(meta_path, "w") as f:
                    json.dump(
                        {
                            "env": env_factors,  # raw factors used for env_hash
                            "config_hash": config_hash,
                            "code_hash": code_hash,
                            "compiler_hash": compiler_hash,
                        },
                        f,
                        indent=2,
                        sort_keys=True,
                    )
        except Exception:
            # Best-effort only; metadata write failures are non-fatal.
            logger.warning(
                (
                    "Could not write compile cache metadata at %s; continuing without "
                    "metadata. Compiled cache remains valid; diagnostics may be "
                    "limited."
                ),
                local_cache_dir,
                exc_info=True,
            )

856
857
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
858
        compilation_counter.num_graphs_seen += 1
859
        from .monitor import torch_compile_start_time
860

861
        dynamo_time = time.time() - torch_compile_start_time
862
863
864
        logger.info_once(
            "Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local"
        )
865
        self.compilation_config.compilation_time += dynamo_time
866
867
868
869
870
871

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
872
        self.configure_post_pass()
873

874
875
876
877
878
879
        if self.compilation_config.use_inductor_graph_partition:
            # Let Inductor decide partitioning; avoid FX-level pre-splitting.
            fx_split_ops: list[str] = []
        else:
            fx_split_ops = self.compilation_config.splitting_ops or []

880
        self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
881

882
883
884
885
886
887
        # keep a split_gm copy from BEFORE the interpreter replaces
        # submodules with PiecewiseBackend -- used for serialization
        original_split_gm = None
        if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
            original_split_gm = deepcopy(self.split_gm)

888
        from torch._dynamo.utils import lazy_format_graph_code
889
890
891
892
893

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
894

895
        compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
896
        submod_names_to_compile = [
897
898
            item.submod_name
            for item in self.piecewise_graphs
899
900
901
            if not item.is_splitting_graph
        ]

902
903
904
905
906
907
908
909
910
911
        # Extract fake values from the graph to use them when needed.
        all_fake_values = []
        for i in graph.graph.find_nodes(op="placeholder"):
            all_fake_values.append(i.meta["example_value"])

        fake_args = [
            all_fake_values[i] if isinstance(t, torch.Tensor) else t
            for i, t in enumerate(example_inputs)
        ]

912
913
        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
914
915
        PiecewiseCompileInterpreter(
            self.split_gm, submod_names_to_compile, self.vllm_config, self
916
        ).run(*fake_args)
917

918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
        from torch._guards import detect_fake_mode

        fake_mode = detect_fake_mode()

        if (
            self.compilation_config.dynamic_shapes_config.evaluate_guards
            and self.compilation_config.dynamic_shapes_config.type
            == DynamicShapesType.BACKED
        ):
            from torch.utils._sympy.value_ranges import ValueRanges

            # Drop counter-0/1 specializations guards; for backed dynamic shapes,
            # torch.compile will specialize for 0/1 inputs or otherwise guards that
            # shape is >= 2. This is because it's really hard not to hit a check
            # against 0/1. When we evaluate shape guards, we exclude checking those
            # guards (We would fail always otherwise).

            # We avoid that by updating the ranges of backed sizes when the min is
            # 2 for any, we assume it's 0.
            for s, r in fake_mode.shape_env.var_to_range.items():
                if r.lower == 2:
                    fake_mode.shape_env.var_to_range[s] = ValueRanges(0, r.upper)

941
942
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
943
944
            # code adapted from
            # https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30
945
            # use `print_readable` because it can include submodules
946
947
948
949
            src = (
                "from __future__ import annotations\nimport torch\n"
                + self.split_gm.print_readable(print_output=False)
            )
950
951
952
953
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

954
955
956
            logger.debug_once(
                "Computation graph saved to %s", graph_path, scope="local"
            )
957

958
        self._called = True
959
960
961
        graph_to_serialize = (
            original_split_gm if envs.VLLM_USE_MEGA_AOT_ARTIFACT else self.graph
        )
962

963
964
965
966
        if (
            self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
            or not self.compilation_config.cudagraph_copy_inputs
        ):
967
            return VllmSerializableFunction(
968
969
970
971
972
973
                graph_to_serialize,
                example_inputs,
                self.prefix,
                self.split_gm,
                is_encoder=self.is_encoder,
                vllm_backend=self,
974
            )
975
976

        # index of tensors that have symbolic shapes (batch size)
977
978
979
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
980

981
        sym_tensor_indices = [
982
983
984
985
            i
            for i, x in enumerate(fake_args)
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
            and any(is_symbolic(d) for d in x.size())
986
987
988
989
990
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
991
992
993
994
995
        copy_and_call = make_copy_and_call(
            sym_tensor_indices,
            [example_inputs[x].clone() for x in sym_tensor_indices],
            self.split_gm,
        )
996

997
        return VllmSerializableFunction(
998
999
1000
1001
1002
1003
1004
            graph_to_serialize,
            example_inputs,
            self.prefix,
            copy_and_call,
            is_encoder=self.is_encoder,
            vllm_backend=self,
            sym_tensor_indices=sym_tensor_indices,
1005
        )