backends.py 28.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import ast
4
import dataclasses
5
6
import os
import pprint
7
import time
8
from contextlib import ExitStack
9
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
10
from unittest.mock import patch
11
12
13
14

import torch
import torch.fx as fx

15
import vllm.envs as envs
16
from vllm.config import CompilationConfig, VllmConfig
17
from vllm.logger import init_logger
18
from vllm.utils import weak_ref_tensors
19

20
from .compiler_interface import EagerAdaptor, InductorAdaptor
21
from .counter import compilation_counter
22
from .inductor_pass import InductorPass
23
from .monitor import end_monitoring_torch_compile
24
from .pass_manager import PostGradPassManager
25
26
27

logger = init_logger(__name__)

28

29
30
31
32
33
class CompilerManager:
    """
    A manager to manage the compilation process, including
    caching the compiled graph, loading the compiled graph,
    and compiling the graph.
34

35
36
37
    The cache is a dict mapping
    `(runtime_shape, graph_index, backend_name)`
    to `any_data` returned from the compiler.
38

39
40
41
    When serializing the cache, we save it to a Python file
    for readability. We don't use json here because json doesn't
    support int as key.
42
43
    """

44
45
46
47
    def __init__(self, use_inductor: bool):
        self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict()
        cls = InductorAdaptor if use_inductor else EagerAdaptor
        self.compiler = cls()
48

49
50
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        return self.compiler.compute_hash(vllm_config)
51

52
53
    def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
        self.disable_cache = disable_cache
54
        self.cache_dir = cache_dir
55
56
57
58
        self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")

        if not disable_cache and os.path.exists(self.cache_file_path):
            # load the cache from the file
59
            with open(self.cache_file_path) as f:
60
61
62
63
64
65
66
                # we use ast.literal_eval to parse the data
                # because it is a safe way to parse Python literals.
                # do not use eval(), it is unsafe.
                self.cache = ast.literal_eval(f.read())

        self.compiler.initialize_cache(cache_dir=cache_dir,
                                       disable_cache=disable_cache)
67
68

    def save_to_file(self):
69
        if self.disable_cache:
70
71
            return
        with open(self.cache_file_path, "w") as f:
72
73
74
75
76
77
78
79
80
81
82
83
84
85
            printer = pprint.PrettyPrinter(indent=4)
            data = printer.pformat(self.cache)
            f.write(data)

    def load(self,
             graph: fx.GraphModule,
             example_inputs: List[Any],
             graph_index: int,
             runtime_shape: Optional[int] = None) -> Optional[Callable]:
        if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
            return None
        handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
        compiled_graph = self.compiler.load(handle, graph, example_inputs,
                                            graph_index, runtime_shape)
86
        logger.debug(
87
88
89
90
91
92
93
94
95
96
97
98
99
            "Directly load the %s-th graph for shape %s from %s via "
            "handle %s", graph_index, str(runtime_shape), self.compiler.name,
            handle)
        return compiled_graph

    def compile(self,
                graph: fx.GraphModule,
                example_inputs,
                additional_inductor_config,
                compilation_config: CompilationConfig,
                graph_index: int = 0,
                num_graphs: int = 1,
                runtime_shape: Optional[int] = None) -> Any:
100
        if graph_index == 0:
101
102
103
104
105
106
107
108
109
110
111
112
            # before compiling the first graph, record the start time
            global compilation_start_time
            compilation_start_time = time.time()

        compilation_counter.num_backend_compilations += 1

        compiled_graph = None

        # try to load from the cache
        compiled_graph = self.load(graph, example_inputs, graph_index,
                                   runtime_shape)
        if compiled_graph is not None:
113
114
115
116
117
118
119
120
            if graph_index == num_graphs - 1:
                # after loading the last graph for this shape, record the time.
                # there can be multiple graphs due to piecewise compilation.
                now = time.time()
                elapsed = now - compilation_start_time
                logger.info(
                    "Directly load the compiled graph(s) for shape %s "
                    "from the cache, took %.3f s", str(runtime_shape), elapsed)
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
            return compiled_graph

        # no compiler cached the graph, or the cache is disabled,
        # we need to compile it
        compiled_graph, handle = self.compiler.compile(
            graph, example_inputs, additional_inductor_config, runtime_shape)

        assert compiled_graph is not None, "Failed to compile the graph"

        # store the artifact in the cache
        if handle is not None:
            self.cache[(runtime_shape, graph_index,
                        self.compiler.name)] = handle
            if graph_index == 0:
                # adds some info logging for the first graph
                logger.info("Cache the graph of shape %s for later use",
                            str(runtime_shape))
            logger.debug(
                "store the %s-th graph for shape %s from %s via handle %s",
                graph_index, str(runtime_shape), self.compiler.name, handle)

        # after compiling the last graph, record the end time
        if graph_index == num_graphs - 1:
            now = time.time()
            elapsed = now - compilation_start_time
            compilation_config.compilation_time += elapsed
            if runtime_shape is None:
                logger.info("Compiling a graph for general shape takes %.2f s",
                            elapsed)
            else:
                logger.info("Compiling a graph for shape %s takes %.2f s",
                            runtime_shape, elapsed)
153

154
        return compiled_graph
155
156


157
158
159
@dataclasses.dataclass
class SplitItem:
    submod_name: str
160
    graph_id: int
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    is_splitting_graph: bool
    graph: fx.GraphModule


def split_graph(graph: fx.GraphModule,
                ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]:
    # split graph by ops
    subgraph_id = 0
    node_to_subgraph_id = {}
    split_op_graphs = []
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
        if node.op == 'call_function' and str(node.target) in ops:
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
187
        graph,
188
189
190
        None,
        lambda node: node_to_subgraph_id[node],
        keep_original_order=True)
191

192
    outputs = []
193

194
    names = [name for (name, module) in split_gm.named_modules()]
195

196
197
198
199
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
200

201
        module = getattr(split_gm, name)
202

203
        graph_id = int(name.replace("submod_", ""))
204
205
206
207
208
        outputs.append(
            SplitItem(name, graph_id, (graph_id in split_op_graphs), module))

    # sort by intetger graph_id, rather than string name
    outputs.sort(key=lambda x: x.graph_id)
209

210
    return split_gm, outputs
211
212


213
214
215
# we share the global graph pool among all the backends
global_graph_pool = None

216
217
compilation_start_time = 0.0

218
219
220
221
222
223

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
224
225
226
227
228

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
229
230
231
    """

    def __init__(self, module: torch.fx.GraphModule,
232
                 compile_submod_names: List[str], vllm_config: VllmConfig,
233
                 graph_pool, vllm_backend: "VllmBackend"):
234
235
236
237
        super().__init__(module)
        from torch._guards import detect_fake_mode
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
238
        self.compilation_config = vllm_config.compilation_config
239
        self.graph_pool = graph_pool
240
        self.vllm_config = vllm_config
241
        self.vllm_backend = vllm_backend
242
243
244
245
246
247

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
248
249
        with self.fake_mode:
            return super().run(*fake_args)
250
251
252
253
254
255
256
257

    def call_module(self, target: torch.fx.node.Target,
                    args: Tuple[torch.fx.node.Argument,
                                ...], kwargs: Dict[str, Any]) -> Any:
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
258
            index = self.compile_submod_names.index(target)
259
260
261
262
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
263
            global compilation_start_time
264
265
            compiled_graph_for_general_shape = self.vllm_backend.\
                compiler_manager.compile(
266
267
                submod,
                args,
268
269
                self.compilation_config.inductor_compile_config,
                self.compilation_config,
270
271
                graph_index=index,
                num_graphs=len(self.compile_submod_names),
272
                runtime_shape=None)
273
274

            self.module.__dict__[target] = PiecewiseBackend(
275
                submod, self.vllm_config, self.graph_pool, index,
276
                len(self.compile_submod_names), sym_shape_indices,
277
                compiled_graph_for_general_shape, self.vllm_backend)
278
279
280
281
282
283

            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


284
class VllmBackend:
285
    """The compilation backend for `torch.compile` with vLLM.
286
287
    It is used for compilation level of `CompilationLevel.PIECEWISE`,
    where we customize the compilation.
288

289
290
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
291

292
293
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
294
    """
295

296
297
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
298
299
300
301
302
303
304
305
    graph_pool: Any
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
    piecewise_graphs: List[SplitItem]
    returned_callable: Callable
306
307
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
308
309
    sym_tensor_indices: List[int]
    input_buffers: List[torch.Tensor]
310
    compiler_manager: CompilerManager
311

312
313
    def __init__(
        self,
314
        vllm_config: VllmConfig,
315
    ):
316
317
318
319
320
321
322
323
        global global_graph_pool
        if global_graph_pool is None:
            global_graph_pool = torch.cuda.graph_pool_handle()

        # TODO: in the future, if we want to use multiple
        # streams, it might not be safe to share a global pool.
        # only investigate this when we use multiple streams
        self.graph_pool = global_graph_pool
324
325
326

        # Passes to run on the graph post-grad.
        self.post_grad_pass_manager = PostGradPassManager()
327

328
329
330
        self.sym_tensor_indices = []
        self.input_buffers = []

331
332
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
333

334
335
336
        self.compiler_manager: CompilerManager = CompilerManager(
            self.compilation_config.use_inductor)

337
338
339
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

340
    def configure_post_pass(self):
341
        config = self.compilation_config
342
        self.post_grad_pass_manager.configure(self.vllm_config)
343

344
345
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
346
        inductor_config = config.inductor_compile_config
347
348
349
350
351
352
        PASS_KEY = "post_grad_custom_post_pass"
        if PASS_KEY in inductor_config:
            # Config should automatically wrap all inductor passes
            assert isinstance(inductor_config[PASS_KEY], InductorPass)
            self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
        inductor_config[PASS_KEY] = self.post_grad_pass_manager
353

354
355
    def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

356
        vllm_config = self.vllm_config
357
358
359
360
361
362
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.

363
            factors = []
364
365
366
367
368
            # 0. factors come from the env, for example, The values of
            # VLLM_PP_LAYER_PARTITION will affects the computation graph.
            env_hash = envs.compute_hash()
            factors.append(env_hash)

369
370
371
            # 1. factors come from the vllm_config (it mainly summarizes how the
            #    model is created)
            config_hash = vllm_config.compute_hash()
372
            factors.append(config_hash)
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387

            # 2. factors come from the code files that are traced by Dynamo (
            #    it mainly summarizes how the model is used in forward pass)
            forward_code_files = list(
                sorted(self.compilation_config.traced_files))
            self.compilation_config.traced_files.clear()
            logger.debug(
                "Traced files (to be considered for compilation cache):\n%s",
                "\n".join(forward_code_files))
            hash_content = []
            for filepath in forward_code_files:
                hash_content.append(filepath)
                with open(filepath) as f:
                    hash_content.append(f.read())
            import hashlib
388
389
            code_hash = hashlib.md5("\n".join(hash_content).encode(),
                                    usedforsecurity=False).hexdigest()
390
391
392
393
394
395
396
            factors.append(code_hash)

            # 3. compiler hash
            compiler_hash = self.compiler_manager.compute_hash(vllm_config)
            factors.append(compiler_hash)

            # combine all factors to generate the cache dir
397
398
            hash_key = hashlib.md5(str(factors).encode(),
                                   usedforsecurity=False).hexdigest()[:10]
399
400

            cache_dir = os.path.join(
401
402
403
404
405
406
407
                envs.VLLM_CACHE_ROOT,
                "torch_compile_cache",
                hash_key,
            )
            self.compilation_config.cache_dir = cache_dir

        cache_dir = self.compilation_config.cache_dir
408
        os.makedirs(cache_dir, exist_ok=True)
409
410
411
        rank = vllm_config.parallel_config.rank
        dp_rank = vllm_config.parallel_config.data_parallel_rank
        local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
412
        os.makedirs(local_cache_dir, exist_ok=True)
413
        self.compilation_config.local_cache_dir = local_cache_dir
414

415
416
417
        disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE

        if disable_cache:
418
419
420
            logger.info("vLLM's torch.compile cache is disabled.")
        else:
            logger.info("Using cache directory: %s for vLLM's torch.compile",
421
                        local_cache_dir)
422

423
424
        self.compiler_manager.initialize_cache(local_cache_dir, disable_cache)

425
426
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
427
        compilation_counter.num_graphs_seen += 1
428
429
430
        from .monitor import torch_compile_start_time
        dynamo_time = time.time() - torch_compile_start_time
        logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
431
        self.compilation_config.compilation_time += dynamo_time
432
433
434
435
436
437

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
438
        self.configure_post_pass()
439
440

        self.split_gm, self.piecewise_graphs = split_graph(
441
            graph, self.compilation_config.splitting_ops)
442

443
        from torch._dynamo.utils import lazy_format_graph_code
444
445
446
447
448

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
449

450
451
452
453
454
455
456
457
458
459
        compilation_counter.num_piecewise_graphs_seen += len(
            self.piecewise_graphs)
        submod_names_to_compile = [
            item.submod_name for item in self.piecewise_graphs
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
        PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
460
461
                                    self.vllm_config, self.graph_pool,
                                    self).run(*example_inputs)
462

463
464
465
466
467
468
469
470
471
472
473
474
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
            # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
            # use `print_readable` because it can include submodules
            src = "from __future__ import annotations\nimport torch\n" + \
                self.split_gm.print_readable(print_output=False)
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

            logger.debug("Computation graph saved to %s", graph_path)

475
476
        self._called = True

477
478
        if not self.compilation_config.use_cudagraph or \
            not self.compilation_config.cudagraph_copy_inputs:
479
480
481
482
483
484
485
486
487
488
489
            return self.split_gm

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
490
491
492
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
493
494
        self.sym_tensor_indices = [
            i for i, x in enumerate(fake_args)
495
496
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \
                any(is_symbolic(d) for d in x.size())
497
498
499
500
501
502
503
504
505
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

youkaichao's avatar
youkaichao committed
506
507
        # this is the callable we return to Dynamo to run
        def copy_and_call(*args):
508
509
510
511
512
513
514
515
516
517
518
519
520
521
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

        return copy_and_call
522
523
524
525
526
527


@dataclasses.dataclass
class ConcreteSizeEntry:
    runtime_shape: int
    need_to_compile: bool  # the size is in compile_sizes
528
    use_cudagraph: bool  # the size is in cudagraph_capture_sizes
529
530
531
532
533
534
535

    compiled: bool = False
    runnable: Callable = None  # type: ignore
    num_finished_warmup: int = 0
    cudagraph: Optional[torch.cuda.CUDAGraph] = None
    output: Optional[Any] = None

536
537
538
539
    # for cudagraph debugging, track the input addresses
    # during capture, and check if they are the same during replay
    input_addresses: Optional[List[int]] = None

540
541
542

class PiecewiseBackend:

543
544
545
    def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
                 graph_pool: Any, piecewise_compile_index: int,
                 total_piecewise_compiles: int, sym_shape_indices: List[int],
546
547
                 compiled_graph_for_general_shape: Callable,
                 vllm_backend: VllmBackend):
548
549
550
551
552
553
        """
        The backend for piecewise compilation.
        It mainly handles the compilation and cudagraph capturing.

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
554
        `compilation_config.compile_sizes`.
