backends.py 27.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import ast
5
import dataclasses
6
import hashlib
7
8
import os
import pprint
9
import time
10
from collections.abc import Callable, Sequence
11
from contextlib import contextmanager
12
from typing import Any
13
14
15

import torch
import torch.fx as fx
16
from torch._dispatch.python import enable_python_dispatcher
17

18
import vllm.envs as envs
19
20
21
22
23
from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.partition_rules import (
    inductor_partition_rule_context,
    resolve_defined_ops,
)
24
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
25
from vllm.logger import init_logger
26
from vllm.platforms import current_platform
27
from vllm.utils.import_utils import resolve_obj_by_qualname
28
from vllm.utils.torch_utils import is_torch_equal_or_newer
29

30
from .caching import VllmSerializableFunction
31
32
33
34
35
36
from .compiler_interface import (
    CompilerInterface,
    EagerAdaptor,
    InductorAdaptor,
    InductorStandaloneAdaptor,
)
37
from .counter import compilation_counter
38
39
from .inductor_pass import InductorPass
from .pass_manager import PostGradPassManager
40
41
42

logger = init_logger(__name__)

43

44
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
45
    if compilation_config.backend == "inductor":
46
47
        # Use standalone compile only if requested, version is new enough,
        # and the symbol actually exists in this PyTorch build.
48
49
50
51
52
        if (
            envs.VLLM_USE_STANDALONE_COMPILE
            and is_torch_equal_or_newer("2.8.0.dev")
            and hasattr(torch._inductor, "standalone_compile")
        ):
53
            logger.debug("Using InductorStandaloneAdaptor")
54
55
56
            return InductorStandaloneAdaptor(
                compilation_config.compile_cache_save_format
            )
57
        else:
58
            logger.debug("Using InductorAdaptor")
59
60
            return InductorAdaptor()
    else:
61
        assert compilation_config.backend == "eager", (
62
            "Custom backends not supported with CompilationMode.VLLM_COMPILE"
63
64
        )

65
        logger.debug("Using EagerAdaptor")
66
67
68
        return EagerAdaptor()


69
70
71
72
73
class CompilerManager:
    """
    A manager to manage the compilation process, including
    caching the compiled graph, loading the compiled graph,
    and compiling the graph.
74

75
76
77
    The cache is a dict mapping
    `(runtime_shape, graph_index, backend_name)`
    to `any_data` returned from the compiler.
78

79
80
81
    When serializing the cache, we save it to a Python file
    for readability. We don't use json here because json doesn't
    support int as key.
82
83
    """

84
    def __init__(self, compilation_config: CompilationConfig):
85
        self.cache: dict[tuple[int | None, int, str], Any] = dict()
86
        self.is_cache_updated = False
87
88
        self.compilation_config = compilation_config
        self.compiler = make_compiler(compilation_config)
89

90
91
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        return self.compiler.compute_hash(vllm_config)
92

93
    @contextmanager
94
    def compile_context(self, runtime_shape: int | None = None):
95
96
97
98
99
        """Provide compilation context for the duration of compilation to set
        any torch global properties we want to scope to a single Inductor
        compilation (e.g. partition rules, pass context)."""
        with pass_context(runtime_shape):
            if self.compilation_config.use_inductor_graph_partition:
100
                with inductor_partition_rule_context(
101
                    self.compilation_config.splitting_ops
102
                ):
103
104
105
106
                    yield
            else:
                yield

107
108
109
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
    ):
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        """
        Initialize the cache directory for the compiler.

        The organization of the cache directory is as follows:
        cache_dir=/path/to/hash_str/rank_i_j/prefix/
        inside cache_dir, there will be:
        - vllm_compile_cache.py
        - computation_graph.py
        - transformed_code.py

        for multiple prefixes, they can share the same
        base cache dir of /path/to/hash_str/rank_i_j/ ,
        to store some common compilation artifacts.
        """

125
        self.disable_cache = disable_cache
126
        self.cache_dir = cache_dir
127
128
129
130
        self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")

        if not disable_cache and os.path.exists(self.cache_file_path):
            # load the cache from the file
131
            with open(self.cache_file_path) as f:
132
133
134
135
136
                # we use ast.literal_eval to parse the data
                # because it is a safe way to parse Python literals.
                # do not use eval(), it is unsafe.
                self.cache = ast.literal_eval(f.read())

137
138
139
        self.compiler.initialize_cache(
            cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
        )
140
141

    def save_to_file(self):
142
        if self.disable_cache or not self.is_cache_updated:
143
            return
144
145
        printer = pprint.PrettyPrinter(indent=4)
        data = printer.pformat(self.cache)
146
        with open(self.cache_file_path, "w") as f:
147
148
            f.write(data)

149
150
151
152
153
    def load(
        self,
        graph: fx.GraphModule,
        example_inputs: list[Any],
        graph_index: int,
154
155
        runtime_shape: int | None = None,
    ) -> Callable | None:
156
157
158
        if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
            return None
        handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
159
160
161
        compiled_graph = self.compiler.load(
            handle, graph, example_inputs, graph_index, runtime_shape
        )
162
163
        if runtime_shape is None:
            logger.debug(
164
165
166
167
168
                "Directly load the %s-th graph for dynamic shape from %s via handle %s",
                graph_index,
                self.compiler.name,
                handle,
            )
169
170
        else:
            logger.debug(
171
172
173
174
175
176
                "Directly load the %s-th graph for shape %s from %s via handle %s",
                graph_index,
                str(runtime_shape),
                self.compiler.name,
                handle,
            )
177
178
        return compiled_graph

179
180
181
182
183
184
185
186
    def compile(
        self,
        graph: fx.GraphModule,
        example_inputs,
        additional_inductor_config,
        compilation_config: CompilationConfig,
        graph_index: int = 0,
        num_graphs: int = 1,
187
        runtime_shape: int | None = None,
188
    ) -> Any:
189
        if graph_index == 0:
190
191
192
193
194
195
196
197
198
            # before compiling the first graph, record the start time
            global compilation_start_time
            compilation_start_time = time.time()

        compilation_counter.num_backend_compilations += 1

        compiled_graph = None

        # try to load from the cache
199
        compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
200
        if compiled_graph is not None:
201
202
203
204
205
            if graph_index == num_graphs - 1:
                # after loading the last graph for this shape, record the time.
                # there can be multiple graphs due to piecewise compilation.
                now = time.time()
                elapsed = now - compilation_start_time
206
                compilation_config.compilation_time += elapsed
207
208
209
                if runtime_shape is None:
                    logger.info(
                        "Directly load the compiled graph(s) for dynamic shape "
210
211
212
                        "from the cache, took %.3f s",
                        elapsed,
                    )
213
214
215
                else:
                    logger.info(
                        "Directly load the compiled graph(s) for shape %s "
216
217
218
219
                        "from the cache, took %.3f s",
                        str(runtime_shape),
                        elapsed,
                    )
220
221
222
223
            return compiled_graph

        # no compiler cached the graph, or the cache is disabled,
        # we need to compile it
224
225
226
227
        if isinstance(self.compiler, InductorAdaptor):
            # Let compile_fx generate a key for us
            maybe_key = None
        else:
228
            maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
229
230
231
232
233
234
235
236
237

        with self.compile_context(runtime_shape):
            compiled_graph, handle = self.compiler.compile(
                graph,
                example_inputs,
                additional_inductor_config,
                runtime_shape,
                maybe_key,
            )
238
239
240
241

        assert compiled_graph is not None, "Failed to compile the graph"

        # store the artifact in the cache
242
        if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None:
243
            self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
244
            compilation_counter.num_cache_entries_updated += 1
245
            self.is_cache_updated = True
246
247
            if graph_index == 0:
                # adds some info logging for the first graph
248
                if runtime_shape is None:
249
250
251
                    logger.info_once(
                        "Cache the graph for dynamic shape for later use", scope="local"
                    )
252
                else:
253
254
255
256
                    logger.info_once(
                        "Cache the graph of shape %s for later use",
                        str(runtime_shape),
                        scope="local",
257
                    )
258
259
            if runtime_shape is None:
                logger.debug(
260
261
262
263
264
                    "Store the %s-th graph for dynamic shape from %s via handle %s",
                    graph_index,
                    self.compiler.name,
                    handle,
                )
