backends.py 45.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import ast
5
import dataclasses
6
import hashlib
7
import json
8
import operator
9
10
import os
import pprint
11
import time
12
from collections import defaultdict
13
from collections.abc import Callable, Generator, Sequence
14
from contextlib import contextmanager
15
from copy import deepcopy
16
from functools import partial
17
from typing import Any
18
19
20

import torch
import torch.fx as fx
21
from torch._dynamo.utils import dynamo_timed
22
from torch._logging._internal import trace_structured
23

24
import vllm.envs as envs
25
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
26
from vllm.config.compilation import DynamicShapesType
27
from vllm.config.utils import Range, hash_factors
28
from vllm.logger import init_logger
29
from vllm.logging_utils import lazy
30
from vllm.platforms import current_platform
31
from vllm.tracing import instrument, instrument_manual
32
from vllm.utils.import_utils import resolve_obj_by_qualname
33

34
35
36
37
38
from .compiler_interface import (
    CompilerInterface,
    EagerAdaptor,
    InductorAdaptor,
    InductorStandaloneAdaptor,
39
    is_compile_cache_enabled,
40
)
41
from .counter import compilation_counter
42
43
44
45
46
47
from .partition_rules import (
    inductor_partition_rule_context,
    should_split,
)
from .passes.inductor_pass import InductorPass, pass_context
from .passes.pass_manager import PostGradPassManager
48
49
50

logger = init_logger(__name__)

51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def make_copy_and_call(
    sym_tensor_indices: list[int],
    input_buffers: list[torch.Tensor | None],
    callable_fn: Callable[..., Any],
) -> Callable[..., Any]:
    """Create a wrapper that copies inputs to static buffers before calling.

    This is used for cudagraph input copying where we need to copy dynamic
    tensors to static buffers before invoking the compiled graph.

    Args:
        sym_tensor_indices: Indices of tensors with symbolic shapes
        input_buffers: List of static buffers (can contain None for lazy init)
        callable_fn: The compiled function to call

    Returns:
        A wrapper function that copies inputs and calls the compiled function
    """

71
    def copy_and_call(*args: Any) -> Any:
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        list_args = list(args)
        for i, index in enumerate(sym_tensor_indices):
            runtime_tensor = list_args[index]
            runtime_shape = runtime_tensor.shape[0]

            # lazy initialization of buffer on first call
            if input_buffers[i] is None:
                input_buffers[i] = runtime_tensor.clone()

            static_tensor = input_buffers[i][:runtime_shape]  # type: ignore[index]
            static_tensor.copy_(runtime_tensor)
            list_args[index] = static_tensor
        return callable_fn(*list_args)

    return copy_and_call


89
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
90
91
92
93
    assert not envs.VLLM_USE_MEGA_AOT_ARTIFACT or envs.VLLM_USE_STANDALONE_COMPILE, (
        "VLLM_USE_MEGA_AOT_ARTIFACT=1 requires VLLM_USE_STANDALONE_COMPILE=1"
    )

94
    if compilation_config.backend == "inductor":
95
96
        # Use standalone compile only if requested, version is new enough,
        # and the symbol actually exists in this PyTorch build.
97
98
        if envs.VLLM_USE_STANDALONE_COMPILE and hasattr(
            torch._inductor, "standalone_compile"
99
        ):
100
            logger.debug("Using InductorStandaloneAdaptor")
101
102
103
            return InductorStandaloneAdaptor(
                compilation_config.compile_cache_save_format
            )
104
        else:
105
            logger.debug("Using InductorAdaptor")
106
            return InductorAdaptor()
107
    elif compilation_config.backend == "eager":
108
        logger.debug("Using EagerAdaptor")
109
        return EagerAdaptor()
110
111
112
113
114
    else:
        logger.debug("Using custom backend: %s", compilation_config.backend)
        compiler = resolve_obj_by_qualname(current_platform.get_compile_backend())()
        assert isinstance(compiler, CompilerInterface)
        return compiler
115
116


117
118
119
120
121
class CompilerManager:
    """
    A manager to manage the compilation process, including
    caching the compiled graph, loading the compiled graph,
    and compiling the graph.
122

123
    The cache is a dict mapping
124
    `(runtime_shape, graph_index, backend_name)`
125
    to `any_data` returned from the compiler.
126

127
128
129
    When serializing the cache, we save it to a Python file
    for readability. We don't use json here because json doesn't
    support int as key.
130
131
    """

132
    def __init__(self, compilation_config: CompilationConfig) -> None:
133
        self.cache: dict[tuple[Range, int, str], Any] = dict()
134
        self.is_cache_updated = False
135
136
        self.compilation_config = compilation_config
        self.compiler = make_compiler(compilation_config)
137
        self.loaded_artifacts: dict[str, Any] = {}
138

139
140
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        return self.compiler.compute_hash(vllm_config)
141

142
    @contextmanager
143
    def compile_context(self, compile_range: Range) -> Generator[None, None, None]:
144
145
146
        """Provide compilation context for the duration of compilation to set
        any torch global properties we want to scope to a single Inductor
        compilation (e.g. partition rules, pass context)."""
