serial_utils.py 16.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
import importlib
6
import pickle
7
from collections.abc import Callable, Sequence
8
from inspect import isclass
9
from types import FunctionType
10
from typing import Any, TypeAlias
11

12
import cloudpickle
13
import msgspec
14
import numpy as np
15
import torch
16
import zmq
17
18
from msgspec import msgpack

19
from vllm import envs
20
from vllm.logger import init_logger
21
22
23
24
25
26
27
28
29
30
31
32
from vllm.multimodal.inputs import (
    BaseMultiModalField,
    MultiModalBatchedField,
    MultiModalFieldConfig,
    MultiModalFieldElem,
    MultiModalFlatField,
    MultiModalKwargs,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    MultiModalSharedField,
    NestedTensors,
)
33
from vllm.v1.engine import UtilityResult
34
from vllm.v1.utils import tensor_data
35

36
37
logger = init_logger(__name__)

38
39
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
40
CUSTOM_TYPE_RAW_VIEW = 3
41

42
43
44
45
46
47
48
49
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
    MultiModalFlatField: "flat",
    MultiModalSharedField: "shared",
    MultiModalBatchedField: "batched",
}
50

51
bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame
52
53


54
def _log_insecure_serialization_warning():
55
56
57
58
    logger.warning_once(
        "Allowing insecure serialization using pickle due to "
        "VLLM_ALLOW_INSECURE_SERIALIZATION=1"
    )
59
60


61
def _typestr(val: Any) -> tuple[str, str] | None:
62
63
64
    if val is None:
        return None
    t = type(val)
65
66
67
    return t.__module__, t.__qualname__


68
69
70
71
72
73
74
75
76
77
78
79
80
def _encode_type_info_recursive(obj: Any) -> Any:
    """Recursively encode type information for nested structures of
    lists/dicts."""
    if obj is None:
        return None
    if type(obj) is list:
        return [_encode_type_info_recursive(item) for item in obj]
    if type(obj) is dict:
        return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
    return _typestr(obj)


def _decode_type_info_recursive(
81
82
    type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any], Any]
) -> Any:
83
84
85
86
87
88
89
90
91
92
93
    """Recursively decode type information for nested structures of
    lists/dicts."""
    if type_info is None:
        return data
    if isinstance(type_info, dict):
        assert isinstance(data, dict)
        return {
            k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
            for k in type_info
        }
    if isinstance(type_info, list) and (
94
95
96
        # Exclude serialized tensors/numpy arrays.
        len(type_info) != 2 or not isinstance(type_info[0], str)
    ):
97
98
99
100
101
102
103
104
        assert isinstance(data, list)
        return [
            _decode_type_info_recursive(ti, d, convert_fn)
            for ti, d in zip(type_info, data)
        ]
    return convert_fn(type_info, data)


105
106
class MsgpackEncoder:
    """Encoder with custom torch tensor and numpy array serialization.
107

108
109
    Note that unlike vanilla `msgspec` Encoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
110

111
    By default, arrays below 256B are serialized inline Larger will get sent
112
    via dedicated messages. Note that this is a per-tensor limit.
113
114
    """

115
    def __init__(self, size_threshold: int | None = None):
116
117
        if size_threshold is None:
            size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
118
119
120
121
        self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
        # This is used as a local stash of buffers that we can then access from
        # our custom `msgspec` hook, `enc_hook`. We don't have a way to
        # pass custom data to the hook otherwise.
122
        self.aux_buffers: list[bytestr] | None = None
123
        self.size_threshold = size_threshold
124
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
125
            _log_insecure_serialization_warning()
126
127
128

    def encode(self, obj: Any) -> Sequence[bytestr]:
        try:
129
            self.aux_buffers = bufs = [b""]
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
            bufs[0] = self.encoder.encode(obj)
            # This `bufs` list allows us to collect direct pointers to backing
            # buffers of tensors and np arrays, and return them along with the
            # top-level encoded buffer instead of copying their data into the
            # new buffer.
            return bufs
        finally:
            self.aux_buffers = None

    def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
        try:
            self.aux_buffers = [buf]
            bufs = self.aux_buffers
            self.encoder.encode_into(obj, buf)
            return bufs
        finally:
            self.aux_buffers = None

    def enc_hook(self, obj: Any) -> Any:
        if isinstance(obj, torch.Tensor):
150
            return self._encode_tensor(obj)
151
152

        # Fall back to pickle for object or void kind ndarrays.
153
        if isinstance(obj, np.ndarray) and obj.dtype.kind not in ("O", "V"):
154
155
            return self._encode_ndarray(obj)

156
157
158
159
        if isinstance(obj, slice):
            # We are assuming only int-based values will be used here.
            return tuple(
                int(v) if v is not None else None
160
161
                for v in (obj.start, obj.stop, obj.step)
            )
162

163
164
165
        if isinstance(obj, MultiModalKwargsItem):
            return self._encode_mm_item(obj)

166
167
168
        if isinstance(obj, MultiModalKwargsItems):
            return self._encode_mm_items(obj)

169
        if isinstance(obj, MultiModalKwargs):
170
            return self._encode_mm_kwargs(obj)
171

172
173
        if isinstance(obj, UtilityResult):
            result = obj.result
174
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
175
                return None, result
176
177
178
179
            # Since utility results are not strongly typed, we recursively
            # encode type information for nested structures of lists/dicts
            # to help with correct msgspec deserialization.
            return _encode_type_info_recursive(result), result
180

181
        if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
182
183
184
185
186
            raise TypeError(
                f"Object of type {type(obj)} is not serializable"
                "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
                "fallback to pickle-based serialization."
            )
187

188
189
190
191
192
        if isinstance(obj, FunctionType):
            # `pickle` is generally faster than cloudpickle, but can have
            # problems serializing methods.
            return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

193
194
195
        return msgpack.Ext(
            CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
        )
196
197
198

    def _encode_ndarray(
        self, obj: np.ndarray
199
    ) -> tuple[str, tuple[int, ...], int | memoryview]:
200
        assert self.aux_buffers is not None
201
        # If the array is non-contiguous, we need to copy it first
202
        arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
203
        if not obj.shape or obj.nbytes < self.size_threshold:
204
205
206
            # Encode small arrays and scalars inline. Using this extension type
            # ensures we can avoid copying when decoding.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
207
        else:
208
            # Otherwise encode index of backing buffer to avoid copy.
209
            data = len(self.aux_buffers)
210
211
            self.aux_buffers.append(arr_data)

212
213
214
215
        # We serialize the ndarray as a tuple of native types.
        # The data is either inlined if small, or an index into a list of
        # backing buffers that we've stashed in `aux_buffers`.
        return obj.dtype.str, obj.shape, data
216

217
218
    def _encode_tensor(
        self, obj: torch.Tensor
219
    ) -> tuple[str, tuple[int, ...], int | memoryview]:
220
        assert self.aux_buffers is not None
221
        # view the tensor as a contiguous 1D array of bytes
222
        arr_data = tensor_data(obj)
223
224
        if obj.nbytes < self.size_threshold:
            # Smaller tensors are encoded inline, just like ndarrays.
225
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
226
227
228
        else:
            # Otherwise encode index of backing buffer to avoid copy.
            data = len(self.aux_buffers)
229
            self.aux_buffers.append(arr_data)
230
        dtype = str(obj.dtype).removeprefix("torch.")
231
232
        return dtype, obj.shape, data

233
234
235
236
237
238
    def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]:
        return {
            modality: [self._encode_mm_item(item) for item in itemlist]
            for modality, itemlist in items.items()
        }

239
    def _encode_mm_item(self, item: MultiModalKwargsItem) -> list[dict[str, Any]]:
240
241
        return [self._encode_mm_field_elem(elem) for elem in item.values()]

242
    def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]:
243
        return {
244
245
246
247
248
249
            "modality": elem.modality,
            "key": elem.key,
            "data": (
                None if elem.data is None else self._encode_nested_tensors(elem.data)
            ),
            "field": self._encode_mm_field(elem.field),
250
251
        }

252
253
    def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]:
        return {
254
            modality: self._encode_nested_tensors(data) for modality, data in kw.items()
255
256
        }

257
258
    def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
        if isinstance(nt, torch.Tensor):
259
            return self._encode_tensor(nt)