265
266
267
            else:
                logger.debug(
                    "Store the %s-th graph for shape %s from %s via handle %s",
268
269
270
271
272
                    graph_index,
                    str(runtime_shape),
                    self.compiler.name,
                    handle,
                )
273
274
275
276
277
278
279

        # after compiling the last graph, record the end time
        if graph_index == num_graphs - 1:
            now = time.time()
            elapsed = now - compilation_start_time
            compilation_config.compilation_time += elapsed
            if runtime_shape is None:
280
281
282
283
284
                logger.info_once(
                    "Compiling a graph for dynamic shape takes %.2f s",
                    elapsed,
                    scope="local",
                )
285
            else:
286
                logger.info_once(
287
288
289
                    "Compiling a graph for shape %s takes %.2f s",
                    runtime_shape,
                    elapsed,
290
                    scope="local",
291
                )
292

293
        return compiled_graph
294
295


296
297
298
@dataclasses.dataclass
class SplitItem:
    submod_name: str
299
    graph_id: int
300
301
302
303
    is_splitting_graph: bool
    graph: fx.GraphModule


304
def split_graph(
305
    graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
306
) -> tuple[fx.GraphModule, list[SplitItem]]:
307
308
309
310
311
312
313
    # split graph by ops
    subgraph_id = 0
    node_to_subgraph_id = {}
    split_op_graphs = []
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
314
315
316
317
318
319
        # Match node.target against resolved_ops
        # node.target can be OpOverloadPacket, need to check .default
        if node.op == "call_function" and (
            node.target in resolved_ops
            or (hasattr(node.target, "default") and node.target.default in resolved_ops)
        ):
320
321
322
323
324
325
326
327
328
329
330
331
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
332
333
        graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
    )
334

335
    outputs = []
336

337
    names = [name for (name, module) in split_gm.named_modules()]
338

339
340
341
342
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
343

344
        module = getattr(split_gm, name)
345

346
        graph_id = int(name.replace("submod_", ""))
347
        outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
348

349
    # sort by integer graph_id, rather than string name
350
    outputs.sort(key=lambda x: x.graph_id)
351

352
    return split_gm, outputs
353
354


355
356
compilation_start_time = 0.0

357
358
359
360
361
362

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
363
364
365
366
367

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
368
369
    """

370
371
372
373
374
375
376
    def __init__(
        self,
        module: torch.fx.GraphModule,
        compile_submod_names: list[str],
        vllm_config: VllmConfig,
        vllm_backend: "VllmBackend",
    ):
377
378
        super().__init__(module)
        from torch._guards import detect_fake_mode
379

380
381
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
382
383
        self.compilation_config = vllm_config.compilation_config
        self.vllm_config = vllm_config
384
        self.vllm_backend = vllm_backend
385
386
        # When True, it annoyingly dumps the torch.fx.Graph on errors.
        self.extra_traceback = False
387
388
389
390
391
392

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
393
        with self.fake_mode, enable_python_dispatcher():
394
            return super().run(*fake_args)
395

396
397
398
399
400
401
    def call_module(
        self,
        target: torch.fx.node.Target,
        args: tuple[torch.fx.node.Argument, ...],
        kwargs: dict[str, Any],
    ) -> Any:
402
403
404
405
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
406
            index = self.compile_submod_names.index(target)
407
408
409
410
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
411
            global compilation_start_time
412

413
414
415
416
417
418
419
420
421
422
423
            compiled_graph_for_dynamic_shape = (
                self.vllm_backend.compiler_manager.compile(
                    submod,
                    args,
                    self.compilation_config.inductor_compile_config,
                    self.compilation_config,
                    graph_index=index,
                    num_graphs=len(self.compile_submod_names),
                    runtime_shape=None,
                )
            )
424
            # Lazy import here to avoid circular import
425
            from .piecewise_backend import PiecewiseBackend
426

427
            piecewise_backend = PiecewiseBackend(
428
429
430
431
432
433
434
435
                submod,
                self.vllm_config,
                index,
                len(self.compile_submod_names),
                sym_shape_indices,
                compiled_graph_for_dynamic_shape,
                self.vllm_backend,
            )
436

437
438
439
440
            if (
                self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
                and not self.compilation_config.use_inductor_graph_partition
            ):
441
442
443
444
                # We're using Dynamo-based piecewise splitting, so we wrap
                # the whole subgraph with a static graph wrapper.
                from .cuda_graph import CUDAGraphOptions

445
446
447
                # resolve the static graph wrapper class (e.g. CUDAGraphWrapper
                # class) as platform dependent.
                static_graph_wrapper_class = resolve_obj_by_qualname(
448
449
                    current_platform.get_static_graph_wrapper_cls()
                )
450
451
452
453
454
455
456
457
458
459
460
461

                # Always assign PIECEWISE runtime mode to the
                # CUDAGraphWrapper for piecewise_backend, to distinguish
                # it from the FULL cudagraph runtime mode, no matter it
                # is wrapped on a full or piecewise fx graph.
                self.module.__dict__[target] = static_graph_wrapper_class(
                    runnable=piecewise_backend,
                    vllm_config=self.vllm_config,
                    runtime_mode=CUDAGraphMode.PIECEWISE,
                    cudagraph_options=CUDAGraphOptions(
                        debug_log_enable=piecewise_backend.is_first_graph,
                        gc_disable=not piecewise_backend.is_first_graph,
462
463
464
                        weak_ref_output=piecewise_backend.is_last_graph,
                    ),
                )
465
466
467
            else:
                self.module.__dict__[target] = piecewise_backend

468
469
470
471
472
            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


473
474
475
476
477
478
479
480
481
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"


@contextmanager
def set_model_tag(tag: str):
    """Context manager to set the model tag."""
    global model_tag
482
    assert tag != model_tag, (
483
        f"Model tag {tag} is the same as the current tag {model_tag}."
484
    )
485
486
487
488
489
490
491
492
    old_tag = model_tag
    model_tag = tag
    try:
        yield
    finally:
        model_tag = old_tag


493
class VllmBackend:
494
    """The compilation backend for `torch.compile` with vLLM.
