serial_utils.py 19.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
import importlib
6
import pickle
7
from collections.abc import Callable, Sequence
8
from functools import partial
9
from inspect import isclass
10
from types import FunctionType
11
from typing import Any, TypeAlias, get_type_hints
12

13
import cloudpickle
14
import msgspec
15
import numpy as np
16
import torch
17
import zmq
18
from msgspec import msgpack
19
20
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
21

22
from vllm import envs
23
from vllm.logger import init_logger
24
25
26
27
28
29
30
31
32
33
34
from vllm.multimodal.inputs import (
    BaseMultiModalField,
    MultiModalBatchedField,
    MultiModalFieldConfig,
    MultiModalFieldElem,
    MultiModalFlatField,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    MultiModalSharedField,
    NestedTensors,
)
35
from vllm.utils.platform_utils import is_pin_memory_available
36
from vllm.v1.utils import tensor_data
37

38
39
logger = init_logger(__name__)

40
41
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
42
CUSTOM_TYPE_RAW_VIEW = 3
43

44
45
46
47
48
49
50
51
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
    MultiModalFlatField: "flat",
    MultiModalSharedField: "shared",
    MultiModalBatchedField: "batched",
}
52

53
bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame
54
55


56
def _log_insecure_serialization_warning():
57
58
59
60
    logger.warning_once(
        "Allowing insecure serialization using pickle due to "
        "VLLM_ALLOW_INSECURE_SERIALIZATION=1"
    )
61
62


63
def _typestr(val: Any) -> tuple[str, str] | None:
64
65
66
    if val is None:
        return None
    t = type(val)
67
68
69
    return t.__module__, t.__qualname__


70
71
72
73
74
75
76
77
78
79
80
81
82
def _encode_type_info_recursive(obj: Any) -> Any:
    """Recursively encode type information for nested structures of
    lists/dicts."""
    if obj is None:
        return None
    if type(obj) is list:
        return [_encode_type_info_recursive(item) for item in obj]
    if type(obj) is dict:
        return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
    return _typestr(obj)


def _decode_type_info_recursive(
83
84
    type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any], Any]
) -> Any:
85
86
87
88
89
90
91
92
93
94
95
    """Recursively decode type information for nested structures of
    lists/dicts."""
    if type_info is None:
        return data
    if isinstance(type_info, dict):
        assert isinstance(data, dict)
        return {
            k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
            for k in type_info
        }
    if isinstance(type_info, list) and (
96
97
98
        # Exclude serialized tensors/numpy arrays.
        len(type_info) != 2 or not isinstance(type_info[0], str)
    ):
99
100
101
102
103
104
105
106
        assert isinstance(data, list)
        return [
            _decode_type_info_recursive(ti, d, convert_fn)
            for ti, d in zip(type_info, data)
        ]
    return convert_fn(type_info, data)


107
108
109
110
111
112
113
class UtilityResult:
    """Wrapper for special handling when serializing/deserializing."""

    def __init__(self, r: Any = None):
        self.result = r


114
115
class MsgpackEncoder:
    """Encoder with custom torch tensor and numpy array serialization.
116

117
118
    Note that unlike vanilla `msgspec` Encoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
119

120
    By default, arrays below 256B are serialized inline Larger will get sent
121
    via dedicated messages. Note that this is a per-tensor limit.
122
123
    """

124
    def __init__(self, size_threshold: int | None = None):
125
126
        if size_threshold is None:
            size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
127
128
129
130
        self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
        # This is used as a local stash of buffers that we can then access from
        # our custom `msgspec` hook, `enc_hook`. We don't have a way to
        # pass custom data to the hook otherwise.
131
        self.aux_buffers: list[bytestr] | None = None
132
        self.size_threshold = size_threshold
133
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
134
            _log_insecure_serialization_warning()
135
136
137

    def encode(self, obj: Any) -> Sequence[bytestr]:
        try:
138
            self.aux_buffers = bufs = [b""]
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
            bufs[0] = self.encoder.encode(obj)
            # This `bufs` list allows us to collect direct pointers to backing
            # buffers of tensors and np arrays, and return them along with the
            # top-level encoded buffer instead of copying their data into the
            # new buffer.
            return bufs
        finally:
            self.aux_buffers = None

    def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
        try:
            self.aux_buffers = [buf]
            bufs = self.aux_buffers
            self.encoder.encode_into(obj, buf)
            return bufs
        finally:
            self.aux_buffers = None

    def enc_hook(self, obj: Any) -> Any:
        if isinstance(obj, torch.Tensor):
