backends.py 49.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import ast
5
import dataclasses
6
import hashlib
7
import json
8
import operator
9
10
import os
import pprint
11
import time
12
from collections import defaultdict
13
from collections.abc import Callable, Generator, Sequence
14
from contextlib import contextmanager
15
from copy import deepcopy
16
from functools import partial
17
from typing import Any
18
19
20

import torch
import torch.fx as fx
21
from torch._dynamo.utils import dynamo_timed
22
from torch._logging._internal import trace_structured
23
from torch.fx._lazy_graph_module import _use_lazy_graph_module
24

25
import vllm.envs as envs
26
27
28
29
from vllm.compilation.codegen import (
    compile_execution_fn,
    generate_execution_code,
)
30
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
31
from vllm.config.compilation import DynamicShapesType
32
from vllm.config.utils import Range, hash_factors
33
from vllm.logger import init_logger
34
from vllm.logging_utils import lazy
35
from vllm.platforms import current_platform
36
from vllm.tracing import instrument, instrument_manual
37
from vllm.utils.import_utils import resolve_obj_by_qualname
38
from vllm.utils.torch_utils import is_torch_equal_or_newer
39

40
41
42
43
44
from .compiler_interface import (
    CompilerInterface,
    EagerAdaptor,
    InductorAdaptor,
    InductorStandaloneAdaptor,
45
    is_compile_cache_enabled,
46
)
47
from .counter import compilation_counter
48
49
50
51
52
53
from .partition_rules import (
    inductor_partition_rule_context,
    should_split,
)
from .passes.inductor_pass import InductorPass, pass_context
from .passes.pass_manager import PostGradPassManager
54
55
56

logger = init_logger(__name__)

57

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def make_copy_and_call(
    sym_tensor_indices: list[int],
    input_buffers: list[torch.Tensor | None],
    callable_fn: Callable[..., Any],
) -> Callable[..., Any]:
    """Create a wrapper that copies inputs to static buffers before calling.

    This is used for cudagraph input copying where we need to copy dynamic
    tensors to static buffers before invoking the compiled graph.

    Args:
        sym_tensor_indices: Indices of tensors with symbolic shapes
        input_buffers: List of static buffers (can contain None for lazy init)
        callable_fn: The compiled function to call

    Returns:
        A wrapper function that copies inputs and calls the compiled function
    """

77
    def copy_and_call(*args: Any) -> Any:
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        list_args = list(args)
        for i, index in enumerate(sym_tensor_indices):
            runtime_tensor = list_args[index]
            runtime_shape = runtime_tensor.shape[0]

            # lazy initialization of buffer on first call
            if input_buffers[i] is None:
                input_buffers[i] = runtime_tensor.clone()

            static_tensor = input_buffers[i][:runtime_shape]  # type: ignore[index]
            static_tensor.copy_(runtime_tensor)
            list_args[index] = static_tensor
        return callable_fn(*list_args)

    return copy_and_call


95
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
96
97
98
99
    assert not envs.VLLM_USE_MEGA_AOT_ARTIFACT or envs.VLLM_USE_STANDALONE_COMPILE, (
        "VLLM_USE_MEGA_AOT_ARTIFACT=1 requires VLLM_USE_STANDALONE_COMPILE=1"
    )

100
    if compilation_config.backend == "inductor":
101
102
        # Use standalone compile only if requested, version is new enough,
        # and the symbol actually exists in this PyTorch build.
103
104
        if envs.VLLM_USE_STANDALONE_COMPILE and hasattr(
            torch._inductor, "standalone_compile"
105
        ):
106
            logger.debug("Using InductorStandaloneAdaptor")
107
108
109
            return InductorStandaloneAdaptor(
                compilation_config.compile_cache_save_format
            )
110
        else:
111
            logger.debug("Using InductorAdaptor")
112
            return InductorAdaptor()
113
    elif compilation_config.backend == "eager":
114
        logger.debug("Using EagerAdaptor")
115
        return EagerAdaptor()
116
117
118
119
120
    else:
        logger.debug("Using custom backend: %s", compilation_config.backend)
        compiler = resolve_obj_by_qualname(current_platform.get_compile_backend())()
        assert isinstance(compiler, CompilerInterface)
        return compiler
121
122


123
124
125
126
127
class CompilerManager:
    """
    A manager to manage the compilation process, including
    caching the compiled graph, loading the compiled graph,
    and compiling the graph.
128

129
    The cache is a dict mapping
130
    `(runtime_shape, graph_index, backend_name)`
131
    to `any_data` returned from the compiler.
132

133
134
135
    When serializing the cache, we save it to a Python file
    for readability. We don't use json here because json doesn't
    support int as key.
136
137
    """

138
    def __init__(self, compilation_config: CompilationConfig) -> None:
139
        self.cache: dict[tuple[Range, int, str], Any] = dict()
140
        self.is_cache_updated = False
141
142
        self.compilation_config = compilation_config
        self.compiler = make_compiler(compilation_config)
143
        self.loaded_artifacts: dict[str, Any] = {}
144

145
146
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        return self.compiler.compute_hash(vllm_config)
147

148
    @contextmanager
149
    def compile_context(self, compile_range: Range) -> Generator[None, None, None]:
150
151
152
        """Provide compilation context for the duration of compilation to set
        any torch global properties we want to scope to a single Inductor
        compilation (e.g. partition rules, pass context)."""
153
        with pass_context(compile_range):
154
            if self.compilation_config.use_inductor_graph_partition:
155
                with inductor_partition_rule_context(
156
                    self.compilation_config.splitting_ops
157
                ):
158
159
160
161
                    yield
            else:
                yield

162
163
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
164
    ) -> None:
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        """
        Initialize the cache directory for the compiler.

        The organization of the cache directory is as follows:
        cache_dir=/path/to/hash_str/rank_i_j/prefix/
        inside cache_dir, there will be:
        - vllm_compile_cache.py
        - computation_graph.py
        - transformed_code.py

        for multiple prefixes, they can share the same
        base cache dir of /path/to/hash_str/rank_i_j/ ,
        to store some common compilation artifacts.
        """

180
        self.disable_cache = disable_cache
181
        self.cache_dir = cache_dir
182
183
184
185
        self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")

        if not disable_cache and os.path.exists(self.cache_file_path):
            # load the cache from the file
186
            with open(self.cache_file_path) as f:
187
188
189
                # we use ast.literal_eval to parse the data
                # because it is a safe way to parse Python literals.
                # do not use eval(), it is unsafe.
190
191
                cache = ast.literal_eval(f.read())

192
            def check_type(value: Any, ty: type) -> None:
193
194
195
                if not isinstance(value, ty):
                    raise TypeError(f"Expected {ty} but got {type(value)} for {value}")

196
197
198
            def parse_key(key: Any) -> tuple[Range, int, str]:
                range_tuple, graph_index, compiler_name = key
                check_type(graph_index, int)
199
200
201
202
203
204
205
                check_type(compiler_name, str)
                if isinstance(range_tuple, tuple):
                    start, end = range_tuple
                    check_type(start, int)
                    check_type(end, int)
                    range_tuple = Range(start=start, end=end)
                check_type(range_tuple, Range)
206
                return range_tuple, graph_index, compiler_name
207
208

            self.cache = {parse_key(key): value for key, value in cache.items()}
209

210
211
212
        self.compiler.initialize_cache(
            cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
        )
213

214
    def save_to_file(self) -> None:
215
        if self.disable_cache or not self.is_cache_updated:
216
            return
217
218
        printer = pprint.PrettyPrinter(indent=4)
        data = printer.pformat(self.cache)
219
        with open(self.cache_file_path, "w") as f:
220
221
            f.write(data)

222
223
224
225
    def load(
        self,
        graph: fx.GraphModule,
        example_inputs: list[Any],
226
        graph_index: int,
227
        compile_range: Range,
228
    ) -> Callable[..., Any] | None:
229
        if (compile_range, graph_index, self.compiler.name) not in self.cache:
230
            return None
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

        def parse_value(value: Any) -> tuple[tuple[str, str], str]:
            assert isinstance(value, dict)
            handle = value["graph_handle"]
            assert isinstance(handle[0], str)
            assert isinstance(handle[1], str)
            cache_key = value["cache_key"]
            return handle, cache_key

        try:
            handle, cache_key = parse_value(
                self.cache[(compile_range, graph_index, self.compiler.name)]
            )
        except Exception:
            # When the cache is outdated, we should ignore the existing file.
            # This should cause the correct cache to be generated again.
            return None

249
        compiled_graph = self.compiler.load(
250
            handle, graph, example_inputs, graph_index, compile_range
251
        )
252
        self.loaded_artifacts[cache_key] = compiled_graph
253
        logger.debug(
254
255
            "Directly load the %s-th graph for compile range %sfrom %s via handle %s",
            graph_index,
256
257
258
            str(compile_range),
            self.compiler.name,
            handle,
259
        )
260
261
        return compiled_graph

262
    @instrument(span_name="Compile graph")
263
264
265
    def compile(
        self,
        graph: fx.GraphModule,
266
        example_inputs: list[Any],
267
        additional_inductor_config: dict[str, Any],
268
        compilation_config: CompilationConfig,
269
        compile_range: Range,
270
271
        graph_index: int = 0,
        num_graphs: int = 1,
272
        is_encoder: bool = False,
273
    ) -> Any:
274
        if graph_index == 0:
275
276
            # before compiling the first graph, record the start time
            global compilation_start_time
277
            compilation_start_time = time.perf_counter()
278
279
280
281
282
283

        compilation_counter.num_backend_compilations += 1

        compiled_graph = None

        # try to load from the cache
284
        compiled_graph = self.load(graph, example_inputs, graph_index, compile_range)
285
        if compiled_graph is not None:
286
287
288
            if graph_index == num_graphs - 1:
                # after loading the last graph for this shape, record the time.
                # there can be multiple graphs due to piecewise compilation.
289
                elapsed = time.perf_counter() - compilation_start_time
290
                logger.info_once(
291
292
293
294
295
                    "Directly load the compiled graph(s) for compile range %s "
                    "from the cache, took %.3f s",
                    str(compile_range),
                    elapsed,
                )
296
297
298
299
            return compiled_graph

        # no compiler cached the graph, or the cache is disabled,
        # we need to compile it
300
301
302
303
        if isinstance(self.compiler, InductorAdaptor):
            # Let compile_fx generate a key for us
            maybe_key = None
        else:
304
305
306
307
            maybe_key = "artifact_compile_range_"
            maybe_key += f"{compile_range.start}_{compile_range.end}"
            maybe_key += f"_subgraph_{graph_index}"
        with self.compile_context(compile_range):
308
309
            # There is a compilation time optimization here.
            #
310
            # If the (input metadata, graph, compiler config) are the same, then
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
            # we want to avoid compiling the same artifact again. If we didn't
            # do this optimization, the backend compilation (InductorAdaptor or
            # InductorStandaloneAdaptor)
            # is able to cache hit and produce an artifact faster if it was
            # already created, but it is still a duplicate artifact that
            # requires unnecessary things e.g. disk IO.
            #
            # The optimization is: If the backend compilation cache hits,
            # then do an early return from the backend compilation and look up
            # which of the previous in-memory artifacts we created to reuse.
            #
            # We implemented this by monkey-patching torch (torch does not
            # easily expose the cache_key function), but in the future torch
            # should expose the cache_key function that we can just call
            # directly before invoking backend compilation.
            cache_key = None
            orig = torch._functorch._aot_autograd.autograd_cache.autograd_cache_key

            def autograd_cache_key(*args, **kwargs):
                result = orig(*args, **kwargs)
                if result is None:
                    return None
                nonlocal cache_key
                cache_key = result[0]
                if cache_key in self.loaded_artifacts:
                    raise StopCompiling()
                return result

            from unittest.mock import patch

            with (
                # Graphs that are isometric (different node names but same
                # structure) should be treated as the same.
                torch._functorch.config.patch(autograd_cache_normalize_inputs=True),
                patch(
                    "torch._functorch._aot_autograd.autograd_cache.autograd_cache_key",
                    autograd_cache_key,
                ),
            ):
                try:
                    compiled_graph, handle = self.compiler.compile(
                        graph,
                        example_inputs,
                        additional_inductor_config,
                        compile_range,
                        maybe_key,
                    )
                except StopCompiling:
                    assert cache_key is not None
                    return self.loaded_artifacts[cache_key]
            if cache_key is not None and compiled_graph is not None:
                self.loaded_artifacts[cache_key] = compiled_graph
363
364
365
366

        assert compiled_graph is not None, "Failed to compile the graph"

        # store the artifact in the cache
367
        if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
368
369
370
371
            self.cache[(compile_range, graph_index, self.compiler.name)] = {
                "graph_handle": handle,
                "cache_key": cache_key,
            }
372
            compilation_counter.num_cache_entries_updated += 1
373
            self.is_cache_updated = True
374
375
            if graph_index == 0:
                # adds some info logging for the first graph
376
377
378
                logger.info_once(
                    "Cache the graph of compile range %s for later use",
                    str(compile_range),
379
                )
380
            logger.debug_once(
381
382
383
384
385
386
                "Store the %s-th graph for compile range%s from %s via handle %s",
                graph_index,
                str(compile_range),
                self.compiler.name,
                handle,
            )
387
388
389

        # after compiling the last graph, record the end time
        if graph_index == num_graphs - 1:
390
            elapsed = time.perf_counter() - compilation_start_time
391
392
393
394
395
            logger.info_once(
                "Compiling a graph for compile range %s takes %.2f s",
                str(compile_range),
                elapsed,
            )
396

397
        return compiled_graph
398
399


400
401
402
403
class StopCompiling(BaseException):
    pass


404
405
406
@dataclasses.dataclass
class SplitItem:
    submod_name: str
407
    graph_id: int
408
409
410
411
    is_splitting_graph: bool
    graph: fx.GraphModule


412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
def _is_empty_allocation_node(node: fx.Node) -> bool:
    if node.op == "call_method":
        return node.target == "new_empty"

    if node.op != "call_function":
        return False

    target = node.target
    if target in (torch.empty, torch.empty_like, torch.empty_strided):
        return True

    if isinstance(target, torch._ops.OpOverloadPacket):
        packet_name = target._qualified_op_name
    elif isinstance(target, torch._ops.OpOverload):
        packet_name = target.name()
    else:
        return False

    return packet_name.startswith("aten::empty") or packet_name.startswith(
        "aten::new_empty"
    )


def _merge_empty_only_subgraphs(
    node_to_subgraph_id: dict[fx.Node, int],
437
    split_op_graphs: list[int],
438
439
440
441
442
443
444
445
446
447
448
) -> None:
    """
    Merge a partition that only contains an empty allocation op into the
    previous partition. This avoids generating standalone empty submodules,
    which can lead to empty cudagraph captures.
    """

    nodes_by_subgraph_id: dict[int, list[fx.Node]] = defaultdict(list)
    for node, subgraph_id in node_to_subgraph_id.items():
        nodes_by_subgraph_id[subgraph_id].append(node)

449
450
451
452
453
454
455
    splitting_subgraphs = set(split_op_graphs)
    prev_non_splitting_subgraph_id: int | None = None

    max_subgraph_id = max(node_to_subgraph_id.values(), default=-1)
    for subgraph_id in range(max_subgraph_id + 1):
        nodes = nodes_by_subgraph_id.get(subgraph_id, [])
        if not nodes:
456
            continue
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474

        is_non_splitting_subgraph = subgraph_id not in splitting_subgraphs
        is_empty_only_subgraph = len(nodes) == 1 and _is_empty_allocation_node(nodes[0])
        merged = False

        if is_empty_only_subgraph and prev_non_splitting_subgraph_id is not None:
            # Safety check: don't move allocation before any input producer.
            empty_node = nodes[0]
            if all(
                input_node.op == "placeholder"
                or node_to_subgraph_id[input_node] <= prev_non_splitting_subgraph_id
                for input_node in empty_node.all_input_nodes
            ):
                node_to_subgraph_id[empty_node] = prev_non_splitting_subgraph_id
                merged = True

        if not merged and is_non_splitting_subgraph:
            prev_non_splitting_subgraph_id = subgraph_id
475
476


477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
def _decompose_size_nodes(graph: fx.GraphModule) -> None:
    """Decompose x.size() into per-dim sym_size.int calls.

    torch.Size objects cannot cross split boundaries because aot_autograd
    cannot handle them as submodule outputs. This replaces each size() call
    with individual sym_size.int(x, dim) nodes:
      - Dynamic dims (SymInt) → new sym_size.int node
      - Static dims (plain int) → inlined as literal constant
    """
    # Dynamo captures x.size()/x.shape as call_method target="size".
    size_nodes = list(graph.graph.find_nodes(op="call_method", target="size"))

    for node in size_nodes:
        tensor_node = node.args[0]
        ev = tensor_node.meta.get("example_value")
        assert ev is not None, (
            f"Tensor node '{tensor_node.name}' has no example_value metadata. "
            f"Cannot decompose size node '{node.name}'."
        )

        # Build per-dim replacements: sym_size.int node or literal int.
        dims: list[fx.Node | int] = []
        with graph.graph.inserting_after(tensor_node):
            for i in range(ev.dim()):
                dim_val = ev.shape[i]
                if isinstance(dim_val, torch.SymInt):
                    dn = graph.graph.call_function(
                        torch.ops.aten.sym_size.int, args=(tensor_node, i)
                    )
                    dn.meta["example_value"] = dim_val
                    dims.append(dn)
                elif isinstance(dim_val, int):
                    dims.append(dim_val)
                else:
                    raise AssertionError(
                        f"dim_val is either torch.SymInt or int, "
                        f"got {type(dim_val)} for dim {i} of "
                        f"'{node.name}'"
                    )

        # Replace size node in each user's args.
        for user in list(node.users):
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
            if (
                user.op == "call_function"
                and user.target is operator.getitem
                and len(user.args) == 2
                and user.args[0] is node
            ):
                # getitem(size, idx) → replace with dims[idx] directly.
                idx = user.args[1]
                assert isinstance(idx, int), (
                    f"Expected literal int index for getitem on size(), "
                    f"got {type(idx).__name__}: {idx}"
                )
                user.replace_all_uses_with(dims[idx])
                graph.graph.erase_node(user)
            else:
                # User consumes the full size tuple (e.g. view(clone, size))
                # → view(clone, d0, d1, ...)
                new_args = []
                for arg in user.args:
                    if arg is node:
                        new_args.extend(dims)
                    else:
                        new_args.append(arg)
                user.args = tuple(new_args)
543
544
545
        graph.graph.erase_node(node)


546
def split_graph(
547
    graph: fx.GraphModule, splitting_ops: list[str]
548
) -> tuple[fx.GraphModule, list[SplitItem]]:
549
550
    _decompose_size_nodes(graph)

551
552
    # split graph by ops
    subgraph_id = 0
553
554
    node_to_subgraph_id: dict[fx.Node, int] = {}
    split_op_graphs: list[int] = []
555
556
557
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
558

559
560
561
562
563
564
565
566
567
568
569
570
        # Check if this is a getitem operation on a node from an earlier subgraph.
        # If so, assign it to the same subgraph as its input to avoid passing entire
        # tuple as input to submodules, which is against standalone_compile and
        # AoTAutograd input requirement.
        if node.op == "call_function" and node.target == operator.getitem:
            # Assign this getitem to the same subgraph as its input
            input_node = node.args[0]
            if input_node.op != "placeholder":
                assert input_node in node_to_subgraph_id
                node_to_subgraph_id[node] = node_to_subgraph_id[input_node]
                continue

571
        if should_split(node, splitting_ops):
572
573
574
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
575
576
577
578
579
580
581
582

            # keep consecutive splitting ops together
            # (we know node.next exists because node isn't the last (output) node)
            if should_split(node.next, splitting_ops):
                # this will get incremented by the next node
                subgraph_id -= 1
            else:
                subgraph_id += 1
583
584
585
        else:
            node_to_subgraph_id[node] = subgraph_id

586
    _merge_empty_only_subgraphs(node_to_subgraph_id, split_op_graphs)
587

588
589
590
591
    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
592
    with _use_lazy_graph_module(True):
593
594
        has_tuple_return = is_torch_equal_or_newer("2.12.0.dev")
        tuple_return_kwarg = {"tuple_return": True} if has_tuple_return else {}
595
596
597
598
599
        split_gm = torch.fx.passes.split_module.split_module(
            graph,
            None,
            lambda node: node_to_subgraph_id[node],
            keep_original_order=True,
600
            **tuple_return_kwarg,
601
        )
602

603
    outputs = []
604

605
    names = [name for (name, module) in split_gm.named_modules()]
606

607
608
609
610
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
611

612
        module = getattr(split_gm, name)
613

614
        graph_id = int(name.replace("submod_", ""))
615
        outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
616

617
    # sort by integer graph_id, rather than string name
618
    outputs.sort(key=lambda x: x.graph_id)
619

620
    return split_gm, outputs
621
622


623
624
compilation_start_time = 0.0

625

626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
def wrap_with_cudagraph_if_needed(
    piecewise_backend: Any,
    vllm_config: VllmConfig,
    compilation_config: CompilationConfig,
    is_first_graph: bool,
    is_last_graph: bool,
) -> Any:
    """
    Wrap a piecewise backend with CUDA graph wrapper if needed.
    This function is shared between VllmBackend and
    construct_serializable_fn_from_inductor_cache.

    Args:
        piecewise_backend: The backend to wrap
        vllm_config: The vLLM configuration
        compilation_config: The compilation configuration
        is_first_graph: Whether this is the first graph in the sequence
        is_last_graph: Whether this is the last graph in the sequence

    Returns:
        The wrapped backend if CUDA graphs are enabled, otherwise the original backend
    """
    if (
        not compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
        or compilation_config.use_inductor_graph_partition
    ):
        return piecewise_backend

    # We're using Dynamo-based piecewise splitting, so we wrap
    # the whole subgraph with a static graph wrapper.
    from .cuda_graph import CUDAGraphOptions

    # resolve the static graph wrapper class (e.g. CUDAGraphWrapper
    # class) as platform dependent.
    static_graph_wrapper_class = resolve_obj_by_qualname(
        current_platform.get_static_graph_wrapper_cls()
    )

    # Always assign PIECEWISE runtime mode to the
    # CUDAGraphWrapper for piecewise_backend, to distinguish
    # it from the FULL cudagraph runtime mode, no matter it
    # is wrapped on a full or piecewise fx graph.
    return static_graph_wrapper_class(
        runnable=piecewise_backend,
        vllm_config=vllm_config,
        runtime_mode=CUDAGraphMode.PIECEWISE,
        cudagraph_options=CUDAGraphOptions(
            debug_log_enable=is_first_graph,
            gc_disable=not is_first_graph,
            weak_ref_output=is_last_graph,
        ),
    )


680
class PiecewiseCompileInterpreter(torch.fx.Interpreter):  # type: ignore[misc]
681
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
682
683
684
    It runs the given split graph interpreter, and for each submodule in
    `compile_submod_names`, creates a PiecewiseBackend and compiles all
    ranges up front.
685
686
687
688
689

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
690
691
692
693
694
695
696
697
698
699
700
701

    Note: This class shares similar logic with
    reconstruct_serializable_fn_from_mega_artifact in caching.py.
    Both create PiecewiseBackend instances and wrap them with cudagraph.
    The key difference is:
    - reconstruct_serializable_fn_from_mega_artifact: PiecewiseBackend receives
      pre-compiled runnables (compiled_runnables is set, graph is None)
    - this class: PiecewiseBackend receives the FX graph to compile
      (graph is set, compiled_runnables is None)


    If modifying the backend creation/wrapping logic, consider updating both.
702
703
    """

704
705
706
707
708
709
    def __init__(
        self,
        module: torch.fx.GraphModule,
        compile_submod_names: list[str],
        vllm_config: VllmConfig,
        vllm_backend: "VllmBackend",
710
    ) -> None:
711
712
        super().__init__(module)
        self.compile_submod_names = compile_submod_names
713
714
        self.compilation_config = vllm_config.compilation_config
        self.vllm_config = vllm_config
715
        self.vllm_backend = vllm_backend
716
717
        # When True, it annoyingly dumps the torch.fx.Graph on errors.
        self.extra_traceback = False
718

719
    @instrument(span_name="Inductor compilation")
720
    def run(self, *args: Any) -> Any:
721
        return super().run(*args)
722

723
724
725
726
727
728
    def call_module(
        self,
        target: torch.fx.node.Target,
        args: tuple[torch.fx.node.Argument, ...],
        kwargs: dict[str, Any],
    ) -> Any:
729
        assert isinstance(target, str)
730

731
732
733
        gm = getattr(self.module, target)
        outputs = gm.graph.output_node().args[0]
        output = fx.map_arg(outputs, lambda node: node.meta["example_value"])
734
735

        if target in self.compile_submod_names:
736
            index = self.compile_submod_names.index(target)
737
            submod = self.fetch_attr(target)
738

739
740
741
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
742

743
            # Lazy import here to avoid circular import
744
745
            from torch._inductor.compile_fx import graph_returns_tuple

746
            from .piecewise_backend import PiecewiseBackend
747

748
            piecewise_backend = PiecewiseBackend(
749
750
751
752
753
754
                submod,
                self.vllm_config,
                index,
                len(self.compile_submod_names),
                sym_shape_indices,
                self.vllm_backend,
755
                graph_returns_tuple(submod),
756
                submod_name=target,
757
            )
758

759
760
761
762
763
764
765
            self.module.__dict__[target] = wrap_with_cudagraph_if_needed(
                piecewise_backend,
                self.vllm_config,
                self.compilation_config,
                piecewise_backend.is_first_graph,
                piecewise_backend.is_last_graph,
            )
766

767
768
769
770
771
            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


772
773
774
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"
775
model_is_encoder: bool = False
776
777
778


@contextmanager
779
def set_model_tag(tag: str, is_encoder: bool = False) -> Generator[None, None, None]:
780
781
    """Context manager to set the model tag."""
    global model_tag
782
    global model_is_encoder
783
    assert tag != model_tag, (
784
        f"Model tag {tag} is the same as the current tag {model_tag}."
785
    )
786
    old_tag = model_tag
787
788
    old_is_encoder = model_is_encoder

789
    model_tag = tag
790
    model_is_encoder = is_encoder
791
792
793
794
    try:
        yield
    finally:
        model_tag = old_tag
795
        model_is_encoder = old_is_encoder
796
797


798
class VllmBackend:
799
    """The compilation backend for `torch.compile` with vLLM.
800
    It is used for compilation mode of `CompilationMode.VLLM_COMPILE`,
801
    where we customize the compilation.
802

803
804
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
805

