Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0661cb9d
Unverified
Commit
0661cb9d
authored
Sep 07, 2025
by
Flora Feng
Committed by
GitHub
Sep 07, 2025
Browse files
Add renderer-based prompt processing for embedding and classification endpoints (#24356)
Signed-off-by:
sfeng33
<
4florafeng@gmail.com
>
parent
105d3d62
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
59 additions
and
56 deletions
+59
-56
tests/entrypoints/openai/test_truncation.py
tests/entrypoints/openai/test_truncation.py
+4
-10
tests/entrypoints/test_renderer.py
tests/entrypoints/test_renderer.py
+17
-0
vllm/entrypoints/openai/serving_classification.py
vllm/entrypoints/openai/serving_classification.py
+5
-8
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+19
-26
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+7
-10
vllm/entrypoints/renderer.py
vllm/entrypoints/renderer.py
+7
-2
No files found.
tests/entrypoints/openai/test_truncation.py
View file @
0661cb9d
...
...
@@ -73,17 +73,11 @@ async def test_zero_truncation_size(client: openai.AsyncOpenAI):
"truncate_prompt_tokens"
:
truncation_size
}
with
pytest
.
raises
(
openai
.
BadRequestError
)
as
err
:
await
client
.
post
(
path
=
"embeddings"
,
cast_to
=
object
,
body
=
{
**
kwargs
})
assert
err
.
value
.
status_code
==
400
error_details
=
err
.
value
.
response
.
json
()[
"error"
]
response
=
await
client
.
post
(
path
=
"embeddings"
,
cast_to
=
object
,
body
=
{
**
kwargs
})
assert
error_details
[
"type"
]
==
"BadRequestError"
assert
"This model's maximum context length is"
in
error_details
[
"message"
]
assert
"tokens in the input for embedding generation"
in
error_details
[
"message"
]
assert
"Please reduce the length of the input"
in
error_details
[
"message"
]
assert
response
[
"usage"
][
"prompt_tokens"
]
==
truncation_size
@
pytest
.
mark
.
asyncio
...
...
tests/entrypoints/test_renderer.py
View file @
0661cb9d
...
...
@@ -130,6 +130,23 @@ class TestRenderPrompt:
assert
call_args
.
kwargs
[
"truncation"
]
is
True
assert
call_args
.
kwargs
[
"max_length"
]
==
50
@
pytest
.
mark
.
asyncio
async
def
test_truncation_negative
(
self
,
renderer
,
mock_async_tokenizer
):
# Test that negative truncation uses model's max_model_len
mock_async_tokenizer
.
return_value
=
MockTokenizerResult
(
[
101
,
7592
,
2088
])
# Truncated to max_model_len
renderer
.
async_tokenizer_pool
[
renderer
.
tokenizer
]
=
mock_async_tokenizer
results
=
await
renderer
.
render_prompt
(
prompt_or_prompts
=
"Hello world"
,
max_length
=
200
,
truncate_prompt_tokens
=-
1
)
assert
len
(
results
)
==
1
call_args
=
mock_async_tokenizer
.
call_args
assert
call_args
.
kwargs
[
"truncation"
]
is
True
assert
call_args
.
kwargs
[
"max_length"
]
==
100
# model's max_model_len
@
pytest
.
mark
.
asyncio
async
def
test_token_truncation_last_elements
(
self
,
renderer
):
# Test that token truncation keeps the last N elements
...
...
vllm/entrypoints/openai/serving_classification.py
View file @
0661cb9d
...
...
@@ -54,14 +54,11 @@ class ClassificationMixin(OpenAIServing):
ctx
.
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
ctx
.
lora_request
)
(
ctx
.
request_prompts
,
ctx
.
engine_prompts
,
)
=
await
self
.
_preprocess_completion
(
ctx
.
request
,
ctx
.
tokenizer
,
ctx
.
request
.
input
,
)
renderer
=
self
.
_get_renderer
(
ctx
.
tokenizer
)
ctx
.
engine_prompts
=
await
renderer
.
render_prompt
(
prompt_or_prompts
=
ctx
.
request
.
input
,
max_length
=
self
.
max_model_len
,
truncate_prompt_tokens
=
ctx
.
request
.
truncate_prompt_tokens
)
return
None
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
0661cb9d
...
...
@@ -24,7 +24,6 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
ErrorResponse
,
UsageInfo
)
from
vllm.entrypoints.openai.serving_engine
import
(
EmbeddingServeContext
,
OpenAIServing
,
RequestPrompt
,
ServeContext
,
TextTokensPrompt
)
# yapf: enable
...
...
@@ -79,11 +78,12 @@ class EmbeddingMixin(OpenAIServing):
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
ctx
.
lora_request
)
renderer
=
self
.
_get_renderer
(
tokenizer
)
if
isinstance
(
ctx
.
request
,
EmbeddingChatRequest
):
(
_
,
ctx
.
request_prompts
,
_
,
ctx
.
engine_prompts
,
)
=
await
self
.
_preprocess_chat
(
ctx
.
request
,
...
...
@@ -98,13 +98,18 @@ class EmbeddingMixin(OpenAIServing):
add_special_tokens
=
ctx
.
request
.
add_special_tokens
,
)
else
:
(
ctx
.
request_prompts
,
ctx
.
engine_prompts
)
=
await
self
.
_preprocess_completion
(
ctx
.
request
,
tokenizer
,
ctx
.
request
.
input
,
add_special_tokens
=
ctx
.
request
.
add_special_tokens
,
)
# Set max_length based on chunked processing capability
if
self
.
_should_use_chunked_processing
(
ctx
.
request
):
max_length
=
None
else
:
max_length
=
self
.
max_embed_len
or
self
.
max_model_len
ctx
.
engine_prompts
=
await
renderer
.
render_prompt
(
prompt_or_prompts
=
ctx
.
request
.
input
,
max_length
=
max_length
,
truncate_prompt_tokens
=
ctx
.
request
.
truncate_prompt_tokens
,
add_special_tokens
=
ctx
.
request
.
add_special_tokens
,
)
return
None
except
(
ValueError
,
TypeError
)
as
e
:
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
...
...
@@ -286,7 +291,6 @@ class EmbeddingMixin(OpenAIServing):
self
,
ctx
:
EmbeddingServeContext
,
engine_prompt
:
Union
[
EngineTokensPrompt
,
EngineEmbedsPrompt
],
request_prompt
:
RequestPrompt
,
pooling_params
:
PoolingParams
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]],
prompt_index
:
int
,
...
...
@@ -295,7 +299,7 @@ class EmbeddingMixin(OpenAIServing):
request_id_item
=
f
"
{
ctx
.
request_id
}
-
{
prompt_index
}
"
self
.
_log_inputs
(
request_id_item
,
request
_prompt
,
engine
_prompt
,
params
=
pooling_params
,
lora_request
=
ctx
.
lora_request
)
...
...
@@ -353,20 +357,14 @@ class EmbeddingMixin(OpenAIServing):
return
self
.
create_error_response
(
"Engine prompts not available"
)
if
ctx
.
request_prompts
is
None
:
return
self
.
create_error_response
(
"Request prompts not available"
)
max_pos_embeddings
=
self
.
_get_max_position_embeddings
()
for
i
,
engine_prompt
in
enumerate
(
ctx
.
engine_prompts
):
request_prompt
=
ctx
.
request_prompts
[
i
]
# Check if this specific prompt needs chunked processing
if
self
.
_is_text_tokens_prompt
(
request
_prompt
):
if
self
.
_is_text_tokens_prompt
(
engine
_prompt
):
# Cast to TextTokensPrompt since we've verified
# prompt_token_ids
text_tokens_prompt
=
cast
(
TextTokensPrompt
,
request
_prompt
)
text_tokens_prompt
=
cast
(
TextTokensPrompt
,
engine
_prompt
)
if
(
len
(
text_tokens_prompt
[
"prompt_token_ids"
])
>
max_pos_embeddings
):
# Use chunked processing for this prompt
...
...
@@ -382,8 +380,7 @@ class EmbeddingMixin(OpenAIServing):
Union
[
EngineTokensPrompt
,
EngineEmbedsPrompt
],
engine_prompt
)
generator
=
await
self
.
_create_single_prompt_generator
(
ctx
,
engine_prompt_typed
,
request_prompt
,
pooling_params
,
trace_headers
,
i
)
ctx
,
engine_prompt_typed
,
pooling_params
,
trace_headers
,
i
)
generators
.
append
(
generator
)
from
vllm.utils
import
merge_async_iterators
...
...
@@ -419,10 +416,6 @@ class EmbeddingMixin(OpenAIServing):
if
not
use_chunked
:
return
await
super
().
_collect_batch
(
ctx
=
ctx
)
if
ctx
.
request_prompts
is
None
:
return
self
.
create_error_response
(
"Request prompts not available"
)
if
ctx
.
result_generator
is
None
:
return
self
.
create_error_response
(
"Result generator not available"
)
...
...
@@ -538,7 +531,7 @@ class EmbeddingMixin(OpenAIServing):
data
=
final_embedding
)
# Get original prompt token IDs for this prompt
original_prompt
=
ctx
.
request
_prompts
[
prompt_idx
]
original_prompt
=
ctx
.
engine
_prompts
[
prompt_idx
]
if
not
self
.
_is_text_tokens_prompt
(
original_prompt
):
return
self
.
create_error_response
(
f
"Chunked prompt
{
prompt_idx
}
is not a "
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
0661cb9d
...
...
@@ -368,23 +368,20 @@ class OpenAIServing:
for
i
,
engine_prompt
in
enumerate
(
ctx
.
engine_prompts
):
request_id_item
=
f
"
{
ctx
.
request_id
}
-
{
i
}
"
if
ctx
.
request_prompts
is
None
:
return
self
.
create_error_response
(
"Request prompts not available"
)
# Mypy has an existing bug related to inferring the variance of
# TypedDicts with `builtins.enumerate`:
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
engine_prompt
=
cast
(
Union
[
EngineTokensPrompt
,
EngineEmbedsPrompt
],
engine_prompt
)
self
.
_log_inputs
(
request_id_item
,
ctx
.
request
_prompt
s
[
i
]
,
engine
_prompt
,
params
=
pooling_params
,
lora_request
=
ctx
.
lora_request
,
)
# Mypy has an existing bug related to inferring the variance of
# TypedDicts with `builtins.enumerate`:
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
engine_prompt
=
cast
(
Union
[
EngineTokensPrompt
,
EngineEmbedsPrompt
],
engine_prompt
)
generator
=
self
.
engine_client
.
encode
(
engine_prompt
,
pooling_params
,
...
...
vllm/entrypoints/renderer.py
View file @
0661cb9d
...
...
@@ -108,10 +108,15 @@ class CompletionRenderer(BaseRenderer):
for detailed parameter documentation.
"""
if
truncate_prompt_tokens
is
not
None
:
if
max_length
is
not
None
:
assert
0
<=
truncate_prompt_tokens
<=
max_length
if
truncate_prompt_tokens
==
0
:
return
[]
if
truncate_prompt_tokens
<
0
:
truncate_prompt_tokens
=
self
.
model_config
.
max_model_len
if
max_length
is
not
None
and
truncate_prompt_tokens
>
max_length
:
raise
ValueError
(
f
"truncate_prompt_tokens (
{
truncate_prompt_tokens
}
) "
f
"cannot be greater than max_length (
{
max_length
}
). "
f
"Please select a smaller truncation size."
)
# Parse and batch the input prompts
batch_inputs
=
parse_and_batch_prompt
(
prompt_or_prompts
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment