piecewise_backend.py 11.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import dataclasses
5
6
import io
import pickle
7
from collections.abc import Callable
8
from pickle import Pickler
9
from typing import Any
10

11
import torch._functorch.config
12
import torch.fx as fx
13
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
14
15
16
17

from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
18
from vllm.config.utils import Range
19
20
21
22
23
24
from vllm.logger import init_logger

logger = init_logger(__name__)


@dataclasses.dataclass
25
26
class RangeEntry:
    compile_range: Range
27
    compiled: bool = False
28
    runnable: Callable[..., Any] = None  # type: ignore
29
30


31
class PiecewiseBackend:
32
33
    def __init__(
        self,
34
        graph: fx.GraphModule | None,
35
36
37
38
39
        vllm_config: VllmConfig,
        piecewise_compile_index: int,
        total_piecewise_compiles: int,
        sym_shape_indices: list[int],
        vllm_backend: VllmBackend,
40
        returns_tuple: bool,
41
        compiled_runnables: dict[str, Callable[..., Any]] | None = None,
42
    ):
43
44
        """
        The backend for piecewise compilation.
45
        It mainly handles the compilation of static shapes and
46
        dispatching based on runtime shape.
47
48
49
50

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
        `compilation_config.compile_sizes`.
51
52
53
54
55
56
57
58
59
60

        This class supports two mutually exclusive modes:
        1. Compilation (graph is set, compiled_runnables is None):
           Used during initial compilation when we have the FX graph
           and need to compile it for each shape range.
        2. Precompilation (graph is None, compiled_runnables is set):
           Used when loading from cache/AOT artifacts where we already
           have pre-compiled callables and don't need the original graph.

        Exactly one of graph or compiled_runnables must be provided.
61
        """
62
63
64
65
        assert bool(graph is not None) ^ bool(compiled_runnables is not None), (
            "exactly one of graph and compiled_runnables should be set."
        )

66
67
68
69
70
71
        self.graph = graph
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles
        self.vllm_backend = vllm_backend
72
        self.compiled_runnables = compiled_runnables
73
74

        self.is_first_graph = piecewise_compile_index == 0
75
        self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
76

77
        self.is_full_graph = total_piecewise_compiles == 1
78
        self.is_encoder_compilation = vllm_backend.is_encoder
79

80
        self.compile_ranges = self.compilation_config.get_compile_ranges()
81
82
83
84
85
86
87
88
89
90
91
92
93
        if self.is_encoder_compilation:
            # For encoder compilation we use the max int32 value
            # to set the upper bound of the compile ranges
            max_int32 = 2**31 - 1
            last_compile_range = self.compile_ranges[-1]
            assert (
                last_compile_range.end
                == vllm_config.scheduler_config.max_num_batched_tokens
            )
            self.compile_ranges[-1] = Range(
                start=last_compile_range.start, end=max_int32
            )

94
95
        log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
        logger.debug_once(log_string)
96

97
98
99
        self.compile_sizes = self.compilation_config.compile_sizes
        log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
        logger.debug_once(log_string)
100
101

        self.sym_shape_indices = sym_shape_indices
102
        self.returns_tuple = returns_tuple
103

104
105
        # the entries for ranges that we need to either
        self.range_entries: dict[Range, RangeEntry] = {}
106

107
        # to_be_compiled_ranges tracks the remaining ranges to compile,
108
        # and updates during the compilation process, so we need to copy it
109
        self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
110
111

        # We only keep compilation management inside this class directly.
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        if self.compile_sizes is not None:
            for size in self.compile_sizes:
                if isinstance(size, str):
                    assert size == "cudagraph_capture_sizes"
                    raise NotImplementedError(
                        "cudagraph_capture_sizes not supported in compile_sizes."
                        "This should be handled in `post_init_cudagraph_sizes`."
                    )
                else:
                    assert isinstance(size, int)
                    range = Range(start=size, end=size)
                    if range not in self.compile_ranges:
                        self.range_entries[range] = RangeEntry(
                            compile_range=range,
                        )
                        self.to_be_compiled_ranges.add(range)
128
129
130
131

        for range in self.compile_ranges:
            self.range_entries[range] = RangeEntry(
                compile_range=range,
132
133
            )

134
135
136
137
138
139
140
        # get the on_compilation_complete callback from context...
        # PiecewiseBackend is created during the first call,
        # which is when the context is set (see compilation/decorators.py)
        from vllm.compilation.backends import _on_compilation_complete_callback

        self.on_compilation_complete = _on_compilation_complete_callback.get()

141
142
143
144
    def get_compiled_graph_wrapper(
        self, compiled_graph: Callable[..., Any]
    ) -> Callable[..., Any]:
        def compiled_graph_wrapper(*args: Any) -> Any:
145
146
147
148
149
150
151
152
153
154
155
            graph_output = compiled_graph(*args)
            # unpack the tuple if needed
            # TODO(rzou): the implication is that we're not
            # reading the python bytecode correctly in vLLM?
            if self.returns_tuple or not isinstance(graph_output, (tuple, list)):
                return graph_output
            else:
                return graph_output[0]

        return compiled_graph_wrapper

156
    def check_for_ending_compilation(self) -> None:
157
        if self.is_last_graph and not self.to_be_compiled_ranges:
158
159
160
161
            # no specific sizes to compile
            # save the hash of the inductor graph for the next run
            self.vllm_backend.compiler_manager.save_to_file()
            end_monitoring_torch_compile(self.vllm_config)
162
163
164
165
166
167
            # Call the completion callback (e.g., to save AOT compiled function)
            if self.on_compilation_complete is not None:
                self.on_compilation_complete()

    def to_bytes(self) -> dict[str, bytes]:
        class StandaloneCompiledArtifactsPickler(Pickler):
168
            def reducer_override(self, obj: object) -> Any:
169
170
171
172
173
174
175
176
177
                if isinstance(obj, CachingAutotuner):
                    obj.prepare_for_pickle()
                    return pickle.loads, (
                        pickle.dumps(
                            obj,
                        ),
                    )
                return NotImplemented

178
        def serialize(fn: Callable[..., Any]) -> bytes:
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
            assert hasattr(fn, "serialize"), "fn must have serialize method"
            with torch._functorch.config.patch("bundled_autograd_cache", True):
                entry = fn.serialize()

                f = io.BytesIO()
                StandaloneCompiledArtifactsPickler(f).dump(entry)
                result = f.getvalue()
            return result

        out = {}

        for range_key, entry in self.range_entries.items():
            if not entry.compiled:
                logger.debug(
                    "entry with range %s not compiled, so cannot get its bytes",
                    range_key,
                )
                continue
            if hasattr(entry.runnable, "serialize"):
                out[str(range_key)] = serialize(entry.runnable)

        return out
201

202
    def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]:
203
204
205
206
207
208
209
210
211
212
213
        # We need to pass fake example_inputs, otherwise torch.compile
        # will fakify the example_inputs potentially causing some non dynamic
        # dimension to be be duck shaped to other existing shapes that have hints
        # matching their values.
        # This is problem because it can lead to unintended specializations!
        # if the new wrongly dynamic dim is specialized
        # it will force specializing the whole shape
        # torch.compile probably should not accept
        # non fake tensors as example inputs!
        # See issue https://github.com/vllm-project/vllm/issues/27899
        fake_example_inputs = []
214
        assert self.graph is not None
215
216
217
218
219
220
221
222
223
        for node in self.graph.graph.nodes:
            # All place holders come first
            if node.op == "placeholder":
                fake_example_inputs.append(node.meta["example_value"])
            else:
                break
        assert len(fake_example_inputs) == len(args)
        return fake_example_inputs

224
225
226
    def _maybe_compile_for_range_entry(
        self, range_entry: RangeEntry, args: tuple[Any, ...]
    ) -> Any:
227
        if not range_entry.compiled:
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
            if self.compiled_runnables is not None:
                range_entry.runnable = self.get_compiled_graph_wrapper(
                    self.compiled_runnables[str(range_entry.compile_range)]
                )
            else:
                # args are real arguments
                # fakify for range, real args for concrete size.
                # For concrete size, we clear the shape env in
                # compiler_manager.compile() so no need to fakify.
                args_list = (
                    self._fakify_args(args)
                    if not range_entry.compile_range.is_single_size()
                    else list(args)
                )

                with (
                    torch._functorch.config.patch("bundled_autograd_cache", True),
                ):
                    range_entry.runnable = self.vllm_backend.compiler_manager.compile(
                        self.graph,
                        args_list,
                        self.vllm_backend.inductor_config,
                        self.compilation_config,
                        compile_range=range_entry.compile_range,
                        graph_index=self.piecewise_compile_index,
                        num_graphs=self.total_piecewise_compiles,
                    )

256
257
            range_entry.compiled = True
            self.to_be_compiled_ranges.remove(range_entry.compile_range)
258

259
260
            self.check_for_ending_compilation()

261
    def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
262
263
264
        # First we try to find the range entry for the concrete compile size
        # If not found, we search for the range entry
        # that contains the runtime shape.
265
266
267
        if self.compile_sizes is None:
            return None

268
269
270
271
272
273
274
275
        if runtime_shape in self.compile_sizes:
            return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
        else:
            for range in self.compile_ranges:
                if runtime_shape in range:
                    return self.range_entries[range]
        return None

276
    def __call__(self, *args: Any) -> Any:
277
278
279
280
        runtime_shape = args[self.sym_shape_indices[0]]
        range_entry = self._find_range_for_shape(runtime_shape)

        assert range_entry is not None, (
281
            f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
282
        )
283

284
285
        self._maybe_compile_for_range_entry(range_entry, args)
        return range_entry.runnable(*args)