Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c6c240aa
Unverified
Commit
c6c240aa
authored
Jun 30, 2024
by
llmpros
Committed by
GitHub
Jun 30, 2024
Browse files
[Frontend]: Support base64 embedding (#5935)
Co-authored-by:
Cyrus Leung
<
cyrus.tl.leung@gmail.com
>
parent
2be6955a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
14 deletions
+47
-14
tests/entrypoints/openai/test_embedding.py
tests/entrypoints/openai/test_embedding.py
+33
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+1
-1
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+13
-13
No files found.
tests/entrypoints/openai/test_embedding.py
View file @
c6c240aa
import
base64
import
numpy
as
np
import
openai
import
pytest
import
ray
...
...
@@ -109,3 +112,33 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
assert
embeddings
.
usage
.
completion_tokens
==
0
assert
embeddings
.
usage
.
prompt_tokens
==
17
assert
embeddings
.
usage
.
total_tokens
==
17
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
EMBEDDING_MODEL_NAME
],
)
async
def
test_batch_base64_embedding
(
embedding_client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
input_texts
=
[
"Hello my name is"
,
"The best thing about vLLM is that it supports many different models"
]
responses_float
=
await
embedding_client
.
embeddings
.
create
(
input
=
input_texts
,
model
=
model_name
,
encoding_format
=
"float"
)
responses_base64
=
await
embedding_client
.
embeddings
.
create
(
input
=
input_texts
,
model
=
model_name
,
encoding_format
=
"base64"
)
decoded_responses_base64_data
=
[]
for
data
in
responses_base64
.
data
:
decoded_responses_base64_data
.
append
(
np
.
frombuffer
(
base64
.
b64decode
(
data
.
embedding
),
dtype
=
"float"
).
tolist
())
assert
responses_float
.
data
[
0
].
embedding
==
decoded_responses_base64_data
[
0
]
assert
responses_float
.
data
[
1
].
embedding
==
decoded_responses_base64_data
[
1
]
vllm/entrypoints/openai/protocol.py
View file @
c6c240aa
...
...
@@ -580,7 +580,7 @@ class CompletionStreamResponse(OpenAIBaseModel):
class
EmbeddingResponseData
(
BaseModel
):
index
:
int
object
:
str
=
"embedding"
embedding
:
List
[
float
]
embedding
:
Union
[
List
[
float
]
,
str
]
class
EmbeddingResponse
(
BaseModel
):
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
c6c240aa
import
base64
import
time
from
typing
import
AsyncIterator
,
List
,
Optional
,
Tuple
import
numpy
as
np
from
fastapi
import
Request
from
vllm.config
import
ModelConfig
...
...
@@ -20,19 +22,18 @@ TypeTokenIDs = List[int]
def
request_output_to_embedding_response
(
final_res_batch
:
List
[
EmbeddingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
)
->
EmbeddingResponse
:
final_res_batch
:
List
[
EmbeddingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
encoding_format
:
str
)
->
EmbeddingResponse
:
data
:
List
[
EmbeddingResponseData
]
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
assert
final_res
is
not
None
prompt_token_ids
=
final_res
.
prompt_token_ids
embedding_data
=
EmbeddingResponseData
(
index
=
idx
,
embedding
=
final_res
.
outputs
.
embedding
)
embedding
=
final_res
.
outputs
.
embedding
if
encoding_format
==
"base64"
:
embedding
=
base64
.
b64encode
(
np
.
array
(
embedding
))
embedding_data
=
EmbeddingResponseData
(
index
=
idx
,
embedding
=
embedding
)
data
.
append
(
embedding_data
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
...
...
@@ -72,10 +73,8 @@ class OpenAIServingEmbedding(OpenAIServing):
if
error_check_ret
is
not
None
:
return
error_check_ret
# Return error for unsupported features.
if
request
.
encoding_format
==
"base64"
:
return
self
.
create_error_response
(
"base64 encoding is not currently supported"
)
encoding_format
=
(
request
.
encoding_format
if
request
.
encoding_format
else
"float"
)
if
request
.
dimensions
is
not
None
:
return
self
.
create_error_response
(
"dimensions is currently not supported"
)
...
...
@@ -129,7 +128,8 @@ class OpenAIServingEmbedding(OpenAIServing):
return
self
.
create_error_response
(
"Client disconnected"
)
final_res_batch
[
i
]
=
res
response
=
request_output_to_embedding_response
(
final_res_batch
,
request_id
,
created_time
,
model_name
)
final_res_batch
,
request_id
,
created_time
,
model_name
,
encoding_format
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment