backends.py 27.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import ast
5
import dataclasses
6
import hashlib
7
8
import os
import pprint
9
import time
10
from collections.abc import Callable, Sequence
11
from contextlib import contextmanager
12
from typing import Any
13
14
15

import torch
import torch.fx as fx
16
from torch._dispatch.python import enable_python_dispatcher
17

18
import vllm.envs as envs
19
20
21
22
23
from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.partition_rules import (
    inductor_partition_rule_context,
    resolve_defined_ops,
)
24
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
25
from vllm.logger import init_logger
26
from vllm.platforms import current_platform
27
from vllm.utils.import_utils import resolve_obj_by_qualname
28
from vllm.utils.torch_utils import is_torch_equal_or_newer
29

30
from .caching import VllmSerializableFunction
31
32
33
34
35
from .compiler_interface import (
    CompilerInterface,
    EagerAdaptor,
    InductorAdaptor,
    InductorStandaloneAdaptor,
36
    is_compile_cache_enabled,
37
)
38
from .counter import compilation_counter
39
40
from .inductor_pass import InductorPass
from .pass_manager import PostGradPassManager
41
42
43

logger = init_logger(__name__)

44

45
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
46
    if compilation_config.backend == "inductor":
47
48
        # Use standalone compile only if requested, version is new enough,
        # and the symbol actually exists in this PyTorch build.
49
50
51
52
53
        if (
            envs.VLLM_USE_STANDALONE_COMPILE
            and is_torch_equal_or_newer("2.8.0.dev")
            and hasattr(torch._inductor, "standalone_compile")
        ):
54
            logger.debug("Using InductorStandaloneAdaptor")
55
56
57
            return InductorStandaloneAdaptor(
                compilation_config.compile_cache_save_format
            )
58
        else:
59
            logger.debug("Using InductorAdaptor")
60
61
            return InductorAdaptor()
    else:
62
        assert compilation_config.backend == "eager", (
63
            "Custom backends not supported with CompilationMode.VLLM_COMPILE"
64
65
        )

66
        logger.debug("Using EagerAdaptor")
67
68
69
        return EagerAdaptor()


70
71
72
73
74
class CompilerManager:
    """
    A manager to manage the compilation process, including
    caching the compiled graph, loading the compiled graph,
    and compiling the graph.
75

76
77
78
    The cache is a dict mapping
    `(runtime_shape, graph_index, backend_name)`
    to `any_data` returned from the compiler.
79

80
81
82
    When serializing the cache, we save it to a Python file
    for readability. We don't use json here because json doesn't
    support int as key.
83
84
    """

85
    def __init__(self, compilation_config: CompilationConfig):
86
        self.cache: dict[tuple[int | None, int, str], Any] = dict()
87
        self.is_cache_updated = False
88
89
        self.compilation_config = compilation_config
        self.compiler = make_compiler(compilation_config)
90

91
92
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        return self.compiler.compute_hash(vllm_config)
93

94
    @contextmanager
95
    def compile_context(self, runtime_shape: int | None = None):
96
97
98
99
100
        """Provide compilation context for the duration of compilation to set
        any torch global properties we want to scope to a single Inductor
        compilation (e.g. partition rules, pass context)."""
        with pass_context(runtime_shape):
            if self.compilation_config.use_inductor_graph_partition:
101
                with inductor_partition_rule_context(
102
                    self.compilation_config.splitting_ops
103
                ):
104
105
106
107
                    yield
            else:
                yield

108
109
110
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
    ):
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        """
        Initialize the cache directory for the compiler.

        The organization of the cache directory is as follows:
        cache_dir=/path/to/hash_str/rank_i_j/prefix/
        inside cache_dir, there will be:
        - vllm_compile_cache.py
        - computation_graph.py
        - transformed_code.py

        for multiple prefixes, they can share the same
        base cache dir of /path/to/hash_str/rank_i_j/ ,
        to store some common compilation artifacts.
        """

126
        self.disable_cache = disable_cache
127
        self.cache_dir = cache_dir
128
129
130
131
        self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")

        if not disable_cache and os.path.exists(self.cache_file_path):
            # load the cache from the file
132
            with open(self.cache_file_path) as f:
133
134
135
136
137
                # we use ast.literal_eval to parse the data
                # because it is a safe way to parse Python literals.
                # do not use eval(), it is unsafe.
                self.cache = ast.literal_eval(f.read())

138
139
140
        self.compiler.initialize_cache(
            cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
        )
141
142

    def save_to_file(self):
143
        if self.disable_cache or not self.is_cache_updated:
144
            return
145
146
        printer = pprint.PrettyPrinter(indent=4)
        data = printer.pformat(self.cache)
147
        with open(self.cache_file_path, "w") as f:
148
149
            f.write(data)

150
151
152
153
154
    def load(
        self,
        graph: fx.GraphModule,
        example_inputs: list[Any],
        graph_index: int,
155
156
        runtime_shape: int | None = None,
    ) -> Callable | None:
157
158
159
        if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
            return None
        handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
160
161
162
        compiled_graph = self.compiler.load(
            handle, graph, example_inputs, graph_index, runtime_shape
        )
163
164
        if runtime_shape is None:
            logger.debug(
165
166
167
168
169
                "Directly load the %s-th graph for dynamic shape from %s via handle %s",
                graph_index,
                self.compiler.name,
                handle,
            )
170
171
        else:
            logger.debug(
172
173
174
175
176
177
                "Directly load the %s-th graph for shape %s from %s via handle %s",
                graph_index,
                str(runtime_shape),
                self.compiler.name,
                handle,
            )
178
179
        return compiled_graph

180
181
182
183
184
185
186
187
    def compile(
        self,
        graph: fx.GraphModule,
        example_inputs,
        additional_inductor_config,
        compilation_config: CompilationConfig,
        graph_index: int = 0,
        num_graphs: int = 1,
188
        runtime_shape: int | None = None,
189
    ) -> Any:
190
        if graph_index == 0:
191
192
193
194
195
196
197
198
199
            # before compiling the first graph, record the start time
            global compilation_start_time
            compilation_start_time = time.time()

        compilation_counter.num_backend_compilations += 1

        compiled_graph = None

        # try to load from the cache
200
        compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
201
        if compiled_graph is not None:
202
203
204
205
206
            if graph_index == num_graphs - 1:
                # after loading the last graph for this shape, record the time.
                # there can be multiple graphs due to piecewise compilation.
                now = time.time()
                elapsed = now - compilation_start_time
207
                compilation_config.compilation_time += elapsed
208
209
210
                if runtime_shape is None:
                    logger.info(
                        "Directly load the compiled graph(s) for dynamic shape "
211
212
213
                        "from the cache, took %.3f s",
                        elapsed,
                    )
214
215
216
                else:
                    logger.info(
                        "Directly load the compiled graph(s) for shape %s "
217
218
219
220
                        "from the cache, took %.3f s",
                        str(runtime_shape),
                        elapsed,
                    )
221
222
223
224
            return compiled_graph

        # no compiler cached the graph, or the cache is disabled,
        # we need to compile it
225
226
227
228
        if isinstance(self.compiler, InductorAdaptor):
            # Let compile_fx generate a key for us
            maybe_key = None
        else:
229
            maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
230
231
232
233
234
235
236
237
238

        with self.compile_context(runtime_shape):
            compiled_graph, handle = self.compiler.compile(
                graph,
                example_inputs,
                additional_inductor_config,
                runtime_shape,
                maybe_key,
            )
239
240
241
242

        assert compiled_graph is not None, "Failed to compile the graph"

        # store the artifact in the cache
243
        if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
244
            self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
245
            compilation_counter.num_cache_entries_updated += 1
246
            self.is_cache_updated = True
247
248
            if graph_index == 0:
                # adds some info logging for the first graph
249
                if runtime_shape is None:
250
251
252
                    logger.info_once(
                        "Cache the graph for dynamic shape for later use", scope="local"
                    )
253
                else:
254
255
256
257
                    logger.info_once(
                        "Cache the graph of shape %s for later use",
                        str(runtime_shape),
                        scope="local",
258
                    )
259
260
            if runtime_shape is None:
                logger.debug(
261
262
263
264
265
                    "Store the %s-th graph for dynamic shape from %s via handle %s",
                    graph_index,
                    self.compiler.name,
                    handle,
                )
266
267
268
            else:
                logger.debug(
                    "Store the %s-th graph for shape %s from %s via handle %s",
269
270
271
272
273
                    graph_index,
                    str(runtime_shape),
                    self.compiler.name,
                    handle,
                )
274
275
276
277
278
279
280

        # after compiling the last graph, record the end time
        if graph_index == num_graphs - 1:
            now = time.time()
            elapsed = now - compilation_start_time
            compilation_config.compilation_time += elapsed
            if runtime_shape is None:
281
282
283
284
285
                logger.info_once(
                    "Compiling a graph for dynamic shape takes %.2f s",
                    elapsed,
                    scope="local",
                )
286
            else:
287
                logger.info_once(
288
289
290
                    "Compiling a graph for shape %s takes %.2f s",
                    runtime_shape,
                    elapsed,
291
                    scope="local",
292
                )
293

294
        return compiled_graph
295
296


297
298
299
@dataclasses.dataclass
class SplitItem:
    submod_name: str
300
    graph_id: int
301
302
303
304
    is_splitting_graph: bool
    graph: fx.GraphModule


305
def split_graph(
306
    graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
307
) -> tuple[fx.GraphModule, list[SplitItem]]:
308
309
310
311
312
313
314
    # split graph by ops
    subgraph_id = 0
    node_to_subgraph_id = {}
    split_op_graphs = []
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
315
316
317
318
319
320
        # Match node.target against resolved_ops
        # node.target can be OpOverloadPacket, need to check .default
        if node.op == "call_function" and (
            node.target in resolved_ops
            or (hasattr(node.target, "default") and node.target.default in resolved_ops)
        ):
321
322
323
324
325
326
327
328
329
330
331
332
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
333
334
        graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
    )
335

336
    outputs = []
337

338
    names = [name for (name, module) in split_gm.named_modules()]
339

340
341
342
343
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
344

345
        module = getattr(split_gm, name)
346

347
        graph_id = int(name.replace("submod_", ""))
348
        outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
349

350
    # sort by integer graph_id, rather than string name
351
    outputs.sort(key=lambda x: x.graph_id)
352

353
    return split_gm, outputs
354
355


356
357
compilation_start_time = 0.0

358
359
360
361
362
363

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
364
365
366
367
368

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
369
370
    """

371
372
373
374
375
376
377
    def __init__(
        self,
        module: torch.fx.GraphModule,
        compile_submod_names: list[str],
        vllm_config: VllmConfig,
        vllm_backend: "VllmBackend",
    ):
378
379
        super().__init__(module)
        from torch._guards import detect_fake_mode
380

381
382
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
383
384
        self.compilation_config = vllm_config.compilation_config
        self.vllm_config = vllm_config
385
        self.vllm_backend = vllm_backend
386
387
        # When True, it annoyingly dumps the torch.fx.Graph on errors.
        self.extra_traceback = False
388
389
390
391
392
393

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
394
        with self.fake_mode, enable_python_dispatcher():
395
            return super().run(*fake_args)
396

397
398
399
400
401
402
    def call_module(
        self,
        target: torch.fx.node.Target,
        args: tuple[torch.fx.node.Argument, ...],
        kwargs: dict[str, Any],
    ) -> Any:
403
404
405
406
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
407
            index = self.compile_submod_names.index(target)
408
409
410
411
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
412
            global compilation_start_time
413

414
415
416
417
418
419
420
421
422
423
424
            compiled_graph_for_dynamic_shape = (
                self.vllm_backend.compiler_manager.compile(
                    submod,
                    args,
                    self.compilation_config.inductor_compile_config,
                    self.compilation_config,
                    graph_index=index,
                    num_graphs=len(self.compile_submod_names),
                    runtime_shape=None,
                )
            )
425
            # Lazy import here to avoid circular import
426
            from .piecewise_backend import PiecewiseBackend
427

428
            piecewise_backend = PiecewiseBackend(
429
430
431
432
433
434
435
436
                submod,
                self.vllm_config,
                index,
                len(self.compile_submod_names),
                sym_shape_indices,
                compiled_graph_for_dynamic_shape,
                self.vllm_backend,
            )
437

438
439
440
441
            if (
                self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
                and not self.compilation_config.use_inductor_graph_partition
            ):
442
443
444
445
                # We're using Dynamo-based piecewise splitting, so we wrap
                # the whole subgraph with a static graph wrapper.
                from .cuda_graph import CUDAGraphOptions

446
447
448
                # resolve the static graph wrapper class (e.g. CUDAGraphWrapper
                # class) as platform dependent.
                static_graph_wrapper_class = resolve_obj_by_qualname(
449
450
                    current_platform.get_static_graph_wrapper_cls()
                )
451
452
453
454
455
456
457
458
459
460
461
462

                # Always assign PIECEWISE runtime mode to the
                # CUDAGraphWrapper for piecewise_backend, to distinguish
                # it from the FULL cudagraph runtime mode, no matter it
                # is wrapped on a full or piecewise fx graph.
                self.module.__dict__[target] = static_graph_wrapper_class(
                    runnable=piecewise_backend,
                    vllm_config=self.vllm_config,
                    runtime_mode=CUDAGraphMode.PIECEWISE,
                    cudagraph_options=CUDAGraphOptions(
                        debug_log_enable=piecewise_backend.is_first_graph,
                        gc_disable=not piecewise_backend.is_first_graph,
463
464
465
                        weak_ref_output=piecewise_backend.is_last_graph,
                    ),
                )
466
467
468
            else:
                self.module.__dict__[target] = piecewise_backend

469
470
471
472
473
            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


474
475
476
477
478
479
480
481
482
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"


@contextmanager
def set_model_tag(tag: str):
    """Context manager to set the model tag."""
    global model_tag
483
    assert tag != model_tag, (
484
        f"Model tag {tag} is the same as the current tag {model_tag}."
485
    )
486
487
488
489
490
491
492
493
    old_tag = model_tag
    model_tag = tag
    try:
        yield
    finally:
        model_tag = old_tag


494
class VllmBackend:
495
    """The compilation backend for `torch.compile` with vLLM.
