backends.py 27.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import ast
5
import dataclasses
6
7
import os
import pprint
8
import time
9
from collections.abc import Sequence
10
from contextlib import contextmanager
11
from typing import Any, Callable, Optional
12
13
14

import torch
import torch.fx as fx
15
from torch._dispatch.python import enable_python_dispatcher
16

17
import vllm.envs as envs
18
19
20
21
22
from vllm.compilation.inductor_pass import pass_context
from vllm.compilation.partition_rules import (
    inductor_partition_rule_context,
    resolve_defined_ops,
)
23
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
24
from vllm.logger import init_logger
25
from vllm.platforms import current_platform
26
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
27

28
29
30
31
32
33
from .compiler_interface import (
    CompilerInterface,
    EagerAdaptor,
    InductorAdaptor,
    InductorStandaloneAdaptor,
)
34
from .counter import compilation_counter
35
36
from .inductor_pass import InductorPass
from .pass_manager import PostGradPassManager
37
38
39

logger = init_logger(__name__)

40

41
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
42
    if compilation_config.use_inductor:
43
44
        # Use standalone compile only if requested, version is new enough,
        # and the symbol actually exists in this PyTorch build.
45
46
47
48
49
        if (
            envs.VLLM_USE_STANDALONE_COMPILE
            and is_torch_equal_or_newer("2.8.0.dev")
            and hasattr(torch._inductor, "standalone_compile")
        ):
50
            logger.debug("Using InductorStandaloneAdaptor")
51
52
            return InductorStandaloneAdaptor()
        else:
53
            logger.debug("Using InductorAdaptor")
54
55
            return InductorAdaptor()
    else:
56
        logger.debug("Using EagerAdaptor")
57
58
59
        return EagerAdaptor()


60
61
62
63
64
class CompilerManager:
    """
    A manager to manage the compilation process, including
    caching the compiled graph, loading the compiled graph,
    and compiling the graph.
65

66
67
68
    The cache is a dict mapping
    `(runtime_shape, graph_index, backend_name)`
    to `any_data` returned from the compiler.
69

70
71
72
    When serializing the cache, we save it to a Python file
    for readability. We don't use json here because json doesn't
    support int as key.
73
74
    """

75
    def __init__(self, compilation_config: CompilationConfig):
76
        self.cache: dict[tuple[Optional[int], int, str], Any] = dict()
77
        self.is_cache_updated = False
78
79
        self.compilation_config = compilation_config
        self.compiler = make_compiler(compilation_config)
80

81
82
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        return self.compiler.compute_hash(vllm_config)
83

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    @contextmanager
    def compile_context(self, runtime_shape: Optional[int] = None):
        """Provide compilation context for the duration of compilation to set
        any torch global properties we want to scope to a single Inductor
        compilation (e.g. partition rules, pass context)."""
        with pass_context(runtime_shape):
            if self.compilation_config.use_inductor_graph_partition:
                inductor_partition_ops = resolve_defined_ops(
                    self.compilation_config.splitting_ops
                )
                with inductor_partition_rule_context(inductor_partition_ops):
                    yield
            else:
                yield

99
100
101
    def initialize_cache(
        self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
    ):
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        """
        Initialize the cache directory for the compiler.

        The organization of the cache directory is as follows:
        cache_dir=/path/to/hash_str/rank_i_j/prefix/
        inside cache_dir, there will be:
        - vllm_compile_cache.py
        - computation_graph.py
        - transformed_code.py

        for multiple prefixes, they can share the same
        base cache dir of /path/to/hash_str/rank_i_j/ ,
        to store some common compilation artifacts.
        """

117
        self.disable_cache = disable_cache
118
        self.cache_dir = cache_dir
119
120
121
122
        self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")

        if not disable_cache and os.path.exists(self.cache_file_path):
            # load the cache from the file
123
            with open(self.cache_file_path) as f:
124
125
126
127
128
                # we use ast.literal_eval to parse the data
                # because it is a safe way to parse Python literals.
                # do not use eval(), it is unsafe.
                self.cache = ast.literal_eval(f.read())

129
130
131
        self.compiler.initialize_cache(
            cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
        )
132
133

    def save_to_file(self):
134
        if self.disable_cache or not self.is_cache_updated:
135
            return
136
137
        printer = pprint.PrettyPrinter(indent=4)
        data = printer.pformat(self.cache)
138
        with open(self.cache_file_path, "w") as f:
139
140
            f.write(data)

141
142
143
144
145
146
147
    def load(
        self,
        graph: fx.GraphModule,
        example_inputs: list[Any],
        graph_index: int,
        runtime_shape: Optional[int] = None,
    ) -> Optional[Callable]:
148
149
150
        if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
            return None
        handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
151
152
153
        compiled_graph = self.compiler.load(
            handle, graph, example_inputs, graph_index, runtime_shape
        )
154
155
        if runtime_shape is None:
            logger.debug(
156
157
158
159
160
                "Directly load the %s-th graph for dynamic shape from %s via handle %s",
                graph_index,
                self.compiler.name,
                handle,
            )
161
162
        else:
            logger.debug(
163
164
165
166
167
168
                "Directly load the %s-th graph for shape %s from %s via handle %s",
                graph_index,
                str(runtime_shape),
                self.compiler.name,
                handle,
            )
169
170
        return compiled_graph

171
172
173
174
175
176
177
178
179
180
    def compile(
        self,
        graph: fx.GraphModule,
        example_inputs,
        additional_inductor_config,
        compilation_config: CompilationConfig,
        graph_index: int = 0,
        num_graphs: int = 1,
        runtime_shape: Optional[int] = None,
    ) -> Any:
181
        if graph_index == 0:
182
183
184
185
186
187
188
189
190
            # before compiling the first graph, record the start time
            global compilation_start_time
            compilation_start_time = time.time()

        compilation_counter.num_backend_compilations += 1

        compiled_graph = None

        # try to load from the cache
191
        compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
192
        if compiled_graph is not None:
193
194
195
196
197
            if graph_index == num_graphs - 1:
                # after loading the last graph for this shape, record the time.
                # there can be multiple graphs due to piecewise compilation.
                now = time.time()
                elapsed = now - compilation_start_time
198
199
200
                if runtime_shape is None:
                    logger.info(
                        "Directly load the compiled graph(s) for dynamic shape "
201
202
203
                        "from the cache, took %.3f s",
                        elapsed,
                    )
204
205
206
                else:
                    logger.info(
                        "Directly load the compiled graph(s) for shape %s "
207
208
209
210
                        "from the cache, took %.3f s",
                        str(runtime_shape),
                        elapsed,
                    )
211
212
213
214
            return compiled_graph

        # no compiler cached the graph, or the cache is disabled,
        # we need to compile it
215
216
217
218
        if isinstance(self.compiler, InductorAdaptor):
            # Let compile_fx generate a key for us
            maybe_key = None
        else:
219
            maybe_key = f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
220
221
222
223
224
225
226
227
228

        with self.compile_context(runtime_shape):
            compiled_graph, handle = self.compiler.compile(
                graph,
                example_inputs,
                additional_inductor_config,
                runtime_shape,
                maybe_key,
            )
229
230
231
232

        assert compiled_graph is not None, "Failed to compile the graph"

        # store the artifact in the cache
233
        if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None:
234
            self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
235
            compilation_counter.num_cache_entries_updated += 1
236
            self.is_cache_updated = True
237
238
            if graph_index == 0:
                # adds some info logging for the first graph
239
                if runtime_shape is None:
240
                    logger.info("Cache the graph for dynamic shape for later use")
241
                else:
242
243
244
                    logger.info(
                        "Cache the graph of shape %s for later use", str(runtime_shape)
                    )
245
246
            if runtime_shape is None:
                logger.debug(
247
248
249
250
251
                    "Store the %s-th graph for dynamic shape from %s via handle %s",
                    graph_index,
                    self.compiler.name,
                    handle,
                )
252
253
254
            else:
                logger.debug(
                    "Store the %s-th graph for shape %s from %s via handle %s",
255
256
257
258
259
                    graph_index,
                    str(runtime_shape),
                    self.compiler.name,
                    handle,
                )
260
261
262
263
264
265
266

        # after compiling the last graph, record the end time
        if graph_index == num_graphs - 1:
            now = time.time()
            elapsed = now - compilation_start_time
            compilation_config.compilation_time += elapsed
            if runtime_shape is None:
267
                logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed)
268
            else:
269
270
271
272
273
                logger.info(
                    "Compiling a graph for shape %s takes %.2f s",
                    runtime_shape,
                    elapsed,
                )
274

275
        return compiled_graph
276
277


278
279
280
@dataclasses.dataclass
class SplitItem:
    submod_name: str
281
    graph_id: int
282
283
284
285
    is_splitting_graph: bool
    graph: fx.GraphModule


286
def split_graph(
287
    graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
288
) -> tuple[fx.GraphModule, list[SplitItem]]:
289
290
291
292
293
294
295
    # split graph by ops
    subgraph_id = 0
    node_to_subgraph_id = {}
    split_op_graphs = []
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
296
297
298
299
300
301
        # Match node.target against resolved_ops
        # node.target can be OpOverloadPacket, need to check .default
        if node.op == "call_function" and (
            node.target in resolved_ops
            or (hasattr(node.target, "default") and node.target.default in resolved_ops)
        ):
302
303
304
305
306
307
308
309
310
311
312
313
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
314
315
        graph, None, lambda node: node_to_subgraph_id[node], keep_original_order=True
    )
316

317
    outputs = []
318

319
    names = [name for (name, module) in split_gm.named_modules()]
320

321
322
323
324
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
325

326
        module = getattr(split_gm, name)
327

328
        graph_id = int(name.replace("submod_", ""))
329
        outputs.append(SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
330

331
    # sort by integer graph_id, rather than string name
332
    outputs.sort(key=lambda x: x.graph_id)
333

334
    return split_gm, outputs
335
336


337
338
compilation_start_time = 0.0

339
340
341
342
343
344

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
345
346
347
348
349

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
350
351
    """

352
353
354
355
356
357
358
    def __init__(
        self,
        module: torch.fx.GraphModule,
        compile_submod_names: list[str],
        vllm_config: VllmConfig,
        vllm_backend: "VllmBackend",
    ):
359
360
        super().__init__(module)
        from torch._guards import detect_fake_mode
361

362
363
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
364
365
        self.compilation_config = vllm_config.compilation_config
        self.vllm_config = vllm_config
366
        self.vllm_backend = vllm_backend
367
368
        # When True, it annoyingly dumps the torch.fx.Graph on errors.
        self.extra_traceback = False
369
370
371
372
373
374

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
375
        with self.fake_mode, enable_python_dispatcher():
376
            return super().run(*fake_args)
377

378
379
380
381
382
383
    def call_module(
        self,
        target: torch.fx.node.Target,
        args: tuple[torch.fx.node.Argument, ...],
        kwargs: dict[str, Any],
    ) -> Any:
384
385
386
387
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
388
            index = self.compile_submod_names.index(target)
389
390
391
392
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
393
            global compilation_start_time
394

395
396
397
398
399
400
401
402
403
404
405
            compiled_graph_for_dynamic_shape = (
                self.vllm_backend.compiler_manager.compile(
                    submod,
                    args,
                    self.compilation_config.inductor_compile_config,
                    self.compilation_config,
                    graph_index=index,
                    num_graphs=len(self.compile_submod_names),
                    runtime_shape=None,
                )
            )
406
            # Lazy import here to avoid circular import
407
            from .piecewise_backend import PiecewiseBackend
408

409
            piecewise_backend = PiecewiseBackend(
410
411
412
413
414
415
416
417
                submod,
                self.vllm_config,
                index,
                len(self.compile_submod_names),
                sym_shape_indices,
                compiled_graph_for_dynamic_shape,
                self.vllm_backend,
            )
418

419
420
421
422
            if (
                self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
                and not self.compilation_config.use_inductor_graph_partition
            ):
423
424
425
426
                # We're using Dynamo-based piecewise splitting, so we wrap
                # the whole subgraph with a static graph wrapper.
                from .cuda_graph import CUDAGraphOptions

427
428
429
                # resolve the static graph wrapper class (e.g. CUDAGraphWrapper
                # class) as platform dependent.
                static_graph_wrapper_class = resolve_obj_by_qualname(
430
431
                    current_platform.get_static_graph_wrapper_cls()
                )
432
433
434
435
436
437
438
439
440
441
442
443

                # Always assign PIECEWISE runtime mode to the
                # CUDAGraphWrapper for piecewise_backend, to distinguish
                # it from the FULL cudagraph runtime mode, no matter it
                # is wrapped on a full or piecewise fx graph.
                self.module.__dict__[target] = static_graph_wrapper_class(
                    runnable=piecewise_backend,
                    vllm_config=self.vllm_config,
                    runtime_mode=CUDAGraphMode.PIECEWISE,
                    cudagraph_options=CUDAGraphOptions(
                        debug_log_enable=piecewise_backend.is_first_graph,
                        gc_disable=not piecewise_backend.is_first_graph,
444
445
446
                        weak_ref_output=piecewise_backend.is_last_graph,
                    ),
                )
447
448
449
            else:
                self.module.__dict__[target] = piecewise_backend

450
451
452
453
454
            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


455
456
457
458
459
460
461
462
463
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"


@contextmanager
def set_model_tag(tag: str):
    """Context manager to set the model tag."""
    global model_tag
464
    assert tag != model_tag, (
465
        f"Model tag {tag} is the same as the current tag {model_tag}."
466
    )
467
468
469
470
471
472
473
474
    old_tag = model_tag
    model_tag = tag
    try:
        yield
    finally:
        model_tag = old_tag


475
class VllmBackend:
476
    """The compilation backend for `torch.compile` with vLLM.
477
478
    It is used for compilation level of `CompilationLevel.PIECEWISE`,
    where we customize the compilation.
479

480
481
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
482

483
484
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
485
    """
486

487
488
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
489
490
491
492
493
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
494
    piecewise_graphs: list[SplitItem]
495
    returned_callable: Callable
496
497
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
498
499
    sym_tensor_indices: list[int]
    input_buffers: list[torch.Tensor]
500
    compiler_manager: CompilerManager
501

502
503
    def __init__(
        self,
504
        vllm_config: VllmConfig,
505
        prefix: str = "",
506
    ):
507
508
        # if the model is initialized with a non-empty prefix,
        # then usually it's enough to use that prefix,
509
        # e.g. language_model, vision_model, etc.
510
511
512
513
514
        # when multiple parts are initialized as independent
        # models, we need to use the model_tag to distinguish
        # them, e.g. backbone (default), eagle_head, etc.
        self.prefix = prefix or model_tag

515
516
        # Passes to run on the graph post-grad.
        self.post_grad_pass_manager = PostGradPassManager()
517

518
519
520
        self.sym_tensor_indices = []
        self.input_buffers = []

521
522
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
523

524
        self.compiler_manager: CompilerManager = CompilerManager(
525
526
            self.compilation_config
        )
527

528
529
530
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

531
    def configure_post_pass(self):
532
        config = self.compilation_config
533
        self.post_grad_pass_manager.configure(self.vllm_config)
534

535
536
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
537
        inductor_config = config.inductor_compile_config
538
539
        PASS_KEY = "post_grad_custom_post_pass"
        if PASS_KEY in inductor_config:
540
            if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
541
                # PassManager already added to config, make sure it's correct
542
543
544
545
                assert (
                    inductor_config[PASS_KEY].uuid()
                    == self.post_grad_pass_manager.uuid()
                )
546
            else:
547
                # Config should automatically wrap all inductor passes
548
549
                assert isinstance(inductor_config[PASS_KEY], InductorPass)
                self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
550
        inductor_config[PASS_KEY] = self.post_grad_pass_manager
551

552
    def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
553
        vllm_config = self.vllm_config
554
555
556
557
558
559
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.

560
            factors = []
561
            # 0. factors come from the env, for example, The values of
562
            # VLLM_PP_LAYER_PARTITION will affect the computation graph.
563
564
565
            env_hash = envs.compute_hash()
            factors.append(env_hash)

566
567
568
            # 1. factors come from the vllm_config (it mainly summarizes how the
            #    model is created)
            config_hash = vllm_config.compute_hash()
569
            factors.append(config_hash)
570
571
572

            # 2. factors come from the code files that are traced by Dynamo (
            #    it mainly summarizes how the model is used in forward pass)
573
            forward_code_files = list(sorted(self.compilation_config.traced_files))
574
575
576
            self.compilation_config.traced_files.clear()
            logger.debug(
                "Traced files (to be considered for compilation cache):\n%s",
577
578
                "\n".join(forward_code_files),
            )
579
580
581
            hash_content = []
            for filepath in forward_code_files:
                hash_content.append(filepath)
582
583
584
585
                if filepath == "<string>":
                    # This means the function was dynamically generated, with
                    # e.g. exec(). We can't actually check these.
                    continue
586
587
588
                with open(filepath) as f:
                    hash_content.append(f.read())
            import hashlib
589
590
591
592

            code_hash = hashlib.md5(
                "\n".join(hash_content).encode(), usedforsecurity=False
            ).hexdigest()
593
594
595
596
597
598
599
            factors.append(code_hash)

            # 3. compiler hash
            compiler_hash = self.compiler_manager.compute_hash(vllm_config)
            factors.append(compiler_hash)

            # combine all factors to generate the cache dir
600
601
602
            hash_key = hashlib.md5(
                str(factors).encode(), usedforsecurity=False
            ).hexdigest()[:10]
603
604

            cache_dir = os.path.join(
605
606
607
608
609
610
                envs.VLLM_CACHE_ROOT,
                "torch_compile_cache",
                hash_key,
            )
            self.compilation_config.cache_dir = cache_dir

611
        cache_dir = self.compilation_config.cache_dir
612
        os.makedirs(cache_dir, exist_ok=True)
613
        self.compilation_config.cache_dir = cache_dir
614
615
        rank = vllm_config.parallel_config.rank
        dp_rank = vllm_config.parallel_config.data_parallel_rank
616
        local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}", self.prefix)
617
        os.makedirs(local_cache_dir, exist_ok=True)
618
        self.compilation_config.local_cache_dir = local_cache_dir
619

620
621
622
        disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE

        if disable_cache:
623
624
            logger.info("vLLM's torch.compile cache is disabled.")
        else:
625
626
627
            logger.info(
                "Using cache directory: %s for vLLM's torch.compile", local_cache_dir
            )
628

629
630
631
        self.compiler_manager.initialize_cache(
            local_cache_dir, disable_cache, self.prefix
        )
632

633
634
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
635
        compilation_counter.num_graphs_seen += 1
636
        from .monitor import torch_compile_start_time
637

638
639
        dynamo_time = time.time() - torch_compile_start_time
        logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
640
        self.compilation_config.compilation_time += dynamo_time
641
642
643
644
645
646

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
647
        self.configure_post_pass()
648

649
650
651
652
653
654
655
656
        if self.compilation_config.use_inductor_graph_partition:
            # Let Inductor decide partitioning; avoid FX-level pre-splitting.
            fx_split_ops: list[str] = []
        else:
            fx_split_ops = self.compilation_config.splitting_ops or []

        resolved_split_ops = resolve_defined_ops(fx_split_ops)
        self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
657

658
        from torch._dynamo.utils import lazy_format_graph_code
659
660
661
662
663

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
664

665
        compilation_counter.num_piecewise_graphs_seen += len(self.piecewise_graphs)
666
        submod_names_to_compile = [
667
668
            item.submod_name
            for item in self.piecewise_graphs
669
670
671
672
673
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
674
675
676
        PiecewiseCompileInterpreter(
            self.split_gm, submod_names_to_compile, self.vllm_config, self
        ).run(*example_inputs)
677

678
679
680
681
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
            # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
            # use `print_readable` because it can include submodules
682
683
684
685
            src = (
                "from __future__ import annotations\nimport torch\n"
                + self.split_gm.print_readable(print_output=False)
            )
686
687
688
689
690
691
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

            logger.debug("Computation graph saved to %s", graph_path)

692
693
        self._called = True

694
695
696
697
        if (
            self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
            or not self.compilation_config.cudagraph_copy_inputs
        ):
698
699
700
701
            return self.split_gm

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
702

703
704
705
706
707
708
709
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
710
711
712
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
713

714
        self.sym_tensor_indices = [
715
716
717
718
            i
            for i, x in enumerate(fake_args)
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
            and any(is_symbolic(d) for d in x.size())
719
720
721
722
723
724
725
726
727
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

youkaichao's avatar
youkaichao committed
728
729
        # this is the callable we return to Dynamo to run
        def copy_and_call(*args):
730
731
732
733
734
735
736
737
738
739
740
741
742
743
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

        return copy_and_call