serial_utils.py 16.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
import importlib
6
import pickle
7
8
from collections.abc import Sequence
from inspect import isclass
9
from types import FunctionType
10
from typing import Any, Callable, Optional, Union
11

12
import cloudpickle
13
import msgspec
14
import numpy as np
15
import torch
16
import zmq
17
18
from msgspec import msgpack

19
from vllm import envs
20
from vllm.logger import init_logger
21
# yapf: disable
22
23
24
25
26
from vllm.multimodal.inputs import (BaseMultiModalField,
                                    MultiModalBatchedField,
                                    MultiModalFieldConfig, MultiModalFieldElem,
                                    MultiModalFlatField, MultiModalKwargs,
                                    MultiModalKwargsItem,
27
                                    MultiModalKwargsItems,
28
                                    MultiModalSharedField, NestedTensors)
29
# yapf: enable
30
from vllm.v1.engine import UtilityResult
31

32
33
logger = init_logger(__name__)

34
35
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
36
CUSTOM_TYPE_RAW_VIEW = 3
37

38
39
40
41
42
43
44
45
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
    MultiModalFlatField: "flat",
    MultiModalSharedField: "shared",
    MultiModalBatchedField: "batched",
}
46

47
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
48
49


50
51
52
53
54
def _log_insecure_serialization_warning():
    logger.warning_once("Allowing insecure serialization using pickle due to "
                        "VLLM_ALLOW_INSECURE_SERIALIZATION=1")


55
56
57
58
def _typestr(val: Any) -> Optional[tuple[str, str]]:
    if val is None:
        return None
    t = type(val)
59
60
61
    return t.__module__, t.__qualname__


62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def _encode_type_info_recursive(obj: Any) -> Any:
    """Recursively encode type information for nested structures of
    lists/dicts."""
    if obj is None:
        return None
    if type(obj) is list:
        return [_encode_type_info_recursive(item) for item in obj]
    if type(obj) is dict:
        return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
    return _typestr(obj)


def _decode_type_info_recursive(
        type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any],
                                                        Any]) -> Any:
    """Recursively decode type information for nested structures of
    lists/dicts."""
    if type_info is None:
        return data
    if isinstance(type_info, dict):
        assert isinstance(data, dict)
        return {
            k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
            for k in type_info
        }
    if isinstance(type_info, list) and (
            # Exclude serialized tensors/numpy arrays.
            len(type_info) != 2 or not isinstance(type_info[0], str)):
        assert isinstance(data, list)
        return [
            _decode_type_info_recursive(ti, d, convert_fn)
            for ti, d in zip(type_info, data)
        ]
    return convert_fn(type_info, data)


98
99
class MsgpackEncoder:
    """Encoder with custom torch tensor and numpy array serialization.
100

101
102
    Note that unlike vanilla `msgspec` Encoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
103
104
105

    By default, arrays below 256B are serialized inline Larger will get sent 
    via dedicated messages. Note that this is a per-tensor limit.
106
107
    """

108
    def __init__(self, size_threshold: Optional[int] = None):
109
110
        if size_threshold is None:
            size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
111
112
113
114
115
        self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
        # This is used as a local stash of buffers that we can then access from
        # our custom `msgspec` hook, `enc_hook`. We don't have a way to
        # pass custom data to the hook otherwise.
        self.aux_buffers: Optional[list[bytestr]] = None
116
        self.size_threshold = size_threshold
117
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
118
            _log_insecure_serialization_warning()
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

    def encode(self, obj: Any) -> Sequence[bytestr]:
        try:
            self.aux_buffers = bufs = [b'']
            bufs[0] = self.encoder.encode(obj)
            # This `bufs` list allows us to collect direct pointers to backing
            # buffers of tensors and np arrays, and return them along with the
            # top-level encoded buffer instead of copying their data into the
            # new buffer.
            return bufs
        finally:
            self.aux_buffers = None

    def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
        try:
            self.aux_buffers = [buf]
            bufs = self.aux_buffers
            self.encoder.encode_into(obj, buf)
            return bufs
        finally:
            self.aux_buffers = None

    def enc_hook(self, obj: Any) -> Any:
        if isinstance(obj, torch.Tensor):
143
            return self._encode_tensor(obj)
144
145
146
147
148

        # Fall back to pickle for object or void kind ndarrays.
        if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
            return self._encode_ndarray(obj)

149
150
151
152
153
154
        if isinstance(obj, slice):
            # We are assuming only int-based values will be used here.
            return tuple(
                int(v) if v is not None else None
                for v in (obj.start, obj.stop, obj.step))

155
156
157
        if isinstance(obj, MultiModalKwargsItem):
            return self._encode_mm_item(obj)

158
159
160
        if isinstance(obj, MultiModalKwargsItems):
            return self._encode_mm_items(obj)

161
        if isinstance(obj, MultiModalKwargs):
