serial_utils.py 17.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
import importlib
6
import pickle
7
from collections.abc import Callable, Sequence
8
from functools import partial
9
from inspect import isclass
10
from types import FunctionType
11
from typing import Any, TypeAlias
12

13
import cloudpickle
14
import msgspec
15
import numpy as np
16
import torch
17
import zmq
18
19
from msgspec import msgpack

20
from vllm import envs
21
from vllm.logger import init_logger
22
23
24
25
26
27
28
29
30
31
32
33
from vllm.multimodal.inputs import (
    BaseMultiModalField,
    MultiModalBatchedField,
    MultiModalFieldConfig,
    MultiModalFieldElem,
    MultiModalFlatField,
    MultiModalKwargs,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    MultiModalSharedField,
    NestedTensors,
)
34
from vllm.v1.engine import UtilityResult
35
from vllm.v1.utils import tensor_data
36

37
38
logger = init_logger(__name__)

39
40
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
41
CUSTOM_TYPE_RAW_VIEW = 3
42

43
44
45
46
47
48
49
50
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
    MultiModalFlatField: "flat",
    MultiModalSharedField: "shared",
    MultiModalBatchedField: "batched",
}
51

52
bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame
53
54


55
def _log_insecure_serialization_warning():
56
57
58
59
    logger.warning_once(
        "Allowing insecure serialization using pickle due to "
        "VLLM_ALLOW_INSECURE_SERIALIZATION=1"
    )
60
61


62
def _typestr(val: Any) -> tuple[str, str] | None:
63
64
65
    if val is None:
        return None
    t = type(val)
66
67
68
    return t.__module__, t.__qualname__


69
70
71
72
73
74
75
76
77
78
79
80
81
def _encode_type_info_recursive(obj: Any) -> Any:
    """Recursively encode type information for nested structures of
    lists/dicts."""
    if obj is None:
        return None
    if type(obj) is list:
        return [_encode_type_info_recursive(item) for item in obj]
    if type(obj) is dict:
        return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
    return _typestr(obj)


def _decode_type_info_recursive(
82
83
    type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any], Any]
) -> Any:
84
85
86
87
88
89
90
91
92
93
94
    """Recursively decode type information for nested structures of
    lists/dicts."""
    if type_info is None:
        return data
    if isinstance(type_info, dict):
        assert isinstance(data, dict)
        return {
            k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
            for k in type_info
        }
    if isinstance(type_info, list) and (
95
96
97
        # Exclude serialized tensors/numpy arrays.
        len(type_info) != 2 or not isinstance(type_info[0], str)
    ):
98
99
100
101
102
103
104
105
        assert isinstance(data, list)
        return [
            _decode_type_info_recursive(ti, d, convert_fn)
            for ti, d in zip(type_info, data)
        ]
    return convert_fn(type_info, data)


106
107
class MsgpackEncoder:
    """Encoder with custom torch tensor and numpy array serialization.
108

109
110
    Note that unlike vanilla `msgspec` Encoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
111

112
    By default, arrays below 256B are serialized inline Larger will get sent
113
    via dedicated messages. Note that this is a per-tensor limit.
114
115
    """

116
    def __init__(self, size_threshold: int | None = None):
117
118
        if size_threshold is None:
            size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
119
120
121
122
        self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
        # This is used as a local stash of buffers that we can then access from
        # our custom `msgspec` hook, `enc_hook`. We don't have a way to
        # pass custom data to the hook otherwise.
123
        self.aux_buffers: list[bytestr] | None = None
124
        self.size_threshold = size_threshold
125
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
126
            _log_insecure_serialization_warning()
127
128
129

    def encode(self, obj: Any) -> Sequence[bytestr]:
        try:
130
            self.aux_buffers = bufs = [b""]
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
            bufs[0] = self.encoder.encode(obj)
            # This `bufs` list allows us to collect direct pointers to backing
            # buffers of tensors and np arrays, and return them along with the
            # top-level encoded buffer instead of copying their data into the
            # new buffer.
            return bufs
        finally:
            self.aux_buffers = None

    def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
        try:
            self.aux_buffers = [buf]
            bufs = self.aux_buffers
            self.encoder.encode_into(obj, buf)
            return bufs
        finally:
            self.aux_buffers = None

    def enc_hook(self, obj: Any) -> Any:
        if isinstance(obj, torch.Tensor):
151
            return self._encode_tensor(obj)
152
153

        # Fall back to pickle for object or void kind ndarrays.
154
        if isinstance(obj, np.ndarray) and obj.dtype.kind not in ("O", "V"):
155
156
            return self._encode_ndarray(obj)

157
158
159
160
        if isinstance(obj, slice):
            # We are assuming only int-based values will be used here.
            return tuple(
                int(v) if v is not None else None
161
162
                for v in (obj.start, obj.stop, obj.step)
            )
163

164
165
166
        if isinstance(obj, MultiModalKwargsItem):
            return self._encode_mm_item(obj)

167
168
169
        if isinstance(obj, MultiModalKwargsItems):
            return self._encode_mm_items(obj)

170
        if isinstance(obj, MultiModalKwargs):
171
            return self._encode_mm_kwargs(obj)
172

173
174
        if isinstance(obj, UtilityResult):
            result = obj.result
175
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
176
                return None, result
177
178
179
180
            # Since utility results are not strongly typed, we recursively
            # encode type information for nested structures of lists/dicts
            # to help with correct msgspec deserialization.
            return _encode_type_info_recursive(result), result
181

182
        if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
183
184
185
186
187
            raise TypeError(
                f"Object of type {type(obj)} is not serializable"
                "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
                "fallback to pickle-based serialization."
            )
188

189
190
191
192
193
        if isinstance(obj, FunctionType):
            # `pickle` is generally faster than cloudpickle, but can have
            # problems serializing methods.
            return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

194
195
196
        return msgpack.Ext(
            CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
        )
197
198
199

    def _encode_ndarray(
        self, obj: np.ndarray
200
    ) -> tuple[str, tuple[int, ...], int | memoryview]:
201
        assert self.aux_buffers is not None
202
        # If the array is non-contiguous, we need to copy it first
203
        arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
204
        if not obj.shape or obj.nbytes < self.size_threshold:
205
206
207
            # Encode small arrays and scalars inline. Using this extension type
            # ensures we can avoid copying when decoding.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
208
        else:
209
            # Otherwise encode index of backing buffer to avoid copy.
210
            data = len(self.aux_buffers)
211
212
            self.aux_buffers.append(arr_data)

213
214
215
216
        # We serialize the ndarray as a tuple of native types.
        # The data is either inlined if small, or an index into a list of
        # backing buffers that we've stashed in `aux_buffers`.
        return obj.dtype.str, obj.shape, data
217

218
219
    def _encode_tensor(
        self, obj: torch.Tensor
220
    ) -> tuple[str, tuple[int, ...], int | memoryview]:
221
        assert self.aux_buffers is not None
222
        # view the tensor as a contiguous 1D array of bytes
223
        arr_data = tensor_data(obj)
224
225
        if obj.nbytes < self.size_threshold:
            # Smaller tensors are encoded inline, just like ndarrays.
226
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
227
228
229
        else:
            # Otherwise encode index of backing buffer to avoid copy.
            data = len(self.aux_buffers)
230
            self.aux_buffers.append(arr_data)
231
        dtype = str(obj.dtype).removeprefix("torch.")
232
233
        return dtype, obj.shape, data

234
235
236
237
238
239
    def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]:
        return {
            modality: [self._encode_mm_item(item) for item in itemlist]
            for modality, itemlist in items.items()
        }

240
    def _encode_mm_item(self, item: MultiModalKwargsItem) -> list[dict[str, Any]]:
241
242
        return [self._encode_mm_field_elem(elem) for elem in item.values()]

