msgspec_utils.py 1.18 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from array import array
5
from typing import Any
6

7
from vllm.multimodal.inputs import MultiModalKwargs
8
9
10
11
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE


def encode_hook(obj: Any) -> Any:
12
    """Custom msgspec enc hook that supports array types and MultiModalKwargs.
13
14
15
16
17
18

    See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
    """
    if isinstance(obj, array):
        assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, (
            f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
19
20
            f"Given array has a type code of {obj.typecode}."
        )
21
        return obj.tobytes()
22
23
    if isinstance(obj, MultiModalKwargs):
        return dict(obj)
24
25


26
def decode_hook(type: type, obj: Any) -> Any:
27
    """Custom msgspec dec hook that supports array types and MultiModalKwargs.
28
29
30
31
32
33
34

    See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
    """
    if type is array:
        deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE)
        deserialized.frombytes(obj)
        return deserialized
35
36
    if type is MultiModalKwargs:
        return MultiModalKwargs(obj)