piecewise_backend.py 14.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import dataclasses
5
import io
6
import json
7
import pickle
8
from collections.abc import Callable
9
from pickle import Pickler
10
from typing import Any
11

12
import torch._functorch.config
13
import torch.fx as fx
14
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
15
from torch._logging._internal import trace_structured
16
17
18

from vllm.compilation.backends import VllmBackend
from vllm.config import VllmConfig
19
from vllm.config.utils import Range
20
21
22
23
24
from vllm.logger import init_logger

logger = init_logger(__name__)


25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def get_fake_args_from_graph(graph: fx.GraphModule) -> list[Any]:
    """Get fake args directly from graph placeholder nodes."""
    fake_args = []
    for node in graph.graph.nodes:
        if node.op == "placeholder":
            fake_args.append(node.meta["example_value"])
        else:
            break
    return fake_args


def create_concrete_args(graph: fx.GraphModule, size: int) -> list[Any]:
    """Create example inputs with symbolic dims replaced by a concrete size.

    Used for single-size eager compilation where we need concrete-shaped
    inputs but don't have real runtime tensors yet.
    """
    from torch._prims_common import compute_required_storage_length
    from torch.fx.experimental.symbolic_shapes import is_symbolic

    def concretize(sym_val: Any) -> int:
        """Replace all symbolic variables in a SymInt expression with size."""
        if not is_symbolic(sym_val):
            return int(sym_val)
        expr = sym_val.node.expr
        return int(expr.subs({s: size for s in expr.free_symbols}))

    args: list[Any] = []
    for node in graph.graph.nodes:
        if node.op != "placeholder":
            break
        val = node.meta["example_value"]
        if isinstance(val, torch.SymInt):
            args.append(concretize(val))
        elif isinstance(val, torch.Tensor):
            new_shape = tuple(concretize(d) for d in val.shape)
            new_strides = tuple(concretize(s) for s in val.stride())
            new_storage_offset = concretize(val.storage_offset())
            needed_size = compute_required_storage_length(
                new_shape, new_strides, new_storage_offset
            )
            t = torch.empty(needed_size, dtype=val.dtype, device=val.device)
            t = t.as_strided(new_shape, new_strides, new_storage_offset)
            args.append(t)
        else:
            args.append(val)
    return args


74
@dataclasses.dataclass
75
76
class RangeEntry:
    compile_range: Range
77
    compiled: bool = False
78
    runnable: Callable[..., Any] = None  # type: ignore
79
80


81
class PiecewiseBackend:
82
83
    def __init__(
        self,
84
        graph: fx.GraphModule | None,
85
86
87
88
89
        vllm_config: VllmConfig,
        piecewise_compile_index: int,
        total_piecewise_compiles: int,
        sym_shape_indices: list[int],
        vllm_backend: VllmBackend,
90
        returns_tuple: bool,
91
        compiled_runnables: dict[str, Callable[..., Any]] | None = None,
92
        submod_name: str = "",
93
    ):
94
95
        """
        The backend for piecewise compilation.
96
        It mainly handles the compilation of static shapes and
97
        dispatching based on runtime shape.
98
99
100
101

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
        `compilation_config.compile_sizes`.
102
103
104
105
106
107
108
109
110
111

        This class supports two mutually exclusive modes:
        1. Compilation (graph is set, compiled_runnables is None):
           Used during initial compilation when we have the FX graph
           and need to compile it for each shape range.
        2. Precompilation (graph is None, compiled_runnables is set):
           Used when loading from cache/AOT artifacts where we already
           have pre-compiled callables and don't need the original graph.

        Exactly one of graph or compiled_runnables must be provided.
112
        """
113
114
115
116
        assert bool(graph is not None) ^ bool(compiled_runnables is not None), (
            "exactly one of graph and compiled_runnables should be set."
        )

117
118
119
120
121
122
        self.graph = graph
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles
        self.vllm_backend = vllm_backend
123
        self.compiled_runnables = compiled_runnables
124
        self.submod_name = submod_name
125
126

        self.is_first_graph = piecewise_compile_index == 0
127
        self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
128

129
        self.is_full_graph = total_piecewise_compiles == 1
130
        self.is_encoder_compilation = vllm_backend.is_encoder
131

132
        self.compile_ranges = self.compilation_config.get_compile_ranges()
133
134
135
136
137
138
139
140
141
142
143
144
145
        if self.is_encoder_compilation:
            # For encoder compilation we use the max int32 value
            # to set the upper bound of the compile ranges
            max_int32 = 2**31 - 1
            last_compile_range = self.compile_ranges[-1]
            assert (
                last_compile_range.end
                == vllm_config.scheduler_config.max_num_batched_tokens
            )
            self.compile_ranges[-1] = Range(
                start=last_compile_range.start, end=max_int32
            )

146
147
        log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
        logger.debug_once(log_string)
148

149
150
151
        self.compile_sizes = self.compilation_config.compile_sizes
        log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
        logger.debug_once(log_string)
152
153

        self.sym_shape_indices = sym_shape_indices
154
        self.returns_tuple = returns_tuple
155

156
157
        # the entries for ranges that we need to either
        self.range_entries: dict[Range, RangeEntry] = {}
158

159
        # We only keep compilation management inside this class directly.
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        if self.compile_sizes is not None:
            for size in self.compile_sizes:
                if isinstance(size, str):
                    assert size == "cudagraph_capture_sizes"
                    raise NotImplementedError(
                        "cudagraph_capture_sizes not supported in compile_sizes."
                        "This should be handled in `post_init_cudagraph_sizes`."
                    )
                else:
                    assert isinstance(size, int)
                    range = Range(start=size, end=size)
                    if range not in self.compile_ranges:
                        self.range_entries[range] = RangeEntry(
                            compile_range=range,
                        )
175
176
177
178

        for range in self.compile_ranges:
            self.range_entries[range] = RangeEntry(
                compile_range=range,
179
180
            )

181
182
183
        # Track whether we've logged the graph for this subgraph (only log once)
        self._graph_logged = False

184
185
186
187
        if self.graph is not None:
            self.compile_all_ranges()
        else:
            self.load_all_ranges()
188

189
190
191
192
    def get_compiled_graph_wrapper(
        self, compiled_graph: Callable[..., Any]
    ) -> Callable[..., Any]:
        def compiled_graph_wrapper(*args: Any) -> Any:
193
194
195
196
197
198
199
200
201
202
203
204
205
            graph_output = compiled_graph(*args)
            # unpack the tuple if needed
            # TODO(rzou): the implication is that we're not
            # reading the python bytecode correctly in vLLM?
            if self.returns_tuple or not isinstance(graph_output, (tuple, list)):
                return graph_output
            else:
                return graph_output[0]

        return compiled_graph_wrapper

    def to_bytes(self) -> dict[str, bytes]:
        class StandaloneCompiledArtifactsPickler(Pickler):
206
            def reducer_override(self, obj: object) -> Any:
207
208
209
210
211
212
213
214
215
                if isinstance(obj, CachingAutotuner):
                    obj.prepare_for_pickle()
                    return pickle.loads, (
                        pickle.dumps(
                            obj,
                        ),
                    )
                return NotImplemented

216
        def serialize(fn: Callable[..., Any]) -> bytes:
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
            assert hasattr(fn, "serialize"), "fn must have serialize method"
            with torch._functorch.config.patch("bundled_autograd_cache", True):
                entry = fn.serialize()

                f = io.BytesIO()
                StandaloneCompiledArtifactsPickler(f).dump(entry)
                result = f.getvalue()
            return result

        out = {}

        for range_key, entry in self.range_entries.items():
            if not entry.compiled:
                logger.debug(
                    "entry with range %s not compiled, so cannot get its bytes",
                    range_key,
                )
                continue
            if hasattr(entry.runnable, "serialize"):
                out[str(range_key)] = serialize(entry.runnable)

        return out
239

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    def compile_all_ranges(self) -> None:
        """Compile all range entries for this piecewise subgraph up front."""
        assert self.graph is not None, (
            "Cannot compile without a graph. "
            "When loading from cache/AOT artifacts, "
            "compile_all_ranges should not be called."
        )

        for range_entry in self.range_entries.values():
            if range_entry.compiled:
                continue

            self._log_compile_start(range_entry.compile_range)

            if range_entry.compile_range.is_single_size():
                args_list = create_concrete_args(
                    self.graph, range_entry.compile_range.start
                )
258
            else:
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
                args_list = get_fake_args_from_graph(self.graph)

            # TODO(https://github.com/vllm-project/vllm/issues/35766)
            # Can we remove strict_autograd_cache and
            # force_non_lazy_backward_lowering overrides?
            # I added them explicitly because this is what they are
            # set to before the refactor
            # (https://github.com/vllm-project/vllm/pull/35472).
            # They affect the aotautograd cache key computation
            # but they shouldn't have any effect on the actual
            # compilation.
            config_patches = dict(
                bundled_autograd_cache=True,
                strict_autograd_cache=False,
            )
            if hasattr(torch._functorch.config, "force_non_lazy_backward_lowering"):
                config_patches["force_non_lazy_backward_lowering"] = False
            with torch._functorch.config.patch(**config_patches):
                range_entry.runnable = self.vllm_backend.compiler_manager.compile(
                    self.graph,
                    args_list,
                    self.vllm_backend.inductor_config,
                    self.compilation_config,
                    compile_range=range_entry.compile_range,
                    graph_index=self.piecewise_compile_index,
                    num_graphs=self.total_piecewise_compiles,
                )

            range_entry.compiled = True
288

289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    def _log_compile_start(self, compile_range: Range):
        """Log compilation event for TORCH_TRACE/tlparse."""
        is_cudagraph_size = (
            self.compile_sizes is not None and compile_range.start in self.compile_sizes
        )
        subgraph_index = self.piecewise_compile_index
        submod_name = self.submod_name
        trace_structured(
            "artifact",
            metadata_fn=lambda: {
                "name": "vllm_piecewise_compile_start",
                "encoding": "json",
            },
            payload_fn=lambda: json.dumps(
                {
                    "piecewise_index": subgraph_index,
                    "submod_name": submod_name,
                    "total_piecewise_compiles": self.total_piecewise_compiles,
                    "compile_range_start": compile_range.start,
                    "compile_range_end": compile_range.end,
                    "is_single_size": compile_range.is_single_size(),
                    "is_cudagraph_capture_size": is_cudagraph_size,
                }
            ),
        )

        # Log the subgraph graph dump only once per subgraph (not per size)
        # to reduce log file size. The graph code is the same for all sizes.
        if not self._graph_logged:
            self._graph_logged = True
            assert self.graph is not None
            trace_structured(
                "graph_dump",
                metadata_fn=lambda: {
                    "name": f"vllm_{submod_name}",
                },
                payload_fn=lambda: self.graph.print_readable(print_output=False),
            )

328
329
    def load_all_ranges(self) -> None:
        """Load all pre-compiled runnables for this piecewise subgraph.
330

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
        Called during warm start to wrap all cached compiled_runnables
        into range_entry.runnable up front, analogous to compile_all_ranges()
        for the cold start path.
        """
        assert self.compiled_runnables is not None, (
            "load_all_ranges should only be called when compiled_runnables "
            "is set (warm start / cache loading path)."
        )
        for range_entry in self.range_entries.values():
            if range_entry.compiled:
                continue
            key = str(range_entry.compile_range)
            assert key in self.compiled_runnables, (
                f"Missing compiled runnable for range {range_entry.compile_range}. "
                f"Available keys: {list(self.compiled_runnables.keys())}"
            )
            range_entry.runnable = self.get_compiled_graph_wrapper(
                self.compiled_runnables[key]
            )
350
351
            range_entry.compiled = True

352
    def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
353
354
355
        # First we try to find the range entry for the concrete compile size
        # If not found, we search for the range entry
        # that contains the runtime shape.
356
357
358
        if self.compile_sizes is None:
            return None

359
360
361
362
363
364
365
366
        if runtime_shape in self.compile_sizes:
            return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
        else:
            for range in self.compile_ranges:
                if runtime_shape in range:
                    return self.range_entries[range]
        return None

367
    def __call__(self, *args: Any) -> Any:
368
369
370
371
        runtime_shape = args[self.sym_shape_indices[0]]
        range_entry = self._find_range_for_shape(runtime_shape)

        assert range_entry is not None, (
372
            f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
373
        )
374
375
376
377
378
        assert range_entry.compiled, (
            "All ranges should be compiled or loaded up front in "
            "PiecewiseBackend.__init__. "
            f"range_entry={range_entry.compile_range}"
        )
379
        return range_entry.runnable(*args)