serial_utils.py 15.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
import importlib
6
import pickle
7
8
from collections.abc import Sequence
from inspect import isclass
9
from types import FunctionType
10
from typing import Any, Optional, Union
11

12
import cloudpickle
13
import msgspec
14
import numpy as np
15
import torch
16
import zmq
17
18
from msgspec import msgpack

19
from vllm import envs
20
from vllm.logger import init_logger
21
22
23
24
25
26
from vllm.multimodal.inputs import (BaseMultiModalField,
                                    MultiModalBatchedField,
                                    MultiModalFieldConfig, MultiModalFieldElem,
                                    MultiModalFlatField, MultiModalKwargs,
                                    MultiModalKwargsItem,
                                    MultiModalSharedField, NestedTensors)
27
from vllm.v1.engine import UtilityResult
28

29
30
logger = init_logger(__name__)

31
32
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
33
CUSTOM_TYPE_RAW_VIEW = 3
34

35
36
37
38
39
40
41
42
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
    MultiModalFlatField: "flat",
    MultiModalSharedField: "shared",
    MultiModalBatchedField: "batched",
}
43

44
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
45
46


47
48
49
50
51
def _log_insecure_serialization_warning():
    logger.warning_once("Allowing insecure serialization using pickle due to "
                        "VLLM_ALLOW_INSECURE_SERIALIZATION=1")


52
53
54
55
def _typestr(val: Any) -> Optional[tuple[str, str]]:
    if val is None:
        return None
    t = type(val)
56
57
58
    return t.__module__, t.__qualname__


59
60
class MsgpackEncoder:
    """Encoder with custom torch tensor and numpy array serialization.
61

62
63
    Note that unlike vanilla `msgspec` Encoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
64
65
66

    By default, arrays below 256B are serialized inline Larger will get sent 
    via dedicated messages. Note that this is a per-tensor limit.
67
68
    """

69
    def __init__(self, size_threshold: Optional[int] = None):
70
71
        if size_threshold is None:
            size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
72
73
74
75
76
        self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
        # This is used as a local stash of buffers that we can then access from
        # our custom `msgspec` hook, `enc_hook`. We don't have a way to
        # pass custom data to the hook otherwise.
        self.aux_buffers: Optional[list[bytestr]] = None
77
        self.size_threshold = size_threshold
78
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
79
            _log_insecure_serialization_warning()
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

    def encode(self, obj: Any) -> Sequence[bytestr]:
        try:
            self.aux_buffers = bufs = [b'']
            bufs[0] = self.encoder.encode(obj)
            # This `bufs` list allows us to collect direct pointers to backing
            # buffers of tensors and np arrays, and return them along with the
            # top-level encoded buffer instead of copying their data into the
            # new buffer.
            return bufs
        finally:
            self.aux_buffers = None

    def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
        try:
            self.aux_buffers = [buf]
            bufs = self.aux_buffers
            self.encoder.encode_into(obj, buf)
            return bufs
        finally:
            self.aux_buffers = None

    def enc_hook(self, obj: Any) -> Any:
        if isinstance(obj, torch.Tensor):
104
            return self._encode_tensor(obj)
105
106
107
108
109

        # Fall back to pickle for object or void kind ndarrays.
        if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
            return self._encode_ndarray(obj)

110
111
112
113
114
115
        if isinstance(obj, slice):
            # We are assuming only int-based values will be used here.
            return tuple(
                int(v) if v is not None else None
                for v in (obj.start, obj.stop, obj.step))

116
117
118
        if isinstance(obj, MultiModalKwargsItem):
            return self._encode_mm_item(obj)

119
120
121
122
123
124
125
126
        if isinstance(obj, MultiModalKwargs):
            mm: MultiModalKwargs = obj
            if not mm.modalities:
                # just return the main dict if there are no modalities.
                return dict(mm)

            # ignore the main dict, it will be re-indexed.
            # Any tensors *not* indexed by modality will be ignored.
127
128
129
130
131
            return [
                self._encode_mm_item(item)
                for itemlist in mm._items_by_modality.values()
                for item in itemlist
            ]
132

133
134
        if isinstance(obj, UtilityResult):
            result = obj.result
135
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
136
137
138
139
                return None, result
            # Since utility results are not strongly typed, we also encode
            # the type (or a list of types in the case it's a list) to
            # help with correct msgspec deserialization.
140
141
            return _typestr(result) if type(result) is not list else [
                _typestr(v) for v in result
142
143
            ], result

144
        if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