243
    def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]:
244
        return {
245
246
247
248
249
250
            "modality": elem.modality,
            "key": elem.key,
            "data": (
                None if elem.data is None else self._encode_nested_tensors(elem.data)
            ),
            "field": self._encode_mm_field(elem.field),
251
252
        }

253
254
    def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]:
        return {
255
            modality: self._encode_nested_tensors(data) for modality, data in kw.items()
256
257
        }

258
259
    def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
        if isinstance(nt, torch.Tensor):
260
            return self._encode_tensor(nt)
261
262
263
264
265
266
267
268
269
270
271
272
273
        if isinstance(nt, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return nt
        return [self._encode_nested_tensors(x) for x in nt]

    def _encode_mm_field(self, field: BaseMultiModalField):
        # Figure out the factory name for the field type.
        name = MMF_CLASS_TO_FACTORY.get(field.__class__)
        if not name:
            raise TypeError(f"Unsupported field type: {field.__class__}")
        # We just need to copy all of the field values in order
        # which will be then used to reconstruct the field.
274
        field_values = (getattr(field, f.name) for f in dataclasses.fields(field))
275
276
        return name, *field_values

277
278

class MsgpackDecoder:
279
280
281
282
283
    """Decoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Decoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
    """
284

285
    def __init__(self, t: Any | None = None):
286
287
288
289
        args = () if t is None else (t,)
        self.decoder = msgpack.Decoder(
            *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
        )
290
        self.aux_buffers: Sequence[bytestr] = ()
291
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
292
            _log_insecure_serialization_warning()
293

294
    def decode(self, bufs: bytestr | Sequence[bytestr]) -> Any:
295
        if isinstance(bufs, bytestr):  # type: ignore
296
297
298
299
300
301
302
303
304
305
306
307
308
309
            return self.decoder.decode(bufs)

        self.aux_buffers = bufs
        try:
            return self.decoder.decode(bufs[0])
        finally:
            self.aux_buffers = ()

    def dec_hook(self, t: type, obj: Any) -> Any:
        # Given native types in `obj`, convert to type `t`.
        if isclass(t):
            if issubclass(t, np.ndarray):
                return self._decode_ndarray(obj)
            if issubclass(t, torch.Tensor):
310
                return self._decode_tensor(obj)
311
312
            if t is slice:
                return slice(*obj)
313
314
            if issubclass(t, MultiModalKwargsItem):
                return self._decode_mm_item(obj)
315
316
            if issubclass(t, MultiModalKwargsItems):
                return self._decode_mm_items(obj)
317
            if issubclass(t, MultiModalKwargs):
318
                return self._decode_mm_kwargs(obj)
319
320
            if t is UtilityResult:
                return self._decode_utility_result(obj)
321
322
        return obj

323
324
325
326
    def _decode_utility_result(self, obj: Any) -> UtilityResult:
        result_type, result = obj
        if result_type is not None:
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
327
328
329
330
                raise TypeError(
                    "VLLM_ALLOW_INSECURE_SERIALIZATION must "
                    "be set to use custom utility result types"
                )
331
            # Use recursive decoding to handle nested structures
332
333
334
            result = _decode_type_info_recursive(
                result_type, result, self._convert_result
            )
335
336
        return UtilityResult(result)

337
338
339
    def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
        if result_type is None:
            return result
340
341
342
343
344
        mod_name, name = result_type
        mod = importlib.import_module(mod_name)
        result_type = getattr(mod, name)
        return msgspec.convert(result, result_type, dec_hook=self.dec_hook)

345
346
    def _decode_ndarray(self, arr: Any) -> np.ndarray:
        dtype, shape, data = arr
347
348
349
        # zero-copy decode. We assume the ndarray will not be kept around,
        # as it now locks the whole received message buffer in memory.
        buffer = self.aux_buffers[data] if isinstance(data, int) else data
350
        return np.frombuffer(buffer, dtype=dtype).reshape(shape)
351
352
353
354
355
356

    def _decode_tensor(self, arr: Any) -> torch.Tensor:
        dtype, shape, data = arr
        # Copy from inline representation, to decouple the memory storage
        # of the message from the original buffer. And also make Torch
        # not complain about a readonly memoryview.
357
        buffer = self.aux_buffers[data] if isinstance(data, int) else bytearray(data)
358
359
        torch_dtype = getattr(torch, dtype)
        assert isinstance(torch_dtype, torch.dtype)
360
361
362
363
364
        if not buffer:  # torch.frombuffer doesn't like empty buffers
            assert 0 in shape
            return torch.empty(shape, dtype=torch_dtype)
        # Create uint8 array
        arr = torch.frombuffer(buffer, dtype=torch.uint8)
365
        # Convert back to proper shape & type
366
        return arr.view(torch_dtype).view(shape)
367

368
    def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems:
369
370
371
372
373
374
        return MultiModalKwargsItems(
            {
                modality: [self._decode_mm_item(item) for item in itemlist]
                for modality, itemlist in obj.items()
            }
        )
375

376
    def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
377
        return MultiModalKwargsItem.from_elems(
378
379
            [self._decode_mm_field_elem(v) for v in obj]
        )
380

381
    def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem:
382
383
384
        if obj["data"] is not None:
            obj["data"] = self._decode_nested_tensors(obj["data"])

385
386
387
388
389
390
391
392
393
394
395
        # Reconstruct the field processor using MultiModalFieldConfig
        factory_meth_name, *field_args = obj["field"]
        factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)

        # Special case: decode the union "slices" field of
        # MultiModalFlatField
        if factory_meth_name == "flat":
            field_args[0] = self._decode_nested_slices(field_args[0])

        obj["field"] = factory_meth(None, *field_args).field
        return MultiModalFieldElem(**obj)
