"cmake/vscode:/vscode.git/clone" did not exist on "96ad65b7fe515663da8ede09a1aa7f74aa500c97"
serial_utils.py 12.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
import pickle
6
7
from collections.abc import Sequence
from inspect import isclass
8
from types import FunctionType
9
from typing import Any, Optional, Union
10

11
import cloudpickle
12
import numpy as np
13
import torch
14
import zmq
15
16
from msgspec import msgpack

17
from vllm import envs
18
from vllm.logger import init_logger
19
20
21
22
23
24
25
from vllm.multimodal.inputs import (BaseMultiModalField,
                                    MultiModalBatchedField,
                                    MultiModalFieldConfig, MultiModalFieldElem,
                                    MultiModalFlatField, MultiModalKwargs,
                                    MultiModalKwargsItem,
                                    MultiModalSharedField, NestedTensors)

26
27
logger = init_logger(__name__)

28
29
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
30
CUSTOM_TYPE_RAW_VIEW = 3
31

32
33
34
35
36
37
38
39
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
    MultiModalFlatField: "flat",
    MultiModalSharedField: "shared",
    MultiModalBatchedField: "batched",
}
40

41
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
42
43


44
45
46
47
48
def _log_insecure_serialization_warning():
    logger.warning_once("Allowing insecure serialization using pickle due to "
                        "VLLM_ALLOW_INSECURE_SERIALIZATION=1")


49
50
class MsgpackEncoder:
    """Encoder with custom torch tensor and numpy array serialization.
51

52
53
    Note that unlike vanilla `msgspec` Encoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
54
55
56

    By default, arrays below 256B are serialized inline Larger will get sent 
    via dedicated messages. Note that this is a per-tensor limit.
57
58
    """

59
    def __init__(self, size_threshold: Optional[int] = None):
60
61
        if size_threshold is None:
            size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
62
63
64
65
66
        self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
        # This is used as a local stash of buffers that we can then access from
        # our custom `msgspec` hook, `enc_hook`. We don't have a way to
        # pass custom data to the hook otherwise.
        self.aux_buffers: Optional[list[bytestr]] = None
67
        self.size_threshold = size_threshold
68
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
69
            _log_insecure_serialization_warning()
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

    def encode(self, obj: Any) -> Sequence[bytestr]:
        try:
            self.aux_buffers = bufs = [b'']
            bufs[0] = self.encoder.encode(obj)
            # This `bufs` list allows us to collect direct pointers to backing
            # buffers of tensors and np arrays, and return them along with the
            # top-level encoded buffer instead of copying their data into the
            # new buffer.
            return bufs
        finally:
            self.aux_buffers = None

    def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
        try:
            self.aux_buffers = [buf]
            bufs = self.aux_buffers
            self.encoder.encode_into(obj, buf)
            return bufs
        finally:
            self.aux_buffers = None

    def enc_hook(self, obj: Any) -> Any:
        if isinstance(obj, torch.Tensor):
94
            return self._encode_tensor(obj)
95
96
97
98
99

        # Fall back to pickle for object or void kind ndarrays.
        if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
            return self._encode_ndarray(obj)

100
101
102
103
104
105
        if isinstance(obj, slice):
            # We are assuming only int-based values will be used here.
            return tuple(
                int(v) if v is not None else None
                for v in (obj.start, obj.stop, obj.step))

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        if isinstance(obj, MultiModalKwargs):
            mm: MultiModalKwargs = obj
            if not mm.modalities:
                # just return the main dict if there are no modalities.
                return dict(mm)

            # ignore the main dict, it will be re-indexed.
            # Encode a list of MultiModalKwargsItems as plain dicts
            # + special handling for .field.
            # Any tensors *not* indexed by modality will be ignored.
            return [[{
                "modality": elem.modality,
                "key": elem.key,
                "data": self._encode_nested_tensors(elem.data),
                "field": self._encode_mm_field(elem.field),
            } for elem in item.values()]
                    for itemlist in mm._items_by_modality.values()
                    for item in itemlist]

125
        if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
126
127
128
            raise TypeError(f"Object of type {type(obj)} is not serializable"
                            "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
                            "fallback to pickle-based serialization.")
129

130
131
132
133
134
135
136
137
138
139
140
141
        if isinstance(obj, FunctionType):
            # `pickle` is generally faster than cloudpickle, but can have
            # problems serializing methods.
            return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

        return msgpack.Ext(CUSTOM_TYPE_PICKLE,
                           pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))

    def _encode_ndarray(
        self, obj: np.ndarray
    ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
        assert self.aux_buffers is not None
142
        # If the array is non-contiguous, we need to copy it first
143
        arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
144
        if not obj.shape or obj.nbytes < self.size_threshold:
145
146
147
            # Encode small arrays and scalars inline. Using this extension type
            # ensures we can avoid copying when decoding.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
148
        else:
149
            # Otherwise encode index of backing buffer to avoid copy.
150
            data = len(self.aux_buffers)
151
152
            self.aux_buffers.append(arr_data)

153
154
155
156
        # We serialize the ndarray as a tuple of native types.
        # The data is either inlined if small, or an index into a list of
        # backing buffers that we've stashed in `aux_buffers`.
        return obj.dtype.str, obj.shape, data
157

158
159
160
161
    def _encode_tensor(
        self, obj: torch.Tensor
    ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
        assert self.aux_buffers is not None
162
163
        # view the tensor as a contiguous 1D array of bytes
        arr = obj.flatten().contiguous().view(torch.uint8).numpy()
164
165
166
167
168
169
170
        if obj.nbytes < self.size_threshold:
            # Smaller tensors are encoded inline, just like ndarrays.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
        else:
            # Otherwise encode index of backing buffer to avoid copy.
            data = len(self.aux_buffers)
            self.aux_buffers.append(arr.data)
171
        dtype = str(obj.dtype).removeprefix("torch.")
172
173
        return dtype, obj.shape, data

174
175
    def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
        if isinstance(nt, torch.Tensor):
176
            return self._encode_tensor(nt)
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        if isinstance(nt, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return nt
        return [self._encode_nested_tensors(x) for x in nt]

    def _encode_mm_field(self, field: BaseMultiModalField):
        # Figure out the factory name for the field type.
        name = MMF_CLASS_TO_FACTORY.get(field.__class__)
        if not name:
            raise TypeError(f"Unsupported field type: {field.__class__}")
        # We just need to copy all of the field values in order
        # which will be then used to reconstruct the field.
        field_values = (getattr(field, f.name)
                        for f in dataclasses.fields(field))
        return name, *field_values

194
195

class MsgpackDecoder:
196
197
198
199
200
    """Decoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Decoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
    """
201

202
    def __init__(self, t: Optional[Any] = None):
203
        args = () if t is None else (t, )
204
205
206
207
        self.decoder = msgpack.Decoder(*args,
                                       ext_hook=self.ext_hook,
                                       dec_hook=self.dec_hook)
        self.aux_buffers: Sequence[bytestr] = ()
208
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
209
            _log_insecure_serialization_warning()
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

    def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
        if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
            # TODO - This check can become `isinstance(bufs, bytestr)`
            # as of Python 3.10.
            return self.decoder.decode(bufs)

        self.aux_buffers = bufs
        try:
            return self.decoder.decode(bufs[0])
        finally:
            self.aux_buffers = ()

    def dec_hook(self, t: type, obj: Any) -> Any:
        # Given native types in `obj`, convert to type `t`.
        if isclass(t):
            if issubclass(t, np.ndarray):
                return self._decode_ndarray(obj)
            if issubclass(t, torch.Tensor):
229
                return self._decode_tensor(obj)
230
231
            if t is slice:
                return slice(*obj)
232
233
234
235
236
237
238
239
            if issubclass(t, MultiModalKwargs):
                if isinstance(obj, list):
                    return MultiModalKwargs.from_items(
                        self._decode_mm_items(obj))
                return MultiModalKwargs({
                    k: self._decode_nested_tensors(v)
                    for k, v in obj.items()
                })
240
241
242
243
        return obj

    def _decode_ndarray(self, arr: Any) -> np.ndarray:
        dtype, shape, data = arr
244
245
246
        # zero-copy decode. We assume the ndarray will not be kept around,
        # as it now locks the whole received message buffer in memory.
        buffer = self.aux_buffers[data] if isinstance(data, int) else data
247
        return np.frombuffer(buffer, dtype=dtype).reshape(shape)
248
249
250
251
252
253

    def _decode_tensor(self, arr: Any) -> torch.Tensor:
        dtype, shape, data = arr
        # Copy from inline representation, to decouple the memory storage
        # of the message from the original buffer. And also make Torch
        # not complain about a readonly memoryview.
254
255
        buffer = self.aux_buffers[data] if isinstance(data, int) \
            else bytearray(data)
256
257
        torch_dtype = getattr(torch, dtype)
        assert isinstance(torch_dtype, torch.dtype)
258
259
260
261
262
        if not buffer:  # torch.frombuffer doesn't like empty buffers
            assert 0 in shape
            return torch.empty(shape, dtype=torch_dtype)
        # Create uint8 array
        arr = torch.frombuffer(buffer, dtype=torch.uint8)
263
        # Convert back to proper shape & type
264
        return arr.view(torch_dtype).view(shape)
265

266
267
268
269
270
271
272
273
274
275
    def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
        decoded_items = []
        for item in obj:
            elems = []
            for v in item:
                v["data"] = self._decode_nested_tensors(v["data"])
                # Reconstruct the field processor using MultiModalFieldConfig
                factory_meth_name, *field_args = v["field"]
                factory_meth = getattr(MultiModalFieldConfig,
                                       factory_meth_name)
276
277
278
279
280
281

                # Special case: decode the union "slices" field of
                # MultiModalFlatField
                if factory_meth_name == "flat":
                    field_args[0] = self._decode_nested_slices(field_args[0])

282
283
284
285
286
287
288
289
290
291
292
293
294
                v["field"] = factory_meth(None, *field_args).field
                elems.append(MultiModalFieldElem(**v))
            decoded_items.append(MultiModalKwargsItem.from_elems(elems))
        return decoded_items

    def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
        if isinstance(obj, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return obj
        if not isinstance(obj, list):
            raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
        if obj and isinstance(obj[0], str):
295
            return self._decode_tensor(obj)
296
297
        return [self._decode_nested_tensors(x) for x in obj]

298
299
300
301
302
303
    def _decode_nested_slices(self, obj: Any) -> Any:
        assert isinstance(obj, (list, tuple))
        if obj and not isinstance(obj[0], (list, tuple)):
            return slice(*obj)
        return [self._decode_nested_slices(x) for x in obj]

304
    def ext_hook(self, code: int, data: memoryview) -> Any:
305
306
        if code == CUSTOM_TYPE_RAW_VIEW:
            return data
307

308
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
309
310
311
312
            if code == CUSTOM_TYPE_PICKLE:
                return pickle.loads(data)
            if code == CUSTOM_TYPE_CLOUDPICKLE:
                return cloudpickle.loads(data)
313
314
315

        raise NotImplementedError(
            f"Extension type code {code} is not supported")