serial_utils.py 9.84 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import dataclasses
4
import pickle
5
6
from collections.abc import Sequence
from inspect import isclass
7
from types import FunctionType
8
from typing import Any, Optional, Union
9

10
import cloudpickle
11
import numpy as np
12
import torch
13
import zmq
14
15
from msgspec import msgpack

16
17
18
19
20
21
22
23
from vllm import envs
from vllm.multimodal.inputs import (BaseMultiModalField,
                                    MultiModalBatchedField,
                                    MultiModalFieldConfig, MultiModalFieldElem,
                                    MultiModalFlatField, MultiModalKwargs,
                                    MultiModalKwargsItem,
                                    MultiModalSharedField, NestedTensors)

24
25
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
26
CUSTOM_TYPE_RAW_VIEW = 3
27

28
29
30
31
32
33
34
35
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
    MultiModalFlatField: "flat",
    MultiModalSharedField: "shared",
    MultiModalBatchedField: "batched",
}
36

37
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
38
39


40
41
class MsgpackEncoder:
    """Encoder with custom torch tensor and numpy array serialization.
42

43
44
    Note that unlike vanilla `msgspec` Encoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
45
46
47

    By default, arrays below 256B are serialized inline Larger will get sent 
    via dedicated messages. Note that this is a per-tensor limit.
48
49
    """

50
51
52
    def __init__(self, size_threshold: Optional[int] = None):
        if size_threshold is None:
            size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
53
54
55
56
57
        self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
        # This is used as a local stash of buffers that we can then access from
        # our custom `msgspec` hook, `enc_hook`. We don't have a way to
        # pass custom data to the hook otherwise.
        self.aux_buffers: Optional[list[bytestr]] = None
58
        self.size_threshold = size_threshold
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

    def encode(self, obj: Any) -> Sequence[bytestr]:
        try:
            self.aux_buffers = bufs = [b'']
            bufs[0] = self.encoder.encode(obj)
            # This `bufs` list allows us to collect direct pointers to backing
            # buffers of tensors and np arrays, and return them along with the
            # top-level encoded buffer instead of copying their data into the
            # new buffer.
            return bufs
        finally:
            self.aux_buffers = None

    def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
        try:
            self.aux_buffers = [buf]
            bufs = self.aux_buffers
            self.encoder.encode_into(obj, buf)
            return bufs
        finally:
            self.aux_buffers = None

    def enc_hook(self, obj: Any) -> Any:
        if isinstance(obj, torch.Tensor):
            return self._encode_ndarray(obj.numpy())

        # Fall back to pickle for object or void kind ndarrays.
        if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
            return self._encode_ndarray(obj)

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        if isinstance(obj, MultiModalKwargs):
            mm: MultiModalKwargs = obj
            if not mm.modalities:
                # just return the main dict if there are no modalities.
                return dict(mm)

            # ignore the main dict, it will be re-indexed.
            # Encode a list of MultiModalKwargsItems as plain dicts
            # + special handling for .field.
            # Any tensors *not* indexed by modality will be ignored.
            return [[{
                "modality": elem.modality,
                "key": elem.key,
                "data": self._encode_nested_tensors(elem.data),
                "field": self._encode_mm_field(elem.field),
            } for elem in item.values()]
                    for itemlist in mm._items_by_modality.values()
                    for item in itemlist]