555
556
557
558
559
560
561

        Independently, we will capture cudagraph for different shapes.

        If a shape needs both compilation and cudagraph, we will
        compile it first, and then capture cudagraph.
        """
        self.graph = graph
562
563
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
564
        self.graph_pool = graph_pool
565
566
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles
567
        self.vllm_backend = vllm_backend
568
569
570
571

        self.is_first_graph = piecewise_compile_index == 0
        self.is_last_graph = (
            piecewise_compile_index == total_piecewise_compiles - 1)
572
573

        self.compile_sizes: Set[int] = set(
574
            self.compilation_config.compile_sizes)
575
576
        self.cudagraph_capture_sizes: Set[int] = set(
            self.compilation_config.cudagraph_capture_sizes
577
        ) if self.compilation_config.use_cudagraph else set()
578
579
580

        self.first_run_finished = False

581
        self.compiled_graph_for_general_shape = compiled_graph_for_general_shape  # noqa
582

583
        self.sym_shape_indices = sym_shape_indices
584

585
586
        self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"

587
588
589
        # the entries for different shapes that we need to either
        # compile or capture cudagraph
        self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
590
591
592
593

        # to_be_compiled_sizes tracks the remaining sizes to compile,
        # and updates during the compilation process, so we need to copy it
        self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
594
        for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
595
596
597
            self.concrete_size_entries[shape] = ConcreteSizeEntry(
                runtime_shape=shape,
                need_to_compile=shape in self.compile_sizes,
598
                use_cudagraph=shape in self.cudagraph_capture_sizes,
599
600
            )

601
602
603
604
    def check_for_ending_compilation(self):
        if self.is_last_graph and not self.to_be_compiled_sizes:
            # no specific sizes to compile
            # save the hash of the inductor graph for the next run
605
            self.vllm_backend.compiler_manager.save_to_file()
606
607
            end_monitoring_torch_compile(self.vllm_config)

608
609
610
    def __call__(self, *args) -> Any:
        if not self.first_run_finished:
            self.first_run_finished = True
611
            self.check_for_ending_compilation()
612
613
614
615
616
617
618
619
            return self.compiled_graph_for_general_shape(*args)

        runtime_shape = args[self.sym_shape_indices[0]]
        if runtime_shape not in self.concrete_size_entries:
            # we don't need to do anything for this shape
            return self.compiled_graph_for_general_shape(*args)

        entry = self.concrete_size_entries[runtime_shape]
620

621
622
        if entry.runnable is None:
            entry.runnable = self.compiled_graph_for_general_shape
623

624
625
        if entry.need_to_compile and not entry.compiled:
            entry.compiled = True
626
            self.to_be_compiled_sizes.remove(runtime_shape)
627
            # args are real arguments
628
            entry.runnable = self.vllm_backend.compiler_manager.compile(
629
630
                self.graph,
                args,
631
632
                self.compilation_config.inductor_compile_config,
                self.compilation_config,
633
634
                graph_index=self.piecewise_compile_index,
                num_graphs=self.total_piecewise_compiles,
635
                runtime_shape=runtime_shape)
636

637
638
            # finished compilations for all required shapes
            if self.is_last_graph and not self.to_be_compiled_sizes:
639
                self.check_for_ending_compilation()
640

641
642
643
644
        if not entry.use_cudagraph:
            return entry.runnable(*args)

        if entry.cudagraph is None:
645
            if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups:  # noqa
646
647
648
649
650
                entry.num_finished_warmup += 1
                if self.is_first_graph:
                    logger.debug(
                        "Warming up %s/%s for shape %s",
                        entry.num_finished_warmup,
651
                        self.compilation_config.cudagraph_num_of_warmups,
652
653
654
655
                        runtime_shape)
                return entry.runnable(*args)

            if self.is_first_graph:
656
657
658
659
660
                # Since we capture cudagraph for many different shapes and
                # capturing is fast, we don't need to log it for every shape.
                # We only log it in the debug mode.
                logger.debug("Capturing a cudagraph for shape %s",
                             runtime_shape)
661

662
663
664
665
            input_addresses = [
                x.data_ptr() for x in args if isinstance(x, torch.Tensor)
            ]
            entry.input_addresses = input_addresses
666
            cudagraph = torch.cuda.CUDAGraph()
667

668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
            with ExitStack() as stack:
                if not self.is_first_graph:
                    # during every model forward, we will capture
                    # many pieces of cudagraphs (roughly one per layer).
                    # running gc again and again across layers will
                    # make the cudagraph capture very slow.
                    # therefore, we only run gc for the first graph,
                    # and disable gc for the rest of the graphs.
                    stack.enter_context(patch("gc.collect", lambda: None))
                    stack.enter_context(
                        patch("torch.cuda.empty_cache", lambda: None))

                # mind-exploding: carefully manage the reference and memory.
                with torch.cuda.graph(cudagraph, pool=self.graph_pool):
                    # `output` is managed by pytorch's cudagraph pool
                    output = entry.runnable(*args)
                    if self.is_last_graph:
                        # by converting it to weak ref,
                        # the original `output` will immediately be released
                        # to save memory. It is only safe to do this for
                        # the last graph, because the output of the last graph
                        # will not be used by any other cuda graph.
                        output = weak_ref_tensors(output)
691
692
693
694
695

            # here we always use weak ref for the output
            # to save memory
            entry.output = weak_ref_tensors(output)
            entry.cudagraph = cudagraph
696
697
698

            compilation_counter.num_cudagraph_caputured += 1

699
700
701
702
703
704
705
706
707
708
709
710
711
712
            # important: we need to return the output, rather than
            # the weak ref of the output, so that pytorch can correctly
            # manage the memory during cuda graph capture
            return output

        if self.is_debugging_mode:
            # check if the input addresses are the same
            new_input_addresses = [
                x.data_ptr() for x in args if isinstance(x, torch.Tensor)
            ]
            assert new_input_addresses == entry.input_addresses, (
                "Input addresses for cudagraphs are different during replay."
                f" Expected {entry.input_addresses}, got {new_input_addresses}"
            )
713
714
715

        entry.cudagraph.replay()
        return entry.output