Unverified Commit c6c240aa authored by llmpros's avatar llmpros Committed by GitHub
Browse files

[Frontend]: Support base64 embedding (#5935)


Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 2be6955a
import base64
import numpy as np
import openai import openai
import pytest import pytest
import ray import ray
...@@ -109,3 +112,33 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI, ...@@ -109,3 +112,33 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 17 assert embeddings.usage.prompt_tokens == 17
assert embeddings.usage.total_tokens == 17 assert embeddings.usage.total_tokens == 17
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
model_name: str):
input_texts = [
"Hello my name is",
"The best thing about vLLM is that it supports many different models"
]
responses_float = await embedding_client.embeddings.create(
input=input_texts, model=model_name, encoding_format="float")
responses_base64 = await embedding_client.embeddings.create(
input=input_texts, model=model_name, encoding_format="base64")
decoded_responses_base64_data = []
for data in responses_base64.data:
decoded_responses_base64_data.append(
np.frombuffer(base64.b64decode(data.embedding),
dtype="float").tolist())
assert responses_float.data[0].embedding == decoded_responses_base64_data[
0]
assert responses_float.data[1].embedding == decoded_responses_base64_data[
1]
...@@ -580,7 +580,7 @@ class CompletionStreamResponse(OpenAIBaseModel): ...@@ -580,7 +580,7 @@ class CompletionStreamResponse(OpenAIBaseModel):
class EmbeddingResponseData(BaseModel): class EmbeddingResponseData(BaseModel):
index: int index: int
object: str = "embedding" object: str = "embedding"
embedding: List[float] embedding: Union[List[float], str]
class EmbeddingResponse(BaseModel): class EmbeddingResponse(BaseModel):
......
import base64
import time import time
from typing import AsyncIterator, List, Optional, Tuple from typing import AsyncIterator, List, Optional, Tuple
import numpy as np
from fastapi import Request from fastapi import Request
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -20,19 +22,18 @@ TypeTokenIDs = List[int] ...@@ -20,19 +22,18 @@ TypeTokenIDs = List[int]
def request_output_to_embedding_response( def request_output_to_embedding_response(
final_res_batch: List[EmbeddingRequestOutput], final_res_batch: List[EmbeddingRequestOutput], request_id: str,
request_id: str, created_time: int, model_name: str,
created_time: int, encoding_format: str) -> EmbeddingResponse:
model_name: str,
) -> EmbeddingResponse:
data: List[EmbeddingResponseData] = [] data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch): for idx, final_res in enumerate(final_res_batch):
assert final_res is not None assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids prompt_token_ids = final_res.prompt_token_ids
embedding = final_res.outputs.embedding
embedding_data = EmbeddingResponseData( if encoding_format == "base64":
index=idx, embedding=final_res.outputs.embedding) embedding = base64.b64encode(np.array(embedding))
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
data.append(embedding_data) data.append(embedding_data)
num_prompt_tokens += len(prompt_token_ids) num_prompt_tokens += len(prompt_token_ids)
...@@ -72,10 +73,8 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -72,10 +73,8 @@ class OpenAIServingEmbedding(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
# Return error for unsupported features. encoding_format = (request.encoding_format
if request.encoding_format == "base64": if request.encoding_format else "float")
return self.create_error_response(
"base64 encoding is not currently supported")
if request.dimensions is not None: if request.dimensions is not None:
return self.create_error_response( return self.create_error_response(
"dimensions is currently not supported") "dimensions is currently not supported")
...@@ -129,7 +128,8 @@ class OpenAIServingEmbedding(OpenAIServing): ...@@ -129,7 +128,8 @@ class OpenAIServingEmbedding(OpenAIServing):
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
final_res_batch[i] = res final_res_batch[i] = res
response = request_output_to_embedding_response( response = request_output_to_embedding_response(
final_res_batch, request_id, created_time, model_name) final_res_batch, request_id, created_time, model_name,
encoding_format)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment