backends.py 23.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import ast
5
import dataclasses
6
7
import os
import pprint
8
import time
9
from collections.abc import Sequence
10
from contextlib import contextmanager
11
from typing import Any, Callable, Optional
12
13
14

import torch
import torch.fx as fx
15
from torch._dispatch.python import enable_python_dispatcher
16

17
import vllm.envs as envs
18
from vllm.config import CompilationConfig, VllmConfig
19
from vllm.logger import init_logger
20
from vllm.platforms import current_platform
21
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
22

23
24
from .compiler_interface import (CompilerInterface, EagerAdaptor,
                                 InductorAdaptor, InductorStandaloneAdaptor)
25
from .counter import compilation_counter
26
27
from .inductor_pass import InductorPass
from .pass_manager import PostGradPassManager
28
29
30

logger = init_logger(__name__)

31

32
33
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
    if compilation_config.use_inductor:
34
        if envs.VLLM_USE_STANDALONE_COMPILE and is_torch_equal_or_newer(
35
                "2.8.0.dev"):
36
            logger.debug("Using InductorStandaloneAdaptor")
37
38
            return InductorStandaloneAdaptor()
        else:
39
            logger.debug("Using InductorAdaptor")
40
41
            return InductorAdaptor()
    else:
42
        logger.debug("Using EagerAdaptor")
43
44
45
        return EagerAdaptor()


46
47
48
49
50
class CompilerManager:
    """
    A manager to manage the compilation process, including
    caching the compiled graph, loading the compiled graph,
    and compiling the graph.
51

52
53
54
    The cache is a dict mapping
    `(runtime_shape, graph_index, backend_name)`
    to `any_data` returned from the compiler.
55

56
57
58
    When serializing the cache, we save it to a Python file
    for readability. We don't use json here because json doesn't
    support int as key.
59
60
    """

61
    def __init__(self, compilation_config: CompilationConfig):
62
        self.cache: dict[tuple[Optional[int], int, str], Any] = dict()
63
        self.is_cache_updated = False
64
65
        self.compilation_config = compilation_config
        self.compiler = make_compiler(compilation_config)
66

67
68
    def compute_hash(self, vllm_config: VllmConfig) -> str:
        return self.compiler.compute_hash(vllm_config)
69

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    def initialize_cache(self,
                         cache_dir: str,
                         disable_cache: bool = False,
                         prefix: str = ""):
        """
        Initialize the cache directory for the compiler.

        The organization of the cache directory is as follows:
        cache_dir=/path/to/hash_str/rank_i_j/prefix/
        inside cache_dir, there will be:
        - vllm_compile_cache.py
        - computation_graph.py
        - transformed_code.py

        for multiple prefixes, they can share the same
        base cache dir of /path/to/hash_str/rank_i_j/ ,
        to store some common compilation artifacts.
        """

89
        self.disable_cache = disable_cache
90
        self.cache_dir = cache_dir
91
92
93
94
        self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")

        if not disable_cache and os.path.exists(self.cache_file_path):
            # load the cache from the file
95
            with open(self.cache_file_path) as f:
96
97
98
99
100
101
                # we use ast.literal_eval to parse the data
                # because it is a safe way to parse Python literals.
                # do not use eval(), it is unsafe.
                self.cache = ast.literal_eval(f.read())

        self.compiler.initialize_cache(cache_dir=cache_dir,
102
103
                                       disable_cache=disable_cache,
                                       prefix=prefix)
104
105

    def save_to_file(self):
106
        if self.disable_cache or not self.is_cache_updated:
107
            return
108
109
        printer = pprint.PrettyPrinter(indent=4)
        data = printer.pformat(self.cache)
110
        with open(self.cache_file_path, "w") as f:
111
112
113
114
            f.write(data)

    def load(self,
             graph: fx.GraphModule,
115
             example_inputs: list[Any],
116
117
118
119
120
121
122
             graph_index: int,
             runtime_shape: Optional[int] = None) -> Optional[Callable]:
        if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
            return None
        handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
        compiled_graph = self.compiler.load(handle, graph, example_inputs,
                                            graph_index, runtime_shape)
123
        logger.debug(
124
125
126
127
128
129
130
131
132
133
134
135
136
            "Directly load the %s-th graph for shape %s from %s via "
            "handle %s", graph_index, str(runtime_shape), self.compiler.name,
            handle)
        return compiled_graph

    def compile(self,
                graph: fx.GraphModule,
                example_inputs,
                additional_inductor_config,
                compilation_config: CompilationConfig,
                graph_index: int = 0,
                num_graphs: int = 1,
                runtime_shape: Optional[int] = None) -> Any:
137
        if graph_index == 0:
138
139
140
141
142
143
144
145
146
147
148
149
            # before compiling the first graph, record the start time
            global compilation_start_time
            compilation_start_time = time.time()

        compilation_counter.num_backend_compilations += 1

        compiled_graph = None

        # try to load from the cache
        compiled_graph = self.load(graph, example_inputs, graph_index,
                                   runtime_shape)
        if compiled_graph is not None:
150
151
152
153
154
155
156
157
            if graph_index == num_graphs - 1:
                # after loading the last graph for this shape, record the time.
                # there can be multiple graphs due to piecewise compilation.
                now = time.time()
                elapsed = now - compilation_start_time
                logger.info(
                    "Directly load the compiled graph(s) for shape %s "
                    "from the cache, took %.3f s", str(runtime_shape), elapsed)
158
159
160
161
            return compiled_graph

        # no compiler cached the graph, or the cache is disabled,
        # we need to compile it
162
163
164
165
166
167
        if isinstance(self.compiler, InductorAdaptor):
            # Let compile_fx generate a key for us
            maybe_key = None
        else:
            maybe_key = \
                f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
168
        compiled_graph, handle = self.compiler.compile(
169
170
            graph, example_inputs, additional_inductor_config, runtime_shape,
            maybe_key)
171
172
173
174
175
176
177

        assert compiled_graph is not None, "Failed to compile the graph"

        # store the artifact in the cache
        if handle is not None:
            self.cache[(runtime_shape, graph_index,
                        self.compiler.name)] = handle
178
            self.is_cache_updated = True
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
            if graph_index == 0:
                # adds some info logging for the first graph
                logger.info("Cache the graph of shape %s for later use",
                            str(runtime_shape))
            logger.debug(
                "store the %s-th graph for shape %s from %s via handle %s",
                graph_index, str(runtime_shape), self.compiler.name, handle)

        # after compiling the last graph, record the end time
        if graph_index == num_graphs - 1:
            now = time.time()
            elapsed = now - compilation_start_time
            compilation_config.compilation_time += elapsed
            if runtime_shape is None:
                logger.info("Compiling a graph for general shape takes %.2f s",
                            elapsed)
            else:
                logger.info("Compiling a graph for shape %s takes %.2f s",
                            runtime_shape, elapsed)
198

199
        return compiled_graph
200
201


202
203
204
@dataclasses.dataclass
class SplitItem:
    submod_name: str
205
    graph_id: int
206
207
208
209
210
    is_splitting_graph: bool
    graph: fx.GraphModule


def split_graph(graph: fx.GraphModule,
211
                ops: list[str]) -> tuple[fx.GraphModule, list[SplitItem]]:
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    # split graph by ops
    subgraph_id = 0
    node_to_subgraph_id = {}
    split_op_graphs = []
    for node in graph.graph.nodes:
        if node.op in ("output", "placeholder"):
            continue
        if node.op == 'call_function' and str(node.target) in ops:
            subgraph_id += 1
            node_to_subgraph_id[node] = subgraph_id
            split_op_graphs.append(subgraph_id)
            subgraph_id += 1
        else:
            node_to_subgraph_id[node] = subgraph_id

    # `keep_original_order` is important!
    # otherwise pytorch might reorder the nodes and
    # the semantics of the graph will change when we
    # have mutations in the graph
    split_gm = torch.fx.passes.split_module.split_module(
232
        graph,
233
234
235
        None,
        lambda node: node_to_subgraph_id[node],
        keep_original_order=True)
236

237
    outputs = []
238

239
    names = [name for (name, module) in split_gm.named_modules()]
240

241
242
243
244
    for name in names:
        if "." in name or name == "":
            # recursive child module or the root module
            continue
245

246
        module = getattr(split_gm, name)
247

248
        graph_id = int(name.replace("submod_", ""))
249
250
251
252
253
        outputs.append(
            SplitItem(name, graph_id, (graph_id in split_op_graphs), module))

    # sort by intetger graph_id, rather than string name
    outputs.sort(key=lambda x: x.graph_id)
254

255
    return split_gm, outputs
256
257


258
259
260
# we share the global graph pool among all the backends
global_graph_pool = None

261
262
compilation_start_time = 0.0

263
264
265
266
267
268

class PiecewiseCompileInterpreter(torch.fx.Interpreter):
    """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
    It runs the given graph with fake inputs, and compile some
    submodules specified by `compile_submod_names` with the given
    compilation configs.
269
270
271
272
273

    NOTE: the order in `compile_submod_names` matters, because
    it will be used to determine the order of the compiled piecewise
    graphs. The first graph will handle logging, and the last graph
    has some special cudagraph output handling.
274
275
276
    """

    def __init__(self, module: torch.fx.GraphModule,
277
                 compile_submod_names: list[str], vllm_config: VllmConfig,
278
                 graph_pool, vllm_backend: "VllmBackend"):
279
280
281
282
        super().__init__(module)
        from torch._guards import detect_fake_mode
        self.fake_mode = detect_fake_mode()
        self.compile_submod_names = compile_submod_names
283
        self.compilation_config = vllm_config.compilation_config
284
        self.graph_pool = graph_pool
285
        self.vllm_config = vllm_config
286
        self.vllm_backend = vllm_backend
287
288
        # When True, it annoyingly dumps the torch.fx.Graph on errors.
        self.extra_traceback = False
289
290
291
292
293
294

    def run(self, *args):
        fake_args = [
            self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in args
        ]
295
        with self.fake_mode, enable_python_dispatcher():
296
            return super().run(*fake_args)
297
298

    def call_module(self, target: torch.fx.node.Target,
299
300
                    args: tuple[torch.fx.node.Argument,
                                ...], kwargs: dict[str, Any]) -> Any:
301
302
303
304
        assert isinstance(target, str)
        output = super().call_module(target, args, kwargs)

        if target in self.compile_submod_names:
305
            index = self.compile_submod_names.index(target)
306
307
308
309
            submod = self.fetch_attr(target)
            sym_shape_indices = [
                i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
            ]
310
            global compilation_start_time
311
312
            compiled_graph_for_general_shape = self.vllm_backend.\
                compiler_manager.compile(
313
314
                submod,
                args,
315
316
                self.compilation_config.inductor_compile_config,
                self.compilation_config,
317
318
                graph_index=index,
                num_graphs=len(self.compile_submod_names),
319
                runtime_shape=None)
320

321
322
323
            piecewise_backend = resolve_obj_by_qualname(
                current_platform.get_piecewise_backend_cls())
            self.module.__dict__[target] = piecewise_backend(
324
                submod, self.vllm_config, self.graph_pool, index,
325
                len(self.compile_submod_names), sym_shape_indices,
326
                compiled_graph_for_general_shape, self.vllm_backend)
327
328
329
330
331
332

            compilation_counter.num_piecewise_capturable_graphs_seen += 1

        return output


333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
# the tag for the part of model being compiled,
# e.g. backbone/eagle_head
model_tag: str = "backbone"


@contextmanager
def set_model_tag(tag: str):
    """Context manager to set the model tag."""
    global model_tag
    assert tag != model_tag, \
        f"Model tag {tag} is the same as the current tag {model_tag}."
    old_tag = model_tag
    model_tag = tag
    try:
        yield
    finally:
        model_tag = old_tag


352
class VllmBackend:
353
    """The compilation backend for `torch.compile` with vLLM.
354
355
    It is used for compilation level of `CompilationLevel.PIECEWISE`,
    where we customize the compilation.
356

357
358
    The major work of this backend is to split the graph into
    piecewise graphs, and pass them to the piecewise backend.
359

360
361
    This backend also adds the PostGradPassManager to Inductor config,
    which handles the post-grad passes.
362
    """
363

364
365
    vllm_config: VllmConfig
    compilation_config: CompilationConfig
366
367
368
369
370
371
    graph_pool: Any
    _called: bool = False
    # the graph we compiled
    graph: fx.GraphModule
    # the stiching graph module for all the piecewise graphs
    split_gm: fx.GraphModule
372
    piecewise_graphs: list[SplitItem]
373
    returned_callable: Callable
374
375
    # Inductor passes to run on the graph pre-defunctionalization
    post_grad_passes: Sequence[Callable]
376
377
    sym_tensor_indices: list[int]
    input_buffers: list[torch.Tensor]
378
    compiler_manager: CompilerManager
379

380
381
    def __init__(
        self,
382
        vllm_config: VllmConfig,
383
        prefix: str = "",
384
    ):
385
386
387
388
389
390
391
392
393

        # if the model is initialized with a non-empty prefix,
        # then usually it's enough to use that prefix,
        # e.g. launguage_model, vision_model, etc.
        # when multiple parts are initialized as independent
        # models, we need to use the model_tag to distinguish
        # them, e.g. backbone (default), eagle_head, etc.
        self.prefix = prefix or model_tag

394
395
        global global_graph_pool
        if global_graph_pool is None:
396
            global_graph_pool = current_platform.graph_pool_handle()
397
398
399
400
401

        # TODO: in the future, if we want to use multiple
        # streams, it might not be safe to share a global pool.
        # only investigate this when we use multiple streams
        self.graph_pool = global_graph_pool
402
403
404

        # Passes to run on the graph post-grad.
        self.post_grad_pass_manager = PostGradPassManager()
405

406
407
408
        self.sym_tensor_indices = []
        self.input_buffers = []

409
410
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
411

412
        self.compiler_manager: CompilerManager = CompilerManager(
413
            self.compilation_config)
414

415
416
417
        # `torch.compile` is JIT compiled, so we don't need to
        # do anything here

418
    def configure_post_pass(self):
419
        config = self.compilation_config
420
        self.post_grad_pass_manager.configure(self.vllm_config)
421

422
423
        # Post-grad custom passes are run using the post_grad_custom_post_pass
        # hook. If a pass for that hook exists, add it to the pass manager.
424
        inductor_config = config.inductor_compile_config
425
426
427
        PASS_KEY = "post_grad_custom_post_pass"
        if PASS_KEY in inductor_config:
            # Config should automatically wrap all inductor passes
428
429
430
431
432
433
            if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
                assert (inductor_config[PASS_KEY].uuid() ==
                        self.post_grad_pass_manager.uuid())
            else:
                assert isinstance(inductor_config[PASS_KEY], InductorPass)
                self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
434
        inductor_config[PASS_KEY] = self.post_grad_pass_manager
435

436
437
    def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

438
        vllm_config = self.vllm_config
439
440
441
442
443
444
        if not self.compilation_config.cache_dir:
            # no provided cache dir, generate one based on the known factors
            # that affects the compilation. if none of the factors change,
            # the cache dir will be the same so that we can reuse the compiled
            # graph.

445
            factors = []
446
447
448
449
450
            # 0. factors come from the env, for example, The values of
            # VLLM_PP_LAYER_PARTITION will affects the computation graph.
            env_hash = envs.compute_hash()
            factors.append(env_hash)

451
452
453
            # 1. factors come from the vllm_config (it mainly summarizes how the
            #    model is created)
            config_hash = vllm_config.compute_hash()
454
            factors.append(config_hash)
455
456
457
458
459
460
461
462
463
464
465
466

            # 2. factors come from the code files that are traced by Dynamo (
            #    it mainly summarizes how the model is used in forward pass)
            forward_code_files = list(
                sorted(self.compilation_config.traced_files))
            self.compilation_config.traced_files.clear()
            logger.debug(
                "Traced files (to be considered for compilation cache):\n%s",
                "\n".join(forward_code_files))
            hash_content = []
            for filepath in forward_code_files:
                hash_content.append(filepath)
467
468
469
470
                if filepath == "<string>":
                    # This means the function was dynamically generated, with
                    # e.g. exec(). We can't actually check these.
                    continue
471
472
473
                with open(filepath) as f:
                    hash_content.append(f.read())
            import hashlib
474
475
            code_hash = hashlib.md5("\n".join(hash_content).encode(),
                                    usedforsecurity=False).hexdigest()
476
477
478
479
480
481
482
            factors.append(code_hash)

            # 3. compiler hash
            compiler_hash = self.compiler_manager.compute_hash(vllm_config)
            factors.append(compiler_hash)

            # combine all factors to generate the cache dir
483
484
            hash_key = hashlib.md5(str(factors).encode(),
                                   usedforsecurity=False).hexdigest()[:10]
485
486

            cache_dir = os.path.join(
487
488
489
490
491
492
                envs.VLLM_CACHE_ROOT,
                "torch_compile_cache",
                hash_key,
            )
            self.compilation_config.cache_dir = cache_dir

493
        cache_dir = self.compilation_config.cache_dir
494
        os.makedirs(cache_dir, exist_ok=True)
495
        self.compilation_config.cache_dir = cache_dir
496
497
        rank = vllm_config.parallel_config.rank
        dp_rank = vllm_config.parallel_config.data_parallel_rank
498
499
        local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}",
                                       self.prefix)
500
        os.makedirs(local_cache_dir, exist_ok=True)
501
        self.compilation_config.local_cache_dir = local_cache_dir
502

503
504
505
        disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE

        if disable_cache:
506
507
508
            logger.info("vLLM's torch.compile cache is disabled.")
        else:
            logger.info("Using cache directory: %s for vLLM's torch.compile",
509
                        local_cache_dir)
510

511
512
        self.compiler_manager.initialize_cache(local_cache_dir, disable_cache,
                                               self.prefix)
513

514
515
        # when dynamo calls the backend, it means the bytecode
        # transform and analysis are done
516
        compilation_counter.num_graphs_seen += 1
517
518
519
        from .monitor import torch_compile_start_time
        dynamo_time = time.time() - torch_compile_start_time
        logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time)
520
        self.compilation_config.compilation_time += dynamo_time
521
522
523
524
525
526

        # we control the compilation process, each instance can only be
        # called once
        assert not self._called, "VllmBackend can only be called once"

        self.graph = graph
527
        self.configure_post_pass()
528
529

        self.split_gm, self.piecewise_graphs = split_graph(
530
            graph, self.compilation_config.splitting_ops)
531

532
        from torch._dynamo.utils import lazy_format_graph_code
533
534
535
536
537

        # depyf will hook lazy_format_graph_code and dump the graph
        # for debugging, no need to print the graph here
        lazy_format_graph_code("before split", self.graph)
        lazy_format_graph_code("after split", self.split_gm)
538

539
540
541
542
543
544
545
546
547
548
        compilation_counter.num_piecewise_graphs_seen += len(
            self.piecewise_graphs)
        submod_names_to_compile = [
            item.submod_name for item in self.piecewise_graphs
            if not item.is_splitting_graph
        ]

        # propagate the split graph to the piecewise backend,
        # compile submodules with symbolic shapes
        PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
549
550
                                    self.vllm_config, self.graph_pool,
                                    self).run(*example_inputs)
551

552
553
554
555
556
557
558
559
560
561
562
563
        graph_path = os.path.join(local_cache_dir, "computation_graph.py")
        if not os.path.exists(graph_path):
            # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
            # use `print_readable` because it can include submodules
            src = "from __future__ import annotations\nimport torch\n" + \
                self.split_gm.print_readable(print_output=False)
            src = src.replace("<lambda>", "GraphModule")
            with open(graph_path, "w") as f:
                f.write(src)

            logger.debug("Computation graph saved to %s", graph_path)

564
565
        self._called = True

566
567
        if not self.compilation_config.use_cudagraph or \
            not self.compilation_config.cudagraph_copy_inputs:
568
569
570
571
572
573
574
575
576
577
578
            return self.split_gm

        # if we need to copy input buffers for cudagraph
        from torch._guards import detect_fake_mode
        fake_mode = detect_fake_mode()
        fake_args = [
            fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
            for t in example_inputs
        ]

        # index of tensors that have symbolic shapes (batch size)
579
580
581
        # for weights and static buffers, they will have concrete shapes.
        # symbolic shape only happens for input tensors.
        from torch.fx.experimental.symbolic_shapes import is_symbolic
582
583
        self.sym_tensor_indices = [
            i for i, x in enumerate(fake_args)
584
585
            if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \
                any(is_symbolic(d) for d in x.size())
586
587
588
589
590
591
592
593
594
        ]

        # compiler managed cudagraph input buffers
        # we assume the first run with symbolic shapes
        # has the maximum size among all the tensors
        self.input_buffers = [
            example_inputs[x].clone() for x in self.sym_tensor_indices
        ]

youkaichao's avatar
youkaichao committed
595
596
        # this is the callable we return to Dynamo to run
        def copy_and_call(*args):
597
598
599
600
601
602
603
604
605
606
607
608
609
610
            list_args = list(args)
            for i, index in enumerate(self.sym_tensor_indices):
                runtime_tensor = list_args[index]
                runtime_shape = runtime_tensor.shape[0]
                static_tensor = self.input_buffers[i][:runtime_shape]

                # copy the tensor to the static buffer
                static_tensor.copy_(runtime_tensor)

                # replace the tensor in the list_args to the static buffer
                list_args[index] = static_tensor
            return self.split_gm(*list_args)

        return copy_and_call