147
        with pass_context(compile_range):
148
            if self.compilation_config.use_inductor_graph_partition:
149
                with inductor_partition_rule_context(
150
                    self.compilation_config.splitting_ops
151
                ):
152
153
154
155
                    yield
            else:
                yield

156
157
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
158
    ) -> None:
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        """
        Initialize the cache directory for the compiler.

        The organization of the cache directory is as follows:
        cache_dir=/path/to/hash_str/rank_i_j/prefix/
        inside cache_dir, there will be:
        - vllm_compile_cache.py
        - computation_graph.py
        - transformed_code.py

        for multiple prefixes, they can share the same
        base cache dir of /path/to/hash_str/rank_i_j/ ,
        to store some common compilation artifacts.
        """

174
        self.disable_cache = disable_cache
175
        self.cache_dir = cache_dir
176
177
178
179
        self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")

        if not disable_cache and os.path.exists(self.cache_file_path):
            # load the cache from the file
180
            with open(self.cache_file_path) as f:
181
182
183
                # we use ast.literal_eval to parse the data
                # because it is a safe way to parse Python literals.
                # do not use eval(), it is unsafe.
184
185
                cache = ast.literal_eval(f.read())

186
            def check_type(value: Any, ty: type) -> None:
187
188
189
                if not isinstance(value, ty):
                    raise TypeError(f"Expected {ty} but got {type(value)} for {value}")

190
191
192
            def parse_key(key: Any) -> tuple[Range, int, str]:
                range_tuple, graph_index, compiler_name = key
                check_type(graph_index, int)
193
194
195
196
197
198
199
                check_type(compiler_name, str)
                if isinstance(range_tuple, tuple):
                    start, end = range_tuple
                    check_type(start, int)
                    check_type(end, int)
                    range_tuple = Range(start=start, end=end)
                check_type(range_tuple, Range)
200
                return range_tuple, graph_index, compiler_name
201
202

            self.cache = {parse_key(key): value for key, value in cache.items()}
203

204
205
206
        self.compiler.initialize_cache(
            cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
        )
207

208
    def save_to_file(self) -> None:
209
        if self.disable_cache or not self.is_cache_updated:
210
            return
211
212
        printer = pprint.PrettyPrinter(indent=4)
        data = printer.pformat(self.cache)
213
        with open(self.cache_file_path, "w") as f:
214
215
            f.write(data)

216
217
218
219
    def load(
        self,
        graph: fx.GraphModule,
        example_inputs: list[Any],
220
        graph_index: int,
221
        compile_range: Range,
222
    ) -> Callable[..., Any] | None:
223
        if (compile_range, graph_index, self.compiler.name) not in self.cache:
224
            return None
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242

        def parse_value(value: Any) -> tuple[tuple[str, str], str]:
            assert isinstance(value, dict)
            handle = value["graph_handle"]
            assert isinstance(handle[0], str)
            assert isinstance(handle[1], str)
            cache_key = value["cache_key"]
            return handle, cache_key

        try:
            handle, cache_key = parse_value(
                self.cache[(compile_range, graph_index, self.compiler.name)]
            )
        except Exception:
            # When the cache is outdated, we should ignore the existing file.
            # This should cause the correct cache to be generated again.
            return None

243
        compiled_graph = self.compiler.load(
244
            handle, graph, example_inputs, graph_index, compile_range
245
        )
246
        self.loaded_artifacts[cache_key] = compiled_graph
247
        logger.debug(
248
249
            "Directly load the %s-th graph for compile range %sfrom %s via handle %s",
            graph_index,
250
251
252
            str(compile_range),
            self.compiler.name,
            handle,
253
        )
254
255
        return compiled_graph

256
    @instrument(span_name="Compile graph")
257
258
259
    def compile(
        self,
        graph: fx.GraphModule,
260
        example_inputs: list[Any],
261
        additional_inductor_config: dict[str, Any],
262
        compilation_config: CompilationConfig,
263
        compile_range: Range,
264
265
266
        graph_index: int = 0,
        num_graphs: int = 1,
    ) -> Any:
267
        if graph_index == 0:
268
269
            # before compiling the first graph, record the start time
            global compilation_start_time
270
            compilation_start_time = time.perf_counter()
271
272
273
274
275
276

        compilation_counter.num_backend_compilations += 1

        compiled_graph = None

        # try to load from the cache
277
        compiled_graph = self.load(graph, example_inputs, graph_index, compile_range)
278
        if compiled_graph is not None:
279
280
281
            if graph_index == num_graphs - 1:
                # after loading the last graph for this shape, record the time.
                # there can be multiple graphs due to piecewise compilation.
282
                elapsed = time.perf_counter() - compilation_start_time
283
                compilation_config.compilation_time += elapsed
284
                logger.info_once(
285
286
287
288
                    "Directly load the compiled graph(s) for compile range %s "
                    "from the cache, took %.3f s",
                    str(compile_range),
                    elapsed,
289
                    scope="local",
290
                )
291
292
293
294
            return compiled_graph

        # no compiler cached the graph, or the cache is disabled,
        # we need to compile it
295
296
297
298
        if isinstance(self.compiler, InductorAdaptor):
            # Let compile_fx generate a key for us
            maybe_key = None
        else:
299
300
301
302
            maybe_key = "artifact_compile_range_"
            maybe_key += f"{compile_range.start}_{compile_range.end}"
            maybe_key += f"_subgraph_{graph_index}"
        with self.compile_context(compile_range):
303
304
            # There is a compilation time optimization here.
            #
305
            # If the (input metadata, graph, compiler config) are the same, then
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
            # we want to avoid compiling the same artifact again. If we didn't
            # do this optimization, the backend compilation (InductorAdaptor or
            # InductorStandaloneAdaptor)
            # is able to cache hit and produce an artifact faster if it was
            # already created, but it is still a duplicate artifact that
            # requires unnecessary things e.g. disk IO.
            #
            # The optimization is: If the backend compilation cache hits,
            # then do an early return from the backend compilation and look up
            # which of the previous in-memory artifacts we created to reuse.
            #
            # We implemented this by monkey-patching torch (torch does not
            # easily expose the cache_key function), but in the future torch
            # should expose the cache_key function that we can just call
            # directly before invoking backend compilation.
            cache_key = None
            orig = torch._functorch._aot_autograd.autograd_cache.autograd_cache_key

            def autograd_cache_key(*args, **kwargs):
                result = orig(*args, **kwargs)
                if result is None:
                    return None
                nonlocal cache_key
                cache_key = result[0]
                if cache_key in self.loaded_artifacts:
                    raise StopCompiling()
                return result

            from unittest.mock import patch

            with (
                # Graphs that are isometric (different node names but same
                # structure) should be treated as the same.
                torch._functorch.config.patch(autograd_cache_normalize_inputs=True),
                patch(
                    "torch._functorch._aot_autograd.autograd_cache.autograd_cache_key",
                    autograd_cache_key,
                ),
            ):
                try:
                    compiled_graph, handle = self.compiler.compile(
                        graph,
                        example_inputs,
                        additional_inductor_config,
                        compile_range,
                        maybe_key,
                    )
                except StopCompiling:
                    assert cache_key is not None
                    return self.loaded_artifacts[cache_key]
            if cache_key is not None and compiled_graph is not None:
                self.loaded_artifacts[cache_key] = compiled_graph
358
359
360
361

        assert compiled_graph is not None, "Failed to compile the graph"

        # store the artifact in the cache
362
        if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
363
364
365
366
            self.cache[(compile_range, graph_index, self.compiler.name)] = {
                "graph_handle": handle,
                "cache_key": cache_key,
            }
367
            compilation_counter.num_cache_entries_updated += 1
368
            self.is_cache_updated = True
369
370
            if graph_index == 0:
                # adds some info logging for the first graph
371
372
373
                logger.info_once(
                    "Cache the graph of compile range %s for later use",
                    str(compile_range),
374
                )
375
376
377
378
379
380
381
            logger.debug(
                "Store the %s-th graph for compile range%s from %s via handle %s",
                graph_index,
                str(compile_range),
                self.compiler.name,
                handle,
            )
382
383
384

        # after compiling the last graph, record the end time
        if graph_index == num_graphs - 1:
385
            elapsed = time.perf_counter() - compilation_start_time
386
            compilation_config.compilation_time += elapsed
387
388
389
390
391
392
            logger.info_once(
                "Compiling a graph for compile range %s takes %.2f s",
                str(compile_range),
                elapsed,
                scope="local",
            )
393

394
        return compiled_graph
395
396


397
398
399
400
class StopCompiling(BaseException):
    pass


401
402
403
@dataclasses.dataclass
class SplitItem:
    submod_name: str
404
    graph_id: int
405
406
407
408
    is_splitting_graph: bool
    graph: fx.GraphModule


409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
def _is_empty_allocation_node(node: fx.Node) -> bool:
    if node.op == "call_method":
        return node.target == "new_empty"

    if node.op != "call_function":
        return False

    target = node.target
    if target in (torch.empty, torch.empty_like, torch.empty_strided):
        return True

    if isinstance(target, torch._ops.OpOverloadPacket):
        packet_name = target._qualified_op_name
    elif isinstance(target, torch._ops.OpOverload):
        packet_name = target.name()
    else:
        return False

    return packet_name.startswith("aten::empty") or packet_name.startswith(
        "aten::new_empty"
    )


def _merge_empty_only_subgraphs(
    node_to_subgraph_id: dict[fx.Node, int],
) -> None:
    """
    Merge a partition that only contains an empty allocation op into the
    previous partition. This avoids generating standalone empty submodules,
    which can lead to empty cudagraph captures.
    """

    nodes_by_subgraph_id: dict[int, list[fx.Node]] = defaultdict(list)
    subgraph_id_order: list[int] = []
    for node, subgraph_id in node_to_subgraph_id.items():
        if subgraph_id not in nodes_by_subgraph_id:
            subgraph_id_order.append(subgraph_id)
        nodes_by_subgraph_id[subgraph_id].append(node)

    prev_subgraph_id: int | None = None
    for subgraph_id in subgraph_id_order:
        nodes = nodes_by_subgraph_id[subgraph_id]
        if (
            len(nodes) == 1
            and _is_empty_allocation_node(nodes[0])
            and prev_subgraph_id is not None
        ):
            node_to_subgraph_id[nodes[0]] = prev_subgraph_id
            continue
        prev_subgraph_id = subgraph_id


