backends.py 30.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import ast
5
import dataclasses
6
import hashlib
7
import json
8
import operator
9
10
import os
import pprint
11
import time
12
from collections.abc import Callable, Sequence
13
from contextlib import contextmanager
14
from functools import partial
15
from typing import Any
16
17
18

import torch
import torch.fx as fx
19
from torch._dispatch.python import enable_python_dispatcher
20

21
import vllm.envs as envs
22
23
24
from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.partition_rules import (
    inductor_partition_rule_context,
25
    should_split,
26
)
27
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
28
from vllm.config.utils import hash_factors
29
from vllm.logger import init_logger
30
from vllm.logging_utils import lazy
31
from vllm.platforms import current_platform
32
from vllm.utils.import_utils import resolve_obj_by_qualname
33
from vllm.utils.torch_utils import is_torch_equal_or_newer
34

35
from .caching import VllmSerializableFunction
36
37
38
39
40
from .compiler_interface import (
    CompilerInterface,
    EagerAdaptor,
    InductorAdaptor,
    InductorStandaloneAdaptor,
41
    is_compile_cache_enabled,
42
)
43
from .counter import compilation_counter
44
45
from .inductor_pass import InductorPass
from .pass_manager import PostGradPassManager
46
47
48

logger = init_logger(__name__)

49

50
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
51
    if compilation_config.backend == "inductor":
52
53
        # Use standalone compile only if requested, version is new enough,
        # and the symbol actually exists in this PyTorch build.
54
55
56
57
58
        if (
            envs.VLLM_USE_STANDALONE_COMPILE
            and is_torch_equal_or_newer("2.8.0.dev")
            and hasattr(torch._inductor, "standalone_compile")
        ):
59
            logger.debug("Using InductorStandaloneAdaptor")
60
61
62
            return InductorStandaloneAdaptor(
                compilation_config.compile_cache_save_format
            )
63
        else:
64
            logger.debug("Using InductorAdaptor")
65
66
            return InductorAdaptor()
    else:
67
        assert compilation_config.backend == "eager", (
68
            "Custom backends not supported with CompilationMode.VLLM_COMPILE"
69
70
        )

71
        logger.debug("Using EagerAdaptor")
72
73
74
        return EagerAdaptor()


75
76
77
78
79
class CompilerManager:
    """
    A manager to manage the compilation process, including
    caching the compiled graph, loading the compiled graph,
    and compiling the graph.
80

81
82
83
    The cache is a dict mapping
    `(runtime_shape, graph_index, backend_name)`
    to `any_data` returned from the compiler.
84

85
86
87
    When serializing the cache, we save it to a Python file
    for readability. We don't use json here because json doesn't
    support int as key.
88
89
    """

90
    def __init__(self, compilation_config: CompilationConfig):
91
        self.cache: dict[tuple[int | None, int, str], Any] = dict()
92
        self.is_cache_updated = False
93
94
        self.compilation_config = compilation_config
        self.compiler = make_compiler(compilation_config)
95

96
97
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        return self.compiler.compute_hash(vllm_config)
98

99
    @contextmanager
100
    def compile_context(self, runtime_shape: int | None = None):
101
102
103
104
105
        """Provide compilation context for the duration of compilation to set
        any torch global properties we want to scope to a single Inductor
        compilation (e.g. partition rules, pass context)."""
        with pass_context(runtime_shape):
            if self.compilation_config.use_inductor_graph_partition:
106
                with inductor_partition_rule_context(
107
                    self.compilation_config.splitting_ops
108
                ):
109
110
111
112
                    yield
            else:
                yield

113
114
115
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
    ):
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        """
        Initialize the cache directory for the compiler.

        The organization of the cache directory is as follows:
        cache_dir=/path/to/hash_str/rank_i_j/prefix/
        inside cache_dir, there will be:
        - vllm_compile_cache.py
        - computation_graph.py
        - transformed_code.py

        for multiple prefixes, they can share the same
        base cache dir of /path/to/hash_str/rank_i_j/ ,
        to store some common compilation artifacts.
        """

131
        self.disable_cache = disable_cache
132
        self.cache_dir = cache_dir
133
134
135
136
        self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")

        if not disable_cache and os.path.exists(self.cache_file_path):
            # load the cache from the file
137
            with open(self.cache_file_path) as f:
138
139
140
141
142
                # we use ast.literal_eval to parse the data
                # because it is a safe way to parse Python literals.
                # do not use eval(), it is unsafe.
                self.cache = ast.literal_eval(f.read())

143
144
145
        self.compiler.initialize_cache(
            cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
        )
146
147

    def save_to_file(self):
148
        if self.disable_cache or not self.is_cache_updated:
149
            return
150
151
        printer = pprint.PrettyPrinter(indent=4)
        data = printer.pformat(self.cache)
152
        with open(self.cache_file_path, "w") as f:
153
154
            f.write(data)

155
156
157
158
159
    def load(
        self,
        graph: fx.GraphModule,
        example_inputs: list[Any],
        graph_index: int,
160
161
        runtime_shape: int | None = None,
    ) -> Callable | None:
162
163
164
        if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
            return None
        handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
165
166
167
        compiled_graph = self.compiler.load(
            handle, graph, example_inputs, graph_index, runtime_shape
        )
168
169
        if runtime_shape is None:
            logger.debug(
170
171
172
173
174
                "Directly load the %s-th graph for dynamic shape from %s via handle %s",
                graph_index,
                self.compiler.name,
                handle,
            )
175
176
        else:
            logger.debug(
177
178
179
180
181
182
                "Directly load the %s-th graph for shape %s from %s via handle %s",
                graph_index,
                str(runtime_shape),
                self.compiler.name,
                handle,
            )
183
184
        return compiled_graph

185
186
187
188
189
190
191
192
    def compile(
        self,
        graph: fx.GraphModule,
        example_inputs,
        additional_inductor_config,
        compilation_config: CompilationConfig,
        graph_index: int = 0,
        num_graphs: int = 1,
193
        runtime_shape: int | None = None,
194
    ) -> Any:
195
        if graph_index == 0:
196
197
198
199
200
201
202
203
204
            # before compiling the first graph, record the start time
            global compilation_start_time
            compilation_start_time = time.time()

        compilation_counter.num_backend_compilations += 1

        compiled_graph = None

        # try to load from the cache
205
        compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
206
        if compiled_graph is not None:
207
208
209
210
211
            if graph_index == num_graphs - 1:
                # after loading the last graph for this shape, record the time.
                # there can be multiple graphs due to piecewise compilation.
                now = time.time()
                elapsed = now - compilation_start_time
212
                compilation_config.compilation_time += elapsed
213
214
215
                if runtime_shape is None:
                    logger.info(
                        "Directly load the compiled graph(s) for dynamic shape "
216
217
218
                        "from the cache, took %.3f s",
                        elapsed,
                    )
219
220
221
                else:
                    logger.info(
                        "Directly load the compiled graph(s) for shape %s "
222
223
224
225
                        "from the cache, took %.3f s",
                        str(runtime_shape),
                        elapsed,
                    )
226
227
228
229
            return compiled_graph

        # no compiler cached the graph, or the cache is disabled,
        # we need to compile it
230
231
232
233
        if isinstance(self.compiler, InductorAdaptor):
            # Let compile_fx generate a key for us
            maybe_key = None
        else:
234
            maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
235
236
237
238
239
240
241
242
243

        with self.compile_context(runtime_shape):
            compiled_graph, handle = self.compiler.compile(
                graph,
                example_inputs,
                additional_inductor_config,
                runtime_shape,
                maybe_key,
            )
244
245
246
247

        assert compiled_graph is not None, "Failed to compile the graph"

        # store the artifact in the cache
248
        if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
249
            self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
250
            compilation_counter.num_cache_entries_updated += 1
251
            self.is_cache_updated = True
252
253
            if graph_index == 0:
                # adds some info logging for the first graph
254
                if runtime_shape is None:
255
256
257
                    logger.info_once(
                        "Cache the graph for dynamic shape for later use", scope="local"
                    )
258
                else:
259
260
261
262
                    logger.info_once(
                        "Cache the graph of shape %s for later use",
                        str(runtime_shape),
                        scope="local",
263
                    )
264
265
            if runtime_shape is None:
                logger.debug(
266
267
268
269
270
                    "Store the %s-th graph for dynamic shape from %s via handle %s",
                    graph_index,
                    self.compiler.name,
                    handle,
                )
271
272
273
            else:
                logger.debug(
                    "Store the %s-th graph for shape %s from %s via handle %s",
274
275
276
277
278
                    graph_index,
                    str(runtime_shape),
                    self.compiler.name,
                    handle,
                )
279
280
281
282
283
284
285

        # after compiling the last graph, record the end time
        if graph_index == num_graphs - 1:
            now = time.time()
            elapsed = now - compilation_start_time
            compilation_config.compilation_time += elapsed
            if runtime_shape is None:
286
287
288
289
290
                logger.info_once(
                    "Compiling a graph for dynamic shape takes %.2f s",
                    elapsed,
                    scope="local",
                )
291
            else:
292
                logger.info_once(
293
294
295
                    "Compiling a graph for shape %s takes %.2f s",
                    runtime_shape,
                    elapsed,
296
                    scope="local",
297
                )
298

299
        return compiled_graph
300
301


302
303
304
@dataclasses.dataclass
class SplitItem:
    submod_name: str
305
    graph_id: int
306
307
308
309
    is_splitting_graph: bool
    graph: fx.GraphModule


310
def split_graph(
311
    graph: fx.GraphModule, splitting_ops: list[str]
312
) -> tuple[fx.GraphModule, list[SplitItem]]:
313
314
    # split graph by ops
    subgraph_id = 0
315
316
    node_to_subgraph_id: dict[fx.Node, int] = {}
    split_op_graphs: list[int] = []
317
318
319
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
320

321
322
323
324
325
326
327
328
329
330
331
332
        # Check if this is a getitem operation on a node from an earlier subgraph.
        # If so, assign it to the same subgraph as its input to avoid passing entire
        # tuple as input to submodules, which is against standalone_compile and
        # AoTAutograd input requirement.
        if node.op == "call_function" and node.target == operator.getitem:
            # Assign this getitem to the same subgraph as its input
            input_node = node.args[0]
            if input_node.op != "placeholder":
                assert input_node in node_to_subgraph_id
                node_to_subgraph_id[node] = node_to_subgraph_id[input_node]
                continue

333
        if should_split(node, splitting_ops):
334
335
336
337
338
339
340
341
342
343
344
345
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
346
347
        graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
    )
348

349
    outputs = []
350

351
    names = [name for (name, module) in split_gm.named_modules()]
352

353
354
355
356
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
357

358
        module = getattr(split_gm, name)
359

360
        graph_id = int(name.replace("submod_", ""))
361
        outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
362

363
    # sort by integer graph_id, rather than string name
364
    outputs.sort(key=lambda x: x.graph_id)
365

366
    return split_gm, outputs
367
368


369
370
compilation_start_time = 0.0

371
372
373
374
375
376

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
377
378
379
380
381

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
382
383
    """

384
385
386
387
388
389
390
    def __init__(
        self,
        module: torch.fx.GraphModule,
        compile_submod_names: list[str],
        vllm_config: VllmConfig,
        vllm_backend: "VllmBackend",
    ):
391
392
        super().__init__(module)
        from torch._guards import detect_fake_mode
393

394
395
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
396
397
        self.compilation_config = vllm_config.compilation_config
        self.vllm_config = vllm_config
398
        self.vllm_backend = vllm_backend
399
400
        # When True, it annoyingly dumps the torch.fx.Graph on errors.
        self.extra_traceback = False
401
402
403
404
405
406

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
407
        with self.fake_mode, enable_python_dispatcher():
408
            return super().run(*fake_args)
409

410
411
412
413
414
415
    def call_module(
        self,
        target: torch.fx.node.Target,
        args: tuple[torch.fx.node.Argument, ...],
        kwargs: dict[str, Any],
    ) -> Any:
416
417
418
419
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
420
            index = self.compile_submod_names.index(target)
421
422
423
424
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
425
            global compilation_start_time
426

427
428
429
430
431
432
433
434
435
436
437
            compiled_graph_for_dynamic_shape = (
                self.vllm_backend.compiler_manager.compile(
                    submod,
                    args,
                    self.compilation_config.inductor_compile_config,
                    self.compilation_config,
                    graph_index=index,
                    num_graphs=len(self.compile_submod_names),
                    runtime_shape=None,
                )
            )
438
            # Lazy import here to avoid circular import
439
            from .piecewise_backend import PiecewiseBackend
440

441
            piecewise_backend = PiecewiseBackend(
442
443
444
445
446
447
448
449
                submod,
                self.vllm_config,
                index,
                len(self.compile_submod_names),
                sym_shape_indices,
                compiled_graph_for_dynamic_shape,
                self.vllm_backend,
            )
450

451
452
453
454
            if (
                self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
                and not self.compilation_config.use_inductor_graph_partition
            ):
455
456
457
458
                # We're using Dynamo-based piecewise splitting, so we wrap
                # the whole subgraph with a static graph wrapper.
                from .cuda_graph import CUDAGraphOptions

459
460
461
                # resolve the static graph wrapper class (e.g. CUDAGraphWrapper
                # class) as platform dependent.
                static_graph_wrapper_class = resolve_obj_by_qualname(
462
463
                    current_platform.get_static_graph_wrapper_cls()
                )
464
465
466
467
468
469
470
471
472
473
474
475

                # Always assign PIECEWISE runtime mode to the
                # CUDAGraphWrapper for piecewise_backend, to distinguish
                # it from the FULL cudagraph runtime mode, no matter it
                # is wrapped on a full or piecewise fx graph.
                self.module.__dict__[target] = static_graph_wrapper_class(
                    runnable=piecewise_backend,
                    vllm_config=self.vllm_config,
                    runtime_mode=CUDAGraphMode.PIECEWISE,
                    cudagraph_options=CUDAGraphOptions(
                        debug_log_enable=piecewise_backend.is_first_graph,
                        gc_disable=not piecewise_backend.is_first_graph,
476
477
478
                        weak_ref_output=piecewise_backend.is_last_graph,
                    ),
                )
479
480
481
            else:
                self.module.__dict__[target] = piecewise_backend

482
483
484
485
486
            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


487
488
489
490
491
492
493
494
495
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"


@contextmanager
def set_model_tag(tag: str):
    """Context manager to set the model tag."""
    global model_tag
496
    assert tag != model_tag, (
497
        f"Model tag {tag} is the same as the current tag {model_tag}."
498
    )
499
500
501
502
503
504
505
506
    old_tag = model_tag
    model_tag = tag
    try:
        yield
    finally:
        model_tag = old_tag


507
class VllmBackend:
508
    """The compilation backend for `torch.compile` with vLLM.
509
    It is used for compilation mode of `CompilationMode.VLLM_COMPILE`,
510
    where we customize the compilation.
511

512
513
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
514

