serial_utils.py 19.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
import importlib
6
import pickle
7
from collections.abc import Callable, Sequence
8
from functools import partial
9
from inspect import isclass
10
from types import FunctionType
11
from typing import Any, TypeAlias, get_type_hints
12

13
import cloudpickle
14
import msgspec
15
import numpy as np
16
import torch
17
import zmq
18
from msgspec import msgpack
19
20
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
21

22
from vllm import envs
23
from vllm.logger import init_logger
24
25
26
27
28
29
30
31
32
33
34
from vllm.multimodal.inputs import (
    BaseMultiModalField,
    MultiModalBatchedField,
    MultiModalFieldConfig,
    MultiModalFieldElem,
    MultiModalFlatField,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    MultiModalSharedField,
    NestedTensors,
)
35
from vllm.utils.platform_utils import is_pin_memory_available
36
from vllm.v1.utils import tensor_data
37

38
39
logger = init_logger(__name__)

40
41
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
42
CUSTOM_TYPE_RAW_VIEW = 3
43

44
45
46
47
48
49
50
51
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
    MultiModalFlatField: "flat",
    MultiModalSharedField: "shared",
    MultiModalBatchedField: "batched",
}
52

53
bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame
54
55


56
def _log_insecure_serialization_warning():
57
58
59
60
    logger.warning_once(
        "Allowing insecure serialization using pickle due to "
        "VLLM_ALLOW_INSECURE_SERIALIZATION=1"
    )
61
62


63
def _typestr(val: Any) -> tuple[str, str] | None:
64
65
66
    if val is None:
        return None
    t = type(val)
67
68
69
    return t.__module__, t.__qualname__


70
71
72
73
74
75
76
77
78
79
80
81
82
def _encode_type_info_recursive(obj: Any) -> Any:
    """Recursively encode type information for nested structures of
    lists/dicts."""
    if obj is None:
        return None
    if type(obj) is list:
        return [_encode_type_info_recursive(item) for item in obj]
    if type(obj) is dict:
        return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
    return _typestr(obj)


def _decode_type_info_recursive(
83
84
    type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any], Any]
) -> Any:
85
86
87
88
89
90
91
92
93
94
95
    """Recursively decode type information for nested structures of
    lists/dicts."""
    if type_info is None:
        return data
    if isinstance(type_info, dict):
        assert isinstance(data, dict)
        return {
            k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
            for k in type_info
        }
    if isinstance(type_info, list) and (
96
97
98
        # Exclude serialized tensors/numpy arrays.
        len(type_info) != 2 or not isinstance(type_info[0], str)
    ):
99
100
101
102
103
104
105
106
        assert isinstance(data, list)
        return [
            _decode_type_info_recursive(ti, d, convert_fn)
            for ti, d in zip(type_info, data)
        ]
    return convert_fn(type_info, data)


107
108
109
110
111
112
113
class UtilityResult:
    """Wrapper for special handling when serializing/deserializing."""

    def __init__(self, r: Any = None):
        self.result = r


114
115
class MsgpackEncoder:
    """Encoder with custom torch tensor and numpy array serialization.
116

117
118
    Note that unlike vanilla `msgspec` Encoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
119

120
    By default, arrays below 256B are serialized inline Larger will get sent
121
    via dedicated messages. Note that this is a per-tensor limit.
122
123
    """

124
    def __init__(self, size_threshold: int | None = None):
125
126
        if size_threshold is None:
            size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
127
128
129
130
        self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
        # This is used as a local stash of buffers that we can then access from
        # our custom `msgspec` hook, `enc_hook`. We don't have a way to
        # pass custom data to the hook otherwise.
131
        self.aux_buffers: list[bytestr] | None = None
132
        self.size_threshold = size_threshold
133
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
134
            _log_insecure_serialization_warning()
135
136
137

    def encode(self, obj: Any) -> Sequence[bytestr]:
        try:
138
            self.aux_buffers = bufs = [b""]
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
            bufs[0] = self.encoder.encode(obj)
            # This `bufs` list allows us to collect direct pointers to backing
            # buffers of tensors and np arrays, and return them along with the
            # top-level encoded buffer instead of copying their data into the
            # new buffer.
            return bufs
        finally:
            self.aux_buffers = None

    def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
        try:
            self.aux_buffers = [buf]
            bufs = self.aux_buffers
            self.encoder.encode_into(obj, buf)
            return bufs
        finally:
            self.aux_buffers = None

    def enc_hook(self, obj: Any) -> Any:
        if isinstance(obj, torch.Tensor):
159
            return self._encode_tensor(obj)
160
161

        # Fall back to pickle for object or void kind ndarrays.
162
        if isinstance(obj, np.ndarray) and obj.dtype.kind not in ("O", "V"):
163
164
            return self._encode_ndarray(obj)

165
166
167
168
        if isinstance(obj, slice):
            # We are assuming only int-based values will be used here.
            return tuple(
                int(v) if v is not None else None
169
170
                for v in (obj.start, obj.stop, obj.step)
            )
171

172
173
174
        if isinstance(obj, MultiModalKwargsItem):
            return self._encode_mm_item(obj)

175
176
177
        if isinstance(obj, MultiModalKwargsItems):
            return self._encode_mm_items(obj)

178
179
        if isinstance(obj, UtilityResult):
            result = obj.result
180
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
181
                return None, result
182
183
184
185
            # Since utility results are not strongly typed, we recursively
            # encode type information for nested structures of lists/dicts
            # to help with correct msgspec deserialization.
            return _encode_type_info_recursive(result), result
186

187
        if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
188
189
190
191
192
            raise TypeError(
                f"Object of type {type(obj)} is not serializable"
                "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
                "fallback to pickle-based serialization."
            )
193

194
195
196
197
198
        if isinstance(obj, FunctionType):
            # `pickle` is generally faster than cloudpickle, but can have
            # problems serializing methods.
            return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

199
200
201
        return msgpack.Ext(
            CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
        )
202
203
204

    def _encode_ndarray(
        self, obj: np.ndarray
205
    ) -> tuple[str, tuple[int, ...], int | memoryview]:
206
        assert self.aux_buffers is not None
207
        # If the array is non-contiguous, we need to copy it first
208
        arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
209
        if not obj.shape or obj.nbytes < self.size_threshold:
210
211
212
            # Encode small arrays and scalars inline. Using this extension type
            # ensures we can avoid copying when decoding.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
213
        else:
214
            # Otherwise encode index of backing buffer to avoid copy.
215
            data = len(self.aux_buffers)
216
217
            self.aux_buffers.append(arr_data)

218
219
220
221
        # We serialize the ndarray as a tuple of native types.
        # The data is either inlined if small, or an index into a list of
        # backing buffers that we've stashed in `aux_buffers`.
        return obj.dtype.str, obj.shape, data
222

223
224
    def _encode_tensor(
        self, obj: torch.Tensor
225
    ) -> tuple[str, tuple[int, ...], int | memoryview]:
226
        assert self.aux_buffers is not None
227
        # view the tensor as a contiguous 1D array of bytes
228
        arr_data = tensor_data(obj)
229
230
        if obj.nbytes < self.size_threshold:
            # Smaller tensors are encoded inline, just like ndarrays.
231
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
232
233
234
        else:
            # Otherwise encode index of backing buffer to avoid copy.
            data = len(self.aux_buffers)
235
            self.aux_buffers.append(arr_data)
236
        dtype = str(obj.dtype).removeprefix("torch.")
237
238
        return dtype, obj.shape, data

239
240
241
242
243
244
    def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]:
        return {
            modality: [self._encode_mm_item(item) for item in itemlist]
            for modality, itemlist in items.items()
        }

245
    def _encode_mm_item(self, item: MultiModalKwargsItem) -> list[dict[str, Any]]:
246
247
        return [self._encode_mm_field_elem(elem) for elem in item.values()]

248
    def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]:
249
        return {
250
251
252
253
254
255
            "modality": elem.modality,
            "key": elem.key,
            "data": (
                None if elem.data is None else self._encode_nested_tensors(elem.data)
            ),
            "field": self._encode_mm_field(elem.field),
256
257
        }

258
259
    def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
        if isinstance(nt, torch.Tensor):
260
            return self._encode_tensor(nt)