461
def split_graph(
462
    graph: fx.GraphModule, splitting_ops: list[str]
463
) -> tuple[fx.GraphModule, list[SplitItem]]:
464
465
    # split graph by ops
    subgraph_id = 0
466
467
    node_to_subgraph_id: dict[fx.Node, int] = {}
    split_op_graphs: list[int] = []
468
469
470
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
471

472
473
474
475
476
477
478
479
480
481
482
483
        # Check if this is a getitem operation on a node from an earlier subgraph.
        # If so, assign it to the same subgraph as its input to avoid passing entire
        # tuple as input to submodules, which is against standalone_compile and
        # AoTAutograd input requirement.
        if node.op == "call_function" and node.target == operator.getitem:
            # Assign this getitem to the same subgraph as its input
            input_node = node.args[0]
            if input_node.op != "placeholder":
                assert input_node in node_to_subgraph_id
                node_to_subgraph_id[node] = node_to_subgraph_id[input_node]
                continue

484
        if should_split(node, splitting_ops):
485
486
487
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
488
489
490
491
492
493
494
495

            # keep consecutive splitting ops together
            # (we know node.next exists because node isn't the last (output) node)
            if should_split(node.next, splitting_ops):
                # this will get incremented by the next node
                subgraph_id -= 1
            else:
                subgraph_id += 1
496
497
498
        else:
            node_to_subgraph_id[node] = subgraph_id

499
500
    _merge_empty_only_subgraphs(node_to_subgraph_id)

501
502
503
504
505
    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
506
507
        graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
    )
508

509
    outputs = []
510

511
    names = [name for (name, module) in split_gm.named_modules()]
512

513
514
515
516
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
517

518
        module = getattr(split_gm, name)
519

520
        graph_id = int(name.replace("submod_", ""))
521
        outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
522

523
    # sort by integer graph_id, rather than string name
524
    outputs.sort(key=lambda x: x.graph_id)
525

526
    return split_gm, outputs
527
528


529
530
compilation_start_time = 0.0

531

532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
def wrap_with_cudagraph_if_needed(
    piecewise_backend: Any,
    vllm_config: VllmConfig,
    compilation_config: CompilationConfig,
    is_first_graph: bool,
    is_last_graph: bool,
) -> Any:
    """
    Wrap a piecewise backend with CUDA graph wrapper if needed.
    This function is shared between VllmBackend and
    construct_serializable_fn_from_inductor_cache.

    Args:
        piecewise_backend: The backend to wrap
        vllm_config: The vLLM configuration
        compilation_config: The compilation configuration
        is_first_graph: Whether this is the first graph in the sequence
        is_last_graph: Whether this is the last graph in the sequence

    Returns:
        The wrapped backend if CUDA graphs are enabled, otherwise the original backend
    """
    if (
        not compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
        or compilation_config.use_inductor_graph_partition
    ):
        return piecewise_backend

    # We're using Dynamo-based piecewise splitting, so we wrap
    # the whole subgraph with a static graph wrapper.
    from .cuda_graph import CUDAGraphOptions

    # resolve the static graph wrapper class (e.g. CUDAGraphWrapper
    # class) as platform dependent.
    static_graph_wrapper_class = resolve_obj_by_qualname(
        current_platform.get_static_graph_wrapper_cls()
    )

    # Always assign PIECEWISE runtime mode to the
    # CUDAGraphWrapper for piecewise_backend, to distinguish
    # it from the FULL cudagraph runtime mode, no matter it
    # is wrapped on a full or piecewise fx graph.
    return static_graph_wrapper_class(
        runnable=piecewise_backend,
        vllm_config=vllm_config,
        runtime_mode=CUDAGraphMode.PIECEWISE,
        cudagraph_options=CUDAGraphOptions(
            debug_log_enable=is_first_graph,
            gc_disable=not is_first_graph,
            weak_ref_output=is_last_graph,
        ),
    )


586
class PiecewiseCompileInterpreter(torch.fx.Interpreter):  # type: ignore[misc]
587
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
588
589
590
    It runs the given split graph interpreter, and for each submodule in
    `compile_submod_names`, creates a PiecewiseBackend and compiles all
    ranges up front.
591
592
593
594
595

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
596
597
598
599
600
601
602
603
604
605
606
607

    Note: This class shares similar logic with
    reconstruct_serializable_fn_from_mega_artifact in caching.py.
    Both create PiecewiseBackend instances and wrap them with cudagraph.
    The key difference is:
    - reconstruct_serializable_fn_from_mega_artifact: PiecewiseBackend receives
      pre-compiled runnables (compiled_runnables is set, graph is None)
    - this class: PiecewiseBackend receives the FX graph to compile
      (graph is set, compiled_runnables is None)


    If modifying the backend creation/wrapping logic, consider updating both.
608
609
    """

610
611
612
613
614
615
    def __init__(
        self,
        module: torch.fx.GraphModule,
        compile_submod_names: list[str],
        vllm_config: VllmConfig,
        vllm_backend: "VllmBackend",
616
    ) -> None:
617
618
        super().__init__(module)
        self.compile_submod_names = compile_submod_names
619
620
        self.compilation_config = vllm_config.compilation_config
        self.vllm_config = vllm_config
621
        self.vllm_backend = vllm_backend
622
623
        # When True, it annoyingly dumps the torch.fx.Graph on errors.
        self.extra_traceback = False
624

625
    @instrument(span_name="Inductor compilation")
626
    def run(self, *args: Any) -> Any:
627
        return super().run(*args)
628

629
630
631
632
633
634
    def call_module(
        self,
        target: torch.fx.node.Target,
        args: tuple[torch.fx.node.Argument, ...],
        kwargs: dict[str, Any],
    ) -> Any:
635
        assert isinstance(target, str)
636

637
638
639
        gm = getattr(self.module, target)
        outputs = gm.graph.output_node().args[0]
        output = fx.map_arg(outputs, lambda node: node.meta["example_value"])
640
641

        if target in self.compile_submod_names:
642
            index = self.compile_submod_names.index(target)
643
            submod = self.fetch_attr(target)
644

645
646
647
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
648

649
            # Lazy import here to avoid circular import
650
651
            from torch._inductor.compile_fx import graph_returns_tuple

652
            from .piecewise_backend import PiecewiseBackend
653

654
            piecewise_backend = PiecewiseBackend(
655
656
657
658
659
660
                submod,
                self.vllm_config,
                index,
                len(self.compile_submod_names),
                sym_shape_indices,
                self.vllm_backend,
661
                graph_returns_tuple(submod),
662
                submod_name=target,
663
            )
664

665
666
667
668
669
670
671
            self.module.__dict__[target] = wrap_with_cudagraph_if_needed(
                piecewise_backend,
                self.vllm_config,
                self.compilation_config,
                piecewise_backend.is_first_graph,
                piecewise_backend.is_last_graph,
            )
672

673
674
675
676
677
            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


678
679
680
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"
681
model_is_encoder: bool = False
682
683
684


@contextmanager
685
def set_model_tag(tag: str, is_encoder: bool = False) -> Generator[None, None, None]:
686
687
    """Context manager to set the model tag."""
    global model_tag
688
    global model_is_encoder
689
    assert tag != model_tag, (
690
        f"Model tag {tag} is the same as the current tag {model_tag}."
691
    )
692
    old_tag = model_tag
693
694
    old_is_encoder = model_is_encoder

695
    model_tag = tag
696
    model_is_encoder = is_encoder
697
698
699
700
    try:
        yield
    finally:
        model_tag = old_tag
701
        model_is_encoder = old_is_encoder
702
703


704
class VllmBackend:
705
    """The compilation backend for `torch.compile` with vLLM.
706
    It is used for compilation mode of `CompilationMode.VLLM_COMPILE`,
707
    where we customize the compilation.
708

709
710
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
711

