"benchmarks/kernels/benchmark_per_token_quant_fp8.py" did not exist on "5c2acb270aad36e35750f617f6219bf95a4924c4"
serial_utils.py 16.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
import importlib
6
import pickle
7
8
from collections.abc import Sequence
from inspect import isclass
9
from types import FunctionType
10
from typing import Any, Callable, Optional, Union
11

12
import cloudpickle
13
import msgspec
14
import numpy as np
15
import torch
16
import zmq
17
18
from msgspec import msgpack

19
from vllm import envs
20
from vllm.logger import init_logger
21

22
# yapf: disable
23
24
25
26
27
28
29
30
31
32
33
34
35
from vllm.multimodal.inputs import (
    BaseMultiModalField,
    MultiModalBatchedField,
    MultiModalFieldConfig,
    MultiModalFieldElem,
    MultiModalFlatField,
    MultiModalKwargs,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    MultiModalSharedField,
    NestedTensors,
)

36
# yapf: enable
37
from vllm.v1.engine import UtilityResult
38

39
40
logger = init_logger(__name__)

41
42
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
43
CUSTOM_TYPE_RAW_VIEW = 3
44

45
46
47
48
49
50
51
52
# MultiModalField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
    MultiModalFlatField: "flat",
    MultiModalSharedField: "shared",
    MultiModalBatchedField: "batched",
}
53

54
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
55
56


57
def _log_insecure_serialization_warning():
58
59
60
61
    logger.warning_once(
        "Allowing insecure serialization using pickle due to "
        "VLLM_ALLOW_INSECURE_SERIALIZATION=1"
    )
62
63


64
65
66
67
def _typestr(val: Any) -> Optional[tuple[str, str]]:
    if val is None:
        return None
    t = type(val)
68
69
70
    return t.__module__, t.__qualname__


71
72
73
74
75
76
77
78
79
80
81
82
83
def _encode_type_info_recursive(obj: Any) -> Any:
    """Recursively encode type information for nested structures of
    lists/dicts."""
    if obj is None:
        return None
    if type(obj) is list:
        return [_encode_type_info_recursive(item) for item in obj]
    if type(obj) is dict:
        return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
    return _typestr(obj)


def _decode_type_info_recursive(
84
85
    type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any], Any]
) -> Any:
86
87
88
89
90
91
92
93
94
95
96
    """Recursively decode type information for nested structures of
    lists/dicts."""
    if type_info is None:
        return data
    if isinstance(type_info, dict):
        assert isinstance(data, dict)
        return {
            k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
            for k in type_info
        }
    if isinstance(type_info, list) and (
97
98
99
        # Exclude serialized tensors/numpy arrays.
        len(type_info) != 2 or not isinstance(type_info[0], str)
    ):
100
101
102
103
104
105
106
107
        assert isinstance(data, list)
        return [
            _decode_type_info_recursive(ti, d, convert_fn)
            for ti, d in zip(type_info, data)
        ]
    return convert_fn(type_info, data)


108
109
class MsgpackEncoder:
    """Encoder with custom torch tensor and numpy array serialization.
110

111
112
    Note that unlike vanilla `msgspec` Encoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
113

114
    By default, arrays below 256B are serialized inline Larger will get sent
115
    via dedicated messages. Note that this is a per-tensor limit.
116
117
    """

118
    def __init__(self, size_threshold: Optional[int] = None):
119
120
        if size_threshold is None:
            size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
121
122
123
124
125
        self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
        # This is used as a local stash of buffers that we can then access from
        # our custom `msgspec` hook, `enc_hook`. We don't have a way to
        # pass custom data to the hook otherwise.
        self.aux_buffers: Optional[list[bytestr]] = None
126
        self.size_threshold = size_threshold
127
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
128
            _log_insecure_serialization_warning()
129
130
131

    def encode(self, obj: Any) -> Sequence[bytestr]:
        try:
132
            self.aux_buffers = bufs = [b""]
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
            bufs[0] = self.encoder.encode(obj)
            # This `bufs` list allows us to collect direct pointers to backing
            # buffers of tensors and np arrays, and return them along with the
            # top-level encoded buffer instead of copying their data into the
            # new buffer.
            return bufs
        finally:
            self.aux_buffers = None

    def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
        try:
            self.aux_buffers = [buf]
            bufs = self.aux_buffers
            self.encoder.encode_into(obj, buf)
            return bufs
        finally:
            self.aux_buffers = None

    def enc_hook(self, obj: Any) -> Any:
        if isinstance(obj, torch.Tensor):
153
            return self._encode_tensor(obj)
154
155

        # Fall back to pickle for object or void kind ndarrays.
156
        if isinstance(obj, np.ndarray) and obj.dtype.kind not in ("O", "V"):
157
158
            return self._encode_ndarray(obj)

159
160
161
162
        if isinstance(obj, slice):
            # We are assuming only int-based values will be used here.
            return tuple(
                int(v) if v is not None else None
163
164
                for v in (obj.start, obj.stop, obj.step)
            )
165

166
167
168
        if isinstance(obj, MultiModalKwargsItem):
            return self._encode_mm_item(obj)

169
170
171
        if isinstance(obj, MultiModalKwargsItems):
            return self._encode_mm_items(obj)

172
        if isinstance(obj, MultiModalKwargs):
173
            return self._encode_mm_kwargs(obj)
174

175
176
        if isinstance(obj, UtilityResult):
            result = obj.result
177
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
178
                return None, result
179
180
181
182
            # Since utility results are not strongly typed, we recursively
            # encode type information for nested structures of lists/dicts
            # to help with correct msgspec deserialization.
            return _encode_type_info_recursive(result), result
183

184
        if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
185
186
187
188
189
            raise TypeError(
                f"Object of type {type(obj)} is not serializable"
                "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
                "fallback to pickle-based serialization."
            )
190

191
192
193
194
195
        if isinstance(obj, FunctionType):
            # `pickle` is generally faster than cloudpickle, but can have
            # problems serializing methods.
            return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

196
197
198
        return msgpack.Ext(
            CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
        )
199
200
201
202
203

    def _encode_ndarray(
        self, obj: np.ndarray
    ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
        assert self.aux_buffers is not None
204
        # If the array is non-contiguous, we need to copy it first
205
        arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
206
        if not obj.shape or obj.nbytes < self.size_threshold:
207
208
209
            # Encode small arrays and scalars inline. Using this extension type
            # ensures we can avoid copying when decoding.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
210
        else:
211
            # Otherwise encode index of backing buffer to avoid copy.
212
            data = len(self.aux_buffers)
213
214
            self.aux_buffers.append(arr_data)

215
216
217
218
        # We serialize the ndarray as a tuple of native types.
        # The data is either inlined if small, or an index into a list of
        # backing buffers that we've stashed in `aux_buffers`.
        return obj.dtype.str, obj.shape, data
219

220
221
222
223
    def _encode_tensor(
        self, obj: torch.Tensor
    ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
        assert self.aux_buffers is not None
224
        # view the tensor as a contiguous 1D array of bytes
225
        arr = obj.flatten().contiguous().view(torch.uint8).numpy()
226
227
228
229
230
231
232
        if obj.nbytes < self.size_threshold:
            # Smaller tensors are encoded inline, just like ndarrays.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
        else:
            # Otherwise encode index of backing buffer to avoid copy.
            data = len(self.aux_buffers)
            self.aux_buffers.append(arr.data)
233
        dtype = str(obj.dtype).removeprefix("torch.")
234
235
        return dtype, obj.shape, data

236
237
238
239
240
241
    def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]:
        return {
            modality: [self._encode_mm_item(item) for item in itemlist]
            for modality, itemlist in items.items()
        }

242
    def _encode_mm_item(self, item: MultiModalKwargsItem) -> list[dict[str, Any]]:
243
244
        return [self._encode_mm_field_elem(elem) for elem in item.values()]

245
    def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]:
246
        return {
247
248
249
250
251
252
            "modality": elem.modality,
            "key": elem.key,
            "data": (
                None if elem.data is None else self._encode_nested_tensors(elem.data)
            ),
            "field": self._encode_mm_field(elem.field),
253
254
        }

255
256
    def _encode_mm_kwargs(self, kw: MultiModalKwargs) -> dict[str, Any]:
        return {
257
            modality: self._encode_nested_tensors(data) for modality, data in kw.items()
258
259
        }

260
261
    def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
        if isinstance(nt, torch.Tensor):
262
            return self._encode_tensor(nt)
