backends.py 27.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import ast
5
import dataclasses
6
import hashlib
7
8
import os
import pprint
9
import time
10
from collections.abc import Callable, Sequence
11
from contextlib import contextmanager
12
from typing import Any
13
14
15

import torch
import torch.fx as fx
16
from torch._dispatch.python import enable_python_dispatcher
17

18
import vllm.envs as envs
19
20
21
22
23
from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.partition_rules import (
    inductor_partition_rule_context,
    resolve_defined_ops,
)
24
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
25
from vllm.logger import init_logger
26
from vllm.platforms import current_platform
27
from vllm.utils.import_utils import resolve_obj_by_qualname
28
from vllm.utils.torch_utils import is_torch_equal_or_newer
29

30
from .caching import VllmSerializableFunction
31
32
33
34
35
36
from .compiler_interface import (
    CompilerInterface,
    EagerAdaptor,
    InductorAdaptor,
    InductorStandaloneAdaptor,
)
37
from .counter import compilation_counter
38
39
from .inductor_pass import InductorPass
from .pass_manager import PostGradPassManager
40
41
42

logger = init_logger(__name__)

43

44
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
45
    if compilation_config.backend == "inductor":
46
47
        # Use standalone compile only if requested, version is new enough,
        # and the symbol actually exists in this PyTorch build.
48
49
50
51
52
        if (
            envs.VLLM_USE_STANDALONE_COMPILE
            and is_torch_equal_or_newer("2.8.0.dev")
            and hasattr(torch._inductor, "standalone_compile")
        ):
53
            logger.debug("Using InductorStandaloneAdaptor")
54
55
            return InductorStandaloneAdaptor()
        else:
56
            logger.debug("Using InductorAdaptor")
57
58
            return InductorAdaptor()
    else:
59
        assert compilation_config.backend == "eager", (
60
            "Custom backends not supported with CompilationMode.VLLM_COMPILE"
61
62
        )

63
        logger.debug("Using EagerAdaptor")
64
65
66
        return EagerAdaptor()


67
68
69
70
71
class CompilerManager:
    """
    A manager to manage the compilation process, including
    caching the compiled graph, loading the compiled graph,
    and compiling the graph.
72

73
74
75
    The cache is a dict mapping
    `(runtime_shape, graph_index, backend_name)`
    to `any_data` returned from the compiler.
76

77
78
79
    When serializing the cache, we save it to a Python file
    for readability. We don't use json here because json doesn't
    support int as key.
80
81
    """

82
    def __init__(self, compilation_config: CompilationConfig):
83
        self.cache: dict[tuple[int | None, int, str], Any] = dict()
84
        self.is_cache_updated = False
85
86
        self.compilation_config = compilation_config
        self.compiler = make_compiler(compilation_config)
87

88
89
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        return self.compiler.compute_hash(vllm_config)
90

91
    @contextmanager
92
    def compile_context(self, runtime_shape: int | None = None):
93
94
95
96
97
98
99
100
101
102
103
104
105
        """Provide compilation context for the duration of compilation to set
        any torch global properties we want to scope to a single Inductor
        compilation (e.g. partition rules, pass context)."""
        with pass_context(runtime_shape):
            if self.compilation_config.use_inductor_graph_partition:
                inductor_partition_ops = resolve_defined_ops(
                    self.compilation_config.splitting_ops
                )
                with inductor_partition_rule_context(inductor_partition_ops):
                    yield
            else:
                yield

106
107
108
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
    ):
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        """
        Initialize the cache directory for the compiler.

        The organization of the cache directory is as follows:
        cache_dir=/path/to/hash_str/rank_i_j/prefix/
        inside cache_dir, there will be:
        - vllm_compile_cache.py
        - computation_graph.py
        - transformed_code.py

        for multiple prefixes, they can share the same
        base cache dir of /path/to/hash_str/rank_i_j/ ,
        to store some common compilation artifacts.
        """

124
        self.disable_cache = disable_cache
125
        self.cache_dir = cache_dir
126
127
128
129
        self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")

        if not disable_cache and os.path.exists(self.cache_file_path):
            # load the cache from the file
130
            with open(self.cache_file_path) as f:
131
132
133
134
135
                # we use ast.literal_eval to parse the data
                # because it is a safe way to parse Python literals.
                # do not use eval(), it is unsafe.
                self.cache = ast.literal_eval(f.read())

136
137
138
        self.compiler.initialize_cache(
            cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
        )
139
140

    def save_to_file(self):
141
        if self.disable_cache or not self.is_cache_updated:
142
            return
143
144
        printer = pprint.PrettyPrinter(indent=4)
        data = printer.pformat(self.cache)
145
        with open(self.cache_file_path, "w") as f:
146
147
            f.write(data)

148
149
150
151
152
    def load(
        self,
        graph: fx.GraphModule,
        example_inputs: list[Any],
        graph_index: int,
153
154
        runtime_shape: int | None = None,
    ) -> Callable | None:
155
156
157
        if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
            return None
        handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
158
159
160
        compiled_graph = self.compiler.load(
            handle, graph, example_inputs, graph_index, runtime_shape
        )
161
162
        if runtime_shape is None:
            logger.debug(
163
164
165
166
167
                "Directly load the %s-th graph for dynamic shape from %s via handle %s",
                graph_index,
                self.compiler.name,
                handle,
            )
168
169
        else:
            logger.debug(
170
171
172
173
174
175
                "Directly load the %s-th graph for shape %s from %s via handle %s",
                graph_index,
                str(runtime_shape),
                self.compiler.name,
                handle,
            )
176
177
        return compiled_graph

178
179
180
181
182
183
184
185
    def compile(
        self,
        graph: fx.GraphModule,
        example_inputs,
        additional_inductor_config,
        compilation_config: CompilationConfig,
        graph_index: int = 0,
        num_graphs: int = 1,
186
        runtime_shape: int | None = None,
187
    ) -> Any:
188
        if graph_index == 0:
189
190
191
192
193
194
195
196
197
            # before compiling the first graph, record the start time
            global compilation_start_time
            compilation_start_time = time.time()

        compilation_counter.num_backend_compilations += 1

        compiled_graph = None

        # try to load from the cache
198
        compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
199
        if compiled_graph is not None:
200
201
202
203
204
            if graph_index == num_graphs - 1:
                # after loading the last graph for this shape, record the time.
                # there can be multiple graphs due to piecewise compilation.
                now = time.time()
                elapsed = now - compilation_start_time
205
                compilation_config.compilation_time += elapsed
206
207
208
                if runtime_shape is None:
                    logger.info(
                        "Directly load the compiled graph(s) for dynamic shape "
209
210
211
                        "from the cache, took %.3f s",
                        elapsed,
                    )
212
213
214
                else:
                    logger.info(
                        "Directly load the compiled graph(s) for shape %s "
215
216
217
218
                        "from the cache, took %.3f s",
                        str(runtime_shape),
                        elapsed,
                    )
219
220
221
222
            return compiled_graph

        # no compiler cached the graph, or the cache is disabled,
        # we need to compile it
223
224
225
226
        if isinstance(self.compiler, InductorAdaptor):
            # Let compile_fx generate a key for us
            maybe_key = None
        else:
227
            maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
228
229
230
231
232
233
234
235
236

        with self.compile_context(runtime_shape):
            compiled_graph, handle = self.compiler.compile(
                graph,
                example_inputs,
                additional_inductor_config,
                runtime_shape,
                maybe_key,
            )
237
238
239
240

        assert compiled_graph is not None, "Failed to compile the graph"

        # store the artifact in the cache
241
        if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None:
242
            self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
243
            compilation_counter.num_cache_entries_updated += 1
244
            self.is_cache_updated = True
245
246
            if graph_index == 0:
                # adds some info logging for the first graph
247
                if runtime_shape is None:
248
249
250
                    logger.info_once(
                        "Cache the graph for dynamic shape for later use", scope="local"
                    )
251
                else:
252
253
254
255
                    logger.info_once(
                        "Cache the graph of shape %s for later use",
                        str(runtime_shape),
                        scope="local",
256
                    )
257
258
            if runtime_shape is None:
                logger.debug(
259
260
261
262
263
                    "Store the %s-th graph for dynamic shape from %s via handle %s",
                    graph_index,
                    self.compiler.name,
                    handle,
                )
264
265
266
            else:
                logger.debug(
                    "Store the %s-th graph for shape %s from %s via handle %s",
267
268
269
270
271
                    graph_index,
                    str(runtime_shape),
                    self.compiler.name,
                    handle,
                )
272
273
274
275
276
277
278

        # after compiling the last graph, record the end time
        if graph_index == num_graphs - 1:
            now = time.time()
            elapsed = now - compilation_start_time
            compilation_config.compilation_time += elapsed
            if runtime_shape is None:
279
280
281
282
283
                logger.info_once(
                    "Compiling a graph for dynamic shape takes %.2f s",
                    elapsed,
                    scope="local",
                )
284
            else:
285
                logger.info_once(
286
287
288
                    "Compiling a graph for shape %s takes %.2f s",
                    runtime_shape,
                    elapsed,
289
                    scope="local",
290
                )
291

292
        return compiled_graph
293
294


295
296
297
@dataclasses.dataclass
class SplitItem:
    submod_name: str
298
    graph_id: int
299
300
301
302
    is_splitting_graph: bool
    graph: fx.GraphModule


303
def split_graph(
304
    graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
305
) -> tuple[fx.GraphModule, list[SplitItem]]:
306
307
308
309
310
311
312
    # split graph by ops
    subgraph_id = 0
    node_to_subgraph_id = {}
    split_op_graphs = []
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
313
314
315
316
317
318
        # Match node.target against resolved_ops
        # node.target can be OpOverloadPacket, need to check .default
        if node.op == "call_function" and (
            node.target in resolved_ops
            or (hasattr(node.target, "default") and node.target.default in resolved_ops)
        ):
319
320
321
322
323
324
325
326
327
328
329
330
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
331
332
        graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
    )
333

334
    outputs = []
335

336
    names = [name for (name, module) in split_gm.named_modules()]
337

338
339
340
341
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
342

343
        module = getattr(split_gm, name)
344

345
        graph_id = int(name.replace("submod_", ""))
346
        outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
347

348
    # sort by integer graph_id, rather than string name
349
    outputs.sort(key=lambda x: x.graph_id)
350

351
    return split_gm, outputs
352
353


354
355
compilation_start_time = 0.0

356
357
358
359
360
361

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
362
363
364
365
366

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
367
368
    """

369
370
371
372
373
374
375
    def __init__(
        self,
        module: torch.fx.GraphModule,
        compile_submod_names: list[str],
        vllm_config: VllmConfig,
        vllm_backend: "VllmBackend",
    ):
376
377
        super().__init__(module)
        from torch._guards import detect_fake_mode
378

379
380
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
381
382
        self.compilation_config = vllm_config.compilation_config
        self.vllm_config = vllm_config
383
        self.vllm_backend = vllm_backend
384
385
        # When True, it annoyingly dumps the torch.fx.Graph on errors.
        self.extra_traceback = False
386
387
388
389
390
391

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
392
        with self.fake_mode, enable_python_dispatcher():
393
            return super().run(*fake_args)
394

395
396
397
398
399
400
    def call_module(
        self,
        target: torch.fx.node.Target,
        args: tuple[torch.fx.node.Argument, ...],
        kwargs: dict[str, Any],
    ) -> Any:
401
402
403
404
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
405
            index = self.compile_submod_names.index(target)
406
407
408
409
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
410
            global compilation_start_time
411

412
413
414
415
416
417
418
419
420
421
422
            compiled_graph_for_dynamic_shape = (
                self.vllm_backend.compiler_manager.compile(
                    submod,
                    args,
                    self.compilation_config.inductor_compile_config,
                    self.compilation_config,
                    graph_index=index,
                    num_graphs=len(self.compile_submod_names),
                    runtime_shape=None,
                )
            )
423
            # Lazy import here to avoid circular import
424
            from .piecewise_backend import PiecewiseBackend
425

426
            piecewise_backend = PiecewiseBackend(
427
428
429
430
431
432
433
434
                submod,
                self.vllm_config,
                index,
                len(self.compile_submod_names),
                sym_shape_indices,
                compiled_graph_for_dynamic_shape,
                self.vllm_backend,
            )
435

436
437
438
439
            if (
                self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
                and not self.compilation_config.use_inductor_graph_partition
            ):
440
441
442
443
                # We're using Dynamo-based piecewise splitting, so we wrap
                # the whole subgraph with a static graph wrapper.
                from .cuda_graph import CUDAGraphOptions

444
445
446
                # resolve the static graph wrapper class (e.g. CUDAGraphWrapper
                # class) as platform dependent.
                static_graph_wrapper_class = resolve_obj_by_qualname(
447
448
                    current_platform.get_static_graph_wrapper_cls()
                )
449
450
451
452
453
454
455
456
457
458
459
460

                # Always assign PIECEWISE runtime mode to the
                # CUDAGraphWrapper for piecewise_backend, to distinguish
                # it from the FULL cudagraph runtime mode, no matter it
                # is wrapped on a full or piecewise fx graph.
                self.module.__dict__[target] = static_graph_wrapper_class(
                    runnable=piecewise_backend,
                    vllm_config=self.vllm_config,
                    runtime_mode=CUDAGraphMode.PIECEWISE,
                    cudagraph_options=CUDAGraphOptions(
                        debug_log_enable=piecewise_backend.is_first_graph,
                        gc_disable=not piecewise_backend.is_first_graph,
461
462
463
                        weak_ref_output=piecewise_backend.is_last_graph,
                    ),
                )
464
465
466
            else:
                self.module.__dict__[target] = piecewise_backend

467
468
469
470
471
            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


472
473
474
475
476
477
478
479
480
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"


@contextmanager
def set_model_tag(tag: str):
    """Context manager to set the model tag."""
    global model_tag
481
    assert tag != model_tag, (
482
        f"Model tag {tag} is the same as the current tag {model_tag}."
483
    )
484
485
486
487
488
489
490
491
    old_tag = model_tag
    model_tag = tag
    try:
        yield
    finally:
        model_tag = old_tag


492
class VllmBackend:
493
    """The compilation backend for `torch.compile` with vLLM.
494
    It is used for compilation mode of `CompilationMode.VLLM_COMPILE`,
495
    where we customize the compilation.
496

497
498
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
499

500
501
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
502
    """
503

504
505
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
506
507
508
509
510
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
511
    piecewise_graphs: list[SplitItem]
512
    returned_callable: Callable
513
514
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
515
516
    sym_tensor_indices: list[int]
    input_buffers: list[torch.Tensor]
517
    compiler_manager: CompilerManager
518

519
520
    def __init__(
        self,
521
        vllm_config: VllmConfig,
522
        prefix: str = "",
523
    ):
524
525
        # if the model is initialized with a non-empty prefix,
        # then usually it's enough to use that prefix,
526
        # e.g. language_model, vision_model, etc.
527
528
529
530
531
        # when multiple parts are initialized as independent
        # models, we need to use the model_tag to distinguish
        # them, e.g. backbone (default), eagle_head, etc.
        self.prefix = prefix or model_tag

532
533
        # Passes to run on the graph post-grad.
        self.post_grad_pass_manager = PostGradPassManager()
534

535
536
537
        self.sym_tensor_indices = []
        self.input_buffers = []

538
539
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
540

541
        self.compiler_manager: CompilerManager = CompilerManager(
542
543
            self.compilation_config
        )
544

545
546
547
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

548
    def configure_post_pass(self):
549
        config = self.compilation_config
550
        self.post_grad_pass_manager.configure(self.vllm_config)
551

552
553
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
554
        inductor_config = config.inductor_compile_config
555
556
        PASS_KEY = "post_grad_custom_post_pass"
        if PASS_KEY in inductor_config:
557
            if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
558
                # PassManager already added to config, make sure it's correct
559
560
561
562
                assert (
                    inductor_config[PASS_KEY].uuid()
                    == self.post_grad_pass_manager.uuid()
                )
563
            else:
564
                # Config should automatically wrap all inductor passes
565
566
                assert isinstance(inductor_config[PASS_KEY], InductorPass)
                self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
567
        inductor_config[PASS_KEY] = self.post_grad_pass_manager
568

569
570
571
572
573
    def __call__(
        self, graph: fx.GraphModule, example_inputs
    ) -> VllmSerializableFunction:
        from .caching import _compute_code_hash, compilation_config_hash_factors

574
        vllm_config = self.vllm_config
575
576
577
578
579
580
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.

581
            factors = compilation_config_hash_factors(vllm_config)
582
583
            # 2. factors come from the code files that are traced by Dynamo (
            #    it mainly summarizes how the model is used in forward pass)
584
            code_hash = _compute_code_hash(self.compilation_config.traced_files)
585
            self.compilation_config.traced_files.clear()
586
587
588
589
590
591
592
            factors.append(code_hash)

            # 3. compiler hash
            compiler_hash = self.compiler_manager.compute_hash(vllm_config)
            factors.append(compiler_hash)

            # combine all factors to generate the cache dir
593
594
595
            hash_key = hashlib.md5(
                str(factors).encode(), usedforsecurity=False
            ).hexdigest()[:10]
596
597

            cache_dir = os.path.join(
598
599
600
601
602
603
                envs.VLLM_CACHE_ROOT,
                "torch_compile_cache",
                hash_key,
            )
            self.compilation_config.cache_dir = cache_dir

604
        cache_dir = self.compilation_config.cache_dir
605
        os.makedirs(cache_dir, exist_ok=True)
606
        self.compilation_config.cache_dir = cache_dir
607
608
        rank = vllm_config.parallel_config.rank
        dp_rank = vllm_config.parallel_config.data_parallel_rank
609
        local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
610
        os.makedirs(local_cache_dir, exist_ok=True)
611
        self.compilation_config.local_cache_dir = local_cache_dir
612

613
614
615
        disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE

        if disable_cache:
616
            logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
617
        else:
618
619
620
621
            logger.info_once(
                "Using cache directory: %s for vLLM's torch.compile",
                local_cache_dir,
                scope="local",
622
            )
623

624
625
626
        self.compiler_manager.initialize_cache(
            local_cache_dir, disable_cache, self.prefix
        )
627

628
629
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
630
        compilation_counter.num_graphs_seen += 1
631
        from .monitor import torch_compile_start_time
632

633
        dynamo_time = time.time() - torch_compile_start_time
634
635
636
        logger.info_once(
            "Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local"
        )
637
        self.compilation_config.compilation_time += dynamo_time
638
639
640
641
642
643

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
644
        self.configure_post_pass()
645

646
647
648
649
650
651
652
653
        if self.compilation_config.use_inductor_graph_partition:
            # Let Inductor decide partitioning; avoid FX-level pre-splitting.
            fx_split_ops: list[str] = []
        else:
            fx_split_ops = self.compilation_config.splitting_ops or []

        resolved_split_ops = resolve_defined_ops(fx_split_ops)
        self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
654

655
        from torch._dynamo.utils import lazy_format_graph_code
656
657
658
659
660

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
661

662
        compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
663
        submod_names_to_compile = [
664
665
            item.submod_name
            for item in self.piecewise_graphs
666
667
668
669
670
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
671
672
673
        PiecewiseCompileInterpreter(
            self.split_gm, submod_names_to_compile, self.vllm_config, self
        ).run(*example_inputs)
674

675
676
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
677
678
            # code adapted from
            # https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30
679
            # use `print_readable` because it can include submodules
680
681
682
683
            src = (
                "from __future__ import annotations\nimport torch\n"
                + self.split_gm.print_readable(print_output=False)
            )
684
685
686
687
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

688
689
690
            logger.debug_once(
                "Computation graph saved to %s", graph_path, scope="local"
            )
691

692
693
        self._called = True

694
695
696
697
        if (
            self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
            or not self.compilation_config.cudagraph_copy_inputs
        ):
698
699
700
            return VllmSerializableFunction(
                graph, example_inputs, self.prefix, self.split_gm
            )
701
702
703

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
704

705
706
707
708
709
710
711
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
712
713
714
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
715

716
        self.sym_tensor_indices = [
717
718
719
720
            i
            for i, x in enumerate(fake_args)
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
            and any(is_symbolic(d) for d in x.size())
721
722
723
724
725
726
727
728
729
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

youkaichao's avatar
youkaichao committed
730
731
        # this is the callable we return to Dynamo to run
        def copy_and_call(*args):
732
733
734
735
736
737
738
739
740
741
742
743
744
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

745
746
747
        return VllmSerializableFunction(
            graph, example_inputs, self.prefix, copy_and_call
        )