backends.py 30.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import ast
5
import dataclasses
6
import hashlib
7
import json
8
import operator
9
10
import os
import pprint
11
import time
12
from collections.abc import Callable, Sequence
13
from contextlib import contextmanager
14
from functools import partial
15
from typing import Any
16
17
18

import torch
import torch.fx as fx
19
from torch._dispatch.python import enable_python_dispatcher
20

21
import vllm.envs as envs
22
23
24
from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.partition_rules import (
    inductor_partition_rule_context,
25
    should_split,
26
)
27
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
28
from vllm.config.utils import hash_factors
29
from vllm.logger import init_logger
30
from vllm.logging_utils import lazy
31
from vllm.platforms import current_platform
32
from vllm.utils.import_utils import resolve_obj_by_qualname
33
from vllm.utils.torch_utils import is_torch_equal_or_newer
34

35
from .caching import VllmSerializableFunction
36
37
38
39
40
from .compiler_interface import (
    CompilerInterface,
    EagerAdaptor,
    InductorAdaptor,
    InductorStandaloneAdaptor,
41
    is_compile_cache_enabled,
42
)
43
from .counter import compilation_counter
44
45
from .inductor_pass import InductorPass
from .pass_manager import PostGradPassManager
46
47
48

logger = init_logger(__name__)

49

50
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
51
    if compilation_config.backend == "inductor":
52
53
        # Use standalone compile only if requested, version is new enough,
        # and the symbol actually exists in this PyTorch build.
54
55
56
57
58
        if (
            envs.VLLM_USE_STANDALONE_COMPILE
            and is_torch_equal_or_newer("2.8.0.dev")
            and hasattr(torch._inductor, "standalone_compile")
        ):
59
            logger.debug("Using InductorStandaloneAdaptor")
60
61
62
            return InductorStandaloneAdaptor(
                compilation_config.compile_cache_save_format
            )
63
        else:
64
            logger.debug("Using InductorAdaptor")
65
            return InductorAdaptor()
66
    elif compilation_config.backend == "eager":
67
        logger.debug("Using EagerAdaptor")
68
        return EagerAdaptor()
69
70
71
72
73
    else:
        logger.debug("Using custom backend: %s", compilation_config.backend)
        compiler = resolve_obj_by_qualname(current_platform.get_compile_backend())()
        assert isinstance(compiler, CompilerInterface)
        return compiler
74
75


76
77
78
79
80
class CompilerManager:
    """
    A manager to manage the compilation process, including
    caching the compiled graph, loading the compiled graph,
    and compiling the graph.
81

82
83
84
    The cache is a dict mapping
    `(runtime_shape, graph_index, backend_name)`
    to `any_data` returned from the compiler.
85

86
87
88
    When serializing the cache, we save it to a Python file
    for readability. We don't use json here because json doesn't
    support int as key.
89
90
    """

91
    def __init__(self, compilation_config: CompilationConfig):
92
        self.cache: dict[tuple[int | None, int, str], Any] = dict()
93
        self.is_cache_updated = False
94
95
        self.compilation_config = compilation_config
        self.compiler = make_compiler(compilation_config)
96

97
98
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        return self.compiler.compute_hash(vllm_config)
99

100
    @contextmanager
101
    def compile_context(self, runtime_shape: int | None = None):
102
103
104
105
106
        """Provide compilation context for the duration of compilation to set
        any torch global properties we want to scope to a single Inductor
        compilation (e.g. partition rules, pass context)."""
        with pass_context(runtime_shape):
            if self.compilation_config.use_inductor_graph_partition:
107
                with inductor_partition_rule_context(
108
                    self.compilation_config.splitting_ops
109
                ):
110
111
112
113
                    yield
            else:
                yield

114
115
116
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
    ):
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        """
        Initialize the cache directory for the compiler.

        The organization of the cache directory is as follows:
        cache_dir=/path/to/hash_str/rank_i_j/prefix/
        inside cache_dir, there will be:
        - vllm_compile_cache.py
        - computation_graph.py
        - transformed_code.py

        for multiple prefixes, they can share the same
        base cache dir of /path/to/hash_str/rank_i_j/ ,
        to store some common compilation artifacts.
        """

132
        self.disable_cache = disable_cache
133
        self.cache_dir = cache_dir
134
135
136
137
        self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")

        if not disable_cache and os.path.exists(self.cache_file_path):
            # load the cache from the file
138
            with open(self.cache_file_path) as f:
139
140
141
142
143
                # we use ast.literal_eval to parse the data
                # because it is a safe way to parse Python literals.
                # do not use eval(), it is unsafe.
                self.cache = ast.literal_eval(f.read())

144
145
146
        self.compiler.initialize_cache(
            cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
        )
147
148

    def save_to_file(self):
149
        if self.disable_cache or not self.is_cache_updated:
150
            return
151
152
        printer = pprint.PrettyPrinter(indent=4)
        data = printer.pformat(self.cache)
153
        with open(self.cache_file_path, "w") as f:
154
155
            f.write(data)

156
157
158
159
160
    def load(
        self,
        graph: fx.GraphModule,
        example_inputs: list[Any],
        graph_index: int,
161
162
        runtime_shape: int | None = None,
    ) -> Callable | None:
163
164
165
        if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
            return None
        handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
166
167
168
        compiled_graph = self.compiler.load(
            handle, graph, example_inputs, graph_index, runtime_shape
        )
169
170
        if runtime_shape is None:
            logger.debug(
171
172
173
174
175
                "Directly load the %s-th graph for dynamic shape from %s via handle %s",
                graph_index,
                self.compiler.name,
                handle,
            )
176
177
        else:
            logger.debug(
178
179
180
181
182
183
                "Directly load the %s-th graph for shape %s from %s via handle %s",
                graph_index,
                str(runtime_shape),
                self.compiler.name,
                handle,
            )
184
185
        return compiled_graph

186
187
188
189
190
191
192
193
    def compile(
        self,
        graph: fx.GraphModule,
        example_inputs,
        additional_inductor_config,
        compilation_config: CompilationConfig,
        graph_index: int = 0,
        num_graphs: int = 1,
194
        runtime_shape: int | None = None,
195
    ) -> Any:
196
        if graph_index == 0:
197
198
199
200
201
202
203
204
205
            # before compiling the first graph, record the start time
            global compilation_start_time
            compilation_start_time = time.time()

        compilation_counter.num_backend_compilations += 1

        compiled_graph = None

        # try to load from the cache
206
        compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
207
        if compiled_graph is not None:
208
209
210
211
212
            if graph_index == num_graphs - 1:
                # after loading the last graph for this shape, record the time.
                # there can be multiple graphs due to piecewise compilation.
                now = time.time()
                elapsed = now - compilation_start_time
213
                compilation_config.compilation_time += elapsed
214
215
216
                if runtime_shape is None:
                    logger.info(
                        "Directly load the compiled graph(s) for dynamic shape "
217
218
219
                        "from the cache, took %.3f s",
                        elapsed,
                    )
220
221
222
                else:
                    logger.info(
                        "Directly load the compiled graph(s) for shape %s "
223
224
225
226
                        "from the cache, took %.3f s",
                        str(runtime_shape),
                        elapsed,
                    )
227
228
229
230
            return compiled_graph

        # no compiler cached the graph, or the cache is disabled,
        # we need to compile it
231
232
233
234
        if isinstance(self.compiler, InductorAdaptor):
            # Let compile_fx generate a key for us
            maybe_key = None
        else:
235
            maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
236
237
238
239
240
241
242
243
244

        with self.compile_context(runtime_shape):
            compiled_graph, handle = self.compiler.compile(
                graph,
                example_inputs,
                additional_inductor_config,
                runtime_shape,
                maybe_key,
            )
245
246
247
248

        assert compiled_graph is not None, "Failed to compile the graph"

        # store the artifact in the cache
249
        if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
250
            self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
251
            compilation_counter.num_cache_entries_updated += 1
252
            self.is_cache_updated = True
253
254
            if graph_index == 0:
                # adds some info logging for the first graph
255
                if runtime_shape is None:
256
257
258
                    logger.info_once(
                        "Cache the graph for dynamic shape for later use", scope="local"
                    )
259
                else:
260
261
262
263
                    logger.info_once(
                        "Cache the graph of shape %s for later use",
                        str(runtime_shape),
                        scope="local",
264
                    )
265
266
            if runtime_shape is None:
                logger.debug(
267
268
269
270
271
                    "Store the %s-th graph for dynamic shape from %s via handle %s",
                    graph_index,
                    self.compiler.name,
                    handle,
                )
272
273
274
            else:
                logger.debug(
                    "Store the %s-th graph for shape %s from %s via handle %s",
275
276
277
278
279
                    graph_index,
                    str(runtime_shape),
                    self.compiler.name,
                    handle,
                )
280
281
282
283
284
285
286

        # after compiling the last graph, record the end time
        if graph_index == num_graphs - 1:
            now = time.time()
            elapsed = now - compilation_start_time
            compilation_config.compilation_time += elapsed
            if runtime_shape is None:
287
288
289
290
291
                logger.info_once(
                    "Compiling a graph for dynamic shape takes %.2f s",
                    elapsed,
                    scope="local",
                )
292
            else:
293
                logger.info_once(
294
295
296
                    "Compiling a graph for shape %s takes %.2f s",
                    runtime_shape,
                    elapsed,
297
                    scope="local",
298
                )
299

300
        return compiled_graph
301
302


303
304
305
@dataclasses.dataclass
class SplitItem:
    submod_name: str
306
    graph_id: int
307
308
309
310
    is_splitting_graph: bool
    graph: fx.GraphModule


311
def split_graph(
312
    graph: fx.GraphModule, splitting_ops: list[str]
313
) -> tuple[fx.GraphModule, list[SplitItem]]:
314
315
    # split graph by ops
    subgraph_id = 0
316
317
    node_to_subgraph_id: dict[fx.Node, int] = {}
    split_op_graphs: list[int] = []
318
319
320
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
321

322
323
324
325
326
327
328
329
330
331
332
333
        # Check if this is a getitem operation on a node from an earlier subgraph.
        # If so, assign it to the same subgraph as its input to avoid passing entire
        # tuple as input to submodules, which is against standalone_compile and
        # AoTAutograd input requirement.
        if node.op == "call_function" and node.target == operator.getitem:
            # Assign this getitem to the same subgraph as its input
            input_node = node.args[0]
            if input_node.op != "placeholder":
                assert input_node in node_to_subgraph_id
                node_to_subgraph_id[node] = node_to_subgraph_id[input_node]
                continue

334
        if should_split(node, splitting_ops):
335
336
337
338
339
340
341
342
343
344
345
346
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
347
348
        graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
    )
349

350
    outputs = []
351

352
    names = [name for (name, module) in split_gm.named_modules()]
353

354
355
356
357
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
358

359
        module = getattr(split_gm, name)
360

361
        graph_id = int(name.replace("submod_", ""))
362
        outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
363

364
    # sort by integer graph_id, rather than string name
365
    outputs.sort(key=lambda x: x.graph_id)
366

367
    return split_gm, outputs
368
369


370
371
compilation_start_time = 0.0