515
516
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
517
    """
518

519
520
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
521
522
523
524
525
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
526
    piecewise_graphs: list[SplitItem]
527
    returned_callable: Callable
528
529
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
530
531
    sym_tensor_indices: list[int]
    input_buffers: list[torch.Tensor]
532
    compiler_manager: CompilerManager
533

534
535
    def __init__(
        self,
536
        vllm_config: VllmConfig,
537
        prefix: str = "",
538
    ):
539
540
        # if the model is initialized with a non-empty prefix,
        # then usually it's enough to use that prefix,
541
        # e.g. language_model, vision_model, etc.
542
543
544
545
546
        # when multiple parts are initialized as independent
        # models, we need to use the model_tag to distinguish
        # them, e.g. backbone (default), eagle_head, etc.
        self.prefix = prefix or model_tag

547
548
        # Passes to run on the graph post-grad.
        self.post_grad_pass_manager = PostGradPassManager()
549

550
551
552
        self.sym_tensor_indices = []
        self.input_buffers = []

553
554
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
555

556
        self.compiler_manager: CompilerManager = CompilerManager(
557
558
            self.compilation_config
        )
559

560
561
562
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

563
    def configure_post_pass(self):
564
        config = self.compilation_config
565
        self.post_grad_pass_manager.configure(self.vllm_config)
566

567
568
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
569
        inductor_config = config.inductor_compile_config
570
571
        PASS_KEY = "post_grad_custom_post_pass"
        if PASS_KEY in inductor_config:
572
            if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
573
                # PassManager already added to config, make sure it's correct
574
575
576
577
                assert (
                    inductor_config[PASS_KEY].uuid()
                    == self.post_grad_pass_manager.uuid()
                )
578
            else:
579
                # Config should automatically wrap all inductor passes
580
581
                assert isinstance(inductor_config[PASS_KEY], InductorPass)
                self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
582
        inductor_config[PASS_KEY] = self.post_grad_pass_manager
583

584
585
586
    def __call__(
        self, graph: fx.GraphModule, example_inputs
    ) -> VllmSerializableFunction:
587
        vllm_config = self.vllm_config
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
        # Minimal hashing here with existing utilities, reused below.

        env_factors = envs.compile_factors()
        env_hash = hash_factors(env_factors)
        # Compute config/compiler/code hashes once and reuse
        config_hash = vllm_config.compute_hash()
        compiler_hash = self.compiler_manager.compute_hash(vllm_config)
        forward_code_files = list(sorted(self.compilation_config.traced_files))

        logger.debug(
            "Traced files (to be considered for compilation cache):\n%s",
            lazy(lambda: "\n".join(forward_code_files)),
        )
        hash_content = []
        for filepath in forward_code_files:
            hash_content.append(filepath)
            if filepath == "<string>":
                # This means the function was dynamically generated, with
                # e.g. exec(). We can't actually check these.
                continue
            try:
                with open(filepath) as f:
                    hash_content.append(f.read())
            except Exception:
                logger.warning("Failed to read file %s", filepath)
                continue
        code_hash = hashlib.sha256("\n".join(hash_content).encode()).hexdigest()
        # Clear after consumption
        self.compilation_config.traced_files.clear()
617
618
619
620
621
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.
622
623
624
625
            factors = [env_hash, config_hash, code_hash, compiler_hash]
            # Use SHA-256 for cache key hashing to be consistent across
            # compute_hash functions. Truncate for a short cache dir name.
            hash_key = hashlib.sha256(str(factors).encode()).hexdigest()[:10]
626
            cache_dir = os.path.join(
627
                envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key
628
629
630
            )
            self.compilation_config.cache_dir = cache_dir

631
        cache_dir = self.compilation_config.cache_dir
632
        os.makedirs(cache_dir, exist_ok=True)
633
        self.compilation_config.cache_dir = cache_dir
634
635
        rank = vllm_config.parallel_config.rank
        dp_rank = vllm_config.parallel_config.data_parallel_rank
636
        local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
637
        os.makedirs(local_cache_dir, exist_ok=True)
638
        self.compilation_config.local_cache_dir = local_cache_dir
639

640
        # Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
641
642
643
        disable_cache = not is_compile_cache_enabled(
            self.compilation_config.inductor_compile_config
        )
644
645

        if disable_cache:
646
            logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
647
        else:
648
649
650
651
            logger.info_once(
                "Using cache directory: %s for vLLM's torch.compile",
                local_cache_dir,
                scope="local",
652
            )
653

654
655
656
        self.compiler_manager.initialize_cache(
            local_cache_dir, disable_cache, self.prefix
        )
657

658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
        # Reuses existing cache key

        logger.debug(
            "torch.compile cache factors: env=%s cfg=%s comp=%s code=%s dir=%s",
            env_hash,
            config_hash,
            compiler_hash,
            code_hash,
            local_cache_dir,
        )

        # Persist and log only hash-relevant factors together.
        try:
            logger.debug(
                "Compile env factors (raw):\n%s\nVllm config hash: %s",
                lazy(partial(pprint.pformat, env_factors, width=120)),
                config_hash,
            )
            meta_path = os.path.join(local_cache_dir, "cache_key_factors.json")
            if not os.path.exists(meta_path):
                with open(meta_path, "w") as f:
                    json.dump(
                        {
                            "env": env_factors,  # raw factors used for env_hash
                            "config_hash": config_hash,
                            "code_hash": code_hash,
                            "compiler_hash": compiler_hash,
                        },
                        f,
                        indent=2,
                        sort_keys=True,
                    )
        except Exception:
            # Best-effort only; metadata write failures are non-fatal.
            logger.warning(
                (
                    "Could not write compile cache metadata at %s; continuing without "
                    "metadata. Compiled cache remains valid; diagnostics may be "
                    "limited."
                ),
                local_cache_dir,
                exc_info=True,
            )

702
703
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
704
        compilation_counter.num_graphs_seen += 1
705
        from .monitor import torch_compile_start_time
706

707
        dynamo_time = time.time() - torch_compile_start_time
708
709
710
        logger.info_once(
            "Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local"
        )
711
        self.compilation_config.compilation_time += dynamo_time
712
713
714
715
716
717

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
718
        self.configure_post_pass()
719

720
721
722
723
724
725
        if self.compilation_config.use_inductor_graph_partition:
            # Let Inductor decide partitioning; avoid FX-level pre-splitting.
            fx_split_ops: list[str] = []
        else:
            fx_split_ops = self.compilation_config.splitting_ops or []

726
        self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
727

728
        from torch._dynamo.utils import lazy_format_graph_code
729
730
731
732
733

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
734

735
        compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
736
        submod_names_to_compile = [
737
738
            item.submod_name
            for item in self.piecewise_graphs
739
740
741
742
743
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
744
745
746
        PiecewiseCompileInterpreter(
            self.split_gm, submod_names_to_compile, self.vllm_config, self
        ).run(*example_inputs)
747

748
749
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
750
751
            # code adapted from
            # https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30
752
            # use `print_readable` because it can include submodules
753
754
755
756
            src = (
                "from __future__ import annotations\nimport torch\n"
                + self.split_gm.print_readable(print_output=False)
            )
757
758
759
760
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

761
762
763
            logger.debug_once(
                "Computation graph saved to %s", graph_path, scope="local"
            )
764

765
766
        self._called = True

767
768
769
770
        if (
            self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
            or not self.compilation_config.cudagraph_copy_inputs
        ):
771
772
773
            return VllmSerializableFunction(
                graph, example_inputs, self.prefix, self.split_gm
            )
774
775
776

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
777

778
779
780
781
782
783
784
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
785
786
787
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
788

789
        self.sym_tensor_indices = [
790
791
792
793
            i
            for i, x in enumerate(fake_args)
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
            and any(is_symbolic(d) for d in x.size())
794
795
796
797
798
799
800
801
802
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

youkaichao's avatar
youkaichao committed
803
804
        # this is the callable we return to Dynamo to run
        def copy_and_call(*args):
805
806
807
808
809
810
811
812
813
814
815
816
817
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

818
819
820
        return VllmSerializableFunction(
            graph, example_inputs, self.prefix, copy_and_call
        )