backends.py 27.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import ast
5
import dataclasses
6
import hashlib
7
import operator
8
9
import os
import pprint
10
import time
11
from collections.abc import Callable, Sequence
12
from contextlib import contextmanager
13
from typing import Any
14
15
16

import torch
import torch.fx as fx
17
from torch._dispatch.python import enable_python_dispatcher
18

19
import vllm.envs as envs
20
21
22
from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.partition_rules import (
    inductor_partition_rule_context,
23
    should_split,
24
)
25
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
26
from vllm.logger import init_logger
27
from vllm.platforms import current_platform
28
from vllm.utils.import_utils import resolve_obj_by_qualname
29
from vllm.utils.torch_utils import is_torch_equal_or_newer
30

31
from .caching import VllmSerializableFunction
32
33
34
35
36
from .compiler_interface import (
    CompilerInterface,
    EagerAdaptor,
    InductorAdaptor,
    InductorStandaloneAdaptor,
37
    is_compile_cache_enabled,
38
)
39
from .counter import compilation_counter
40
41
from .inductor_pass import InductorPass
from .pass_manager import PostGradPassManager
42
43
44

logger = init_logger(__name__)

45

46
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
47
    if compilation_config.backend == "inductor":
48
49
        # Use standalone compile only if requested, version is new enough,
        # and the symbol actually exists in this PyTorch build.
50
51
52
53
54
        if (
            envs.VLLM_USE_STANDALONE_COMPILE
            and is_torch_equal_or_newer("2.8.0.dev")
            and hasattr(torch._inductor, "standalone_compile")
        ):
55
            logger.debug("Using InductorStandaloneAdaptor")
56
57
58
            return InductorStandaloneAdaptor(
                compilation_config.compile_cache_save_format
            )
59
        else:
60
            logger.debug("Using InductorAdaptor")
61
62
            return InductorAdaptor()
    else:
63
        assert compilation_config.backend == "eager", (
64
            "Custom backends not supported with CompilationMode.VLLM_COMPILE"
65
66
        )

67
        logger.debug("Using EagerAdaptor")
68
69
70
        return EagerAdaptor()


71
72
73
74
75
class CompilerManager:
    """
    A manager to manage the compilation process, including
    caching the compiled graph, loading the compiled graph,
    and compiling the graph.
76

77
78
79
    The cache is a dict mapping
    `(runtime_shape, graph_index, backend_name)`
    to `any_data` returned from the compiler.
80

81
82
83
    When serializing the cache, we save it to a Python file
    for readability. We don't use json here because json doesn't
    support int as key.
84
85
    """

86
    def __init__(self, compilation_config: CompilationConfig):
87
        self.cache: dict[tuple[int | None, int, str], Any] = dict()
88
        self.is_cache_updated = False
89
90
        self.compilation_config = compilation_config
        self.compiler = make_compiler(compilation_config)
91

92
93
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        return self.compiler.compute_hash(vllm_config)
94

95
    @contextmanager
96
    def compile_context(self, runtime_shape: int | None = None):
97
98
99
100
101
        """Provide compilation context for the duration of compilation to set
        any torch global properties we want to scope to a single Inductor
        compilation (e.g. partition rules, pass context)."""
        with pass_context(runtime_shape):
            if self.compilation_config.use_inductor_graph_partition:
102
                with inductor_partition_rule_context(
103
                    self.compilation_config.splitting_ops
104
                ):
105
106
107
108
                    yield
            else:
                yield

109
110
111
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
    ):
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        """
        Initialize the cache directory for the compiler.

        The organization of the cache directory is as follows:
        cache_dir=/path/to/hash_str/rank_i_j/prefix/
        inside cache_dir, there will be:
        - vllm_compile_cache.py
        - computation_graph.py
        - transformed_code.py

        for multiple prefixes, they can share the same
        base cache dir of /path/to/hash_str/rank_i_j/ ,
        to store some common compilation artifacts.
        """

127
        self.disable_cache = disable_cache
128
        self.cache_dir = cache_dir
129
130
131
132
        self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")

        if not disable_cache and os.path.exists(self.cache_file_path):
            # load the cache from the file
133
            with open(self.cache_file_path) as f:
134
135
136
137
138
                # we use ast.literal_eval to parse the data
                # because it is a safe way to parse Python literals.
                # do not use eval(), it is unsafe.
                self.cache = ast.literal_eval(f.read())

139
140
141
        self.compiler.initialize_cache(
            cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
        )
142
143

    def save_to_file(self):
144
        if self.disable_cache or not self.is_cache_updated:
145
            return
146
147
        printer = pprint.PrettyPrinter(indent=4)
        data = printer.pformat(self.cache)
148
        with open(self.cache_file_path, "w") as f:
149
150
            f.write(data)

151
152
153
154
155
    def load(
        self,
        graph: fx.GraphModule,
        example_inputs: list[Any],
        graph_index: int,
156
157
        runtime_shape: int | None = None,
    ) -> Callable | None:
158
159
160
        if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
            return None
        handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
161
162
163
        compiled_graph = self.compiler.load(
            handle, graph, example_inputs, graph_index, runtime_shape
        )
164
165
        if runtime_shape is None:
            logger.debug(
166
167
168
169
170
                "Directly load the %s-th graph for dynamic shape from %s via handle %s",
                graph_index,
                self.compiler.name,
                handle,
            )
171
172
        else:
            logger.debug(
173
174
175
176
177
178
                "Directly load the %s-th graph for shape %s from %s via handle %s",
                graph_index,
                str(runtime_shape),
                self.compiler.name,
                handle,
            )
179
180
        return compiled_graph

181
182
183
184
185
186
187
188
    def compile(
        self,
        graph: fx.GraphModule,
        example_inputs,
        additional_inductor_config,
        compilation_config: CompilationConfig,
        graph_index: int = 0,
        num_graphs: int = 1,
189
        runtime_shape: int | None = None,
190
    ) -> Any:
191
        if graph_index == 0:
192
193
194
195
196
197
198
199
200
            # before compiling the first graph, record the start time
            global compilation_start_time
            compilation_start_time = time.time()

        compilation_counter.num_backend_compilations += 1

        compiled_graph = None

        # try to load from the cache
201
        compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
202
        if compiled_graph is not None:
203
204
205
206
207
            if graph_index == num_graphs - 1:
                # after loading the last graph for this shape, record the time.
                # there can be multiple graphs due to piecewise compilation.
                now = time.time()
                elapsed = now - compilation_start_time
208
                compilation_config.compilation_time += elapsed
209
210
211
                if runtime_shape is None:
                    logger.info(
                        "Directly load the compiled graph(s) for dynamic shape "
212
213
214
                        "from the cache, took %.3f s",
                        elapsed,
                    )
215
216
217
                else:
                    logger.info(
                        "Directly load the compiled graph(s) for shape %s "
218
219
220
221
                        "from the cache, took %.3f s",
                        str(runtime_shape),
                        elapsed,
                    )
222
223
224
225
            return compiled_graph

        # no compiler cached the graph, or the cache is disabled,
        # we need to compile it
226
227
228
229
        if isinstance(self.compiler, InductorAdaptor):
            # Let compile_fx generate a key for us
            maybe_key = None
        else:
230
            maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
231
232
233
234
235
236
237
238
239

        with self.compile_context(runtime_shape):
            compiled_graph, handle = self.compiler.compile(
                graph,
                example_inputs,
                additional_inductor_config,
                runtime_shape,
                maybe_key,
            )
240
241
242
243

        assert compiled_graph is not None, "Failed to compile the graph"

        # store the artifact in the cache
244
        if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
245
            self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
246
            compilation_counter.num_cache_entries_updated += 1
247
            self.is_cache_updated = True
248
249
            if graph_index == 0:
                # adds some info logging for the first graph
250
                if runtime_shape is None:
251
252
253
                    logger.info_once(
                        "Cache the graph for dynamic shape for later use", scope="local"
                    )
254
                else:
255
256
257
258
                    logger.info_once(
                        "Cache the graph of shape %s for later use",
                        str(runtime_shape),
                        scope="local",
259
                    )
260
261
            if runtime_shape is None:
                logger.debug(
262
263
264
265
266
                    "Store the %s-th graph for dynamic shape from %s via handle %s",
                    graph_index,
                    self.compiler.name,
                    handle,
                )
267
268
269
            else:
                logger.debug(
                    "Store the %s-th graph for shape %s from %s via handle %s",
270
271
272
273
274
                    graph_index,
                    str(runtime_shape),
                    self.compiler.name,
                    handle,
                )
275
276
277
278
279
280
281

        # after compiling the last graph, record the end time
        if graph_index == num_graphs - 1:
            now = time.time()
            elapsed = now - compilation_start_time
            compilation_config.compilation_time += elapsed
            if runtime_shape is None:
282
283
284
285
286
                logger.info_once(
                    "Compiling a graph for dynamic shape takes %.2f s",
                    elapsed,
                    scope="local",
                )
287
            else:
288
                logger.info_once(
289
290
291
                    "Compiling a graph for shape %s takes %.2f s",
                    runtime_shape,
                    elapsed,
292
                    scope="local",
293
                )
294

295
        return compiled_graph
296
297


298
299
300
@dataclasses.dataclass
class SplitItem:
    submod_name: str
301
    graph_id: int
302
303
304
305
    is_splitting_graph: bool
    graph: fx.GraphModule


306
def split_graph(
307
    graph: fx.GraphModule, splitting_ops: list[str]
308
) -> tuple[fx.GraphModule, list[SplitItem]]:
309
310
    # split graph by ops
    subgraph_id = 0
311
312
    node_to_subgraph_id: dict[fx.Node, int] = {}
    split_op_graphs: list[int] = []
313
314
315
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
316

317
318
319
320
321
322
323
324
325
326
327
328
        # Check if this is a getitem operation on a node from an earlier subgraph.
        # If so, assign it to the same subgraph as its input to avoid passing entire
        # tuple as input to submodules, which is against standalone_compile and
        # AoTAutograd input requirement.
        if node.op == "call_function" and node.target == operator.getitem:
            # Assign this getitem to the same subgraph as its input
            input_node = node.args[0]
            if input_node.op != "placeholder":
                assert input_node in node_to_subgraph_id
                node_to_subgraph_id[node] = node_to_subgraph_id[input_node]
                continue

329
        if should_split(node, splitting_ops):
330
331
332
333
334
335
336
337
338
339
340
341
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
342
343
        graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
    )
344

345
    outputs = []
346

347
    names = [name for (name, module) in split_gm.named_modules()]
348

349
350
351
352
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
353

354
        module = getattr(split_gm, name)
355

356
        graph_id = int(name.replace("submod_", ""))
357
        outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
358

359
    # sort by integer graph_id, rather than string name
360
    outputs.sort(key=lambda x: x.graph_id)
361

362
    return split_gm, outputs
363
364


365
366
compilation_start_time = 0.0

367
368
369
370
371
372

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
373
374
375
376
377

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
378
379
    """

380
381
382
383
384
385
386
    def __init__(
        self,
        module: torch.fx.GraphModule,
        compile_submod_names: list[str],
        vllm_config: VllmConfig,
        vllm_backend: "VllmBackend",
    ):
387
388
        super().__init__(module)
        from torch._guards import detect_fake_mode
389

390
391
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
392
393
        self.compilation_config = vllm_config.compilation_config
        self.vllm_config = vllm_config
394
        self.vllm_backend = vllm_backend
395
396
        # When True, it annoyingly dumps the torch.fx.Graph on errors.
        self.extra_traceback = False
397
398
399
400
401
402

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
403
        with self.fake_mode, enable_python_dispatcher():
404
            return super().run(*fake_args)
405

406
407
408
409
410
411
    def call_module(
        self,
        target: torch.fx.node.Target,
        args: tuple[torch.fx.node.Argument, ...],
        kwargs: dict[str, Any],
    ) -> Any:
412
413
414
415
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
416
            index = self.compile_submod_names.index(target)
417
418
419
420
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
421
            global compilation_start_time
422

423
424
425
426
427
428
429
430
431
432
433
            compiled_graph_for_dynamic_shape = (
                self.vllm_backend.compiler_manager.compile(
                    submod,
                    args,
                    self.compilation_config.inductor_compile_config,
                    self.compilation_config,
                    graph_index=index,
                    num_graphs=len(self.compile_submod_names),
                    runtime_shape=None,
                )
            )
434
            # Lazy import here to avoid circular import
435
            from .piecewise_backend import PiecewiseBackend
436

437
            piecewise_backend = PiecewiseBackend(
438
439
440
441
442
443
444
445
                submod,
                self.vllm_config,
                index,
                len(self.compile_submod_names),
                sym_shape_indices,
                compiled_graph_for_dynamic_shape,
                self.vllm_backend,
            )
446

447
448
449
450
            if (
                self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
                and not self.compilation_config.use_inductor_graph_partition
            ):
451
452
453
454
                # We're using Dynamo-based piecewise splitting, so we wrap
                # the whole subgraph with a static graph wrapper.
                from .cuda_graph import CUDAGraphOptions

455
456
457
                # resolve the static graph wrapper class (e.g. CUDAGraphWrapper
                # class) as platform dependent.
                static_graph_wrapper_class = resolve_obj_by_qualname(
458
459
                    current_platform.get_static_graph_wrapper_cls()
                )
460
461
462
463
464
465
466
467
468
469
470
471

                # Always assign PIECEWISE runtime mode to the
                # CUDAGraphWrapper for piecewise_backend, to distinguish
                # it from the FULL cudagraph runtime mode, no matter it
                # is wrapped on a full or piecewise fx graph.
                self.module.__dict__[target] = static_graph_wrapper_class(
                    runnable=piecewise_backend,
                    vllm_config=self.vllm_config,
                    runtime_mode=CUDAGraphMode.PIECEWISE,
                    cudagraph_options=CUDAGraphOptions(
                        debug_log_enable=piecewise_backend.is_first_graph,
                        gc_disable=not piecewise_backend.is_first_graph,
472
473
474
                        weak_ref_output=piecewise_backend.is_last_graph,
                    ),
                )
475
476
477
            else:
                self.module.__dict__[target] = piecewise_backend

478
479
480
481
482
            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


483
484
485
486
487
488
489
490
491
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"


@contextmanager
def set_model_tag(tag: str):
    """Context manager to set the model tag."""
    global model_tag
492
    assert tag != model_tag, (
493
        f"Model tag {tag} is the same as the current tag {model_tag}."
494
    )
495
496
497
498
499
500
501
502
    old_tag = model_tag
    model_tag = tag
    try:
        yield
    finally:
        model_tag = old_tag


503
class VllmBackend:
504
    """The compilation backend for `torch.compile` with vLLM.
505
    It is used for compilation mode of `CompilationMode.VLLM_COMPILE`,
506
    where we customize the compilation.
507

508
509
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
510

