backends.py 34.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import ast
4
import copy
5
import dataclasses
6
7
import os
import pprint
8
import time
9
from collections import defaultdict
10
from contextlib import ExitStack
11
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
12
from unittest.mock import patch
13
14
15
16

import torch
import torch.fx as fx

17
import vllm.envs as envs
18
from vllm.config import CompilationConfig, VllmConfig
19
from vllm.logger import init_logger
20
from vllm.utils import weak_ref_tensors
21

22
from .counter import compilation_counter
23
from .inductor_pass import InductorPass
24
from .monitor import end_monitoring_torch_compile
25
from .pass_manager import PostGradPassManager
26
27
28

logger = init_logger(__name__)

29

30
31
32
33
34
35
@dataclasses.dataclass
class InductorArtifact:
    hash_str: str = ""
    file_path: str = ""


36
37
38
class InductorHashCache:
    """
    Disk format: a Python list of tuples, each tuple is
39
    (runtime_shape, graph_index, hash_str, file_path)
40
41
42
43
44
    We use list of tuple for readability.

    In-memory format: a defaultdict of dict, where the key is
    runtime_shape, and the value is a dict of graph_index to hash_str.

45
    The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`,
46
47
48
49
50
51
    we don't use json here because json doesn't support int as key.

    TODO: better off-the-shelf solution to serialize the data?
    """

    def __init__(self, cache_dir: str, disabled: bool = False):
52
53
        self.cache: Dict[Optional[int],
                         Dict[int, InductorArtifact]] = defaultdict(dict)
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        self.disabled = disabled
        self.cache_dir = cache_dir
        self.cache_file_path = os.path.join(cache_dir,
                                            "inductor_hash_cache.py")
        if disabled:
            return
        # set flags so that Inductor and Triton store their cache
        # in the cache_dir, then users only need to copy the cache_dir
        # to another machine to reuse the cache.
        inductor_cache = os.path.join(cache_dir, "inductor_cache")
        os.makedirs(inductor_cache, exist_ok=True)
        os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
        triton_cache = os.path.join(cache_dir, "triton_cache")
        os.makedirs(triton_cache, exist_ok=True)
        os.environ["TRITON_CACHE_DIR"] = triton_cache
        if os.path.exists(self.cache_file_path):
            with open(self.cache_file_path) as f:
                self.deserialize(f.read())

    def deserialize(self, data: str):
        # we use ast.literal_eval to parse the data
        # because it is a safe way to parse Python literals.
        # do not use eval(), it is unsafe.
        list_data = ast.literal_eval(data)
78
79
80
81
82
83
84
85
86
87
88
        for item in list_data:
            runtime_shape = item[0]
            graph_index = item[1]
            hash_str = item[2]
            # for compatibility of old version,
            # where we don't have file_path.
            # NOTE: after running the new code, the file_path
            # will be updated.
            file_path = "" if len(item) == 3 else item[3]
            self.cache[runtime_shape][graph_index] = InductorArtifact(
                hash_str=hash_str, file_path=file_path)
89
90
91

    def serialize(self) -> str:
        data = []
92
93
94
95
96
        for runtime_shape, value in self.cache.items():
            for graph_index, inductor_artifact in value.items():
                data.append(
                    (runtime_shape, graph_index, inductor_artifact.hash_str,
                     inductor_artifact.file_path))
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        printer = pprint.PrettyPrinter(indent=4)
        return printer.pformat(data)

    def save_to_file(self):
        if self.disabled:
            return
        with open(self.cache_file_path, "w") as f:
            f.write(self.serialize())

    def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
        if self.disabled:
            return False
        runtime_shape, graph_index = key
        return runtime_shape in self.cache and graph_index in self.cache[
            runtime_shape]

113
    def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact:
114
115
116
117
118
        if self.disabled:
            raise KeyError("cannot read from disabled cache")
        runtime_shape, graph_index = key
        return self.cache[runtime_shape][graph_index]

119
120
    def __setitem__(self, key: Tuple[Optional[int], int],
                    value: InductorArtifact):
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        # setitem for disabled cache is fine, because we
        # don't actually write to the disk
        runtime_shape, graph_index = key
        self.cache[runtime_shape][graph_index] = value


class AlwaysHitShapeEnv:
    """
    Why do we need this class:

    For normal `torch.compile` usage, every compilation will have
    one Dynamo bytecode compilation and one Inductor compilation.
    The Inductor compilation happens under the context of the
    Dynamo bytecode compilation, and that context is used to
    determine the dynamic shape information, etc.

    For our use case, we only run Dynamo bytecode compilation once,
    and run Inductor compilation multiple times with different shapes
    plus a general shape. The compilation for specific shapes happens
    outside of the context of the Dynamo bytecode compilation. At that
    time, we don't have shape environment to provide to Inductor, and
    it will fail the Inductor code cache lookup.

    By providing a dummy shape environment that always hits, we can
    make the Inductor code cache lookup always hit, and we can
    compile the graph for different shapes as needed.

    The following dummy methods are obtained by trial-and-error
    until it works.
    """

    def __init__(self) -> None:
        self.guards: List[Any] = []

    def evaluate_guards_expression(self, *args, **kwargs):
        return True

    def get_pruned_guards(self, *args, **kwargs):
        return []

    def produce_guards_expression(self, *args, **kwargs):
        return ""


165
def wrap_inductor(graph: fx.GraphModule,
166
167
                  example_inputs,
                  additional_inductor_config,
168
                  compilation_config: CompilationConfig,
169
                  vllm_backend: "VllmBackend",
170
171
                  graph_index: int = 0,
                  num_graphs: int = 1,
172
                  runtime_shape: Optional[int] = None,
173
                  use_inductor: bool = True) -> Any:
174
175
176
177
178
    if graph_index == 0:
        # before compiling the first graph, record the start time
        global compilation_start_time
        compilation_start_time = time.time()

179
180
181
182
183
    if not use_inductor:
        return graph

    compilation_counter.num_inductor_compilations += 1

184
    from torch._inductor import config
185
    current_config = config.get_config_copy()
186
    from torch._inductor.compile_fx import compile_fx
187
188
189

    if additional_inductor_config is not None:
        current_config.update(additional_inductor_config)
190

191
192
193
194
195
196
    if isinstance(runtime_shape, int):
        # for a specific batchsize, tuning triton kernel parameters
        # can be beneficial
        current_config["max_autotune"] = True
        current_config["coordinate_descent_tuning"] = True

197
198
199
    # inductor can inplace modify the graph, so we need to copy it
    # see https://github.com/pytorch/pytorch/issues/138980
    graph = copy.deepcopy(graph)
200

201
    cache_data = vllm_backend.inductor_hash_cache
202
203
204
    if (runtime_shape, graph_index) in cache_data:
        # we compiled this graph before
        # so we can directly lookup the compiled graph via hash
205
206
        inductor_artifact = cache_data[(runtime_shape, graph_index)]
        hash_str = inductor_artifact.hash_str
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        if graph_index == 0:
            # adds some info logging for the first graph
            logger.info(
                "Directly lookup the graph for shape %s from the cache",
                str(runtime_shape))  # noqa
        logger.debug(
            "directly lookup the %s-th graph for shape %s via hash %s",
            graph_index, str(runtime_shape), hash_str)
        from torch._inductor.codecache import FxGraphCache
        with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
                   lambda *args, **kwargs: AlwaysHitShapeEnv()):
            inductor_compiled_graph = FxGraphCache._lookup_graph(
                hash_str, example_inputs, True, False)
            assert inductor_compiled_graph is not None, (
                "Inductor cache lookup failed. Please remove"
222
                f"the cache file {cache_data.cache_file_path} and try again."  # noqa
223
            )
224
            inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename  # noqa
225
226
227
228
229
230
231
232
233
234

        # Inductor calling convention (function signature):
        # f(list) -> tuple
        # Dynamo calling convention (function signature):
        # f(*args) -> Any

        # need to know if the graph returns a tuple
        from torch._inductor.compile_fx import graph_returns_tuple
        returns_tuple = graph_returns_tuple(graph)

youkaichao's avatar
youkaichao committed
235
236
        # this is the callable we return to Dynamo to run
        def compiled_graph(*args):
237
238
239
240
241
242
243
244
245
246
247
248
249
            # convert args to list
            list_args = list(args)
            graph_output = inductor_compiled_graph(list_args)
            # unpack the tuple if needed
            if returns_tuple:
                return graph_output
            else:
                return graph_output[0]
    else:
        # it's the first time we compile this graph
        # the assumption is that we don't have nested Inductor compilation.
        # compiled_fx_graph_hash will only be called once, and we can hook
        # it to get the hash of the compiled graph directly.
250
251
252
253
254
255
256
257
258
259

        inductor_artifact = InductorArtifact()
        from torch._inductor.codecache import (FxGraphCache,
                                               compiled_fx_graph_hash)
        original_load = FxGraphCache.load

        def hijack_load(*args, **kwargs):
            inductor_compiled_graph = original_load(*args, **kwargs)
            inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename  # noqa
            return inductor_compiled_graph
260
261
262

        def hijack_compiled_fx_graph_hash(*args, **kwargs):
            out = compiled_fx_graph_hash(*args, **kwargs)
263
            inductor_artifact.hash_str = out[0]
264
265
266
267
268
269
270
271
272
273
274
            return out

        def _check_can_cache(*args, **kwargs):
            # no error means it can be cached.
            # Inductor refuses to cache the graph outside of Dynamo
            # tracing context, and also disables caching for graphs
            # with high-order ops.
            # For vLLM, in either case, we want to cache the graph.
            # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
            return

275
        def _get_shape_env() -> AlwaysHitShapeEnv:
276
277
            return AlwaysHitShapeEnv()

278
279
280
281
        with ExitStack() as stack:
            if not cache_data.disabled:
                # compilation cache is enabled, patch several functions

282
283
284
285
286
                # hijack to get the compiled graph itself
                stack.enter_context(
                    patch("torch._inductor.codecache.FxGraphCache.load",
                          hijack_load))

287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
                # for hijacking the hash of the compiled graph
                stack.enter_context(
                    patch("torch._inductor.codecache.compiled_fx_graph_hash",
                          hijack_compiled_fx_graph_hash))

                # for providing a dummy shape environment
                stack.enter_context(
                    patch(
                        "torch._inductor.codecache.FxGraphCache._get_shape_env",
                        _get_shape_env))

                # for forcing the graph to be cached
                stack.enter_context(
                    patch(
                        "torch._inductor.codecache.FxGraphCache._check_can_cache",
                        _check_can_cache))

304
305
306
            compiled_graph = compile_fx(graph,
                                        example_inputs,
                                        config_patches=current_config)
307
308
309
310
311
312
313
314
315
316
        # store the inductor_artifact in the cache
        cache_data[(runtime_shape, graph_index)] = inductor_artifact
        if graph_index == 0:
            # adds some info logging for the first graph
            logger.info("Cache the graph of shape %s for later use",
                        str(runtime_shape))
        logger.debug(
            "store the %s-th graph for shape %s via hash %s from file %s",
            graph_index, str(runtime_shape), inductor_artifact.hash_str,
            inductor_artifact.file_path)
317
318
319
320
321
322
323
324
325
326
327
328
329
    # after compiling the last graph, record the end time
    if graph_index == num_graphs - 1:
        now = time.time()
        elapsed = now - compilation_start_time
        compilation_config.compilation_time += elapsed
        if runtime_shape is None:
            logger.info("Compiling a graph for general shape takes %.2f s",
                        elapsed)
        else:
            logger.info("Compiling a graph for shape %s takes %.2f s",
                        runtime_shape, elapsed)

    return compiled_graph
330
331


332
333
334
@dataclasses.dataclass
class SplitItem:
    submod_name: str
335
    graph_id: int
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
    is_splitting_graph: bool
    graph: fx.GraphModule


def split_graph(graph: fx.GraphModule,
                ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]:
    # split graph by ops
    subgraph_id = 0
    node_to_subgraph_id = {}
    split_op_graphs = []
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
        if node.op == 'call_function' and str(node.target) in ops:
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
362
        graph,
363
364
365
        None,
        lambda node: node_to_subgraph_id[node],
        keep_original_order=True)
366

367
    outputs = []
368

369
    names = [name for (name, module) in split_gm.named_modules()]
370

371
372
373
374
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
375

376
        module = getattr(split_gm, name)
377

378
        graph_id = int(name.replace("submod_", ""))
379
380
381
382
383
        outputs.append(
            SplitItem(name, graph_id, (graph_id in split_op_graphs), module))

    # sort by intetger graph_id, rather than string name
    outputs.sort(key=lambda x: x.graph_id)
384

385
    return split_gm, outputs
386
387


388
389
390
# we share the global graph pool among all the backends
global_graph_pool = None

391
392
compilation_start_time = 0.0

393
394
395
396
397
398

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
399
400
401
402
403

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
404
405
406
    """

    def __init__(self, module: torch.fx.GraphModule,
407
                 compile_submod_names: List[str], vllm_config: VllmConfig,
408
                 graph_pool, vllm_backend: "VllmBackend"):
409
410
411
412
        super().__init__(module)
        from torch._guards import detect_fake_mode
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
413
        self.compilation_config = vllm_config.compilation_config
414
        self.graph_pool = graph_pool
415
        self.vllm_config = vllm_config
416
        self.vllm_backend = vllm_backend
417
418
419
420
421
422

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
423
424
        with self.fake_mode:
            return super().run(*fake_args)
425
426
427
428
429
430
431
432

    def call_module(self, target: torch.fx.node.Target,
                    args: Tuple[torch.fx.node.Argument,
                                ...], kwargs: Dict[str, Any]) -> Any:
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
433
            index = self.compile_submod_names.index(target)
434
435
436
437
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
438
            global compilation_start_time
439
440
441
            compiled_graph_for_general_shape = wrap_inductor(
                submod,
                args,
442
443
                self.compilation_config.inductor_compile_config,
                self.compilation_config,
444
                self.vllm_backend,
445
446
                graph_index=index,
                num_graphs=len(self.compile_submod_names),
447
                runtime_shape=None,
448
                use_inductor=self.compilation_config.use_inductor)
449
450

            self.module.__dict__[target] = PiecewiseBackend(
451
                submod, self.vllm_config, self.graph_pool, index,
452
                len(self.compile_submod_names), sym_shape_indices,
453
                compiled_graph_for_general_shape, self.vllm_backend)
454
455
456
457
458
459

            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


460
461
462
463
class VllmBackend:
    """The compilation backend for `torch.compile` with VLLM.
    It is used for compilation level of `CompilationLevel.PIECEWISE`,
    where we customize the compilation.
464

465
466
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
467

468
469
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
470
    """
471

472
473
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
474
475
476
477
478
479
480
481
    graph_pool: Any
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
    piecewise_graphs: List[SplitItem]
    returned_callable: Callable
482
483
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
484
485
    sym_tensor_indices: List[int]
    input_buffers: List[torch.Tensor]
486
    inductor_hash_cache: InductorHashCache
487

488
489
    def __init__(
        self,
490
        vllm_config: VllmConfig,
491
    ):
492
493
494
495
496
497
498
499
        global global_graph_pool
        if global_graph_pool is None:
            global_graph_pool = torch.cuda.graph_pool_handle()

        # TODO: in the future, if we want to use multiple
        # streams, it might not be safe to share a global pool.
        # only investigate this when we use multiple streams
        self.graph_pool = global_graph_pool
500
501
502

        # Passes to run on the graph post-grad.
        self.post_grad_pass_manager = PostGradPassManager()
503

504
505
506
        self.sym_tensor_indices = []
        self.input_buffers = []

507
508
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
509

510
511
512
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

513
    def configure_post_pass(self):
514
        config = self.compilation_config
515
        self.post_grad_pass_manager.configure(config.pass_config)
516

517
518
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
519
        inductor_config = config.inductor_compile_config
520
521
522
523
524
525
        PASS_KEY = "post_grad_custom_post_pass"
        if PASS_KEY in inductor_config:
            # Config should automatically wrap all inductor passes
            assert isinstance(inductor_config[PASS_KEY], InductorPass)
            self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
        inductor_config[PASS_KEY] = self.post_grad_pass_manager
526

527
528
    def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

529
        vllm_config = self.vllm_config
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.

            # 1. factors come from the vllm_config (it mainly summarizes how the
            #    model is created)
            config_hash = vllm_config.compute_hash()

            # 2. factors come from the code files that are traced by Dynamo (
            #    it mainly summarizes how the model is used in forward pass)
            forward_code_files = list(
                sorted(self.compilation_config.traced_files))
            self.compilation_config.traced_files.clear()
            logger.debug(
                "Traced files (to be considered for compilation cache):\n%s",
                "\n".join(forward_code_files))
            hash_content = []
            for filepath in forward_code_files:
                hash_content.append(filepath)
                with open(filepath) as f:
                    hash_content.append(f.read())
            import hashlib
            code_hash = hashlib.md5(
                "\n".join(hash_content).encode()).hexdigest()

            # combine the two hashes to generate the cache dir
            hash_key = hashlib.md5(
                f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
            cache_dir = os.path.join(
561
562
563
564
565
566
567
                envs.VLLM_CACHE_ROOT,
                "torch_compile_cache",
                hash_key,
            )
            self.compilation_config.cache_dir = cache_dir

        cache_dir = self.compilation_config.cache_dir
568
        os.makedirs(cache_dir, exist_ok=True)
569
570
571
        local_cache_dir = os.path.join(
            cache_dir, f"rank_{vllm_config.parallel_config.rank}")
        self.compilation_config.local_cache_dir = local_cache_dir
572
573
574

        disabled = envs.VLLM_DISABLE_COMPILE_CACHE
        self.inductor_hash_cache: InductorHashCache = InductorHashCache(
575
            local_cache_dir, disabled=disabled)
576
577
578
579
        if disabled:
            logger.info("vLLM's torch.compile cache is disabled.")
        else:
            logger.info("Using cache directory: %s for vLLM's torch.compile",
580
                        local_cache_dir)
581

582
583
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
584
        compilation_counter.num_graphs_seen += 1
585
586
587
        from .monitor import torch_compile_start_time
        dynamo_time = time.time() - torch_compile_start_time
        logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
588
        self.compilation_config.compilation_time += dynamo_time
589
590
591
592
593
594

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
595
        self.configure_post_pass()
596
597

        self.split_gm, self.piecewise_graphs = split_graph(
598
            graph, self.compilation_config.splitting_ops)
599

600
        from torch._dynamo.utils import lazy_format_graph_code
601
602
603
604
605

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
606

607
608
609
610
611
612
613
614
615
616
        compilation_counter.num_piecewise_graphs_seen += len(
            self.piecewise_graphs)
        submod_names_to_compile = [
            item.submod_name for item in self.piecewise_graphs
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
        PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
617
618
                                    self.vllm_config, self.graph_pool,
                                    self).run(*example_inputs)
619

620
621
622
623
624
625
626
627
628
629
630
631
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
            # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
            # use `print_readable` because it can include submodules
            src = "from __future__ import annotations\nimport torch\n" + \
                self.split_gm.print_readable(print_output=False)
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

            logger.debug("Computation graph saved to %s", graph_path)

632
633
        self._called = True

634
635
        if not self.compilation_config.use_cudagraph or \
            not self.compilation_config.cudagraph_copy_inputs:
636
637
638
639
640
641
642
643
644
645
646
            return self.split_gm

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
647
648
649
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
650
651
        self.sym_tensor_indices = [
            i for i, x in enumerate(fake_args)
652
653
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \
                any(is_symbolic(d) for d in x.size())
654
655
656
657
658
659
660
661
662
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

youkaichao's avatar
youkaichao committed
663
664
        # this is the callable we return to Dynamo to run
        def copy_and_call(*args):
665
666
667
668
669
670
671
672
673
674
675
676
677
678
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

        return copy_and_call
679
680
681
682
683
684


@dataclasses.dataclass
class ConcreteSizeEntry:
    runtime_shape: int
    need_to_compile: bool  # the size is in compile_sizes
685
    use_cudagraph: bool  # the size is in cudagraph_capture_sizes
686
687
688
689
690
691
692

    compiled: bool = False
    runnable: Callable = None  # type: ignore
    num_finished_warmup: int = 0
    cudagraph: Optional[torch.cuda.CUDAGraph] = None
    output: Optional[Any] = None

693
694
695
696
    # for cudagraph debugging, track the input addresses
    # during capture, and check if they are the same during replay
    input_addresses: Optional[List[int]] = None

697
698
699

class PiecewiseBackend:

700
701
702
    def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
                 graph_pool: Any, piecewise_compile_index: int,
                 total_piecewise_compiles: int, sym_shape_indices: List[int],
703
704
                 compiled_graph_for_general_shape: Callable,
                 vllm_backend: VllmBackend):
705
706
707
708
709
710
        """
        The backend for piecewise compilation.
        It mainly handles the compilation and cudagraph capturing.

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
711
        `compilation_config.compile_sizes`.
712
713
714
715
716
717
718

        Independently, we will capture cudagraph for different shapes.

        If a shape needs both compilation and cudagraph, we will
        compile it first, and then capture cudagraph.
        """
        self.graph = graph
719
720
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
721
        self.graph_pool = graph_pool
722
723
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles
724
        self.vllm_backend = vllm_backend
725
726
727
728

        self.is_first_graph = piecewise_compile_index == 0
        self.is_last_graph = (
            piecewise_compile_index == total_piecewise_compiles - 1)
729
730

        self.compile_sizes: Set[int] = set(
731
            self.compilation_config.compile_sizes)
732
733
        self.cudagraph_capture_sizes: Set[int] = set(
            self.compilation_config.cudagraph_capture_sizes
734
        ) if self.compilation_config.use_cudagraph else set()
735
736
737

        self.first_run_finished = False

738
        self.compiled_graph_for_general_shape = compiled_graph_for_general_shape  # noqa
739

740
        self.sym_shape_indices = sym_shape_indices
741

742
743
        self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"

744
745
746
        # the entries for different shapes that we need to either
        # compile or capture cudagraph
        self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
747
748
749
750

        # to_be_compiled_sizes tracks the remaining sizes to compile,
        # and updates during the compilation process, so we need to copy it
        self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
751
        for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
752
753
754
            self.concrete_size_entries[shape] = ConcreteSizeEntry(
                runtime_shape=shape,
                need_to_compile=shape in self.compile_sizes,
755
                use_cudagraph=shape in self.cudagraph_capture_sizes,
756
757
            )

758
759
760
761
    def check_for_ending_compilation(self):
        if self.is_last_graph and not self.to_be_compiled_sizes:
            # no specific sizes to compile
            # save the hash of the inductor graph for the next run
762
            self.vllm_backend.inductor_hash_cache.save_to_file()
763
764
            end_monitoring_torch_compile(self.vllm_config)

765
766
767
    def __call__(self, *args) -> Any:
        if not self.first_run_finished:
            self.first_run_finished = True
768
            self.check_for_ending_compilation()
769
770
771
772
773
774
775
776
            return self.compiled_graph_for_general_shape(*args)

        runtime_shape = args[self.sym_shape_indices[0]]
        if runtime_shape not in self.concrete_size_entries:
            # we don't need to do anything for this shape
            return self.compiled_graph_for_general_shape(*args)

        entry = self.concrete_size_entries[runtime_shape]
777

778
779
        if entry.runnable is None:
            entry.runnable = self.compiled_graph_for_general_shape
780

781
782
        if entry.need_to_compile and not entry.compiled:
            entry.compiled = True
783
            self.to_be_compiled_sizes.remove(runtime_shape)
784
785
786
787
            # args are real arguments
            entry.runnable = wrap_inductor(
                self.graph,
                args,
788
789
                self.compilation_config.inductor_compile_config,
                self.compilation_config,
790
                self.vllm_backend,
791
792
                graph_index=self.piecewise_compile_index,
                num_graphs=self.total_piecewise_compiles,
793
                runtime_shape=runtime_shape,
794
                use_inductor=self.compilation_config.use_inductor)
795

796
797
            # finished compilations for all required shapes
            if self.is_last_graph and not self.to_be_compiled_sizes:
798
                self.check_for_ending_compilation()
799

800
801
802
803
        if not entry.use_cudagraph:
            return entry.runnable(*args)

        if entry.cudagraph is None:
804
            if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups:  # noqa
805
806
807
808
809
                entry.num_finished_warmup += 1
                if self.is_first_graph:
                    logger.debug(
                        "Warming up %s/%s for shape %s",
                        entry.num_finished_warmup,
810
                        self.compilation_config.cudagraph_num_of_warmups,
811
812
813
814
                        runtime_shape)
                return entry.runnable(*args)

            if self.is_first_graph:
815
816
817
818
819
                # Since we capture cudagraph for many different shapes and
                # capturing is fast, we don't need to log it for every shape.
                # We only log it in the debug mode.
                logger.debug("Capturing a cudagraph for shape %s",
                             runtime_shape)
820

821
822
823
824
            input_addresses = [
                x.data_ptr() for x in args if isinstance(x, torch.Tensor)
            ]
            entry.input_addresses = input_addresses
825
            cudagraph = torch.cuda.CUDAGraph()
826

827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
            with ExitStack() as stack:
                if not self.is_first_graph:
                    # during every model forward, we will capture
                    # many pieces of cudagraphs (roughly one per layer).
                    # running gc again and again across layers will
                    # make the cudagraph capture very slow.
                    # therefore, we only run gc for the first graph,
                    # and disable gc for the rest of the graphs.
                    stack.enter_context(patch("gc.collect", lambda: None))
                    stack.enter_context(
                        patch("torch.cuda.empty_cache", lambda: None))

                # mind-exploding: carefully manage the reference and memory.
                with torch.cuda.graph(cudagraph, pool=self.graph_pool):
                    # `output` is managed by pytorch's cudagraph pool
                    output = entry.runnable(*args)
                    if self.is_last_graph:
                        # by converting it to weak ref,
                        # the original `output` will immediately be released
                        # to save memory. It is only safe to do this for
                        # the last graph, because the output of the last graph
                        # will not be used by any other cuda graph.
                        output = weak_ref_tensors(output)
850
851
852
853
854

            # here we always use weak ref for the output
            # to save memory
            entry.output = weak_ref_tensors(output)
            entry.cudagraph = cudagraph
855
856
857

            compilation_counter.num_cudagraph_caputured += 1

858
859
860
861
862
863
864
865
866
867
868
869
870
871
            # important: we need to return the output, rather than
            # the weak ref of the output, so that pytorch can correctly
            # manage the memory during cuda graph capture
            return output

        if self.is_debugging_mode:
            # check if the input addresses are the same
            new_input_addresses = [
                x.data_ptr() for x in args if isinstance(x, torch.Tensor)
            ]
            assert new_input_addresses == entry.input_addresses, (
                "Input addresses for cudagraphs are different during replay."
                f" Expected {entry.input_addresses}, got {new_input_addresses}"
            )
872
873
874

        entry.cudagraph.replay()
        return entry.output