piecewise_backend.py 13.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import dataclasses
5
import io
6
import json
7
import pickle
8
import time
9
from collections.abc import Callable
10
from pickle import Pickler
11
from typing import Any
12

13
import torch._functorch.config
14
import torch.fx as fx
15
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
16
from torch._logging._internal import trace_structured
17
18
19
20

from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
21
from vllm.config.utils import Range
22
23
24
25
26
27
from vllm.logger import init_logger

logger = init_logger(__name__)


@dataclasses.dataclass
28
29
class RangeEntry:
    compile_range: Range
30
    compiled: bool = False
31
    runnable: Callable[..., Any] = None  # type: ignore
32
33


34
class PiecewiseBackend:
35
36
    def __init__(
        self,
37
        graph: fx.GraphModule | None,
38
39
40
41
42
        vllm_config: VllmConfig,
        piecewise_compile_index: int,
        total_piecewise_compiles: int,
        sym_shape_indices: list[int],
        vllm_backend: VllmBackend,
43
        returns_tuple: bool,
44
        compiled_runnables: dict[str, Callable[..., Any]] | None = None,
45
        submod_name: str = "",
46
    ):
47
48
        """
        The backend for piecewise compilation.
49
        It mainly handles the compilation of static shapes and
50
        dispatching based on runtime shape.
51
52
53
54

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
        `compilation_config.compile_sizes`.
55
56
57
58
59
60
61
62
63
64

        This class supports two mutually exclusive modes:
        1. Compilation (graph is set, compiled_runnables is None):
           Used during initial compilation when we have the FX graph
           and need to compile it for each shape range.
        2. Precompilation (graph is None, compiled_runnables is set):
           Used when loading from cache/AOT artifacts where we already
           have pre-compiled callables and don't need the original graph.

        Exactly one of graph or compiled_runnables must be provided.
65
        """
66
67
68
69
        assert bool(graph is not None) ^ bool(compiled_runnables is not None), (
            "exactly one of graph and compiled_runnables should be set."
        )

70
71
72
73
74
75
        self.graph = graph
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles
        self.vllm_backend = vllm_backend
76
        self.compiled_runnables = compiled_runnables
77
        self.submod_name = submod_name
78
79

        self.is_first_graph = piecewise_compile_index == 0
80
        self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
81

82
        self.is_full_graph = total_piecewise_compiles == 1
83
        self.is_encoder_compilation = vllm_backend.is_encoder
84

85
        self.compile_ranges = self.compilation_config.get_compile_ranges()
86
87
88
89
90
91
92
93
94
95
96
97
98
        if self.is_encoder_compilation:
            # For encoder compilation we use the max int32 value
            # to set the upper bound of the compile ranges
            max_int32 = 2**31 - 1
            last_compile_range = self.compile_ranges[-1]
            assert (
                last_compile_range.end
                == vllm_config.scheduler_config.max_num_batched_tokens
            )
            self.compile_ranges[-1] = Range(
                start=last_compile_range.start, end=max_int32
            )

99
100
        log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
        logger.debug_once(log_string)
101

102
103
104
        self.compile_sizes = self.compilation_config.compile_sizes
        log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
        logger.debug_once(log_string)
105
106

        self.sym_shape_indices = sym_shape_indices
107
        self.returns_tuple = returns_tuple
108

109
110
        # the entries for ranges that we need to either
        self.range_entries: dict[Range, RangeEntry] = {}
111

112
        # to_be_compiled_ranges tracks the remaining ranges to compile,
113
        # and updates during the compilation process, so we need to copy it
114
        self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
115
116

        # We only keep compilation management inside this class directly.
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        if self.compile_sizes is not None:
            for size in self.compile_sizes:
                if isinstance(size, str):
                    assert size == "cudagraph_capture_sizes"
                    raise NotImplementedError(
                        "cudagraph_capture_sizes not supported in compile_sizes."
                        "This should be handled in `post_init_cudagraph_sizes`."
                    )
                else:
                    assert isinstance(size, int)
                    range = Range(start=size, end=size)
                    if range not in self.compile_ranges:
                        self.range_entries[range] = RangeEntry(
                            compile_range=range,
                        )
                        self.to_be_compiled_ranges.add(range)
133
134
135
136

        for range in self.compile_ranges:
            self.range_entries[range] = RangeEntry(
                compile_range=range,
137
138
            )

139
140
141
        # Track whether we've logged the graph for this subgraph (only log once)
        self._graph_logged = False

142
143
144
145
146
147
148
        # get the on_compilation_complete callback from context...
        # PiecewiseBackend is created during the first call,
        # which is when the context is set (see compilation/decorators.py)
        from vllm.compilation.backends import _on_compilation_complete_callback

        self.on_compilation_complete = _on_compilation_complete_callback.get()

149
150
151
152
    def get_compiled_graph_wrapper(
        self, compiled_graph: Callable[..., Any]
    ) -> Callable[..., Any]:
        def compiled_graph_wrapper(*args: Any) -> Any:
153
154
155
156
157
158
159
160
161
162
163
            graph_output = compiled_graph(*args)
            # unpack the tuple if needed
            # TODO(rzou): the implication is that we're not
            # reading the python bytecode correctly in vLLM?
            if self.returns_tuple or not isinstance(graph_output, (tuple, list)):
                return graph_output
            else:
                return graph_output[0]

        return compiled_graph_wrapper

164
    def check_for_ending_compilation(self) -> None:
165
        if self.is_last_graph and not self.to_be_compiled_ranges:
166
167
            # no specific sizes to compile
            # save the hash of the inductor graph for the next run
168
            time_before_saving = time.perf_counter()
169
            self.vllm_backend.compiler_manager.save_to_file()
170
171
172
173
174
175
176
177
            elapsed = time.perf_counter() - time_before_saving
            if elapsed > 1:
                logger.info_once(
                    "Saved compiler manager cache in %.2f seconds.",
                    elapsed,
                    scope="local",
                )

178
            end_monitoring_torch_compile(self.vllm_config)
179
180
181
182
183
184
            # Call the completion callback (e.g., to save AOT compiled function)
            if self.on_compilation_complete is not None:
                self.on_compilation_complete()

    def to_bytes(self) -> dict[str, bytes]:
        class StandaloneCompiledArtifactsPickler(Pickler):
185
            def reducer_override(self, obj: object) -> Any:
186
187
188
189
190
191
192
193
194
                if isinstance(obj, CachingAutotuner):
                    obj.prepare_for_pickle()
                    return pickle.loads, (
                        pickle.dumps(
                            obj,
                        ),
                    )
                return NotImplemented

195
        def serialize(fn: Callable[..., Any]) -> bytes:
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
            assert hasattr(fn, "serialize"), "fn must have serialize method"
            with torch._functorch.config.patch("bundled_autograd_cache", True):
                entry = fn.serialize()

                f = io.BytesIO()
                StandaloneCompiledArtifactsPickler(f).dump(entry)
                result = f.getvalue()
            return result

        out = {}

        for range_key, entry in self.range_entries.items():
            if not entry.compiled:
                logger.debug(
                    "entry with range %s not compiled, so cannot get its bytes",
                    range_key,
                )
                continue
            if hasattr(entry.runnable, "serialize"):
                out[str(range_key)] = serialize(entry.runnable)

        return out
218

219
    def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]:
220
221
222
223
224
225
226
227
228
229
230
        # We need to pass fake example_inputs, otherwise torch.compile
        # will fakify the example_inputs potentially causing some non dynamic
        # dimension to be be duck shaped to other existing shapes that have hints
        # matching their values.
        # This is problem because it can lead to unintended specializations!
        # if the new wrongly dynamic dim is specialized
        # it will force specializing the whole shape
        # torch.compile probably should not accept
        # non fake tensors as example inputs!
        # See issue https://github.com/vllm-project/vllm/issues/27899
        fake_example_inputs = []
231
        assert self.graph is not None
232
233
234
235
236
237
238
239
240
        for node in self.graph.graph.nodes:
            # All place holders come first
            if node.op == "placeholder":
                fake_example_inputs.append(node.meta["example_value"])
            else:
                break
        assert len(fake_example_inputs) == len(args)
        return fake_example_inputs

241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
    def _log_compile_start(self, compile_range: Range):
        """Log compilation event for TORCH_TRACE/tlparse."""
        is_cudagraph_size = (
            self.compile_sizes is not None and compile_range.start in self.compile_sizes
        )
        subgraph_index = self.piecewise_compile_index
        submod_name = self.submod_name
        trace_structured(
            "artifact",
            metadata_fn=lambda: {
                "name": "vllm_piecewise_compile_start",
                "encoding": "json",
            },
            payload_fn=lambda: json.dumps(
                {
                    "piecewise_index": subgraph_index,
                    "submod_name": submod_name,
                    "total_piecewise_compiles": self.total_piecewise_compiles,
                    "compile_range_start": compile_range.start,
                    "compile_range_end": compile_range.end,
                    "is_single_size": compile_range.is_single_size(),
                    "is_cudagraph_capture_size": is_cudagraph_size,
                }
            ),
        )

        # Log the subgraph graph dump only once per subgraph (not per size)
        # to reduce log file size. The graph code is the same for all sizes.
        if not self._graph_logged:
            self._graph_logged = True
            assert self.graph is not None
            trace_structured(
                "graph_dump",
                metadata_fn=lambda: {
                    "name": f"vllm_{submod_name}",
                },
                payload_fn=lambda: self.graph.print_readable(print_output=False),
            )

280
281
282
    def _maybe_compile_for_range_entry(
        self, range_entry: RangeEntry, args: tuple[Any, ...]
    ) -> Any:
283
        if not range_entry.compiled:
284
285
286
287
288
            if self.compiled_runnables is not None:
                range_entry.runnable = self.get_compiled_graph_wrapper(
                    self.compiled_runnables[str(range_entry.compile_range)]
                )
            else:
289
290
                self._log_compile_start(range_entry.compile_range)

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
                # args are real arguments
                # fakify for range, real args for concrete size.
                # For concrete size, we clear the shape env in
                # compiler_manager.compile() so no need to fakify.
                args_list = (
                    self._fakify_args(args)
                    if not range_entry.compile_range.is_single_size()
                    else list(args)
                )

                with (
                    torch._functorch.config.patch("bundled_autograd_cache", True),
                ):
                    range_entry.runnable = self.vllm_backend.compiler_manager.compile(
                        self.graph,
                        args_list,
                        self.vllm_backend.inductor_config,
                        self.compilation_config,
                        compile_range=range_entry.compile_range,
                        graph_index=self.piecewise_compile_index,
                        num_graphs=self.total_piecewise_compiles,
                    )

314
315
            range_entry.compiled = True
            self.to_be_compiled_ranges.remove(range_entry.compile_range)
316

317
318
            self.check_for_ending_compilation()

319
    def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
320
321
322
        # First we try to find the range entry for the concrete compile size
        # If not found, we search for the range entry
        # that contains the runtime shape.
323
324
325
        if self.compile_sizes is None:
            return None

326
327
328
329
330
331
332
333
        if runtime_shape in self.compile_sizes:
            return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
        else:
            for range in self.compile_ranges:
                if runtime_shape in range:
                    return self.range_entries[range]
        return None

334
    def __call__(self, *args: Any) -> Any:
335
336
337
338
        runtime_shape = args[self.sym_shape_indices[0]]
        range_entry = self._find_range_for_shape(runtime_shape)

        assert range_entry is not None, (
339
            f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
340
        )
341

342
343
        self._maybe_compile_for_range_entry(range_entry, args)
        return range_entry.runnable(*args)