"vllm/utils/__init__.py" did not exist on "233df6f5c4520ae57e4a24acfbaedcc9ce166074"
backends.py 26.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import ast
5
import dataclasses
6
import hashlib
7
8
import os
import pprint
9
import time
10
from collections.abc import Callable, Sequence
11
from contextlib import contextmanager
12
from typing import Any
13
14
15

import torch
import torch.fx as fx
16
from torch._dispatch.python import enable_python_dispatcher
17

18
import vllm.envs as envs
19
20
21
22
23
from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.partition_rules import (
    inductor_partition_rule_context,
    resolve_defined_ops,
)
24
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
25
from vllm.logger import init_logger
26
from vllm.platforms import current_platform
27
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
28

29
from .caching import VllmSerializableFunction
30
31
32
33
34
35
from .compiler_interface import (
    CompilerInterface,
    EagerAdaptor,
    InductorAdaptor,
    InductorStandaloneAdaptor,
)
36
from .counter import compilation_counter
37
38
from .inductor_pass import InductorPass
from .pass_manager import PostGradPassManager
39
40
41

logger = init_logger(__name__)

42

43
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
44
    if compilation_config.backend == "inductor":
45
46
        # Use standalone compile only if requested, version is new enough,
        # and the symbol actually exists in this PyTorch build.
47
48
49
50
51
        if (
            envs.VLLM_USE_STANDALONE_COMPILE
            and is_torch_equal_or_newer("2.8.0.dev")
            and hasattr(torch._inductor, "standalone_compile")
        ):
52
            logger.debug("Using InductorStandaloneAdaptor")
53
54
            return InductorStandaloneAdaptor()
        else:
55
            logger.debug("Using InductorAdaptor")
56
57
            return InductorAdaptor()
    else:
58
        assert compilation_config.backend == "eager", (
59
            "Custom backends not supported with CompilationMode.VLLM_COMPILE"
60
61
        )

62
        logger.debug("Using EagerAdaptor")
63
64
65
        return EagerAdaptor()


66
67
68
69
70
class CompilerManager:
    """
    A manager to manage the compilation process, including
    caching the compiled graph, loading the compiled graph,
    and compiling the graph.
71

72
73
74
    The cache is a dict mapping
    `(runtime_shape, graph_index, backend_name)`
    to `any_data` returned from the compiler.
75

76
77
78
    When serializing the cache, we save it to a Python file
    for readability. We don't use json here because json doesn't
    support int as key.
79
80
    """

81
    def __init__(self, compilation_config: CompilationConfig):
82
        self.cache: dict[tuple[int | None, int, str], Any] = dict()
83
        self.is_cache_updated = False
84
85
        self.compilation_config = compilation_config
        self.compiler = make_compiler(compilation_config)
86

87
88
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        return self.compiler.compute_hash(vllm_config)
89

90
    @contextmanager
91
    def compile_context(self, runtime_shape: int | None = None):
92
93
94
95
96
97
98
99
100
101
102
103
104
        """Provide compilation context for the duration of compilation to set
        any torch global properties we want to scope to a single Inductor
        compilation (e.g. partition rules, pass context)."""
        with pass_context(runtime_shape):
            if self.compilation_config.use_inductor_graph_partition:
                inductor_partition_ops = resolve_defined_ops(
                    self.compilation_config.splitting_ops
                )
                with inductor_partition_rule_context(inductor_partition_ops):
                    yield
            else:
                yield

105
106
107
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
    ):
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        """
        Initialize the cache directory for the compiler.

        The organization of the cache directory is as follows:
        cache_dir=/path/to/hash_str/rank_i_j/prefix/
        inside cache_dir, there will be:
        - vllm_compile_cache.py
        - computation_graph.py
        - transformed_code.py

        for multiple prefixes, they can share the same
        base cache dir of /path/to/hash_str/rank_i_j/ ,
        to store some common compilation artifacts.
        """

123
        self.disable_cache = disable_cache
124
        self.cache_dir = cache_dir
125
126
127
128
        self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")

        if not disable_cache and os.path.exists(self.cache_file_path):
            # load the cache from the file
129
            with open(self.cache_file_path) as f:
130
131
132
133
134
                # we use ast.literal_eval to parse the data
                # because it is a safe way to parse Python literals.
                # do not use eval(), it is unsafe.
                self.cache = ast.literal_eval(f.read())

135
136
137
        self.compiler.initialize_cache(
            cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
        )
138
139

    def save_to_file(self):
140
        if self.disable_cache or not self.is_cache_updated:
141
            return
142
143
        printer = pprint.PrettyPrinter(indent=4)
        data = printer.pformat(self.cache)
144
        with open(self.cache_file_path, "w") as f:
145
146
            f.write(data)

147
148
149
150
151
    def load(
        self,
        graph: fx.GraphModule,
        example_inputs: list[Any],
        graph_index: int,
152
153
        runtime_shape: int | None = None,
    ) -> Callable | None:
154
155
156
        if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
            return None
        handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
157
158
159
        compiled_graph = self.compiler.load(
            handle, graph, example_inputs, graph_index, runtime_shape
        )
160
161
        if runtime_shape is None:
            logger.debug(
162
163
164
165
166
                "Directly load the %s-th graph for dynamic shape from %s via handle %s",
                graph_index,
                self.compiler.name,
                handle,
            )
167
168
        else:
            logger.debug(
169
170
171
172
173
174
                "Directly load the %s-th graph for shape %s from %s via handle %s",
                graph_index,
                str(runtime_shape),
                self.compiler.name,
                handle,
            )
175
176
        return compiled_graph

177
178
179
180
181
182
183
184
    def compile(
        self,
        graph: fx.GraphModule,
        example_inputs,
        additional_inductor_config,
        compilation_config: CompilationConfig,
        graph_index: int = 0,
        num_graphs: int = 1,
185
        runtime_shape: int | None = None,
186
    ) -> Any:
187
        if graph_index == 0:
188
189
190
191
192
193
194
195
196
            # before compiling the first graph, record the start time
            global compilation_start_time
            compilation_start_time = time.time()

        compilation_counter.num_backend_compilations += 1

        compiled_graph = None

        # try to load from the cache
197
        compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
198
        if compiled_graph is not None:
199
200
201
202
203
            if graph_index == num_graphs - 1:
                # after loading the last graph for this shape, record the time.
                # there can be multiple graphs due to piecewise compilation.
                now = time.time()
                elapsed = now - compilation_start_time
204
                compilation_config.compilation_time += elapsed
205
206
207
                if runtime_shape is None:
                    logger.info(
                        "Directly load the compiled graph(s) for dynamic shape "
208
209
210
                        "from the cache, took %.3f s",
                        elapsed,
                    )
211
212
213
                else:
                    logger.info(
                        "Directly load the compiled graph(s) for shape %s "
214
215
216
217
                        "from the cache, took %.3f s",
                        str(runtime_shape),
                        elapsed,
                    )
218
219
220
221
            return compiled_graph

        # no compiler cached the graph, or the cache is disabled,
        # we need to compile it
222
223
224
225
        if isinstance(self.compiler, InductorAdaptor):
            # Let compile_fx generate a key for us
            maybe_key = None
        else:
226
            maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
227
228
229
230
231
232
233
234
235

        with self.compile_context(runtime_shape):
            compiled_graph, handle = self.compiler.compile(
                graph,
                example_inputs,
                additional_inductor_config,
                runtime_shape,
                maybe_key,
            )
236
237
238
239

        assert compiled_graph is not None, "Failed to compile the graph"

        # store the artifact in the cache
240
        if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None:
241
            self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
242
            compilation_counter.num_cache_entries_updated += 1
243
            self.is_cache_updated = True
244
245
            if graph_index == 0:
                # adds some info logging for the first graph
246
                if runtime_shape is None:
247
                    logger.info("Cache the graph for dynamic shape for later use")
248
                else:
249
250
251
                    logger.info(
                        "Cache the graph of shape %s for later use", str(runtime_shape)
                    )
252
253
            if runtime_shape is None:
                logger.debug(
254
255
256
257
258
                    "Store the %s-th graph for dynamic shape from %s via handle %s",
                    graph_index,
                    self.compiler.name,
                    handle,
                )