712
713
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
714
    """
715

716
717
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
718
719
720
721
722
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
723
    piecewise_graphs: list[SplitItem]
724
    returned_callable: Callable[..., Any]
725
    # Inductor passes to run on the graph pre-defunctionalization
726
    post_grad_passes: Sequence[Callable[..., Any]]
727
    compiler_manager: CompilerManager
728
729
730
    # Copy of CompilationConfig.inductor_compile_config +
    # an entry for PostGradPassManager
    inductor_config: dict[str, Any]
731

732
733
    def __init__(
        self,
734
        vllm_config: VllmConfig,
735
        prefix: str = "",
736
        is_encoder: bool = False,
737
    ) -> None:
738
739
        # if the model is initialized with a non-empty prefix,
        # then usually it's enough to use that prefix,
740
        # e.g. language_model, vision_model, etc.
741
742
743
744
745
        # when multiple parts are initialized as independent
        # models, we need to use the model_tag to distinguish
        # them, e.g. backbone (default), eagle_head, etc.
        self.prefix = prefix or model_tag

746
        # Mark compilation for encoder.
747
        self.is_encoder = is_encoder or model_is_encoder
748

749
        # Passes to run on the graph post-grad.
750
751
752
753
        self.pass_manager = resolve_obj_by_qualname(
            current_platform.get_pass_manager_cls()
        )()
        self.pass_key = current_platform.pass_key
754

755
756
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
757

758
        self.compiler_manager: CompilerManager = CompilerManager(
759
760
            self.compilation_config
        )
761

762
763
764
765
766
767
        # Deepcopy the inductor config to detach the post-grad custom pass
        # from CompilationConfig.
        # We want to avoid PostGradPassManager in CompilationConfig because
        # in future we need PostGradPassManager.uuid() to be executed
        # only at compile time.
        self.inductor_config = deepcopy(self.compilation_config.inductor_compile_config)
768
769
770
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
    def collect_standalone_compile_artifacts(
        self,
    ) -> tuple[Any, dict[str, list[int]] | None, dict[str, bool] | None]:
        """Collect inductor cache artifacts from all piecewise backends.

        Returns:
            tuple: (standalone_compile_artifacts, sym_shape_indices_map,
                    returns_tuple_map)
                - standalone_compile_artifacts: StandaloneCompiledArtifacts
                  with compiled artifacts
                - sym_shape_indices_map: dict mapping submod_name to
                  sym_shape_indices
                - returns_tuple_map: dict mapping submod_name to
                  returns_tuple
        """

        if not envs.VLLM_USE_MEGA_AOT_ARTIFACT:
            return None, None, None

        from .caching import StandaloneCompiledArtifacts
        from .piecewise_backend import PiecewiseBackend

        standalone_compile_artifacts = StandaloneCompiledArtifacts()
        sym_shape_indices_map = {}
        returns_tuple_map = {}

        for name, _ in self.split_gm.named_children():
            # get the actual attribute (shadowed by PiecewiseBackend in __dict__)
            child = getattr(self.split_gm, name)
            # unwrap the static graph wrapper class if applicable
            piecewise_backend = child.runnable if hasattr(child, "runnable") else child

            if not isinstance(piecewise_backend, PiecewiseBackend):
                continue

            submod_name = name
            sym_shape_indices_map[submod_name] = piecewise_backend.sym_shape_indices
            returns_tuple_map[submod_name] = piecewise_backend.returns_tuple

            for shape_str, bytes_data in piecewise_backend.to_bytes().items():
                standalone_compile_artifacts.insert(submod_name, shape_str, bytes_data)
                logger.debug(
                    "collected artifact for %s shape %s (%d bytes)",
                    submod_name,
                    shape_str,
                    len(bytes_data),
                )

        logger.info(
            "collected artifacts: %d entries, %d artifacts, %d bytes total",
            standalone_compile_artifacts.num_entries(),
            standalone_compile_artifacts.num_artifacts(),
            standalone_compile_artifacts.size_bytes(),
        )

        logger.debug(
            "standalone compile artifact keys: %s",
            list(standalone_compile_artifacts.submodule_bytes.keys()),
        )

        return standalone_compile_artifacts, sym_shape_indices_map, returns_tuple_map

833
    def configure_post_pass(self) -> None:
834
        self.pass_manager.configure(self.vllm_config)
835

836
837
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
838
839
840
841
842
        if self.pass_key in self.inductor_config:
            if isinstance(self.inductor_config[self.pass_key], PostGradPassManager):
                raise ValueError(
                    "PostGradPassManager can not be kept in CompilationConfig."
                )
843
            else:
844
                # Config should automatically wrap all inductor passes
845
846
847
848
849
850
851
                assert isinstance(
                    self.compilation_config.inductor_compile_config[self.pass_key],
                    InductorPass,
                )
                self.pass_manager.add(
                    self.compilation_config.inductor_compile_config[self.pass_key]
                )
852
        self.inductor_config[self.pass_key] = self.pass_manager
853

854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
    def _log_compilation_config(self):
        """Log vLLM compilation config for TORCH_TRACE/tlparse."""
        cc = self.compilation_config
        pass_cfg = cc.pass_config

        # Helper to convert lists to comma-separated strings for tlparse display
        def list_to_str(lst: list | None) -> str:
            if lst is None:
                return ""
            return ", ".join(str(x) for x in lst)

        # Get enabled passes by introspecting dataclass fields
        enabled_passes = [
            f.name
            for f in dataclasses.fields(pass_cfg)
            if isinstance(getattr(pass_cfg, f.name), bool) and getattr(pass_cfg, f.name)
        ]

        trace_structured(
            "artifact",
            metadata_fn=lambda: {
                "name": "vllm_compilation_config",
                "encoding": "json",
            },
            payload_fn=lambda: json.dumps(
                {
                    "model": self.vllm_config.model_config.model,
                    "prefix": self.prefix,
                    "mode": str(cc.mode),
                    "backend": cc.backend,
                    "custom_ops": list_to_str(cc.custom_ops),
                    "splitting_ops": list_to_str(cc.splitting_ops),
                    "cudagraph_mode": str(cc.cudagraph_mode),
                    "compile_sizes": list_to_str(cc.compile_sizes),
                    "compile_ranges_split_points": list_to_str(
                        cc.compile_ranges_split_points
                    ),
                    "use_inductor_graph_partition": cc.use_inductor_graph_partition,
                    "inductor_passes": list_to_str(list(cc.inductor_passes.keys())),
                    "enabled_passes": list_to_str(enabled_passes),
                    "dynamic_shapes_type": str(cc.dynamic_shapes_config.type),
                    "dynamic_shapes_evaluate_guards": cc.dynamic_shapes_config.evaluate_guards,  # noqa: E501
                }
            ),
        )

900
    @dynamo_timed("vllm_backend")
901
902
903
904
905
    def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
        from .caching import (
            VllmSerializableFunction,
        )

906
        vllm_config = self.vllm_config
907
908
909

        self._log_compilation_config()

910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
        # Minimal hashing here with existing utilities, reused below.

        env_factors = envs.compile_factors()
        env_hash = hash_factors(env_factors)
        # Compute config/compiler/code hashes once and reuse
        config_hash = vllm_config.compute_hash()
        compiler_hash = self.compiler_manager.compute_hash(vllm_config)
        forward_code_files = list(sorted(self.compilation_config.traced_files))

        logger.debug(
            "Traced files (to be considered for compilation cache):\n%s",
            lazy(lambda: "\n".join(forward_code_files)),
        )
        hash_content = []
        for filepath in forward_code_files:
            hash_content.append(filepath)
            if filepath == "<string>":
                # This means the function was dynamically generated, with
                # e.g. exec(). We can't actually check these.
                continue
            try:
                with open(filepath) as f:
                    hash_content.append(f.read())
933
            except (OSError, UnicodeDecodeError):
934
935
936
937
938
                logger.warning("Failed to read file %s", filepath)
                continue
        code_hash = hashlib.sha256("\n".join(hash_content).encode()).hexdigest()
        # Clear after consumption
        self.compilation_config.traced_files.clear()
939
940
941
942
943
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.
944
945
946
947
            factors = [env_hash, config_hash, code_hash, compiler_hash]
            # Use SHA-256 for cache key hashing to be consistent across
            # compute_hash functions. Truncate for a short cache dir name.
            hash_key = hashlib.sha256(str(factors).encode()).hexdigest()[:10]
948
            cache_dir = os.path.join(
949
                envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key
950
951
952
            )
            self.compilation_config.cache_dir = cache_dir

953
        cache_dir = self.compilation_config.cache_dir
954
        os.makedirs(cache_dir, exist_ok=True)
955
        self.compilation_config.cache_dir = cache_dir
956
        rank = vllm_config.parallel_config.rank
957
        dp_rank = vllm_config.parallel_config.data_parallel_index
958
        local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
959
        os.makedirs(local_cache_dir, exist_ok=True)
960
        self.compilation_config.local_cache_dir = local_cache_dir
961

962
        # Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
963
        disable_cache = not is_compile_cache_enabled(self.inductor_config)
964

965
966
967
968
969
970
971
        # TODO(patchy): ngram gpu kernel will cause vllm torch compile cache errors.
        is_ngram_gpu_enabled = (
            vllm_config.speculative_config is not None
            and vllm_config.speculative_config.use_ngram_gpu()
        )
        disable_cache = disable_cache or is_ngram_gpu_enabled

972
        if disable_cache:
973
            logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
974
        else:
975
976
977
978
            logger.info_once(
                "Using cache directory: %s for vLLM's torch.compile",
                local_cache_dir,
                scope="local",
979
            )
980

981
982
983
        self.compiler_manager.initialize_cache(
            local_cache_dir, disable_cache, self.prefix
        )
984

985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
        # Reuses existing cache key

        logger.debug(
            "torch.compile cache factors: env=%s cfg=%s comp=%s code=%s dir=%s",
            env_hash,
            config_hash,
            compiler_hash,
            code_hash,
            local_cache_dir,
        )

        # Persist and log only hash-relevant factors together.
        try:
            logger.debug(
                "Compile env factors (raw):\n%s\nVllm config hash: %s",
                lazy(partial(pprint.pformat, env_factors, width=120)),
                config_hash,
            )
            meta_path = os.path.join(local_cache_dir, "cache_key_factors.json")
            if not os.path.exists(meta_path):
                with open(meta_path, "w") as f:
                    json.dump(
                        {
                            "env": env_factors,  # raw factors used for env_hash
                            "config_hash": config_hash,
                            "code_hash": code_hash,
                            "compiler_hash": compiler_hash,
                        },
                        f,
                        indent=2,
                        sort_keys=True,
                    )
        except Exception:
            # Best-effort only; metadata write failures are non-fatal.
            logger.warning(
                (
                    "Could not write compile cache metadata at %s; continuing without "
                    "metadata. Compiled cache remains valid; diagnostics may be "
                    "limited."
                ),
                local_cache_dir,
                exc_info=True,
            )

1029
1030
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
1031
        compilation_counter.num_graphs_seen += 1
1032
        from .monitor import torch_compile_start_time
1033

1034
        dynamo_time = time.perf_counter() - torch_compile_start_time
1035
1036
1037
        logger.info_once(
            "Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local"
        )
1038
        self.compilation_config.compilation_time += dynamo_time
1039

1040
1041
1042
1043
1044
        # Record Dynamo time in tracing if available
        start_time = int(torch_compile_start_time * 1e9)
        attributes = {"dynamo.time_seconds": dynamo_time}
        instrument_manual("Dynamo bytecode transform", start_time, None, attributes)

1045
1046
1047
1048
1049
        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
1050
        self.configure_post_pass()
1051

1052
1053
1054
1055
1056
1057
        if self.compilation_config.use_inductor_graph_partition:
            # Let Inductor decide partitioning; avoid FX-level pre-splitting.
            fx_split_ops: list[str] = []
        else:
            fx_split_ops = self.compilation_config.splitting_ops or []

1058
        self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
1059

1060
1061
1062
1063
1064
1065
        # keep a split_gm copy from BEFORE the interpreter replaces
        # submodules with PiecewiseBackend -- used for serialization
        original_split_gm = None
        if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
            original_split_gm = deepcopy(self.split_gm)

1066
        from torch._dynamo.utils import lazy_format_graph_code
1067
1068
1069
1070
1071

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
1072

1073
1074
1075
1076
1077
1078
1079
        # Log the piecewise split graph for TORCH_TRACE/tlparse
        trace_structured(
            "graph_dump",
            metadata_fn=lambda: {"name": "vllm_piecewise_split_graph"},
            payload_fn=lambda: self.split_gm.print_readable(print_output=False),
        )

1080
        compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
1081
        submod_names_to_compile = [
1082
1083
            item.submod_name
            for item in self.piecewise_graphs
1084
1085
1086
            if not item.is_splitting_graph
        ]

1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
        # Extract fake values from the graph to use them when needed.
        all_fake_values = []
        for i in graph.graph.find_nodes(op="placeholder"):
            all_fake_values.append(i.meta["example_value"])

        fake_args = [
            all_fake_values[i] if isinstance(t, torch.Tensor) else t
            for i, t in enumerate(example_inputs)
        ]

1097
        # propagate the split graph to the piecewise backend,
1098
1099
1100
        # compile submodules with symbolic shapes, and compile all ranges
        # up front so that compilation is complete before the callable
        # is returned.
1101
1102
        PiecewiseCompileInterpreter(
            self.split_gm, submod_names_to_compile, self.vllm_config, self
1103
        ).run(*fake_args)
1104

1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
        # All compilation is done. Save the cache.
        time_before_saving = time.perf_counter()
        self.compiler_manager.save_to_file()
        elapsed = time.perf_counter() - time_before_saving
        if elapsed > 1:
            logger.info_once(
                "Saved compiler manager cache in %.2f seconds.",
                elapsed,
                scope="local",
            )

1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
        from torch._guards import detect_fake_mode

        fake_mode = detect_fake_mode()

        if (
            self.compilation_config.dynamic_shapes_config.evaluate_guards
            and self.compilation_config.dynamic_shapes_config.type
            == DynamicShapesType.BACKED
        ):
            from torch.utils._sympy.value_ranges import ValueRanges

            # Drop counter-0/1 specializations guards; for backed dynamic shapes,
            # torch.compile will specialize for 0/1 inputs or otherwise guards that
            # shape is >= 2. This is because it's really hard not to hit a check
            # against 0/1. When we evaluate shape guards, we exclude checking those
            # guards (We would fail always otherwise).

            # We avoid that by updating the ranges of backed sizes when the min is
            # 2 for any, we assume it's 0.
            for s, r in fake_mode.shape_env.var_to_range.items():
                if r.lower == 2:
                    fake_mode.shape_env.var_to_range[s] = ValueRanges(0, r.upper)

1139
1140
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
1141
1142
            # code adapted from
            # https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30
1143
            # use `print_readable` because it can include submodules
1144
1145
1146
1147
            src = (
                "from __future__ import annotations\nimport torch\n"
                + self.split_gm.print_readable(print_output=False)
            )
1148
1149
1150
1151
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

1152
1153
1154
            logger.debug_once(
                "Computation graph saved to %s", graph_path, scope="local"
            )
1155

1156
        self._called = True
1157
1158
1159
        graph_to_serialize = (
            original_split_gm if envs.VLLM_USE_MEGA_AOT_ARTIFACT else self.graph
        )
1160

1161
1162
1163
1164
        if (
            self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
            or not self.compilation_config.cudagraph_copy_inputs
        ):
1165
            return VllmSerializableFunction(
1166
1167
1168
1169
1170
1171
                graph_to_serialize,
                example_inputs,
                self.prefix,
                self.split_gm,
                is_encoder=self.is_encoder,
                vllm_backend=self,
1172
            )
1173
1174

        # index of tensors that have symbolic shapes (batch size)
1175
1176
1177
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
1178

1179
        sym_tensor_indices = [
1180
1181
1182
1183
            i
            for i, x in enumerate(fake_args)
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
            and any(is_symbolic(d) for d in x.size())
1184
1185
1186
1187
1188
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
1189
1190
1191
1192
1193
        copy_and_call = make_copy_and_call(
            sym_tensor_indices,
            [example_inputs[x].clone() for x in sym_tensor_indices],
            self.split_gm,
        )
1194

1195
        return VllmSerializableFunction(
1196
1197
1198
1199
1200
1201
1202
            graph_to_serialize,
            example_inputs,
            self.prefix,
            copy_and_call,
            is_encoder=self.is_encoder,
            vllm_backend=self,
            sym_tensor_indices=sym_tensor_indices,
1203
        )