backends.py 29 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import ast
4
import dataclasses
5
6
import os
import pprint
7
import time
8
from contextlib import ExitStack
9
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
10
from unittest.mock import patch
11
12
13
14

import torch
import torch.fx as fx

15
import vllm.envs as envs
16
from vllm.config import CompilationConfig, VllmConfig
17
from vllm.logger import init_logger
18
from vllm.utils import weak_ref_tensors
19

20
from .compiler_interface import EagerAdaptor, InductorAdaptor
21
from .counter import compilation_counter
22
from .inductor_pass import InductorPass
23
from .monitor import end_monitoring_torch_compile
24
from .pass_manager import PostGradPassManager
25
26
27

logger = init_logger(__name__)

28

29
30
31
32
33
class CompilerManager:
    """
    A manager to manage the compilation process, including
    caching the compiled graph, loading the compiled graph,
    and compiling the graph.
34

35
36
37
    The cache is a dict mapping
    `(runtime_shape, graph_index, backend_name)`
    to `any_data` returned from the compiler.
38

39
40
41
    When serializing the cache, we save it to a Python file
    for readability. We don't use json here because json doesn't
    support int as key.
42
43
    """

44
45
46
47
    def __init__(self, use_inductor: bool):
        self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict()
        cls = InductorAdaptor if use_inductor else EagerAdaptor
        self.compiler = cls()
48
        self.is_cache_updated = False
49

50
51
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        return self.compiler.compute_hash(vllm_config)
52

53
54
    def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
        self.disable_cache = disable_cache
55
        self.cache_dir = cache_dir
56
57
58
59
        self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")

        if not disable_cache and os.path.exists(self.cache_file_path):
            # load the cache from the file
60
            with open(self.cache_file_path) as f:
61
62
63
64
65
66
67
                # we use ast.literal_eval to parse the data
                # because it is a safe way to parse Python literals.
                # do not use eval(), it is unsafe.
                self.cache = ast.literal_eval(f.read())

        self.compiler.initialize_cache(cache_dir=cache_dir,
                                       disable_cache=disable_cache)
68
69

    def save_to_file(self):
70
        if self.disable_cache or not self.is_cache_updated:
71
            return
72
73
        printer = pprint.PrettyPrinter(indent=4)
        data = printer.pformat(self.cache)
74
        with open(self.cache_file_path, "w") as f:
75
76
77
78
79
80
81
82
83
84
85
86
            f.write(data)

    def load(self,
             graph: fx.GraphModule,
             example_inputs: List[Any],
             graph_index: int,
             runtime_shape: Optional[int] = None) -> Optional[Callable]:
        if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
            return None
        handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
        compiled_graph = self.compiler.load(handle, graph, example_inputs,
                                            graph_index, runtime_shape)
87
        logger.debug(
88
89
90
91
92
93
94
95
96
97
98
99
100
            "Directly load the %s-th graph for shape %s from %s via "
            "handle %s", graph_index, str(runtime_shape), self.compiler.name,
            handle)
        return compiled_graph

    def compile(self,
                graph: fx.GraphModule,
                example_inputs,
                additional_inductor_config,
                compilation_config: CompilationConfig,
                graph_index: int = 0,
                num_graphs: int = 1,
                runtime_shape: Optional[int] = None) -> Any:
101
        if graph_index == 0:
102
103
104
105
106
107
108
109
110
111
112
113
            # before compiling the first graph, record the start time
            global compilation_start_time
            compilation_start_time = time.time()

        compilation_counter.num_backend_compilations += 1

        compiled_graph = None

        # try to load from the cache
        compiled_graph = self.load(graph, example_inputs, graph_index,
                                   runtime_shape)
        if compiled_graph is not None:
114
115
116
117
118
119
120
121
            if graph_index == num_graphs - 1:
                # after loading the last graph for this shape, record the time.
                # there can be multiple graphs due to piecewise compilation.
                now = time.time()
                elapsed = now - compilation_start_time
                logger.info(
                    "Directly load the compiled graph(s) for shape %s "
                    "from the cache, took %.3f s", str(runtime_shape), elapsed)
122
123
124
125
126
127
128
129
130
131
132
133
134
            return compiled_graph

        # no compiler cached the graph, or the cache is disabled,
        # we need to compile it
        compiled_graph, handle = self.compiler.compile(
            graph, example_inputs, additional_inductor_config, runtime_shape)

        assert compiled_graph is not None, "Failed to compile the graph"

        # store the artifact in the cache
        if handle is not None:
            self.cache[(runtime_shape, graph_index,
                        self.compiler.name)] = handle
135
            self.is_cache_updated = True
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
            if graph_index == 0:
                # adds some info logging for the first graph
                logger.info("Cache the graph of shape %s for later use",
                            str(runtime_shape))
            logger.debug(
                "store the %s-th graph for shape %s from %s via handle %s",
                graph_index, str(runtime_shape), self.compiler.name, handle)

        # after compiling the last graph, record the end time
        if graph_index == num_graphs - 1:
            now = time.time()
            elapsed = now - compilation_start_time
            compilation_config.compilation_time += elapsed
            if runtime_shape is None:
                logger.info("Compiling a graph for general shape takes %.2f s",
                            elapsed)
            else:
                logger.info("Compiling a graph for shape %s takes %.2f s",
                            runtime_shape, elapsed)
155

156
        return compiled_graph
157
158


159
160
161
@dataclasses.dataclass
class SplitItem:
    submod_name: str
162
    graph_id: int
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    is_splitting_graph: bool
    graph: fx.GraphModule


def split_graph(graph: fx.GraphModule,
                ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]:
    # split graph by ops
    subgraph_id = 0
    node_to_subgraph_id = {}
    split_op_graphs = []
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
        if node.op == 'call_function' and str(node.target) in ops:
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
189
        graph,
190
191
192
        None,
        lambda node: node_to_subgraph_id[node],
        keep_original_order=True)
193

194
    outputs = []
195

196
    names = [name for (name, module) in split_gm.named_modules()]
197

198
199
200
201
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
202

203
        module = getattr(split_gm, name)
204

205
        graph_id = int(name.replace("submod_", ""))
206
207
208
209
210
        outputs.append(
            SplitItem(name, graph_id, (graph_id in split_op_graphs), module))

    # sort by intetger graph_id, rather than string name
    outputs.sort(key=lambda x: x.graph_id)
211

212
    return split_gm, outputs
213
214


215
216
217
# we share the global graph pool among all the backends
global_graph_pool = None

218
219
compilation_start_time = 0.0

220
221
222
223
224
225

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
226
227
228
229
230

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
231
232
233
    """

    def __init__(self, module: torch.fx.GraphModule,
234
                 compile_submod_names: List[str], vllm_config: VllmConfig,
235
                 graph_pool, vllm_backend: "VllmBackend"):
236
237
238
239
        super().__init__(module)
        from torch._guards import detect_fake_mode
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
240
        self.compilation_config = vllm_config.compilation_config
241
        self.graph_pool = graph_pool
242
        self.vllm_config = vllm_config
243
        self.vllm_backend = vllm_backend
244
245
246
247
248
249

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
250
251
        with self.fake_mode:
            return super().run(*fake_args)
252
253
254
255
256
257
258
259

    def call_module(self, target: torch.fx.node.Target,
                    args: Tuple[torch.fx.node.Argument,
                                ...], kwargs: Dict[str, Any]) -> Any:
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
260
            index = self.compile_submod_names.index(target)
261
262
263
264
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
265
            global compilation_start_time
266
267
            compiled_graph_for_general_shape = self.vllm_backend.\
                compiler_manager.compile(
268
269
                submod,
                args,
270
271
                self.compilation_config.inductor_compile_config,
                self.compilation_config,
272
273
                graph_index=index,
                num_graphs=len(self.compile_submod_names),
274
                runtime_shape=None)
275
276

            self.module.__dict__[target] = PiecewiseBackend(
277
                submod, self.vllm_config, self.graph_pool, index,
278
                len(self.compile_submod_names), sym_shape_indices,
279
                compiled_graph_for_general_shape, self.vllm_backend)
280
281
282
283
284
285

            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


286
class VllmBackend:
287
    """The compilation backend for `torch.compile` with vLLM.
288
289
    It is used for compilation level of `CompilationLevel.PIECEWISE`,
    where we customize the compilation.
290

291
292
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
293

294
295
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
296
    """
297

298
299
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
300
301
302
303
304
305
306
307
    graph_pool: Any
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
    piecewise_graphs: List[SplitItem]
    returned_callable: Callable
308
309
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
310
311
    sym_tensor_indices: List[int]
    input_buffers: List[torch.Tensor]
312
    compiler_manager: CompilerManager
313

314
315
    def __init__(
        self,
316
        vllm_config: VllmConfig,
317
    ):
318
319
320
321
322
323
324
325
        global global_graph_pool
        if global_graph_pool is None:
            global_graph_pool = torch.cuda.graph_pool_handle()

        # TODO: in the future, if we want to use multiple
        # streams, it might not be safe to share a global pool.
        # only investigate this when we use multiple streams
        self.graph_pool = global_graph_pool
326
327
328

        # Passes to run on the graph post-grad.
        self.post_grad_pass_manager = PostGradPassManager()
329

330
331
332
        self.sym_tensor_indices = []
        self.input_buffers = []

333
334
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
335

336
337
338
        self.compiler_manager: CompilerManager = CompilerManager(
            self.compilation_config.use_inductor)

339
340
341
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

342
    def configure_post_pass(self):
343
        config = self.compilation_config
344
        self.post_grad_pass_manager.configure(self.vllm_config)
345

346
347
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
348
        inductor_config = config.inductor_compile_config
349
350
351
        PASS_KEY = "post_grad_custom_post_pass"
        if PASS_KEY in inductor_config:
            # Config should automatically wrap all inductor passes
352
353
354
355
356
357
            if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
                assert (inductor_config[PASS_KEY].uuid() ==
                        self.post_grad_pass_manager.uuid())
            else:
                assert isinstance(inductor_config[PASS_KEY], InductorPass)
                self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
358
        inductor_config[PASS_KEY] = self.post_grad_pass_manager
359

360
361
    def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

362
        vllm_config = self.vllm_config
363
364
365
366
367
368
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.

369
            factors = []
370
371
372
373
374
            # 0. factors come from the env, for example, The values of
            # VLLM_PP_LAYER_PARTITION will affects the computation graph.
            env_hash = envs.compute_hash()
            factors.append(env_hash)

375
376
377
            # 1. factors come from the vllm_config (it mainly summarizes how the
            #    model is created)
            config_hash = vllm_config.compute_hash()
378
            factors.append(config_hash)
379
380
381
382
383
384
385
386
387
388
389
390

            # 2. factors come from the code files that are traced by Dynamo (
            #    it mainly summarizes how the model is used in forward pass)
            forward_code_files = list(
                sorted(self.compilation_config.traced_files))
            self.compilation_config.traced_files.clear()
            logger.debug(
                "Traced files (to be considered for compilation cache):\n%s",
                "\n".join(forward_code_files))
            hash_content = []
            for filepath in forward_code_files:
                hash_content.append(filepath)
391
392
393
394
                if filepath == "<string>":
                    # This means the function was dynamically generated, with
                    # e.g. exec(). We can't actually check these.
                    continue
395
396
397
                with open(filepath) as f:
                    hash_content.append(f.read())
            import hashlib
398
399
            code_hash = hashlib.md5("\n".join(hash_content).encode(),
                                    usedforsecurity=False).hexdigest()
400
401
402
403
404
405
406
            factors.append(code_hash)

            # 3. compiler hash
            compiler_hash = self.compiler_manager.compute_hash(vllm_config)
            factors.append(compiler_hash)

            # combine all factors to generate the cache dir
407
408
            hash_key = hashlib.md5(str(factors).encode(),
                                   usedforsecurity=False).hexdigest()[:10]
409
410

            cache_dir = os.path.join(
411
412
413
414
415
416
                envs.VLLM_CACHE_ROOT,
                "torch_compile_cache",
                hash_key,
            )
            self.compilation_config.cache_dir = cache_dir

417
418
419
420
421
        if compilation_counter.num_graphs_seen > 0:
            cache_dir = self.compilation_config.cache_dir + \
                f'-{compilation_counter.num_graphs_seen}'
        else:
            cache_dir = self.compilation_config.cache_dir
422
        os.makedirs(cache_dir, exist_ok=True)
423
        self.compilation_config.cache_dir = cache_dir
424
425
426
        rank = vllm_config.parallel_config.rank
        dp_rank = vllm_config.parallel_config.data_parallel_rank
        local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
427
        os.makedirs(local_cache_dir, exist_ok=True)
428
        self.compilation_config.local_cache_dir = local_cache_dir
429

430
431
432
        disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE

        if disable_cache:
433
434
435
            logger.info("vLLM's torch.compile cache is disabled.")
        else:
            logger.info("Using cache directory: %s for vLLM's torch.compile",
436
                        local_cache_dir)
437

438
439
        self.compiler_manager.initialize_cache(local_cache_dir, disable_cache)

440
441
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
442
        compilation_counter.num_graphs_seen += 1
443
444
445
        from .monitor import torch_compile_start_time
        dynamo_time = time.time() - torch_compile_start_time
        logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
446
        self.compilation_config.compilation_time += dynamo_time
447
448
449
450
451
452

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
453
        self.configure_post_pass()
454
455

        self.split_gm, self.piecewise_graphs = split_graph(
456
            graph, self.compilation_config.splitting_ops)
457

458
        from torch._dynamo.utils import lazy_format_graph_code
459
460
461
462
463

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
464

465
466
467
468
469
470
471
472
473
474
        compilation_counter.num_piecewise_graphs_seen += len(
            self.piecewise_graphs)
        submod_names_to_compile = [
            item.submod_name for item in self.piecewise_graphs
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
        PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
475
476
                                    self.vllm_config, self.graph_pool,
                                    self).run(*example_inputs)
477

478
479
480
481
482
483
484
485
486
487
488
489
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
            # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
            # use `print_readable` because it can include submodules
            src = "from __future__ import annotations\nimport torch\n" + \
                self.split_gm.print_readable(print_output=False)
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

            logger.debug("Computation graph saved to %s", graph_path)

490
491
        self._called = True

492
493
        if not self.compilation_config.use_cudagraph or \
            not self.compilation_config.cudagraph_copy_inputs:
494
495
496
497
498
499
500
501
502
503
504
            return self.split_gm

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
505
506
507
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
508
509
        self.sym_tensor_indices = [
            i for i, x in enumerate(fake_args)
510
511
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \
                any(is_symbolic(d) for d in x.size())
512
513
514
515
516
517
518
519
520
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

youkaichao's avatar
youkaichao committed
521
522
        # this is the callable we return to Dynamo to run
        def copy_and_call(*args):
523
524
525
526
527
528
529
530
531
532
533
534
535
536
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

        return copy_and_call
537
538
539
540
541
542


@dataclasses.dataclass
class ConcreteSizeEntry:
    runtime_shape: int
    need_to_compile: bool  # the size is in compile_sizes
543
    use_cudagraph: bool  # the size is in cudagraph_capture_sizes
544
545
546
547
548
549
550

    compiled: bool = False
    runnable: Callable = None  # type: ignore
    num_finished_warmup: int = 0
    cudagraph: Optional[torch.cuda.CUDAGraph] = None
    output: Optional[Any] = None

551
552
553
554
    # for cudagraph debugging, track the input addresses
    # during capture, and check if they are the same during replay
    input_addresses: Optional[List[int]] = None

555
556
557

class PiecewiseBackend:

558
559
560
    def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
                 graph_pool: Any, piecewise_compile_index: int,
                 total_piecewise_compiles: int, sym_shape_indices: List[int],
561
562
                 compiled_graph_for_general_shape: Callable,
                 vllm_backend: VllmBackend):
563
564
565
566
567
568
        """
        The backend for piecewise compilation.
        It mainly handles the compilation and cudagraph capturing.

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
569
        `compilation_config.compile_sizes`.
570
571
572
573
574
575
576

        Independently, we will capture cudagraph for different shapes.

        If a shape needs both compilation and cudagraph, we will
        compile it first, and then capture cudagraph.
        """
        self.graph = graph
577
578
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
579
        self.graph_pool = graph_pool
580
581
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles
582
        self.vllm_backend = vllm_backend
583
584
585
586

        self.is_first_graph = piecewise_compile_index == 0
        self.is_last_graph = (
            piecewise_compile_index == total_piecewise_compiles - 1)
587
588

        self.compile_sizes: Set[int] = set(
589
            self.compilation_config.compile_sizes)
590
591
        self.cudagraph_capture_sizes: Set[int] = set(
            self.compilation_config.cudagraph_capture_sizes
592
        ) if self.compilation_config.use_cudagraph else set()
593
594
595

        self.first_run_finished = False

596
        self.compiled_graph_for_general_shape = compiled_graph_for_general_shape  # noqa
597

598
        self.sym_shape_indices = sym_shape_indices
599

600
601
        self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"

602
603
604
        # the entries for different shapes that we need to either
        # compile or capture cudagraph
        self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
605
606
607
608

        # to_be_compiled_sizes tracks the remaining sizes to compile,
        # and updates during the compilation process, so we need to copy it
        self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
609
        for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
610
611
612
            self.concrete_size_entries[shape] = ConcreteSizeEntry(
                runtime_shape=shape,
                need_to_compile=shape in self.compile_sizes,
613
                use_cudagraph=shape in self.cudagraph_capture_sizes,
614
615
            )

616
617
618
619
    def check_for_ending_compilation(self):
        if self.is_last_graph and not self.to_be_compiled_sizes:
            # no specific sizes to compile
            # save the hash of the inductor graph for the next run
620
            self.vllm_backend.compiler_manager.save_to_file()
621
622
            end_monitoring_torch_compile(self.vllm_config)

623
624
625
    def __call__(self, *args) -> Any:
        if not self.first_run_finished:
            self.first_run_finished = True
626
            self.check_for_ending_compilation()
627
628
629
630
631
632
633
634
            return self.compiled_graph_for_general_shape(*args)

        runtime_shape = args[self.sym_shape_indices[0]]
        if runtime_shape not in self.concrete_size_entries:
            # we don't need to do anything for this shape
            return self.compiled_graph_for_general_shape(*args)

        entry = self.concrete_size_entries[runtime_shape]
635

636
637
        if entry.runnable is None:
            entry.runnable = self.compiled_graph_for_general_shape
638

639
640
        if entry.need_to_compile and not entry.compiled:
            entry.compiled = True
641
            self.to_be_compiled_sizes.remove(runtime_shape)
642
            # args are real arguments
643
            entry.runnable = self.vllm_backend.compiler_manager.compile(
644
645
                self.graph,
                args,
646
647
                self.compilation_config.inductor_compile_config,
                self.compilation_config,
648
649
                graph_index=self.piecewise_compile_index,
                num_graphs=self.total_piecewise_compiles,
650
                runtime_shape=runtime_shape)
651

652
653
            # finished compilations for all required shapes
            if self.is_last_graph and not self.to_be_compiled_sizes:
654
                self.check_for_ending_compilation()
655

656
657
658
659
        if not entry.use_cudagraph:
            return entry.runnable(*args)

        if entry.cudagraph is None:
660
            if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups:  # noqa
661
662
663
664
665
                entry.num_finished_warmup += 1
                if self.is_first_graph:
                    logger.debug(
                        "Warming up %s/%s for shape %s",
                        entry.num_finished_warmup,
666
                        self.compilation_config.cudagraph_num_of_warmups,
667
668
669
670
                        runtime_shape)
                return entry.runnable(*args)

            if self.is_first_graph:
671
672
673
674
675
                # Since we capture cudagraph for many different shapes and
                # capturing is fast, we don't need to log it for every shape.
                # We only log it in the debug mode.
                logger.debug("Capturing a cudagraph for shape %s",
                             runtime_shape)
676

677
678
679
680
            input_addresses = [
                x.data_ptr() for x in args if isinstance(x, torch.Tensor)
            ]
            entry.input_addresses = input_addresses
681
            cudagraph = torch.cuda.CUDAGraph()
682

683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
            with ExitStack() as stack:
                if not self.is_first_graph:
                    # during every model forward, we will capture
                    # many pieces of cudagraphs (roughly one per layer).
                    # running gc again and again across layers will
                    # make the cudagraph capture very slow.
                    # therefore, we only run gc for the first graph,
                    # and disable gc for the rest of the graphs.
                    stack.enter_context(patch("gc.collect", lambda: None))
                    stack.enter_context(
                        patch("torch.cuda.empty_cache", lambda: None))

                # mind-exploding: carefully manage the reference and memory.
                with torch.cuda.graph(cudagraph, pool=self.graph_pool):
                    # `output` is managed by pytorch's cudagraph pool
                    output = entry.runnable(*args)
                    if self.is_last_graph:
                        # by converting it to weak ref,
                        # the original `output` will immediately be released
                        # to save memory. It is only safe to do this for
                        # the last graph, because the output of the last graph
                        # will not be used by any other cuda graph.
                        output = weak_ref_tensors(output)
706
707
708
709
710

            # here we always use weak ref for the output
            # to save memory
            entry.output = weak_ref_tensors(output)
            entry.cudagraph = cudagraph
711
712
713

            compilation_counter.num_cudagraph_caputured += 1

714
715
716
717
718
719
720
721
722
723
724
725
726
727
            # important: we need to return the output, rather than
            # the weak ref of the output, so that pytorch can correctly
            # manage the memory during cuda graph capture
            return output

        if self.is_debugging_mode:
            # check if the input addresses are the same
            new_input_addresses = [
                x.data_ptr() for x in args if isinstance(x, torch.Tensor)
            ]
            assert new_input_addresses == entry.input_addresses, (
                "Input addresses for cudagraphs are different during replay."
                f" Expected {entry.input_addresses}, got {new_input_addresses}"
            )
728
729
730

        entry.cudagraph.replay()
        return entry.output