495
    It is used for compilation mode of `CompilationMode.VLLM_COMPILE`,
496
    where we customize the compilation.
497

498
499
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
500

501
502
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
503
    """
504

505
506
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
507
508
509
510
511
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
512
    piecewise_graphs: list[SplitItem]
513
    returned_callable: Callable
514
515
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
516
517
    sym_tensor_indices: list[int]
    input_buffers: list[torch.Tensor]
518
    compiler_manager: CompilerManager
519

520
521
    def __init__(
        self,
522
        vllm_config: VllmConfig,
523
        prefix: str = "",
524
    ):
525
526
        # if the model is initialized with a non-empty prefix,
        # then usually it's enough to use that prefix,
527
        # e.g. language_model, vision_model, etc.
528
529
530
531
532
        # when multiple parts are initialized as independent
        # models, we need to use the model_tag to distinguish
        # them, e.g. backbone (default), eagle_head, etc.
        self.prefix = prefix or model_tag

533
534
        # Passes to run on the graph post-grad.
        self.post_grad_pass_manager = PostGradPassManager()
535

536
537
538
        self.sym_tensor_indices = []
        self.input_buffers = []

539
540
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
541

542
        self.compiler_manager: CompilerManager = CompilerManager(
543
544
            self.compilation_config
        )
545

546
547
548
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

549
    def configure_post_pass(self):
550
        config = self.compilation_config
551
        self.post_grad_pass_manager.configure(self.vllm_config)
552

553
554
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
555
        inductor_config = config.inductor_compile_config
556
557
        PASS_KEY = "post_grad_custom_post_pass"
        if PASS_KEY in inductor_config:
558
            if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
559
                # PassManager already added to config, make sure it's correct
560
561
562
563
                assert (
                    inductor_config[PASS_KEY].uuid()
                    == self.post_grad_pass_manager.uuid()
                )
564
            else:
565
                # Config should automatically wrap all inductor passes
566
567
                assert isinstance(inductor_config[PASS_KEY], InductorPass)
                self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
568
        inductor_config[PASS_KEY] = self.post_grad_pass_manager
569

570
571
572
573
574
    def __call__(
        self, graph: fx.GraphModule, example_inputs
    ) -> VllmSerializableFunction:
        from .caching import _compute_code_hash, compilation_config_hash_factors

575
        vllm_config = self.vllm_config
576
577
578
579
580
581
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.

582
            factors = compilation_config_hash_factors(vllm_config)
583
584
            # 2. factors come from the code files that are traced by Dynamo (
            #    it mainly summarizes how the model is used in forward pass)
585
            code_hash = _compute_code_hash(self.compilation_config.traced_files)
586
            self.compilation_config.traced_files.clear()
587
588
589
590
591
592
593
            factors.append(code_hash)

            # 3. compiler hash
            compiler_hash = self.compiler_manager.compute_hash(vllm_config)
            factors.append(compiler_hash)

            # combine all factors to generate the cache dir
594
595
596
            hash_key = hashlib.md5(
                str(factors).encode(), usedforsecurity=False
            ).hexdigest()[:10]
597
598

            cache_dir = os.path.join(
599
600
601
602
603
604
                envs.VLLM_CACHE_ROOT,
                "torch_compile_cache",
                hash_key,
            )
            self.compilation_config.cache_dir = cache_dir

605
        cache_dir = self.compilation_config.cache_dir
606
        os.makedirs(cache_dir, exist_ok=True)
607
        self.compilation_config.cache_dir = cache_dir
608
609
        rank = vllm_config.parallel_config.rank
        dp_rank = vllm_config.parallel_config.data_parallel_rank
610
        local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
611
        os.makedirs(local_cache_dir, exist_ok=True)
612
        self.compilation_config.local_cache_dir = local_cache_dir
613

614
615
616
        disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE

        if disable_cache:
617
            logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
618
        else:
619
620
621
622
            logger.info_once(
                "Using cache directory: %s for vLLM's torch.compile",
                local_cache_dir,
                scope="local",
623
            )
624

625
626
627
        self.compiler_manager.initialize_cache(
            local_cache_dir, disable_cache, self.prefix
        )
628

629
630
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
631
        compilation_counter.num_graphs_seen += 1
632
        from .monitor import torch_compile_start_time
633

634
        dynamo_time = time.time() - torch_compile_start_time
635
636
637
        logger.info_once(
            "Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local"
        )
638
        self.compilation_config.compilation_time += dynamo_time
639
640
641
642
643
644

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
645
        self.configure_post_pass()
646

647
648
649
650
651
652
653
654
        if self.compilation_config.use_inductor_graph_partition:
            # Let Inductor decide partitioning; avoid FX-level pre-splitting.
            fx_split_ops: list[str] = []
        else:
            fx_split_ops = self.compilation_config.splitting_ops or []

        resolved_split_ops = resolve_defined_ops(fx_split_ops)
        self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
655

656
        from torch._dynamo.utils import lazy_format_graph_code
657
658
659
660
661

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
662

663
        compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
664
        submod_names_to_compile = [
665
666
            item.submod_name
            for item in self.piecewise_graphs
667
668
669
670
671
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
672
673
674
        PiecewiseCompileInterpreter(
            self.split_gm, submod_names_to_compile, self.vllm_config, self
        ).run(*example_inputs)
675

676
677
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
678
679
            # code adapted from
            # https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30
680
            # use `print_readable` because it can include submodules
681
682
683
684
            src = (
                "from __future__ import annotations\nimport torch\n"
                + self.split_gm.print_readable(print_output=False)
            )
685
686
687
688
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

689
690
691
            logger.debug_once(
                "Computation graph saved to %s", graph_path, scope="local"
            )
692

693
694
        self._called = True

695
696
697
698
        if (
            self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
            or not self.compilation_config.cudagraph_copy_inputs
        ):
699
700
701
            return VllmSerializableFunction(
                graph, example_inputs, self.prefix, self.split_gm
            )
702
703
704

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
705

706
707
708
709
710
711
712
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
713
714
715
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
716

717
        self.sym_tensor_indices = [
718
719
720
721
            i
            for i, x in enumerate(fake_args)
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
            and any(is_symbolic(d) for d in x.size())
722
723
724
725
726
727
728
729
730
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

youkaichao's avatar
youkaichao committed
731
732
        # this is the callable we return to Dynamo to run
        def copy_and_call(*args):
733
734
735
736
737
738
739
740
741
742
743
744
745
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

746
747
748
        return VllmSerializableFunction(
            graph, example_inputs, self.prefix, copy_and_call
        )