serial_utils.py 21.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
import importlib
6
import pickle
7
from collections.abc import Callable, Sequence
8
from functools import partial
9
from inspect import isclass
10
from types import FunctionType
11
from typing import Any, ClassVar, TypeAlias, cast, get_type_hints
12

13
import cloudpickle
14
import msgspec
15
import numpy as np
16
import torch
17
import zmq
18
from msgspec import msgpack
19
20
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
21

22
from vllm import envs
23
from vllm.logger import init_logger
24
25
26
27
28
29
30
31
32
33
34
from vllm.multimodal.inputs import (
    BaseMultiModalField,
    MultiModalBatchedField,
    MultiModalFieldConfig,
    MultiModalFieldElem,
    MultiModalFlatField,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    MultiModalSharedField,
    NestedTensors,
)
35
from vllm.utils.platform_utils import is_pin_memory_available
36
from vllm.v1.utils import tensor_data
37

38
39
logger = init_logger(__name__)

40
41
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
42
CUSTOM_TYPE_RAW_VIEW = 3
43

44
45
46
47
48
49
50
51
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
    MultiModalFlatField: "flat",
    MultiModalSharedField: "shared",
    MultiModalBatchedField: "batched",
}
52

53
bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame
54
55


56
def _log_insecure_serialization_warning():
57
58
59
60
    logger.warning_once(
        "Allowing insecure serialization using pickle due to "
        "VLLM_ALLOW_INSECURE_SERIALIZATION=1"
    )
61
62


63
def _typestr(val: Any) -> tuple[str, str] | None:
64
65
66
    if val is None:
        return None
    t = type(val)
67
68
69
    return t.__module__, t.__qualname__


70
71
72
73
74
75
76
77
78
79
80
81
82
def _encode_type_info_recursive(obj: Any) -> Any:
    """Recursively encode type information for nested structures of
    lists/dicts."""
    if obj is None:
        return None
    if type(obj) is list:
        return [_encode_type_info_recursive(item) for item in obj]
    if type(obj) is dict:
        return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
    return _typestr(obj)


def _decode_type_info_recursive(
83
84
    type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any], Any]
) -> Any:
85
86
87
88
89
90
91
92
93
94
95
    """Recursively decode type information for nested structures of
    lists/dicts."""
    if type_info is None:
        return data
    if isinstance(type_info, dict):
        assert isinstance(data, dict)
        return {
            k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
            for k in type_info
        }
    if isinstance(type_info, list) and (
96
97
98
        # Exclude serialized tensors/numpy arrays.
        len(type_info) != 2 or not isinstance(type_info[0], str)
    ):
99
100
101
102
103
104
105
106
        assert isinstance(data, list)
        return [
            _decode_type_info_recursive(ti, d, convert_fn)
            for ti, d in zip(type_info, data)
        ]
    return convert_fn(type_info, data)


107
108
109
110
111
112
113
class UtilityResult:
    """Wrapper for special handling when serializing/deserializing."""

    def __init__(self, r: Any = None):
        self.result = r


114
115
class MsgpackEncoder:
    """Encoder with custom torch tensor and numpy array serialization.
116

117
118
    Note that unlike vanilla `msgspec` Encoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
119

120
    By default, arrays below 256B are serialized inline Larger will get sent
121
    via dedicated messages. Note that this is a per-tensor limit.
122
123
    """

124
    def __init__(self, size_threshold: int | None = None):
125
126
        if size_threshold is None:
            size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
127
128
129
130
        self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
        # This is used as a local stash of buffers that we can then access from
        # our custom `msgspec` hook, `enc_hook`. We don't have a way to
        # pass custom data to the hook otherwise.
131
        self.aux_buffers: list[bytestr] | None = None
132
        self.size_threshold = size_threshold
133
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
134
            _log_insecure_serialization_warning()
135
136
137

    def encode(self, obj: Any) -> Sequence[bytestr]:
        try:
138
            self.aux_buffers = bufs = [b""]
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
            bufs[0] = self.encoder.encode(obj)
            # This `bufs` list allows us to collect direct pointers to backing
            # buffers of tensors and np arrays, and return them along with the
            # top-level encoded buffer instead of copying their data into the
            # new buffer.
            return bufs
        finally:
            self.aux_buffers = None

    def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
        try:
            self.aux_buffers = [buf]
            bufs = self.aux_buffers
            self.encoder.encode_into(obj, buf)
            return bufs
        finally:
            self.aux_buffers = None

    def enc_hook(self, obj: Any) -> Any:
        if isinstance(obj, torch.Tensor):
159
            return self._encode_tensor(obj)
160
161

        # Fall back to pickle for object or void kind ndarrays.
162
        if isinstance(obj, np.ndarray) and obj.dtype.kind not in ("O", "V"):
163
164
            return self._encode_ndarray(obj)

165
166
167
168
        if isinstance(obj, slice):
            # We are assuming only int-based values will be used here.
            return tuple(
                int(v) if v is not None else None
169
170
                for v in (obj.start, obj.stop, obj.step)
            )
171

172
173
174
        if isinstance(obj, MultiModalKwargsItem):
            return self._encode_mm_item(obj)

175
176
177
        if isinstance(obj, MultiModalKwargsItems):
            return self._encode_mm_items(obj)

178
179
        if isinstance(obj, UtilityResult):
            result = obj.result
180
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
181
                return None, result
182
183
184
185
            # Since utility results are not strongly typed, we recursively
            # encode type information for nested structures of lists/dicts
            # to help with correct msgspec deserialization.
            return _encode_type_info_recursive(result), result
186

187
        if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
188
189
190
191
192
            raise TypeError(
                f"Object of type {type(obj)} is not serializable"
                "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
                "fallback to pickle-based serialization."
            )
193

194
195
196
197
198
        if isinstance(obj, FunctionType):
            # `pickle` is generally faster than cloudpickle, but can have
            # problems serializing methods.
            return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

199
200
201
        return msgpack.Ext(
            CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
        )
202
203
204

    def _encode_ndarray(
        self, obj: np.ndarray
205
    ) -> tuple[str, tuple[int, ...], int | memoryview]:
206
        assert self.aux_buffers is not None
207
        # If the array is non-contiguous, we need to copy it first
208
        arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
209
        if not obj.shape or obj.nbytes < self.size_threshold:
210
211
212
            # Encode small arrays and scalars inline. Using this extension type
            # ensures we can avoid copying when decoding.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
213
        else:
214
            # Otherwise encode index of backing buffer to avoid copy.
215
            data = len(self.aux_buffers)
216
217
            self.aux_buffers.append(arr_data)

218
219
220
221
        # We serialize the ndarray as a tuple of native types.
        # The data is either inlined if small, or an index into a list of
        # backing buffers that we've stashed in `aux_buffers`.
        return obj.dtype.str, obj.shape, data
222

223
224
    def _encode_tensor(
        self, obj: torch.Tensor
225
    ) -> tuple[str, tuple[int, ...], int | memoryview]:
226
        assert self.aux_buffers is not None
227
        # view the tensor as a contiguous 1D array of bytes
228
        arr_data = tensor_data(obj)
229
230
        if obj.nbytes < self.size_threshold:
            # Smaller tensors are encoded inline, just like ndarrays.
231
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
232
233
234
        else:
            # Otherwise encode index of backing buffer to avoid copy.
            data = len(self.aux_buffers)
235
            self.aux_buffers.append(arr_data)
236
        dtype = str(obj.dtype).removeprefix("torch.")
237
238
        return dtype, obj.shape, data

239
240
241
242
243
244
    def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]:
        return {
            modality: [self._encode_mm_item(item) for item in itemlist]
            for modality, itemlist in items.items()
        }

245
246
    def _encode_mm_item(self, item: MultiModalKwargsItem) -> dict[str, Any]:
        return {key: self._encode_mm_field_elem(elem) for key, elem in item.items()}
247

248
    def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]:
249
        return {
250
251
252
253
            "data": (
                None if elem.data is None else self._encode_nested_tensors(elem.data)
            ),
            "field": self._encode_mm_field(elem.field),
254
255
        }

256
257
    def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
        if isinstance(nt, torch.Tensor):
258
            return self._encode_tensor(nt)
259
260
261
262
263
264
265
266
267
268
269
        if isinstance(nt, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return nt
        return [self._encode_nested_tensors(x) for x in nt]

    def _encode_mm_field(self, field: BaseMultiModalField):
        # Figure out the factory name for the field type.
        name = MMF_CLASS_TO_FACTORY.get(field.__class__)
        if not name:
            raise TypeError(f"Unsupported field type: {field.__class__}")
270

271
272
        # We just need to copy all of the field values in order
        # which will be then used to reconstruct the field.
273
274
        factory_kw = {f.name: getattr(field, f.name) for f in dataclasses.fields(field)}
        return name, factory_kw
275

276
277

class MsgpackDecoder:
278
279
280
281
282
    """Decoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Decoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
    """
283

284
285
286
    def __init__(self, t: Any | None = None, share_mem: bool = True):
        self.share_mem = share_mem
        self.pin_tensors = is_pin_memory_available()
287
288
289
290
        args = () if t is None else (t,)
        self.decoder = msgpack.Decoder(
            *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
        )
291
        self.aux_buffers: Sequence[bytestr] = ()
292
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
293
            _log_insecure_serialization_warning()
294

295
    def decode(self, bufs: bytestr | Sequence[bytestr]) -> Any:
296
        if isinstance(bufs, bytestr):  # type: ignore
297
298
299
300
301
302
303
304
305
306
307
308
309
310
            return self.decoder.decode(bufs)

        self.aux_buffers = bufs
        try:
            return self.decoder.decode(bufs[0])
        finally:
            self.aux_buffers = ()

    def dec_hook(self, t: type, obj: Any) -> Any:
        # Given native types in `obj`, convert to type `t`.
        if isclass(t):
            if issubclass(t, np.ndarray):
                return self._decode_ndarray(obj)
            if issubclass(t, torch.Tensor):
311
                return self._decode_tensor(obj)
312
313
            if t is slice:
                return slice(*obj)
314
315
            if issubclass(t, MultiModalKwargsItem):
                return self._decode_mm_item(obj)
316
317
            if issubclass(t, MultiModalKwargsItems):
                return self._decode_mm_items(obj)
318
319
            if t is UtilityResult:
                return self._decode_utility_result(obj)
320
321
        return obj

322
323
324
325
    def _decode_utility_result(self, obj: Any) -> UtilityResult:
        result_type, result = obj
        if result_type is not None:
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
326
327
328
329
                raise TypeError(
                    "VLLM_ALLOW_INSECURE_SERIALIZATION must "
                    "be set to use custom utility result types"
                )
330
            # Use recursive decoding to handle nested structures
331
332
333
            result = _decode_type_info_recursive(
                result_type, result, self._convert_result
            )
334
335
        return UtilityResult(result)

336
337
338
    def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
        if result_type is None:
            return result
339
340
341
342
343
        mod_name, name = result_type
        mod = importlib.import_module(mod_name)
        result_type = getattr(mod, name)
        return msgspec.convert(result, result_type, dec_hook=self.dec_hook)

344
345
    def _decode_ndarray(self, arr: Any) -> np.ndarray:
        dtype, shape, data = arr
346
347
348
        # zero-copy decode. We assume the ndarray will not be kept around,
        # as it now locks the whole received message buffer in memory.
        buffer = self.aux_buffers[data] if isinstance(data, int) else data
349
350
351
352
        arr = np.frombuffer(buffer, dtype=dtype)
        if not self.share_mem:
            arr = arr.copy()
        return arr.reshape(shape)
353
354
355

    def _decode_tensor(self, arr: Any) -> torch.Tensor:
        dtype, shape, data = arr
356
357
358
        is_aux = isinstance(data, int)
        buffer = self.aux_buffers[data] if is_aux else data
        buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer)
359
360
        torch_dtype = getattr(torch, dtype)
        assert isinstance(torch_dtype, torch.dtype)
361
        if not buffer.nbytes:  # torch.frombuffer doesn't like empty buffers
362
363
364
365
            assert 0 in shape
            return torch.empty(shape, dtype=torch_dtype)
        # Create uint8 array
        arr = torch.frombuffer(buffer, dtype=torch.uint8)
366
367
368
369
370
371
372
        # Clone ensures tensor is backed by pytorch-owned memory for safe
        # future async CPU->GPU transfer.
        # Pin larger tensors for more efficient CPU->GPU transfer.
        if not is_aux:
            arr = arr.clone()
        elif not self.share_mem:
            arr = arr.pin_memory() if self.pin_tensors else arr.clone()
373
        # Convert back to proper shape & type
374
        return arr.view(torch_dtype).view(shape)
375

376
    def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems:
377
378
379
380
381
382
        return MultiModalKwargsItems(
            {
                modality: [self._decode_mm_item(item) for item in itemlist]
                for modality, itemlist in obj.items()
            }
        )
383

384
385
386
    def _decode_mm_item(self, obj: dict[str, Any]) -> MultiModalKwargsItem:
        return MultiModalKwargsItem(
            {key: self._decode_mm_field_elem(elem) for key, elem in obj.items()}
387
        )
388

389
    def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem:
390
391
392
        if obj["data"] is not None:
            obj["data"] = self._decode_nested_tensors(obj["data"])

393
        # Reconstruct the field processor using MultiModalFieldConfig
394
        factory_meth_name, factory_kw = obj["field"]
395
396
397
398
399
        factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)

        # Special case: decode the union "slices" field of
        # MultiModalFlatField
        if factory_meth_name == "flat":
400
            factory_kw["slices"] = self._decode_nested_slices(factory_kw["slices"])
401

402
        obj["field"] = factory_meth("", **factory_kw).field
403
        return MultiModalFieldElem(**obj)