396

397
    def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs:
398
399
400
401
402
403
        return MultiModalKwargs(
            {
                modality: self._decode_nested_tensors(data)
                for modality, data in obj.items()
            }
        )
404

405
406
407
408
409
410
411
412
    def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
        if isinstance(obj, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return obj
        if not isinstance(obj, list):
            raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
        if obj and isinstance(obj[0], str):
413
            return self._decode_tensor(obj)
414
415
        return [self._decode_nested_tensors(x) for x in obj]

416
417
418
419
420
421
    def _decode_nested_slices(self, obj: Any) -> Any:
        assert isinstance(obj, (list, tuple))
        if obj and not isinstance(obj[0], (list, tuple)):
            return slice(*obj)
        return [self._decode_nested_slices(x) for x in obj]

422
    def ext_hook(self, code: int, data: memoryview) -> Any:
423
424
        if code == CUSTOM_TYPE_RAW_VIEW:
            return data
425

426
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
427
428
429
430
            if code == CUSTOM_TYPE_PICKLE:
                return pickle.loads(data)
            if code == CUSTOM_TYPE_CLOUDPICKLE:
                return cloudpickle.loads(data)
431

432
        raise NotImplementedError(f"Extension type code {code} is not supported")
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459


def run_method(
    obj: Any,
    method: str | bytes | Callable,
    args: tuple[Any, ...],
    kwargs: dict[str, Any],
) -> Any:
    """
    Run a method of an object with the given arguments and keyword arguments.
    If the method is string, it will be converted to a method using getattr.
    If the method is serialized bytes and will be deserialized using
    cloudpickle.
    If the method is a callable, it will be called directly.
    """
    if isinstance(method, bytes):
        func = partial(cloudpickle.loads(method), obj)
    elif isinstance(method, str):
        try:
            func = getattr(obj, method)
        except AttributeError:
            raise NotImplementedError(
                f"Method {method!r} is not implemented."
            ) from None
    else:
        func = partial(method, obj)  # type: ignore
    return func(*args, **kwargs)