caching.py 22.8 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
import contextlib
5
import hashlib
6
import inspect
7
import os
8
import pickle
9
10
from collections.abc import Callable, Sequence
from typing import Any, Literal
11
12
13
from unittest.mock import patch

import torch
14
15
from torch._subclasses import FakeTensorMode
from torch.fx._graph_pickler import GraphPickler, Options
16
17
18
from torch.utils import _pytree as pytree

import vllm.envs as envs
19
from vllm.compilation.codegen import compile_execution_fn
20
from vllm.compilation.compiler_interface import get_inductor_factors
21
from vllm.compilation.counter import compilation_counter
22
from vllm.config import VllmConfig, get_current_vllm_config
Driss Guessous's avatar
Driss Guessous committed
23
from vllm.config.utils import hash_factors
24
from vllm.logger import init_logger
25
from vllm.utils.hashing import safe_hash
26
27
28
29
30
31
32
33
34
35
36

try:
    from torch._dynamo.aot_compile import SerializableCallable
except ImportError:
    SerializableCallable = object

assert isinstance(SerializableCallable, type)

logger = init_logger(__name__)


37
38
39
40
41
42
43
44
45
46
47
48
49
50
class StandaloneCompiledArtifacts:
    """Storage for standalone compiled artifacts with content-based deduplication.

    Deduplication works via a two-level indirection:
    1. `submodule_bytes` maps "{submod_name}_{shape}" -> SHA256 hash
    2. `submodule_bytes_store` maps SHA256 hash -> actual bytes

    When inserting, we compute the SHA256 hash of the bytes. If the hash
    already exists in `submodule_bytes_store`, we reuse the existing entry
    rather than storing duplicate bytes. This is common because submodules
    often compile to identical artifacts (e.g., identical transformer layers
    split on attn)
    """

51
    def __init__(self) -> None:
52
        # dict from submodule name to byte hash
53
        self.submodule_bytes: dict[str, str] = {}
54
        # dict from byte hash to bytes
55
        self.submodule_bytes_store: dict[str, bytes] = {}
56
        # dict from byte hash to loaded module
57
        self.loaded_submodule_store: dict[str, Any] = {}
58

59
    def insert(self, submod_name: str, shape: str, entry: bytes) -> None:
60
61
62
63
64
65
        hasher = hashlib.sha256()
        hasher.update(entry)
        hex_digest = hasher.hexdigest()
        self.submodule_bytes[f"{submod_name}_{shape}"] = hex_digest
        if hex_digest not in self.submodule_bytes_store:
            self.submodule_bytes_store[hex_digest] = entry
66
            compilation_counter.num_compiled_artifacts_saved += 1
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
            logger.debug(
                "inserting new artifact for submod %s with shape %s "
                "(%s bytes) at hash %s",
                submod_name,
                shape,
                len(entry),
                hex_digest,
            )
        else:
            logger.debug(
                "reusing existing cache artifact for submod %s "
                "with shape %s (%s bytes) at hash %s",
                submod_name,
                shape,
                len(entry),
                hex_digest,
            )

    def get(self, submod_name: str, shape: str) -> bytes:
        logger.debug(
            "getting artifact for submod %s with shape %s",
            submod_name,
            shape,
        )
        return self.submodule_bytes_store[
            self.submodule_bytes[f"{submod_name}_{shape}"]
        ]

95
    def get_loaded(self, submod_name: str, shape: str) -> Any:
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        logger.debug(
            "getting artifact for submod %s with shape %s",
            submod_name,
            shape,
        )
        return self.loaded_submodule_store[
            self.submodule_bytes[f"{submod_name}_{shape}"]
        ]

    def size_bytes(self) -> int:
        return sum(len(entry) for entry in self.submodule_bytes_store.values())

    def num_artifacts(self) -> int:
        return len(self.submodule_bytes_store)

    def num_entries(self) -> int:
        return len(self.submodule_bytes)

    def submodule_names(self) -> list[str]:
        # get unique "{submod_name}" from "{submod_name}_{shape}", preserving order
        names = [cache_key.rsplit("_", 1)[0] for cache_key in self.submodule_bytes]
        return list(dict.fromkeys(names))

    def load_all(self) -> None:
        import concurrent.futures

        # check already loaded
        if len(self.loaded_submodule_store) == len(self.submodule_bytes_store):
            return

        from torch._inductor.standalone_compile import AOTCompiledArtifact

