serial_utils.py 12.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import dataclasses
4
import pickle
5
6
from collections.abc import Sequence
from inspect import isclass
7
from types import FunctionType
8
from typing import Any, Optional, Union
9

10
import cloudpickle
11
import numpy as np
12
import torch
13
import zmq
14
15
from msgspec import msgpack

16
from vllm import envs
17
from vllm.logger import init_logger
18
19
20
21
22
23
24
from vllm.multimodal.inputs import (BaseMultiModalField,
                                    MultiModalBatchedField,
                                    MultiModalFieldConfig, MultiModalFieldElem,
                                    MultiModalFlatField, MultiModalKwargs,
                                    MultiModalKwargsItem,
                                    MultiModalSharedField, NestedTensors)

25
26
logger = init_logger(__name__)

27
28
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
29
CUSTOM_TYPE_RAW_VIEW = 3
30

31
32
33
34
35
36
37
38
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
    MultiModalFlatField: "flat",
    MultiModalSharedField: "shared",
    MultiModalBatchedField: "batched",
}
39

40
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
41
42


43
44
45
46
47
def _log_insecure_serialization_warning():
    logger.warning_once("Allowing insecure serialization using pickle due to "
                        "VLLM_ALLOW_INSECURE_SERIALIZATION=1")


48
49
class MsgpackEncoder:
    """Encoder with custom torch tensor and numpy array serialization.
50

51
52
    Note that unlike vanilla `msgspec` Encoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
53
54
55

    By default, arrays below 256B are serialized inline Larger will get sent 
    via dedicated messages. Note that this is a per-tensor limit.
56
57
    """

58
    def __init__(self, size_threshold: Optional[int] = None):
59
60
        if size_threshold is None:
            size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
61
62
63
64
65
        self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
        # This is used as a local stash of buffers that we can then access from
        # our custom `msgspec` hook, `enc_hook`. We don't have a way to
        # pass custom data to the hook otherwise.
        self.aux_buffers: Optional[list[bytestr]] = None
66
        self.size_threshold = size_threshold
67
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
68
            _log_insecure_serialization_warning()
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

    def encode(self, obj: Any) -> Sequence[bytestr]:
        try:
            self.aux_buffers = bufs = [b'']
            bufs[0] = self.encoder.encode(obj)
            # This `bufs` list allows us to collect direct pointers to backing
            # buffers of tensors and np arrays, and return them along with the
            # top-level encoded buffer instead of copying their data into the
            # new buffer.
            return bufs
        finally:
            self.aux_buffers = None

    def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
        try:
            self.aux_buffers = [buf]
            bufs = self.aux_buffers
            self.encoder.encode_into(obj, buf)
            return bufs
        finally:
            self.aux_buffers = None

    def enc_hook(self, obj: Any) -> Any:
        if isinstance(obj, torch.Tensor):
93
            return self._encode_tensor(obj)
94
95
96
97
98

        # Fall back to pickle for object or void kind ndarrays.
        if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
            return self._encode_ndarray(obj)

99
100
101
102
103
104
        if isinstance(obj, slice):
            # We are assuming only int-based values will be used here.
            return tuple(
                int(v) if v is not None else None
                for v in (obj.start, obj.stop, obj.step))

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        if isinstance(obj, MultiModalKwargs):
            mm: MultiModalKwargs = obj
            if not mm.modalities:
                # just return the main dict if there are no modalities.
                return dict(mm)

            # ignore the main dict, it will be re-indexed.
            # Encode a list of MultiModalKwargsItems as plain dicts
            # + special handling for .field.
            # Any tensors *not* indexed by modality will be ignored.
            return [[{
                "modality": elem.modality,
                "key": elem.key,
                "data": self._encode_nested_tensors(elem.data),
                "field": self._encode_mm_field(elem.field),
            } for elem in item.values()]
                    for itemlist in mm._items_by_modality.values()
                    for item in itemlist]

124
        if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
125
126
127
            raise TypeError(f"Object of type {type(obj)} is not serializable"
                            "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
                            "fallback to pickle-based serialization.")
128