260
261
262
263
264
265
266
267
268
269
270
271
272
        if isinstance(nt, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return nt
        return [self._encode_nested_tensors(x) for x in nt]

    def _encode_mm_field(self, field: BaseMultiModalField):
        # Figure out the factory name for the field type.
        name = MMF_CLASS_TO_FACTORY.get(field.__class__)
        if not name:
            raise TypeError(f"Unsupported field type: {field.__class__}")
        # We just need to copy all of the field values in order
        # which will be then used to reconstruct the field.
273
        field_values = (getattr(field, f.name) for f in dataclasses.fields(field))
274
275
        return name, *field_values

276
277

class MsgpackDecoder:
278
279
280
281
282
    """Decoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Decoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
    """
283

284
    def __init__(self, t: Any | None = None):
285
286
287
288
        args = () if t is None else (t,)
        self.decoder = msgpack.Decoder(
            *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
        )
289
        self.aux_buffers: Sequence[bytestr] = ()
290
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
291
            _log_insecure_serialization_warning()
292

293
    def decode(self, bufs: bytestr | Sequence[bytestr]) -> Any:
294
        if isinstance(bufs, bytestr):  # type: ignore
295
296
297
298
299
300
301
302
303
304
305
306
307
308
            return self.decoder.decode(bufs)

        self.aux_buffers = bufs
        try:
            return self.decoder.decode(bufs[0])
        finally:
            self.aux_buffers = ()

    def dec_hook(self, t: type, obj: Any) -> Any:
        # Given native types in `obj`, convert to type `t`.
        if isclass(t):
            if issubclass(t, np.ndarray):
                return self._decode_ndarray(obj)
            if issubclass(t, torch.Tensor):
309
                return self._decode_tensor(obj)
310
311
            if t is slice:
                return slice(*obj)
312
313
            if issubclass(t, MultiModalKwargsItem):
                return self._decode_mm_item(obj)
314
315
            if issubclass(t, MultiModalKwargsItems):
                return self._decode_mm_items(obj)
316
            if issubclass(t, MultiModalKwargs):
317
                return self._decode_mm_kwargs(obj)
318
319
            if t is UtilityResult:
                return self._decode_utility_result(obj)
320
321
        return obj

322
323
324
325
    def _decode_utility_result(self, obj: Any) -> UtilityResult:
        result_type, result = obj
        if result_type is not None:
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
326
327
328
329
                raise TypeError(
                    "VLLM_ALLOW_INSECURE_SERIALIZATION must "
                    "be set to use custom utility result types"
                )
330
            # Use recursive decoding to handle nested structures
331
332
333
            result = _decode_type_info_recursive(
                result_type, result, self._convert_result
            )
334
335
        return UtilityResult(result)

336
337
338
    def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
        if result_type is None:
            return result
339
340
341
342
343
        mod_name, name = result_type
        mod = importlib.import_module(mod_name)
        result_type = getattr(mod, name)
        return msgspec.convert(result, result_type, dec_hook=self.dec_hook)

344
345
    def _decode_ndarray(self, arr: Any) -> np.ndarray:
        dtype, shape, data = arr
346
347
348
        # zero-copy decode. We assume the ndarray will not be kept around,
        # as it now locks the whole received message buffer in memory.
        buffer = self.aux_buffers[data] if isinstance(data, int) else data
349
        return np.frombuffer(buffer, dtype=dtype).reshape(shape)
350
351
352
353
354
355

    def _decode_tensor(self, arr: Any) -> torch.Tensor:
        dtype, shape, data = arr
        # Copy from inline representation, to decouple the memory storage
        # of the message from the original buffer. And also make Torch
        # not complain about a readonly memoryview.
356
        buffer = self.aux_buffers[data] if isinstance(data, int) else bytearray(data)
357
358
        torch_dtype = getattr(torch, dtype)
        assert isinstance(torch_dtype, torch.dtype)
359
360
361
362
363
        if not buffer:  # torch.frombuffer doesn't like empty buffers
            assert 0 in shape
            return torch.empty(shape, dtype=torch_dtype)
        # Create uint8 array
        arr = torch.frombuffer(buffer, dtype=torch.uint8)
364
        # Convert back to proper shape & type
365
        return arr.view(torch_dtype).view(shape)
366

367
    def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems:
368
369
370
371
372
373
        return MultiModalKwargsItems(
            {
                modality: [self._decode_mm_item(item) for item in itemlist]
                for modality, itemlist in obj.items()
            }
        )
374

375
    def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
376
        return MultiModalKwargsItem.from_elems(
377
378
            [self._decode_mm_field_elem(v) for v in obj]
        )
379

380
    def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem:
381
382
383
        if obj["data"] is not None:
            obj["data"] = self._decode_nested_tensors(obj["data"])

384
385
386
387
388
389
390
391
392
393
394
        # Reconstruct the field processor using MultiModalFieldConfig
        factory_meth_name, *field_args = obj["field"]
        factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)

        # Special case: decode the union "slices" field of
        # MultiModalFlatField
        if factory_meth_name == "flat":
            field_args[0] = self._decode_nested_slices(field_args[0])

        obj["field"] = factory_meth(None, *field_args).field
        return MultiModalFieldElem(**obj)
395

396
    def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs:
397
398
399
400
401
402
        return MultiModalKwargs(
            {
                modality: self._decode_nested_tensors(data)
                for modality, data in obj.items()
            }
        )
403

404
405
406
407
408
409
410
411
    def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
        if isinstance(obj, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return obj
        if not isinstance(obj, list):
            raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
        if obj and isinstance(obj[0], str):
412
            return self._decode_tensor(obj)
413
414
        return [self._decode_nested_tensors(x) for x in obj]

415
416
417
418
419
420
    def _decode_nested_slices(self, obj: Any) -> Any:
        assert isinstance(obj, (list, tuple))
        if obj and not isinstance(obj[0], (list, tuple)):
            return slice(*obj)
        return [self._decode_nested_slices(x) for x in obj]

421
    def ext_hook(self, code: int, data: memoryview) -> Any:
422
423
        if code == CUSTOM_TYPE_RAW_VIEW:
            return data
424

425
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
426
427
428
429
            if code == CUSTOM_TYPE_PICKLE:
                return pickle.loads(data)
            if code == CUSTOM_TYPE_CLOUDPICKLE:
                return cloudpickle.loads(data)
430

431
        raise NotImplementedError(f"Extension type code {code} is not supported")