263
264
265
266
267
268
269
270
271
272
273
274
275
        if isinstance(nt, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return nt
        return [self._encode_nested_tensors(x) for x in nt]

    def _encode_mm_field(self, field: BaseMultiModalField):
        # Figure out the factory name for the field type.
        name = MMF_CLASS_TO_FACTORY.get(field.__class__)
        if not name:
            raise TypeError(f"Unsupported field type: {field.__class__}")
        # We just need to copy all of the field values in order
        # which will be then used to reconstruct the field.
276
        field_values = (getattr(field, f.name) for f in dataclasses.fields(field))
277
278
        return name, *field_values

279
280

class MsgpackDecoder:
281
282
283
284
285
    """Decoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Decoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.
    """
286

287
    def __init__(self, t: Optional[Any] = None):
288
289
290
291
        args = () if t is None else (t,)
        self.decoder = msgpack.Decoder(
            *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
        )
292
        self.aux_buffers: Sequence[bytestr] = ()
293
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
294
            _log_insecure_serialization_warning()
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313

    def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
        if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
            # TODO - This check can become `isinstance(bufs, bytestr)`
            # as of Python 3.10.
            return self.decoder.decode(bufs)

        self.aux_buffers = bufs
        try:
            return self.decoder.decode(bufs[0])
        finally:
            self.aux_buffers = ()

    def dec_hook(self, t: type, obj: Any) -> Any:
        # Given native types in `obj`, convert to type `t`.
        if isclass(t):
            if issubclass(t, np.ndarray):
                return self._decode_ndarray(obj)
            if issubclass(t, torch.Tensor):
314
                return self._decode_tensor(obj)
315
316
            if t is slice:
                return slice(*obj)
317
318
            if issubclass(t, MultiModalKwargsItem):
                return self._decode_mm_item(obj)
319
320
            if issubclass(t, MultiModalKwargsItems):
                return self._decode_mm_items(obj)
321
            if issubclass(t, MultiModalKwargs):
322
                return self._decode_mm_kwargs(obj)
323
324
            if t is UtilityResult:
                return self._decode_utility_result(obj)
325
326
        return obj

327
328
329
330
    def _decode_utility_result(self, obj: Any) -> UtilityResult:
        result_type, result = obj
        if result_type is not None:
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
331
332
333
334
                raise TypeError(
                    "VLLM_ALLOW_INSECURE_SERIALIZATION must "
                    "be set to use custom utility result types"
                )
335
            # Use recursive decoding to handle nested structures
336
337
338
            result = _decode_type_info_recursive(
                result_type, result, self._convert_result
            )
339
340
        return UtilityResult(result)

341
342
343
    def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
        if result_type is None:
            return result
344
345
346
347
348
        mod_name, name = result_type
        mod = importlib.import_module(mod_name)
        result_type = getattr(mod, name)
        return msgspec.convert(result, result_type, dec_hook=self.dec_hook)

349
350
    def _decode_ndarray(self, arr: Any) -> np.ndarray:
        dtype, shape, data = arr
351
352
353
        # zero-copy decode. We assume the ndarray will not be kept around,
        # as it now locks the whole received message buffer in memory.
        buffer = self.aux_buffers[data] if isinstance(data, int) else data
354
        return np.frombuffer(buffer, dtype=dtype).reshape(shape)
355
356
357
358
359
360

    def _decode_tensor(self, arr: Any) -> torch.Tensor:
        dtype, shape, data = arr
        # Copy from inline representation, to decouple the memory storage
        # of the message from the original buffer. And also make Torch
        # not complain about a readonly memoryview.
361
        buffer = self.aux_buffers[data] if isinstance(data, int) else bytearray(data)
362
363
        torch_dtype = getattr(torch, dtype)
        assert isinstance(torch_dtype, torch.dtype)
364
365
366
367
368
        if not buffer:  # torch.frombuffer doesn't like empty buffers
            assert 0 in shape
            return torch.empty(shape, dtype=torch_dtype)
        # Create uint8 array
        arr = torch.frombuffer(buffer, dtype=torch.uint8)
369
        # Convert back to proper shape & type
370
        return arr.view(torch_dtype).view(shape)
371

372
    def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems:
373
374
375
376
377
378
        return MultiModalKwargsItems(
            {
                modality: [self._decode_mm_item(item) for item in itemlist]
                for modality, itemlist in obj.items()
            }
        )
379

380
    def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
381
        return MultiModalKwargsItem.from_elems(
382
383
            [self._decode_mm_field_elem(v) for v in obj]
        )
384

385
    def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem:
386
387
388
        if obj["data"] is not None:
            obj["data"] = self._decode_nested_tensors(obj["data"])

389
390
391
392
393
394
395
396
397
398
399
        # Reconstruct the field processor using MultiModalFieldConfig
        factory_meth_name, *field_args = obj["field"]
        factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)

        # Special case: decode the union "slices" field of
        # MultiModalFlatField
        if factory_meth_name == "flat":
            field_args[0] = self._decode_nested_slices(field_args[0])

        obj["field"] = factory_meth(None, *field_args).field
        return MultiModalFieldElem(**obj)
400

401
    def _decode_mm_kwargs(self, obj: dict[str, Any]) -> MultiModalKwargs:
402
403
404
405
406
407
        return MultiModalKwargs(
            {
                modality: self._decode_nested_tensors(data)
                for modality, data in obj.items()
            }
        )
408

409
410
411
412
413
414
415
416
    def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
        if isinstance(obj, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return obj
        if not isinstance(obj, list):
            raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
        if obj and isinstance(obj[0], str):
417
            return self._decode_tensor(obj)
418
419
        return [self._decode_nested_tensors(x) for x in obj]

420
421
422
423
424
425
    def _decode_nested_slices(self, obj: Any) -> Any:
        assert isinstance(obj, (list, tuple))
        if obj and not isinstance(obj[0], (list, tuple)):
            return slice(*obj)
        return [self._decode_nested_slices(x) for x in obj]

426
    def ext_hook(self, code: int, data: memoryview) -> Any:
427
428
        if code == CUSTOM_TYPE_RAW_VIEW:
            return data
429

430
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
431
432
433
434
            if code == CUSTOM_TYPE_PICKLE:
                return pickle.loads(data)
            if code == CUSTOM_TYPE_CLOUDPICKLE:
                return cloudpickle.loads(data)
435

436
        raise NotImplementedError(f"Extension type code {code} is not supported")