caching.py 7.25 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import inspect
5
import os
6
import pickle
7
8
from collections.abc import Callable, Sequence
from typing import Any, Literal
9
10
11
12
13
14
15
from unittest.mock import patch

import torch
from torch.utils import _pytree as pytree

import vllm.envs as envs
from vllm.config import VllmConfig, get_current_vllm_config
Driss Guessous's avatar
Driss Guessous committed
16
from vllm.config.utils import hash_factors
17
from vllm.logger import init_logger
18
from vllm.utils.hashing import safe_hash
19
20
21
22
23
24
25
26
27
28
29

try:
    from torch._dynamo.aot_compile import SerializableCallable
except ImportError:
    SerializableCallable = object

assert isinstance(SerializableCallable, type)

logger = init_logger(__name__)


30
class VllmSerializableFunction(SerializableCallable):  # type: ignore[misc]
31
32
33
34
35
36
37
38
39
40
41
    """
    A wrapper around a compiled function by vllm. It will forward the tensor
    inputs to the compiled function and return the result.
    It also implements a serialization interface to support PyTorch's precompile
    with custom backend, so that we can save and load the compiled function on
    disk. There's no need to wrap around the compiled function if we don't want
    to serialize them in particular cases.
    Right now serialization for the custom backend is done via
    serializing the Dynamo fx graph plus example inputs.
    """

42
    def __init__(
43
44
45
46
47
48
49
        self,
        graph_module: torch.fx.GraphModule,
        example_inputs: Sequence[Any],
        prefix: str,
        optimized_call: Callable[..., Any],
        is_encoder: bool = False,
    ) -> None:
50
51
52
53
54
        assert isinstance(graph_module, torch.fx.GraphModule)
        self.graph_module = graph_module
        self.example_inputs = example_inputs
        self.prefix = prefix
        self.optimized_call = optimized_call
55
        self.is_encoder = is_encoder
56
57
58
59
60
61
62
        self.shape_env = None
        sym_input = next(
            (i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
        )
        if sym_input is not None:
            self.shape_env = sym_input.node.shape_env

63
    def __call__(self, *args: Any, **kwargs: Any) -> Any:
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        return self.optimized_call(*args, **kwargs)

    @classmethod
    def serialize_compile_artifacts(
        cls, compiled_fn: "VllmSerializableFunction"
    ) -> bytes:
        import sympy
        from torch._subclasses import FakeTensorMode
        from torch.fx._graph_pickler import GraphPickler, Options

        state = compiled_fn.__dict__.copy()
        state.pop("optimized_call")
        state.pop("shape_env")
        for node in state["graph_module"].graph.nodes:
            node.meta.pop("source_fn_stack", None)
            node.meta.pop("nn_module_stack", None)

        graph_reducer_override = GraphPickler.reducer_override

83
84
85
        def _graph_reducer_override(
            self: GraphPickler, obj: Any
        ) -> tuple[Callable[..., Any], tuple[Any, ...]] | Any:
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
            if (
                inspect.isclass(obj)
                and issubclass(obj, sympy.Function)
                and hasattr(obj, "_torch_unpickler")
            ):
                return obj._torch_unpickler, (obj._torch_handler_name,)
            if isinstance(obj, FakeTensorMode):
                return type(None), ()
            return graph_reducer_override(self, obj)

        # Mask off tensor inputs since they are large and not needed.
        state["example_inputs"] = pytree.tree_map_only(
            torch.Tensor, lambda _: None, state["example_inputs"]
        )
        with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
            state["graph_module"] = GraphPickler.dumps(
                state["graph_module"], Options(ops_filter=None)
            )
            state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
        return pickle.dumps(state)

    @classmethod
    def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction":
        from torch._guards import TracingContext, tracing
        from torch._subclasses import FakeTensorMode
        from torch.fx._graph_pickler import GraphPickler
        from torch.fx.experimental.symbolic_shapes import ShapeEnv

        from vllm.compilation.backends import VllmBackend

        state = pickle.loads(data)
        fake_mode = FakeTensorMode(shape_env=ShapeEnv())
        state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
119
        state["graph_module"].recompile()
120
        state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
121
122
123
124
        is_encoder = state.get("is_encoder", False)
        vllm_backend = VllmBackend(
            get_current_vllm_config(), state["prefix"], is_encoder
        )
125

126
        def optimized_call(*example_inputs: Any) -> Any:
127
128
129
130
131
132
133
134
            """
            On the first run of the optimized call, we rerun the compiler
            backend which should result in a cache hit. After the backend
            call returns, we just do a one-time replacement of the optimized
            call with the compiled function, so that subsequent calls are on
            the AOT compiled path.
            """
            compile_inputs = [
135
136
                inp if inp is not None else example_inputs[i]
                for i, inp in enumerate(fn.example_inputs)
137
138
139
140
141
142
143
144
145
146
147
            ]
            with tracing(TracingContext(fake_mode)):
                fn.optimized_call = vllm_backend(
                    state["graph_module"], compile_inputs
                ).optimized_call
            return fn.optimized_call(*example_inputs)

        fn = cls(**state, optimized_call=optimized_call)
        return fn

    @property
148
    def co_name(self) -> Literal["VllmSerializableFunction"]:
149
150
151
152
153
154
155
156
157
158
        """
        Used for depyf debugging.
        """
        return "VllmSerializableFunction"


def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]:
    factors = []
    # 0. factors come from the env, for example, The values of
    # VLLM_PP_LAYER_PARTITION will affect the computation graph.
Driss Guessous's avatar
Driss Guessous committed
159
    env_hash = hash_factors(envs.compile_factors())
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    factors.append(env_hash)

    # 1. factors come from the vllm_config (it mainly summarizes how the
    #    model is created)
    config_hash = vllm_config.compute_hash()
    factors.append(config_hash)
    return factors


def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
    items = list(sorted(file_contents.items(), key=lambda x: x[0]))
    hash_content = []
    for filepath, content in items:
        hash_content.append(filepath)
        if filepath == "<string>":
            # This means the function was dynamically generated, with
            # e.g. exec(). We can't actually check these.
            continue
        hash_content.append(content)
179
    return safe_hash(
180
181
182
183
184
185
186
187
188
189
        "\n".join(hash_content).encode(), usedforsecurity=False
    ).hexdigest()


def _compute_code_hash(files: set[str]) -> str:
    logger.debug(
        "Traced files (to be considered for compilation cache):\n%s", "\n".join(files)
    )
    file_contents = {}
    for filepath in files:
190
191
        # Skip files that don't exist (e.g., <string>, <frozen modules>, etc.)
        if not os.path.isfile(filepath):
192
193
194
195
196
            file_contents[filepath] = ""
        else:
            with open(filepath) as f:
                file_contents[filepath] = f.read()
    return _compute_code_hash_with_content(file_contents)