511
512
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
513
    """
514

515
516
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
517
518
519
520
521
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
522
    piecewise_graphs: list[SplitItem]
523
    returned_callable: Callable
524
525
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
526
527
    sym_tensor_indices: list[int]
    input_buffers: list[torch.Tensor]
528
    compiler_manager: CompilerManager
529

530
531
    def __init__(
        self,
532
        vllm_config: VllmConfig,
533
        prefix: str = "",
534
    ):
535
536
        # if the model is initialized with a non-empty prefix,
        # then usually it's enough to use that prefix,
537
        # e.g. language_model, vision_model, etc.
538
539
540
541
542
        # when multiple parts are initialized as independent
        # models, we need to use the model_tag to distinguish
        # them, e.g. backbone (default), eagle_head, etc.
        self.prefix = prefix or model_tag

543
544
        # Passes to run on the graph post-grad.
        self.post_grad_pass_manager = PostGradPassManager()
545

546
547
548
        self.sym_tensor_indices = []
        self.input_buffers = []

549
550
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
551

552
        self.compiler_manager: CompilerManager = CompilerManager(
553
554
            self.compilation_config
        )
555

556
557
558
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

559
    def configure_post_pass(self):
560
        config = self.compilation_config
561
        self.post_grad_pass_manager.configure(self.vllm_config)
562

563
564
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
565
        inductor_config = config.inductor_compile_config
566
567
        PASS_KEY = "post_grad_custom_post_pass"
        if PASS_KEY in inductor_config:
568
            if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
569
                # PassManager already added to config, make sure it's correct
570
571
572
573
                assert (
                    inductor_config[PASS_KEY].uuid()
                    == self.post_grad_pass_manager.uuid()
                )
574
            else:
575
                # Config should automatically wrap all inductor passes
576
577
                assert isinstance(inductor_config[PASS_KEY], InductorPass)
                self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
578
        inductor_config[PASS_KEY] = self.post_grad_pass_manager
579

580
581
582
583
584
    def __call__(
        self, graph: fx.GraphModule, example_inputs
    ) -> VllmSerializableFunction:
        from .caching import _compute_code_hash, compilation_config_hash_factors

585
        vllm_config = self.vllm_config
586
587
588
589
590
591
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.

592
            factors = compilation_config_hash_factors(vllm_config)
593
594
            # 2. factors come from the code files that are traced by Dynamo (
            #    it mainly summarizes how the model is used in forward pass)
595
            code_hash = _compute_code_hash(self.compilation_config.traced_files)
596
            self.compilation_config.traced_files.clear()
597
598
599
600
601
602
603
            factors.append(code_hash)

            # 3. compiler hash
            compiler_hash = self.compiler_manager.compute_hash(vllm_config)
            factors.append(compiler_hash)

            # combine all factors to generate the cache dir
604
605
606
            hash_key = hashlib.md5(
                str(factors).encode(), usedforsecurity=False
            ).hexdigest()[:10]
607
608

            cache_dir = os.path.join(
609
610
611
612
613
614
                envs.VLLM_CACHE_ROOT,
                "torch_compile_cache",
                hash_key,
            )
            self.compilation_config.cache_dir = cache_dir

615
        cache_dir = self.compilation_config.cache_dir
616
        os.makedirs(cache_dir, exist_ok=True)
617
        self.compilation_config.cache_dir = cache_dir
618
619
        rank = vllm_config.parallel_config.rank
        dp_rank = vllm_config.parallel_config.data_parallel_rank
620
        local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
621
        os.makedirs(local_cache_dir, exist_ok=True)
622
        self.compilation_config.local_cache_dir = local_cache_dir
623

624
625
626
        disable_cache = not is_compile_cache_enabled(
            self.compilation_config.inductor_compile_config
        )
627
628

        if disable_cache:
629
            logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
630
        else:
631
632
633
634
            logger.info_once(
                "Using cache directory: %s for vLLM's torch.compile",
                local_cache_dir,
                scope="local",
635
            )
636

637
638
639
        self.compiler_manager.initialize_cache(
            local_cache_dir, disable_cache, self.prefix
        )
640

641
642
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
643
        compilation_counter.num_graphs_seen += 1
644
        from .monitor import torch_compile_start_time
645

646
        dynamo_time = time.time() - torch_compile_start_time
647
648
649
        logger.info_once(
            "Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local"
        )
650
        self.compilation_config.compilation_time += dynamo_time
651
652
653
654
655
656

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
657
        self.configure_post_pass()
658

659
660
661
662
663
664
        if self.compilation_config.use_inductor_graph_partition:
            # Let Inductor decide partitioning; avoid FX-level pre-splitting.
            fx_split_ops: list[str] = []
        else:
            fx_split_ops = self.compilation_config.splitting_ops or []

665
        self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
666

667
        from torch._dynamo.utils import lazy_format_graph_code
668
669
670
671
672

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
673

674
        compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
675
        submod_names_to_compile = [
676
677
            item.submod_name
            for item in self.piecewise_graphs
678
679
680
681
682
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
683
684
685
        PiecewiseCompileInterpreter(
            self.split_gm, submod_names_to_compile, self.vllm_config, self
        ).run(*example_inputs)
686

687
688
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
689
690
            # code adapted from
            # https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30
691
            # use `print_readable` because it can include submodules
692
693
694
695
            src = (
                "from __future__ import annotations\nimport torch\n"
                + self.split_gm.print_readable(print_output=False)
            )
696
697
698
699
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

700
701
702
            logger.debug_once(
                "Computation graph saved to %s", graph_path, scope="local"
            )
703

704
705
        self._called = True

706
707
708
709
        if (
            self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
            or not self.compilation_config.cudagraph_copy_inputs
        ):
710
711
712
            return VllmSerializableFunction(
                graph, example_inputs, self.prefix, self.split_gm
            )
713
714
715

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
716

717
718
719
720
721
722
723
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
724
725
726
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
727

728
        self.sym_tensor_indices = [
729
730
731
732
            i
            for i, x in enumerate(fake_args)
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
            and any(is_symbolic(d) for d in x.size())
733
734
735
736
737
738
739
740
741
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

youkaichao's avatar
youkaichao committed
742
743
        # this is the callable we return to Dynamo to run
        def copy_and_call(*args):
744
745
746
747
748
749
750
751
752
753
754
755
756
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

757
758
759
        return VllmSerializableFunction(
            graph, example_inputs, self.prefix, copy_and_call
        )