372
373
374
375
376
377

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
378
379
380
381
382

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
383
384
    """

385
386
387
388
389
390
391
    def __init__(
        self,
        module: torch.fx.GraphModule,
        compile_submod_names: list[str],
        vllm_config: VllmConfig,
        vllm_backend: "VllmBackend",
    ):
392
393
        super().__init__(module)
        from torch._guards import detect_fake_mode
394

395
396
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
397
398
        self.compilation_config = vllm_config.compilation_config
        self.vllm_config = vllm_config
399
        self.vllm_backend = vllm_backend
400
401
        # When True, it annoyingly dumps the torch.fx.Graph on errors.
        self.extra_traceback = False
402
403
404
405
406
407

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
408
        with self.fake_mode, enable_python_dispatcher():
409
            return super().run(*fake_args)
410

411
412
413
414
415
416
    def call_module(
        self,
        target: torch.fx.node.Target,
        args: tuple[torch.fx.node.Argument, ...],
        kwargs: dict[str, Any],
    ) -> Any:
417
418
419
420
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
421
            index = self.compile_submod_names.index(target)
422
423
424
425
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
426
            global compilation_start_time
427

428
429
430
431
432
433
434
435
436
437
438
            compiled_graph_for_dynamic_shape = (
                self.vllm_backend.compiler_manager.compile(
                    submod,
                    args,
                    self.compilation_config.inductor_compile_config,
                    self.compilation_config,
                    graph_index=index,
                    num_graphs=len(self.compile_submod_names),
                    runtime_shape=None,
                )
            )
439
            # Lazy import here to avoid circular import
440
            from .piecewise_backend import PiecewiseBackend
441

442
            piecewise_backend = PiecewiseBackend(
443
444
445
446
447
448
449
450
                submod,
                self.vllm_config,
                index,
                len(self.compile_submod_names),
                sym_shape_indices,
                compiled_graph_for_dynamic_shape,
                self.vllm_backend,
            )
451

452
453
454
455
            if (
                self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
                and not self.compilation_config.use_inductor_graph_partition
            ):
456
457
458
459
                # We're using Dynamo-based piecewise splitting, so we wrap
                # the whole subgraph with a static graph wrapper.
                from .cuda_graph import CUDAGraphOptions

460
461
462
                # resolve the static graph wrapper class (e.g. CUDAGraphWrapper
                # class) as platform dependent.
                static_graph_wrapper_class = resolve_obj_by_qualname(
463
464
                    current_platform.get_static_graph_wrapper_cls()
                )
465
466
467
468
469
470
471
472
473
474
475
476

                # Always assign PIECEWISE runtime mode to the
                # CUDAGraphWrapper for piecewise_backend, to distinguish
                # it from the FULL cudagraph runtime mode, no matter it
                # is wrapped on a full or piecewise fx graph.
                self.module.__dict__[target] = static_graph_wrapper_class(
                    runnable=piecewise_backend,
                    vllm_config=self.vllm_config,
                    runtime_mode=CUDAGraphMode.PIECEWISE,
                    cudagraph_options=CUDAGraphOptions(
                        debug_log_enable=piecewise_backend.is_first_graph,
                        gc_disable=not piecewise_backend.is_first_graph,
477
478
479
                        weak_ref_output=piecewise_backend.is_last_graph,
                    ),
                )
480
481
482
            else:
                self.module.__dict__[target] = piecewise_backend

483
484
485
486
487
            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


488
489
490
491
492
493
494
495
496
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"


@contextmanager
def set_model_tag(tag: str):
    """Context manager to set the model tag."""
    global model_tag
497
    assert tag != model_tag, (
498
        f"Model tag {tag} is the same as the current tag {model_tag}."
499
    )
500
501
502
503
504
505
506
507
    old_tag = model_tag
    model_tag = tag
    try:
        yield
    finally:
        model_tag = old_tag


508
class VllmBackend:
509
    """The compilation backend for `torch.compile` with vLLM.
510
    It is used for compilation mode of `CompilationMode.VLLM_COMPILE`,
511
    where we customize the compilation.
512

513
514
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
515

