Unverified Commit 83449a5f authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Clean up pooling serial utils (#33665)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent dad2d6a5
......@@ -12,11 +12,7 @@ import base64
import requests
import torch
from vllm.utils.serial_utils import (
EMBED_DTYPE_TO_TORCH_DTYPE,
ENDIANNESS,
binary2tensor,
)
from vllm.utils.serial_utils import EMBED_DTYPES, ENDIANNESS, binary2tensor
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
......@@ -45,7 +41,7 @@ def main(args):
] * 2
# The OpenAI client does not support the embed_dtype and endianness parameters.
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS:
prompt = {
"model": model,
......
......@@ -12,13 +12,12 @@ import json
import requests
import torch
from vllm.utils.serial_utils import (
EMBED_DTYPE_TO_TORCH_DTYPE,
ENDIANNESS,
from vllm.entrypoints.pooling.utils import (
MetadataItem,
build_metadata_items,
decode_pooling_output,
)
from vllm.utils.serial_utils import EMBED_DTYPES, ENDIANNESS
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
......@@ -51,7 +50,7 @@ def main(args):
# The OpenAI client does not support the bytes encoding_format.
# The OpenAI client does not support the embed_dtype and endianness parameters.
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS:
prompt = {
"model": model,
......@@ -74,7 +73,7 @@ def main(args):
# The vllm server always sorts the returned embeddings in the order of input. So
# returning metadata is not necessary. You can set encoding_format to bytes_only
# to let the server not return metadata.
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS:
prompt = {
"model": model,
......
......@@ -17,16 +17,14 @@ from tests.models.utils import check_embeddings_close
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
from vllm.platforms import current_platform
from vllm.tokenizers import get_tokenizer
from vllm.utils.serial_utils import (
EMBED_DTYPE_TO_TORCH_DTYPE,
ENDIANNESS,
from vllm.entrypoints.pooling.utils import (
MetadataItem,
binary2tensor,
build_metadata_items,
decode_pooling_output,
)
from vllm.platforms import current_platform
from vllm.tokenizers import get_tokenizer
from vllm.utils.serial_utils import EMBED_DTYPES, ENDIANNESS, binary2tensor
MODEL_NAME = "intfloat/multilingual-e5-small"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
......@@ -535,7 +533,7 @@ async def test_base64_embed_dtype_and_endianness(
)
float_data = [d.embedding for d in responses_float.data]
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS:
responses_base64 = requests.post(
server.url_for("/v1/embeddings"),
......@@ -574,7 +572,7 @@ async def test_bytes_embed_dtype_and_endianness(
)
float_data = [d.embedding for d in responses_float.data]
for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()):
for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS:
responses_bytes = requests.post(
server.url_for("/v1/embeddings"),
......@@ -618,7 +616,7 @@ async def test_bytes_only_embed_dtype_and_endianness(
float_data = [d.embedding for d in responses_float.data]
embedding_size = len(float_data[0])
for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()):
for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS:
responses_bytes = requests.post(
server.url_for("/v1/embeddings"),
......
......@@ -12,15 +12,13 @@ import torch
from tests.models.utils import check_embeddings_close
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
from vllm.tokenizers import get_tokenizer
from vllm.utils.serial_utils import (
EMBED_DTYPE_TO_TORCH_DTYPE,
ENDIANNESS,
from vllm.entrypoints.pooling.utils import (
MetadataItem,
binary2tensor,
build_metadata_items,
decode_pooling_output,
)
from vllm.tokenizers import get_tokenizer
from vllm.utils.serial_utils import EMBED_DTYPES, ENDIANNESS, binary2tensor
MODEL_NAME = "internlm/internlm2-1_8b-reward"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
......@@ -342,7 +340,7 @@ async def test_base64_embed_dtype_and_endianness(
responses_float = PoolingResponse.model_validate(float_response.json())
float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data]
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS:
responses_base64 = requests.post(
url,
......@@ -389,7 +387,7 @@ async def test_bytes_embed_dtype_and_endianness(
responses_float = PoolingResponse.model_validate(float_response.json())
float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data]
for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()):
for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS:
responses_bytes = requests.post(
url,
......@@ -438,7 +436,7 @@ async def test_bytes_only_embed_dtype_and_endianness(
float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data]
n_tokens = responses_float.usage.prompt_tokens // len(input_texts)
for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()):
for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS:
responses_bytes = requests.post(
url,
......
......@@ -5,17 +5,19 @@ import torch
from tests.models.utils import check_embeddings_close
from vllm.utils.serial_utils import (
EMBED_DTYPE_TO_TORCH_DTYPE,
EMBED_DTYPES,
ENDIANNESS,
EmbedDType,
Endianness,
binary2tensor,
tensor2binary,
)
@pytest.mark.parametrize("endianness", ENDIANNESS)
@pytest.mark.parametrize("embed_dtype", EMBED_DTYPE_TO_TORCH_DTYPE.keys())
@pytest.mark.parametrize("embed_dtype", EMBED_DTYPES.keys())
@torch.inference_mode()
def test_encode_and_decode(embed_dtype: str, endianness: str):
def test_encode_and_decode(embed_dtype: EmbedDType, endianness: Endianness):
for i in range(10):
tensor = torch.rand(2, 3, 5, 7, 11, 13, device="cpu", dtype=torch.float32)
shape = tensor.shape
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import AsyncGenerator, Mapping
from typing import Any, Final, TypeAlias
from collections.abc import AsyncGenerator, Callable, Mapping
from functools import partial
from typing import Any, Final, Literal, TypeAlias, cast
import torch
from fastapi import Request
......@@ -22,16 +23,18 @@ from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingResponse,
EmbeddingResponseData,
)
from vllm.entrypoints.pooling.utils import (
encode_pooling_bytes,
encode_pooling_output_base64,
encode_pooling_output_float,
)
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import (
encode_pooling_bytes,
encode_pooling_output,
)
from vllm.utils.serial_utils import EmbedDType, Endianness
logger = init_logger(__name__)
......@@ -113,79 +116,120 @@ class OpenAIServingEmbedding(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
def request_output_to_embed_json_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: Literal["float", "base64"],
embed_dtype: EmbedDType,
endianness: Endianness,
) -> EmbeddingResponse:
encode_fn = cast(
Callable[[PoolingRequestOutput], list[float] | str],
(
encode_pooling_output_float
if encoding_format == "float"
else partial(
encode_pooling_output_base64,
embed_dtype=embed_dtype,
endianness=endianness,
)
),
)
items: list[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
item = EmbeddingResponseData(
index=idx,
embedding=encode_fn(final_res),
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return EmbeddingResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)
def request_output_to_embed_bytes_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: Literal["bytes", "bytes_only"],
embed_dtype: EmbedDType,
endianness: Endianness,
) -> EmbeddingBytesResponse:
content, items, usage = encode_pooling_bytes(
pooling_outputs=final_res_batch,
embed_dtype=embed_dtype,
endianness=endianness,
)
headers = (
None
if encoding_format == "bytes_only"
else {
"metadata": json.dumps(
{
"id": request_id,
"created": created_time,
"model": model_name,
"data": items,
"usage": usage,
}
)
}
)
return EmbeddingBytesResponse(content=content, headers=headers)
def _build_response(
self,
ctx: EmbeddingServeContext,
) -> EmbeddingResponse | EmbeddingBytesResponse | ErrorResponse:
final_res_batch_checked = ctx.final_res_batch
encoding_format = ctx.request.encoding_format
embed_dtype = ctx.request.embed_dtype
endianness = ctx.request.endianness
def encode_float_base64():
items: list[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch_checked):
item = EmbeddingResponseData(
index=idx,
embedding=encode_pooling_output(
final_res,
encoding_format=encoding_format,
embed_dtype=embed_dtype,
endianness=endianness,
),
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return EmbeddingResponse(
id=ctx.request_id,
created=ctx.created_time,
model=ctx.model_name,
data=items,
usage=usage,
if encoding_format == "float" or encoding_format == "base64":
return self.request_output_to_embed_json_response(
ctx.final_res_batch,
ctx.request_id,
ctx.created_time,
ctx.model_name,
encoding_format,
embed_dtype,
endianness,
)
def encode_bytes(bytes_only: bool) -> EmbeddingBytesResponse:
content, items, usage = encode_pooling_bytes(
pooling_outputs=final_res_batch_checked,
embed_dtype=embed_dtype,
endianness=endianness,
if encoding_format == "bytes" or encoding_format == "bytes_only":
return self.request_output_to_embed_bytes_response(
ctx.final_res_batch,
ctx.request_id,
ctx.created_time,
ctx.model_name,
encoding_format,
embed_dtype,
endianness,
)
headers = (
None
if bytes_only
else {
"metadata": json.dumps(
{
"id": ctx.request_id,
"created": ctx.created_time,
"model": ctx.model_name,
"data": items,
"usage": usage,
}
)
}
)
return EmbeddingBytesResponse(content=content, headers=headers)
if encoding_format == "float" or encoding_format == "base64":
return encode_float_base64()
elif encoding_format == "bytes" or encoding_format == "bytes_only":
return encode_bytes(bytes_only=encoding_format == "bytes_only")
else:
assert_never(encoding_format)
assert_never(encoding_format)
def _get_max_position_embeddings(self) -> int:
"""Get the model's effective maximum sequence length for chunking."""
......
......@@ -4,8 +4,9 @@
import asyncio
import json
import time
from collections.abc import AsyncGenerator, Sequence
from typing import Any, Final, cast
from collections.abc import AsyncGenerator, Callable, Sequence
from functools import partial
from typing import Any, Final, Literal, cast
import jinja2
from fastapi import Request
......@@ -27,17 +28,16 @@ from vllm.entrypoints.pooling.pooling.protocol import (
PoolingResponse,
PoolingResponseData,
)
from vllm.entrypoints.pooling.utils import (
encode_pooling_bytes,
encode_pooling_output_base64,
encode_pooling_output_float,
)
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.tasks import PoolingTask, SupportedTask
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import (
EmbedDType,
EncodingFormat,
Endianness,
encode_pooling_bytes,
encode_pooling_output,
)
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
logger = init_logger(__name__)
......@@ -256,79 +256,119 @@ class OpenAIServingPooling(OpenAIServing):
return response
def request_output_to_pooling_response(
def request_output_to_pooling_json_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: EncodingFormat,
encoding_format: Literal["float", "base64"],
embed_dtype: EmbedDType,
endianness: Endianness,
) -> PoolingResponse | PoolingBytesResponse:
def encode_float_base64():
items: list[PoolingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
item = PoolingResponseData(
index=idx,
data=encode_pooling_output(
final_res,
encoding_format=encoding_format,
embed_dtype=embed_dtype,
endianness=endianness,
),
) -> PoolingResponse:
encode_fn = cast(
Callable[[PoolingRequestOutput], list[float] | str],
(
encode_pooling_output_float
if encoding_format == "float"
else partial(
encode_pooling_output_base64,
embed_dtype=embed_dtype,
endianness=endianness,
)
prompt_token_ids = final_res.prompt_token_ids
),
)
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
items: list[PoolingResponseData] = []
num_prompt_tokens = 0
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
for idx, final_res in enumerate(final_res_batch):
item = PoolingResponseData(
index=idx,
data=encode_fn(final_res),
)
prompt_token_ids = final_res.prompt_token_ids
return PoolingResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
def encode_bytes(bytes_only: bool) -> PoolingBytesResponse:
content, items, usage = encode_pooling_bytes(
pooling_outputs=final_res_batch,
embed_dtype=embed_dtype,
endianness=endianness,
)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
headers = (
None
if bytes_only
else {
"metadata": json.dumps(
{
"id": request_id,
"created": created_time,
"model": model_name,
"data": items,
"usage": usage,
}
)
}
return PoolingResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)
def request_output_to_pooling_bytes_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: Literal["bytes", "bytes_only"],
embed_dtype: EmbedDType,
endianness: Endianness,
) -> PoolingBytesResponse:
content, items, usage = encode_pooling_bytes(
pooling_outputs=final_res_batch,
embed_dtype=embed_dtype,
endianness=endianness,
)
headers = (
None
if encoding_format == "bytes_only"
else {
"metadata": json.dumps(
{
"id": request_id,
"created": created_time,
"model": model_name,
"data": items,
"usage": usage,
}
)
}
)
return PoolingBytesResponse(content=content, headers=headers)
def request_output_to_pooling_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: EncodingFormat,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> PoolingResponse | PoolingBytesResponse:
if encoding_format == "float" or encoding_format == "base64":
return self.request_output_to_pooling_json_response(
final_res_batch,
request_id,
created_time,
model_name,
encoding_format,
embed_dtype,
endianness,
)
return PoolingBytesResponse(
content=content,
headers=headers,
if encoding_format == "bytes" or encoding_format == "bytes_only":
return self.request_output_to_pooling_bytes_response(
final_res_batch,
request_id,
created_time,
model_name,
encoding_format,
embed_dtype,
endianness,
)
if encoding_format == "float" or encoding_format == "base64":
return encode_float_base64()
elif encoding_format == "bytes" or encoding_format == "bytes_only":
return encode_bytes(bytes_only=encoding_format == "bytes_only")
else:
assert_never(encoding_format)
assert_never(encoding_format)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass
from typing import Any
import pybase64
import torch
from vllm.outputs import PoolingRequestOutput
from vllm.utils.serial_utils import (
EMBED_DTYPES,
EmbedDType,
Endianness,
binary2tensor,
tensor2binary,
)
@dataclass
class MetadataItem:
index: int
embed_dtype: EmbedDType
endianness: Endianness
start: int
end: int
shape: tuple[int, ...]
def build_metadata_items(
embed_dtype: EmbedDType,
endianness: Endianness,
shape: tuple[int, ...],
n_request: int,
) -> list[MetadataItem]:
n_bytes = EMBED_DTYPES[embed_dtype].nbytes
size = math.prod(shape)
return [
MetadataItem(
index=i,
embed_dtype=embed_dtype,
endianness=endianness,
start=i * size * n_bytes,
end=(i + 1) * size * n_bytes,
shape=shape,
)
for i in range(n_request)
]
def encode_pooling_output_float(output: PoolingRequestOutput) -> list[float]:
return output.outputs.data.tolist()
def encode_pooling_output_binary(
output: PoolingRequestOutput,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> bytes:
return tensor2binary(output.outputs.data, embed_dtype, endianness)
def encode_pooling_output_base64(
output: PoolingRequestOutput,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> str:
embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness)
return pybase64.b64encode(embedding_bytes).decode("utf-8")
def encode_pooling_bytes(
pooling_outputs: list[PoolingRequestOutput],
embed_dtype: EmbedDType,
endianness: Endianness,
) -> tuple[list[bytes], list[dict[str, Any]], dict[str, Any]]:
num_prompt_tokens = 0
items: list[dict[str, Any]] = []
body: list[bytes] = []
offset = 0
for idx, output in enumerate(pooling_outputs):
binary = tensor2binary(
tensor=output.outputs.data,
embed_dtype=embed_dtype,
endianness=endianness,
)
size = len(binary)
# Dictionary form of MetadataItem
item = dict(
index=idx,
embed_dtype=embed_dtype,
endianness=endianness,
start=offset,
end=offset + size,
shape=output.outputs.data.shape,
)
body.append(binary)
items.append(item)
prompt_token_ids = output.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
offset += size
# Dictionary form of UsageInfo
usage = dict(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return body, items, usage
def decode_pooling_output(items: list[MetadataItem], body: bytes) -> list[torch.Tensor]:
return [
binary2tensor(
body[item.start : item.end],
item.shape,
item.embed_dtype,
item.endianness,
)
for item in sorted(items, key=lambda x: x.index)
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import io
import math
import sys
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal
from typing import Literal, get_args
import numpy as np
import numpy.typing as npt
import pybase64
import torch
from typing_extensions import assert_never
if TYPE_CHECKING:
from vllm import PoolingRequestOutput
else:
PoolingRequestOutput = Any
sys_byteorder = sys.byteorder
EMBED_DTYPE_TO_TORCH_DTYPE = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
# I'm not sure if other platforms' CPUs support the fp8 data format.
# EMBED_DTYPE only uses the fp8 data representation,
# does not use fp8 computation, and only occurs on the CPU.
# Apologize for any possible break.
"fp8_e4m3": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2,
}
EMBED_DTYPE_TO_N_BYTES = {
"float32": 4,
"float16": 2,
"bfloat16": 2,
"fp8_e4m3": 1,
"fp8_e5m2": 1,
}
@dataclass(frozen=True)
class DTypeInfo:
torch_dtype: torch.dtype
EMBED_DTYPE_TO_TORCH_DTYPE_VIEW = {
"float32": torch.float32,
"float16": torch.float16,
# numpy does not support bfloat16 and fp8
"bfloat16": torch.float16,
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
}
torch_view_dtype: torch.dtype
numpy_view_dtype: npt.DTypeLike
EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW = {
"float32": np.float32,
"float16": np.float16,
# numpy does not support bfloat16 and fp8
"bfloat16": np.float16,
"fp8_e4m3": np.uint8,
"fp8_e5m2": np.uint8,
}
@property
def nbytes(self) -> int:
return self.torch_dtype.itemsize
ENDIANNESS = ["native", "big", "little"]
EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"]
Endianness = Literal["native", "big", "little"]
EncodingFormat = Literal["float", "base64", "bytes", "bytes_only"]
# I'm not sure if other platforms' CPUs support the fp8 data format.
# EMBED_DTYPE only uses the fp8 data representation,
# does not use fp8 computation, and only occurs on the CPU.
# Apologize for any possible break.
# NOTE: numpy does not support bfloat16 and fp8
EMBED_DTYPES: Mapping[EmbedDType, DTypeInfo] = {
"float32": DTypeInfo(torch.float32, torch.float32, np.float32),
"float16": DTypeInfo(torch.float16, torch.float16, np.float16),
"bfloat16": DTypeInfo(torch.bfloat16, torch.float16, np.float16),
"fp8_e4m3": DTypeInfo(torch.float8_e4m3fn, torch.uint8, np.uint8),
"fp8_e5m2": DTypeInfo(torch.float8_e5m2, torch.uint8, np.uint8),
}
ENDIANNESS: tuple[Endianness, ...] = get_args(Endianness)
def tensor2base64(x: torch.Tensor) -> str:
with io.BytesIO() as buf:
......@@ -71,21 +51,26 @@ def tensor2base64(x: torch.Tensor) -> str:
buf.seek(0)
binary_data = buf.read()
return base64.b64encode(binary_data).decode("utf-8")
return pybase64.b64encode(binary_data).decode("utf-8")
def tensor2binary(
tensor: torch.Tensor, embed_dtype: EmbedDType, endianness: Endianness
tensor: torch.Tensor,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> bytes:
assert isinstance(tensor, torch.Tensor)
assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE
assert embed_dtype in EMBED_DTYPES
assert endianness in ENDIANNESS
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
torch_view_dtype = EMBED_DTYPE_TO_TORCH_DTYPE_VIEW[embed_dtype]
dtype_info = EMBED_DTYPES[embed_dtype]
np_array = (
tensor.to(torch_dtype).flatten().contiguous().view(torch_view_dtype).numpy()
tensor.to(dtype_info.torch_dtype)
.flatten()
.contiguous()
.view(dtype_info.torch_view_dtype)
.numpy()
)
if endianness != "native" and endianness != sys_byteorder:
......@@ -100,115 +85,14 @@ def binary2tensor(
embed_dtype: EmbedDType,
endianness: Endianness,
) -> torch.Tensor:
assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE
assert embed_dtype in EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW
assert embed_dtype in EMBED_DTYPES
assert endianness in ENDIANNESS
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
np_dtype = EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW[embed_dtype]
dtype_info = EMBED_DTYPES[embed_dtype]
np_array = np.frombuffer(binary, dtype=np_dtype).reshape(shape)
np_array = np.frombuffer(binary, dtype=dtype_info.numpy_view_dtype).reshape(shape)
if endianness != "native" and endianness != sys_byteorder:
np_array = np_array.byteswap()
return torch.from_numpy(np_array).view(torch_dtype)
def encode_pooling_output(
output: PoolingRequestOutput,
encoding_format: EncodingFormat,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> list[float] | str | bytes:
if encoding_format == "float":
return output.outputs.data.tolist()
elif encoding_format == "base64":
embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness)
return base64.b64encode(embedding_bytes).decode("utf-8")
elif encoding_format == "bytes" or encoding_format == "bytes_only":
return tensor2binary(output.outputs.data, embed_dtype, endianness)
assert_never(encoding_format)
@dataclass
class MetadataItem:
index: int
embed_dtype: EmbedDType
endianness: Endianness
start: int
end: int
shape: tuple[int, ...]
def build_metadata_items(
embed_dtype: EmbedDType,
endianness: Endianness,
shape: tuple[int, ...],
n_request: int,
):
n_bytes = EMBED_DTYPE_TO_N_BYTES[embed_dtype]
size = math.prod(shape)
items = [
MetadataItem(
index=i,
embed_dtype=embed_dtype,
endianness=endianness,
start=i * size * n_bytes,
end=(i + 1) * size * n_bytes,
shape=shape,
)
for i in range(n_request)
]
return items
def encode_pooling_bytes(
pooling_outputs: list[PoolingRequestOutput],
embed_dtype: EmbedDType,
endianness: Endianness,
):
num_prompt_tokens = 0
items: list[dict[str, MetadataItem]] = []
body = []
offset = 0
for idx, output in enumerate(pooling_outputs):
binary = tensor2binary(
tensor=output.outputs.data,
embed_dtype=embed_dtype,
endianness=endianness,
)
size = len(binary)
item = {
"index": idx,
"embed_dtype": embed_dtype,
"endianness": endianness,
"start": offset,
"end": offset + size,
"shape": output.outputs.data.shape,
}
body.append(binary)
items.append(item)
prompt_token_ids = output.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
offset += size
usage = {
"prompt_tokens": num_prompt_tokens,
"total_tokens": num_prompt_tokens,
}
return body, items, usage
def decode_pooling_output(items: list[MetadataItem], body: bytes) -> list[torch.Tensor]:
items.sort(key=lambda x: x.index)
tensor_list: list[torch.Tensor] = []
for item in items:
binary = body[item.start : item.end]
tensor = binary2tensor(binary, item.shape, item.embed_dtype, item.endianness)
tensor_list.append(tensor)
return tensor_list
return torch.from_numpy(np_array).view(dtype_info.torch_dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment