piecewise_backend.py 14.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import dataclasses
5
import io
6
import json
7
import pickle
8
from collections.abc import Callable
9
from pickle import Pickler
10
from typing import Any
11

12
import torch._functorch.config
13
import torch.fx as fx
14
from torch._dynamo.utils import dynamo_timed
15
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
16
from torch._logging._internal import trace_structured
17
18
19

from vllm.compilation.backends import VllmBackend
from vllm.config import VllmConfig
20
from vllm.config.utils import Range
21
22
23
24
25
from vllm.logger import init_logger

logger = init_logger(__name__)


26
27
28
29
30
31
32
33
34
35
36
37
def get_fake_args_from_graph(graph: fx.GraphModule) -> list[Any]:
    """Get fake args directly from graph placeholder nodes."""
    fake_args = []
    for node in graph.graph.nodes:
        if node.op == "placeholder":
            fake_args.append(node.meta["example_value"])
        else:
            break
    return fake_args


def create_concrete_args(graph: fx.GraphModule, size: int) -> list[Any]:
38
    """Create Fake example inputs with symbolic dims replaced by a concrete size.
39

40
41
    Used for single-size compilation where we need concrete-shaped inputs.
    The Dynamo-captured graph gives us example inputs with SymInts in them.
42
43
    """
    from torch._prims_common import compute_required_storage_length
44
45
    from torch._subclasses.fake_tensor import FakeTensorMode
    from torch.fx.experimental.symbolic_shapes import ShapeEnv, is_symbolic
46
47
48
49
50
51
52
53

    def concretize(sym_val: Any) -> int:
        """Replace all symbolic variables in a SymInt expression with size."""
        if not is_symbolic(sym_val):
            return int(sym_val)
        expr = sym_val.node.expr
        return int(expr.subs({s: size for s in expr.free_symbols}))

54
55
    fake_mode = FakeTensorMode(shape_env=ShapeEnv())

56
    args: list[Any] = []
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    with fake_mode:
        for node in graph.graph.nodes:
            if node.op != "placeholder":
                break
            val = node.meta["example_value"]
            if isinstance(val, torch.SymInt):
                args.append(concretize(val))
            elif isinstance(val, torch.Tensor):
                new_shape = tuple(concretize(d) for d in val.shape)
                new_strides = tuple(concretize(s) for s in val.stride())
                new_storage_offset = concretize(val.storage_offset())
                needed_size = compute_required_storage_length(
                    new_shape, new_strides, new_storage_offset
                )
                t = torch.empty(needed_size, dtype=val.dtype, device=val.device)
                t = t.as_strided(new_shape, new_strides, new_storage_offset)
                args.append(t)
            else:
                args.append(val)
76
77
78
    return args


79
@dataclasses.dataclass
80
81
class RangeEntry:
    compile_range: Range
82
    compiled: bool = False
83
    runnable: Callable[..., Any] = None  # type: ignore
84
85


86
class PiecewiseBackend:
87
88
    def __init__(
        self,
89
        graph: fx.GraphModule | None,
90
91
92
93
94
        vllm_config: VllmConfig,
        piecewise_compile_index: int,
        total_piecewise_compiles: int,
        sym_shape_indices: list[int],
        vllm_backend: VllmBackend,
95
        returns_tuple: bool,
96
        compiled_runnables: dict[str, Callable[..., Any]] | None = None,
97
        submod_name: str = "",
98
    ):
99
100
        """
        The backend for piecewise compilation.
101
        It mainly handles the compilation of static shapes and
102
        dispatching based on runtime shape.
103
104
105
106

        We will compile `self.graph` once for the general shape,
        and then compile for different shapes specified in
        `compilation_config.compile_sizes`.
107
108
109
110
111
112
113
114
115
116

        This class supports two mutually exclusive modes:
        1. Compilation (graph is set, compiled_runnables is None):
           Used during initial compilation when we have the FX graph
           and need to compile it for each shape range.
        2. Precompilation (graph is None, compiled_runnables is set):
           Used when loading from cache/AOT artifacts where we already
           have pre-compiled callables and don't need the original graph.

        Exactly one of graph or compiled_runnables must be provided.
117
        """
118
119
120
121
        assert bool(graph is not None) ^ bool(compiled_runnables is not None), (
            "exactly one of graph and compiled_runnables should be set."
        )

122
123
124
125
126
127
        self.graph = graph
        self.vllm_config = vllm_config
        self.compilation_config = vllm_config.compilation_config
        self.piecewise_compile_index = piecewise_compile_index
        self.total_piecewise_compiles = total_piecewise_compiles
        self.vllm_backend = vllm_backend
128
        self.compiled_runnables = compiled_runnables
129
        self.submod_name = submod_name
130
131

        self.is_first_graph = piecewise_compile_index == 0
132
        self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
133

134
        self.is_full_graph = total_piecewise_compiles == 1
135
        self.is_encoder_compilation = vllm_backend.is_encoder
136

137
        self.compile_ranges = self.compilation_config.get_compile_ranges()
138
139
140
141
142
143
144
145
146
147
148
149
150
        if self.is_encoder_compilation:
            # For encoder compilation we use the max int32 value
            # to set the upper bound of the compile ranges
            max_int32 = 2**31 - 1
            last_compile_range = self.compile_ranges[-1]
            assert (
                last_compile_range.end
                == vllm_config.scheduler_config.max_num_batched_tokens
            )
            self.compile_ranges[-1] = Range(
                start=last_compile_range.start, end=max_int32
            )

151
152
        log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
        logger.debug_once(log_string)
153

154
155
156
        self.compile_sizes = self.compilation_config.compile_sizes
        log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
        logger.debug_once(log_string)
157
158

        self.sym_shape_indices = sym_shape_indices
159
        self.returns_tuple = returns_tuple
160

161
162
        # the entries for ranges that we need to either
        self.range_entries: dict[Range, RangeEntry] = {}
163

164
        # We only keep compilation management inside this class directly.
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        if self.compile_sizes is not None:
            for size in self.compile_sizes:
                if isinstance(size, str):
                    assert size == "cudagraph_capture_sizes"
                    raise NotImplementedError(
                        "cudagraph_capture_sizes not supported in compile_sizes."
                        "This should be handled in `post_init_cudagraph_sizes`."
                    )
                else:
                    assert isinstance(size, int)
                    range = Range(start=size, end=size)
                    if range not in self.compile_ranges:
                        self.range_entries[range] = RangeEntry(
                            compile_range=range,
                        )
180
181
182
183

        for range in self.compile_ranges:
            self.range_entries[range] = RangeEntry(
                compile_range=range,
184
185
            )

186
187
188
        # Track whether we've logged the graph for this subgraph (only log once)
        self._graph_logged = False

189
190
191
192
        if self.graph is not None:
            self.compile_all_ranges()
        else:
            self.load_all_ranges()
193

194
195
196
197
    def get_compiled_graph_wrapper(
        self, compiled_graph: Callable[..., Any]
    ) -> Callable[..., Any]:
        def compiled_graph_wrapper(*args: Any) -> Any:
198
199
200
201
202
203
204
205
206
207
208
209
210
            graph_output = compiled_graph(*args)
            # unpack the tuple if needed
            # TODO(rzou): the implication is that we're not
            # reading the python bytecode correctly in vLLM?
            if self.returns_tuple or not isinstance(graph_output, (tuple, list)):
                return graph_output
            else:
                return graph_output[0]

        return compiled_graph_wrapper

    def to_bytes(self) -> dict[str, bytes]:
        class StandaloneCompiledArtifactsPickler(Pickler):
211
            def reducer_override(self, obj: object) -> Any:
212
213
214
215
216
217
218
219
220
                if isinstance(obj, CachingAutotuner):
                    obj.prepare_for_pickle()
                    return pickle.loads, (
                        pickle.dumps(
                            obj,
                        ),
                    )
                return NotImplemented

221
        def serialize(fn: Callable[..., Any]) -> bytes:
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
            assert hasattr(fn, "serialize"), "fn must have serialize method"
            with torch._functorch.config.patch("bundled_autograd_cache", True):
                entry = fn.serialize()

                f = io.BytesIO()
                StandaloneCompiledArtifactsPickler(f).dump(entry)
                result = f.getvalue()
            return result

        out = {}

        for range_key, entry in self.range_entries.items():
            if not entry.compiled:
                logger.debug(
                    "entry with range %s not compiled, so cannot get its bytes",
                    range_key,
                )
                continue
            if hasattr(entry.runnable, "serialize"):
                out[str(range_key)] = serialize(entry.runnable)

        return out
244

245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    def compile_all_ranges(self) -> None:
        """Compile all range entries for this piecewise subgraph up front."""
        assert self.graph is not None, (
            "Cannot compile without a graph. "
            "When loading from cache/AOT artifacts, "
            "compile_all_ranges should not be called."
        )

        for range_entry in self.range_entries.values():
            if range_entry.compiled:
                continue

            self._log_compile_start(range_entry.compile_range)

            if range_entry.compile_range.is_single_size():
                args_list = create_concrete_args(
                    self.graph, range_entry.compile_range.start
                )
263
            else:
264
265
                args_list = get_fake_args_from_graph(self.graph)

266
267
268
269
270
271
272
273
            range_entry.runnable = self.vllm_backend.compiler_manager.compile(
                self.graph,
                args_list,
                self.vllm_backend.inductor_config,
                self.compilation_config,
                compile_range=range_entry.compile_range,
                graph_index=self.piecewise_compile_index,
                num_graphs=self.total_piecewise_compiles,
274
                is_encoder=self.vllm_backend.is_encoder,
275
276
277
            )

            range_entry.compiled = True
278

279
    @dynamo_timed("vllm_log_compile_start_torch_trace_only")
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
    def _log_compile_start(self, compile_range: Range):
        """Log compilation event for TORCH_TRACE/tlparse."""
        is_cudagraph_size = (
            self.compile_sizes is not None and compile_range.start in self.compile_sizes
        )
        subgraph_index = self.piecewise_compile_index
        submod_name = self.submod_name
        trace_structured(
            "artifact",
            metadata_fn=lambda: {
                "name": "vllm_piecewise_compile_start",
                "encoding": "json",
            },
            payload_fn=lambda: json.dumps(
                {
                    "piecewise_index": subgraph_index,
                    "submod_name": submod_name,
                    "total_piecewise_compiles": self.total_piecewise_compiles,
                    "compile_range_start": compile_range.start,
                    "compile_range_end": compile_range.end,
                    "is_single_size": compile_range.is_single_size(),
                    "is_cudagraph_capture_size": is_cudagraph_size,
                }
            ),
        )

        # Log the subgraph graph dump only once per subgraph (not per size)
        # to reduce log file size. The graph code is the same for all sizes.
        if not self._graph_logged:
            self._graph_logged = True
            assert self.graph is not None
            trace_structured(
                "graph_dump",
                metadata_fn=lambda: {
                    "name": f"vllm_{submod_name}",
                },
                payload_fn=lambda: self.graph.print_readable(print_output=False),
            )

319
320
    def load_all_ranges(self) -> None:
        """Load all pre-compiled runnables for this piecewise subgraph.
321

322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        Called during warm start to wrap all cached compiled_runnables
        into range_entry.runnable up front, analogous to compile_all_ranges()
        for the cold start path.
        """
        assert self.compiled_runnables is not None, (
            "load_all_ranges should only be called when compiled_runnables "
            "is set (warm start / cache loading path)."
        )
        for range_entry in self.range_entries.values():
            if range_entry.compiled:
                continue
            key = str(range_entry.compile_range)
            assert key in self.compiled_runnables, (
                f"Missing compiled runnable for range {range_entry.compile_range}. "
                f"Available keys: {list(self.compiled_runnables.keys())}"
            )
            range_entry.runnable = self.get_compiled_graph_wrapper(
                self.compiled_runnables[key]
            )
341
342
            range_entry.compiled = True

343
    def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
344
345
346
        # First we try to find the range entry for the concrete compile size
        # If not found, we search for the range entry
        # that contains the runtime shape.
347
348
349
        if self.compile_sizes is None:
            return None

350
351
352
353
354
355
356
357
        if runtime_shape in self.compile_sizes:
            return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
        else:
            for range in self.compile_ranges:
                if runtime_shape in range:
                    return self.range_entries[range]
        return None

358
    def __call__(self, *args: Any) -> Any:
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
        if self.sym_shape_indices:
            runtime_shape = args[self.sym_shape_indices[0]]
            range_entry = self._find_range_for_shape(runtime_shape)
            assert range_entry is not None, (
                f"Shape: {runtime_shape} out of considered ranges: "
                f"{self.compile_ranges}"
            )
        else:
            # All inputs have static shapes; use the only compiled range_entry
            compiled_entries = [re for re in self.range_entries.values() if re.compiled]
            assert len(compiled_entries) == 1, (
                f"Expected exactly one compiled range_entry for static shape "
                f"compilation, but found {len(compiled_entries)}"
            )
            range_entry = compiled_entries[0]
374

375
376
377
378
379
        assert range_entry.compiled, (
            "All ranges should be compiled or loaded up front in "
            "PiecewiseBackend.__init__. "
            f"range_entry={range_entry.compile_range}"
        )
380
        return range_entry.runnable(*args)