backends.py 26.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import ast
5
import dataclasses
6
import hashlib
7
8
import os
import pprint
9
import time
10
from collections.abc import Callable, Sequence
11
from contextlib import contextmanager
12
from typing import Any
13
14
15

import torch
import torch.fx as fx
16
from torch._dispatch.python import enable_python_dispatcher
17

18
import vllm.envs as envs
19
20
21
22
23
from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.partition_rules import (
    inductor_partition_rule_context,
    resolve_defined_ops,
)
24
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
25
from vllm.logger import init_logger
26
from vllm.platforms import current_platform
27
from vllm.utils.import_utils import resolve_obj_by_qualname
28
from vllm.utils.torch_utils import is_torch_equal_or_newer
29

30
from .caching import VllmSerializableFunction
31
32
33
34
35
36
from .compiler_interface import (
    CompilerInterface,
    EagerAdaptor,
    InductorAdaptor,
    InductorStandaloneAdaptor,
)
37
from .counter import compilation_counter
38
39
from .inductor_pass import InductorPass
from .pass_manager import PostGradPassManager
40
41
42

logger = init_logger(__name__)

43

44
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
45
    if compilation_config.backend == "inductor":
46
47
        # Use standalone compile only if requested, version is new enough,
        # and the symbol actually exists in this PyTorch build.
48
49
50
51
52
        if (
            envs.VLLM_USE_STANDALONE_COMPILE
            and is_torch_equal_or_newer("2.8.0.dev")
            and hasattr(torch._inductor, "standalone_compile")
        ):
53
            logger.debug("Using InductorStandaloneAdaptor")
54
55
            return InductorStandaloneAdaptor()
        else:
56
            logger.debug("Using InductorAdaptor")
57
58
            return InductorAdaptor()
    else:
59
        assert compilation_config.backend == "eager", (
60
            "Custom backends not supported with CompilationMode.VLLM_COMPILE"
61
62
        )

63
        logger.debug("Using EagerAdaptor")
64
65
66
        return EagerAdaptor()


67
68
69
70
71
class CompilerManager:
    """
    A manager to manage the compilation process, including
    caching the compiled graph, loading the compiled graph,
    and compiling the graph.
72

73
74
75
    The cache is a dict mapping
    `(runtime_shape, graph_index, backend_name)`
    to `any_data` returned from the compiler.
76

77
78
79
    When serializing the cache, we save it to a Python file
    for readability. We don't use json here because json doesn't
    support int as key.
80
81
    """

82
    def __init__(self, compilation_config: CompilationConfig):
83
        self.cache: dict[tuple[int | None, int, str], Any] = dict()
84
        self.is_cache_updated = False
85
86
        self.compilation_config = compilation_config
        self.compiler = make_compiler(compilation_config)
87

88
89
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        return self.compiler.compute_hash(vllm_config)
90

91
    @contextmanager
92
    def compile_context(self, runtime_shape: int | None = None):
93
94
95
96
97
98
99
100
101
102
103
104
105
        """Provide compilation context for the duration of compilation to set
        any torch global properties we want to scope to a single Inductor
        compilation (e.g. partition rules, pass context)."""
        with pass_context(runtime_shape):
            if self.compilation_config.use_inductor_graph_partition:
                inductor_partition_ops = resolve_defined_ops(
                    self.compilation_config.splitting_ops
                )
                with inductor_partition_rule_context(inductor_partition_ops):
                    yield
            else:
                yield

106
107
108
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
    ):
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        """
        Initialize the cache directory for the compiler.

        The organization of the cache directory is as follows:
        cache_dir=/path/to/hash_str/rank_i_j/prefix/
        inside cache_dir, there will be:
        - vllm_compile_cache.py
        - computation_graph.py
        - transformed_code.py

        for multiple prefixes, they can share the same
        base cache dir of /path/to/hash_str/rank_i_j/ ,
        to store some common compilation artifacts.
        """

124
        self.disable_cache = disable_cache
125
        self.cache_dir = cache_dir
126
127
128
129
        self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")

        if not disable_cache and os.path.exists(self.cache_file_path):
            # load the cache from the file
130
            with open(self.cache_file_path) as f:
131
132
133
134
135
                # we use ast.literal_eval to parse the data
                # because it is a safe way to parse Python literals.
                # do not use eval(), it is unsafe.
                self.cache = ast.literal_eval(f.read())

136
137
138
        self.compiler.initialize_cache(
            cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
        )
139
140

    def save_to_file(self):
141
        if self.disable_cache or not self.is_cache_updated:
142
            return
143
144
        printer = pprint.PrettyPrinter(indent=4)
        data = printer.pformat(self.cache)
145
        with open(self.cache_file_path, "w") as f:
146
147
            f.write(data)

148
149
150
151
152
    def load(
        self,
        graph: fx.GraphModule,
        example_inputs: list[Any],
        graph_index: int,
153
154
        runtime_shape: int | None = None,
    ) -> Callable | None:
155
156
157
        if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
            return None
        handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
158
159
160
        compiled_graph = self.compiler.load(
            handle, graph, example_inputs, graph_index, runtime_shape
        )
161
162
        if runtime_shape is None:
            logger.debug(
163
164
165
166
167
                "Directly load the %s-th graph for dynamic shape from %s via handle %s",
                graph_index,
                self.compiler.name,
                handle,
            )
168
169
        else:
            logger.debug(
170
171
172
173
174
175
                "Directly load the %s-th graph for shape %s from %s via handle %s",
                graph_index,
                str(runtime_shape),
                self.compiler.name,
                handle,
            )
176
177
        return compiled_graph

178
179
180
181
182
183
184
185
    def compile(
        self,
        graph: fx.GraphModule,
        example_inputs,
        additional_inductor_config,
        compilation_config: CompilationConfig,
        graph_index: int = 0,
        num_graphs: int = 1,
186
        runtime_shape: int | None = None,
187
    ) -> Any:
188
        if graph_index == 0:
189
190
191
192
193
194
195
196
197
            # before compiling the first graph, record the start time
            global compilation_start_time
            compilation_start_time = time.time()

        compilation_counter.num_backend_compilations += 1

        compiled_graph = None

        # try to load from the cache
198
        compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
199
        if compiled_graph is not None:
200
201
202
203
204
            if graph_index == num_graphs - 1:
                # after loading the last graph for this shape, record the time.
                # there can be multiple graphs due to piecewise compilation.
                now = time.time()
                elapsed = now - compilation_start_time
205
                compilation_config.compilation_time += elapsed
206
207
208
                if runtime_shape is None:
                    logger.info(
                        "Directly load the compiled graph(s) for dynamic shape "
209
210
211
                        "from the cache, took %.3f s",
                        elapsed,
                    )
212
213
214
                else:
                    logger.info(
                        "Directly load the compiled graph(s) for shape %s "
215
216
217
218
                        "from the cache, took %.3f s",
                        str(runtime_shape),
                        elapsed,
                    )
219
220
221
222
            return compiled_graph

        # no compiler cached the graph, or the cache is disabled,
        # we need to compile it
223
224
225
226
        if isinstance(self.compiler, InductorAdaptor):
            # Let compile_fx generate a key for us
            maybe_key = None
        else:
227
            maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
228
229
230
231
232
233
234
235
236

        with self.compile_context(runtime_shape):
            compiled_graph, handle = self.compiler.compile(
                graph,
                example_inputs,
                additional_inductor_config,
                runtime_shape,
                maybe_key,
            )
237
238
239
240

        assert compiled_graph is not None, "Failed to compile the graph"

        # store the artifact in the cache
241
        if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None:
242
            self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
243
            compilation_counter.num_cache_entries_updated += 1
244
            self.is_cache_updated = True
245
246
            if graph_index == 0:
                # adds some info logging for the first graph
247
                if runtime_shape is None:
248
                    logger.info("Cache the graph for dynamic shape for later use")
249
                else:
250
251
252
                    logger.info(
                        "Cache the graph of shape %s for later use", str(runtime_shape)
                    )
253
254
            if runtime_shape is None:
                logger.debug(
255
256
257
258
259
                    "Store the %s-th graph for dynamic shape from %s via handle %s",
                    graph_index,
                    self.compiler.name,
                    handle,
                )
260
261
262
            else:
                logger.debug(
                    "Store the %s-th graph for shape %s from %s via handle %s",
263
264
265
266
267
                    graph_index,
                    str(runtime_shape),
                    self.compiler.name,
                    handle,
                )
268
269
270
271
272
273
274

        # after compiling the last graph, record the end time
        if graph_index == num_graphs - 1:
            now = time.time()
            elapsed = now - compilation_start_time
            compilation_config.compilation_time += elapsed
            if runtime_shape is None:
275
                logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed)
276
            else:
277
278
279
280
281
                logger.info(
                    "Compiling a graph for shape %s takes %.2f s",
                    runtime_shape,
                    elapsed,
                )
282

283
        return compiled_graph
284
285


286
287
288
@dataclasses.dataclass
class SplitItem:
    submod_name: str
289
    graph_id: int
290
291
292
293
    is_splitting_graph: bool
    graph: fx.GraphModule


294
def split_graph(
295
    graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
296
) -> tuple[fx.GraphModule, list[SplitItem]]:
297
298
299
300
301
302
303
    # split graph by ops
    subgraph_id = 0
    node_to_subgraph_id = {}
    split_op_graphs = []
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
304
305
306
307
308
309
        # Match node.target against resolved_ops
        # node.target can be OpOverloadPacket, need to check .default
        if node.op == "call_function" and (
            node.target in resolved_ops
            or (hasattr(node.target, "default") and node.target.default in resolved_ops)
        ):
310
311
312
313
314
315
316
317
318
319
320
321
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
322
323
        graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
    )
324

325
    outputs = []
326

327
    names = [name for (name, module) in split_gm.named_modules()]
328

329
330
331
332
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
333

334
        module = getattr(split_gm, name)
335

336
        graph_id = int(name.replace("submod_", ""))
337
        outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
338

339
    # sort by integer graph_id, rather than string name
340
    outputs.sort(key=lambda x: x.graph_id)
341

342
    return split_gm, outputs
343
344


345
346
compilation_start_time = 0.0

347
348
349
350
351
352

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
353
354
355
356
357

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
358
359
    """

360
361
362
363
364
365
366
    def __init__(
        self,
        module: torch.fx.GraphModule,
        compile_submod_names: list[str],
        vllm_config: VllmConfig,
        vllm_backend: "VllmBackend",
    ):
367
368
        super().__init__(module)
        from torch._guards import detect_fake_mode
369

370
371
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
372
373
        self.compilation_config = vllm_config.compilation_config
        self.vllm_config = vllm_config
374
        self.vllm_backend = vllm_backend
375
376
        # When True, it annoyingly dumps the torch.fx.Graph on errors.
        self.extra_traceback = False
377
378
379
380
381
382

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
383
        with self.fake_mode, enable_python_dispatcher():
384
            return super().run(*fake_args)
385

386
387
388
389
390
391
    def call_module(
        self,
        target: torch.fx.node.Target,
        args: tuple[torch.fx.node.Argument, ...],
        kwargs: dict[str, Any],
    ) -> Any:
392
393
394
395
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
396
            index = self.compile_submod_names.index(target)
397
398
399
400
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
401
            global compilation_start_time
402

403
404
405
406
407
408
409
410
411
412
413
            compiled_graph_for_dynamic_shape = (
                self.vllm_backend.compiler_manager.compile(
                    submod,
                    args,
                    self.compilation_config.inductor_compile_config,
                    self.compilation_config,
                    graph_index=index,
                    num_graphs=len(self.compile_submod_names),
                    runtime_shape=None,
                )
            )
414
            # Lazy import here to avoid circular import
415
            from .piecewise_backend import PiecewiseBackend
416

417
            piecewise_backend = PiecewiseBackend(
418
419
420
421
422
423
424
425
                submod,
                self.vllm_config,
                index,
                len(self.compile_submod_names),
                sym_shape_indices,
                compiled_graph_for_dynamic_shape,
                self.vllm_backend,
            )
426

427
428
429
430
            if (
                self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
                and not self.compilation_config.use_inductor_graph_partition
            ):
431
432
433
434
                # We're using Dynamo-based piecewise splitting, so we wrap
                # the whole subgraph with a static graph wrapper.
                from .cuda_graph import CUDAGraphOptions

435
436
437
                # resolve the static graph wrapper class (e.g. CUDAGraphWrapper
                # class) as platform dependent.
                static_graph_wrapper_class = resolve_obj_by_qualname(
438
439
                    current_platform.get_static_graph_wrapper_cls()
                )
440
441
442
443
444
445
446
447
448
449
450
451

                # Always assign PIECEWISE runtime mode to the
                # CUDAGraphWrapper for piecewise_backend, to distinguish
                # it from the FULL cudagraph runtime mode, no matter it
                # is wrapped on a full or piecewise fx graph.
                self.module.__dict__[target] = static_graph_wrapper_class(
                    runnable=piecewise_backend,
                    vllm_config=self.vllm_config,
                    runtime_mode=CUDAGraphMode.PIECEWISE,
                    cudagraph_options=CUDAGraphOptions(
                        debug_log_enable=piecewise_backend.is_first_graph,
                        gc_disable=not piecewise_backend.is_first_graph,
452
453
454
                        weak_ref_output=piecewise_backend.is_last_graph,
                    ),
                )
455
456
457
            else:
                self.module.__dict__[target] = piecewise_backend

458
459
460
461
462
            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


463
464
465
466
467
468
469
470
471
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"


@contextmanager
def set_model_tag(tag: str):
    """Context manager to set the model tag."""
    global model_tag
472
    assert tag != model_tag, (
473
        f"Model tag {tag} is the same as the current tag {model_tag}."
474
    )
475
476
477
478
479
480
481
482
    old_tag = model_tag
    model_tag = tag
    try:
        yield
    finally:
        model_tag = old_tag


483
class VllmBackend:
484
    """The compilation backend for `torch.compile` with vLLM.
485
    It is used for compilation mode of `CompilationMode.VLLM_COMPILE`,
486
    where we customize the compilation.
487

488
489
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
490

491
492
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
493
    """
494

495
496
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
497
498
499
500
501
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
502
    piecewise_graphs: list[SplitItem]
503
    returned_callable: Callable
504
505
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
506
507
    sym_tensor_indices: list[int]
    input_buffers: list[torch.Tensor]
508
    compiler_manager: CompilerManager
509

510
511
    def __init__(
        self,
512
        vllm_config: VllmConfig,
513
        prefix: str = "",
514
    ):
515
516
        # if the model is initialized with a non-empty prefix,
        # then usually it's enough to use that prefix,
517
        # e.g. language_model, vision_model, etc.
518
519
520
521
522
        # when multiple parts are initialized as independent
        # models, we need to use the model_tag to distinguish
        # them, e.g. backbone (default), eagle_head, etc.
        self.prefix = prefix or model_tag

523
524
        # Passes to run on the graph post-grad.
        self.post_grad_pass_manager = PostGradPassManager()
525

526
527
528
        self.sym_tensor_indices = []
        self.input_buffers = []

529
530
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
531

532
        self.compiler_manager: CompilerManager = CompilerManager(
533
534
            self.compilation_config
        )
535

536
537
538
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

539
    def configure_post_pass(self):
540
        config = self.compilation_config
541
        self.post_grad_pass_manager.configure(self.vllm_config)
542

543
544
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
545
        inductor_config = config.inductor_compile_config
546
547
        PASS_KEY = "post_grad_custom_post_pass"
        if PASS_KEY in inductor_config:
548
            if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
549
                # PassManager already added to config, make sure it's correct
550
551
552
553
                assert (
                    inductor_config[PASS_KEY].uuid()
                    == self.post_grad_pass_manager.uuid()
                )
554
            else:
555
                # Config should automatically wrap all inductor passes
556
557
                assert isinstance(inductor_config[PASS_KEY], InductorPass)
                self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
558
        inductor_config[PASS_KEY] = self.post_grad_pass_manager
559

560
561
562
563
564
    def __call__(
        self, graph: fx.GraphModule, example_inputs
    ) -> VllmSerializableFunction:
        from .caching import _compute_code_hash, compilation_config_hash_factors

565
        vllm_config = self.vllm_config
566
567
568
569
570
571
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.

572
            factors = compilation_config_hash_factors(vllm_config)
573
574
            # 2. factors come from the code files that are traced by Dynamo (
            #    it mainly summarizes how the model is used in forward pass)
575
            code_hash = _compute_code_hash(self.compilation_config.traced_files)
576
            self.compilation_config.traced_files.clear()
577
578
579
580
581
582
583
            factors.append(code_hash)

            # 3. compiler hash
            compiler_hash = self.compiler_manager.compute_hash(vllm_config)
            factors.append(compiler_hash)

            # combine all factors to generate the cache dir
584
585
586
            hash_key = hashlib.md5(
                str(factors).encode(), usedforsecurity=False
            ).hexdigest()[:10]
587
588

            cache_dir = os.path.join(
589
590
591
592
593
594
                envs.VLLM_CACHE_ROOT,
                "torch_compile_cache",
                hash_key,
            )
            self.compilation_config.cache_dir = cache_dir

595
        cache_dir = self.compilation_config.cache_dir
596
        os.makedirs(cache_dir, exist_ok=True)
597
        self.compilation_config.cache_dir = cache_dir
598
599
        rank = vllm_config.parallel_config.rank
        dp_rank = vllm_config.parallel_config.data_parallel_rank
600
        local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
601
        os.makedirs(local_cache_dir, exist_ok=True)
602
        self.compilation_config.local_cache_dir = local_cache_dir
603

604
605
606
        disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE

        if disable_cache:
607
608
            logger.info("vLLM's torch.compile cache is disabled.")
        else:
609
610
611
            logger.info(
                "Using cache directory: %s for vLLM's torch.compile", local_cache_dir
            )
612

613
614
615
        self.compiler_manager.initialize_cache(
            local_cache_dir, disable_cache, self.prefix
        )
616

617
618
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
619
        compilation_counter.num_graphs_seen += 1
620
        from .monitor import torch_compile_start_time
621

622
623
        dynamo_time = time.time() - torch_compile_start_time
        logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
624
        self.compilation_config.compilation_time += dynamo_time
625
626
627
628
629
630

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
631
        self.configure_post_pass()
632

633
634
635
636
637
638
639
640
        if self.compilation_config.use_inductor_graph_partition:
            # Let Inductor decide partitioning; avoid FX-level pre-splitting.
            fx_split_ops: list[str] = []
        else:
            fx_split_ops = self.compilation_config.splitting_ops or []

        resolved_split_ops = resolve_defined_ops(fx_split_ops)
        self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
641

642
        from torch._dynamo.utils import lazy_format_graph_code
643
644
645
646
647

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
648

649
        compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
650
        submod_names_to_compile = [
651
652
            item.submod_name
            for item in self.piecewise_graphs
653
654
655
656
657
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
658
659
660
        PiecewiseCompileInterpreter(
            self.split_gm, submod_names_to_compile, self.vllm_config, self
        ).run(*example_inputs)
661

662
663
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
664
665
            # code adapted from
            # https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30
666
            # use `print_readable` because it can include submodules
667
668
669
670
            src = (
                "from __future__ import annotations\nimport torch\n"
                + self.split_gm.print_readable(print_output=False)
            )
671
672
673
674
675
676
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

            logger.debug("Computation graph saved to %s", graph_path)

677
678
        self._called = True

679
680
681
682
        if (
            self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
            or not self.compilation_config.cudagraph_copy_inputs
        ):
683
684
685
            return VllmSerializableFunction(
                graph, example_inputs, self.prefix, self.split_gm
            )
686
687
688

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
689

690
691
692
693
694
695
696
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
697
698
699
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
700

701
        self.sym_tensor_indices = [
702
703
704
705
            i
            for i, x in enumerate(fake_args)
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
            and any(is_symbolic(d) for d in x.size())
706
707
708
709
710
711
712
713
714
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

youkaichao's avatar
youkaichao committed
715
716
        # this is the callable we return to Dynamo to run
        def copy_and_call(*args):
717
718
719
720
721
722
723
724
725
726
727
728
729
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

730
731
732
        return VllmSerializableFunction(
            graph, example_inputs, self.prefix, copy_and_call
        )