backends.py 26 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import ast
5
import dataclasses
6
7
import os
import pprint
8
import time
9
from collections.abc import Sequence
10
from contextlib import contextmanager
11
from typing import Any, Callable, Optional
12
13
14

import torch
import torch.fx as fx
15
from torch._dispatch.python import enable_python_dispatcher
16

17
import vllm.envs as envs
18
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
19
from vllm.logger import init_logger
20
from vllm.platforms import current_platform
21
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
22

23
24
25
26
27
28
from .compiler_interface import (
    CompilerInterface,
    EagerAdaptor,
    InductorAdaptor,
    InductorStandaloneAdaptor,
)
29
from .counter import compilation_counter
30
31
from .inductor_pass import InductorPass
from .pass_manager import PostGradPassManager
32
33
34

logger = init_logger(__name__)

35

36
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
37
    if compilation_config.backend == "inductor":
38
39
        # Use standalone compile only if requested, version is new enough,
        # and the symbol actually exists in this PyTorch build.
40
41
42
43
44
        if (
            envs.VLLM_USE_STANDALONE_COMPILE
            and is_torch_equal_or_newer("2.8.0.dev")
            and hasattr(torch._inductor, "standalone_compile")
        ):
45
            logger.debug("Using InductorStandaloneAdaptor")
46
47
            return InductorStandaloneAdaptor()
        else:
48
            logger.debug("Using InductorAdaptor")
49
50
            return InductorAdaptor()
    else:
51
52
53
54
        assert compilation_config.backend == "eager", (
            "Custom backends not supported with CompilationLevel.PIECEWISE"
        )

55
        logger.debug("Using EagerAdaptor")
56
57
58
        return EagerAdaptor()


59
60
61
62
63
class CompilerManager:
    """
    A manager to manage the compilation process, including
    caching the compiled graph, loading the compiled graph,
    and compiling the graph.
64

65
66
67
    The cache is a dict mapping
    `(runtime_shape, graph_index, backend_name)`
    to `any_data` returned from the compiler.
68

69
70
71
    When serializing the cache, we save it to a Python file
    for readability. We don't use json here because json doesn't
    support int as key.
72
73
    """

74
    def __init__(self, compilation_config: CompilationConfig):
75
        self.cache: dict[tuple[Optional[int], int, str], Any] = dict()
76
        self.is_cache_updated = False
77
78
        self.compilation_config = compilation_config
        self.compiler = make_compiler(compilation_config)
79

80
81
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        return self.compiler.compute_hash(vllm_config)
82

83
84
85
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
    ):
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        """
        Initialize the cache directory for the compiler.

        The organization of the cache directory is as follows:
        cache_dir=/path/to/hash_str/rank_i_j/prefix/
        inside cache_dir, there will be:
        - vllm_compile_cache.py
        - computation_graph.py
        - transformed_code.py

        for multiple prefixes, they can share the same
        base cache dir of /path/to/hash_str/rank_i_j/ ,
        to store some common compilation artifacts.
        """

101
        self.disable_cache = disable_cache
102
        self.cache_dir = cache_dir
103
104
105
106
        self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")

        if not disable_cache and os.path.exists(self.cache_file_path):
            # load the cache from the file
107
            with open(self.cache_file_path) as f:
108
109
110
111
112
                # we use ast.literal_eval to parse the data
                # because it is a safe way to parse Python literals.
                # do not use eval(), it is unsafe.
                self.cache = ast.literal_eval(f.read())

113
114
115
        self.compiler.initialize_cache(
            cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
        )
116
117

    def save_to_file(self):
118
        if self.disable_cache or not self.is_cache_updated:
119
            return
120
121
        printer = pprint.PrettyPrinter(indent=4)
        data = printer.pformat(self.cache)
122
        with open(self.cache_file_path, "w") as f:
123
124
            f.write(data)

125
126
127
128
129
130
131
    def load(
        self,
        graph: fx.GraphModule,
        example_inputs: list[Any],
        graph_index: int,
        runtime_shape: Optional[int] = None,
    ) -> Optional[Callable]:
132
133
134
        if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
            return None
        handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
135
136
137
        compiled_graph = self.compiler.load(
            handle, graph, example_inputs, graph_index, runtime_shape
        )
138
139
        if runtime_shape is None:
            logger.debug(
140
141
142
143
144
                "Directly load the %s-th graph for dynamic shape from %s via handle %s",
                graph_index,
                self.compiler.name,
                handle,
            )
145
146
        else:
            logger.debug(
147
148
149
150
151
152
                "Directly load the %s-th graph for shape %s from %s via handle %s",
                graph_index,
                str(runtime_shape),
                self.compiler.name,
                handle,
            )
153
154
        return compiled_graph

155
156
157
158
159
160
161
162
163
164
    def compile(
        self,
        graph: fx.GraphModule,
        example_inputs,
        additional_inductor_config,
        compilation_config: CompilationConfig,
        graph_index: int = 0,
        num_graphs: int = 1,
        runtime_shape: Optional[int] = None,
    ) -> Any:
165
        if graph_index == 0:
166
167
168
169
170
171
172
173
174
            # before compiling the first graph, record the start time
            global compilation_start_time
            compilation_start_time = time.time()

        compilation_counter.num_backend_compilations += 1

        compiled_graph = None

        # try to load from the cache
175
        compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
176
        if compiled_graph is not None:
177
178
179
180
181
            if graph_index == num_graphs - 1:
                # after loading the last graph for this shape, record the time.
                # there can be multiple graphs due to piecewise compilation.
                now = time.time()
                elapsed = now - compilation_start_time
182
183
184
                if runtime_shape is None:
                    logger.info(
                        "Directly load the compiled graph(s) for dynamic shape "
185
186
187
                        "from the cache, took %.3f s",
                        elapsed,
                    )
188
189
190
                else:
                    logger.info(
                        "Directly load the compiled graph(s) for shape %s "
191
192
193
194
                        "from the cache, took %.3f s",
                        str(runtime_shape),
                        elapsed,
                    )
195
196
197
198
            return compiled_graph

        # no compiler cached the graph, or the cache is disabled,
        # we need to compile it
199
200
201
202
        if isinstance(self.compiler, InductorAdaptor):
            # Let compile_fx generate a key for us
            maybe_key = None
        else:
203
            maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
204
        compiled_graph, handle = self.compiler.compile(
205
206
            graph, example_inputs, additional_inductor_config, runtime_shape, maybe_key
        )
207
208
209
210

        assert compiled_graph is not None, "Failed to compile the graph"

        # store the artifact in the cache
211
        if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None:
212
            self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
213
            compilation_counter.num_cache_entries_updated += 1
214
            self.is_cache_updated = True
215
216
            if graph_index == 0:
                # adds some info logging for the first graph
217
                if runtime_shape is None:
218
                    logger.info("Cache the graph for dynamic shape for later use")
219
                else:
220
221
222
                    logger.info(
                        "Cache the graph of shape %s for later use", str(runtime_shape)
                    )
223
224
            if runtime_shape is None:
                logger.debug(
225
226
227
228
229
                    "Store the %s-th graph for dynamic shape from %s via handle %s",
                    graph_index,
                    self.compiler.name,
                    handle,
                )
230
231
232
            else:
                logger.debug(
                    "Store the %s-th graph for shape %s from %s via handle %s",
233
234
235
236
237
                    graph_index,
                    str(runtime_shape),
                    self.compiler.name,
                    handle,
                )
238
239
240
241
242
243
244

        # after compiling the last graph, record the end time
        if graph_index == num_graphs - 1:
            now = time.time()
            elapsed = now - compilation_start_time
            compilation_config.compilation_time += elapsed
            if runtime_shape is None:
245
                logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed)
246
            else:
247
248
249
250
251
                logger.info(
                    "Compiling a graph for shape %s takes %.2f s",
                    runtime_shape,
                    elapsed,
                )
252

253
        return compiled_graph
254
255


256
257
258
@dataclasses.dataclass
class SplitItem:
    submod_name: str
259
    graph_id: int
260
261
262
263
    is_splitting_graph: bool
    graph: fx.GraphModule


264
265
266
def split_graph(
    graph: fx.GraphModule, ops: list[str]
) -> tuple[fx.GraphModule, list[SplitItem]]:
267
268
269
270
271
272
273
    # split graph by ops
    subgraph_id = 0
    node_to_subgraph_id = {}
    split_op_graphs = []
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
274
        if node.op == "call_function" and str(node.target) in ops:
275
276
277
278
279
280
281
282
283
284
285
286
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
287
288
        graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
    )
289

290
    outputs = []
291

292
    names = [name for (name, module) in split_gm.named_modules()]
293

294
295
296
297
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
298

299
        module = getattr(split_gm, name)
300

301
        graph_id = int(name.replace("submod_", ""))
302
        outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
303

304
    # sort by integer graph_id, rather than string name
305
    outputs.sort(key=lambda x: x.graph_id)
306

307
    return split_gm, outputs
308
309


310
311
compilation_start_time = 0.0

312
313
314
315
316
317

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
318
319
320
321
322

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
323
324
    """

325
326
327
328
329
330
331
    def __init__(
        self,
        module: torch.fx.GraphModule,
        compile_submod_names: list[str],
        vllm_config: VllmConfig,
        vllm_backend: "VllmBackend",
    ):
332
333
        super().__init__(module)
        from torch._guards import detect_fake_mode
334

335
336
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
337
338
        self.compilation_config = vllm_config.compilation_config
        self.vllm_config = vllm_config
339
        self.vllm_backend = vllm_backend
340
341
        # When True, it annoyingly dumps the torch.fx.Graph on errors.
        self.extra_traceback = False
342
343
344
345
346
347

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
348
        with self.fake_mode, enable_python_dispatcher():
349
            return super().run(*fake_args)
350

351
352
353
354
355
356
    def call_module(
        self,
        target: torch.fx.node.Target,
        args: tuple[torch.fx.node.Argument, ...],
        kwargs: dict[str, Any],
    ) -> Any:
357
358
359
360
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
361
            index = self.compile_submod_names.index(target)
362
363
364
365
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
366
            global compilation_start_time
367

368
369
370
371
372
373
374
375
376
377
378
            compiled_graph_for_dynamic_shape = (
                self.vllm_backend.compiler_manager.compile(
                    submod,
                    args,
                    self.compilation_config.inductor_compile_config,
                    self.compilation_config,
                    graph_index=index,
                    num_graphs=len(self.compile_submod_names),
                    runtime_shape=None,
                )
            )
379
            # Lazy import here to avoid circular import
380
            from .piecewise_backend import PiecewiseBackend
381

382
            piecewise_backend = PiecewiseBackend(
383
384
385
386
387
388
389
390
                submod,
                self.vllm_config,
                index,
                len(self.compile_submod_names),
                sym_shape_indices,
                compiled_graph_for_dynamic_shape,
                self.vllm_backend,
            )
391

392
393
394
395
            if (
                self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
                and not self.compilation_config.use_inductor_graph_partition
            ):
396
397
398
399
                # We're using Dynamo-based piecewise splitting, so we wrap
                # the whole subgraph with a static graph wrapper.
                from .cuda_graph import CUDAGraphOptions

400
401
402
                # resolve the static graph wrapper class (e.g. CUDAGraphWrapper
                # class) as platform dependent.
                static_graph_wrapper_class = resolve_obj_by_qualname(
403
404
                    current_platform.get_static_graph_wrapper_cls()
                )
405
406
407
408
409
410
411
412
413
414
415
416

                # Always assign PIECEWISE runtime mode to the
                # CUDAGraphWrapper for piecewise_backend, to distinguish
                # it from the FULL cudagraph runtime mode, no matter it
                # is wrapped on a full or piecewise fx graph.
                self.module.__dict__[target] = static_graph_wrapper_class(
                    runnable=piecewise_backend,
                    vllm_config=self.vllm_config,
                    runtime_mode=CUDAGraphMode.PIECEWISE,
                    cudagraph_options=CUDAGraphOptions(
                        debug_log_enable=piecewise_backend.is_first_graph,
                        gc_disable=not piecewise_backend.is_first_graph,
417
418
419
                        weak_ref_output=piecewise_backend.is_last_graph,
                    ),
                )
420
421
422
            else:
                self.module.__dict__[target] = piecewise_backend

423
424
425
426
427
            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


428
429
430
431
432
433
434
435
436
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"


@contextmanager
def set_model_tag(tag: str):
    """Context manager to set the model tag."""
    global model_tag
437
    assert tag != model_tag, (
438
        f"Model tag {tag} is the same as the current tag {model_tag}."
439
    )
440
441
442
443
444
445
446
447
    old_tag = model_tag
    model_tag = tag
    try:
        yield
    finally:
        model_tag = old_tag


448
class VllmBackend:
449
    """The compilation backend for `torch.compile` with vLLM.
450
451
    It is used for compilation level of `CompilationLevel.PIECEWISE`,
    where we customize the compilation.
452

453
454
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
455

456
457
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
458
    """
459

460
461
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
462
463
464
465
466
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
467
    piecewise_graphs: list[SplitItem]
468
    returned_callable: Callable
469
470
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
471
472
    sym_tensor_indices: list[int]
    input_buffers: list[torch.Tensor]
473
    compiler_manager: CompilerManager
474

475
476
    def __init__(
        self,
477
        vllm_config: VllmConfig,
478
        prefix: str = "",
479
    ):
480
481
        # if the model is initialized with a non-empty prefix,
        # then usually it's enough to use that prefix,
482
        # e.g. language_model, vision_model, etc.
483
484
485
486
487
        # when multiple parts are initialized as independent
        # models, we need to use the model_tag to distinguish
        # them, e.g. backbone (default), eagle_head, etc.
        self.prefix = prefix or model_tag

488
489
        # Passes to run on the graph post-grad.
        self.post_grad_pass_manager = PostGradPassManager()
490

491
492
493
        self.sym_tensor_indices = []
        self.input_buffers = []

494
495
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
496

497
        self.compiler_manager: CompilerManager = CompilerManager(
498
499
            self.compilation_config
        )
500

501
502
503
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

504
    def configure_post_pass(self):
505
        config = self.compilation_config
506
        self.post_grad_pass_manager.configure(self.vllm_config)
507

508
509
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
510
        inductor_config = config.inductor_compile_config
511
512
        PASS_KEY = "post_grad_custom_post_pass"
        if PASS_KEY in inductor_config:
513
            if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
514
                # PassManager already added to config, make sure it's correct
515
516
517
518
                assert (
                    inductor_config[PASS_KEY].uuid()
                    == self.post_grad_pass_manager.uuid()
                )
519
            else:
520
                # Config should automatically wrap all inductor passes
521
522
                assert isinstance(inductor_config[PASS_KEY], InductorPass)
                self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
523
        inductor_config[PASS_KEY] = self.post_grad_pass_manager
524

525
    def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
526
        vllm_config = self.vllm_config
527
528
529
530
531
532
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.

533
            factors = []
534
            # 0. factors come from the env, for example, The values of
535
            # VLLM_PP_LAYER_PARTITION will affect the computation graph.
536
537
538
            env_hash = envs.compute_hash()
            factors.append(env_hash)

539
540
541
            # 1. factors come from the vllm_config (it mainly summarizes how the
            #    model is created)
            config_hash = vllm_config.compute_hash()
542
            factors.append(config_hash)
543
544
545

            # 2. factors come from the code files that are traced by Dynamo (
            #    it mainly summarizes how the model is used in forward pass)
546
            forward_code_files = list(sorted(self.compilation_config.traced_files))
547
548
549
            self.compilation_config.traced_files.clear()
            logger.debug(
                "Traced files (to be considered for compilation cache):\n%s",
550
551
                "\n".join(forward_code_files),
            )
552
553
554
            hash_content = []
            for filepath in forward_code_files:
                hash_content.append(filepath)
555
556
557
558
                if filepath == "<string>":
                    # This means the function was dynamically generated, with
                    # e.g. exec(). We can't actually check these.
                    continue
559
560
561
                with open(filepath) as f:
                    hash_content.append(f.read())
            import hashlib
562
563
564
565

            code_hash = hashlib.md5(
                "\n".join(hash_content).encode(), usedforsecurity=False
            ).hexdigest()
566
567
568
569
570
571
572
            factors.append(code_hash)

            # 3. compiler hash
            compiler_hash = self.compiler_manager.compute_hash(vllm_config)
            factors.append(compiler_hash)

            # combine all factors to generate the cache dir
573
574
575
            hash_key = hashlib.md5(
                str(factors).encode(), usedforsecurity=False
            ).hexdigest()[:10]
576
577

            cache_dir = os.path.join(
578
579
580
581
582
583
                envs.VLLM_CACHE_ROOT,
                "torch_compile_cache",
                hash_key,
            )
            self.compilation_config.cache_dir = cache_dir

584
        cache_dir = self.compilation_config.cache_dir
585
        os.makedirs(cache_dir, exist_ok=True)
586
        self.compilation_config.cache_dir = cache_dir
587
588
        rank = vllm_config.parallel_config.rank
        dp_rank = vllm_config.parallel_config.data_parallel_rank
589
        local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
590
        os.makedirs(local_cache_dir, exist_ok=True)
591
        self.compilation_config.local_cache_dir = local_cache_dir
592

593
594
595
        disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE

        if disable_cache:
596
597
            logger.info("vLLM's torch.compile cache is disabled.")
        else:
598
599
600
            logger.info(
                "Using cache directory: %s for vLLM's torch.compile", local_cache_dir
            )
601

602
603
604
        self.compiler_manager.initialize_cache(
            local_cache_dir, disable_cache, self.prefix
        )
605

606
607
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
608
        compilation_counter.num_graphs_seen += 1
609
        from .monitor import torch_compile_start_time
610

611
612
        dynamo_time = time.time() - torch_compile_start_time
        logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
613
        self.compilation_config.compilation_time += dynamo_time
614
615
616
617
618
619

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
620
        self.configure_post_pass()
621
622

        self.split_gm, self.piecewise_graphs = split_graph(
623
624
            graph, self.compilation_config.splitting_ops
        )
625

626
        from torch._dynamo.utils import lazy_format_graph_code
627
628
629
630
631

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
632

633
        compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
634
        submod_names_to_compile = [
635
636
            item.submod_name
            for item in self.piecewise_graphs
637
638
639
640
641
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
642
643
644
        PiecewiseCompileInterpreter(
            self.split_gm, submod_names_to_compile, self.vllm_config, self
        ).run(*example_inputs)
645

646
647
648
649
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
            # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
            # use `print_readable` because it can include submodules
650
651
652
653
            src = (
                "from __future__ import annotations\nimport torch\n"
                + self.split_gm.print_readable(print_output=False)
            )
654
655
656
657
658
659
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

            logger.debug("Computation graph saved to %s", graph_path)

660
661
        self._called = True

662
663
664
665
        if (
            self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
            or not self.compilation_config.cudagraph_copy_inputs
        ):
666
667
668
669
            return self.split_gm

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
670

671
672
673
674
675
676
677
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
678
679
680
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
681

682
        self.sym_tensor_indices = [
683
684
685
686
            i
            for i, x in enumerate(fake_args)
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
            and any(is_symbolic(d) for d in x.size())
687
688
689
690
691
692
693
694
695
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

youkaichao's avatar
youkaichao committed
696
697
        # this is the callable we return to Dynamo to run
        def copy_and_call(*args):
698
699
700
701
702
703
704
705
706
707
708
709
710
711
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

        return copy_and_call