516
517
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
518
    """
519

520
521
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
522
523
524
525
526
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
527
    piecewise_graphs: list[SplitItem]
528
    returned_callable: Callable
529
530
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
531
532
    sym_tensor_indices: list[int]
    input_buffers: list[torch.Tensor]
533
    compiler_manager: CompilerManager
534

535
536
    def __init__(
        self,
537
        vllm_config: VllmConfig,
538
        prefix: str = "",
539
    ):
540
541
        # if the model is initialized with a non-empty prefix,
        # then usually it's enough to use that prefix,
542
        # e.g. language_model, vision_model, etc.
543
544
545
546
547
        # when multiple parts are initialized as independent
        # models, we need to use the model_tag to distinguish
        # them, e.g. backbone (default), eagle_head, etc.
        self.prefix = prefix or model_tag

548
        # Passes to run on the graph post-grad.
549
550
551
552
        self.pass_manager = resolve_obj_by_qualname(
            current_platform.get_pass_manager_cls()
        )()
        self.pass_key = current_platform.pass_key
553

554
555
556
        self.sym_tensor_indices = []
        self.input_buffers = []

557
558
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
559

560
        self.compiler_manager: CompilerManager = CompilerManager(
561
562
            self.compilation_config
        )
563

564
565
566
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

567
    def configure_post_pass(self):
568
        config = self.compilation_config
569
        self.pass_manager.configure(self.vllm_config)
570

571
572
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
573
        inductor_config = config.inductor_compile_config
574
575
        if self.pass_key in inductor_config:
            if isinstance(inductor_config[self.pass_key], PostGradPassManager):
576
                # PassManager already added to config, make sure it's correct
577
                assert inductor_config[self.pass_key].uuid() == self.pass_manager.uuid()
578
            else:
579
                # Config should automatically wrap all inductor passes
580
581
582
                assert isinstance(inductor_config[self.pass_key], InductorPass)
                self.pass_manager.add(inductor_config[self.pass_key])
        inductor_config[self.pass_key] = self.pass_manager
583

584
585
586
    def __call__(
        self, graph: fx.GraphModule, example_inputs
    ) -> VllmSerializableFunction:
587
        vllm_config = self.vllm_config
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
        # Minimal hashing here with existing utilities, reused below.

        env_factors = envs.compile_factors()
        env_hash = hash_factors(env_factors)
        # Compute config/compiler/code hashes once and reuse
        config_hash = vllm_config.compute_hash()
        compiler_hash = self.compiler_manager.compute_hash(vllm_config)
        forward_code_files = list(sorted(self.compilation_config.traced_files))

        logger.debug(
            "Traced files (to be considered for compilation cache):\n%s",
            lazy(lambda: "\n".join(forward_code_files)),
        )
        hash_content = []
        for filepath in forward_code_files:
            hash_content.append(filepath)
            if filepath == "<string>":
                # This means the function was dynamically generated, with
                # e.g. exec(). We can't actually check these.
                continue
            try:
                with open(filepath) as f:
                    hash_content.append(f.read())
            except Exception:
                logger.warning("Failed to read file %s", filepath)
                continue
        code_hash = hashlib.sha256("\n".join(hash_content).encode()).hexdigest()
        # Clear after consumption
        self.compilation_config.traced_files.clear()
617
618
619
620
621
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.
622
623
624
625
            factors = [env_hash, config_hash, code_hash, compiler_hash]
            # Use SHA-256 for cache key hashing to be consistent across
            # compute_hash functions. Truncate for a short cache dir name.
            hash_key = hashlib.sha256(str(factors).encode()).hexdigest()[:10]
626
            cache_dir = os.path.join(
627
                envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key
628
629
630
            )
            self.compilation_config.cache_dir = cache_dir

631
        cache_dir = self.compilation_config.cache_dir
632
        os.makedirs(cache_dir, exist_ok=True)
633
        self.compilation_config.cache_dir = cache_dir
634
635
        rank = vllm_config.parallel_config.rank
        dp_rank = vllm_config.parallel_config.data_parallel_rank
636
        local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
637
        os.makedirs(local_cache_dir, exist_ok=True)
638
        self.compilation_config.local_cache_dir = local_cache_dir
639

640
        # Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
641
642
643
        disable_cache = not is_compile_cache_enabled(
            self.compilation_config.inductor_compile_config
        )
644
645

        if disable_cache:
646
            logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
647
        else:
648
649
650
651
            logger.info_once(
                "Using cache directory: %s for vLLM's torch.compile",
                local_cache_dir,
                scope="local",
652
            )
653

654
655
656
        self.compiler_manager.initialize_cache(
            local_cache_dir, disable_cache, self.prefix
        )
657

658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
        # Reuses existing cache key

        logger.debug(
            "torch.compile cache factors: env=%s cfg=%s comp=%s code=%s dir=%s",
            env_hash,
            config_hash,
            compiler_hash,
            code_hash,
            local_cache_dir,
        )

        # Persist and log only hash-relevant factors together.
        try:
            logger.debug(
                "Compile env factors (raw):\n%s\nVllm config hash: %s",
                lazy(partial(pprint.pformat, env_factors, width=120)),
                config_hash,
            )
            meta_path = os.path.join(local_cache_dir, "cache_key_factors.json")
            if not os.path.exists(meta_path):
                with open(meta_path, "w") as f:
                    json.dump(
                        {
                            "env": env_factors,  # raw factors used for env_hash
                            "config_hash": config_hash,
                            "code_hash": code_hash,
                            "compiler_hash": compiler_hash,
                        },
                        f,
                        indent=2,
                        sort_keys=True,
                    )
        except Exception:
            # Best-effort only; metadata write failures are non-fatal.
            logger.warning(
                (
                    "Could not write compile cache metadata at %s; continuing without "
                    "metadata. Compiled cache remains valid; diagnostics may be "
                    "limited."
                ),
                local_cache_dir,
                exc_info=True,
            )

702
703
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
704
        compilation_counter.num_graphs_seen += 1
705
        from .monitor import torch_compile_start_time
706

707
        dynamo_time = time.time() - torch_compile_start_time
708
709
710
        logger.info_once(
            "Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local"
        )
711
        self.compilation_config.compilation_time += dynamo_time
712
713
714
715
716
717

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
718
        self.configure_post_pass()
719

720
721
722
723
724
725
        if self.compilation_config.use_inductor_graph_partition:
            # Let Inductor decide partitioning; avoid FX-level pre-splitting.
            fx_split_ops: list[str] = []
        else:
            fx_split_ops = self.compilation_config.splitting_ops or []

726
        self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
727

728
        from torch._dynamo.utils import lazy_format_graph_code
729
730
731
732
733

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
734

735
        compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
736
        submod_names_to_compile = [
737
738
            item.submod_name
            for item in self.piecewise_graphs
739
740
741
742
743
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
744
745
746
        PiecewiseCompileInterpreter(
            self.split_gm, submod_names_to_compile, self.vllm_config, self
        ).run(*example_inputs)
747

748
749
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
750
751
            # code adapted from
            # https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30
752
            # use `print_readable` because it can include submodules
753
754
755
756
            src = (
                "from __future__ import annotations\nimport torch\n"
                + self.split_gm.print_readable(print_output=False)
            )
757
758
759
760
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

761
762
763
            logger.debug_once(
                "Computation graph saved to %s", graph_path, scope="local"
            )
764

765
766
        self._called = True

767
768
769
770
        if (
            self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
            or not self.compilation_config.cudagraph_copy_inputs
        ):
771
772
773
            return VllmSerializableFunction(
                graph, example_inputs, self.prefix, self.split_gm
            )
774
775
776

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
777

778
779
780
781
782
783
784
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
785
786
787
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
788

789
        self.sym_tensor_indices = [
790
791
792
793
            i
            for i, x in enumerate(fake_args)
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
            and any(is_symbolic(d) for d in x.size())
794
795
796
797
798
799
800
801
802
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

youkaichao's avatar
youkaichao committed
803
804
        # this is the callable we return to Dynamo to run
        def copy_and_call(*args):
805
806
807
808
809
810
811
812
813
814
815
816
817
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

818
819
820
        return VllmSerializableFunction(
            graph, example_inputs, self.prefix, copy_and_call
        )