159
            return self._encode_tensor(obj)
160
161

        # Fall back to pickle for object or void kind ndarrays.
162
        if isinstance(obj, np.ndarray) and obj.dtype.kind not in ("O", "V"):
163
164
            return self._encode_ndarray(obj)

165
166
167
168
        if isinstance(obj, slice):
            # We are assuming only int-based values will be used here.
            return tuple(
                int(v) if v is not None else None
169
170
                for v in (obj.start, obj.stop, obj.step)
            )
171

172
173
174
        if isinstance(obj, MultiModalKwargsItem):
            return self._encode_mm_item(obj)

175
176
177
        if isinstance(obj, MultiModalKwargsItems):
            return self._encode_mm_items(obj)

178
179
        if isinstance(obj, UtilityResult):
            result = obj.result
180
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
181
                return None, result
182
183
184
185
            # Since utility results are not strongly typed, we recursively
            # encode type information for nested structures of lists/dicts
            # to help with correct msgspec deserialization.
            return _encode_type_info_recursive(result), result
186

187
        if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
188
189
190
191
192
            raise TypeError(
                f"Object of type {type(obj)} is not serializable"
                "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
                "fallback to pickle-based serialization."
            )
193

194
195
196
197
198
        if isinstance(obj, FunctionType):
            # `pickle` is generally faster than cloudpickle, but can have
            # problems serializing methods.
            return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

199
200
201
        return msgpack.Ext(
            CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
        )
202
203
204

    def _encode_ndarray(
        self, obj: np.ndarray
205
    ) -> tuple[str, tuple[int, ...], int | memoryview]:
206
        assert self.aux_buffers is not None
207
        # If the array is non-contiguous, we need to copy it first
208
        arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
209
        if not obj.shape or obj.nbytes < self.size_threshold:
210
211
212
            # Encode small arrays and scalars inline. Using this extension type
            # ensures we can avoid copying when decoding.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
213
        else:
214
            # Otherwise encode index of backing buffer to avoid copy.
215
            data = len(self.aux_buffers)
216
217
            self.aux_buffers.append(arr_data)

218
219
220
221
        # We serialize the ndarray as a tuple of native types.
        # The data is either inlined if small, or an index into a list of
        # backing buffers that we've stashed in `aux_buffers`.
        return obj.dtype.str, obj.shape, data
222

223
224
    def _encode_tensor(
        self, obj: torch.Tensor
225
    ) -> tuple[str, tuple[int, ...], int | memoryview]:
226
        assert self.aux_buffers is not None
227
        # view the tensor as a contiguous 1D array of bytes
228
        arr_data = tensor_data(obj)
229
230
        if obj.nbytes < self.size_threshold:
            # Smaller tensors are encoded inline, just like ndarrays.
231
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
232
233
234
        else:
            # Otherwise encode index of backing buffer to avoid copy.
            data = len(self.aux_buffers)
235
            self.aux_buffers.append(arr_data)
236
        dtype = str(obj.dtype).removeprefix("torch.")
237
238
        return dtype, obj.shape, data

239
240
241
242
243
244
    def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]:
        return {
            modality: [self._encode_mm_item(item) for item in itemlist]
            for modality, itemlist in items.items()
        }

245
    def _encode_mm_item(self, item: MultiModalKwargsItem) -> list[dict[str, Any]]:
246
247
        return [self._encode_mm_field_elem(elem) for elem in item.values()]

248
    def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]:
249
        return {
250
251
252
253
254
255
            "modality": elem.modality,
            "key": elem.key,
            "data": (
                None if elem.data is None else self._encode_nested_tensors(elem.data)
            ),
            "field": self._encode_mm_field(elem.field),
256
257
        }

258
259
    def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
        if isinstance(nt, torch.Tensor):
260
            return self._encode_tensor(nt)
261
262
263
264
265
266
267
268
269
270
271
272
273
        if isinstance(nt, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return nt
        return [self._encode_nested_tensors(x) for x in nt]

    def _encode_mm_field(self, field: BaseMultiModalField):
        # Figure out the factory name for the field type.
        name = MMF_CLASS_TO_FACTORY.get(field.__class__)
        if not name:
            raise TypeError(f"Unsupported field type: {field.__class__}")
        # We just need to copy all of the field values in order
        # which will be then used to reconstruct the field.
274
        field_values = (getattr(field, f.name) for f in dataclasses.fields(field))
275
276
        return name, *field_values

277
278

class MsgpackDecoder:
279
280
281
282
283
    """Decoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Decoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
    """
284

285
286
287
    def __init__(self, t: Any | None = None, share_mem: bool = True):
        self.share_mem = share_mem
        self.pin_tensors = is_pin_memory_available()
288
289
290
291
        args = () if t is None else (t,)
        self.decoder = msgpack.Decoder(
            *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
        )
292
        self.aux_buffers: Sequence[bytestr] = ()
293
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
294
            _log_insecure_serialization_warning()
295

296
    def decode(self, bufs: bytestr | Sequence[bytestr]) -> Any:
297
        if isinstance(bufs, bytestr):  # type: ignore
298
299
300
301
302
303
304
305
306
307
308
309
310
311
            return self.decoder.decode(bufs)

        self.aux_buffers = bufs
        try:
            return self.decoder.decode(bufs[0])
        finally:
            self.aux_buffers = ()

    def dec_hook(self, t: type, obj: Any) -> Any:
        # Given native types in `obj`, convert to type `t`.
        if isclass(t):
            if issubclass(t, np.ndarray):
                return self._decode_ndarray(obj)
            if issubclass(t, torch.Tensor):
312
                return self._decode_tensor(obj)
313
314
            if t is slice:
                return slice(*obj)
315
316
            if issubclass(t, MultiModalKwargsItem):
                return self._decode_mm_item(obj)
317
318
            if issubclass(t, MultiModalKwargsItems):
                return self._decode_mm_items(obj)
319
320
            if t is UtilityResult:
                return self._decode_utility_result(obj)
321
322
        return obj

323
324
325
326
    def _decode_utility_result(self, obj: Any) -> UtilityResult:
        result_type, result = obj
        if result_type is not None:
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
327
328
329
330
                raise TypeError(
                    "VLLM_ALLOW_INSECURE_SERIALIZATION must "
                    "be set to use custom utility result types"
                )
331
            # Use recursive decoding to handle nested structures
332
333
334
            result = _decode_type_info_recursive(
                result_type, result, self._convert_result
            )
335
336
        return UtilityResult(result)

337
338
339
    def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
        if result_type is None:
            return result
340
341
342
343
344
        mod_name, name = result_type
        mod = importlib.import_module(mod_name)
        result_type = getattr(mod, name)
        return msgspec.convert(result, result_type, dec_hook=self.dec_hook)

345
346
    def _decode_ndarray(self, arr: Any) -> np.ndarray:
        dtype, shape, data = arr
347
348
349
        # zero-copy decode. We assume the ndarray will not be kept around,
        # as it now locks the whole received message buffer in memory.
        buffer = self.aux_buffers[data] if isinstance(data, int) else data
350
351
352
353
        arr = np.frombuffer(buffer, dtype=dtype)
        if not self.share_mem:
            arr = arr.copy()
        return arr.reshape(shape)
354
355
356

    def _decode_tensor(self, arr: Any) -> torch.Tensor:
        dtype, shape, data = arr
357
358
359
        is_aux = isinstance(data, int)
        buffer = self.aux_buffers[data] if is_aux else data
        buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer)
360
361
        torch_dtype = getattr(torch, dtype)
        assert isinstance(torch_dtype, torch.dtype)
362
        if not buffer.nbytes:  # torch.frombuffer doesn't like empty buffers
363
364
365
366
            assert 0 in shape
            return torch.empty(shape, dtype=torch_dtype)
        # Create uint8 array
        arr = torch.frombuffer(buffer, dtype=torch.uint8)
367
368
369
370
371
372
373
        # Clone ensures tensor is backed by pytorch-owned memory for safe
        # future async CPU->GPU transfer.
        # Pin larger tensors for more efficient CPU->GPU transfer.
        if not is_aux:
            arr = arr.clone()
        elif not self.share_mem:
            arr = arr.pin_memory() if self.pin_tensors else arr.clone()
374
        # Convert back to proper shape & type
375
        return arr.view(torch_dtype).view(shape)
376

377
    def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems:
378
379
380
381
382
383
        return MultiModalKwargsItems(
            {
                modality: [self._decode_mm_item(item) for item in itemlist]
                for modality, itemlist in obj.items()
            }
        )
384

385
    def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
386
        return MultiModalKwargsItem.from_elems(
387
388
            [self._decode_mm_field_elem(v) for v in obj]
        )
389

390
    def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem:
391
392
393
        if obj["data"] is not None:
            obj["data"] = self._decode_nested_tensors(obj["data"])

394
395
396
397
398
399
400
401
402
403
404
        # Reconstruct the field processor using MultiModalFieldConfig
        factory_meth_name, *field_args = obj["field"]
        factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)

        # Special case: decode the union "slices" field of
        # MultiModalFlatField
        if factory_meth_name == "flat":
            field_args[0] = self._decode_nested_slices(field_args[0])

        obj["field"] = factory_meth(None, *field_args).field
        return MultiModalFieldElem(**obj)
405
406
407
408
409
410
411
412
413

    def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
        if isinstance(obj, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return obj
        if not isinstance(obj, list):
            raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
        if obj and isinstance(obj[0], str):
414
            return self._decode_tensor(obj)
415
416
        return [self._decode_nested_tensors(x) for x in obj]

417
418
419
420
421
422
    def _decode_nested_slices(self, obj: Any) -> Any:
        assert isinstance(obj, (list, tuple))
        if obj and not isinstance(obj[0], (list, tuple)):
            return slice(*obj)
        return [self._decode_nested_slices(x) for x in obj]

423
    def ext_hook(self, code: int, data: memoryview) -> Any:
424
425
        if code == CUSTOM_TYPE_RAW_VIEW:
            return data
426

427
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
428
429
430
431
            if code == CUSTOM_TYPE_PICKLE:
                return pickle.loads(data)
            if code == CUSTOM_TYPE_CLOUDPICKLE:
                return cloudpickle.loads(data)
432

433
        raise NotImplementedError(f"Extension type code {code} is not supported")
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460


def run_method(
    obj: Any,
    method: str | bytes | Callable,
    args: tuple[Any, ...],
    kwargs: dict[str, Any],
) -> Any:
    """
    Run a method of an object with the given arguments and keyword arguments.
    If the method is string, it will be converted to a method using getattr.
    If the method is serialized bytes and will be deserialized using
    cloudpickle.
    If the method is a callable, it will be called directly.
    """
    if isinstance(method, bytes):
        func = partial(cloudpickle.loads(method), obj)
    elif isinstance(method, str):
        try:
            func = getattr(obj, method)
        except AttributeError:
            raise NotImplementedError(
                f"Method {method!r} is not implemented."
            ) from None
    else:
        func = partial(method, obj)  # type: ignore
    return func(*args, **kwargs)
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513


class PydanticMsgspecMixin:
    @classmethod
    def __get_pydantic_core_schema__(
        cls, source_type: Any, handler: GetCoreSchemaHandler
    ) -> core_schema.CoreSchema:
        """
        Make msgspec.Struct compatible with Pydantic, respecting defaults.
        Handle JSON=>msgspec.Struct. Used when exposing msgspec.Struct to the
        API as input or in `/docs`. Note this is cached by Pydantic and not
        called on every validation.
        """
        msgspec_fields = {f.name: f for f in msgspec.structs.fields(source_type)}
        type_hints = get_type_hints(source_type)

        # Build the Pydantic typed_dict_field for each msgspec field
        fields = {}
        for name, hint in type_hints.items():
            msgspec_field = msgspec_fields[name]

            # typed_dict_field using the handler to get the schema
            field_schema = handler(hint)

            # Add default value to the schema.
            if msgspec_field.default_factory is not msgspec.NODEFAULT:
                wrapped_schema = core_schema.with_default_schema(
                    schema=field_schema,
                    default_factory=msgspec_field.default_factory,
                )
                fields[name] = core_schema.typed_dict_field(wrapped_schema)
            elif msgspec_field.default is not msgspec.NODEFAULT:
                wrapped_schema = core_schema.with_default_schema(
                    schema=field_schema,
                    default=msgspec_field.default,
                )
                fields[name] = core_schema.typed_dict_field(wrapped_schema)
            else:
                # No default, so Pydantic will treat it as required
                fields[name] = core_schema.typed_dict_field(field_schema)
        return core_schema.no_info_after_validator_function(
            cls._validate_msgspec,
            core_schema.typed_dict_schema(fields),
        )

    @classmethod
    def _validate_msgspec(cls, value: Any) -> Any:
        """Validate and convert input to msgspec.Struct instance."""
        if isinstance(value, cls):
            return value
        if isinstance(value, dict):
            return cls(**value)
        return msgspec.convert(value, type=cls)