utils.py 3.59 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

import importlib.util
5
6
import math
from dataclasses import dataclass
7
from functools import lru_cache
8
9
10
11
from typing import Any

import pybase64
import torch
12
from fastapi.responses import JSONResponse
13

14
from vllm.logger import init_logger
15
16
17
18
19
20
21
22
23
from vllm.outputs import PoolingRequestOutput
from vllm.utils.serial_utils import (
    EMBED_DTYPES,
    EmbedDType,
    Endianness,
    binary2tensor,
    tensor2binary,
)

24
25
logger = init_logger(__name__)

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

@dataclass
class MetadataItem:
    index: int
    embed_dtype: EmbedDType
    endianness: Endianness
    start: int
    end: int
    shape: tuple[int, ...]


def build_metadata_items(
    embed_dtype: EmbedDType,
    endianness: Endianness,
    shape: tuple[int, ...],
    n_request: int,
) -> list[MetadataItem]:
    n_bytes = EMBED_DTYPES[embed_dtype].nbytes
    size = math.prod(shape)

    return [
        MetadataItem(
            index=i,
            embed_dtype=embed_dtype,
            endianness=endianness,
            start=i * size * n_bytes,
            end=(i + 1) * size * n_bytes,
            shape=shape,
        )
        for i in range(n_request)
    ]


def encode_pooling_output_float(output: PoolingRequestOutput) -> list[float]:
    return output.outputs.data.tolist()


def encode_pooling_output_binary(
    output: PoolingRequestOutput,
    embed_dtype: EmbedDType,
    endianness: Endianness,
) -> bytes:
    return tensor2binary(output.outputs.data, embed_dtype, endianness)


def encode_pooling_output_base64(
    output: PoolingRequestOutput,
    embed_dtype: EmbedDType,
    endianness: Endianness,
) -> str:
    embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness)
    return pybase64.b64encode(embedding_bytes).decode("utf-8")


def encode_pooling_bytes(
    pooling_outputs: list[PoolingRequestOutput],
    embed_dtype: EmbedDType,
    endianness: Endianness,
) -> tuple[list[bytes], list[dict[str, Any]], dict[str, Any]]:
    num_prompt_tokens = 0
    items: list[dict[str, Any]] = []
    body: list[bytes] = []
    offset = 0
    for idx, output in enumerate(pooling_outputs):
        binary = tensor2binary(
            tensor=output.outputs.data,
            embed_dtype=embed_dtype,
            endianness=endianness,
        )
        size = len(binary)

        # Dictionary form of MetadataItem
        item = dict(
            index=idx,
            embed_dtype=embed_dtype,
            endianness=endianness,
            start=offset,
            end=offset + size,
            shape=output.outputs.data.shape,
        )

        body.append(binary)
        items.append(item)
        prompt_token_ids = output.prompt_token_ids
        num_prompt_tokens += len(prompt_token_ids)
        offset += size

    # Dictionary form of UsageInfo
    usage = dict(
        prompt_tokens=num_prompt_tokens,
        total_tokens=num_prompt_tokens,
    )

    return body, items, usage


def decode_pooling_output(items: list[MetadataItem], body: bytes) -> list[torch.Tensor]:
    return [
        binary2tensor(
            body[item.start : item.end],
            item.shape,
            item.embed_dtype,
            item.endianness,
        )
        for item in sorted(items, key=lambda x: x.index)
    ]
132
133
134
135
136
137
138
139
140
141
142
143


@lru_cache(maxsize=1)
def get_json_response_cls() -> type[JSONResponse]:
    if importlib.util.find_spec("orjson") is not None:
        from fastapi.responses import ORJSONResponse

        return ORJSONResponse
    logger.warning_once(
        "To make v1/embeddings API fast, please install orjson by `pip install orjson`"
    )
    return JSONResponse