serial_utils.py 17.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
import importlib
6
import pickle
7
from collections.abc import Callable, Sequence
8
from functools import partial
9
from inspect import isclass
10
from types import FunctionType
11
from typing import Any, TypeAlias
12

13
import cloudpickle
14
import msgspec
15
import numpy as np
16
import torch
17
import zmq
18
19
from msgspec import msgpack

20
from vllm import envs
21
from vllm.logger import init_logger
22
23
24
25
26
27
28
29
30
31
32
33
from vllm.multimodal.inputs import (
    BaseMultiModalField,
    MultiModalBatchedField,
    MultiModalFieldConfig,
    MultiModalFieldElem,
    MultiModalFlatField,
    MultiModalKwargs,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    MultiModalSharedField,
    NestedTensors,
)
34
from vllm.utils.platform_utils import is_pin_memory_available
35
from vllm.v1.engine import UtilityResult
36
from vllm.v1.utils import tensor_data
37

38
39
logger = init_logger(__name__)

40
41
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
42
CUSTOM_TYPE_RAW_VIEW = 3
43

44
45
46
47
48
49
50
51
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
    MultiModalFlatField: "flat",
    MultiModalSharedField: "shared",
    MultiModalBatchedField: "batched",
}
52

53
bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame
54
55


56
def _log_insecure_serialization_warning():
57
58
59
60
    logger.warning_once(
        "Allowing insecure serialization using pickle due to "
        "VLLM_ALLOW_INSECURE_SERIALIZATION=1"
    )
61
62


63
def _typestr(val: Any) -> tuple[str, str] | None:
64
65
66
    if val is None:
        return None
    t = type(val)
67
68
69
    return t.__module__, t.__qualname__


70
71
72
73
74
75
76
77
78
79
80
81
82
def _encode_type_info_recursive(obj: Any) -> Any:
    """Recursively encode type information for nested structures of
    lists/dicts."""
    if obj is None:
        return None
    if type(obj) is list:
        return [_encode_type_info_recursive(item) for item in obj]
    if type(obj) is dict:
        return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
    return _typestr(obj)


def _decode_type_info_recursive(
83
84
    type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any], Any]
) -> Any:
85
86
87
88
89
90
91
92
93
94
95
    """Recursively decode type information for nested structures of
    lists/dicts."""
    if type_info is None:
        return data
    if isinstance(type_info, dict):
        assert isinstance(data, dict)
        return {
            k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
            for k in type_info
        }
    if isinstance(type_info, list) and (
96
97
98
        # Exclude serialized tensors/numpy arrays.
        len(type_info) != 2 or not isinstance(type_info[0], str)
    ):
99
100
101
102
103
104
105
106
        assert isinstance(data, list)
        return [
            _decode_type_info_recursive(ti, d, convert_fn)
            for ti, d in zip(type_info, data)
        ]
    return convert_fn(type_info, data)


107
108
class MsgpackEncoder:
    """Encoder with custom torch tensor and numpy array serialization.
109

110
111
    Note that unlike vanilla `msgspec` Encoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
112

113
    By default, arrays below 256B are serialized inline Larger will get sent
114
    via dedicated messages. Note that this is a per-tensor limit.
115
116
    """

117
    def __init__(self, size_threshold: int | None = None):
118
119
        if size_threshold is None:
            size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
120
121
122
123
        self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
        # This is used as a local stash of buffers that we can then access from
        # our custom `msgspec` hook, `enc_hook`. We don't have a way to
        # pass custom data to the hook otherwise.
124
        self.aux_buffers: list[bytestr] | None = None
125
        self.size_threshold = size_threshold
126
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
127
            _log_insecure_serialization_warning()
128
129
130

    def encode(self, obj: Any) -> Sequence[bytestr]:
        try:
131
            self.aux_buffers = bufs = [b""]
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
            bufs[0] = self.encoder.encode(obj)
            # This `bufs` list allows us to collect direct pointers to backing
            # buffers of tensors and np arrays, and return them along with the
            # top-level encoded buffer instead of copying their data into the
            # new buffer.
            return bufs
        finally:
            self.aux_buffers = None

    def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
        try:
            self.aux_buffers = [buf]
            bufs = self.aux_buffers
            self.encoder.encode_into(obj, buf)
            return bufs
        finally:
            self.aux_buffers = None

    def enc_hook(self, obj: Any) -> Any:
        if isinstance(obj, torch.Tensor):
152
            return self._encode_tensor(obj)
153
154

        # Fall back to pickle for object or void kind ndarrays.
155
        if isinstance(obj, np.ndarray) and obj.dtype.kind not in ("O", "V"):
156
157
            return self._encode_ndarray(obj)

158
159
160
161
        if isinstance(obj, slice):
            # We are assuming only int-based values will be used here.
            return tuple(
                int(v) if v is not None else None
162
163
                for v in (obj.start, obj.stop, obj.step)
            )
164

165
166
167
        if isinstance(obj, MultiModalKwargsItem):
            return self._encode_mm_item(obj)

168
169
170
        if isinstance(obj, MultiModalKwargsItems):
            return self._encode_mm_items(obj)

171
        if isinstance(obj, MultiModalKwargs):
172
            return self._encode_mm_kwargs(obj)
173

174
175
        if isinstance(obj, UtilityResult):
            result = obj.result
176
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
177
                return None, result
178
179
180
181
            # Since utility results are not strongly typed, we recursively
            # encode type information for nested structures of lists/dicts
            # to help with correct msgspec deserialization.
            return _encode_type_info_recursive(result), result
182

183
        if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
184
185
186
187
188
            raise TypeError(
                f"Object of type {type(obj)} is not serializable"
                "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
                "fallback to pickle-based serialization."
            )
189

190
191
192
193
194
        if isinstance(obj, FunctionType):
            # `pickle` is generally faster than cloudpickle, but can have
            # problems serializing methods.
            return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

195
196
197
        return msgpack.Ext(
            CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
        )
198
199
200

    def _encode_ndarray(
        self, obj: np.ndarray
201
    ) -> tuple[str, tuple[int, ...], int | memoryview]:
202
        assert self.aux_buffers is not None
203
        # If the array is non-contiguous, we need to copy it first
204
        arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
205
        if not obj.shape or obj.nbytes < self.size_threshold:
206
207
208
            # Encode small arrays and scalars inline. Using this extension type
            # ensures we can avoid copying when decoding.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
209
        else:
210
            # Otherwise encode index of backing buffer to avoid copy.
211
            data = len(self.aux_buffers)
212
213
            self.aux_buffers.append(arr_data)

214
215
216
217
        # We serialize the ndarray as a tuple of native types.
        # The data is either inlined if small, or an index into a list of
        # backing buffers that we've stashed in `aux_buffers`.
        return obj.dtype.str, obj.shape, data
218

219
220
    def _encode_tensor(
        self, obj: torch.Tensor
221
    ) -> tuple[str, tuple[int, ...], int | memoryview]:
222
        assert self.aux_buffers is not None
223
        # view the tensor as a contiguous 1D array of bytes
224
        arr_data = tensor_data(obj)
225
226
        if obj.nbytes < self.size_threshold:
            # Smaller tensors are encoded inline, just like ndarrays.
227
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
228
229
230
        else:
            # Otherwise encode index of backing buffer to avoid copy.
            data = len(self.aux_buffers)
231
            self.aux_buffers.append(arr_data)
232
        dtype = str(obj.dtype).removeprefix("torch.")
233
234
        return dtype, obj.shape, data

235
236
237
238
239
240
    def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]:
        return {
            modality: [self._encode_mm_item(item) for item in itemlist]
            for modality, itemlist in items.items()
        }

241
    def _encode_mm_item(self, item: MultiModalKwargsItem) -> list[dict[str, Any]]:
242
243
        return [self._encode_mm_field_elem(elem) for elem in item.values()]

244
    def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]:
245
        return {
246
247
248
249
250
251
            "modality": elem.modality,
            "key": elem.key,
            "data": (
                None if elem.data is None else self._encode_nested_tensors(elem.data)
            ),
            "field": self._encode_mm_field(elem.field),
252
253
        }

254
255
    def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]:
        return {
256
            modality: self._encode_nested_tensors(data) for modality, data in kw.items()
257
258
        }

259
260
    def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
        if isinstance(nt, torch.Tensor):
261
            return self._encode_tensor(nt)
262
263
264
265
266
267
268
269
270
271
272
273
274
        if isinstance(nt, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return nt
        return [self._encode_nested_tensors(x) for x in nt]

    def _encode_mm_field(self, field: BaseMultiModalField):
        # Figure out the factory name for the field type.
        name = MMF_CLASS_TO_FACTORY.get(field.__class__)
        if not name:
            raise TypeError(f"Unsupported field type: {field.__class__}")
        # We just need to copy all of the field values in order
        # which will be then used to reconstruct the field.
275
        field_values = (getattr(field, f.name) for f in dataclasses.fields(field))
276
277
        return name, *field_values

278
279

class MsgpackDecoder:
280
281
282
283
284
    """Decoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Decoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
    """
285

286
287
288
    def __init__(self, t: Any | None = None, share_mem: bool = True):
        self.share_mem = share_mem
        self.pin_tensors = is_pin_memory_available()
289
290
291
292
        args = () if t is None else (t,)
        self.decoder = msgpack.Decoder(
            *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
        )
293
        self.aux_buffers: Sequence[bytestr] = ()
294
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
295
            _log_insecure_serialization_warning()
296

297
    def decode(self, bufs: bytestr | Sequence[bytestr]) -> Any:
298
        if isinstance(bufs, bytestr):  # type: ignore
299
300
301
302
303
304
305
306
307
308
309
310
311
312
            return self.decoder.decode(bufs)

        self.aux_buffers = bufs
        try:
            return self.decoder.decode(bufs[0])
        finally:
            self.aux_buffers = ()

    def dec_hook(self, t: type, obj: Any) -> Any:
        # Given native types in `obj`, convert to type `t`.
        if isclass(t):
            if issubclass(t, np.ndarray):
                return self._decode_ndarray(obj)
            if issubclass(t, torch.Tensor):
313
                return self._decode_tensor(obj)
314
315
            if t is slice:
                return slice(*obj)
316
317
            if issubclass(t, MultiModalKwargsItem):
                return self._decode_mm_item(obj)
318
319
            if issubclass(t, MultiModalKwargsItems):
                return self._decode_mm_items(obj)
320
            if issubclass(t, MultiModalKwargs):
321
                return self._decode_mm_kwargs(obj)
322
323
            if t is UtilityResult:
                return self._decode_utility_result(obj)
324
325
        return obj

326
327
328
329
    def _decode_utility_result(self, obj: Any) -> UtilityResult:
        result_type, result = obj
        if result_type is not None:
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
330
331
332
333
                raise TypeError(
                    "VLLM_ALLOW_INSECURE_SERIALIZATION must "
                    "be set to use custom utility result types"
                )
334
            # Use recursive decoding to handle nested structures
335
336
337
            result = _decode_type_info_recursive(
                result_type, result, self._convert_result
            )
338
339
        return UtilityResult(result)

340
341
342
    def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
        if result_type is None:
            return result
343
344
345
346
347
        mod_name, name = result_type
        mod = importlib.import_module(mod_name)
        result_type = getattr(mod, name)
        return msgspec.convert(result, result_type, dec_hook=self.dec_hook)

348
349
    def _decode_ndarray(self, arr: Any) -> np.ndarray:
        dtype, shape, data = arr
350
351
352
        # zero-copy decode. We assume the ndarray will not be kept around,
        # as it now locks the whole received message buffer in memory.
        buffer = self.aux_buffers[data] if isinstance(data, int) else data
353
354
355
356
        arr = np.frombuffer(buffer, dtype=dtype)
        if not self.share_mem:
            arr = arr.copy()
        return arr.reshape(shape)
357
358
359

    def _decode_tensor(self, arr: Any) -> torch.Tensor:
        dtype, shape, data = arr
360
361
362
        is_aux = isinstance(data, int)
        buffer = self.aux_buffers[data] if is_aux else data
        buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer)
363
364
        torch_dtype = getattr(torch, dtype)
        assert isinstance(torch_dtype, torch.dtype)
365
        if not buffer.nbytes:  # torch.frombuffer doesn't like empty buffers
366
367
368
369
            assert 0 in shape
            return torch.empty(shape, dtype=torch_dtype)
        # Create uint8 array
        arr = torch.frombuffer(buffer, dtype=torch.uint8)
370
371
372
373
374
375
376
        # Clone ensures tensor is backed by pytorch-owned memory for safe
        # future async CPU->GPU transfer.
        # Pin larger tensors for more efficient CPU->GPU transfer.
        if not is_aux:
            arr = arr.clone()
        elif not self.share_mem:
            arr = arr.pin_memory() if self.pin_tensors else arr.clone()
377
        # Convert back to proper shape & type
378
        return arr.view(torch_dtype).view(shape)
379

380
    def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems:
381
382
383
384
385
386
        return MultiModalKwargsItems(
            {
                modality: [self._decode_mm_item(item) for item in itemlist]
                for modality, itemlist in obj.items()
            }
        )
387

388
    def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
389
        return MultiModalKwargsItem.from_elems(
390
391
            [self._decode_mm_field_elem(v) for v in obj]
        )
392

393
    def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem:
394
395
396
        if obj["data"] is not None:
            obj["data"] = self._decode_nested_tensors(obj["data"])

397
398
399
400
401
402
403
404
405
406
407
        # Reconstruct the field processor using MultiModalFieldConfig
        factory_meth_name, *field_args = obj["field"]
        factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)

        # Special case: decode the union "slices" field of
        # MultiModalFlatField
        if factory_meth_name == "flat":
            field_args[0] = self._decode_nested_slices(field_args[0])

        obj["field"] = factory_meth(None, *field_args).field
        return MultiModalFieldElem(**obj)
408

409
    def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs:
410
411
412
413
414
415
        return MultiModalKwargs(
            {
                modality: self._decode_nested_tensors(data)
                for modality, data in obj.items()
            }
        )
416

417
418
419
420
421
422
423
424
    def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
        if isinstance(obj, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return obj
        if not isinstance(obj, list):
            raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
        if obj and isinstance(obj[0], str):
425
            return self._decode_tensor(obj)
426
427
        return [self._decode_nested_tensors(x) for x in obj]

428
429
430
431
432
433
    def _decode_nested_slices(self, obj: Any) -> Any:
        assert isinstance(obj, (list, tuple))
        if obj and not isinstance(obj[0], (list, tuple)):
            return slice(*obj)
        return [self._decode_nested_slices(x) for x in obj]

434
    def ext_hook(self, code: int, data: memoryview) -> Any:
435
436
        if code == CUSTOM_TYPE_RAW_VIEW:
            return data
437

438
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
439
440
441
442
            if code == CUSTOM_TYPE_PICKLE:
                return pickle.loads(data)
            if code == CUSTOM_TYPE_CLOUDPICKLE:
                return cloudpickle.loads(data)
443

444
        raise NotImplementedError(f"Extension type code {code} is not supported")
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471


def run_method(
    obj: Any,
    method: str | bytes | Callable,
    args: tuple[Any, ...],
    kwargs: dict[str, Any],
) -> Any:
    """
    Run a method of an object with the given arguments and keyword arguments.
    If the method is string, it will be converted to a method using getattr.
    If the method is serialized bytes and will be deserialized using
    cloudpickle.
    If the method is a callable, it will be called directly.
    """
    if isinstance(method, bytes):
        func = partial(cloudpickle.loads(method), obj)
    elif isinstance(method, str):
        try:
            func = getattr(obj, method)
        except AttributeError:
            raise NotImplementedError(
                f"Method {method!r} is not implemented."
            ) from None
    else:
        func = partial(method, obj)  # type: ignore
    return func(*args, **kwargs)