Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0dcc8cbe
Unverified
Commit
0dcc8cbe
authored
Oct 04, 2024
by
Flávia Béo
Committed by
GitHub
Oct 04, 2024
Browse files
Adds truncate_prompt_tokens param for embeddings creation (#8999)
Signed-off-by:
Flavia Beo
<
flavia.beo@ibm.com
>
parent
26aa325f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
76 additions
and
5 deletions
+76
-5
tests/entrypoints/openai/test_embedding.py
tests/entrypoints/openai/test_embedding.py
+61
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+1
-0
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+14
-5
No files found.
tests/entrypoints/openai/test_embedding.py
View file @
0dcc8cbe
...
...
@@ -144,3 +144,64 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
0
].
embedding
assert
responses_float
.
data
[
1
].
embedding
==
responses_default
.
data
[
1
].
embedding
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
EMBEDDING_MODEL_NAME
],
)
async
def
test_single_embedding_truncation
(
embedding_client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
input_texts
=
[
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?"
,
]
# test single embedding
embeddings
=
await
embedding_client
.
embeddings
.
create
(
model
=
model_name
,
input
=
input_texts
,
extra_body
=
{
"truncate_prompt_tokens"
:
10
})
assert
embeddings
.
id
is
not
None
assert
len
(
embeddings
.
data
)
==
1
assert
len
(
embeddings
.
data
[
0
].
embedding
)
==
4096
assert
embeddings
.
usage
.
completion_tokens
==
0
assert
embeddings
.
usage
.
prompt_tokens
==
10
assert
embeddings
.
usage
.
total_tokens
==
10
input_tokens
=
[
1
,
24428
,
289
,
18341
,
26165
,
285
,
19323
,
283
,
289
,
26789
,
3871
,
28728
,
9901
,
340
,
2229
,
385
,
340
,
315
,
28741
,
28804
,
2
]
embeddings
=
await
embedding_client
.
embeddings
.
create
(
model
=
model_name
,
input
=
input_tokens
,
extra_body
=
{
"truncate_prompt_tokens"
:
10
})
assert
embeddings
.
id
is
not
None
assert
len
(
embeddings
.
data
)
==
1
assert
len
(
embeddings
.
data
[
0
].
embedding
)
==
4096
assert
embeddings
.
usage
.
completion_tokens
==
0
assert
embeddings
.
usage
.
prompt_tokens
==
10
assert
embeddings
.
usage
.
total_tokens
==
10
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
EMBEDDING_MODEL_NAME
],
)
async
def
test_single_embedding_truncation_invalid
(
embedding_client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
input_texts
=
[
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?"
,
]
with
pytest
.
raises
(
openai
.
BadRequestError
):
embeddings
=
await
embedding_client
.
embeddings
.
create
(
model
=
model_name
,
input
=
input_texts
,
extra_body
=
{
"truncate_prompt_tokens"
:
8193
})
assert
"error"
in
embeddings
.
object
assert
"truncate_prompt_tokens value is greater than max_model_len. "
\
"Please, select a smaller truncation size."
in
embeddings
.
message
vllm/entrypoints/openai/protocol.py
View file @
0dcc8cbe
...
...
@@ -671,6 +671,7 @@ class EmbeddingRequest(OpenAIBaseModel):
encoding_format
:
Literal
[
"float"
,
"base64"
]
=
"float"
dimensions
:
Optional
[
int
]
=
None
user
:
Optional
[
str
]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
# doc: begin-embedding-pooling-params
additional_data
:
Optional
[
Any
]
=
None
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
0dcc8cbe
...
...
@@ -110,6 +110,17 @@ class OpenAIServingEmbedding(OpenAIServing):
request_id
=
f
"embd-
{
random_uuid
()
}
"
created_time
=
int
(
time
.
monotonic
())
truncate_prompt_tokens
=
None
if
request
.
truncate_prompt_tokens
is
not
None
:
if
request
.
truncate_prompt_tokens
<=
self
.
max_model_len
:
truncate_prompt_tokens
=
request
.
truncate_prompt_tokens
else
:
return
self
.
create_error_response
(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size."
)
# Schedule the request and get the result generator.
generators
:
List
[
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]]
=
[]
try
:
...
...
@@ -123,11 +134,9 @@ class OpenAIServingEmbedding(OpenAIServing):
pooling_params
=
request
.
to_pooling_params
()
prompts
=
list
(
self
.
_tokenize_prompt_input_or_inputs
(
request
,
tokenizer
,
self
.
_tokenize_prompt_input_or_inputs
(
request
,
tokenizer
,
request
.
input
,
))
truncate_prompt_tokens
))
for
i
,
prompt_inputs
in
enumerate
(
prompts
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment