piecewise_backend.py 4.32 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import dataclasses
5
6
from collections.abc import Callable
from typing import Any
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

import torch.fx as fx

import vllm.envs as envs
from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger

logger = init_logger(__name__)


@dataclasses.dataclass
class ConcreteSizeEntry:
    runtime_shape: int
    compiled: bool = False
    runnable: Callable = None  # type: ignore


26
class PiecewiseBackend:
27
28
29
30
31
32
33
34
35
36
    def __init__(
        self,
        graph: fx.GraphModule,
        vllm_config: VllmConfig,
        piecewise_compile_index: int,
        total_piecewise_compiles: int,
        sym_shape_indices: list[int],
        compiled_graph_for_general_shape: Callable,
        vllm_backend: VllmBackend,
    ):
37
38
        """
        The backend for piecewise compilation.
39
        It mainly handles the compilation of static shapes and
40
        dispatching based on runtime shape.
41
42
43
44
45
46
47
48
49
50
51
52
53

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
        `compilation_config.compile_sizes`.
        """
        self.graph = graph
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles
        self.vllm_backend = vllm_backend

        self.is_first_graph = piecewise_compile_index == 0
54
        self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
55

56
57
        self.is_full_graph = total_piecewise_compiles == 1

58
        self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes)
59
60
61
62
63
64
65
66
67

        self.first_run_finished = False

        self.compiled_graph_for_general_shape = compiled_graph_for_general_shape  # noqa

        self.sym_shape_indices = sym_shape_indices

        self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"

68
        # the entries for different shapes that we need to compile
69
70
71
72
73
        self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}

        # to_be_compiled_sizes tracks the remaining sizes to compile,
        # and updates during the compilation process, so we need to copy it
        self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
74
75
76

        # We only keep compilation management inside this class directly.
        for shape in self.compile_sizes:
77
78
            self.concrete_size_entries[shape] = ConcreteSizeEntry(
                runtime_shape=shape,
79
                runnable=self.compiled_graph_for_general_shape,
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
            )

    def check_for_ending_compilation(self):
        if self.is_last_graph and not self.to_be_compiled_sizes:
            # no specific sizes to compile
            # save the hash of the inductor graph for the next run
            self.vllm_backend.compiler_manager.save_to_file()
            end_monitoring_torch_compile(self.vllm_config)

    def __call__(self, *args) -> Any:
        if not self.first_run_finished:
            self.first_run_finished = True
            self.check_for_ending_compilation()
            return self.compiled_graph_for_general_shape(*args)

        runtime_shape = args[self.sym_shape_indices[0]]
96

97
98
99
100
101
102
        if runtime_shape not in self.concrete_size_entries:
            # we don't need to do anything for this shape
            return self.compiled_graph_for_general_shape(*args)

        entry = self.concrete_size_entries[runtime_shape]

103
        if not entry.compiled:
104
105
106
107
108
109
            entry.compiled = True
            self.to_be_compiled_sizes.remove(runtime_shape)
            # args are real arguments
            entry.runnable = self.vllm_backend.compiler_manager.compile(
                self.graph,
                args,
110
                self.vllm_backend.inductor_config,
111
112
113
                self.compilation_config,
                graph_index=self.piecewise_compile_index,
                num_graphs=self.total_piecewise_compiles,
114
115
                runtime_shape=runtime_shape,
            )
116
117
118
119
120

            # finished compilations for all required shapes
            if self.is_last_graph and not self.to_be_compiled_sizes:
                self.check_for_ending_compilation()

121
        return entry.runnable(*args)