145
146
147
            raise TypeError(f"Object of type {type(obj)} is not serializable"
                            "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
                            "fallback to pickle-based serialization.")
148

149
150
151
152
153
154
155
156
157
158
159
160
        if isinstance(obj, FunctionType):
            # `pickle` is generally faster than cloudpickle, but can have
            # problems serializing methods.
            return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

        return msgpack.Ext(CUSTOM_TYPE_PICKLE,
                           pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))

    def _encode_ndarray(
        self, obj: np.ndarray
    ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
        assert self.aux_buffers is not None
161
        # If the array is non-contiguous, we need to copy it first
162
        arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
163
        if not obj.shape or obj.nbytes < self.size_threshold:
164
165
166
            # Encode small arrays and scalars inline. Using this extension type
            # ensures we can avoid copying when decoding.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
167
        else:
168
            # Otherwise encode index of backing buffer to avoid copy.
169
            data = len(self.aux_buffers)
170
171
            self.aux_buffers.append(arr_data)

172
173
174
175
        # We serialize the ndarray as a tuple of native types.
        # The data is either inlined if small, or an index into a list of
        # backing buffers that we've stashed in `aux_buffers`.
        return obj.dtype.str, obj.shape, data
176

177
178
179
180
    def _encode_tensor(
        self, obj: torch.Tensor
    ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
        assert self.aux_buffers is not None
181
182
        # view the tensor as a contiguous 1D array of bytes
        arr = obj.flatten().contiguous().view(torch.uint8).numpy()
183
184
185
186
187
188
189
        if obj.nbytes < self.size_threshold:
            # Smaller tensors are encoded inline, just like ndarrays.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
        else:
            # Otherwise encode index of backing buffer to avoid copy.
            data = len(self.aux_buffers)
            self.aux_buffers.append(arr.data)
190
        dtype = str(obj.dtype).removeprefix("torch.")
191
192
        return dtype, obj.shape, data

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    def _encode_mm_item(self,
                        item: MultiModalKwargsItem) -> list[dict[str, Any]]:
        return [self._encode_mm_field_elem(elem) for elem in item.values()]

    def _encode_mm_field_elem(self,
                              elem: MultiModalFieldElem) -> dict[str, Any]:
        return {
            "modality":
            elem.modality,
            "key":
            elem.key,
            "data": (None if elem.data is None else
                     self._encode_nested_tensors(elem.data)),
            "field":
            self._encode_mm_field(elem.field),
        }

210
211
    def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
        if isinstance(nt, torch.Tensor):
212
            return self._encode_tensor(nt)
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        if isinstance(nt, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return nt
        return [self._encode_nested_tensors(x) for x in nt]

    def _encode_mm_field(self, field: BaseMultiModalField):
        # Figure out the factory name for the field type.
        name = MMF_CLASS_TO_FACTORY.get(field.__class__)
        if not name:
            raise TypeError(f"Unsupported field type: {field.__class__}")
        # We just need to copy all of the field values in order
        # which will be then used to reconstruct the field.
        field_values = (getattr(field, f.name)
                        for f in dataclasses.fields(field))
        return name, *field_values

230
231

class MsgpackDecoder:
232
233
234
235
236
    """Decoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Decoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
    """
237

238
    def __init__(self, t: Optional[Any] = None):
239
        args = () if t is None else (t, )
240
241
242
243
        self.decoder = msgpack.Decoder(*args,
                                       ext_hook=self.ext_hook,
                                       dec_hook=self.dec_hook)
        self.aux_buffers: Sequence[bytestr] = ()
244
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
245
            _log_insecure_serialization_warning()
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264

    def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
        if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
            # TODO - This check can become `isinstance(bufs, bytestr)`
            # as of Python 3.10.
            return self.decoder.decode(bufs)

        self.aux_buffers = bufs
        try:
            return self.decoder.decode(bufs[0])
        finally:
            self.aux_buffers = ()

    def dec_hook(self, t: type, obj: Any) -> Any:
        # Given native types in `obj`, convert to type `t`.
        if isclass(t):
            if issubclass(t, np.ndarray):
                return self._decode_ndarray(obj)
            if issubclass(t, torch.Tensor):
265
                return self._decode_tensor(obj)
266
267
            if t is slice:
                return slice(*obj)
268
269
            if issubclass(t, MultiModalKwargsItem):
                return self._decode_mm_item(obj)
270
271
272
273
274
275
276
277
            if issubclass(t, MultiModalKwargs):
                if isinstance(obj, list):
                    return MultiModalKwargs.from_items(
                        self._decode_mm_items(obj))
                return MultiModalKwargs({
                    k: self._decode_nested_tensors(v)
                    for k, v in obj.items()
                })
278
279
            if t is UtilityResult:
                return self._decode_utility_result(obj)
280
281
        return obj

282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    def _decode_utility_result(self, obj: Any) -> UtilityResult:
        result_type, result = obj
        if result_type is not None:
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
                raise TypeError("VLLM_ALLOW_INSECURE_SERIALIZATION must "
                                "be set to use custom utility result types")
            assert isinstance(result_type, list)
            if len(result_type) == 2 and isinstance(result_type[0], str):
                result = self._convert_result(result_type, result)
            else:
                assert isinstance(result, list)
                result = [
                    self._convert_result(rt, r)
                    for rt, r in zip(result_type, result)
                ]
        return UtilityResult(result)

299
300
301
    def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
        if result_type is None:
            return result
302
303
304
305
306
        mod_name, name = result_type
        mod = importlib.import_module(mod_name)
        result_type = getattr(mod, name)
        return msgspec.convert(result, result_type, dec_hook=self.dec_hook)

307
308
    def _decode_ndarray(self, arr: Any) -> np.ndarray:
        dtype, shape, data = arr
309
310
311
        # zero-copy decode. We assume the ndarray will not be kept around,
        # as it now locks the whole received message buffer in memory.
        buffer = self.aux_buffers[data] if isinstance(data, int) else data
312
        return np.frombuffer(buffer, dtype=dtype).reshape(shape)
313
314
315
316
317
318

    def _decode_tensor(self, arr: Any) -> torch.Tensor:
        dtype, shape, data = arr
        # Copy from inline representation, to decouple the memory storage
        # of the message from the original buffer. And also make Torch
        # not complain about a readonly memoryview.
319
320
        buffer = self.aux_buffers[data] if isinstance(data, int) \
            else bytearray(data)
321
322
        torch_dtype = getattr(torch, dtype)
        assert isinstance(torch_dtype, torch.dtype)
323
324
325
326
327
        if not buffer:  # torch.frombuffer doesn't like empty buffers
            assert 0 in shape
            return torch.empty(shape, dtype=torch_dtype)
        # Create uint8 array
        arr = torch.frombuffer(buffer, dtype=torch.uint8)
328
        # Convert back to proper shape & type
329
        return arr.view(torch_dtype).view(shape)
330

331
    def _decode_mm_items(self, obj: list[Any]) -> list[MultiModalKwargsItem]:
332
333
        return [self._decode_mm_item(v) for v in obj]

334
    def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
335
336
337
        return MultiModalKwargsItem.from_elems(
            [self._decode_mm_field_elem(v) for v in obj])

338
339
340
341
342
    def _decode_mm_field_elem(self, obj: dict[str,
                                              Any]) -> MultiModalFieldElem:
        if obj["data"] is not None:
            obj["data"] = self._decode_nested_tensors(obj["data"])

343
344
345
346
347
348
349
350
351
352
353
        # Reconstruct the field processor using MultiModalFieldConfig
        factory_meth_name, *field_args = obj["field"]
        factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)

        # Special case: decode the union "slices" field of
        # MultiModalFlatField
        if factory_meth_name == "flat":
            field_args[0] = self._decode_nested_slices(field_args[0])

        obj["field"] = factory_meth(None, *field_args).field
        return MultiModalFieldElem(**obj)
354
355
356
357
358
359
360
361
362

    def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
        if isinstance(obj, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return obj
        if not isinstance(obj, list):
            raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
        if obj and isinstance(obj[0], str):
363
            return self._decode_tensor(obj)
364
365
        return [self._decode_nested_tensors(x) for x in obj]

366
367
368
369
370
371
    def _decode_nested_slices(self, obj: Any) -> Any:
        assert isinstance(obj, (list, tuple))
        if obj and not isinstance(obj[0], (list, tuple)):
            return slice(*obj)
        return [self._decode_nested_slices(x) for x in obj]

372
    def ext_hook(self, code: int, data: memoryview) -> Any:
373
374
        if code == CUSTOM_TYPE_RAW_VIEW:
            return data
375

376
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
377
378
379
380
            if code == CUSTOM_TYPE_PICKLE:
                return pickle.loads(data)
            if code == CUSTOM_TYPE_CLOUDPICKLE:
                return cloudpickle.loads(data)
381
382
383

        raise NotImplementedError(
            f"Extension type code {code} is not supported")