"tests/tokenizers_/test_detokenize.py" did not exist on "69f46359dd5b36c1a059a0a8b729be1bd86394e8"
serial_utils.py 20 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
import importlib
6
import pickle
7
from collections.abc import Callable, Sequence
8
from functools import partial
9
from inspect import isclass
10
from types import FunctionType
11
from typing import Any, TypeAlias, get_type_hints
12

13
import cloudpickle
14
import msgspec
15
import numpy as np
16
import torch
17
import zmq
18
from msgspec import msgpack
19
20
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
21

22
from vllm import envs
23
from vllm.logger import init_logger
24
25
26
27
28
29
30
31
32
33
34
35
from vllm.multimodal.inputs import (
    BaseMultiModalField,
    MultiModalBatchedField,
    MultiModalFieldConfig,
    MultiModalFieldElem,
    MultiModalFlatField,
    MultiModalKwargs,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    MultiModalSharedField,
    NestedTensors,
)
36
from vllm.utils.platform_utils import is_pin_memory_available
37
from vllm.v1.utils import tensor_data
38

39
40
logger = init_logger(__name__)

41
42
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
43
CUSTOM_TYPE_RAW_VIEW = 3
44

45
46
47
48
49
50
51
52
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
    MultiModalFlatField: "flat",
    MultiModalSharedField: "shared",
    MultiModalBatchedField: "batched",
}
53

54
bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame
55
56


57
def _log_insecure_serialization_warning():
58
59
60
61
    logger.warning_once(
        "Allowing insecure serialization using pickle due to "
        "VLLM_ALLOW_INSECURE_SERIALIZATION=1"
    )
62
63


64
def _typestr(val: Any) -> tuple[str, str] | None:
65
66
67
    if val is None:
        return None
    t = type(val)
68
69
70
    return t.__module__, t.__qualname__


71
72
73
74
75
76
77
78
79
80
81
82
83
def _encode_type_info_recursive(obj: Any) -> Any:
    """Recursively encode type information for nested structures of
    lists/dicts."""
    if obj is None:
        return None
    if type(obj) is list:
        return [_encode_type_info_recursive(item) for item in obj]
    if type(obj) is dict:
        return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
    return _typestr(obj)


def _decode_type_info_recursive(
84
85
    type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any], Any]
) -> Any:
86
87
88
89
90
91
92
93
94
95
96
    """Recursively decode type information for nested structures of
    lists/dicts."""
    if type_info is None:
        return data
    if isinstance(type_info, dict):
        assert isinstance(data, dict)
        return {
            k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
            for k in type_info
        }
    if isinstance(type_info, list) and (
97
98
99
        # Exclude serialized tensors/numpy arrays.
        len(type_info) != 2 or not isinstance(type_info[0], str)
    ):
100
101
102
103
104
105
106
107
        assert isinstance(data, list)
        return [
            _decode_type_info_recursive(ti, d, convert_fn)
            for ti, d in zip(type_info, data)
        ]
    return convert_fn(type_info, data)


108
109
110
111
112
113
114
class UtilityResult:
    """Wrapper for special handling when serializing/deserializing."""

    def __init__(self, r: Any = None):
        self.result = r


115
116
class MsgpackEncoder:
    """Encoder with custom torch tensor and numpy array serialization.
117

118
119
    Note that unlike vanilla `msgspec` Encoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
120

121
    By default, arrays below 256B are serialized inline Larger will get sent
122
    via dedicated messages. Note that this is a per-tensor limit.
123
124
    """

125
    def __init__(self, size_threshold: int | None = None):
126
127
        if size_threshold is None:
            size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
128
129
130
131
        self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
        # This is used as a local stash of buffers that we can then access from
        # our custom `msgspec` hook, `enc_hook`. We don't have a way to
        # pass custom data to the hook otherwise.
132
        self.aux_buffers: list[bytestr] | None = None
133
        self.size_threshold = size_threshold
134
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
135
            _log_insecure_serialization_warning()
136
137
138

    def encode(self, obj: Any) -> Sequence[bytestr]:
        try:
139
            self.aux_buffers = bufs = [b""]
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
            bufs[0] = self.encoder.encode(obj)
            # This `bufs` list allows us to collect direct pointers to backing
            # buffers of tensors and np arrays, and return them along with the
            # top-level encoded buffer instead of copying their data into the
            # new buffer.
            return bufs
        finally:
            self.aux_buffers = None

    def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
        try:
            self.aux_buffers = [buf]
            bufs = self.aux_buffers
            self.encoder.encode_into(obj, buf)
            return bufs
        finally:
            self.aux_buffers = None

    def enc_hook(self, obj: Any) -> Any:
        if isinstance(obj, torch.Tensor):
160
            return self._encode_tensor(obj)
161
162

        # Fall back to pickle for object or void kind ndarrays.
163
        if isinstance(obj, np.ndarray) and obj.dtype.kind not in ("O", "V"):
164
165
            return self._encode_ndarray(obj)

166
167
168
169
        if isinstance(obj, slice):
            # We are assuming only int-based values will be used here.
            return tuple(
                int(v) if v is not None else None
170
171
                for v in (obj.start, obj.stop, obj.step)
            )
172

173
174
175
        if isinstance(obj, MultiModalKwargsItem):
            return self._encode_mm_item(obj)

176
177
178
        if isinstance(obj, MultiModalKwargsItems):
            return self._encode_mm_items(obj)

179
        if isinstance(obj, MultiModalKwargs):
180
            return self._encode_mm_kwargs(obj)
181

182
183
        if isinstance(obj, UtilityResult):
            result = obj.result
184
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
185
                return None, result
186
187
188
189
            # Since utility results are not strongly typed, we recursively
            # encode type information for nested structures of lists/dicts
            # to help with correct msgspec deserialization.
            return _encode_type_info_recursive(result), result
190

191
        if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
192
193
194
195
196
            raise TypeError(
                f"Object of type {type(obj)} is not serializable"
                "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
                "fallback to pickle-based serialization."
            )
197

198
199
200
201
202
        if isinstance(obj, FunctionType):
            # `pickle` is generally faster than cloudpickle, but can have
            # problems serializing methods.
            return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

203
204
205
        return msgpack.Ext(
            CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
        )
206
207
208

    def _encode_ndarray(
        self, obj: np.ndarray
209
    ) -> tuple[str, tuple[int, ...], int | memoryview]:
210
        assert self.aux_buffers is not None
211
        # If the array is non-contiguous, we need to copy it first
212
        arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
213
        if not obj.shape or obj.nbytes < self.size_threshold:
214
215
216
            # Encode small arrays and scalars inline. Using this extension type
            # ensures we can avoid copying when decoding.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
217
        else:
218
            # Otherwise encode index of backing buffer to avoid copy.
219
            data = len(self.aux_buffers)
220
221
            self.aux_buffers.append(arr_data)

222
223
224
225
        # We serialize the ndarray as a tuple of native types.
        # The data is either inlined if small, or an index into a list of
        # backing buffers that we've stashed in `aux_buffers`.
        return obj.dtype.str, obj.shape, data
226

227
228
    def _encode_tensor(
        self, obj: torch.Tensor
229
    ) -> tuple[str, tuple[int, ...], int | memoryview]:
230
        assert self.aux_buffers is not None
231
        # view the tensor as a contiguous 1D array of bytes
232
        arr_data = tensor_data(obj)
233
234
        if obj.nbytes < self.size_threshold:
            # Smaller tensors are encoded inline, just like ndarrays.
235
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
236
237
238
        else:
            # Otherwise encode index of backing buffer to avoid copy.
            data = len(self.aux_buffers)
239
            self.aux_buffers.append(arr_data)
240
        dtype = str(obj.dtype).removeprefix("torch.")
241
242
        return dtype, obj.shape, data

243
244
245
246
247
248
    def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]:
        return {
            modality: [self._encode_mm_item(item) for item in itemlist]
            for modality, itemlist in items.items()
        }

249
    def _encode_mm_item(self, item: MultiModalKwargsItem) -> list[dict[str, Any]]:
250
251
        return [self._encode_mm_field_elem(elem) for elem in item.values()]

252
    def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]:
253
        return {
254
255
256
257
258
259
            "modality": elem.modality,
            "key": elem.key,
            "data": (
                None if elem.data is None else self._encode_nested_tensors(elem.data)
            ),
            "field": self._encode_mm_field(elem.field),
260
261
        }

262
263
    def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]:
        return {
264
            modality: self._encode_nested_tensors(data) for modality, data in kw.items()
265
266
        }

267
268
    def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
        if isinstance(nt, torch.Tensor):
269
            return self._encode_tensor(nt)
