serial_utils.py 16.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
import importlib
6
import pickle
7
8
from collections.abc import Sequence
from inspect import isclass
9
from types import FunctionType
10
from typing import Any, Callable, Optional, Union
11

12
import cloudpickle
13
import msgspec
14
import numpy as np
15
import torch
16
import zmq
17
18
from msgspec import msgpack

19
from vllm import envs
20
from vllm.logger import init_logger
21
22
23
24
25
26
27
28
29
30
31
32
from vllm.multimodal.inputs import (
    BaseMultiModalField,
    MultiModalBatchedField,
    MultiModalFieldConfig,
    MultiModalFieldElem,
    MultiModalFlatField,
    MultiModalKwargs,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    MultiModalSharedField,
    NestedTensors,
)
33
from vllm.v1.engine import UtilityResult
34

35
36
logger = init_logger(__name__)

37
38
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
39
CUSTOM_TYPE_RAW_VIEW = 3
40

41
42
43
44
45
46
47
48
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
    MultiModalFlatField: "flat",
    MultiModalSharedField: "shared",
    MultiModalBatchedField: "batched",
}
49

50
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
51
52


53
def _log_insecure_serialization_warning():
54
55
56
57
    logger.warning_once(
        "Allowing insecure serialization using pickle due to "
        "VLLM_ALLOW_INSECURE_SERIALIZATION=1"
    )
58
59


60
61
62
63
def _typestr(val: Any) -> Optional[tuple[str, str]]:
    if val is None:
        return None
    t = type(val)
64
65
66
    return t.__module__, t.__qualname__


67
68
69
70
71
72
73
74
75
76
77
78
79
def _encode_type_info_recursive(obj: Any) -> Any:
    """Recursively encode type information for nested structures of
    lists/dicts."""
    if obj is None:
        return None
    if type(obj) is list:
        return [_encode_type_info_recursive(item) for item in obj]
    if type(obj) is dict:
        return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
    return _typestr(obj)


def _decode_type_info_recursive(
80
81
    type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any], Any]
) -> Any:
82
83
84
85
86
87
88
89
90
91
92
    """Recursively decode type information for nested structures of
    lists/dicts."""
    if type_info is None:
        return data
    if isinstance(type_info, dict):
        assert isinstance(data, dict)
        return {
            k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
            for k in type_info
        }
    if isinstance(type_info, list) and (
93
94
95
        # Exclude serialized tensors/numpy arrays.
        len(type_info) != 2 or not isinstance(type_info[0], str)
    ):
96
97
98
99
100
101
102
103
        assert isinstance(data, list)
        return [
            _decode_type_info_recursive(ti, d, convert_fn)
            for ti, d in zip(type_info, data)
        ]
    return convert_fn(type_info, data)


104
105
class MsgpackEncoder:
    """Encoder with custom torch tensor and numpy array serialization.
106

107
108
    Note that unlike vanilla `msgspec` Encoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
109

110
    By default, arrays below 256B are serialized inline Larger will get sent
111
    via dedicated messages. Note that this is a per-tensor limit.
112
113
    """

114
    def __init__(self, size_threshold: Optional[int] = None):
115
116
        if size_threshold is None:
            size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
117
118
119
120
121
        self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
        # This is used as a local stash of buffers that we can then access from
        # our custom `msgspec` hook, `enc_hook`. We don't have a way to
        # pass custom data to the hook otherwise.
        self.aux_buffers: Optional[list[bytestr]] = None
122
        self.size_threshold = size_threshold
123
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
124
            _log_insecure_serialization_warning()
125
126
127

    def encode(self, obj: Any) -> Sequence[bytestr]:
        try:
128
            self.aux_buffers = bufs = [b""]
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
            bufs[0] = self.encoder.encode(obj)
            # This `bufs` list allows us to collect direct pointers to backing
            # buffers of tensors and np arrays, and return them along with the
            # top-level encoded buffer instead of copying their data into the
            # new buffer.
            return bufs
        finally:
            self.aux_buffers = None

    def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
        try:
            self.aux_buffers = [buf]
            bufs = self.aux_buffers
            self.encoder.encode_into(obj, buf)
            return bufs
        finally:
            self.aux_buffers = None

    def enc_hook(self, obj: Any) -> Any:
        if isinstance(obj, torch.Tensor):
149
            return self._encode_tensor(obj)
150
151

        # Fall back to pickle for object or void kind ndarrays.
152
        if isinstance(obj, np.ndarray) and obj.dtype.kind not in ("O", "V"):
153
154
            return self._encode_ndarray(obj)

155
156
157
158
        if isinstance(obj, slice):
            # We are assuming only int-based values will be used here.
            return tuple(
                int(v) if v is not None else None
159
160
                for v in (obj.start, obj.stop, obj.step)
            )
161

162
163
164
        if isinstance(obj, MultiModalKwargsItem):
            return self._encode_mm_item(obj)

165
166
167
        if isinstance(obj, MultiModalKwargsItems):
            return self._encode_mm_items(obj)

168
        if isinstance(obj, MultiModalKwargs):
169
            return self._encode_mm_kwargs(obj)
170

171
172
        if isinstance(obj, UtilityResult):
            result = obj.result
173
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
174
                return None, result
175
176
177
178
            # Since utility results are not strongly typed, we recursively
            # encode type information for nested structures of lists/dicts
            # to help with correct msgspec deserialization.
            return _encode_type_info_recursive(result), result
179

180
        if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
181
182
183
184
185
            raise TypeError(
                f"Object of type {type(obj)} is not serializable"
                "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
                "fallback to pickle-based serialization."
            )
186

187
188
189
190
191
        if isinstance(obj, FunctionType):
            # `pickle` is generally faster than cloudpickle, but can have
            # problems serializing methods.
            return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

192
193
194
        return msgpack.Ext(
            CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
        )
195
196
197
198
199

    def _encode_ndarray(
        self, obj: np.ndarray
    ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
        assert self.aux_buffers is not None
200
        # If the array is non-contiguous, we need to copy it first
201
        arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
202
        if not obj.shape or obj.nbytes < self.size_threshold:
203
204
205
            # Encode small arrays and scalars inline. Using this extension type
            # ensures we can avoid copying when decoding.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
206
        else:
207
            # Otherwise encode index of backing buffer to avoid copy.
208
            data = len(self.aux_buffers)
209
210
            self.aux_buffers.append(arr_data)

211
212
213
214
        # We serialize the ndarray as a tuple of native types.
        # The data is either inlined if small, or an index into a list of
        # backing buffers that we've stashed in `aux_buffers`.
        return obj.dtype.str, obj.shape, data
215

216
217
218
219
    def _encode_tensor(
        self, obj: torch.Tensor
    ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
        assert self.aux_buffers is not None
220
        # view the tensor as a contiguous 1D array of bytes
221
        arr = obj.flatten().contiguous().view(torch.uint8).numpy()
222
223
224
225
226
227
228
        if obj.nbytes < self.size_threshold:
            # Smaller tensors are encoded inline, just like ndarrays.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
        else:
            # Otherwise encode index of backing buffer to avoid copy.
            data = len(self.aux_buffers)
            self.aux_buffers.append(arr.data)
229
        dtype = str(obj.dtype).removeprefix("torch.")
230
231
        return dtype, obj.shape, data

232
233
234
235
236
237
    def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]:
        return {
            modality: [self._encode_mm_item(item) for item in itemlist]
            for modality, itemlist in items.items()
        }

238
    def _encode_mm_item(self, item: MultiModalKwargsItem) -> list[dict[str, Any]]:
239
240
        return [self._encode_mm_field_elem(elem) for elem in item.values()]

241
    def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]:
242
        return {
243
244
245
246
247
248
            "modality": elem.modality,
            "key": elem.key,
            "data": (
                None if elem.data is None else self._encode_nested_tensors(elem.data)
            ),
            "field": self._encode_mm_field(elem.field),
249
250
        }

251
252
    def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]:
        return {
253
            modality: self._encode_nested_tensors(data) for modality, data in kw.items()
254
255
        }

256
257
    def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
        if isinstance(nt, torch.Tensor):
258
            return self._encode_tensor(nt)
259
260
261
262
263
264
265
266
267
268
269
270
271
        if isinstance(nt, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return nt
        return [self._encode_nested_tensors(x) for x in nt]

    def _encode_mm_field(self, field: BaseMultiModalField):
        # Figure out the factory name for the field type.
        name = MMF_CLASS_TO_FACTORY.get(field.__class__)
        if not name:
            raise TypeError(f"Unsupported field type: {field.__class__}")
        # We just need to copy all of the field values in order
        # which will be then used to reconstruct the field.
272
        field_values = (getattr(field, f.name) for f in dataclasses.fields(field))
273
274
        return name, *field_values

275
276

class MsgpackDecoder:
277
278
279
280
281
    """Decoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Decoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
    """
282

283
    def __init__(self, t: Optional[Any] = None):
284
285
286
287
        args = () if t is None else (t,)
        self.decoder = msgpack.Decoder(
            *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
        )
288
        self.aux_buffers: Sequence[bytestr] = ()
289
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
290
            _log_insecure_serialization_warning()
291
292

    def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
293
        if isinstance(bufs, bytestr):  # type: ignore
294
295
296
297
298
299
300
301
302
303
304
305
306
307
            return self.decoder.decode(bufs)

        self.aux_buffers = bufs
        try:
            return self.decoder.decode(bufs[0])
        finally:
            self.aux_buffers = ()

    def dec_hook(self, t: type, obj: Any) -> Any:
        # Given native types in `obj`, convert to type `t`.
        if isclass(t):
            if issubclass(t, np.ndarray):
                return self._decode_ndarray(obj)
            if issubclass(t, torch.Tensor):
308
                return self._decode_tensor(obj)
309
310
            if t is slice:
                return slice(*obj)
311
312
            if issubclass(t, MultiModalKwargsItem):
                return self._decode_mm_item(obj)
313
314
            if issubclass(t, MultiModalKwargsItems):
                return self._decode_mm_items(obj)
315
            if issubclass(t, MultiModalKwargs):
316
                return self._decode_mm_kwargs(obj)
317
318
            if t is UtilityResult:
                return self._decode_utility_result(obj)
319
320
        return obj

321
322
323
324
    def _decode_utility_result(self, obj: Any) -> UtilityResult:
        result_type, result = obj
        if result_type is not None:
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
325
326
327
328
                raise TypeError(
                    "VLLM_ALLOW_INSECURE_SERIALIZATION must "
                    "be set to use custom utility result types"
                )
329
            # Use recursive decoding to handle nested structures
330
331
332
            result = _decode_type_info_recursive(
                result_type, result, self._convert_result
            )
333
334
        return UtilityResult(result)

335
336
337
    def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
        if result_type is None:
            return result
338
339
340
341
342
        mod_name, name = result_type
        mod = importlib.import_module(mod_name)
        result_type = getattr(mod, name)
        return msgspec.convert(result, result_type, dec_hook=self.dec_hook)

343
344
    def _decode_ndarray(self, arr: Any) -> np.ndarray:
        dtype, shape, data = arr
345
346
347
        # zero-copy decode. We assume the ndarray will not be kept around,
        # as it now locks the whole received message buffer in memory.
        buffer = self.aux_buffers[data] if isinstance(data, int) else data
348
        return np.frombuffer(buffer, dtype=dtype).reshape(shape)
349
350
351
352
353
354

    def _decode_tensor(self, arr: Any) -> torch.Tensor:
        dtype, shape, data = arr
        # Copy from inline representation, to decouple the memory storage
        # of the message from the original buffer. And also make Torch
        # not complain about a readonly memoryview.
355
        buffer = self.aux_buffers[data] if isinstance(data, int) else bytearray(data)
356
357
        torch_dtype = getattr(torch, dtype)
        assert isinstance(torch_dtype, torch.dtype)
358
359
360
361
362
        if not buffer:  # torch.frombuffer doesn't like empty buffers
            assert 0 in shape
            return torch.empty(shape, dtype=torch_dtype)
        # Create uint8 array
        arr = torch.frombuffer(buffer, dtype=torch.uint8)
363
        # Convert back to proper shape & type
364
        return arr.view(torch_dtype).view(shape)
365

366
    def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems:
367
368
369
370
371
372
        return MultiModalKwargsItems(
            {
                modality: [self._decode_mm_item(item) for item in itemlist]
                for modality, itemlist in obj.items()
            }
        )
373

374
    def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
375
        return MultiModalKwargsItem.from_elems(
376
377
            [self._decode_mm_field_elem(v) for v in obj]
        )
378

379
    def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem:
380
381
382
        if obj["data"] is not None:
            obj["data"] = self._decode_nested_tensors(obj["data"])

383
384
385
386
387
388
389
390
391
392
393
        # Reconstruct the field processor using MultiModalFieldConfig
        factory_meth_name, *field_args = obj["field"]
        factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)

        # Special case: decode the union "slices" field of
        # MultiModalFlatField
        if factory_meth_name == "flat":
            field_args[0] = self._decode_nested_slices(field_args[0])

        obj["field"] = factory_meth(None, *field_args).field
        return MultiModalFieldElem(**obj)
394

395
    def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs:
396
397
398
399
400
401
        return MultiModalKwargs(
            {
                modality: self._decode_nested_tensors(data)
                for modality, data in obj.items()
            }
        )
402

403
404
405
406
407
408
409
410
    def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
        if isinstance(obj, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return obj
        if not isinstance(obj, list):
            raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
        if obj and isinstance(obj[0], str):
411
            return self._decode_tensor(obj)
412
413
        return [self._decode_nested_tensors(x) for x in obj]

414
415
416
417
418
419
    def _decode_nested_slices(self, obj: Any) -> Any:
        assert isinstance(obj, (list, tuple))
        if obj and not isinstance(obj[0], (list, tuple)):
            return slice(*obj)
        return [self._decode_nested_slices(x) for x in obj]

420
    def ext_hook(self, code: int, data: memoryview) -> Any:
421
422
        if code == CUSTOM_TYPE_RAW_VIEW:
            return data
423

424
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
425
426
427
428
            if code == CUSTOM_TYPE_PICKLE:
                return pickle.loads(data)
            if code == CUSTOM_TYPE_CLOUDPICKLE:
                return cloudpickle.loads(data)
429

430
        raise NotImplementedError(f"Extension type code {code} is not supported")