259
260
261
            else:
                logger.debug(
                    "Store the %s-th graph for shape %s from %s via handle %s",
262
263
264
265
266
                    graph_index,
                    str(runtime_shape),
                    self.compiler.name,
                    handle,
                )
267
268
269
270
271
272
273

        # after compiling the last graph, record the end time
        if graph_index == num_graphs - 1:
            now = time.time()
            elapsed = now - compilation_start_time
            compilation_config.compilation_time += elapsed
            if runtime_shape is None:
274
                logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed)
275
            else:
276
277
278
279
280
                logger.info(
                    "Compiling a graph for shape %s takes %.2f s",
                    runtime_shape,
                    elapsed,
                )
281

282
        return compiled_graph
283
284


285
286
287
@dataclasses.dataclass
class SplitItem:
    submod_name: str
288
    graph_id: int
289
290
291
292
    is_splitting_graph: bool
    graph: fx.GraphModule


293
def split_graph(
294
    graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
295
) -> tuple[fx.GraphModule, list[SplitItem]]:
296
297
298
299
300
301
302
    # split graph by ops
    subgraph_id = 0
    node_to_subgraph_id = {}
    split_op_graphs = []
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
303
304
305
306
307
308
        # Match node.target against resolved_ops
        # node.target can be OpOverloadPacket, need to check .default
        if node.op == "call_function" and (
            node.target in resolved_ops
            or (hasattr(node.target, "default") and node.target.default in resolved_ops)
        ):
309
310
311
312
313
314
315
316
317
318
319
320
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
321
322
        graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
    )
323

324
    outputs = []
325

326
    names = [name for (name, module) in split_gm.named_modules()]
327

328
329
330
331
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
332

333
        module = getattr(split_gm, name)
334

335
        graph_id = int(name.replace("submod_", ""))
336
        outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
337

338
    # sort by integer graph_id, rather than string name
339
    outputs.sort(key=lambda x: x.graph_id)
340

341
    return split_gm, outputs
342
343


344
345
compilation_start_time = 0.0

346
347
348
349
350
351

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
352
353
354
355
356

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
357
358
    """

359
360
361
362
363
364
365
    def __init__(
        self,
        module: torch.fx.GraphModule,
        compile_submod_names: list[str],
        vllm_config: VllmConfig,
        vllm_backend: "VllmBackend",
    ):
366
367
        super().__init__(module)
        from torch._guards import detect_fake_mode
368

369
370
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
371
372
        self.compilation_config = vllm_config.compilation_config
        self.vllm_config = vllm_config
373
        self.vllm_backend = vllm_backend
374
375
        # When True, it annoyingly dumps the torch.fx.Graph on errors.
        self.extra_traceback = False
376
377
378
379
380
381

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
382
        with self.fake_mode, enable_python_dispatcher():
383
            return super().run(*fake_args)
384

385
386
387
388
389
390
    def call_module(
        self,
        target: torch.fx.node.Target,
        args: tuple[torch.fx.node.Argument, ...],
        kwargs: dict[str, Any],
    ) -> Any:
391
392
393
394
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
395
            index = self.compile_submod_names.index(target)
396
397
398
399
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
400
            global compilation_start_time
401

402
403
404
405
406
407
408
409
410
411
412
            compiled_graph_for_dynamic_shape = (
                self.vllm_backend.compiler_manager.compile(
                    submod,
                    args,
                    self.compilation_config.inductor_compile_config,
                    self.compilation_config,
                    graph_index=index,
                    num_graphs=len(self.compile_submod_names),
                    runtime_shape=None,
                )
            )
413
            # Lazy import here to avoid circular import
414
            from .piecewise_backend import PiecewiseBackend
415

416
            piecewise_backend = PiecewiseBackend(
417
418
419
420
421
422
423
424
                submod,
                self.vllm_config,
                index,
                len(self.compile_submod_names),
                sym_shape_indices,
                compiled_graph_for_dynamic_shape,
                self.vllm_backend,
            )
425

426
427
428
429
            if (
                self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
                and not self.compilation_config.use_inductor_graph_partition
            ):
430
431
432
433
                # We're using Dynamo-based piecewise splitting, so we wrap
                # the whole subgraph with a static graph wrapper.
                from .cuda_graph import CUDAGraphOptions

434
435
436
                # resolve the static graph wrapper class (e.g. CUDAGraphWrapper
                # class) as platform dependent.
                static_graph_wrapper_class = resolve_obj_by_qualname(
437
438
                    current_platform.get_static_graph_wrapper_cls()
                )
439
440
441
442
443
444
445
446
447
448
449
450

                # Always assign PIECEWISE runtime mode to the
                # CUDAGraphWrapper for piecewise_backend, to distinguish
                # it from the FULL cudagraph runtime mode, no matter it
                # is wrapped on a full or piecewise fx graph.
                self.module.__dict__[target] = static_graph_wrapper_class(
                    runnable=piecewise_backend,
                    vllm_config=self.vllm_config,
                    runtime_mode=CUDAGraphMode.PIECEWISE,
                    cudagraph_options=CUDAGraphOptions(
                        debug_log_enable=piecewise_backend.is_first_graph,
                        gc_disable=not piecewise_backend.is_first_graph,
451
452
453
                        weak_ref_output=piecewise_backend.is_last_graph,
                    ),
                )
454
455
456
            else:
                self.module.__dict__[target] = piecewise_backend

457
458
459
460
461
            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


462
463
464
465
466
467
468
469
470
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"


@contextmanager
def set_model_tag(tag: str):
    """Context manager to set the model tag."""
    global model_tag
471
    assert tag != model_tag, (
472
        f"Model tag {tag} is the same as the current tag {model_tag}."
473
    )
474
475
476
477
478
479
480
481
    old_tag = model_tag
    model_tag = tag
    try:
        yield
    finally:
        model_tag = old_tag


482
class VllmBackend:
483
    """The compilation backend for `torch.compile` with vLLM.