270
271
272
273
274
275
276
277
278
279
280
281
282
        if isinstance(nt, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return nt
        return [self._encode_nested_tensors(x) for x in nt]

    def _encode_mm_field(self, field: BaseMultiModalField):
        # Figure out the factory name for the field type.
        name = MMF_CLASS_TO_FACTORY.get(field.__class__)
        if not name:
            raise TypeError(f"Unsupported field type: {field.__class__}")
        # We just need to copy all of the field values in order
        # which will be then used to reconstruct the field.
283
        field_values = (getattr(field, f.name) for f in dataclasses.fields(field))
284
285
        return name, *field_values

286
287

class MsgpackDecoder:
288
289
290
291
292
    """Decoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Decoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
    """
293

294
295
296
    def __init__(self, t: Any | None = None, share_mem: bool = True):
        self.share_mem = share_mem
        self.pin_tensors = is_pin_memory_available()
297
298
299
300
        args = () if t is None else (t,)
        self.decoder = msgpack.Decoder(
            *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
        )
301
        self.aux_buffers: Sequence[bytestr] = ()
302
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
303
            _log_insecure_serialization_warning()
304

305
    def decode(self, bufs: bytestr | Sequence[bytestr]) -> Any:
306
        if isinstance(bufs, bytestr):  # type: ignore
307
308
309
310
311
312
313
314
315
316
317
318
319
320
            return self.decoder.decode(bufs)

        self.aux_buffers = bufs
        try:
            return self.decoder.decode(bufs[0])
        finally:
            self.aux_buffers = ()

    def dec_hook(self, t: type, obj: Any) -> Any:
        # Given native types in `obj`, convert to type `t`.
        if isclass(t):
            if issubclass(t, np.ndarray):
                return self._decode_ndarray(obj)
            if issubclass(t, torch.Tensor):
321
                return self._decode_tensor(obj)
322
323
            if t is slice:
                return slice(*obj)
324
325
            if issubclass(t, MultiModalKwargsItem):
                return self._decode_mm_item(obj)
326
327
            if issubclass(t, MultiModalKwargsItems):
                return self._decode_mm_items(obj)
328
            if issubclass(t, MultiModalKwargs):
329
                return self._decode_mm_kwargs(obj)
330
331
            if t is UtilityResult:
                return self._decode_utility_result(obj)
332
333
        return obj

334
335
336
337
    def _decode_utility_result(self, obj: Any) -> UtilityResult:
        result_type, result = obj
        if result_type is not None:
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
338
339
340
341
                raise TypeError(
                    "VLLM_ALLOW_INSECURE_SERIALIZATION must "
                    "be set to use custom utility result types"
                )
342
            # Use recursive decoding to handle nested structures
343
344
345
            result = _decode_type_info_recursive(
                result_type, result, self._convert_result
            )
346
347
        return UtilityResult(result)

348
349
350
    def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
        if result_type is None:
            return result
351
352
353
354
355
        mod_name, name = result_type
        mod = importlib.import_module(mod_name)
        result_type = getattr(mod, name)
        return msgspec.convert(result, result_type, dec_hook=self.dec_hook)

356
357
    def _decode_ndarray(self, arr: Any) -> np.ndarray:
        dtype, shape, data = arr
358
359
360
        # zero-copy decode. We assume the ndarray will not be kept around,
        # as it now locks the whole received message buffer in memory.
        buffer = self.aux_buffers[data] if isinstance(data, int) else data
361
362
363
364
        arr = np.frombuffer(buffer, dtype=dtype)
        if not self.share_mem:
            arr = arr.copy()
        return arr.reshape(shape)
365
366
367

    def _decode_tensor(self, arr: Any) -> torch.Tensor:
        dtype, shape, data = arr
368
369
370
        is_aux = isinstance(data, int)
        buffer = self.aux_buffers[data] if is_aux else data
        buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer)
371
372
        torch_dtype = getattr(torch, dtype)
        assert isinstance(torch_dtype, torch.dtype)
373
        if not buffer.nbytes:  # torch.frombuffer doesn't like empty buffers
374
375
376
377
            assert 0 in shape
            return torch.empty(shape, dtype=torch_dtype)
        # Create uint8 array
        arr = torch.frombuffer(buffer, dtype=torch.uint8)
378
379
380
381
382
383
384
        # Clone ensures tensor is backed by pytorch-owned memory for safe
        # future async CPU->GPU transfer.
        # Pin larger tensors for more efficient CPU->GPU transfer.
        if not is_aux:
            arr = arr.clone()
        elif not self.share_mem:
            arr = arr.pin_memory() if self.pin_tensors else arr.clone()
385
        # Convert back to proper shape & type
386
        return arr.view(torch_dtype).view(shape)
387

388
    def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems:
389
390
391
392
393
394
        return MultiModalKwargsItems(
            {
                modality: [self._decode_mm_item(item) for item in itemlist]
                for modality, itemlist in obj.items()
            }
        )
395

396
    def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
397
        return MultiModalKwargsItem.from_elems(
398
399
            [self._decode_mm_field_elem(v) for v in obj]
        )
400

401
    def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem:
402
403
404
        if obj["data"] is not None:
            obj["data"] = self._decode_nested_tensors(obj["data"])

405
406
407
408
409
410
411
412
413
414
415
        # Reconstruct the field processor using MultiModalFieldConfig
        factory_meth_name, *field_args = obj["field"]
        factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)

        # Special case: decode the union "slices" field of
        # MultiModalFlatField
        if factory_meth_name == "flat":
            field_args[0] = self._decode_nested_slices(field_args[0])

        obj["field"] = factory_meth(None, *field_args).field
        return MultiModalFieldElem(**obj)
416

417
    def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs:
418
419
420
421
422
423
        return MultiModalKwargs(
            {
                modality: self._decode_nested_tensors(data)
                for modality, data in obj.items()
            }
        )
424

425
426
427
428
429
430
431
432
    def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
        if isinstance(obj, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return obj
        if not isinstance(obj, list):
            raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
        if obj and isinstance(obj[0], str):
433
            return self._decode_tensor(obj)
434
435
        return [self._decode_nested_tensors(x) for x in obj]

436
437
438
439
440
441
    def _decode_nested_slices(self, obj: Any) -> Any:
        assert isinstance(obj, (list, tuple))
        if obj and not isinstance(obj[0], (list, tuple)):
            return slice(*obj)
        return [self._decode_nested_slices(x) for x in obj]

442
    def ext_hook(self, code: int, data: memoryview) -> Any:
443
444
        if code == CUSTOM_TYPE_RAW_VIEW:
            return data
445

446
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
447
448
449
450
            if code == CUSTOM_TYPE_PICKLE:
                return pickle.loads(data)
            if code == CUSTOM_TYPE_CLOUDPICKLE:
                return cloudpickle.loads(data)
451

452
        raise NotImplementedError(f"Extension type code {code} is not supported")
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479


def run_method(
    obj: Any,
    method: str | bytes | Callable,
    args: tuple[Any, ...],
    kwargs: dict[str, Any],
) -> Any:
    """
    Run a method of an object with the given arguments and keyword arguments.
    If the method is string, it will be converted to a method using getattr.
    If the method is serialized bytes and will be deserialized using
    cloudpickle.
    If the method is a callable, it will be called directly.
    """
    if isinstance(method, bytes):
        func = partial(cloudpickle.loads(method), obj)
    elif isinstance(method, str):
        try:
            func = getattr(obj, method)
        except AttributeError:
            raise NotImplementedError(
                f"Method {method!r} is not implemented."
            ) from None
    else:
        func = partial(method, obj)  # type: ignore
    return func(*args, **kwargs)
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532


class PydanticMsgspecMixin:
    @classmethod
    def __get_pydantic_core_schema__(
        cls, source_type: Any, handler: GetCoreSchemaHandler
    ) -> core_schema.CoreSchema:
        """
        Make msgspec.Struct compatible with Pydantic, respecting defaults.
        Handle JSON=>msgspec.Struct. Used when exposing msgspec.Struct to the
        API as input or in `/docs`. Note this is cached by Pydantic and not
        called on every validation.
        """
        msgspec_fields = {f.name: f for f in msgspec.structs.fields(source_type)}
        type_hints = get_type_hints(source_type)

        # Build the Pydantic typed_dict_field for each msgspec field
        fields = {}
        for name, hint in type_hints.items():
            msgspec_field = msgspec_fields[name]

            # typed_dict_field using the handler to get the schema
            field_schema = handler(hint)

            # Add default value to the schema.
            if msgspec_field.default_factory is not msgspec.NODEFAULT:
                wrapped_schema = core_schema.with_default_schema(
                    schema=field_schema,
                    default_factory=msgspec_field.default_factory,
                )
                fields[name] = core_schema.typed_dict_field(wrapped_schema)
            elif msgspec_field.default is not msgspec.NODEFAULT:
                wrapped_schema = core_schema.with_default_schema(
                    schema=field_schema,
                    default=msgspec_field.default,
                )
                fields[name] = core_schema.typed_dict_field(wrapped_schema)
            else:
                # No default, so Pydantic will treat it as required
                fields[name] = core_schema.typed_dict_field(field_schema)
        return core_schema.no_info_after_validator_function(
            cls._validate_msgspec,
            core_schema.typed_dict_schema(fields),
        )

    @classmethod
    def _validate_msgspec(cls, value: Any) -> Any:
        """Validate and convert input to msgspec.Struct instance."""
        if isinstance(value, cls):
            return value
        if isinstance(value, dict):
            return cls(**value)
        return msgspec.convert(value, type=cls)