108
109
110
111
112
113
114
115
116
117
118
119
        if isinstance(obj, FunctionType):
            # `pickle` is generally faster than cloudpickle, but can have
            # problems serializing methods.
            return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

        return msgpack.Ext(CUSTOM_TYPE_PICKLE,
                           pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))

    def _encode_ndarray(
        self, obj: np.ndarray
    ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
        assert self.aux_buffers is not None
120
        # If the array is non-contiguous, we need to copy it first
121
        arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
122
        if not obj.shape or obj.nbytes < self.size_threshold:
123
124
125
            # Encode small arrays and scalars inline. Using this extension type
            # ensures we can avoid copying when decoding.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
126
        else:
127
            # Otherwise encode index of backing buffer to avoid copy.
128
            data = len(self.aux_buffers)
129
130
            self.aux_buffers.append(arr_data)

131
132
133
134
        # We serialize the ndarray as a tuple of native types.
        # The data is either inlined if small, or an index into a list of
        # backing buffers that we've stashed in `aux_buffers`.
        return obj.dtype.str, obj.shape, data
135

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
        if isinstance(nt, torch.Tensor):
            return self._encode_ndarray(nt.numpy())
        if isinstance(nt, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return nt
        return [self._encode_nested_tensors(x) for x in nt]

    def _encode_mm_field(self, field: BaseMultiModalField):
        # Figure out the factory name for the field type.
        name = MMF_CLASS_TO_FACTORY.get(field.__class__)
        if not name:
            raise TypeError(f"Unsupported field type: {field.__class__}")
        # We just need to copy all of the field values in order
        # which will be then used to reconstruct the field.
        field_values = (getattr(field, f.name)
                        for f in dataclasses.fields(field))
        return name, *field_values

156
157

class MsgpackDecoder:
158
159
160
161
162
    """Decoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Decoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
    """
163

164
165
    def __init__(self, t: Optional[Any] = None):
        args = () if t is None else (t, )
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        self.decoder = msgpack.Decoder(*args,
                                       ext_hook=self.ext_hook,
                                       dec_hook=self.dec_hook)
        self.aux_buffers: Sequence[bytestr] = ()

    def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
        if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
            # TODO - This check can become `isinstance(bufs, bytestr)`
            # as of Python 3.10.
            return self.decoder.decode(bufs)

        self.aux_buffers = bufs
        try:
            return self.decoder.decode(bufs[0])
        finally:
            self.aux_buffers = ()

    def dec_hook(self, t: type, obj: Any) -> Any:
        # Given native types in `obj`, convert to type `t`.
        if isclass(t):
            if issubclass(t, np.ndarray):
                return self._decode_ndarray(obj)
            if issubclass(t, torch.Tensor):
                return torch.from_numpy(self._decode_ndarray(obj))
190
191
192
193
194
195
196
197
            if issubclass(t, MultiModalKwargs):
                if isinstance(obj, list):
                    return MultiModalKwargs.from_items(
                        self._decode_mm_items(obj))
                return MultiModalKwargs({
                    k: self._decode_nested_tensors(v)
                    for k, v in obj.items()
                })
198
199
200
201
        return obj

    def _decode_ndarray(self, arr: Any) -> np.ndarray:
        dtype, shape, data = arr
202
203
204
205
        # Copy from inline representation, otherwise Torch is unhappy since
        # the returned memory is non-writeable.
        buffer = self.aux_buffers[data] if isinstance(data, int) \
            else bytearray(data)
206
207
        return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
        decoded_items = []
        for item in obj:
            elems = []
            for v in item:
                v["data"] = self._decode_nested_tensors(v["data"])
                # Reconstruct the field processor using MultiModalFieldConfig
                factory_meth_name, *field_args = v["field"]
                factory_meth = getattr(MultiModalFieldConfig,
                                       factory_meth_name)
                v["field"] = factory_meth(None, *field_args).field
                elems.append(MultiModalFieldElem(**v))
            decoded_items.append(MultiModalKwargsItem.from_elems(elems))
        return decoded_items

    def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
        if isinstance(obj, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return obj
        if not isinstance(obj, list):
            raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
        if obj and isinstance(obj[0], str):
            return torch.from_numpy(self._decode_ndarray(obj))
        return [self._decode_nested_tensors(x) for x in obj]

234
    def ext_hook(self, code: int, data: memoryview) -> Any:
235
236
        if code == CUSTOM_TYPE_RAW_VIEW:
            return data
237
238
239
240
241
242
243
        if code == CUSTOM_TYPE_PICKLE:
            return pickle.loads(data)
        if code == CUSTOM_TYPE_CLOUDPICKLE:
            return cloudpickle.loads(data)

        raise NotImplementedError(
            f"Extension type code {code} is not supported")