backends.py 28.9 KB
Newer Older
1
import ast
2
import copy
3
import dataclasses
4
5
import os
import pprint
6
import time
7
from collections import defaultdict
8
from contextlib import ExitStack
9
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
10
from unittest.mock import patch
11
12
13
14

import torch
import torch.fx as fx

15
import vllm.envs as envs
16
from vllm.config import CompilationConfig, VllmConfig
17
from vllm.logger import init_logger
18
from vllm.utils import weak_ref_tensors
19

20
from .counter import compilation_counter
21
from .inductor_pass import InductorPass
22
from .monitor import end_monitoring_torch_compile
23
from .pass_manager import PostGradPassManager
24
25
26

logger = init_logger(__name__)

27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class InductorHashCache:
    """
    Disk format: a Python list of tuples, each tuple is
    (runtime_shape, graph_index, hash_str)
    We use list of tuple for readability.

    In-memory format: a defaultdict of dict, where the key is
    runtime_shape, and the value is a dict of graph_index to hash_str.

    The data is essentially `Dict[Optional[int], Dict[int, str]]`,
    we don't use json here because json doesn't support int as key.

    TODO: better off-the-shelf solution to serialize the data?
    """

    def __init__(self, cache_dir: str, disabled: bool = False):
        self.cache: defaultdict = defaultdict(dict)
        self.disabled = disabled
        self.cache_dir = cache_dir
        self.cache_file_path = os.path.join(cache_dir,
                                            "inductor_hash_cache.py")
        if disabled:
            return
        # set flags so that Inductor and Triton store their cache
        # in the cache_dir, then users only need to copy the cache_dir
        # to another machine to reuse the cache.
        inductor_cache = os.path.join(cache_dir, "inductor_cache")
        os.makedirs(inductor_cache, exist_ok=True)
        os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
        triton_cache = os.path.join(cache_dir, "triton_cache")
        os.makedirs(triton_cache, exist_ok=True)
        os.environ["TRITON_CACHE_DIR"] = triton_cache
        if os.path.exists(self.cache_file_path):
            with open(self.cache_file_path) as f:
                self.deserialize(f.read())

    def deserialize(self, data: str):
        # we use ast.literal_eval to parse the data
        # because it is a safe way to parse Python literals.
        # do not use eval(), it is unsafe.
        list_data = ast.literal_eval(data)
        for runtime_shape, graph_index, hash_str in list_data:
            self.cache[runtime_shape][graph_index] = hash_str

    def serialize(self) -> str:
        data = []
        for runtime_shape, graph_index_to_hash_str in self.cache.items():
            for graph_index, hash_str in graph_index_to_hash_str.items():
                data.append((runtime_shape, graph_index, hash_str))
        printer = pprint.PrettyPrinter(indent=4)
        return printer.pformat(data)

    def save_to_file(self):
        if self.disabled:
            return
        with open(self.cache_file_path, "w") as f:
            f.write(self.serialize())

    def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
        if self.disabled:
            return False
        runtime_shape, graph_index = key
        return runtime_shape in self.cache and graph_index in self.cache[
            runtime_shape]

    def __getitem__(self, key: Tuple[Optional[int], int]) -> str:
        if self.disabled:
            raise KeyError("cannot read from disabled cache")
        runtime_shape, graph_index = key
        return self.cache[runtime_shape][graph_index]

    def __setitem__(self, key: Tuple[Optional[int], int], value: str):
        # setitem for disabled cache is fine, because we
        # don't actually write to the disk
        runtime_shape, graph_index = key
        self.cache[runtime_shape][graph_index] = value


class AlwaysHitShapeEnv:
    """
    Why do we need this class:

    For normal `torch.compile` usage, every compilation will have
    one Dynamo bytecode compilation and one Inductor compilation.
    The Inductor compilation happens under the context of the
    Dynamo bytecode compilation, and that context is used to
    determine the dynamic shape information, etc.

    For our use case, we only run Dynamo bytecode compilation once,
    and run Inductor compilation multiple times with different shapes
    plus a general shape. The compilation for specific shapes happens
    outside of the context of the Dynamo bytecode compilation. At that
    time, we don't have shape environment to provide to Inductor, and
    it will fail the Inductor code cache lookup.

    By providing a dummy shape environment that always hits, we can
    make the Inductor code cache lookup always hit, and we can
    compile the graph for different shapes as needed.

    The following dummy methods are obtained by trial-and-error
    until it works.
    """

    def __init__(self) -> None:
        self.guards: List[Any] = []

    def evaluate_guards_expression(self, *args, **kwargs):
        return True

    def get_pruned_guards(self, *args, **kwargs):
        return []

    def produce_guards_expression(self, *args, **kwargs):
        return ""


144
def wrap_inductor(graph: fx.GraphModule,
145
146
                  example_inputs,
                  additional_inductor_config,
147
148
149
                  compilation_config: CompilationConfig,
                  graph_index: int = 0,
                  num_graphs: int = 1,
150
                  runtime_shape: Optional[int] = None,
151
                  use_inductor: bool = True) -> Any:
152
153
154
155
156
    if graph_index == 0:
        # before compiling the first graph, record the start time
        global compilation_start_time
        compilation_start_time = time.time()

157
158
159
160
161
    if not use_inductor:
        return graph

    compilation_counter.num_inductor_compilations += 1

162
    from torch._inductor import config
163
    current_config = config.get_config_copy()
164
    from torch._inductor.compile_fx import compile_fx
165
166
167

    if additional_inductor_config is not None:
        current_config.update(additional_inductor_config)
168

169
170
171
172
173
174
    if isinstance(runtime_shape, int):
        # for a specific batchsize, tuning triton kernel parameters
        # can be beneficial
        current_config["max_autotune"] = True
        current_config["coordinate_descent_tuning"] = True

175
176
177
    # inductor can inplace modify the graph, so we need to copy it
    # see https://github.com/pytorch/pytorch/issues/138980
    graph = copy.deepcopy(graph)
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

    cache_data = compilation_config.inductor_hash_cache
    if (runtime_shape, graph_index) in cache_data:
        # we compiled this graph before
        # so we can directly lookup the compiled graph via hash
        hash_str = cache_data[(runtime_shape, graph_index)]
        if graph_index == 0:
            # adds some info logging for the first graph
            logger.info(
                "Directly lookup the graph for shape %s from the cache",
                str(runtime_shape))  # noqa
        logger.debug(
            "directly lookup the %s-th graph for shape %s via hash %s",
            graph_index, str(runtime_shape), hash_str)
        from torch._inductor.codecache import FxGraphCache
        with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
                   lambda *args, **kwargs: AlwaysHitShapeEnv()):
            inductor_compiled_graph = FxGraphCache._lookup_graph(
                hash_str, example_inputs, True, False)
            assert inductor_compiled_graph is not None, (
                "Inductor cache lookup failed. Please remove"
                f"the cache file {compilation_config.inductor_hash_cache.cache_file_path} and try again."  # noqa
            )

        # Inductor calling convention (function signature):
        # f(list) -> tuple
        # Dynamo calling convention (function signature):
        # f(*args) -> Any

        # need to know if the graph returns a tuple
        from torch._inductor.compile_fx import graph_returns_tuple
        returns_tuple = graph_returns_tuple(graph)

youkaichao's avatar
youkaichao committed
211
212
        # this is the callable we return to Dynamo to run
        def compiled_graph(*args):
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
            # convert args to list
            list_args = list(args)
            graph_output = inductor_compiled_graph(list_args)
            # unpack the tuple if needed
            if returns_tuple:
                return graph_output
            else:
                return graph_output[0]
    else:
        # it's the first time we compile this graph
        # the assumption is that we don't have nested Inductor compilation.
        # compiled_fx_graph_hash will only be called once, and we can hook
        # it to get the hash of the compiled graph directly.
        from torch._inductor.codecache import compiled_fx_graph_hash

        def hijack_compiled_fx_graph_hash(*args, **kwargs):
            out = compiled_fx_graph_hash(*args, **kwargs)
            # store the hash in the cache
            nonlocal cache_data
            cache_data[(runtime_shape, graph_index)] = out[0]
            if graph_index == 0:
                # adds some info logging for the first graph
                logger.info("Cache the graph of shape %s for later use",
                            str(runtime_shape))
            logger.debug("store the %s-th graph for shape %s via hash %s",
                         graph_index, str(runtime_shape), out[0])
            return out

        def _check_can_cache(*args, **kwargs):
            # no error means it can be cached.
            # Inductor refuses to cache the graph outside of Dynamo
            # tracing context, and also disables caching for graphs
            # with high-order ops.
            # For vLLM, in either case, we want to cache the graph.
            # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
            return

250
        def _get_shape_env() -> AlwaysHitShapeEnv:
251
252
253
254
255
256
257
258
259
260
261
262
263
264
            return AlwaysHitShapeEnv()

        with patch(# for hijacking the hash of the compiled graph
                "torch._inductor.codecache.compiled_fx_graph_hash",
                hijack_compiled_fx_graph_hash), \
            patch(# for providing a dummy shape environment
                "torch._inductor.codecache.FxGraphCache._get_shape_env",
                 _get_shape_env), \
            patch(# for forcing the graph to be cached
                "torch._inductor.codecache.FxGraphCache._check_can_cache",
                _check_can_cache):
            compiled_graph = compile_fx(graph,
                                        example_inputs,
                                        config_patches=current_config)
265
266
267
268
269
270
271
272
273
274
275
276
277
278

    # after compiling the last graph, record the end time
    if graph_index == num_graphs - 1:
        now = time.time()
        elapsed = now - compilation_start_time
        compilation_config.compilation_time += elapsed
        if runtime_shape is None:
            logger.info("Compiling a graph for general shape takes %.2f s",
                        elapsed)
        else:
            logger.info("Compiling a graph for shape %s takes %.2f s",
                        runtime_shape, elapsed)

    return compiled_graph
279
280


281
282
283
@dataclasses.dataclass
class SplitItem:
    submod_name: str
284
    graph_id: int
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    is_splitting_graph: bool
    graph: fx.GraphModule


def split_graph(graph: fx.GraphModule,
                ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]:
    # split graph by ops
    subgraph_id = 0
    node_to_subgraph_id = {}
    split_op_graphs = []
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
        if node.op == 'call_function' and str(node.target) in ops:
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
311
        graph,
312
313
314
        None,
        lambda node: node_to_subgraph_id[node],
        keep_original_order=True)
315

316
    outputs = []
317

318
    names = [name for (name, module) in split_gm.named_modules()]
319

320
321
322
323
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
324

325
        module = getattr(split_gm, name)
326

327
        graph_id = int(name.replace("submod_", ""))
328
329
330
331
332
        outputs.append(
            SplitItem(name, graph_id, (graph_id in split_op_graphs), module))

    # sort by intetger graph_id, rather than string name
    outputs.sort(key=lambda x: x.graph_id)
333

334
    return split_gm, outputs
335
336


337
338
339
# we share the global graph pool among all the backends
global_graph_pool = None

340
341
compilation_start_time = 0.0

342
343
344
345
346
347

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
348
349
350
351
352

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
353
354
355
    """

    def __init__(self, module: torch.fx.GraphModule,
356
357
                 compile_submod_names: List[str], vllm_config: VllmConfig,
                 graph_pool):
358
359
360
361
        super().__init__(module)
        from torch._guards import detect_fake_mode
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
362
        self.compilation_config = vllm_config.compilation_config
363
        self.graph_pool = graph_pool
364
        self.vllm_config = vllm_config
365
366
367
368
369
370

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
371
372
        with self.fake_mode:
            return super().run(*fake_args)
373
374
375
376
377
378
379
380

    def call_module(self, target: torch.fx.node.Target,
                    args: Tuple[torch.fx.node.Argument,
                                ...], kwargs: Dict[str, Any]) -> Any:
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
381
            index = self.compile_submod_names.index(target)
382
383
384
385
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
386
            global compilation_start_time
387
388
389
            compiled_graph_for_general_shape = wrap_inductor(
                submod,
                args,
390
391
                self.compilation_config.inductor_compile_config,
                self.compilation_config,
392
393
                graph_index=index,
                num_graphs=len(self.compile_submod_names),
394
                runtime_shape=None,
395
                use_inductor=self.compilation_config.use_inductor)
396
397

            self.module.__dict__[target] = PiecewiseBackend(
398
                submod, self.vllm_config, self.graph_pool, index,
399
                len(self.compile_submod_names), sym_shape_indices,
400
401
402
403
404
405
406
                compiled_graph_for_general_shape)

            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


407
408
409
410
class VllmBackend:
    """The compilation backend for `torch.compile` with VLLM.
    It is used for compilation level of `CompilationLevel.PIECEWISE`,
    where we customize the compilation.
411

412
413
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
414

415
416
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
417
    """
418

419
420
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
421
422
423
424
425
426
427
428
    graph_pool: Any
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
    piecewise_graphs: List[SplitItem]
    returned_callable: Callable
429
430
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
431
432
    sym_tensor_indices: List[int]
    input_buffers: List[torch.Tensor]
433

434
435
    def __init__(
        self,
436
        vllm_config: VllmConfig,
437
    ):
438
439
440
441
442
443
444
445
        global global_graph_pool
        if global_graph_pool is None:
            global_graph_pool = torch.cuda.graph_pool_handle()

        # TODO: in the future, if we want to use multiple
        # streams, it might not be safe to share a global pool.
        # only investigate this when we use multiple streams
        self.graph_pool = global_graph_pool
446
447
448

        # Passes to run on the graph post-grad.
        self.post_grad_pass_manager = PostGradPassManager()
449

450
451
452
        self.sym_tensor_indices = []
        self.input_buffers = []

453
454
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
455

456
457
458
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

459
    def configure_post_pass(self):
460
        config = self.compilation_config
461
        self.post_grad_pass_manager.configure(config.pass_config)
462

463
464
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
465
        inductor_config = config.inductor_compile_config
466
467
468
469
470
471
        PASS_KEY = "post_grad_custom_post_pass"
        if PASS_KEY in inductor_config:
            # Config should automatically wrap all inductor passes
            assert isinstance(inductor_config[PASS_KEY], InductorPass)
            self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
        inductor_config[PASS_KEY] = self.post_grad_pass_manager
472

473
474
    def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

475
476
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
477
        compilation_counter.num_graphs_seen += 1
478
479
480
        from .monitor import torch_compile_start_time
        dynamo_time = time.time() - torch_compile_start_time
        logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
481
        self.compilation_config.compilation_time += dynamo_time
482
483
484
485
486
487

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
488
        self.configure_post_pass()
489
490

        self.split_gm, self.piecewise_graphs = split_graph(
491
            graph, self.compilation_config.splitting_ops)
492

493
        from torch._dynamo.utils import lazy_format_graph_code
494
495
496
497
498

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
499

500
501
502
503
504
505
506
507
508
509
        compilation_counter.num_piecewise_graphs_seen += len(
            self.piecewise_graphs)
        submod_names_to_compile = [
            item.submod_name for item in self.piecewise_graphs
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
        PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
510
                                    self.vllm_config,
511
                                    self.graph_pool).run(*example_inputs)
512
513
514

        self._called = True

515
516
        if not self.compilation_config.use_cudagraph or \
            not self.compilation_config.cudagraph_copy_inputs:
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
            return self.split_gm

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
        self.sym_tensor_indices = [
            i for i, x in enumerate(fake_args)
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

youkaichao's avatar
youkaichao committed
540
541
        # this is the callable we return to Dynamo to run
        def copy_and_call(*args):
542
543
544
545
546
547
548
549
550
551
552
553
554
555
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

        return copy_and_call
556
557
558
559
560
561
562
563
564
565
566
567
568
569


@dataclasses.dataclass
class ConcreteSizeEntry:
    runtime_shape: int
    need_to_compile: bool  # the size is in compile_sizes
    use_cudagraph: bool  # the size is in capture_sizes

    compiled: bool = False
    runnable: Callable = None  # type: ignore
    num_finished_warmup: int = 0
    cudagraph: Optional[torch.cuda.CUDAGraph] = None
    output: Optional[Any] = None

570
571
572
573
    # for cudagraph debugging, track the input addresses
    # during capture, and check if they are the same during replay
    input_addresses: Optional[List[int]] = None

574
575
576

class PiecewiseBackend:

577
578
579
    def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
                 graph_pool: Any, piecewise_compile_index: int,
                 total_piecewise_compiles: int, sym_shape_indices: List[int],
580
                 compiled_graph_for_general_shape: Callable):
581
582
583
584
585
586
        """
        The backend for piecewise compilation.
        It mainly handles the compilation and cudagraph capturing.

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
587
        `compilation_config.compile_sizes`.
588
589
590
591
592
593
594

        Independently, we will capture cudagraph for different shapes.

        If a shape needs both compilation and cudagraph, we will
        compile it first, and then capture cudagraph.
        """
        self.graph = graph
595
596
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
597
        self.graph_pool = graph_pool
598
599
600
601
602
603
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles

        self.is_first_graph = piecewise_compile_index == 0
        self.is_last_graph = (
            piecewise_compile_index == total_piecewise_compiles - 1)
604
605

        self.compile_sizes: Set[int] = set(
606
            self.compilation_config.compile_sizes)
607
        self.capture_sizes: Set[int] = set(
608
609
            self.compilation_config.capture_sizes
        ) if self.compilation_config.use_cudagraph else set()
610
611
612

        self.first_run_finished = False

613
        self.compiled_graph_for_general_shape = compiled_graph_for_general_shape  # noqa
614

615
        self.sym_shape_indices = sym_shape_indices
616

617
618
        self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"

619
620
621
        # the entries for different shapes that we need to either
        # compile or capture cudagraph
        self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
622
623
        self.to_be_compiled_sizes: Set[int] = self.compile_sizes.union(
            self.capture_sizes)
624
625
626
627
628
629
630
631
632
633
        for shape in self.compile_sizes.union(self.capture_sizes):
            self.concrete_size_entries[shape] = ConcreteSizeEntry(
                runtime_shape=shape,
                need_to_compile=shape in self.compile_sizes,
                use_cudagraph=shape in self.capture_sizes,
            )

    def __call__(self, *args) -> Any:
        if not self.first_run_finished:
            self.first_run_finished = True
634
635
            # no specific sizes to compile
            if self.is_last_graph and not self.to_be_compiled_sizes:
636
                end_monitoring_torch_compile(self.vllm_config)
637
638
639
640
641
642
643
644
            return self.compiled_graph_for_general_shape(*args)

        runtime_shape = args[self.sym_shape_indices[0]]
        if runtime_shape not in self.concrete_size_entries:
            # we don't need to do anything for this shape
            return self.compiled_graph_for_general_shape(*args)

        entry = self.concrete_size_entries[runtime_shape]
645

646
647
        if entry.runnable is None:
            entry.runnable = self.compiled_graph_for_general_shape
648

649
650
        if entry.need_to_compile and not entry.compiled:
            entry.compiled = True
651
            self.to_be_compiled_sizes.remove(runtime_shape)
652
653
654
655
            # args are real arguments
            entry.runnable = wrap_inductor(
                self.graph,
                args,
656
657
                self.compilation_config.inductor_compile_config,
                self.compilation_config,
658
659
                graph_index=self.piecewise_compile_index,
                num_graphs=self.total_piecewise_compiles,
660
                runtime_shape=runtime_shape,
661
                use_inductor=self.compilation_config.use_inductor)
662

663
664
            # finished compilations for all required shapes
            if self.is_last_graph and not self.to_be_compiled_sizes:
665
666
667

                # save the hash of the inductor graph for the next run
                self.compilation_config.inductor_hash_cache.save_to_file()
668
                end_monitoring_torch_compile(self.vllm_config)
669

670
671
672
673
        if not entry.use_cudagraph:
            return entry.runnable(*args)

        if entry.cudagraph is None:
674
            if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups:  # noqa
675
676
677
678
679
                entry.num_finished_warmup += 1
                if self.is_first_graph:
                    logger.debug(
                        "Warming up %s/%s for shape %s",
                        entry.num_finished_warmup,
680
                        self.compilation_config.cudagraph_num_of_warmups,
681
682
683
684
                        runtime_shape)
                return entry.runnable(*args)

            if self.is_first_graph:
685
686
687
688
689
                # Since we capture cudagraph for many different shapes and
                # capturing is fast, we don't need to log it for every shape.
                # We only log it in the debug mode.
                logger.debug("Capturing a cudagraph for shape %s",
                             runtime_shape)
690

691
692
693
694
            input_addresses = [
                x.data_ptr() for x in args if isinstance(x, torch.Tensor)
            ]
            entry.input_addresses = input_addresses
695
            cudagraph = torch.cuda.CUDAGraph()
696

697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
            with ExitStack() as stack:
                if not self.is_first_graph:
                    # during every model forward, we will capture
                    # many pieces of cudagraphs (roughly one per layer).
                    # running gc again and again across layers will
                    # make the cudagraph capture very slow.
                    # therefore, we only run gc for the first graph,
                    # and disable gc for the rest of the graphs.
                    stack.enter_context(patch("gc.collect", lambda: None))
                    stack.enter_context(
                        patch("torch.cuda.empty_cache", lambda: None))

                # mind-exploding: carefully manage the reference and memory.
                with torch.cuda.graph(cudagraph, pool=self.graph_pool):
                    # `output` is managed by pytorch's cudagraph pool
                    output = entry.runnable(*args)
                    if self.is_last_graph:
                        # by converting it to weak ref,
                        # the original `output` will immediately be released
                        # to save memory. It is only safe to do this for
                        # the last graph, because the output of the last graph
                        # will not be used by any other cuda graph.
                        output = weak_ref_tensors(output)
720
721
722
723
724

            # here we always use weak ref for the output
            # to save memory
            entry.output = weak_ref_tensors(output)
            entry.cudagraph = cudagraph
725
726
727

            compilation_counter.num_cudagraph_caputured += 1

728
729
730
731
732
733
734
735
736
737
738
739
740
741
            # important: we need to return the output, rather than
            # the weak ref of the output, so that pytorch can correctly
            # manage the memory during cuda graph capture
            return output

        if self.is_debugging_mode:
            # check if the input addresses are the same
            new_input_addresses = [
                x.data_ptr() for x in args if isinstance(x, torch.Tensor)
            ]
            assert new_input_addresses == entry.input_addresses, (
                "Input addresses for cudagraphs are different during replay."
                f" Expected {entry.input_addresses}, got {new_input_addresses}"
            )
742
743
744

        entry.cudagraph.replay()
        return entry.output