496
    It is used for compilation mode of `CompilationMode.VLLM_COMPILE`,
497
    where we customize the compilation.
498

499
500
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
501

502
503
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
504
    """
505

506
507
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
508
509
510
511
512
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
513
    piecewise_graphs: list[SplitItem]
514
    returned_callable: Callable
515
516
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
517
518
    sym_tensor_indices: list[int]
    input_buffers: list[torch.Tensor]
519
    compiler_manager: CompilerManager
520

521
522
    def __init__(
        self,
523
        vllm_config: VllmConfig,
524
        prefix: str = "",
525
    ):
526
527
        # if the model is initialized with a non-empty prefix,
        # then usually it's enough to use that prefix,
528
        # e.g. language_model, vision_model, etc.
529
530
531
532
533
        # when multiple parts are initialized as independent
        # models, we need to use the model_tag to distinguish
        # them, e.g. backbone (default), eagle_head, etc.
        self.prefix = prefix or model_tag

534
535
        # Passes to run on the graph post-grad.
        self.post_grad_pass_manager = PostGradPassManager()
536

537
538
539
        self.sym_tensor_indices = []
        self.input_buffers = []

540
541
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
542

543
        self.compiler_manager: CompilerManager = CompilerManager(
544
545
            self.compilation_config
        )
546

547
548
549
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

550
    def configure_post_pass(self):
551
        config = self.compilation_config
552
        self.post_grad_pass_manager.configure(self.vllm_config)
553

554
555
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
556
        inductor_config = config.inductor_compile_config
557
558
        PASS_KEY = "post_grad_custom_post_pass"
        if PASS_KEY in inductor_config:
559
            if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
560
                # PassManager already added to config, make sure it's correct
561
562
563
564
                assert (
                    inductor_config[PASS_KEY].uuid()
                    == self.post_grad_pass_manager.uuid()
                )
565
            else:
566
                # Config should automatically wrap all inductor passes
567
568
                assert isinstance(inductor_config[PASS_KEY], InductorPass)
                self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
569
        inductor_config[PASS_KEY] = self.post_grad_pass_manager
570

571
572
573
574
575
    def __call__(
        self, graph: fx.GraphModule, example_inputs
    ) -> VllmSerializableFunction:
        from .caching import _compute_code_hash, compilation_config_hash_factors

576
        vllm_config = self.vllm_config
577
578
579
580
581
582
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.

583
            factors = compilation_config_hash_factors(vllm_config)
584
585
            # 2. factors come from the code files that are traced by Dynamo (
            #    it mainly summarizes how the model is used in forward pass)
586
            code_hash = _compute_code_hash(self.compilation_config.traced_files)
587
            self.compilation_config.traced_files.clear()
588
589
590
591
592
593
594
            factors.append(code_hash)

            # 3. compiler hash
            compiler_hash = self.compiler_manager.compute_hash(vllm_config)
            factors.append(compiler_hash)

            # combine all factors to generate the cache dir
595
596
597
            hash_key = hashlib.md5(
                str(factors).encode(), usedforsecurity=False
            ).hexdigest()[:10]
598
599

            cache_dir = os.path.join(
600
601
602
603
604
605
                envs.VLLM_CACHE_ROOT,
                "torch_compile_cache",
                hash_key,
            )
            self.compilation_config.cache_dir = cache_dir

606
        cache_dir = self.compilation_config.cache_dir
607
        os.makedirs(cache_dir, exist_ok=True)
608
        self.compilation_config.cache_dir = cache_dir
609
610
        rank = vllm_config.parallel_config.rank
        dp_rank = vllm_config.parallel_config.data_parallel_rank
611
        local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
612
        os.makedirs(local_cache_dir, exist_ok=True)
613
        self.compilation_config.local_cache_dir = local_cache_dir
614

615
616
617
        disable_cache = not is_compile_cache_enabled(
            self.compilation_config.inductor_compile_config
        )
618
619

        if disable_cache:
620
            logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
621
        else:
622
623
624
625
            logger.info_once(
                "Using cache directory: %s for vLLM's torch.compile",
                local_cache_dir,
                scope="local",
626
            )
627

628
629
630
        self.compiler_manager.initialize_cache(
            local_cache_dir, disable_cache, self.prefix
        )
631

632
633
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
634
        compilation_counter.num_graphs_seen += 1
635
        from .monitor import torch_compile_start_time
636

637
        dynamo_time = time.time() - torch_compile_start_time
638
639
640
        logger.info_once(
            "Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local"
        )
641
        self.compilation_config.compilation_time += dynamo_time
642
643
644
645
646
647

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
648
        self.configure_post_pass()
649

650
651
652
653
654
655
656
657
        if self.compilation_config.use_inductor_graph_partition:
            # Let Inductor decide partitioning; avoid FX-level pre-splitting.
            fx_split_ops: list[str] = []
        else:
            fx_split_ops = self.compilation_config.splitting_ops or []

        resolved_split_ops = resolve_defined_ops(fx_split_ops)
        self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
658

659
        from torch._dynamo.utils import lazy_format_graph_code
660
661
662
663
664

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
665

666
        compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
667
        submod_names_to_compile = [
668
669
            item.submod_name
            for item in self.piecewise_graphs
670
671
672
673
674
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
675
676
677
        PiecewiseCompileInterpreter(
            self.split_gm, submod_names_to_compile, self.vllm_config, self
        ).run(*example_inputs)
678

679
680
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
681
682
            # code adapted from
            # https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30
683
            # use `print_readable` because it can include submodules
684
685
686
687
            src = (
                "from __future__ import annotations\nimport torch\n"
                + self.split_gm.print_readable(print_output=False)
            )
688
689
690
691
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

692
693
694
            logger.debug_once(
                "Computation graph saved to %s", graph_path, scope="local"
            )
695

696
697
        self._called = True

698
699
700
701
        if (
            self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
            or not self.compilation_config.cudagraph_copy_inputs
        ):
702
703
704
            return VllmSerializableFunction(
                graph, example_inputs, self.prefix, self.split_gm
            )
705
706
707

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
708

709
710
711
712
713
714
715
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
716
717
718
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
719

720
        self.sym_tensor_indices = [
721
722
723
724
            i
            for i, x in enumerate(fake_args)
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
            and any(is_symbolic(d) for d in x.size())
725
726
727
728
729
730
731
732
733
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

youkaichao's avatar
youkaichao committed
734
735
        # this is the callable we return to Dynamo to run
        def copy_and_call(*args):
736
737
738
739
740
741
742
743
744
745
746
747
748
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

749
750
751
        return VllmSerializableFunction(
            graph, example_inputs, self.prefix, copy_and_call
        )