backends.py 33.6 KB
Newer Older
1
import ast
2
import copy
3
import dataclasses
4
5
import os
import pprint
6
import time
7
from collections import defaultdict
8
from contextlib import ExitStack
9
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
10
from unittest.mock import patch
11
12
13
14

import torch
import torch.fx as fx

15
import vllm.envs as envs
16
from vllm.config import CompilationConfig, VllmConfig
17
from vllm.logger import init_logger
18
from vllm.utils import weak_ref_tensors
19

20
from .counter import compilation_counter
21
from .inductor_pass import InductorPass
22
from .monitor import end_monitoring_torch_compile
23
from .pass_manager import PostGradPassManager
24
25
26

logger = init_logger(__name__)

27

28
29
30
31
32
33
@dataclasses.dataclass
class InductorArtifact:
    hash_str: str = ""
    file_path: str = ""


34
35
36
class InductorHashCache:
    """
    Disk format: a Python list of tuples, each tuple is
37
    (runtime_shape, graph_index, hash_str, file_path)
38
39
40
41
42
    We use list of tuple for readability.

    In-memory format: a defaultdict of dict, where the key is
    runtime_shape, and the value is a dict of graph_index to hash_str.

43
    The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`,
44
45
46
47
48
49
    we don't use json here because json doesn't support int as key.

    TODO: better off-the-shelf solution to serialize the data?
    """

    def __init__(self, cache_dir: str, disabled: bool = False):
50
51
        self.cache: Dict[Optional[int],
                         Dict[int, InductorArtifact]] = defaultdict(dict)
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        self.disabled = disabled
        self.cache_dir = cache_dir
        self.cache_file_path = os.path.join(cache_dir,
                                            "inductor_hash_cache.py")
        if disabled:
            return
        # set flags so that Inductor and Triton store their cache
        # in the cache_dir, then users only need to copy the cache_dir
        # to another machine to reuse the cache.
        inductor_cache = os.path.join(cache_dir, "inductor_cache")
        os.makedirs(inductor_cache, exist_ok=True)
        os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
        triton_cache = os.path.join(cache_dir, "triton_cache")
        os.makedirs(triton_cache, exist_ok=True)
        os.environ["TRITON_CACHE_DIR"] = triton_cache
        if os.path.exists(self.cache_file_path):
            with open(self.cache_file_path) as f:
                self.deserialize(f.read())

    def deserialize(self, data: str):
        # we use ast.literal_eval to parse the data
        # because it is a safe way to parse Python literals.
        # do not use eval(), it is unsafe.
        list_data = ast.literal_eval(data)
76
77
78
79
80
81
82
83
84
85
86
        for item in list_data:
            runtime_shape = item[0]
            graph_index = item[1]
            hash_str = item[2]
            # for compatibility of old version,
            # where we don't have file_path.
            # NOTE: after running the new code, the file_path
            # will be updated.
            file_path = "" if len(item) == 3 else item[3]
            self.cache[runtime_shape][graph_index] = InductorArtifact(
                hash_str=hash_str, file_path=file_path)
87
88
89

    def serialize(self) -> str:
        data = []
90
91
92
93
94
        for runtime_shape, value in self.cache.items():
            for graph_index, inductor_artifact in value.items():
                data.append(
                    (runtime_shape, graph_index, inductor_artifact.hash_str,
                     inductor_artifact.file_path))
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        printer = pprint.PrettyPrinter(indent=4)
        return printer.pformat(data)

    def save_to_file(self):
        if self.disabled:
            return
        with open(self.cache_file_path, "w") as f:
            f.write(self.serialize())

    def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
        if self.disabled:
            return False
        runtime_shape, graph_index = key
        return runtime_shape in self.cache and graph_index in self.cache[
            runtime_shape]

111
    def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact:
112
113
114
115
116
        if self.disabled:
            raise KeyError("cannot read from disabled cache")
        runtime_shape, graph_index = key
        return self.cache[runtime_shape][graph_index]

117
118
    def __setitem__(self, key: Tuple[Optional[int], int],
                    value: InductorArtifact):
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        # setitem for disabled cache is fine, because we
        # don't actually write to the disk
        runtime_shape, graph_index = key
        self.cache[runtime_shape][graph_index] = value


class AlwaysHitShapeEnv:
    """
    Why do we need this class:

    For normal `torch.compile` usage, every compilation will have
    one Dynamo bytecode compilation and one Inductor compilation.
    The Inductor compilation happens under the context of the
    Dynamo bytecode compilation, and that context is used to
    determine the dynamic shape information, etc.

    For our use case, we only run Dynamo bytecode compilation once,
    and run Inductor compilation multiple times with different shapes
    plus a general shape. The compilation for specific shapes happens
    outside of the context of the Dynamo bytecode compilation. At that
    time, we don't have shape environment to provide to Inductor, and
    it will fail the Inductor code cache lookup.

    By providing a dummy shape environment that always hits, we can
    make the Inductor code cache lookup always hit, and we can
    compile the graph for different shapes as needed.

    The following dummy methods are obtained by trial-and-error
    until it works.
    """

    def __init__(self) -> None:
        self.guards: List[Any] = []

    def evaluate_guards_expression(self, *args, **kwargs):
        return True

    def get_pruned_guards(self, *args, **kwargs):
        return []

    def produce_guards_expression(self, *args, **kwargs):
        return ""


163
def wrap_inductor(graph: fx.GraphModule,
164
165
                  example_inputs,
                  additional_inductor_config,
166
                  compilation_config: CompilationConfig,
167
                  vllm_backend: "VllmBackend",
168
169
                  graph_index: int = 0,
                  num_graphs: int = 1,
170
                  runtime_shape: Optional[int] = None,
171
                  use_inductor: bool = True) -> Any:
172
173
174
175
176
    if graph_index == 0:
        # before compiling the first graph, record the start time
        global compilation_start_time
        compilation_start_time = time.time()

177
178
179
180
181
    if not use_inductor:
        return graph

    compilation_counter.num_inductor_compilations += 1

182
    from torch._inductor import config
183
    current_config = config.get_config_copy()
184
    from torch._inductor.compile_fx import compile_fx
185
186
187

    if additional_inductor_config is not None:
        current_config.update(additional_inductor_config)
188

189
190
191
192
193
194
    if isinstance(runtime_shape, int):
        # for a specific batchsize, tuning triton kernel parameters
        # can be beneficial
        current_config["max_autotune"] = True
        current_config["coordinate_descent_tuning"] = True

195
196
197
    # inductor can inplace modify the graph, so we need to copy it
    # see https://github.com/pytorch/pytorch/issues/138980
    graph = copy.deepcopy(graph)
198

199
    cache_data = vllm_backend.inductor_hash_cache
200
201
202
    if (runtime_shape, graph_index) in cache_data:
        # we compiled this graph before
        # so we can directly lookup the compiled graph via hash
203
204
        inductor_artifact = cache_data[(runtime_shape, graph_index)]
        hash_str = inductor_artifact.hash_str
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        if graph_index == 0:
            # adds some info logging for the first graph
            logger.info(
                "Directly lookup the graph for shape %s from the cache",
                str(runtime_shape))  # noqa
        logger.debug(
            "directly lookup the %s-th graph for shape %s via hash %s",
            graph_index, str(runtime_shape), hash_str)
        from torch._inductor.codecache import FxGraphCache
        with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
                   lambda *args, **kwargs: AlwaysHitShapeEnv()):
            inductor_compiled_graph = FxGraphCache._lookup_graph(
                hash_str, example_inputs, True, False)
            assert inductor_compiled_graph is not None, (
                "Inductor cache lookup failed. Please remove"
220
                f"the cache file {cache_data.cache_file_path} and try again."  # noqa
221
            )
222
            inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename  # noqa
223
224
225
226
227
228
229
230
231
232

        # Inductor calling convention (function signature):
        # f(list) -> tuple
        # Dynamo calling convention (function signature):
        # f(*args) -> Any

        # need to know if the graph returns a tuple
        from torch._inductor.compile_fx import graph_returns_tuple
        returns_tuple = graph_returns_tuple(graph)

youkaichao's avatar
youkaichao committed
233
234
        # this is the callable we return to Dynamo to run
        def compiled_graph(*args):
235
236
237
238
239
240
241
242
243
244
245
246
247
            # convert args to list
            list_args = list(args)
            graph_output = inductor_compiled_graph(list_args)
            # unpack the tuple if needed
            if returns_tuple:
                return graph_output
            else:
                return graph_output[0]
    else:
        # it's the first time we compile this graph
        # the assumption is that we don't have nested Inductor compilation.
        # compiled_fx_graph_hash will only be called once, and we can hook
        # it to get the hash of the compiled graph directly.
248
249
250
251
252
253
254
255
256
257

        inductor_artifact = InductorArtifact()
        from torch._inductor.codecache import (FxGraphCache,
                                               compiled_fx_graph_hash)
        original_load = FxGraphCache.load

        def hijack_load(*args, **kwargs):
            inductor_compiled_graph = original_load(*args, **kwargs)
            inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename  # noqa
            return inductor_compiled_graph
258
259
260

        def hijack_compiled_fx_graph_hash(*args, **kwargs):
            out = compiled_fx_graph_hash(*args, **kwargs)
261
            inductor_artifact.hash_str = out[0]
262
263
264
265
266
267
268
269
270
271
272
            return out

        def _check_can_cache(*args, **kwargs):
            # no error means it can be cached.
            # Inductor refuses to cache the graph outside of Dynamo
            # tracing context, and also disables caching for graphs
            # with high-order ops.
            # For vLLM, in either case, we want to cache the graph.
            # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
            return

273
        def _get_shape_env() -> AlwaysHitShapeEnv:
274
275
            return AlwaysHitShapeEnv()

276
277
278
279
        with ExitStack() as stack:
            if not cache_data.disabled:
                # compilation cache is enabled, patch several functions

280
281
282
283
284
                # hijack to get the compiled graph itself
                stack.enter_context(
                    patch("torch._inductor.codecache.FxGraphCache.load",
                          hijack_load))

285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
                # for hijacking the hash of the compiled graph
                stack.enter_context(
                    patch("torch._inductor.codecache.compiled_fx_graph_hash",
                          hijack_compiled_fx_graph_hash))

                # for providing a dummy shape environment
                stack.enter_context(
                    patch(
                        "torch._inductor.codecache.FxGraphCache._get_shape_env",
                        _get_shape_env))

                # for forcing the graph to be cached
                stack.enter_context(
                    patch(
                        "torch._inductor.codecache.FxGraphCache._check_can_cache",
                        _check_can_cache))

302
303
304
            compiled_graph = compile_fx(graph,
                                        example_inputs,
                                        config_patches=current_config)
305
306
307
308
309
310
311
312
313
314
        # store the inductor_artifact in the cache
        cache_data[(runtime_shape, graph_index)] = inductor_artifact
        if graph_index == 0:
            # adds some info logging for the first graph
            logger.info("Cache the graph of shape %s for later use",
                        str(runtime_shape))
        logger.debug(
            "store the %s-th graph for shape %s via hash %s from file %s",
            graph_index, str(runtime_shape), inductor_artifact.hash_str,
            inductor_artifact.file_path)
315
316
317
318
319
320
321
322
323
324
325
326
327
    # after compiling the last graph, record the end time
    if graph_index == num_graphs - 1:
        now = time.time()
        elapsed = now - compilation_start_time
        compilation_config.compilation_time += elapsed
        if runtime_shape is None:
            logger.info("Compiling a graph for general shape takes %.2f s",
                        elapsed)
        else:
            logger.info("Compiling a graph for shape %s takes %.2f s",
                        runtime_shape, elapsed)

    return compiled_graph
328
329


330
331
332
@dataclasses.dataclass
class SplitItem:
    submod_name: str
333
    graph_id: int
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
    is_splitting_graph: bool
    graph: fx.GraphModule


def split_graph(graph: fx.GraphModule,
                ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]:
    # split graph by ops
    subgraph_id = 0
    node_to_subgraph_id = {}
    split_op_graphs = []
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
        if node.op == 'call_function' and str(node.target) in ops:
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
360
        graph,
361
362
363
        None,
        lambda node: node_to_subgraph_id[node],
        keep_original_order=True)
364

365
    outputs = []
366

367
    names = [name for (name, module) in split_gm.named_modules()]
368

369
370
371
372
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
373

374
        module = getattr(split_gm, name)
375

376
        graph_id = int(name.replace("submod_", ""))
377
378
379
380
381
        outputs.append(
            SplitItem(name, graph_id, (graph_id in split_op_graphs), module))

    # sort by intetger graph_id, rather than string name
    outputs.sort(key=lambda x: x.graph_id)
382

383
    return split_gm, outputs
384
385


386
387
388
# we share the global graph pool among all the backends
global_graph_pool = None

389
390
compilation_start_time = 0.0

391
392
393
394
395
396

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
397
398
399
400
401

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
402
403
404
    """

    def __init__(self, module: torch.fx.GraphModule,
405
                 compile_submod_names: List[str], vllm_config: VllmConfig,
406
                 graph_pool, vllm_backend: "VllmBackend"):
407
408
409
410
        super().__init__(module)
        from torch._guards import detect_fake_mode
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
411
        self.compilation_config = vllm_config.compilation_config
412
        self.graph_pool = graph_pool
413
        self.vllm_config = vllm_config
414
        self.vllm_backend = vllm_backend
415
416
417
418
419
420

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
421
422
        with self.fake_mode:
            return super().run(*fake_args)
423
424
425
426
427
428
429
430

    def call_module(self, target: torch.fx.node.Target,
                    args: Tuple[torch.fx.node.Argument,
                                ...], kwargs: Dict[str, Any]) -> Any:
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
431
            index = self.compile_submod_names.index(target)
432
433
434
435
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
436
            global compilation_start_time
437
438
439
            compiled_graph_for_general_shape = wrap_inductor(
                submod,
                args,
440
441
                self.compilation_config.inductor_compile_config,
                self.compilation_config,
442
                self.vllm_backend,
443
444
                graph_index=index,
                num_graphs=len(self.compile_submod_names),
445
                runtime_shape=None,
446
                use_inductor=self.compilation_config.use_inductor)
447
448

            self.module.__dict__[target] = PiecewiseBackend(
449
                submod, self.vllm_config, self.graph_pool, index,
450
                len(self.compile_submod_names), sym_shape_indices,
451
                compiled_graph_for_general_shape, self.vllm_backend)
452
453
454
455
456
457

            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


458
459
460
461
class VllmBackend:
    """The compilation backend for `torch.compile` with VLLM.
    It is used for compilation level of `CompilationLevel.PIECEWISE`,
    where we customize the compilation.
462

463
464
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
465

466
467
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
468
    """
469

470
471
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
472
473
474
475
476
477
478
479
    graph_pool: Any
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
    piecewise_graphs: List[SplitItem]
    returned_callable: Callable
480
481
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
482
483
    sym_tensor_indices: List[int]
    input_buffers: List[torch.Tensor]
484
    inductor_hash_cache: InductorHashCache
485

486
487
    def __init__(
        self,
488
        vllm_config: VllmConfig,
489
    ):
490
491
492
493
494
495
496
497
        global global_graph_pool
        if global_graph_pool is None:
            global_graph_pool = torch.cuda.graph_pool_handle()

        # TODO: in the future, if we want to use multiple
        # streams, it might not be safe to share a global pool.
        # only investigate this when we use multiple streams
        self.graph_pool = global_graph_pool
498
499
500

        # Passes to run on the graph post-grad.
        self.post_grad_pass_manager = PostGradPassManager()
501

502
503
504
        self.sym_tensor_indices = []
        self.input_buffers = []

505
506
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
507

508
509
510
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

511
    def configure_post_pass(self):
512
        config = self.compilation_config
513
        self.post_grad_pass_manager.configure(config.pass_config)
514

515
516
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
517
        inductor_config = config.inductor_compile_config
518
519
520
521
522
523
        PASS_KEY = "post_grad_custom_post_pass"
        if PASS_KEY in inductor_config:
            # Config should automatically wrap all inductor passes
            assert isinstance(inductor_config[PASS_KEY], InductorPass)
            self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
        inductor_config[PASS_KEY] = self.post_grad_pass_manager
524

525
526
    def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.

            # 1. factors come from the vllm_config (it mainly summarizes how the
            #    model is created)
            vllm_config = self.vllm_config
            config_hash = vllm_config.compute_hash()

            # 2. factors come from the code files that are traced by Dynamo (
            #    it mainly summarizes how the model is used in forward pass)
            forward_code_files = list(
                sorted(self.compilation_config.traced_files))
            self.compilation_config.traced_files.clear()
            logger.debug(
                "Traced files (to be considered for compilation cache):\n%s",
                "\n".join(forward_code_files))
            hash_content = []
            for filepath in forward_code_files:
                hash_content.append(filepath)
                with open(filepath) as f:
                    hash_content.append(f.read())
            import hashlib
            code_hash = hashlib.md5(
                "\n".join(hash_content).encode()).hexdigest()

            # combine the two hashes to generate the cache dir
            hash_key = hashlib.md5(
                f"{config_hash}_{code_hash}".encode()).hexdigest()[:10]
            cache_dir = os.path.join(
                envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key,
                f"rank_{vllm_config.parallel_config.rank}")
        else:
            cache_dir = self.compilation_config.cache_dir
        os.makedirs(cache_dir, exist_ok=True)

        disabled = envs.VLLM_DISABLE_COMPILE_CACHE
        self.inductor_hash_cache: InductorHashCache = InductorHashCache(
            cache_dir, disabled=disabled)
        if disabled:
            logger.info("vLLM's torch.compile cache is disabled.")
        else:
            logger.info("Using cache directory: %s for vLLM's torch.compile",
                        cache_dir)

574
575
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
576
        compilation_counter.num_graphs_seen += 1
577
578
579
        from .monitor import torch_compile_start_time
        dynamo_time = time.time() - torch_compile_start_time
        logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
580
        self.compilation_config.compilation_time += dynamo_time
581
582
583
584
585
586

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
587
        self.configure_post_pass()
588
589

        self.split_gm, self.piecewise_graphs = split_graph(
590
            graph, self.compilation_config.splitting_ops)
591

592
        from torch._dynamo.utils import lazy_format_graph_code
593
594
595
596
597

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
598

599
600
601
602
603
604
605
606
607
608
        compilation_counter.num_piecewise_graphs_seen += len(
            self.piecewise_graphs)
        submod_names_to_compile = [
            item.submod_name for item in self.piecewise_graphs
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
        PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
609
610
                                    self.vllm_config, self.graph_pool,
                                    self).run(*example_inputs)
611
612
613

        self._called = True

614
615
        if not self.compilation_config.use_cudagraph or \
            not self.compilation_config.cudagraph_copy_inputs:
616
617
618
619
620
621
622
623
624
625
626
            return self.split_gm

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
627
628
629
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
630
631
        self.sym_tensor_indices = [
            i for i, x in enumerate(fake_args)
632
633
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \
                any(is_symbolic(d) for d in x.size())
634
635
636
637
638
639
640
641
642
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

youkaichao's avatar
youkaichao committed
643
644
        # this is the callable we return to Dynamo to run
        def copy_and_call(*args):
645
646
647
648
649
650
651
652
653
654
655
656
657
658
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

        return copy_and_call
659
660
661
662
663
664
665
666
667
668
669
670
671
672


@dataclasses.dataclass
class ConcreteSizeEntry:
    runtime_shape: int
    need_to_compile: bool  # the size is in compile_sizes
    use_cudagraph: bool  # the size is in capture_sizes

    compiled: bool = False
    runnable: Callable = None  # type: ignore
    num_finished_warmup: int = 0
    cudagraph: Optional[torch.cuda.CUDAGraph] = None
    output: Optional[Any] = None

673
674
675
676
    # for cudagraph debugging, track the input addresses
    # during capture, and check if they are the same during replay
    input_addresses: Optional[List[int]] = None

677
678
679

class PiecewiseBackend:

680
681
682
    def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
                 graph_pool: Any, piecewise_compile_index: int,
                 total_piecewise_compiles: int, sym_shape_indices: List[int],
683
684
                 compiled_graph_for_general_shape: Callable,
                 vllm_backend: VllmBackend):
685
686
687
688
689
690
        """
        The backend for piecewise compilation.
        It mainly handles the compilation and cudagraph capturing.

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
691
        `compilation_config.compile_sizes`.
692
693
694
695
696
697
698

        Independently, we will capture cudagraph for different shapes.

        If a shape needs both compilation and cudagraph, we will
        compile it first, and then capture cudagraph.
        """
        self.graph = graph
699
700
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
701
        self.graph_pool = graph_pool
702
703
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles
704
        self.vllm_backend = vllm_backend
705
706
707
708

        self.is_first_graph = piecewise_compile_index == 0
        self.is_last_graph = (
            piecewise_compile_index == total_piecewise_compiles - 1)
709
710

        self.compile_sizes: Set[int] = set(
711
            self.compilation_config.compile_sizes)
712
        self.capture_sizes: Set[int] = set(
713
714
            self.compilation_config.capture_sizes
        ) if self.compilation_config.use_cudagraph else set()
715
716
717

        self.first_run_finished = False

718
        self.compiled_graph_for_general_shape = compiled_graph_for_general_shape  # noqa
719

720
        self.sym_shape_indices = sym_shape_indices
721

722
723
        self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"

724
725
726
        # the entries for different shapes that we need to either
        # compile or capture cudagraph
        self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
727
728
729
730

        # to_be_compiled_sizes tracks the remaining sizes to compile,
        # and updates during the compilation process, so we need to copy it
        self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
731
732
733
734
735
736
737
        for shape in self.compile_sizes.union(self.capture_sizes):
            self.concrete_size_entries[shape] = ConcreteSizeEntry(
                runtime_shape=shape,
                need_to_compile=shape in self.compile_sizes,
                use_cudagraph=shape in self.capture_sizes,
            )

738
739
740
741
    def check_for_ending_compilation(self):
        if self.is_last_graph and not self.to_be_compiled_sizes:
            # no specific sizes to compile
            # save the hash of the inductor graph for the next run
742
            self.vllm_backend.inductor_hash_cache.save_to_file()
743
744
            end_monitoring_torch_compile(self.vllm_config)

745
746
747
    def __call__(self, *args) -> Any:
        if not self.first_run_finished:
            self.first_run_finished = True
748
            self.check_for_ending_compilation()
749
750
751
752
753
754
755
756
            return self.compiled_graph_for_general_shape(*args)

        runtime_shape = args[self.sym_shape_indices[0]]
        if runtime_shape not in self.concrete_size_entries:
            # we don't need to do anything for this shape
            return self.compiled_graph_for_general_shape(*args)

        entry = self.concrete_size_entries[runtime_shape]
757

758
759
        if entry.runnable is None:
            entry.runnable = self.compiled_graph_for_general_shape
760

761
762
        if entry.need_to_compile and not entry.compiled:
            entry.compiled = True
763
            self.to_be_compiled_sizes.remove(runtime_shape)
764
765
766
767
            # args are real arguments
            entry.runnable = wrap_inductor(
                self.graph,
                args,
768
769
                self.compilation_config.inductor_compile_config,
                self.compilation_config,
770
                self.vllm_backend,
771
772
                graph_index=self.piecewise_compile_index,
                num_graphs=self.total_piecewise_compiles,
773
                runtime_shape=runtime_shape,
774
                use_inductor=self.compilation_config.use_inductor)
775

776
777
            # finished compilations for all required shapes
            if self.is_last_graph and not self.to_be_compiled_sizes:
778
                self.check_for_ending_compilation()
779

780
781
782
783
        if not entry.use_cudagraph:
            return entry.runnable(*args)

        if entry.cudagraph is None:
784
            if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups:  # noqa
785
786
787
788
789
                entry.num_finished_warmup += 1
                if self.is_first_graph:
                    logger.debug(
                        "Warming up %s/%s for shape %s",
                        entry.num_finished_warmup,
790
                        self.compilation_config.cudagraph_num_of_warmups,
791
792
793
794
                        runtime_shape)
                return entry.runnable(*args)

            if self.is_first_graph:
795
796
797
798
799
                # Since we capture cudagraph for many different shapes and
                # capturing is fast, we don't need to log it for every shape.
                # We only log it in the debug mode.
                logger.debug("Capturing a cudagraph for shape %s",
                             runtime_shape)
800

801
802
803
804
            input_addresses = [
                x.data_ptr() for x in args if isinstance(x, torch.Tensor)
            ]
            entry.input_addresses = input_addresses
805
            cudagraph = torch.cuda.CUDAGraph()
806

807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
            with ExitStack() as stack:
                if not self.is_first_graph:
                    # during every model forward, we will capture
                    # many pieces of cudagraphs (roughly one per layer).
                    # running gc again and again across layers will
                    # make the cudagraph capture very slow.
                    # therefore, we only run gc for the first graph,
                    # and disable gc for the rest of the graphs.
                    stack.enter_context(patch("gc.collect", lambda: None))
                    stack.enter_context(
                        patch("torch.cuda.empty_cache", lambda: None))

                # mind-exploding: carefully manage the reference and memory.
                with torch.cuda.graph(cudagraph, pool=self.graph_pool):
                    # `output` is managed by pytorch's cudagraph pool
                    output = entry.runnable(*args)
                    if self.is_last_graph:
                        # by converting it to weak ref,
                        # the original `output` will immediately be released
                        # to save memory. It is only safe to do this for
                        # the last graph, because the output of the last graph
                        # will not be used by any other cuda graph.
                        output = weak_ref_tensors(output)
830
831
832
833
834

            # here we always use weak ref for the output
            # to save memory
            entry.output = weak_ref_tensors(output)
            entry.cudagraph = cudagraph
835
836
837

            compilation_counter.num_cudagraph_caputured += 1

838
839
840
841
842
843
844
845
846
847
848
849
850
851
            # important: we need to return the output, rather than
            # the weak ref of the output, so that pytorch can correctly
            # manage the memory during cuda graph capture
            return output

        if self.is_debugging_mode:
            # check if the input addresses are the same
            new_input_addresses = [
                x.data_ptr() for x in args if isinstance(x, torch.Tensor)
            ]
            assert new_input_addresses == entry.input_addresses, (
                "Input addresses for cudagraphs are different during replay."
                f" Expected {entry.input_addresses}, got {new_input_addresses}"
            )
852
853
854

        entry.cudagraph.replay()
        return entry.output