404
405
406
407
408
409
410
411
412

    def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
        if isinstance(obj, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return obj
        if not isinstance(obj, list):
            raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
        if obj and isinstance(obj[0], str):
413
            return self._decode_tensor(obj)
414
415
        return [self._decode_nested_tensors(x) for x in obj]

416
417
418
419
420
421
    def _decode_nested_slices(self, obj: Any) -> Any:
        assert isinstance(obj, (list, tuple))
        if obj and not isinstance(obj[0], (list, tuple)):
            return slice(*obj)
        return [self._decode_nested_slices(x) for x in obj]

422
    def ext_hook(self, code: int, data: memoryview) -> Any:
423
424
        if code == CUSTOM_TYPE_RAW_VIEW:
            return data
425

426
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
427
428
429
430
            if code == CUSTOM_TYPE_PICKLE:
                return pickle.loads(data)
            if code == CUSTOM_TYPE_CLOUDPICKLE:
                return cloudpickle.loads(data)
431

432
        raise NotImplementedError(f"Extension type code {code} is not supported")
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459


def run_method(
    obj: Any,
    method: str | bytes | Callable,
    args: tuple[Any, ...],
    kwargs: dict[str, Any],
) -> Any:
    """
    Run a method of an object with the given arguments and keyword arguments.
    If the method is string, it will be converted to a method using getattr.
    If the method is serialized bytes and will be deserialized using
    cloudpickle.
    If the method is a callable, it will be called directly.
    """
    if isinstance(method, bytes):
        func = partial(cloudpickle.loads(method), obj)
    elif isinstance(method, str):
        try:
            func = getattr(obj, method)
        except AttributeError:
            raise NotImplementedError(
                f"Method {method!r} is not implemented."
            ) from None
    else:
        func = partial(method, obj)  # type: ignore
    return func(*args, **kwargs)
460
461
462


class PydanticMsgspecMixin:
463
464
465
466
467
468
469
470
471
472
473
474
475
    """Make a ``msgspec.Struct`` compatible with Pydantic for both
    **validation** (JSON/dict -> Struct) and **serialization**
    (Struct -> JSON-safe dict).

    Subclasses may set ``__pydantic_msgspec_exclude__`` (a ``set[str]``)
    to list non-underscore field names that should also be stripped from
    serialized output.  Fields whose names start with ``_`` are always
    excluded automatically.
    """

    # Subclasses can override to exclude additional public-but-internal keys.
    __pydantic_msgspec_exclude__: ClassVar[set[str]] = set()

476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
    @classmethod
    def __get_pydantic_core_schema__(
        cls, source_type: Any, handler: GetCoreSchemaHandler
    ) -> core_schema.CoreSchema:
        """
        Make msgspec.Struct compatible with Pydantic, respecting defaults.
        Handle JSON=>msgspec.Struct. Used when exposing msgspec.Struct to the
        API as input or in `/docs`. Note this is cached by Pydantic and not
        called on every validation.
        """
        msgspec_fields = {f.name: f for f in msgspec.structs.fields(source_type)}
        type_hints = get_type_hints(source_type)

        # Build the Pydantic typed_dict_field for each msgspec field
        fields = {}
        for name, hint in type_hints.items():
492
493
494
495
496
497
498
            if name not in msgspec_fields:
                # Skip ClassVar and other non-struct annotations.
                continue
            # Skip private fields — they are excluded from serialization
            # and should not appear in the generated JSON/OpenAPI schema.
            if name.startswith("_"):
                continue
499
500
501
502
503
504
            msgspec_field = msgspec_fields[name]

            # typed_dict_field using the handler to get the schema
            field_schema = handler(hint)

            # Add default value to the schema.
505
506
507
            # Mark fields with defaults as not required so the generated
            # JSON Schema stays consistent with ``omit_defaults=True``
            # serialization (fields at their default value may be absent).
508
509
510
511
512
            if msgspec_field.default_factory is not msgspec.NODEFAULT:
                wrapped_schema = core_schema.with_default_schema(
                    schema=field_schema,
                    default_factory=msgspec_field.default_factory,
                )
513
514
515
                fields[name] = core_schema.typed_dict_field(
                    wrapped_schema, required=False
                )
516
517
518
519
520
            elif msgspec_field.default is not msgspec.NODEFAULT:
                wrapped_schema = core_schema.with_default_schema(
                    schema=field_schema,
                    default=msgspec_field.default,
                )
521
522
523
                fields[name] = core_schema.typed_dict_field(
                    wrapped_schema, required=False
                )
524
525
526
            else:
                # No default, so Pydantic will treat it as required
                fields[name] = core_schema.typed_dict_field(field_schema)
527
        typed_dict_then_convert = core_schema.no_info_after_validator_function(
528
529
530
531
            cls._validate_msgspec,
            core_schema.typed_dict_schema(fields),
        )

532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
        # Build a serializer that strips private / excluded fields.
        serializer = core_schema.plain_serializer_function_ser_schema(
            cls._serialize_msgspec,
            info_arg=False,
        )

        # Accept either an already-constructed msgspec.Struct instance or a
        # JSON/dict-like payload.
        return core_schema.union_schema(
            [
                core_schema.is_instance_schema(source_type),
                typed_dict_then_convert,
            ],
            serialization=serializer,
        )

548
549
550
551
552
553
554
555
    @classmethod
    def _validate_msgspec(cls, value: Any) -> Any:
        """Validate and convert input to msgspec.Struct instance."""
        if isinstance(value, cls):
            return value
        if isinstance(value, dict):
            return cls(**value)
        return msgspec.convert(value, type=cls)
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577

    @staticmethod
    def _serialize_msgspec(value: Any) -> Any:
        """Serialize a msgspec.Struct to a JSON-compatible dict, stripping
        private (``_``-prefixed) and explicitly excluded fields.

        Uses ``msgspec.to_builtins`` which respects ``omit_defaults=True``,
        so only fields that differ from their declared defaults are included.
        """
        raw = msgspec.to_builtins(value)
        if not isinstance(raw, dict):
            return raw

        exclude: set[str] = cast(
            set[str],
            getattr(type(value), "__pydantic_msgspec_exclude__", set()),
        )
        for key in list(raw):
            if key.startswith("_") or key in exclude:
                del raw[key]

        return raw