Unverified Commit 83449a5f authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Clean up pooling serial utils (#33665)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent dad2d6a5
...@@ -12,11 +12,7 @@ import base64 ...@@ -12,11 +12,7 @@ import base64
import requests import requests
import torch import torch
from vllm.utils.serial_utils import ( from vllm.utils.serial_utils import EMBED_DTYPES, ENDIANNESS, binary2tensor
EMBED_DTYPE_TO_TORCH_DTYPE,
ENDIANNESS,
binary2tensor,
)
def post_http_request(prompt: dict, api_url: str) -> requests.Response: def post_http_request(prompt: dict, api_url: str) -> requests.Response:
...@@ -45,7 +41,7 @@ def main(args): ...@@ -45,7 +41,7 @@ def main(args):
] * 2 ] * 2
# The OpenAI client does not support the embed_dtype and endianness parameters. # The OpenAI client does not support the embed_dtype and endianness parameters.
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE: for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS: for endianness in ENDIANNESS:
prompt = { prompt = {
"model": model, "model": model,
......
...@@ -12,13 +12,12 @@ import json ...@@ -12,13 +12,12 @@ import json
import requests import requests
import torch import torch
from vllm.utils.serial_utils import ( from vllm.entrypoints.pooling.utils import (
EMBED_DTYPE_TO_TORCH_DTYPE,
ENDIANNESS,
MetadataItem, MetadataItem,
build_metadata_items, build_metadata_items,
decode_pooling_output, decode_pooling_output,
) )
from vllm.utils.serial_utils import EMBED_DTYPES, ENDIANNESS
def post_http_request(prompt: dict, api_url: str) -> requests.Response: def post_http_request(prompt: dict, api_url: str) -> requests.Response:
...@@ -51,7 +50,7 @@ def main(args): ...@@ -51,7 +50,7 @@ def main(args):
# The OpenAI client does not support the bytes encoding_format. # The OpenAI client does not support the bytes encoding_format.
# The OpenAI client does not support the embed_dtype and endianness parameters. # The OpenAI client does not support the embed_dtype and endianness parameters.
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE: for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS: for endianness in ENDIANNESS:
prompt = { prompt = {
"model": model, "model": model,
...@@ -74,7 +73,7 @@ def main(args): ...@@ -74,7 +73,7 @@ def main(args):
# The vllm server always sorts the returned embeddings in the order of input. So # The vllm server always sorts the returned embeddings in the order of input. So
# returning metadata is not necessary. You can set encoding_format to bytes_only # returning metadata is not necessary. You can set encoding_format to bytes_only
# to let the server not return metadata. # to let the server not return metadata.
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE: for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS: for endianness in ENDIANNESS:
prompt = { prompt = {
"model": model, "model": model,
......
...@@ -17,16 +17,14 @@ from tests.models.utils import check_embeddings_close ...@@ -17,16 +17,14 @@ from tests.models.utils import check_embeddings_close
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
from vllm.platforms import current_platform from vllm.entrypoints.pooling.utils import (
from vllm.tokenizers import get_tokenizer
from vllm.utils.serial_utils import (
EMBED_DTYPE_TO_TORCH_DTYPE,
ENDIANNESS,
MetadataItem, MetadataItem,
binary2tensor,
build_metadata_items, build_metadata_items,
decode_pooling_output, decode_pooling_output,
) )
from vllm.platforms import current_platform
from vllm.tokenizers import get_tokenizer
from vllm.utils.serial_utils import EMBED_DTYPES, ENDIANNESS, binary2tensor
MODEL_NAME = "intfloat/multilingual-e5-small" MODEL_NAME = "intfloat/multilingual-e5-small"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
...@@ -535,7 +533,7 @@ async def test_base64_embed_dtype_and_endianness( ...@@ -535,7 +533,7 @@ async def test_base64_embed_dtype_and_endianness(
) )
float_data = [d.embedding for d in responses_float.data] float_data = [d.embedding for d in responses_float.data]
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE: for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS: for endianness in ENDIANNESS:
responses_base64 = requests.post( responses_base64 = requests.post(
server.url_for("/v1/embeddings"), server.url_for("/v1/embeddings"),
...@@ -574,7 +572,7 @@ async def test_bytes_embed_dtype_and_endianness( ...@@ -574,7 +572,7 @@ async def test_bytes_embed_dtype_and_endianness(
) )
float_data = [d.embedding for d in responses_float.data] float_data = [d.embedding for d in responses_float.data]
for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()): for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS: for endianness in ENDIANNESS:
responses_bytes = requests.post( responses_bytes = requests.post(
server.url_for("/v1/embeddings"), server.url_for("/v1/embeddings"),
...@@ -618,7 +616,7 @@ async def test_bytes_only_embed_dtype_and_endianness( ...@@ -618,7 +616,7 @@ async def test_bytes_only_embed_dtype_and_endianness(
float_data = [d.embedding for d in responses_float.data] float_data = [d.embedding for d in responses_float.data]
embedding_size = len(float_data[0]) embedding_size = len(float_data[0])
for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()): for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS: for endianness in ENDIANNESS:
responses_bytes = requests.post( responses_bytes = requests.post(
server.url_for("/v1/embeddings"), server.url_for("/v1/embeddings"),
......
...@@ -12,15 +12,13 @@ import torch ...@@ -12,15 +12,13 @@ import torch
from tests.models.utils import check_embeddings_close from tests.models.utils import check_embeddings_close
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
from vllm.tokenizers import get_tokenizer from vllm.entrypoints.pooling.utils import (
from vllm.utils.serial_utils import (
EMBED_DTYPE_TO_TORCH_DTYPE,
ENDIANNESS,
MetadataItem, MetadataItem,
binary2tensor,
build_metadata_items, build_metadata_items,
decode_pooling_output, decode_pooling_output,
) )
from vllm.tokenizers import get_tokenizer
from vllm.utils.serial_utils import EMBED_DTYPES, ENDIANNESS, binary2tensor
MODEL_NAME = "internlm/internlm2-1_8b-reward" MODEL_NAME = "internlm/internlm2-1_8b-reward"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
...@@ -342,7 +340,7 @@ async def test_base64_embed_dtype_and_endianness( ...@@ -342,7 +340,7 @@ async def test_base64_embed_dtype_and_endianness(
responses_float = PoolingResponse.model_validate(float_response.json()) responses_float = PoolingResponse.model_validate(float_response.json())
float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data] float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data]
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE: for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS: for endianness in ENDIANNESS:
responses_base64 = requests.post( responses_base64 = requests.post(
url, url,
...@@ -389,7 +387,7 @@ async def test_bytes_embed_dtype_and_endianness( ...@@ -389,7 +387,7 @@ async def test_bytes_embed_dtype_and_endianness(
responses_float = PoolingResponse.model_validate(float_response.json()) responses_float = PoolingResponse.model_validate(float_response.json())
float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data] float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data]
for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()): for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS: for endianness in ENDIANNESS:
responses_bytes = requests.post( responses_bytes = requests.post(
url, url,
...@@ -438,7 +436,7 @@ async def test_bytes_only_embed_dtype_and_endianness( ...@@ -438,7 +436,7 @@ async def test_bytes_only_embed_dtype_and_endianness(
float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data] float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data]
n_tokens = responses_float.usage.prompt_tokens // len(input_texts) n_tokens = responses_float.usage.prompt_tokens // len(input_texts)
for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()): for embed_dtype in EMBED_DTYPES:
for endianness in ENDIANNESS: for endianness in ENDIANNESS:
responses_bytes = requests.post( responses_bytes = requests.post(
url, url,
......
...@@ -5,17 +5,19 @@ import torch ...@@ -5,17 +5,19 @@ import torch
from tests.models.utils import check_embeddings_close from tests.models.utils import check_embeddings_close
from vllm.utils.serial_utils import ( from vllm.utils.serial_utils import (
EMBED_DTYPE_TO_TORCH_DTYPE, EMBED_DTYPES,
ENDIANNESS, ENDIANNESS,
EmbedDType,
Endianness,
binary2tensor, binary2tensor,
tensor2binary, tensor2binary,
) )
@pytest.mark.parametrize("endianness", ENDIANNESS) @pytest.mark.parametrize("endianness", ENDIANNESS)
@pytest.mark.parametrize("embed_dtype", EMBED_DTYPE_TO_TORCH_DTYPE.keys()) @pytest.mark.parametrize("embed_dtype", EMBED_DTYPES.keys())
@torch.inference_mode() @torch.inference_mode()
def test_encode_and_decode(embed_dtype: str, endianness: str): def test_encode_and_decode(embed_dtype: EmbedDType, endianness: Endianness):
for i in range(10): for i in range(10):
tensor = torch.rand(2, 3, 5, 7, 11, 13, device="cpu", dtype=torch.float32) tensor = torch.rand(2, 3, 5, 7, 11, 13, device="cpu", dtype=torch.float32)
shape = tensor.shape shape = tensor.shape
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json import json
from collections.abc import AsyncGenerator, Mapping from collections.abc import AsyncGenerator, Callable, Mapping
from typing import Any, Final, TypeAlias from functools import partial
from typing import Any, Final, Literal, TypeAlias, cast
import torch import torch
from fastapi import Request from fastapi import Request
...@@ -22,16 +23,18 @@ from vllm.entrypoints.pooling.embed.protocol import ( ...@@ -22,16 +23,18 @@ from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingResponse, EmbeddingResponse,
EmbeddingResponseData, EmbeddingResponseData,
) )
from vllm.entrypoints.pooling.utils import (
encode_pooling_bytes,
encode_pooling_output_base64,
encode_pooling_output_float,
)
from vllm.inputs.data import EmbedsPrompt, TokensPrompt from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import ( from vllm.utils.serial_utils import EmbedDType, Endianness
encode_pooling_bytes,
encode_pooling_output,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -113,29 +116,36 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -113,29 +116,36 @@ class OpenAIServingEmbedding(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
def _build_response( def request_output_to_embed_json_response(
self, self,
ctx: EmbeddingServeContext, final_res_batch: list[PoolingRequestOutput],
) -> EmbeddingResponse | EmbeddingBytesResponse | ErrorResponse: request_id: str,
final_res_batch_checked = ctx.final_res_batch created_time: int,
model_name: str,
encoding_format = ctx.request.encoding_format encoding_format: Literal["float", "base64"],
embed_dtype = ctx.request.embed_dtype embed_dtype: EmbedDType,
endianness = ctx.request.endianness endianness: Endianness,
) -> EmbeddingResponse:
encode_fn = cast(
Callable[[PoolingRequestOutput], list[float] | str],
(
encode_pooling_output_float
if encoding_format == "float"
else partial(
encode_pooling_output_base64,
embed_dtype=embed_dtype,
endianness=endianness,
)
),
)
def encode_float_base64():
items: list[EmbeddingResponseData] = [] items: list[EmbeddingResponseData] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch_checked): for idx, final_res in enumerate(final_res_batch):
item = EmbeddingResponseData( item = EmbeddingResponseData(
index=idx, index=idx,
embedding=encode_pooling_output( embedding=encode_fn(final_res),
final_res,
encoding_format=encoding_format,
embed_dtype=embed_dtype,
endianness=endianness,
),
) )
prompt_token_ids = final_res.prompt_token_ids prompt_token_ids = final_res.prompt_token_ids
...@@ -148,29 +158,38 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -148,29 +158,38 @@ class OpenAIServingEmbedding(OpenAIServing):
) )
return EmbeddingResponse( return EmbeddingResponse(
id=ctx.request_id, id=request_id,
created=ctx.created_time, created=created_time,
model=ctx.model_name, model=model_name,
data=items, data=items,
usage=usage, usage=usage,
) )
def encode_bytes(bytes_only: bool) -> EmbeddingBytesResponse: def request_output_to_embed_bytes_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: Literal["bytes", "bytes_only"],
embed_dtype: EmbedDType,
endianness: Endianness,
) -> EmbeddingBytesResponse:
content, items, usage = encode_pooling_bytes( content, items, usage = encode_pooling_bytes(
pooling_outputs=final_res_batch_checked, pooling_outputs=final_res_batch,
embed_dtype=embed_dtype, embed_dtype=embed_dtype,
endianness=endianness, endianness=endianness,
) )
headers = ( headers = (
None None
if bytes_only if encoding_format == "bytes_only"
else { else {
"metadata": json.dumps( "metadata": json.dumps(
{ {
"id": ctx.request_id, "id": request_id,
"created": ctx.created_time, "created": created_time,
"model": ctx.model_name, "model": model_name,
"data": items, "data": items,
"usage": usage, "usage": usage,
} }
...@@ -180,11 +199,36 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -180,11 +199,36 @@ class OpenAIServingEmbedding(OpenAIServing):
return EmbeddingBytesResponse(content=content, headers=headers) return EmbeddingBytesResponse(content=content, headers=headers)
def _build_response(
self,
ctx: EmbeddingServeContext,
) -> EmbeddingResponse | EmbeddingBytesResponse | ErrorResponse:
encoding_format = ctx.request.encoding_format
embed_dtype = ctx.request.embed_dtype
endianness = ctx.request.endianness
if encoding_format == "float" or encoding_format == "base64": if encoding_format == "float" or encoding_format == "base64":
return encode_float_base64() return self.request_output_to_embed_json_response(
elif encoding_format == "bytes" or encoding_format == "bytes_only": ctx.final_res_batch,
return encode_bytes(bytes_only=encoding_format == "bytes_only") ctx.request_id,
else: ctx.created_time,
ctx.model_name,
encoding_format,
embed_dtype,
endianness,
)
if encoding_format == "bytes" or encoding_format == "bytes_only":
return self.request_output_to_embed_bytes_response(
ctx.final_res_batch,
ctx.request_id,
ctx.created_time,
ctx.model_name,
encoding_format,
embed_dtype,
endianness,
)
assert_never(encoding_format) assert_never(encoding_format)
def _get_max_position_embeddings(self) -> int: def _get_max_position_embeddings(self) -> int:
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
import asyncio import asyncio
import json import json
import time import time
from collections.abc import AsyncGenerator, Sequence from collections.abc import AsyncGenerator, Callable, Sequence
from typing import Any, Final, cast from functools import partial
from typing import Any, Final, Literal, cast
import jinja2 import jinja2
from fastapi import Request from fastapi import Request
...@@ -27,17 +28,16 @@ from vllm.entrypoints.pooling.pooling.protocol import ( ...@@ -27,17 +28,16 @@ from vllm.entrypoints.pooling.pooling.protocol import (
PoolingResponse, PoolingResponse,
PoolingResponseData, PoolingResponseData,
) )
from vllm.entrypoints.pooling.utils import (
encode_pooling_bytes,
encode_pooling_output_base64,
encode_pooling_output_float,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.tasks import PoolingTask, SupportedTask from vllm.tasks import PoolingTask, SupportedTask
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import ( from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
EmbedDType,
EncodingFormat,
Endianness,
encode_pooling_bytes,
encode_pooling_output,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -256,29 +256,36 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -256,29 +256,36 @@ class OpenAIServingPooling(OpenAIServing):
return response return response
def request_output_to_pooling_response( def request_output_to_pooling_json_response(
self, self,
final_res_batch: list[PoolingRequestOutput], final_res_batch: list[PoolingRequestOutput],
request_id: str, request_id: str,
created_time: int, created_time: int,
model_name: str, model_name: str,
encoding_format: EncodingFormat, encoding_format: Literal["float", "base64"],
embed_dtype: EmbedDType, embed_dtype: EmbedDType,
endianness: Endianness, endianness: Endianness,
) -> PoolingResponse | PoolingBytesResponse: ) -> PoolingResponse:
def encode_float_base64(): encode_fn = cast(
Callable[[PoolingRequestOutput], list[float] | str],
(
encode_pooling_output_float
if encoding_format == "float"
else partial(
encode_pooling_output_base64,
embed_dtype=embed_dtype,
endianness=endianness,
)
),
)
items: list[PoolingResponseData] = [] items: list[PoolingResponseData] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch): for idx, final_res in enumerate(final_res_batch):
item = PoolingResponseData( item = PoolingResponseData(
index=idx, index=idx,
data=encode_pooling_output( data=encode_fn(final_res),
final_res,
encoding_format=encoding_format,
embed_dtype=embed_dtype,
endianness=endianness,
),
) )
prompt_token_ids = final_res.prompt_token_ids prompt_token_ids = final_res.prompt_token_ids
...@@ -298,7 +305,16 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -298,7 +305,16 @@ class OpenAIServingPooling(OpenAIServing):
usage=usage, usage=usage,
) )
def encode_bytes(bytes_only: bool) -> PoolingBytesResponse: def request_output_to_pooling_bytes_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: Literal["bytes", "bytes_only"],
embed_dtype: EmbedDType,
endianness: Endianness,
) -> PoolingBytesResponse:
content, items, usage = encode_pooling_bytes( content, items, usage = encode_pooling_bytes(
pooling_outputs=final_res_batch, pooling_outputs=final_res_batch,
embed_dtype=embed_dtype, embed_dtype=embed_dtype,
...@@ -307,7 +323,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -307,7 +323,7 @@ class OpenAIServingPooling(OpenAIServing):
headers = ( headers = (
None None
if bytes_only if encoding_format == "bytes_only"
else { else {
"metadata": json.dumps( "metadata": json.dumps(
{ {
...@@ -321,14 +337,38 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -321,14 +337,38 @@ class OpenAIServingPooling(OpenAIServing):
} }
) )
return PoolingBytesResponse( return PoolingBytesResponse(content=content, headers=headers)
content=content,
headers=headers,
)
def request_output_to_pooling_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: EncodingFormat,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> PoolingResponse | PoolingBytesResponse:
if encoding_format == "float" or encoding_format == "base64": if encoding_format == "float" or encoding_format == "base64":
return encode_float_base64() return self.request_output_to_pooling_json_response(
elif encoding_format == "bytes" or encoding_format == "bytes_only": final_res_batch,
return encode_bytes(bytes_only=encoding_format == "bytes_only") request_id,
else: created_time,
model_name,
encoding_format,
embed_dtype,
endianness,
)
if encoding_format == "bytes" or encoding_format == "bytes_only":
return self.request_output_to_pooling_bytes_response(
final_res_batch,
request_id,
created_time,
model_name,
encoding_format,
embed_dtype,
endianness,
)
assert_never(encoding_format) assert_never(encoding_format)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass
from typing import Any
import pybase64
import torch
from vllm.outputs import PoolingRequestOutput
from vllm.utils.serial_utils import (
EMBED_DTYPES,
EmbedDType,
Endianness,
binary2tensor,
tensor2binary,
)
@dataclass
class MetadataItem:
index: int
embed_dtype: EmbedDType
endianness: Endianness
start: int
end: int
shape: tuple[int, ...]
def build_metadata_items(
embed_dtype: EmbedDType,
endianness: Endianness,
shape: tuple[int, ...],
n_request: int,
) -> list[MetadataItem]:
n_bytes = EMBED_DTYPES[embed_dtype].nbytes
size = math.prod(shape)
return [
MetadataItem(
index=i,
embed_dtype=embed_dtype,
endianness=endianness,
start=i * size * n_bytes,
end=(i + 1) * size * n_bytes,
shape=shape,
)
for i in range(n_request)
]
def encode_pooling_output_float(output: PoolingRequestOutput) -> list[float]:
return output.outputs.data.tolist()
def encode_pooling_output_binary(
output: PoolingRequestOutput,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> bytes:
return tensor2binary(output.outputs.data, embed_dtype, endianness)
def encode_pooling_output_base64(
output: PoolingRequestOutput,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> str:
embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness)
return pybase64.b64encode(embedding_bytes).decode("utf-8")
def encode_pooling_bytes(
pooling_outputs: list[PoolingRequestOutput],
embed_dtype: EmbedDType,
endianness: Endianness,
) -> tuple[list[bytes], list[dict[str, Any]], dict[str, Any]]:
num_prompt_tokens = 0
items: list[dict[str, Any]] = []
body: list[bytes] = []
offset = 0
for idx, output in enumerate(pooling_outputs):
binary = tensor2binary(
tensor=output.outputs.data,
embed_dtype=embed_dtype,
endianness=endianness,
)
size = len(binary)
# Dictionary form of MetadataItem
item = dict(
index=idx,
embed_dtype=embed_dtype,
endianness=endianness,
start=offset,
end=offset + size,
shape=output.outputs.data.shape,
)
body.append(binary)
items.append(item)
prompt_token_ids = output.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
offset += size
# Dictionary form of UsageInfo
usage = dict(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return body, items, usage
def decode_pooling_output(items: list[MetadataItem], body: bytes) -> list[torch.Tensor]:
return [
binary2tensor(
body[item.start : item.end],
item.shape,
item.embed_dtype,
item.endianness,
)
for item in sorted(items, key=lambda x: x.index)
]
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import io import io
import math
import sys import sys
from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal from typing import Literal, get_args
import numpy as np import numpy as np
import numpy.typing as npt
import pybase64
import torch import torch
from typing_extensions import assert_never
if TYPE_CHECKING:
from vllm import PoolingRequestOutput
else:
PoolingRequestOutput = Any
sys_byteorder = sys.byteorder sys_byteorder = sys.byteorder
EMBED_DTYPE_TO_TORCH_DTYPE = { @dataclass(frozen=True)
"float32": torch.float32, class DTypeInfo:
"float16": torch.float16, torch_dtype: torch.dtype
"bfloat16": torch.bfloat16,
# I'm not sure if other platforms' CPUs support the fp8 data format.
# EMBED_DTYPE only uses the fp8 data representation,
# does not use fp8 computation, and only occurs on the CPU.
# Apologize for any possible break.
"fp8_e4m3": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2,
}
EMBED_DTYPE_TO_N_BYTES = {
"float32": 4,
"float16": 2,
"bfloat16": 2,
"fp8_e4m3": 1,
"fp8_e5m2": 1,
}
torch_view_dtype: torch.dtype
numpy_view_dtype: npt.DTypeLike
EMBED_DTYPE_TO_TORCH_DTYPE_VIEW = { @property
"float32": torch.float32, def nbytes(self) -> int:
"float16": torch.float16, return self.torch_dtype.itemsize
# numpy does not support bfloat16 and fp8
"bfloat16": torch.float16,
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
}
EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW = {
"float32": np.float32,
"float16": np.float16,
# numpy does not support bfloat16 and fp8
"bfloat16": np.float16,
"fp8_e4m3": np.uint8,
"fp8_e5m2": np.uint8,
}
ENDIANNESS = ["native", "big", "little"]
EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"] EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"]
Endianness = Literal["native", "big", "little"] Endianness = Literal["native", "big", "little"]
EncodingFormat = Literal["float", "base64", "bytes", "bytes_only"] EncodingFormat = Literal["float", "base64", "bytes", "bytes_only"]
# I'm not sure if other platforms' CPUs support the fp8 data format.
# EMBED_DTYPE only uses the fp8 data representation,
# does not use fp8 computation, and only occurs on the CPU.
# Apologize for any possible break.
# NOTE: numpy does not support bfloat16 and fp8
EMBED_DTYPES: Mapping[EmbedDType, DTypeInfo] = {
"float32": DTypeInfo(torch.float32, torch.float32, np.float32),
"float16": DTypeInfo(torch.float16, torch.float16, np.float16),
"bfloat16": DTypeInfo(torch.bfloat16, torch.float16, np.float16),
"fp8_e4m3": DTypeInfo(torch.float8_e4m3fn, torch.uint8, np.uint8),
"fp8_e5m2": DTypeInfo(torch.float8_e5m2, torch.uint8, np.uint8),
}
ENDIANNESS: tuple[Endianness, ...] = get_args(Endianness)
def tensor2base64(x: torch.Tensor) -> str: def tensor2base64(x: torch.Tensor) -> str:
with io.BytesIO() as buf: with io.BytesIO() as buf:
...@@ -71,21 +51,26 @@ def tensor2base64(x: torch.Tensor) -> str: ...@@ -71,21 +51,26 @@ def tensor2base64(x: torch.Tensor) -> str:
buf.seek(0) buf.seek(0)
binary_data = buf.read() binary_data = buf.read()
return base64.b64encode(binary_data).decode("utf-8") return pybase64.b64encode(binary_data).decode("utf-8")
def tensor2binary( def tensor2binary(
tensor: torch.Tensor, embed_dtype: EmbedDType, endianness: Endianness tensor: torch.Tensor,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> bytes: ) -> bytes:
assert isinstance(tensor, torch.Tensor) assert isinstance(tensor, torch.Tensor)
assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE assert embed_dtype in EMBED_DTYPES
assert endianness in ENDIANNESS assert endianness in ENDIANNESS
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype] dtype_info = EMBED_DTYPES[embed_dtype]
torch_view_dtype = EMBED_DTYPE_TO_TORCH_DTYPE_VIEW[embed_dtype]
np_array = ( np_array = (
tensor.to(torch_dtype).flatten().contiguous().view(torch_view_dtype).numpy() tensor.to(dtype_info.torch_dtype)
.flatten()
.contiguous()
.view(dtype_info.torch_view_dtype)
.numpy()
) )
if endianness != "native" and endianness != sys_byteorder: if endianness != "native" and endianness != sys_byteorder:
...@@ -100,115 +85,14 @@ def binary2tensor( ...@@ -100,115 +85,14 @@ def binary2tensor(
embed_dtype: EmbedDType, embed_dtype: EmbedDType,
endianness: Endianness, endianness: Endianness,
) -> torch.Tensor: ) -> torch.Tensor:
assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE assert embed_dtype in EMBED_DTYPES
assert embed_dtype in EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW
assert endianness in ENDIANNESS assert endianness in ENDIANNESS
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype] dtype_info = EMBED_DTYPES[embed_dtype]
np_dtype = EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW[embed_dtype]
np_array = np.frombuffer(binary, dtype=np_dtype).reshape(shape) np_array = np.frombuffer(binary, dtype=dtype_info.numpy_view_dtype).reshape(shape)
if endianness != "native" and endianness != sys_byteorder: if endianness != "native" and endianness != sys_byteorder:
np_array = np_array.byteswap() np_array = np_array.byteswap()
return torch.from_numpy(np_array).view(torch_dtype) return torch.from_numpy(np_array).view(dtype_info.torch_dtype)
def encode_pooling_output(
output: PoolingRequestOutput,
encoding_format: EncodingFormat,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> list[float] | str | bytes:
if encoding_format == "float":
return output.outputs.data.tolist()
elif encoding_format == "base64":
embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness)
return base64.b64encode(embedding_bytes).decode("utf-8")
elif encoding_format == "bytes" or encoding_format == "bytes_only":
return tensor2binary(output.outputs.data, embed_dtype, endianness)
assert_never(encoding_format)
@dataclass
class MetadataItem:
index: int
embed_dtype: EmbedDType
endianness: Endianness
start: int
end: int
shape: tuple[int, ...]
def build_metadata_items(
embed_dtype: EmbedDType,
endianness: Endianness,
shape: tuple[int, ...],
n_request: int,
):
n_bytes = EMBED_DTYPE_TO_N_BYTES[embed_dtype]
size = math.prod(shape)
items = [
MetadataItem(
index=i,
embed_dtype=embed_dtype,
endianness=endianness,
start=i * size * n_bytes,
end=(i + 1) * size * n_bytes,
shape=shape,
)
for i in range(n_request)
]
return items
def encode_pooling_bytes(
pooling_outputs: list[PoolingRequestOutput],
embed_dtype: EmbedDType,
endianness: Endianness,
):
num_prompt_tokens = 0
items: list[dict[str, MetadataItem]] = []
body = []
offset = 0
for idx, output in enumerate(pooling_outputs):
binary = tensor2binary(
tensor=output.outputs.data,
embed_dtype=embed_dtype,
endianness=endianness,
)
size = len(binary)
item = {
"index": idx,
"embed_dtype": embed_dtype,
"endianness": endianness,
"start": offset,
"end": offset + size,
"shape": output.outputs.data.shape,
}
body.append(binary)
items.append(item)
prompt_token_ids = output.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
offset += size
usage = {
"prompt_tokens": num_prompt_tokens,
"total_tokens": num_prompt_tokens,
}
return body, items, usage
def decode_pooling_output(items: list[MetadataItem], body: bytes) -> list[torch.Tensor]:
items.sort(key=lambda x: x.index)
tensor_list: list[torch.Tensor] = []
for item in items:
binary = body[item.start : item.end]
tensor = binary2tensor(binary, item.shape, item.embed_dtype, item.endianness)
tensor_list.append(tensor)
return tensor_list
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment