"vllm/core/block_manager.py" did not exist on "be1e2163c9f9519310abe2519caa36b0a6966a1b"
serial_utils.py 12.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import dataclasses
4
import pickle
5
6
from collections.abc import Sequence
from inspect import isclass
7
from types import FunctionType
8
from typing import Any, Optional, Union
9

10
import cloudpickle
11
import numpy as np
12
import torch
13
import zmq
14
15
from msgspec import msgpack

16
from vllm import envs
17
from vllm.logger import init_logger
18
19
20
21
22
23
24
from vllm.multimodal.inputs import (BaseMultiModalField,
                                    MultiModalBatchedField,
                                    MultiModalFieldConfig, MultiModalFieldElem,
                                    MultiModalFlatField, MultiModalKwargs,
                                    MultiModalKwargsItem,
                                    MultiModalSharedField, NestedTensors)

25
26
logger = init_logger(__name__)

27
28
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
29
CUSTOM_TYPE_RAW_VIEW = 3
30

31
32
33
34
35
36
37
38
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
    MultiModalFlatField: "flat",
    MultiModalSharedField: "shared",
    MultiModalBatchedField: "batched",
}
39

40
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
41
42


43
44
class MsgpackEncoder:
    """Encoder with custom torch tensor and numpy array serialization.
45

46
47
    Note that unlike vanilla `msgspec` Encoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
48
49
50

    By default, arrays below 256B are serialized inline Larger will get sent 
    via dedicated messages. Note that this is a per-tensor limit.
51
52
    """

53
    def __init__(self, size_threshold: Optional[int] = None):
54
55
        if size_threshold is None:
            size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
56
57
58
59
60
        self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
        # This is used as a local stash of buffers that we can then access from
        # our custom `msgspec` hook, `enc_hook`. We don't have a way to
        # pass custom data to the hook otherwise.
        self.aux_buffers: Optional[list[bytestr]] = None
61
        self.size_threshold = size_threshold
62
63
64
65
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
            logger.warning(
                "Allowing insecure serialization using pickle due to "
                "VLLM_ALLOW_INSECURE_SERIALIZATION=1")
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

    def encode(self, obj: Any) -> Sequence[bytestr]:
        try:
            self.aux_buffers = bufs = [b'']
            bufs[0] = self.encoder.encode(obj)
            # This `bufs` list allows us to collect direct pointers to backing
            # buffers of tensors and np arrays, and return them along with the
            # top-level encoded buffer instead of copying their data into the
            # new buffer.
            return bufs
        finally:
            self.aux_buffers = None

    def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
        try:
            self.aux_buffers = [buf]
            bufs = self.aux_buffers
            self.encoder.encode_into(obj, buf)
            return bufs
        finally:
            self.aux_buffers = None

    def enc_hook(self, obj: Any) -> Any:
        if isinstance(obj, torch.Tensor):
90
            return self._encode_tensor(obj)
91
92
93
94
95

        # Fall back to pickle for object or void kind ndarrays.
        if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
            return self._encode_ndarray(obj)

96
97
98
99
100
101
        if isinstance(obj, slice):
            # We are assuming only int-based values will be used here.
            return tuple(
                int(v) if v is not None else None
                for v in (obj.start, obj.stop, obj.step))

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        if isinstance(obj, MultiModalKwargs):
            mm: MultiModalKwargs = obj
            if not mm.modalities:
                # just return the main dict if there are no modalities.
                return dict(mm)

            # ignore the main dict, it will be re-indexed.
            # Encode a list of MultiModalKwargsItems as plain dicts
            # + special handling for .field.
            # Any tensors *not* indexed by modality will be ignored.
            return [[{
                "modality": elem.modality,
                "key": elem.key,
                "data": self._encode_nested_tensors(elem.data),
                "field": self._encode_mm_field(elem.field),
            } for elem in item.values()]
                    for itemlist in mm._items_by_modality.values()
                    for item in itemlist]

121
        if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
122
123
            raise TypeError(f"Object of type {type(obj)} is not serializable")

124
125
126
127
128
129
130
131
132
133
134
135
        if isinstance(obj, FunctionType):
            # `pickle` is generally faster than cloudpickle, but can have
            # problems serializing methods.
            return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

        return msgpack.Ext(CUSTOM_TYPE_PICKLE,
                           pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))

    def _encode_ndarray(
        self, obj: np.ndarray
    ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
        assert self.aux_buffers is not None
136
        # If the array is non-contiguous, we need to copy it first
137
        arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
138
        if not obj.shape or obj.nbytes < self.size_threshold:
139
140
141
            # Encode small arrays and scalars inline. Using this extension type
            # ensures we can avoid copying when decoding.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
142
        else:
143
            # Otherwise encode index of backing buffer to avoid copy.
144
            data = len(self.aux_buffers)
145
146
            self.aux_buffers.append(arr_data)

147
148
149
150
        # We serialize the ndarray as a tuple of native types.
        # The data is either inlined if small, or an index into a list of
        # backing buffers that we've stashed in `aux_buffers`.
        return obj.dtype.str, obj.shape, data
151

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    def _encode_tensor(
        self, obj: torch.Tensor
    ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
        assert self.aux_buffers is not None
        # this creates a copy of the tensor if it's not already contiguous
        obj = obj.contiguous()
        #  view the tensor as a 1D array of bytes
        arr = obj.view((obj.numel(), )).view(torch.uint8).numpy()
        if obj.nbytes < self.size_threshold:
            # Smaller tensors are encoded inline, just like ndarrays.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
        else:
            # Otherwise encode index of backing buffer to avoid copy.
            data = len(self.aux_buffers)
            self.aux_buffers.append(arr.data)
        dtype = str(obj.dtype)[6:]  # remove 'torch.' prefix
        return dtype, obj.shape, data

170
171
    def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
        if isinstance(nt, torch.Tensor):
172
            return self._encode_tensor(nt)
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        if isinstance(nt, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return nt
        return [self._encode_nested_tensors(x) for x in nt]

    def _encode_mm_field(self, field: BaseMultiModalField):
        # Figure out the factory name for the field type.
        name = MMF_CLASS_TO_FACTORY.get(field.__class__)
        if not name:
            raise TypeError(f"Unsupported field type: {field.__class__}")
        # We just need to copy all of the field values in order
        # which will be then used to reconstruct the field.
        field_values = (getattr(field, f.name)
                        for f in dataclasses.fields(field))
        return name, *field_values

190
191

class MsgpackDecoder:
192
193
194
195
196
    """Decoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Decoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
    """
197

198
    def __init__(self, t: Optional[Any] = None):
199
        args = () if t is None else (t, )
200
201
202
203
        self.decoder = msgpack.Decoder(*args,
                                       ext_hook=self.ext_hook,
                                       dec_hook=self.dec_hook)
        self.aux_buffers: Sequence[bytestr] = ()
204
205
206
207
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
            logger.warning(
                "Allowing insecure deserialization using pickle due to "
                "VLLM_ALLOW_INSECURE_SERIALIZATION=1")
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

    def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
        if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
            # TODO - This check can become `isinstance(bufs, bytestr)`
            # as of Python 3.10.
            return self.decoder.decode(bufs)

        self.aux_buffers = bufs
        try:
            return self.decoder.decode(bufs[0])
        finally:
            self.aux_buffers = ()

    def dec_hook(self, t: type, obj: Any) -> Any:
        # Given native types in `obj`, convert to type `t`.
        if isclass(t):
            if issubclass(t, np.ndarray):
                return self._decode_ndarray(obj)
            if issubclass(t, torch.Tensor):
227
                return self._decode_tensor(obj)
228
229
            if t is slice:
                return slice(*obj)
230
231
232
233
234
235
236
237
            if issubclass(t, MultiModalKwargs):
                if isinstance(obj, list):
                    return MultiModalKwargs.from_items(
                        self._decode_mm_items(obj))
                return MultiModalKwargs({
                    k: self._decode_nested_tensors(v)
                    for k, v in obj.items()
                })
238
239
240
241
        return obj

    def _decode_ndarray(self, arr: Any) -> np.ndarray:
        dtype, shape, data = arr
242
243
244
245
246
247
248
249
250
251
        # zero-copy decode. We assume the ndarray will not be kept around,
        # as it now locks the whole received message buffer in memory.
        buffer = self.aux_buffers[data] if isinstance(data, int) else data
        return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)

    def _decode_tensor(self, arr: Any) -> torch.Tensor:
        dtype, shape, data = arr
        # Copy from inline representation, to decouple the memory storage
        # of the message from the original buffer. And also make Torch
        # not complain about a readonly memoryview.
252
253
        buffer = self.aux_buffers[data] if isinstance(data, int) \
            else bytearray(data)
254
255
256
257
258
259
        # Create numpy wrapper around the bytes
        arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=(len(buffer), ))
        torch_dtype = getattr(torch, dtype)
        assert isinstance(torch_dtype, torch.dtype)
        # Convert back to proper shape & type
        return torch.from_numpy(arr).view(torch_dtype).view(shape)
260

261
262
263
264
265
266
267
268
269
270
    def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
        decoded_items = []
        for item in obj:
            elems = []
            for v in item:
                v["data"] = self._decode_nested_tensors(v["data"])
                # Reconstruct the field processor using MultiModalFieldConfig
                factory_meth_name, *field_args = v["field"]
                factory_meth = getattr(MultiModalFieldConfig,
                                       factory_meth_name)
271
272
273
274
275
276

                # Special case: decode the union "slices" field of
                # MultiModalFlatField
                if factory_meth_name == "flat":
                    field_args[0] = self._decode_nested_slices(field_args[0])

277
278
279
280
281
282
283
284
285
286
287
288
289
                v["field"] = factory_meth(None, *field_args).field
                elems.append(MultiModalFieldElem(**v))
            decoded_items.append(MultiModalKwargsItem.from_elems(elems))
        return decoded_items

    def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
        if isinstance(obj, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return obj
        if not isinstance(obj, list):
            raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
        if obj and isinstance(obj[0], str):
290
            return self._decode_tensor(obj)
291
292
        return [self._decode_nested_tensors(x) for x in obj]

293
294
295
296
297
298
    def _decode_nested_slices(self, obj: Any) -> Any:
        assert isinstance(obj, (list, tuple))
        if obj and not isinstance(obj[0], (list, tuple)):
            return slice(*obj)
        return [self._decode_nested_slices(x) for x in obj]

299
    def ext_hook(self, code: int, data: memoryview) -> Any:
300
301
        if code == CUSTOM_TYPE_RAW_VIEW:
            return data
302

303
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
304
305
306
307
            if code == CUSTOM_TYPE_PICKLE:
                return pickle.loads(data)
            if code == CUSTOM_TYPE_CLOUDPICKLE:
                return cloudpickle.loads(data)
308
309
310

        raise NotImplementedError(
            f"Extension type code {code} is not supported")