261
262
263
264
265
266
267
268
269
270
271
        if isinstance(nt, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return nt
        return [self._encode_nested_tensors(x) for x in nt]

    def _encode_mm_field(self, field: BaseMultiModalField):
        # Figure out the factory name for the field type.
        name = MMF_CLASS_TO_FACTORY.get(field.__class__)
        if not name:
            raise TypeError(f"Unsupported field type: {field.__class__}")
272

273
274
        # We just need to copy all of the field values in order
        # which will be then used to reconstruct the field.
275
276
        factory_kw = {f.name: getattr(field, f.name) for f in dataclasses.fields(field)}
        return name, factory_kw
277

278
279

class MsgpackDecoder:
280
281
282
283
284
    """Decoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Decoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
    """
285

286
287
288
    def __init__(self, t: Any | None = None, share_mem: bool = True):
        self.share_mem = share_mem
        self.pin_tensors = is_pin_memory_available()
289
290
291
292
        args = () if t is None else (t,)
        self.decoder = msgpack.Decoder(
            *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
        )
293
        self.aux_buffers: Sequence[bytestr] = ()
294
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
295
            _log_insecure_serialization_warning()
296

297
    def decode(self, bufs: bytestr | Sequence[bytestr]) -> Any:
298
        if isinstance(bufs, bytestr):  # type: ignore
299
300
301
302
303
304
305
306
307
308
309
310
311
312
            return self.decoder.decode(bufs)

        self.aux_buffers = bufs
        try:
            return self.decoder.decode(bufs[0])
        finally:
            self.aux_buffers = ()

    def dec_hook(self, t: type, obj: Any) -> Any:
        # Given native types in `obj`, convert to type `t`.
        if isclass(t):
            if issubclass(t, np.ndarray):
                return self._decode_ndarray(obj)
            if issubclass(t, torch.Tensor):
313
                return self._decode_tensor(obj)
314
315
            if t is slice:
                return slice(*obj)
316
317
            if issubclass(t, MultiModalKwargsItem):
                return self._decode_mm_item(obj)
318
319
            if issubclass(t, MultiModalKwargsItems):
                return self._decode_mm_items(obj)
320
321
            if t is UtilityResult:
                return self._decode_utility_result(obj)
322
323
        return obj

324
325
326
327
    def _decode_utility_result(self, obj: Any) -> UtilityResult:
        result_type, result = obj
        if result_type is not None:
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
328
329
330
331
                raise TypeError(
                    "VLLM_ALLOW_INSECURE_SERIALIZATION must "
                    "be set to use custom utility result types"
                )
332
            # Use recursive decoding to handle nested structures
333
334
335
            result = _decode_type_info_recursive(
                result_type, result, self._convert_result
            )
336
337
        return UtilityResult(result)

338
339
340
    def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
        if result_type is None:
            return result
341
342
343
344
345
        mod_name, name = result_type
        mod = importlib.import_module(mod_name)
        result_type = getattr(mod, name)
        return msgspec.convert(result, result_type, dec_hook=self.dec_hook)

346
347
    def _decode_ndarray(self, arr: Any) -> np.ndarray:
        dtype, shape, data = arr
348
349
350
        # zero-copy decode. We assume the ndarray will not be kept around,
        # as it now locks the whole received message buffer in memory.
        buffer = self.aux_buffers[data] if isinstance(data, int) else data
351
352
353
354
        arr = np.frombuffer(buffer, dtype=dtype)
        if not self.share_mem:
            arr = arr.copy()
        return arr.reshape(shape)
355
356
357

    def _decode_tensor(self, arr: Any) -> torch.Tensor:
        dtype, shape, data = arr
358
359
360
        is_aux = isinstance(data, int)
        buffer = self.aux_buffers[data] if is_aux else data
        buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer)
361
362
        torch_dtype = getattr(torch, dtype)
        assert isinstance(torch_dtype, torch.dtype)
363
        if not buffer.nbytes:  # torch.frombuffer doesn't like empty buffers
364
365
366
367
            assert 0 in shape
            return torch.empty(shape, dtype=torch_dtype)
        # Create uint8 array
        arr = torch.frombuffer(buffer, dtype=torch.uint8)
368
369
370
371
372
373
374
        # Clone ensures tensor is backed by pytorch-owned memory for safe
        # future async CPU->GPU transfer.
        # Pin larger tensors for more efficient CPU->GPU transfer.
        if not is_aux:
            arr = arr.clone()
        elif not self.share_mem:
            arr = arr.pin_memory() if self.pin_tensors else arr.clone()
375
        # Convert back to proper shape & type
376
        return arr.view(torch_dtype).view(shape)
377

378
    def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems:
379
380
381
382
383
384
        return MultiModalKwargsItems(
            {
                modality: [self._decode_mm_item(item) for item in itemlist]
                for modality, itemlist in obj.items()
            }
        )
385

386
    def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
387
        return MultiModalKwargsItem.from_elems(
388
389
            [self._decode_mm_field_elem(v) for v in obj]
        )