806
807
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
808
    """
809

810
811
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
812
813
814
815
816
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
817
    piecewise_graphs: list[SplitItem]
818
    returned_callable: Callable[..., Any]
819
    # Inductor passes to run on the graph pre-defunctionalization
820
    post_grad_passes: Sequence[Callable[..., Any]]
821
    compiler_manager: CompilerManager
822
823
824
    # Copy of CompilationConfig.inductor_compile_config +
    # an entry for PostGradPassManager
    inductor_config: dict[str, Any]
825

826
827
    def __init__(
        self,
828
        vllm_config: VllmConfig,
829
        prefix: str = "",
830
        is_encoder: bool = False,
831
    ) -> None:
832
833
        # if the model is initialized with a non-empty prefix,
        # then usually it's enough to use that prefix,
834
        # e.g. language_model, vision_model, etc.
835
836
837
838
839
        # when multiple parts are initialized as independent
        # models, we need to use the model_tag to distinguish
        # them, e.g. backbone (default), eagle_head, etc.
        self.prefix = prefix or model_tag

840
        # Mark compilation for encoder.
841
        self.is_encoder = is_encoder or model_is_encoder
842

843
        # Passes to run on the graph post-grad.
844
845
846
847
        self.pass_manager = resolve_obj_by_qualname(
            current_platform.get_pass_manager_cls()
        )()
        self.pass_key = current_platform.pass_key
848

849
850
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
851

852
        self.compiler_manager: CompilerManager = CompilerManager(
853
854
            self.compilation_config
        )
855

856
857
858
859
860
861
        # Deepcopy the inductor config to detach the post-grad custom pass
        # from CompilationConfig.
        # We want to avoid PostGradPassManager in CompilationConfig because
        # in future we need PostGradPassManager.uuid() to be executed
        # only at compile time.
        self.inductor_config = deepcopy(self.compilation_config.inductor_compile_config)
862
863
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here
864

865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
    def collect_standalone_compile_artifacts(
        self,
    ) -> tuple[Any, dict[str, list[int]] | None, dict[str, bool] | None]:
        """Collect inductor cache artifacts from all piecewise backends.

        Returns:
            tuple: (standalone_compile_artifacts, sym_shape_indices_map,
                    returns_tuple_map)
                - standalone_compile_artifacts: StandaloneCompiledArtifacts
                  with compiled artifacts
                - sym_shape_indices_map: dict mapping submod_name to
                  sym_shape_indices
                - returns_tuple_map: dict mapping submod_name to
                  returns_tuple
        """

        if not envs.VLLM_USE_MEGA_AOT_ARTIFACT:
            return None, None, None

        from .caching import StandaloneCompiledArtifacts
        from .piecewise_backend import PiecewiseBackend

        standalone_compile_artifacts = StandaloneCompiledArtifacts()
        sym_shape_indices_map = {}
        returns_tuple_map = {}

        for name, _ in self.split_gm.named_children():
            # get the actual attribute (shadowed by PiecewiseBackend in __dict__)
            child = getattr(self.split_gm, name)
            # unwrap the static graph wrapper class if applicable
            piecewise_backend = child.runnable if hasattr(child, "runnable") else child

            if not isinstance(piecewise_backend, PiecewiseBackend):
                continue

            submod_name = name
            sym_shape_indices_map[submod_name] = piecewise_backend.sym_shape_indices
            returns_tuple_map[submod_name] = piecewise_backend.returns_tuple

            for shape_str, bytes_data in piecewise_backend.to_bytes().items():
                standalone_compile_artifacts.insert(submod_name, shape_str, bytes_data)
                logger.debug(
                    "collected artifact for %s shape %s (%d bytes)",
                    submod_name,
                    shape_str,
                    len(bytes_data),
                )

        logger.info(
            "collected artifacts: %d entries, %d artifacts, %d bytes total",
            standalone_compile_artifacts.num_entries(),
            standalone_compile_artifacts.num_artifacts(),
            standalone_compile_artifacts.size_bytes(),
        )

        logger.debug(
            "standalone compile artifact keys: %s",
            list(standalone_compile_artifacts.submodule_bytes.keys()),
        )

        return standalone_compile_artifacts, sym_shape_indices_map, returns_tuple_map

927
    def configure_post_pass(self) -> None:
928
        self.pass_manager.configure(self.vllm_config)
929

930
931
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
932
933
934
935
936
        if self.pass_key in self.inductor_config:
            if isinstance(self.inductor_config[self.pass_key], PostGradPassManager):
                raise ValueError(
                    "PostGradPassManager can not be kept in CompilationConfig."
                )
937
            else:
938
                # Config should automatically wrap all inductor passes
939
940
941
942
943
944
945
                assert isinstance(
                    self.compilation_config.inductor_compile_config[self.pass_key],
                    InductorPass,
                )
                self.pass_manager.add(
                    self.compilation_config.inductor_compile_config[self.pass_key]
                )
946
        self.inductor_config[self.pass_key] = self.pass_manager
947

948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
    def _log_compilation_config(self):
        """Log vLLM compilation config for TORCH_TRACE/tlparse."""
        cc = self.compilation_config
        pass_cfg = cc.pass_config

        # Helper to convert lists to comma-separated strings for tlparse display
        def list_to_str(lst: list | None) -> str:
            if lst is None:
                return ""
            return ", ".join(str(x) for x in lst)

        # Get enabled passes by introspecting dataclass fields
        enabled_passes = [
            f.name
            for f in dataclasses.fields(pass_cfg)
            if isinstance(getattr(pass_cfg, f.name), bool) and getattr(pass_cfg, f.name)
        ]

        trace_structured(
            "artifact",
            metadata_fn=lambda: {
                "name": "vllm_compilation_config",
                "encoding": "json",
            },
            payload_fn=lambda: json.dumps(
                {
                    "model": self.vllm_config.model_config.model,
                    "prefix": self.prefix,
                    "mode": str(cc.mode),
                    "backend": cc.backend,
                    "custom_ops": list_to_str(cc.custom_ops),
                    "splitting_ops": list_to_str(cc.splitting_ops),
                    "cudagraph_mode": str(cc.cudagraph_mode),
                    "compile_sizes": list_to_str(cc.compile_sizes),
982
983
                    "compile_ranges_endpoints": list_to_str(
                        cc.compile_ranges_endpoints
984
985
986
987
988
989
990
991
992
993
                    ),
                    "use_inductor_graph_partition": cc.use_inductor_graph_partition,
                    "inductor_passes": list_to_str(list(cc.inductor_passes.keys())),
                    "enabled_passes": list_to_str(enabled_passes),
                    "dynamic_shapes_type": str(cc.dynamic_shapes_config.type),
                    "dynamic_shapes_evaluate_guards": cc.dynamic_shapes_config.evaluate_guards,  # noqa: E501
                }
            ),
        )

994
    @dynamo_timed("vllm_backend")
995
996
997
998
999
    def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
        from .caching import (
            VllmSerializableFunction,
        )

1000
        vllm_config = self.vllm_config
1001
1002
1003

        self._log_compilation_config()

1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
        # Minimal hashing here with existing utilities, reused below.

        env_factors = envs.compile_factors()
        env_hash = hash_factors(env_factors)
        # Compute config/compiler/code hashes once and reuse
        config_hash = vllm_config.compute_hash()
        compiler_hash = self.compiler_manager.compute_hash(vllm_config)
        forward_code_files = list(sorted(self.compilation_config.traced_files))

        logger.debug(
            "Traced files (to be considered for compilation cache):\n%s",
            lazy(lambda: "\n".join(forward_code_files)),
        )
        hash_content = []
        for filepath in forward_code_files:
            if filepath == "<string>":
                # This means the function was dynamically generated, with
                # e.g. exec(). We can't actually check these.
                continue
1023
            hash_content.append(filepath)
1024
1025
1026
            try:
                with open(filepath) as f:
                    hash_content.append(f.read())
1027
            except (OSError, UnicodeDecodeError):
1028
1029
1030
1031
1032
                logger.warning("Failed to read file %s", filepath)
                continue
        code_hash = hashlib.sha256("\n".join(hash_content).encode()).hexdigest()
        # Clear after consumption
        self.compilation_config.traced_files.clear()
1033
1034
1035
1036
1037
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.
1038
1039
1040
1041
            factors = [env_hash, config_hash, code_hash, compiler_hash]
            # Use SHA-256 for cache key hashing to be consistent across
            # compute_hash functions. Truncate for a short cache dir name.
            hash_key = hashlib.sha256(str(factors).encode()).hexdigest()[:10]
1042
            cache_dir = os.path.join(
1043
                envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key
1044
1045
1046
            )
            self.compilation_config.cache_dir = cache_dir

1047
        cache_dir = self.compilation_config.cache_dir
1048
        os.makedirs(cache_dir, exist_ok=True)
1049
        self.compilation_config.cache_dir = cache_dir
1050
        rank = vllm_config.parallel_config.rank
1051
        dp_rank = vllm_config.parallel_config.data_parallel_index
1052
        local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
1053
        os.makedirs(local_cache_dir, exist_ok=True)
1054
        self.compilation_config.local_cache_dir = local_cache_dir
1055

1056
        # Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
1057
        disable_cache = not is_compile_cache_enabled(self.inductor_config)
1058

1059
1060
1061
1062
1063
1064
1065
        # TODO(patchy): ngram gpu kernel will cause vllm torch compile cache errors.
        is_ngram_gpu_enabled = (
            vllm_config.speculative_config is not None
            and vllm_config.speculative_config.use_ngram_gpu()
        )
        disable_cache = disable_cache or is_ngram_gpu_enabled

1066
        if disable_cache:
1067
            logger.info_once("vLLM's torch.compile cache is disabled.")
1068
        else:
1069
1070
1071
            logger.info_once(
                "Using cache directory: %s for vLLM's torch.compile",
                local_cache_dir,
1072
            )
1073

1074
1075
1076
        self.compiler_manager.initialize_cache(
            local_cache_dir, disable_cache, self.prefix
        )
1077

1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
        # Reuses existing cache key

        logger.debug(
            "torch.compile cache factors: env=%s cfg=%s comp=%s code=%s dir=%s",
            env_hash,
            config_hash,
            compiler_hash,
            code_hash,
            local_cache_dir,
        )

        # Persist and log only hash-relevant factors together.
        try:
            logger.debug(
                "Compile env factors (raw):\n%s\nVllm config hash: %s",
                lazy(partial(pprint.pformat, env_factors, width=120)),
                config_hash,
            )
            meta_path = os.path.join(local_cache_dir, "cache_key_factors.json")
            if not os.path.exists(meta_path):
                with open(meta_path, "w") as f:
                    json.dump(
                        {
                            "env": env_factors,  # raw factors used for env_hash
                            "config_hash": config_hash,
                            "code_hash": code_hash,
                            "compiler_hash": compiler_hash,
                        },
                        f,
                        indent=2,
                        sort_keys=True,
                    )
        except Exception:
            # Best-effort only; metadata write failures are non-fatal.
            logger.warning(
                (
                    "Could not write compile cache metadata at %s; continuing without "
                    "metadata. Compiled cache remains valid; diagnostics may be "
                    "limited."
                ),
                local_cache_dir,
                exc_info=True,
            )

1122
1123
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
1124
        compilation_counter.num_graphs_seen += 1
1125
        from .monitor import torch_compile_start_time
1126

1127
        dynamo_time = time.perf_counter() - torch_compile_start_time
1128
1129
1130
1131
        logger.info_once(
            "Dynamo bytecode transform time: %.2f s",
            dynamo_time,
        )
1132

1133
1134
1135
1136
1137
        # Record Dynamo time in tracing if available
        start_time = int(torch_compile_start_time * 1e9)
        attributes = {"dynamo.time_seconds": dynamo_time}
        instrument_manual("Dynamo bytecode transform", start_time, None, attributes)

1138
1139
1140
1141
1142
        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
1143
        self.configure_post_pass()
1144

1145
1146
1147
1148
1149
1150
        if self.compilation_config.use_inductor_graph_partition:
            # Let Inductor decide partitioning; avoid FX-level pre-splitting.
            fx_split_ops: list[str] = []
        else:
            fx_split_ops = self.compilation_config.splitting_ops or []

1151
        self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
1152

1153
1154
1155
1156
1157
1158
        # keep a split_gm copy from BEFORE the interpreter replaces
        # submodules with PiecewiseBackend -- used for serialization
        original_split_gm = None
        if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
            original_split_gm = deepcopy(self.split_gm)

1159
        from torch._dynamo.utils import lazy_format_graph_code
1160
1161
1162
1163
1164

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
1165

1166
1167
1168
1169
1170
1171
1172
        # Log the piecewise split graph for TORCH_TRACE/tlparse
        trace_structured(
            "graph_dump",
            metadata_fn=lambda: {"name": "vllm_piecewise_split_graph"},
            payload_fn=lambda: self.split_gm.print_readable(print_output=False),
        )

1173
        compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
1174
        submod_names_to_compile = [
1175
1176
            item.submod_name
            for item in self.piecewise_graphs
1177
1178
1179
            if not item.is_splitting_graph
        ]

1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
        # Extract fake values from the graph to use them when needed.
        all_fake_values = []
        for i in graph.graph.find_nodes(op="placeholder"):
            all_fake_values.append(i.meta["example_value"])

        fake_args = [
            all_fake_values[i] if isinstance(t, torch.Tensor) else t
            for i, t in enumerate(example_inputs)
        ]

1190
        # propagate the split graph to the piecewise backend,
1191
1192
1193
        # compile submodules with symbolic shapes, and compile all ranges
        # up front so that compilation is complete before the callable
        # is returned.
1194
1195
        PiecewiseCompileInterpreter(
            self.split_gm, submod_names_to_compile, self.vllm_config, self
1196
        ).run(*fake_args)
1197

1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
        # All compilation is done. Save the cache.
        time_before_saving = time.perf_counter()
        self.compiler_manager.save_to_file()
        elapsed = time.perf_counter() - time_before_saving
        if elapsed > 1:
            logger.info_once(
                "Saved compiler manager cache in %.2f seconds.",
                elapsed,
            )

1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
        from torch._guards import detect_fake_mode

        fake_mode = detect_fake_mode()

        if (
            self.compilation_config.dynamic_shapes_config.evaluate_guards
            and self.compilation_config.dynamic_shapes_config.type
            == DynamicShapesType.BACKED
        ):
            from torch.utils._sympy.value_ranges import ValueRanges

            # Drop counter-0/1 specializations guards; for backed dynamic shapes,
            # torch.compile will specialize for 0/1 inputs or otherwise guards that
            # shape is >= 2. This is because it's really hard not to hit a check
            # against 0/1. When we evaluate shape guards, we exclude checking those
            # guards (We would fail always otherwise).

            # We avoid that by updating the ranges of backed sizes when the min is
            # 2 for any, we assume it's 0.
            for s, r in fake_mode.shape_env.var_to_range.items():
                if r.lower == 2:
                    fake_mode.shape_env.var_to_range[s] = ValueRanges(0, r.upper)

1231
1232
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
1233
1234
            # code adapted from
            # https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30
1235
            # use `print_readable` because it can include submodules
1236
1237
1238
1239
            src = (
                "from __future__ import annotations\nimport torch\n"
                + self.split_gm.print_readable(print_output=False)
            )
1240
1241
1242
1243
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

1244
            logger.debug_once("Computation graph saved to %s", graph_path)
1245

1246
        self._called = True
1247
1248
1249
        graph_to_serialize = (
            original_split_gm if envs.VLLM_USE_MEGA_AOT_ARTIFACT else self.graph
        )
1250

1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
        execution_code, submod_names = generate_execution_code(self.split_gm)
        # Use getattr to get correct callables: __dict__ has PiecewiseBackend
        # instances (from PiecewiseCompileInterpreter), _modules has originals.
        # getattr checks __dict__ first, then falls back to _modules.
        submod_callables = {
            name: getattr(self.split_gm, name)
            for name, _ in self.split_gm.named_children()
        }
        runtime_callable = compile_execution_fn(
            execution_code, submod_callables, submod_names
        )

1263
1264
1265
1266
        if (
            self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
            or not self.compilation_config.cudagraph_copy_inputs
        ):
1267
            return VllmSerializableFunction(
1268
1269
1270
                graph_to_serialize,
                example_inputs,
                self.prefix,
1271
                runtime_callable,
1272
1273
                is_encoder=self.is_encoder,
                vllm_backend=self,
1274
1275
                execution_code=execution_code,
                submod_names=submod_names,
1276
            )
1277
1278

        # index of tensors that have symbolic shapes (batch size)
1279
1280
1281
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
1282

1283
        sym_tensor_indices = [
1284
1285
1286
1287
            i
            for i, x in enumerate(fake_args)
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
            and any(is_symbolic(d) for d in x.size())
1288
1289
1290
1291
1292
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
1293
1294
1295
        copy_and_call = make_copy_and_call(
            sym_tensor_indices,
            [example_inputs[x].clone() for x in sym_tensor_indices],
1296
            runtime_callable,
1297
        )
1298

1299
        return VllmSerializableFunction(
1300
1301
1302
1303
1304
1305
1306
            graph_to_serialize,
            example_inputs,
            self.prefix,
            copy_and_call,
            is_encoder=self.is_encoder,
            vllm_backend=self,
            sym_tensor_indices=sym_tensor_indices,
1307
1308
            execution_code=execution_code,
            submod_names=submod_names,
1309
        )