serial_utils.py 3.28 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import io
4
import sys
5
from collections.abc import Mapping
6
from dataclasses import dataclass
7
from typing import Literal, get_args
8
9

import numpy as np
10
11
import numpy.typing as npt
import pybase64
12
13
14
15
16
import torch

sys_byteorder = sys.byteorder


17
18
19
@dataclass(frozen=True)
class DTypeInfo:
    torch_dtype: torch.dtype
20

21
22
    torch_view_dtype: torch.dtype
    numpy_view_dtype: npt.DTypeLike
23

24
25
26
    @property
    def nbytes(self) -> int:
        return self.torch_dtype.itemsize
27
28
29


EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"]
30
MmMetadataDType = Literal["int32", "int64", "uint8", "bool"]
31
Endianness = Literal["native", "big", "little"]
32
EncodingFormat = Literal["float", "base64", "bytes", "bytes_only"]
33

34
35
36
37
38
39
40
41
42
43
44
45
# I'm not sure if other platforms' CPUs support the fp8 data format.
# EMBED_DTYPE only uses the fp8 data representation,
# does not use fp8 computation, and only occurs on the CPU.
# Apologize for any possible break.
# NOTE: numpy does not support bfloat16 and fp8
EMBED_DTYPES: Mapping[EmbedDType, DTypeInfo] = {
    "float32": DTypeInfo(torch.float32, torch.float32, np.float32),
    "float16": DTypeInfo(torch.float16, torch.float16, np.float16),
    "bfloat16": DTypeInfo(torch.bfloat16, torch.float16, np.float16),
    "fp8_e4m3": DTypeInfo(torch.float8_e4m3fn, torch.uint8, np.uint8),
    "fp8_e5m2": DTypeInfo(torch.float8_e5m2, torch.uint8, np.uint8),
}
46
47
48
49
50
51
52
53
54
MM_METADATA_DTYPES: Mapping[MmMetadataDType, DTypeInfo] = {
    "int32": DTypeInfo(torch.int32, torch.int32, np.int32),
    "int64": DTypeInfo(torch.int64, torch.int64, np.int64),
    "uint8": DTypeInfo(torch.uint8, torch.uint8, np.uint8),
    "bool": DTypeInfo(torch.bool, torch.uint8, np.uint8),
}
_ALL_SERIAL_DTYPES: Mapping[str, DTypeInfo] = {
    k: v for d in (EMBED_DTYPES, MM_METADATA_DTYPES) for k, v in d.items()
}
55
56
ENDIANNESS: tuple[Endianness, ...] = get_args(Endianness)

57

58
59
60
61
62
63
def tensor2base64(x: torch.Tensor) -> str:
    with io.BytesIO() as buf:
        torch.save(x, buf)
        buf.seek(0)
        binary_data = buf.read()

64
    return pybase64.b64encode(binary_data).decode("utf-8")
65
66


67
def tensor2binary(
68
    tensor: torch.Tensor,
69
    embed_dtype: "EmbedDType | MmMetadataDType",
70
    endianness: Endianness,
71
72
) -> bytes:
    assert isinstance(tensor, torch.Tensor)
73
    assert embed_dtype in _ALL_SERIAL_DTYPES
74
75
    assert endianness in ENDIANNESS

76
    dtype_info = _ALL_SERIAL_DTYPES[embed_dtype]
77
78

    np_array = (
79
80
81
82
83
        tensor.to(dtype_info.torch_dtype)
        .flatten()
        .contiguous()
        .view(dtype_info.torch_view_dtype)
        .numpy()
84
85
86
87
88
89
90
91
92
93
94
    )

    if endianness != "native" and endianness != sys_byteorder:
        np_array = np_array.byteswap()

    return np_array.tobytes()


def binary2tensor(
    binary: bytes,
    shape: tuple[int, ...],
95
    embed_dtype: "EmbedDType | MmMetadataDType",
96
97
    endianness: Endianness,
) -> torch.Tensor:
98
    assert embed_dtype in _ALL_SERIAL_DTYPES
99
100
    assert endianness in ENDIANNESS

101
    dtype_info = _ALL_SERIAL_DTYPES[embed_dtype]
102

103
    np_array = np.frombuffer(binary, dtype=dtype_info.numpy_view_dtype).reshape(shape)
104
105
106
107

    if endianness != "native" and endianness != sys_byteorder:
        np_array = np_array.byteswap()

108
    return torch.from_numpy(np_array).view(dtype_info.torch_dtype)