129
130
131
132
133
134
135
136
137
138
139
140
        if isinstance(obj, FunctionType):
            # `pickle` is generally faster than cloudpickle, but can have
            # problems serializing methods.
            return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

        return msgpack.Ext(CUSTOM_TYPE_PICKLE,
                           pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))

    def _encode_ndarray(
        self, obj: np.ndarray
    ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
        assert self.aux_buffers is not None
141
        # If the array is non-contiguous, we need to copy it first
142
        arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
143
        if not obj.shape or obj.nbytes < self.size_threshold:
144
145
146
            # Encode small arrays and scalars inline. Using this extension type
            # ensures we can avoid copying when decoding.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
147
        else:
148
            # Otherwise encode index of backing buffer to avoid copy.
149
            data = len(self.aux_buffers)
150
151
            self.aux_buffers.append(arr_data)

152
153
154
155
        # We serialize the ndarray as a tuple of native types.
        # The data is either inlined if small, or an index into a list of
        # backing buffers that we've stashed in `aux_buffers`.
        return obj.dtype.str, obj.shape, data
156

157
158
159
160
161
    def _encode_tensor(
        self, obj: torch.Tensor
    ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
        assert self.aux_buffers is not None
        #  view the tensor as a 1D array of bytes
162
        arr = obj.flatten().view(torch.uint8).numpy()
163
164
165
166
167
168
169
        if obj.nbytes < self.size_threshold:
            # Smaller tensors are encoded inline, just like ndarrays.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
        else:
            # Otherwise encode index of backing buffer to avoid copy.
            data = len(self.aux_buffers)
            self.aux_buffers.append(arr.data)
170
        dtype = str(obj.dtype).removeprefix("torch.")
171
172
        return dtype, obj.shape, data

173
174
    def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
        if isinstance(nt, torch.Tensor):
175
            return self._encode_tensor(nt)
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        if isinstance(nt, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return nt
        return [self._encode_nested_tensors(x) for x in nt]

    def _encode_mm_field(self, field: BaseMultiModalField):
        # Figure out the factory name for the field type.
        name = MMF_CLASS_TO_FACTORY.get(field.__class__)
        if not name:
            raise TypeError(f"Unsupported field type: {field.__class__}")
        # We just need to copy all of the field values in order
        # which will be then used to reconstruct the field.
        field_values = (getattr(field, f.name)
                        for f in dataclasses.fields(field))
        return name, *field_values

193
194

class MsgpackDecoder:
195
196
197
198
199
    """Decoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Decoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
    """
200

201
    def __init__(self, t: Optional[Any] = None):
202
        args = () if t is None else (t, )
203
204
205
206
        self.decoder = msgpack.Decoder(*args,
                                       ext_hook=self.ext_hook,
                                       dec_hook=self.dec_hook)
        self.aux_buffers: Sequence[bytestr] = ()
207
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
208
            _log_insecure_serialization_warning()
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

    def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
        if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
            # TODO - This check can become `isinstance(bufs, bytestr)`
            # as of Python 3.10.
            return self.decoder.decode(bufs)

        self.aux_buffers = bufs
        try:
            return self.decoder.decode(bufs[0])
        finally:
            self.aux_buffers = ()

    def dec_hook(self, t: type, obj: Any) -> Any:
        # Given native types in `obj`, convert to type `t`.
        if isclass(t):
            if issubclass(t, np.ndarray):
                return self._decode_ndarray(obj)
            if issubclass(t, torch.Tensor):
228
                return self._decode_tensor(obj)
229
230
            if t is slice:
                return slice(*obj)
231
232
233
234
235
236
237
238
            if issubclass(t, MultiModalKwargs):
                if isinstance(obj, list):
                    return MultiModalKwargs.from_items(
                        self._decode_mm_items(obj))
                return MultiModalKwargs({
                    k: self._decode_nested_tensors(v)
                    for k, v in obj.items()
                })
239
240
241
242
        return obj

    def _decode_ndarray(self, arr: Any) -> np.ndarray:
        dtype, shape, data = arr
243
244
245
        # zero-copy decode. We assume the ndarray will not be kept around,
        # as it now locks the whole received message buffer in memory.
        buffer = self.aux_buffers[data] if isinstance(data, int) else data
246
        return np.frombuffer(buffer, dtype=dtype).reshape(shape)
247
248
249
250
251
252

    def _decode_tensor(self, arr: Any) -> torch.Tensor:
        dtype, shape, data = arr
        # Copy from inline representation, to decouple the memory storage
        # of the message from the original buffer. And also make Torch
        # not complain about a readonly memoryview.
253
254
        buffer = self.aux_buffers[data] if isinstance(data, int) \
            else bytearray(data)
255
256
        torch_dtype = getattr(torch, dtype)
        assert isinstance(torch_dtype, torch.dtype)
257
258
259
260
261
        if not buffer:  # torch.frombuffer doesn't like empty buffers
            assert 0 in shape
            return torch.empty(shape, dtype=torch_dtype)
        # Create uint8 array
        arr = torch.frombuffer(buffer, dtype=torch.uint8)
262
        # Convert back to proper shape & type
263
        return arr.view(torch_dtype).view(shape)
264

265
266
267
268
269
270
271
272
273
274
    def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
        decoded_items = []
        for item in obj:
            elems = []
            for v in item:
                v["data"] = self._decode_nested_tensors(v["data"])
                # Reconstruct the field processor using MultiModalFieldConfig
                factory_meth_name, *field_args = v["field"]
                factory_meth = getattr(MultiModalFieldConfig,
                                       factory_meth_name)
275
276
277
278
279
280

                # Special case: decode the union "slices" field of
                # MultiModalFlatField
                if factory_meth_name == "flat":
                    field_args[0] = self._decode_nested_slices(field_args[0])

281
282
283
284
285
286
287
288
289
290
291
292
293
                v["field"] = factory_meth(None, *field_args).field
                elems.append(MultiModalFieldElem(**v))
            decoded_items.append(MultiModalKwargsItem.from_elems(elems))
        return decoded_items

    def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
        if isinstance(obj, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return obj
        if not isinstance(obj, list):
            raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
        if obj and isinstance(obj[0], str):
294
            return self._decode_tensor(obj)
295
296
        return [self._decode_nested_tensors(x) for x in obj]

297
298
299
300
301
302
    def _decode_nested_slices(self, obj: Any) -> Any:
        assert isinstance(obj, (list, tuple))
        if obj and not isinstance(obj[0], (list, tuple)):
            return slice(*obj)
        return [self._decode_nested_slices(x) for x in obj]

303
    def ext_hook(self, code: int, data: memoryview) -> Any:
304
305
        if code == CUSTOM_TYPE_RAW_VIEW:
            return data
306

307
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
308
309
310
311
            if code == CUSTOM_TYPE_PICKLE:
                return pickle.loads(data)
            if code == CUSTOM_TYPE_CLOUDPICKLE:
                return cloudpickle.loads(data)
312
313
314

        raise NotImplementedError(
            f"Extension type code {code} is not supported")