serial_utils.py 2.74 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import io
4
import sys
5
from collections.abc import Mapping
6
from dataclasses import dataclass
7
from typing import Literal, get_args
8
9

import numpy as np
10
11
import numpy.typing as npt
import pybase64
12
13
14
15
16
import torch

sys_byteorder = sys.byteorder


17
18
19
@dataclass(frozen=True)
class DTypeInfo:
    torch_dtype: torch.dtype
20

21
22
    torch_view_dtype: torch.dtype
    numpy_view_dtype: npt.DTypeLike
23

24
25
26
    @property
    def nbytes(self) -> int:
        return self.torch_dtype.itemsize
27
28
29
30


EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"]
Endianness = Literal["native", "big", "little"]
31
EncodingFormat = Literal["float", "base64", "bytes", "bytes_only"]
32

33
34
35
36
37
38
39
40
41
42
43
44
45
46
# I'm not sure if other platforms' CPUs support the fp8 data format.
# EMBED_DTYPE only uses the fp8 data representation,
# does not use fp8 computation, and only occurs on the CPU.
# Apologize for any possible break.
# NOTE: numpy does not support bfloat16 and fp8
EMBED_DTYPES: Mapping[EmbedDType, DTypeInfo] = {
    "float32": DTypeInfo(torch.float32, torch.float32, np.float32),
    "float16": DTypeInfo(torch.float16, torch.float16, np.float16),
    "bfloat16": DTypeInfo(torch.bfloat16, torch.float16, np.float16),
    "fp8_e4m3": DTypeInfo(torch.float8_e4m3fn, torch.uint8, np.uint8),
    "fp8_e5m2": DTypeInfo(torch.float8_e5m2, torch.uint8, np.uint8),
}
ENDIANNESS: tuple[Endianness, ...] = get_args(Endianness)

47

48
49
50
51
52
53
def tensor2base64(x: torch.Tensor) -> str:
    with io.BytesIO() as buf:
        torch.save(x, buf)
        buf.seek(0)
        binary_data = buf.read()

54
    return pybase64.b64encode(binary_data).decode("utf-8")
55
56


57
def tensor2binary(
58
59
60
    tensor: torch.Tensor,
    embed_dtype: EmbedDType,
    endianness: Endianness,
61
62
) -> bytes:
    assert isinstance(tensor, torch.Tensor)
63
    assert embed_dtype in EMBED_DTYPES
64
65
    assert endianness in ENDIANNESS

66
    dtype_info = EMBED_DTYPES[embed_dtype]
67
68

    np_array = (
69
70
71
72
73
        tensor.to(dtype_info.torch_dtype)
        .flatten()
        .contiguous()
        .view(dtype_info.torch_view_dtype)
        .numpy()
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    )

    if endianness != "native" and endianness != sys_byteorder:
        np_array = np_array.byteswap()

    return np_array.tobytes()


def binary2tensor(
    binary: bytes,
    shape: tuple[int, ...],
    embed_dtype: EmbedDType,
    endianness: Endianness,
) -> torch.Tensor:
88
    assert embed_dtype in EMBED_DTYPES
89
90
    assert endianness in ENDIANNESS

91
    dtype_info = EMBED_DTYPES[embed_dtype]
92

93
    np_array = np.frombuffer(binary, dtype=dtype_info.numpy_view_dtype).reshape(shape)
94
95
96
97

    if endianness != "native" and endianness != sys_byteorder:
        np_array = np_array.byteswap()

98
    return torch.from_numpy(np_array).view(dtype_info.torch_dtype)