128
        def _load_entry(entry_bytes: bytes) -> AOTCompiledArtifact:
129
            entry = pickle.loads(entry_bytes)
130
            compilation_counter.num_compiled_artifacts_loaded += 1
131
132
133
134
135
136
137
138
139
140
141
            return AOTCompiledArtifact.deserialize(entry)

        with concurrent.futures.ThreadPoolExecutor() as executor:
            entries = list(self.submodule_bytes_store.values())
            loaded_entries = list(executor.map(_load_entry, entries))

        for i, k in enumerate(self.submodule_bytes_store.keys()):
            self.loaded_submodule_store[k] = loaded_entries[i]

        logger.debug("loaded all %s submodules", self.num_artifacts())

142
    def __getstate__(self) -> dict[str, dict[str, str] | dict[str, bytes]]:
143
144
145
146
147
        return {
            "submodule_bytes": self.submodule_bytes,
            "submodule_bytes_store": self.submodule_bytes_store,
        }

148
    def __setstate__(self, state: dict[str, dict[str, Any]]) -> None:
149
150
151
152
153
        self.submodule_bytes = state["submodule_bytes"]
        self.submodule_bytes_store = state["submodule_bytes_store"]
        self.loaded_submodule_store = {}


154
155
156
157
158
159
160
161
162
163
164
165
@contextlib.contextmanager
def patch_pytree_map_over_slice():
    pytree._private_register_pytree_node(
        slice, lambda x: ([x.start, x.stop, x.step], None), lambda x, c: slice(*x)
    )

    try:
        yield
    finally:
        pytree._deregister_pytree_node(slice)


166
class VllmSerializableFunction(SerializableCallable):  # type: ignore[misc]
167
168
169
170
171
172
173
174
175
176
177
    """
    A wrapper around a compiled function by vllm. It will forward the tensor
    inputs to the compiled function and return the result.
    It also implements a serialization interface to support PyTorch's precompile
    with custom backend, so that we can save and load the compiled function on
    disk. There's no need to wrap around the compiled function if we don't want
    to serialize them in particular cases.
    Right now serialization for the custom backend is done via
    serializing the Dynamo fx graph plus example inputs.
    """

178
    def __init__(
179
        self,
180
        graph_module: torch.fx.GraphModule | bytes,
181
182
183
184
        example_inputs: Sequence[Any],
        prefix: str,
        optimized_call: Callable[..., Any],
        is_encoder: bool = False,
185
186
        vllm_backend: Any | None = None,
        sym_tensor_indices: list[int] | None = None,
187
        aot_autograd_config: dict[str, Any] | None = None,
188
189
        execution_code: str | None = None,
        submod_names: list[str] | None = None,
190
    ) -> None:
191
192
193
194
        self.graph_module = graph_module
        self.example_inputs = example_inputs
        self.prefix = prefix
        self.optimized_call = optimized_call
195
        self.is_encoder = is_encoder
196
        self.shape_env = None
197
198
        self.vllm_backend = vllm_backend
        self.sym_tensor_indices = sym_tensor_indices
199
200
        self.execution_code = execution_code
        self.submod_names = submod_names
201
        self._fake_mode: Any | None = None
202
203
204
205
206
207

        import torch._functorch.config as functorch_config

        self.aot_autograd_config = (
            aot_autograd_config or functorch_config.save_config_portable()
        )
208
209
210
211
212
213
        sym_input = next(
            (i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
        )
        if sym_input is not None:
            self.shape_env = sym_input.node.shape_env

214
    def __call__(self, *args: Any, **kwargs: Any) -> Any:
215
216
217
        return self.optimized_call(*args, **kwargs)

    @classmethod
218
    def serialize_graph_module(cls, graph_module: torch.fx.GraphModule) -> bytes:
219
220
221
222
        import sympy

        graph_reducer_override = GraphPickler.reducer_override

223
224
225
        def _graph_reducer_override(
            self: GraphPickler, obj: Any
        ) -> tuple[Callable[..., Any], tuple[Any, ...]] | Any:
226
227
228
229
230
231
232
233
234
235
            if (
                inspect.isclass(obj)
                and issubclass(obj, sympy.Function)
                and hasattr(obj, "_torch_unpickler")
            ):
                return obj._torch_unpickler, (obj._torch_handler_name,)
            if isinstance(obj, FakeTensorMode):
                return type(None), ()
            return graph_reducer_override(self, obj)

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
        with (
            patch.object(GraphPickler, "reducer_override", _graph_reducer_override),
            patch_pytree_map_over_slice(),
        ):
            return GraphPickler.dumps(graph_module, Options(ops_filter=None))

    @classmethod
    def deserialize_graph_module(
        cls, data: bytes, fake_mode: FakeTensorMode
    ) -> torch.fx.GraphModule:
        with patch_pytree_map_over_slice():
            return GraphPickler.loads(data, fake_mode)

    @classmethod
    def serialize_compile_artifacts(
        cls, compiled_fn: "VllmSerializableFunction"
    ) -> bytes:
        state = compiled_fn.__dict__.copy()
        state.pop("optimized_call")
        state.pop("shape_env")
        state.pop("vllm_backend", None)
        state.pop("_fake_mode", None)
        for node in state["graph_module"].graph.nodes:
            node.meta.pop("source_fn_stack", None)
            node.meta.pop("nn_module_stack", None)
        for name, submod in state["graph_module"].named_children():
            if hasattr(submod, "graph"):
                for node in submod.graph.nodes:
                    node.meta.pop("source_fn_stack", None)
                    node.meta.pop("nn_module_stack", None)

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        if state.get("sym_tensor_indices"):
            # put tensor inputs on meta device since their data
            # isn't needed, yet we need the meta for make_copy_and_call
            state["example_inputs"] = pytree.tree_map_only(
                torch.Tensor,
                lambda inp: torch.empty_like(inp, device="meta"),
                state["example_inputs"],
            )
        else:
            # mask off all tensor inputs since they are large and not needed.
            state["example_inputs"] = pytree.tree_map_only(
                torch.Tensor,
                lambda inp: torch.empty_like(inp, device="meta"),
                state["example_inputs"],
            )
282
283
284

        state["graph_module"] = cls.serialize_graph_module(state["graph_module"])
        state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
285
286
287
288
289
290
291
292
293
294

        if compiled_fn.vllm_backend:
            (
                standalone_compile_artifacts,
                sym_shape_indices_map,
                returns_tuple_map,
            ) = compiled_fn.vllm_backend.collect_standalone_compile_artifacts()
            state["standalone_compile_artifacts"] = standalone_compile_artifacts
            state["sym_shape_indices_map"] = sym_shape_indices_map
            state["returns_tuple_map"] = returns_tuple_map
295
296
297
298
299
300
301
302
303
        return pickle.dumps(state)

    @classmethod
    def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction":
        from torch._guards import TracingContext, tracing
        from torch.fx.experimental.symbolic_shapes import ShapeEnv

        state = pickle.loads(data)
        fake_mode = FakeTensorMode(shape_env=ShapeEnv())
304

305
        state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
306
307
308
309
310

        standalone_compile_artifacts = state.pop("standalone_compile_artifacts", None)
        sym_shape_indices_map = state.pop("sym_shape_indices_map", {})
        returns_tuple_map = state.pop("returns_tuple_map", {})

311
312
313
314
315
316
        saved_aot_autograd_config = state["aot_autograd_config"]
        if saved_aot_autograd_config is not None:
            functorch_ctx = torch._functorch.config.patch(saved_aot_autograd_config)
        else:
            functorch_ctx = contextlib.nullcontext()

317
318
319
320
321
322
        if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
            assert standalone_compile_artifacts is not None
            submod_names = standalone_compile_artifacts.submodule_names()
            num_submods = len(submod_names)
            num_artifacts = standalone_compile_artifacts.num_artifacts()

323
324
325
326
327
328
329
            with functorch_ctx:
                fn = reconstruct_serializable_fn_from_mega_artifact(
                    state=state,
                    standalone_compile_artifacts=standalone_compile_artifacts,
                    vllm_config=get_current_vllm_config(),
                    sym_shape_indices_map=sym_shape_indices_map,
                    returns_tuple_map=returns_tuple_map,
330
                    fake_mode=fake_mode,
331
                )
332
333

            logger.info(
334
335
336
337
                "reconstructed serializable fn from standalone compile "
                "artifacts. num_artifacts=%d num_submods=%d",
                num_artifacts,
                num_submods,
338
339
340
341
            )

            return fn

342
343
344
345
346
        state["graph_module"] = cls.deserialize_graph_module(
            state["graph_module"], fake_mode
        )
        state["graph_module"].recompile()

347
348
349
350
351
        # Fall back to standard VllmBackend.
        # Use a lazy closure: the backend needs traced_files for cache
        # dir computation, but those are only populated after
        # _verify_source_unchanged runs in decorators.py (which happens
        # after deserialization completes).
352
353
        from vllm.compilation.backends import VllmBackend

354
        is_encoder = state.get("is_encoder", False)
355
356
        vllm_config = get_current_vllm_config()
        compile_inputs = list(state["example_inputs"])
357

358
        def optimized_call(*example_inputs: Any) -> Any:
359
360
361
            vllm_backend: VllmBackend = VllmBackend(
                vllm_config, state["prefix"], is_encoder
            )
362
            with tracing(TracingContext(fake_mode)), functorch_ctx:
363
364
365
                fn.optimized_call = vllm_backend(
                    state["graph_module"], compile_inputs
                ).optimized_call
366
                fn.vllm_backend = vllm_backend
367
368
369
            return fn.optimized_call(*example_inputs)

        fn = cls(**state, optimized_call=optimized_call)
370
        fn._fake_mode = fake_mode
371
372
        return fn

373
374
375
376
377
378
379
380
381
382
383
384
385
386
    def finalize_loading(self, vllm_config: VllmConfig) -> None:
        """Eagerly initialize the compiled backend and perform all loading.

        Must be called after _verify_source_unchanged has populated
        compilation_config.traced_files, which is needed for cache dir
        computation.
        """
        if self._fake_mode is None:
            return  # Already finalized, or mega path (no _fake_mode set)

        from torch._guards import TracingContext, tracing

        from vllm.compilation.backends import VllmBackend

387
388
389
390
391
392
        saved_aot_autograd_config = self.aot_autograd_config
        if saved_aot_autograd_config is not None:
            functorch_ctx = torch._functorch.config.patch(saved_aot_autograd_config)
        else:
            functorch_ctx = contextlib.nullcontext()

393
        vllm_backend = VllmBackend(vllm_config, self.prefix, self.is_encoder)
394
        with tracing(TracingContext(self._fake_mode)), functorch_ctx:
395
396
397
398
399
400
            result = vllm_backend(self.graph_module, list(self.example_inputs))
            self.optimized_call = result.optimized_call
            self.vllm_backend = vllm_backend

        self._fake_mode = None

401
    @property
402
    def co_name(self) -> Literal["VllmSerializableFunction"]:
403
404
405
406
407
408
        """
        Used for depyf debugging.
        """
        return "VllmSerializableFunction"


409
410
411
412
413
414
def reconstruct_serializable_fn_from_mega_artifact(
    state: dict[str, Any],
    standalone_compile_artifacts: "StandaloneCompiledArtifacts",
    vllm_config: VllmConfig,
    sym_shape_indices_map: dict[str, list[int]],
    returns_tuple_map: dict[str, bool],
415
    fake_mode: FakeTensorMode,
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
) -> "VllmSerializableFunction":
    """Construct a VllmSerializableFunction from cached inductor artifacts.

    This function reconstructs a callable model from pre-compiled inductor
    artifacts without re-running the compilation. It:
    1. Loads all cached artifacts
    2. Builds compiled callables for each submodule/shape
    3. Creates PiecewiseBackend instances that dispatch to cached artifacts
    4. Wraps with cudagraph if needed
    5. Returns the final VllmSerializableFunction

    Note: This function shares similar logic with PiecewiseCompileInterpreter
    in backends.py. Both create PiecewiseBackend instances and wrap them with
    cudagraph. The key difference is:
    - this function: PiecewiseBackend receives pre-compiled runnables
      (compiled_runnables is set, graph is None)
    - PiecewiseCompileInterpreter: PiecewiseBackend receives the FX graph
      to compile (graph is set, compiled_runnables is None)

    If modifying the backend creation/wrapping logic, consider updating both.

    Args:
        state: Deserialized state dict containing graph_module, example_inputs,
            prefix, sym_tensor_indices, is_encoder, etc.
        standalone_compile_artifacts: The StandaloneCompiledArtifacts containing
            pre-compiled artifacts for each submodule/shape combination.
        vllm_config: The vLLM configuration.
        sym_shape_indices_map: Mapping from submod_name to sym_shape_indices.
        returns_tuple_map: Mapping from submod_name to returns_tuple.

    Returns:
        A VllmSerializableFunction that can be called directly.
    """
    from vllm.compilation.backends import (
        VllmBackend,
        make_copy_and_call,
        wrap_with_cudagraph_if_needed,
    )
    from vllm.compilation.piecewise_backend import PiecewiseBackend

    prefix = state["prefix"]
    is_encoder = state.get("is_encoder", False)
    compilation_config = vllm_config.compilation_config

    standalone_compile_artifacts.load_all()

462
    piecewise_submod_names = standalone_compile_artifacts.submodule_names()
463
    compiled_callables: dict[str, dict[str, Callable[..., Any]]] = {}
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480

    for cache_key in standalone_compile_artifacts.submodule_bytes:
        submod_name, shape_str = cache_key.rsplit("_", 1)
        compiled_callables.setdefault(submod_name, {})[shape_str] = (
            standalone_compile_artifacts.get_loaded(submod_name, shape_str)
        )

    vllm_backend = VllmBackend(vllm_config, prefix, is_encoder)
    dummy_cache_dir = os.path.join(envs.VLLM_CACHE_ROOT, "dummy_cache")
    os.makedirs(dummy_cache_dir, exist_ok=True)
    vllm_backend.compiler_manager.initialize_cache(
        cache_dir=dummy_cache_dir,
        disable_cache=True,
        prefix=prefix,
    )

    # spot check that cached submodules exist in the graph structure
481
482
483
    # if an old cache is used, this will fail but that's fine because
    # we will just try this error and re-generate the new cache.
    graph_children = set(state["submod_names"])
484
    missing = set(piecewise_submod_names) - graph_children
485
486
487
488
489
    assert not missing, (
        f"artifacts reference submodules not in graph: {missing}. "
        f"graph has: {sorted(graph_children)}"
    )

490
    submod_callables = {}
491
    for i, submod_name in enumerate(piecewise_submod_names):
492
493
494
495
496
497
498
499
500
501
        assert submod_name in sym_shape_indices_map and submod_name in returns_tuple_map

        sym_shape_indices = sym_shape_indices_map[submod_name]
        returns_tuple = returns_tuple_map[submod_name]
        runnables = compiled_callables[submod_name]

        piecewise_backend = PiecewiseBackend(
            graph=None,  # not needed for cached artifacts
            vllm_config=vllm_config,
            piecewise_compile_index=i,
502
            total_piecewise_compiles=len(piecewise_submod_names),
503
504
505
506
507
508
509
            sym_shape_indices=sym_shape_indices,
            vllm_backend=vllm_backend,
            returns_tuple=returns_tuple,
            compiled_runnables=runnables,
        )

        is_first = i == 0
510
        is_last = i == len(piecewise_submod_names) - 1
511
512
513
514
515
516
517
518
        wrapped_backend = wrap_with_cudagraph_if_needed(
            piecewise_backend,
            vllm_config,
            compilation_config,
            is_first,
            is_last,
        )

519
        submod_callables[submod_name] = wrapped_backend
520
521
522
523
524
        logger.debug(
            "Replaced submodule %s with piecewise backend from cache",
            submod_name,
        )

525
526
527
528
529
530
531
532
    # Use codegen'd execution code if available, fall back to split_gm
    execution_code = state.get("execution_code")
    submod_names = state.get("submod_names")
    if execution_code is not None and submod_names is not None:
        runtime_callable = compile_execution_fn(
            execution_code, submod_callables, submod_names
        )
    else:
533
534
535
536
537
538
        logger.warning(
            "No execution code found, falling back to graph module execution."
        )
        runtime_callable = GraphPickler.loads(
            state["graph_module"], fake_mode=fake_mode
        )
539

540
541
542
543
544
545
546
547
    if compilation_config.cudagraph_copy_inputs:
        sym_tensor_indices = state["sym_tensor_indices"]
        input_buffers = [
            torch.empty_like(
                state["example_inputs"][idx], device=vllm_config.device_config.device
            )
            for idx in sym_tensor_indices
        ]
548
549
550
        optimized_call = make_copy_and_call(
            sym_tensor_indices, input_buffers, runtime_callable
        )
551
    else:
552
        optimized_call = runtime_callable
553
554
555
556
557
558
559
560
561
562

    fn = VllmSerializableFunction(
        **state,
        optimized_call=optimized_call,
        vllm_backend=None,
    )
    return fn


def aot_compile_hash_factors(vllm_config: VllmConfig) -> list[str]:
563
564
565
    factors = []
    # 0. factors come from the env, for example, The values of
    # VLLM_PP_LAYER_PARTITION will affect the computation graph.
Driss Guessous's avatar
Driss Guessous committed
566
    env_hash = hash_factors(envs.compile_factors())
567
568
569
570
571
572
    factors.append(env_hash)

    # 1. factors come from the vllm_config (it mainly summarizes how the
    #    model is created)
    config_hash = vllm_config.compute_hash()
    factors.append(config_hash)
573
574
575
576
577

    # 2. inductor factors if applicable
    if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
        factors.extend(get_inductor_factors())

578
579
580
581
582
583
584
585
586
587
588
589
590
    return factors


def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
    items = list(sorted(file_contents.items(), key=lambda x: x[0]))
    hash_content = []
    for filepath, content in items:
        hash_content.append(filepath)
        if filepath == "<string>":
            # This means the function was dynamically generated, with
            # e.g. exec(). We can't actually check these.
            continue
        hash_content.append(content)
591
    result: str = safe_hash(
592
593
        "\n".join(hash_content).encode(), usedforsecurity=False
    ).hexdigest()
594
    return result
595
596
597
598
599
600
601
602


def _compute_code_hash(files: set[str]) -> str:
    logger.debug(
        "Traced files (to be considered for compilation cache):\n%s", "\n".join(files)
    )
    file_contents = {}
    for filepath in files:
603
604
        # Skip files that don't exist (e.g., <string>, <frozen modules>, etc.)
        if not os.path.isfile(filepath):
605
606
607
608
609
            file_contents[filepath] = ""
        else:
            with open(filepath) as f:
                file_contents[filepath] = f.read()
    return _compute_code_hash_with_content(file_contents)