162
            return self._encode_mm_kwargs(obj)
163

164
165
        if isinstance(obj, UtilityResult):
            result = obj.result
166
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
167
                return None, result
168
169
170
171
            # Since utility results are not strongly typed, we recursively
            # encode type information for nested structures of lists/dicts
            # to help with correct msgspec deserialization.
            return _encode_type_info_recursive(result), result
172

173
        if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
174
175
176
            raise TypeError(f"Object of type {type(obj)} is not serializable"
                            "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
                            "fallback to pickle-based serialization.")
177

178
179
180
181
182
183
184
185
186
187
188
189
        if isinstance(obj, FunctionType):
            # `pickle` is generally faster than cloudpickle, but can have
            # problems serializing methods.
            return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

        return msgpack.Ext(CUSTOM_TYPE_PICKLE,
                           pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))

    def _encode_ndarray(
        self, obj: np.ndarray
    ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
        assert self.aux_buffers is not None
190
        # If the array is non-contiguous, we need to copy it first
191
        arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
192
        if not obj.shape or obj.nbytes < self.size_threshold:
193
194
195
            # Encode small arrays and scalars inline. Using this extension type
            # ensures we can avoid copying when decoding.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
196
        else:
197
            # Otherwise encode index of backing buffer to avoid copy.
198
            data = len(self.aux_buffers)
199
200
            self.aux_buffers.append(arr_data)

201
202
203
204
        # We serialize the ndarray as a tuple of native types.
        # The data is either inlined if small, or an index into a list of
        # backing buffers that we've stashed in `aux_buffers`.
        return obj.dtype.str, obj.shape, data
205

206
207
208
209
    def _encode_tensor(
        self, obj: torch.Tensor
    ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
        assert self.aux_buffers is not None
210
        # view the tensor as a contiguous 1D array of bytes
211
        arr = obj.flatten().contiguous().view(torch.uint8).numpy()
212
213
214
215
216
217
218
        if obj.nbytes < self.size_threshold:
            # Smaller tensors are encoded inline, just like ndarrays.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
        else:
            # Otherwise encode index of backing buffer to avoid copy.
            data = len(self.aux_buffers)
            self.aux_buffers.append(arr.data)
219
        dtype = str(obj.dtype).removeprefix("torch.")
220
221
        return dtype, obj.shape, data

222
223
224
225
226
227
    def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]:
        return {
            modality: [self._encode_mm_item(item) for item in itemlist]
            for modality, itemlist in items.items()
        }

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    def _encode_mm_item(self,
                        item: MultiModalKwargsItem) -> list[dict[str, Any]]:
        return [self._encode_mm_field_elem(elem) for elem in item.values()]

    def _encode_mm_field_elem(self,
                              elem: MultiModalFieldElem) -> dict[str, Any]:
        return {
            "modality":
            elem.modality,
            "key":
            elem.key,
            "data": (None if elem.data is None else
                     self._encode_nested_tensors(elem.data)),
            "field":
            self._encode_mm_field(elem.field),
        }

245
246
247
248
249
250
    def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]:
        return {
            modality: self._encode_nested_tensors(data)
            for modality, data in kw.items()
        }

251
252
    def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
        if isinstance(nt, torch.Tensor):
253
            return self._encode_tensor(nt)
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        if isinstance(nt, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return nt
        return [self._encode_nested_tensors(x) for x in nt]

    def _encode_mm_field(self, field: BaseMultiModalField):
        # Figure out the factory name for the field type.
        name = MMF_CLASS_TO_FACTORY.get(field.__class__)
        if not name:
            raise TypeError(f"Unsupported field type: {field.__class__}")
        # We just need to copy all of the field values in order
        # which will be then used to reconstruct the field.
        field_values = (getattr(field, f.name)
                        for f in dataclasses.fields(field))
        return name, *field_values

271
272

class MsgpackDecoder:
273
274
275
276
277
    """Decoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Decoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
    """
278

279
    def __init__(self, t: Optional[Any] = None):
280
        args = () if t is None else (t, )
281
282
283
284
        self.decoder = msgpack.Decoder(*args,
                                       ext_hook=self.ext_hook,
                                       dec_hook=self.dec_hook)
        self.aux_buffers: Sequence[bytestr] = ()
285
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
286
            _log_insecure_serialization_warning()
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305

    def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
        if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
            # TODO - This check can become `isinstance(bufs, bytestr)`
            # as of Python 3.10.
            return self.decoder.decode(bufs)

        self.aux_buffers = bufs
        try:
            return self.decoder.decode(bufs[0])
        finally:
            self.aux_buffers = ()

    def dec_hook(self, t: type, obj: Any) -> Any:
        # Given native types in `obj`, convert to type `t`.
        if isclass(t):
            if issubclass(t, np.ndarray):
                return self._decode_ndarray(obj)
            if issubclass(t, torch.Tensor):
306
                return self._decode_tensor(obj)
307
308
            if t is slice:
                return slice(*obj)
309
310
            if issubclass(t, MultiModalKwargsItem):
                return self._decode_mm_item(obj)
311
312
            if issubclass(t, MultiModalKwargsItems):
                return self._decode_mm_items(obj)
313
            if issubclass(t, MultiModalKwargs):
314
                return self._decode_mm_kwargs(obj)
315
316
            if t is UtilityResult:
                return self._decode_utility_result(obj)
317
318
        return obj

319
320
321
322
323
324
    def _decode_utility_result(self, obj: Any) -> UtilityResult:
        result_type, result = obj
        if result_type is not None:
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
                raise TypeError("VLLM_ALLOW_INSECURE_SERIALIZATION must "
                                "be set to use custom utility result types")
325
326
327
            # Use recursive decoding to handle nested structures
            result = _decode_type_info_recursive(result_type, result,
                                                 self._convert_result)
328
329
        return UtilityResult(result)

330
331
332
    def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
        if result_type is None:
            return result
333
334
335
336
337
        mod_name, name = result_type
        mod = importlib.import_module(mod_name)
        result_type = getattr(mod, name)
        return msgspec.convert(result, result_type, dec_hook=self.dec_hook)

338
339
    def _decode_ndarray(self, arr: Any) -> np.ndarray:
        dtype, shape, data = arr
340
341
342
        # zero-copy decode. We assume the ndarray will not be kept around,
        # as it now locks the whole received message buffer in memory.
        buffer = self.aux_buffers[data] if isinstance(data, int) else data
343
        return np.frombuffer(buffer, dtype=dtype).reshape(shape)
344
345
346
347
348
349

    def _decode_tensor(self, arr: Any) -> torch.Tensor:
        dtype, shape, data = arr
        # Copy from inline representation, to decouple the memory storage
        # of the message from the original buffer. And also make Torch
        # not complain about a readonly memoryview.
350
351
        buffer = self.aux_buffers[data] if isinstance(data, int) \
            else bytearray(data)
352
353
        torch_dtype = getattr(torch, dtype)
        assert isinstance(torch_dtype, torch.dtype)
354
355
356
357
358
        if not buffer:  # torch.frombuffer doesn't like empty buffers
            assert 0 in shape
            return torch.empty(shape, dtype=torch_dtype)
        # Create uint8 array
        arr = torch.frombuffer(buffer, dtype=torch.uint8)
359
        # Convert back to proper shape & type
360
        return arr.view(torch_dtype).view(shape)
361

362
363
364
365
366
    def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems:
        return MultiModalKwargsItems({
            modality: [self._decode_mm_item(item) for item in itemlist]
            for modality, itemlist in obj.items()
        })
367

368
    def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
369
370
371
        return MultiModalKwargsItem.from_elems(
            [self._decode_mm_field_elem(v) for v in obj])

372
373
374
375
376
    def _decode_mm_field_elem(self, obj: dict[str,
                                              Any]) -> MultiModalFieldElem:
        if obj["data"] is not None:
            obj["data"] = self._decode_nested_tensors(obj["data"])

377
378
379
380
381
382
383
384
385
386
387
        # Reconstruct the field processor using MultiModalFieldConfig
        factory_meth_name, *field_args = obj["field"]
        factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)

        # Special case: decode the union "slices" field of
        # MultiModalFlatField
        if factory_meth_name == "flat":
            field_args[0] = self._decode_nested_slices(field_args[0])

        obj["field"] = factory_meth(None, *field_args).field
        return MultiModalFieldElem(**obj)
388

389
390
391
392
393
394
    def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs:
        return MultiModalKwargs({
            modality: self._decode_nested_tensors(data)
            for modality, data in obj.items()
        })

395
396
397
398
399
400
401
402
    def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
        if isinstance(obj, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return obj
        if not isinstance(obj, list):
            raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
        if obj and isinstance(obj[0], str):
403
            return self._decode_tensor(obj)
404
405
        return [self._decode_nested_tensors(x) for x in obj]

406
407
408
409
410
411
    def _decode_nested_slices(self, obj: Any) -> Any:
        assert isinstance(obj, (list, tuple))
        if obj and not isinstance(obj[0], (list, tuple)):
            return slice(*obj)
        return [self._decode_nested_slices(x) for x in obj]

412
    def ext_hook(self, code: int, data: memoryview) -> Any:
413
414
        if code == CUSTOM_TYPE_RAW_VIEW:
            return data
415

416
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
417
418
419
420
            if code == CUSTOM_TYPE_PICKLE:
                return pickle.loads(data)
            if code == CUSTOM_TYPE_CLOUDPICKLE:
                return cloudpickle.loads(data)
421
422
423

        raise NotImplementedError(
            f"Extension type code {code} is not supported")