390

391
    def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem:
392
393
394
        if obj["data"] is not None:
            obj["data"] = self._decode_nested_tensors(obj["data"])

395
        # Reconstruct the field processor using MultiModalFieldConfig
396
        factory_meth_name, factory_kw = obj["field"]
397
398
399
400
401
        factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)

        # Special case: decode the union "slices" field of
        # MultiModalFlatField
        if factory_meth_name == "flat":
402
            factory_kw["slices"] = self._decode_nested_slices(factory_kw["slices"])
403

404
        obj["field"] = factory_meth("", **factory_kw).field
405
        return MultiModalFieldElem(**obj)
406
407
408
409
410
411
412
413
414

    def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
        if isinstance(obj, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return obj
        if not isinstance(obj, list):
            raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
        if obj and isinstance(obj[0], str):
415
            return self._decode_tensor(obj)
416
417
        return [self._decode_nested_tensors(x) for x in obj]

418
419
420
421
422
423
    def _decode_nested_slices(self, obj: Any) -> Any:
        assert isinstance(obj, (list, tuple))
        if obj and not isinstance(obj[0], (list, tuple)):
            return slice(*obj)
        return [self._decode_nested_slices(x) for x in obj]

424
    def ext_hook(self, code: int, data: memoryview) -> Any:
425
426
        if code == CUSTOM_TYPE_RAW_VIEW:
            return data
427

428
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
429
430
431
432
            if code == CUSTOM_TYPE_PICKLE:
                return pickle.loads(data)
            if code == CUSTOM_TYPE_CLOUDPICKLE:
                return cloudpickle.loads(data)
433

434
        raise NotImplementedError(f"Extension type code {code} is not supported")
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461


def run_method(
    obj: Any,
    method: str | bytes | Callable,
    args: tuple[Any, ...],
    kwargs: dict[str, Any],
) -> Any:
    """
    Run a method of an object with the given arguments and keyword arguments.
    If the method is string, it will be converted to a method using getattr.
    If the method is serialized bytes and will be deserialized using
    cloudpickle.
    If the method is a callable, it will be called directly.
    """
    if isinstance(method, bytes):
        func = partial(cloudpickle.loads(method), obj)
    elif isinstance(method, str):
        try:
            func = getattr(obj, method)
        except AttributeError:
            raise NotImplementedError(
                f"Method {method!r} is not implemented."
            ) from None
    else:
        func = partial(method, obj)  # type: ignore
    return func(*args, **kwargs)
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514


class PydanticMsgspecMixin:
    @classmethod
    def __get_pydantic_core_schema__(
        cls, source_type: Any, handler: GetCoreSchemaHandler
    ) -> core_schema.CoreSchema:
        """
        Make msgspec.Struct compatible with Pydantic, respecting defaults.
        Handle JSON=>msgspec.Struct. Used when exposing msgspec.Struct to the
        API as input or in `/docs`. Note this is cached by Pydantic and not
        called on every validation.
        """
        msgspec_fields = {f.name: f for f in msgspec.structs.fields(source_type)}
        type_hints = get_type_hints(source_type)

        # Build the Pydantic typed_dict_field for each msgspec field
        fields = {}
        for name, hint in type_hints.items():
            msgspec_field = msgspec_fields[name]

            # typed_dict_field using the handler to get the schema
            field_schema = handler(hint)

            # Add default value to the schema.
            if msgspec_field.default_factory is not msgspec.NODEFAULT:
                wrapped_schema = core_schema.with_default_schema(
                    schema=field_schema,
                    default_factory=msgspec_field.default_factory,
                )
                fields[name] = core_schema.typed_dict_field(wrapped_schema)
            elif msgspec_field.default is not msgspec.NODEFAULT:
                wrapped_schema = core_schema.with_default_schema(
                    schema=field_schema,
                    default=msgspec_field.default,
                )
                fields[name] = core_schema.typed_dict_field(wrapped_schema)
            else:
                # No default, so Pydantic will treat it as required
                fields[name] = core_schema.typed_dict_field(field_schema)
        return core_schema.no_info_after_validator_function(
            cls._validate_msgspec,
            core_schema.typed_dict_schema(fields),
        )

    @classmethod
    def _validate_msgspec(cls, value: Any) -> Any:
        """Validate and convert input to msgspec.Struct instance."""
        if isinstance(value, cls):
            return value
        if isinstance(value, dict):
            return cls(**value)
        return msgspec.convert(value, type=cls)