484
    It is used for compilation mode of `CompilationMode.VLLM_COMPILE`,
485
    where we customize the compilation.
486

487
488
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
489

490
491
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
492
    """
493

494
495
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
496
497
498
499
500
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
501
    piecewise_graphs: list[SplitItem]
502
    returned_callable: Callable
503
504
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
505
506
    sym_tensor_indices: list[int]
    input_buffers: list[torch.Tensor]
507
    compiler_manager: CompilerManager
508

509
510
    def __init__(
        self,
511
        vllm_config: VllmConfig,
512
        prefix: str = "",
513
    ):
514
515
        # if the model is initialized with a non-empty prefix,
        # then usually it's enough to use that prefix,
516
        # e.g. language_model, vision_model, etc.
517
518
519
520
521
        # when multiple parts are initialized as independent
        # models, we need to use the model_tag to distinguish
        # them, e.g. backbone (default), eagle_head, etc.
        self.prefix = prefix or model_tag

522
523
        # Passes to run on the graph post-grad.
        self.post_grad_pass_manager = PostGradPassManager()
524

525
526
527
        self.sym_tensor_indices = []
        self.input_buffers = []

528
529
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
530

531
        self.compiler_manager: CompilerManager = CompilerManager(
532
533
            self.compilation_config
        )
534

535
536
537
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

538
    def configure_post_pass(self):
539
        config = self.compilation_config
540
        self.post_grad_pass_manager.configure(self.vllm_config)
541

542
543
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
544
        inductor_config = config.inductor_compile_config
545
546
        PASS_KEY = "post_grad_custom_post_pass"
        if PASS_KEY in inductor_config:
547
            if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
548
                # PassManager already added to config, make sure it's correct
549
550
551
552
                assert (
                    inductor_config[PASS_KEY].uuid()
                    == self.post_grad_pass_manager.uuid()
                )
553
            else:
554
                # Config should automatically wrap all inductor passes
555
556
                assert isinstance(inductor_config[PASS_KEY], InductorPass)
                self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
557
        inductor_config[PASS_KEY] = self.post_grad_pass_manager
558

559
560
561
562
563
    def __call__(
        self, graph: fx.GraphModule, example_inputs
    ) -> VllmSerializableFunction:
        from .caching import _compute_code_hash, compilation_config_hash_factors

564
        vllm_config = self.vllm_config
565
566
567
568
569
570
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.

571
            factors = compilation_config_hash_factors(vllm_config)
572
573
            # 2. factors come from the code files that are traced by Dynamo (
            #    it mainly summarizes how the model is used in forward pass)
574
            code_hash = _compute_code_hash(self.compilation_config.traced_files)
575
            self.compilation_config.traced_files.clear()
576
577
578
579
580
581
582
            factors.append(code_hash)

            # 3. compiler hash
            compiler_hash = self.compiler_manager.compute_hash(vllm_config)
            factors.append(compiler_hash)

            # combine all factors to generate the cache dir
583
584
585
            hash_key = hashlib.md5(
                str(factors).encode(), usedforsecurity=False
            ).hexdigest()[:10]
586
587

            cache_dir = os.path.join(
588
589
590
591
592
593
                envs.VLLM_CACHE_ROOT,
                "torch_compile_cache",
                hash_key,
            )
            self.compilation_config.cache_dir = cache_dir

594
        cache_dir = self.compilation_config.cache_dir
595
        os.makedirs(cache_dir, exist_ok=True)
596
        self.compilation_config.cache_dir = cache_dir
597
598
        rank = vllm_config.parallel_config.rank
        dp_rank = vllm_config.parallel_config.data_parallel_rank
599
        local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
600
        os.makedirs(local_cache_dir, exist_ok=True)
601
        self.compilation_config.local_cache_dir = local_cache_dir
602

603
604
605
        disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE

        if disable_cache:
606
607
            logger.info("vLLM's torch.compile cache is disabled.")
        else:
608
609
610
            logger.info(
                "Using cache directory: %s for vLLM's torch.compile", local_cache_dir
            )
611

612
613
614
        self.compiler_manager.initialize_cache(
            local_cache_dir, disable_cache, self.prefix
        )
615

616
617
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
618
        compilation_counter.num_graphs_seen += 1
619
        from .monitor import torch_compile_start_time
620

621
622
        dynamo_time = time.time() - torch_compile_start_time
        logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
623
        self.compilation_config.compilation_time += dynamo_time
624
625
626
627
628
629

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
630
        self.configure_post_pass()
631

632
633
634
635
636
637
638
639
        if self.compilation_config.use_inductor_graph_partition:
            # Let Inductor decide partitioning; avoid FX-level pre-splitting.
            fx_split_ops: list[str] = []
        else:
            fx_split_ops = self.compilation_config.splitting_ops or []

        resolved_split_ops = resolve_defined_ops(fx_split_ops)
        self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
640

641
        from torch._dynamo.utils import lazy_format_graph_code
642
643
644
645
646

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
647

648
        compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
649
        submod_names_to_compile = [
650
651
            item.submod_name
            for item in self.piecewise_graphs
652
653
654
655
656
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
657
658
659
        PiecewiseCompileInterpreter(
            self.split_gm, submod_names_to_compile, self.vllm_config, self
        ).run(*example_inputs)
660

661
662
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
663
664
            # code adapted from
            # https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30
665
            # use `print_readable` because it can include submodules
666
667
668
669
            src = (
                "from __future__ import annotations\nimport torch\n"
                + self.split_gm.print_readable(print_output=False)
            )
670
671
672
673
674
675
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

            logger.debug("Computation graph saved to %s", graph_path)

676
677
        self._called = True

678
679
680
681
        if (
            self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
            or not self.compilation_config.cudagraph_copy_inputs
        ):
682
683
684
            return VllmSerializableFunction(
                graph, example_inputs, self.prefix, self.split_gm
            )
685
686
687

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
688

689
690
691
692
693
694
695
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
696
697
698
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
699

700
        self.sym_tensor_indices = [
701
702
703
704
            i
            for i, x in enumerate(fake_args)
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
            and any(is_symbolic(d) for d in x.size())
705
706
707
708
709
710
711
712
713
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

youkaichao's avatar
youkaichao committed
714
715
        # this is the callable we return to Dynamo to run
        def copy_and_call(*args):
716
717
718
719
720
721
722
723
724
725
726
727
728
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

729
730
731
        return VllmSerializableFunction(
            graph, example_inputs, self.prefix, copy_and_call
        )