Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
176c799f
Unverified
Commit
176c799f
authored
Mar 06, 2026
by
Ning Xie
Committed by
GitHub
Mar 05, 2026
Browse files
[openai api] log exception in exception handler (1/N) (#31164)
Signed-off-by:
Andy Xie
<
andy.xning@gmail.com
>
parent
612e7729
Changes
37
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
499 additions
and
548 deletions
+499
-548
vllm/entrypoints/openai/server_utils.py
vllm/entrypoints/openai/server_utils.py
+76
-4
vllm/entrypoints/openai/speech_to_text/api_router.py
vllm/entrypoints/openai/speech_to_text/api_router.py
+5
-10
vllm/entrypoints/openai/speech_to_text/serving.py
vllm/entrypoints/openai/speech_to_text/serving.py
+0
-4
vllm/entrypoints/openai/speech_to_text/speech_to_text.py
vllm/entrypoints/openai/speech_to_text/speech_to_text.py
+45
-57
vllm/entrypoints/pooling/__init__.py
vllm/entrypoints/pooling/__init__.py
+0
-4
vllm/entrypoints/pooling/base/serving.py
vllm/entrypoints/pooling/base/serving.py
+16
-26
vllm/entrypoints/pooling/embed/api_router.py
vllm/entrypoints/pooling/embed/api_router.py
+1
-4
vllm/entrypoints/pooling/embed/serving.py
vllm/entrypoints/pooling/embed/serving.py
+197
-214
vllm/entrypoints/pooling/pooling/api_router.py
vllm/entrypoints/pooling/pooling/api_router.py
+2
-4
vllm/entrypoints/pooling/pooling/serving.py
vllm/entrypoints/pooling/pooling/serving.py
+76
-90
vllm/entrypoints/pooling/score/api_router.py
vllm/entrypoints/pooling/score/api_router.py
+3
-8
vllm/entrypoints/pooling/score/serving.py
vllm/entrypoints/pooling/score/serving.py
+0
-5
vllm/entrypoints/serve/disagg/api_router.py
vllm/entrypoints/serve/disagg/api_router.py
+2
-4
vllm/entrypoints/serve/disagg/serving.py
vllm/entrypoints/serve/disagg/serving.py
+27
-38
vllm/entrypoints/serve/tokenize/api_router.py
vllm/entrypoints/serve/tokenize/api_router.py
+1
-4
vllm/entrypoints/serve/tokenize/serving.py
vllm/entrypoints/serve/tokenize/serving.py
+33
-43
vllm/entrypoints/utils.py
vllm/entrypoints/utils.py
+15
-29
No files found.
vllm/entrypoints/openai/server_utils.py
View file @
176c799f
...
...
@@ -11,7 +11,7 @@ from contextlib import asynccontextmanager
from
http
import
HTTPStatus
import
pydantic
from
fastapi
import
FastAPI
,
HTTPException
,
Request
from
fastapi
import
FastAPI
,
HTTPException
,
Request
,
Response
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.responses
import
JSONResponse
from
starlette.concurrency
import
iterate_in_threadpool
...
...
@@ -20,11 +20,13 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send
from
vllm
import
envs
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.launcher
import
terminate_if_errored
from
vllm.entrypoints.openai.engine.protocol
import
ErrorInfo
,
ErrorResponse
from
vllm.entrypoints.utils
import
sanitize_message
from
vllm.entrypoints.utils
import
create_error_response
,
sanitize_message
from
vllm.exceptions
import
VLLMValidationError
from
vllm.logger
import
init_logger
from
vllm.utils.gc_utils
import
freeze_gc_heap
from
vllm.v1.engine.exceptions
import
EngineDeadError
,
EngineGenerateError
logger
=
init_logger
(
"vllm.entrypoints.openai.server_utils"
)
...
...
@@ -309,7 +311,69 @@ async def log_response(request: Request, call_next):
return
response
async
def
http_exception_handler
(
_
:
Request
,
exc
:
HTTPException
):
async
def
engine_error_handler
(
req
:
Request
,
exc
:
EngineDeadError
|
EngineGenerateError
):
"""
VLLM V1 AsyncLLM catches exceptions and returns
only two types: EngineGenerateError and EngineDeadError.
EngineGenerateError is raised by the per request generate()
method. This error could be request specific (and therefore
recoverable - e.g. if there is an error in input processing).
EngineDeadError is raised by the background output_handler
method. This error is global and therefore not recoverable.
We register these @app.exception_handlers to return nice
responses to the end user if they occur and shut down if needed.
See https://fastapi.tiangolo.com/tutorial/handling-errors/
for more details on how exception handlers work.
If an exception is encountered in a StreamingResponse
generator, the exception is not raised, since we already sent
a 200 status. Rather, we send an error message as the next chunk.
Since the exception is not raised, this means that the server
will not automatically shut down. Instead, we use the watchdog
background task for check for errored state.
"""
if
req
.
app
.
state
.
args
.
log_error_stack
:
logger
.
exception
(
"Engine Exception caught. Request id: %s"
,
req
.
state
.
request_metadata
.
request_id
if
hasattr
(
req
.
state
,
"request_metadata"
)
else
None
,
)
terminate_if_errored
(
server
=
req
.
app
.
state
.
server
,
engine
=
req
.
app
.
state
.
engine_client
,
)
return
Response
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
)
async
def
exception_handler
(
req
:
Request
,
exc
:
Exception
):
if
req
.
app
.
state
.
args
.
log_error_stack
:
logger
.
exception
(
"Exception caught. Request id: %s"
,
req
.
state
.
request_metadata
.
request_id
if
hasattr
(
req
.
state
,
"request_metadata"
)
else
None
,
)
err
=
create_error_response
(
exc
)
return
JSONResponse
(
err
.
model_dump
(),
status_code
=
err
.
error
.
code
)
async
def
http_exception_handler
(
req
:
Request
,
exc
:
HTTPException
):
if
req
.
app
.
state
.
args
.
log_error_stack
:
logger
.
exception
(
"HTTPException caught. Request id: %s"
,
req
.
state
.
request_metadata
.
request_id
if
hasattr
(
req
.
state
,
"request_metadata"
)
else
None
,
)
err
=
ErrorResponse
(
error
=
ErrorInfo
(
message
=
sanitize_message
(
exc
.
detail
),
...
...
@@ -320,7 +384,15 @@ async def http_exception_handler(_: Request, exc: HTTPException):
return
JSONResponse
(
err
.
model_dump
(),
status_code
=
exc
.
status_code
)
async
def
validation_exception_handler
(
_
:
Request
,
exc
:
RequestValidationError
):
async
def
validation_exception_handler
(
req
:
Request
,
exc
:
RequestValidationError
):
if
req
.
app
.
state
.
args
.
log_error_stack
:
logger
.
exception
(
"RequestValidationError caught. Request id: %s"
,
req
.
state
.
request_metadata
.
request_id
if
hasattr
(
req
.
state
,
"request_metadata"
)
else
None
,
)
param
=
None
errors
=
exc
.
errors
()
for
error
in
errors
:
...
...
vllm/entrypoints/openai/speech_to_text/api_router.py
View file @
176c799f
...
...
@@ -71,10 +71,9 @@ async def create_transcriptions(
)
audio_data
=
await
request
.
file
.
read
()
try
:
generator
=
await
handler
.
create_transcription
(
audio_data
,
request
,
raw_request
)
except
Exception
as
e
:
return
handler
.
create_error_response
(
e
)
generator
=
await
handler
.
create_transcription
(
audio_data
,
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
error
.
code
...
...
@@ -108,10 +107,8 @@ async def create_translations(
)
audio_data
=
await
request
.
file
.
read
()
try
:
generator
=
await
handler
.
create_translation
(
audio_data
,
request
,
raw_request
)
except
Exception
as
e
:
return
handler
.
create_error_response
(
e
)
generator
=
await
handler
.
create_translation
(
audio_data
,
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
...
...
@@ -140,7 +137,6 @@ def init_transcription_state(
engine_client
,
state
.
openai_serving_models
,
request_logger
=
request_logger
,
log_error_stack
=
args
.
log_error_stack
,
enable_force_include_usage
=
args
.
enable_force_include_usage
,
)
if
"transcription"
in
supported_tasks
...
...
@@ -151,7 +147,6 @@ def init_transcription_state(
engine_client
,
state
.
openai_serving_models
,
request_logger
=
request_logger
,
log_error_stack
=
args
.
log_error_stack
,
enable_force_include_usage
=
args
.
enable_force_include_usage
,
)
if
"transcription"
in
supported_tasks
...
...
vllm/entrypoints/openai/speech_to_text/serving.py
View file @
176c799f
...
...
@@ -40,7 +40,6 @@ class OpenAIServingTranscription(OpenAISpeechToText):
*
,
request_logger
:
RequestLogger
|
None
,
return_tokens_as_token_ids
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
enable_force_include_usage
:
bool
=
False
,
):
super
().
__init__
(
...
...
@@ -49,7 +48,6 @@ class OpenAIServingTranscription(OpenAISpeechToText):
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
return_tokens_as_token_ids
,
task_type
=
"transcribe"
,
log_error_stack
=
log_error_stack
,
enable_force_include_usage
=
enable_force_include_usage
,
)
...
...
@@ -113,7 +111,6 @@ class OpenAIServingTranslation(OpenAISpeechToText):
*
,
request_logger
:
RequestLogger
|
None
,
return_tokens_as_token_ids
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
enable_force_include_usage
:
bool
=
False
,
):
super
().
__init__
(
...
...
@@ -122,7 +119,6 @@ class OpenAIServingTranslation(OpenAISpeechToText):
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
return_tokens_as_token_ids
,
task_type
=
"translate"
,
log_error_stack
=
log_error_stack
,
enable_force_include_usage
=
enable_force_include_usage
,
)
...
...
vllm/entrypoints/openai/speech_to_text/speech_to_text.py
View file @
176c799f
...
...
@@ -97,7 +97,6 @@ class OpenAISpeechToText(OpenAIServing):
request_logger
:
RequestLogger
|
None
,
return_tokens_as_token_ids
:
bool
=
False
,
task_type
:
Literal
[
"transcribe"
,
"translate"
]
=
"transcribe"
,
log_error_stack
:
bool
=
False
,
enable_force_include_usage
:
bool
=
False
,
):
super
().
__init__
(
...
...
@@ -105,7 +104,6 @@ class OpenAISpeechToText(OpenAIServing):
models
=
models
,
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
return_tokens_as_token_ids
,
log_error_stack
=
log_error_stack
,
)
self
.
default_sampling_params
=
self
.
model_config
.
get_diff_sampling_param
()
...
...
@@ -517,69 +515,61 @@ class OpenAISpeechToText(OpenAIServing):
if
raw_request
:
raw_request
.
state
.
request_metadata
=
request_metadata
try
:
lora_request
=
self
.
_maybe_get_adapters
(
request
)
engine_prompts
,
duration_s
=
await
self
.
_preprocess_speech_to_text
(
request
=
request
,
audio_data
=
audio_data
,
request_id
=
request_id
,
)
lora_request
=
self
.
_maybe_get_adapters
(
request
)
except
ValueError
as
e
:
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
return
self
.
create_error_response
(
e
)
engine_prompts
,
duration_s
=
await
self
.
_preprocess_speech_to_text
(
request
=
request
,
audio_data
=
audio_data
,
request_id
=
request_id
,
)
# Schedule the request and get the result generator.
max_model_len
=
self
.
model_config
.
max_model_len
list_result_generator
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
|
None
=
None
try
:
# Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
# generated by respecting the extra completion tokens arg.
max_tokens
=
get_max_tokens
(
max_model_len
,
request
.
max_completion_tokens
,
0
,
self
.
default_sampling_params
,
)
# Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
# generated by respecting the extra completion tokens arg.
max_tokens
=
get_max_tokens
(
max_model_len
,
request
.
max_completion_tokens
,
0
,
self
.
default_sampling_params
,
)
sampling_params
=
request
.
to_sampling_params
(
max_tokens
,
self
.
default_sampling_params
,
sampling_params
=
request
.
to_sampling_params
(
max_tokens
,
self
.
default_sampling_params
,
)
if
request
.
response_format
==
"verbose_json"
:
sampling_params
.
logprobs
=
1
list_result_generator
=
[]
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
request_id_item
=
f
"
{
request_id
}
_
{
i
}
"
self
.
_log_inputs
(
request_id_item
,
engine_prompt
,
params
=
sampling_params
,
lora_request
=
lora_request
,
)
if
request
.
response_format
==
"verbose_json"
:
sampling_params
.
logprobs
=
1
list_result_generator
=
[]
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
request_id_item
=
f
"
{
request_id
}
_
{
i
}
"
self
.
_log_inputs
(
request_id_item
,
engine_prompt
,
params
=
sampling_params
,
lora_request
=
lora_request
,
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
)
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
)
)
generator
=
self
.
engine_client
.
generate
(
engine_prompt
,
sampling_params
,
request_id_item
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
)
generator
=
self
.
engine_client
.
generate
(
engine_prompt
,
sampling_params
,
request_id_item
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
)
list_result_generator
.
append
(
generator
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
e
)
list_result_generator
.
append
(
generator
)
if
request
.
stream
:
return
stream_generator_method
(
...
...
@@ -663,8 +653,6 @@ class OpenAISpeechToText(OpenAIServing):
return
final_response
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
e
)
async
def
_speech_to_text_stream_generator
(
self
,
...
...
vllm/entrypoints/pooling/__init__.py
View file @
176c799f
...
...
@@ -72,7 +72,6 @@ def init_pooling_state(
chat_template
=
resolved_chat_template
,
chat_template_content_format
=
args
.
chat_template_content_format
,
trust_request_chat_template
=
args
.
trust_request_chat_template
,
log_error_stack
=
args
.
log_error_stack
,
)
)
if
any
(
t
in
supported_tasks
for
t
in
POOLING_TASKS
)
...
...
@@ -86,7 +85,6 @@ def init_pooling_state(
chat_template
=
resolved_chat_template
,
chat_template_content_format
=
args
.
chat_template_content_format
,
trust_request_chat_template
=
args
.
trust_request_chat_template
,
log_error_stack
=
args
.
log_error_stack
,
)
if
"embed"
in
supported_tasks
else
None
...
...
@@ -99,7 +97,6 @@ def init_pooling_state(
chat_template
=
resolved_chat_template
,
chat_template_content_format
=
args
.
chat_template_content_format
,
trust_request_chat_template
=
args
.
trust_request_chat_template
,
log_error_stack
=
args
.
log_error_stack
,
)
if
"classify"
in
supported_tasks
else
None
...
...
@@ -114,7 +111,6 @@ def init_pooling_state(
state
.
openai_serving_models
,
request_logger
=
request_logger
,
score_template
=
resolved_chat_template
,
log_error_stack
=
args
.
log_error_stack
,
use_gpu_for_pooling_score
=
getattr
(
args
,
"use_gpu_for_pooling_score"
,
False
),
)
if
any
(
t
in
supported_tasks
for
t
in
(
"embed"
,
"score"
,
"token_embed"
))
...
...
vllm/entrypoints/pooling/base/serving.py
View file @
176c799f
...
...
@@ -41,7 +41,6 @@ from vllm.tracing import (
from
vllm.utils
import
random_uuid
from
vllm.utils.async_utils
import
merge_async_iterators
from
...utils
import
create_error_response
from
.io_processor
import
PoolingIOProcessor
PoolingRequestT
=
TypeVar
(
"PoolingRequestT"
,
bound
=
AnyPoolingRequest
)
...
...
@@ -112,34 +111,25 @@ class PoolingServing:
request
:
AnyPoolingRequest
,
raw_request
:
Request
,
)
->
JSONResponse
:
try
:
model_name
=
self
.
models
.
model_name
()
request_id
=
(
f
"
{
self
.
request_id_prefix
}
-
{
self
.
_base_request_id
(
raw_request
)
}
"
)
model_name
=
self
.
models
.
model_name
()
request_id
=
f
"
{
self
.
request_id_prefix
}
-
{
self
.
_base_request_id
(
raw_request
)
}
"
await
self
.
_check_model
(
request
)
await
self
.
_check_model
(
request
)
ctx
=
PoolingServeContext
(
request
=
request
,
raw_request
=
raw_request
,
model_name
=
model_name
,
request_id
=
request_id
,
)
ctx
=
PoolingServeContext
(
request
=
request
,
raw_request
=
raw_request
,
model_name
=
model_name
,
request_id
=
request_id
,
)
self
.
_validate_request
(
ctx
)
self
.
_maybe_get_adapters
(
ctx
)
await
self
.
_preprocess
(
ctx
)
await
self
.
_prepare_generators
(
ctx
)
await
self
.
_collect_batch
(
ctx
)
response
=
await
self
.
_build_response
(
ctx
)
return
JSONResponse
(
content
=
response
.
model_dump
())
except
Exception
as
e
:
error_response
=
create_error_response
(
e
)
return
JSONResponse
(
content
=
error_response
.
model_dump
(),
status_code
=
error_response
.
error
.
code
,
)
self
.
_validate_request
(
ctx
)
self
.
_maybe_get_adapters
(
ctx
)
await
self
.
_preprocess
(
ctx
)
await
self
.
_prepare_generators
(
ctx
)
await
self
.
_collect_batch
(
ctx
)
response
=
await
self
.
_build_response
(
ctx
)
return
JSONResponse
(
content
=
response
.
model_dump
())
async
def
_preprocess
(
self
,
...
...
vllm/entrypoints/pooling/embed/api_router.py
View file @
176c799f
...
...
@@ -61,10 +61,7 @@ async def create_embedding(
message
=
"The model does not support Embeddings API"
)
try
:
generator
=
await
handler
.
create_embedding
(
request
,
raw_request
)
except
Exception
as
e
:
generator
=
handler
.
create_error_response
(
e
)
generator
=
await
handler
.
create_embedding
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
...
...
vllm/entrypoints/pooling/embed/serving.py
View file @
176c799f
...
...
@@ -54,13 +54,11 @@ class OpenAIServingEmbedding(OpenAIServing):
chat_template
:
str
|
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
...
...
@@ -75,38 +73,34 @@ class OpenAIServingEmbedding(OpenAIServing):
self
,
ctx
:
EmbeddingServeContext
,
)
->
ErrorResponse
|
None
:
try
:
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
if
isinstance
(
ctx
.
request
,
EmbeddingChatRequest
):
error_check_ret
=
self
.
_validate_chat_template
(
request_chat_template
=
ctx
.
request
.
chat_template
,
chat_template_kwargs
=
ctx
.
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
)
if
error_check_ret
is
not
None
:
return
error_check_ret
_
,
ctx
.
engine_prompts
=
await
self
.
_preprocess_chat
(
ctx
.
request
,
ctx
.
request
.
messages
,
default_template
=
self
.
chat_template
,
default_template_content_format
=
self
.
chat_template_content_format
,
default_template_kwargs
=
None
,
)
elif
isinstance
(
ctx
.
request
,
EmbeddingCompletionRequest
):
ctx
.
engine_prompts
=
await
self
.
_preprocess_completion
(
ctx
.
request
,
prompt_input
=
ctx
.
request
.
input
,
prompt_embeds
=
None
,
)
else
:
return
self
.
create_error_response
(
"Invalid classification request type"
)
ctx
.
lora_request
=
self
.
_maybe_get_adapters
(
ctx
.
request
)
if
isinstance
(
ctx
.
request
,
EmbeddingChatRequest
):
error_check_ret
=
self
.
_validate_chat_template
(
request_chat_template
=
ctx
.
request
.
chat_template
,
chat_template_kwargs
=
ctx
.
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
)
if
error_check_ret
is
not
None
:
return
error_check_ret
_
,
ctx
.
engine_prompts
=
await
self
.
_preprocess_chat
(
ctx
.
request
,
ctx
.
request
.
messages
,
default_template
=
self
.
chat_template
,
default_template_content_format
=
self
.
chat_template_content_format
,
default_template_kwargs
=
None
,
)
elif
isinstance
(
ctx
.
request
,
EmbeddingCompletionRequest
):
ctx
.
engine_prompts
=
await
self
.
_preprocess_completion
(
ctx
.
request
,
prompt_input
=
ctx
.
request
.
input
,
prompt_embeds
=
None
,
)
else
:
return
self
.
create_error_response
(
"Invalid classification request type"
)
return
None
except
(
ValueError
,
TypeError
)
as
e
:
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
return
self
.
create_error_response
(
str
(
e
))
return
None
def
request_output_to_embed_json_response
(
self
,
...
...
@@ -397,51 +391,47 @@ class OpenAIServingEmbedding(OpenAIServing):
# Custom logic for chunked processing
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
try
:
trace_headers
=
(
None
if
ctx
.
raw_request
is
None
else
await
self
.
_get_trace_headers
(
ctx
.
raw_request
.
headers
)
)
pooling_params
=
self
.
_create_pooling_params
(
ctx
)
if
isinstance
(
pooling_params
,
ErrorResponse
):
return
pooling_params
trace_headers
=
(
None
if
ctx
.
raw_request
is
None
else
await
self
.
_get_trace_headers
(
ctx
.
raw_request
.
headers
)
)
if
ctx
.
engine_prompts
is
None
:
return
self
.
create_error_response
(
"Engine prompts not available"
)
pooling_params
=
self
.
_create_pooling_params
(
ctx
)
if
isinstance
(
pooling_params
,
ErrorResponse
):
return
pooling_params
max_pos_embeddings
=
self
.
_get_max_position_embeddings
()
if
ctx
.
engine_prompts
is
None
:
return
self
.
create_error_response
(
"Engine prompts not available"
)
for
i
,
engine_prompt
in
enumerate
(
ctx
.
engine_prompts
):
# Check if this specific prompt needs chunked processing
if
"prompt_token_ids"
in
engine_prompt
:
prompt_token_ids
=
engine_prompt
[
"prompt_token_ids"
]
# type: ignore[typeddict-item]
if
len
(
prompt_token_ids
)
>
max_pos_embeddings
:
# Use chunked processing for this prompt
chunk_generators
=
await
self
.
_process_chunked_request
(
ctx
,
prompt_token_ids
,
pooling_params
,
trace_headers
,
i
,
)
generators
.
extend
(
chunk_generators
)
continue
max_pos_embeddings
=
self
.
_get_max_position_embeddings
()
# Normal processing for short prompts or non-token prompts
generator
=
await
self
.
_create_single_prompt_generator
(
ctx
,
engine_prompt
,
pooling_params
,
trace_headers
,
i
)
generators
.
append
(
generator
)
for
i
,
engine_prompt
in
enumerate
(
ctx
.
engine_prompts
):
# Check if this specific prompt needs chunked processing
if
"prompt_token_ids"
in
engine_prompt
:
prompt_token_ids
=
engine_prompt
[
"prompt_token_ids"
]
# type: ignore[typeddict-item]
if
len
(
prompt_token_ids
)
>
max_pos_embeddings
:
# Use chunked processing for this prompt
chunk_generators
=
await
self
.
_process_chunked_request
(
ctx
,
prompt_token_ids
,
pooling_params
,
trace_headers
,
i
,
)
generators
.
extend
(
chunk_generators
)
continue
ctx
.
result_generator
=
merge_async_iterators
(
*
generators
)
# Normal processing for short prompts or non-token prompts
generator
=
await
self
.
_create_single_prompt_generator
(
ctx
,
engine_prompt
,
pooling_params
,
trace_headers
,
i
)
generators
.
append
(
generator
)
return
None
ctx
.
result_generator
=
merge_async_iterators
(
*
generators
)
except
Exception
as
e
:
return
self
.
create_error_response
(
e
)
return
None
async
def
_collect_batch
(
self
,
...
...
@@ -454,164 +444,157 @@ class OpenAIServingEmbedding(OpenAIServing):
minimize memory usage.
For regular requests, collects results normally.
"""
try
:
if
ctx
.
engine_prompts
is
None
:
return
self
.
create_error_response
(
"Engine prompts not available"
)
# Check if we used chunked processing
use_chunked
=
self
.
_should_use_chunked_processing
(
ctx
.
request
)
if
not
use_chunked
:
return
await
super
().
_collect_batch
(
ctx
=
ctx
)
if
ctx
.
result_generator
is
None
:
return
self
.
create_error_response
(
"Result generator not available"
)
# Online aggregation for chunked requests to
# minimize memory usage
# Track aggregation state for each prompt
prompt_aggregators
:
dict
[
int
,
dict
[
str
,
Any
]]
=
{}
short_prompts_results
:
dict
[
int
,
PoolingRequestOutput
]
=
{}
async
for
result_idx
,
result
in
ctx
.
result_generator
:
if
"-chunk-"
in
result
.
request_id
:
# Extract prompt_idx from chunked request_id
parts
=
result
.
request_id
.
split
(
"-"
)
try
:
prompt_idx
=
int
(
parts
[
parts
.
index
(
"prompt"
)
+
1
])
except
(
ValueError
,
IndexError
):
# Fallback: extract from result_idx if parsing fails
prompt_idx
=
result_idx
# Initialize aggregator for this prompt if needed
if
prompt_idx
not
in
prompt_aggregators
:
prompt_aggregators
[
prompt_idx
]
=
{
"weighted_sum"
:
None
,
"total_weight"
:
0
,
"chunk_count"
:
0
,
"request_id"
:
result
.
request_id
.
split
(
"-chunk-"
)[
0
],
}
aggregator
=
prompt_aggregators
[
prompt_idx
]
# MEAN pooling with online weighted averaging
# Ensure result is PoolingRequestOutput
# for embedding processing
if
not
isinstance
(
result
,
PoolingRequestOutput
):
return
self
.
create_error_response
(
f
"Expected PoolingRequestOutput for "
f
"chunked embedding, got "
f
"
{
type
(
result
).
__name__
}
"
)
if
ctx
.
engine_prompts
is
None
:
return
self
.
create_error_response
(
"Engine prompts not available"
)
# Handle both PoolingOutput and
# EmbeddingOutput types
if
hasattr
(
result
.
outputs
,
"data"
):
# PoolingOutput case
embedding_data
=
result
.
outputs
.
data
elif
hasattr
(
result
.
outputs
,
"embedding"
):
# EmbeddingOutput case -
# convert embedding list to tensor
embedding_data
=
result
.
outputs
.
embedding
else
:
return
self
.
create_error_response
(
f
"Unsupported output type:
{
type
(
result
.
outputs
).
__name__
}
"
)
# Check if we used chunked processing
use_chunked
=
self
.
_should_use_chunked_processing
(
ctx
.
request
)
if
not
isinstance
(
embedding_data
,
torch
.
Tensor
):
embedding_data
=
torch
.
tensor
(
embedding_data
,
dtype
=
torch
.
float32
)
if
not
use_chunked
:
return
await
super
().
_collect_batch
(
ctx
=
ctx
)
if
ctx
.
result_generator
is
None
:
return
self
.
create_error_response
(
"Result generator not available"
)
# Online aggregation for chunked requests to
# minimize memory usage
# Track aggregation state for each prompt
prompt_aggregators
:
dict
[
int
,
dict
[
str
,
Any
]]
=
{}
short_prompts_results
:
dict
[
int
,
PoolingRequestOutput
]
=
{}
async
for
result_idx
,
result
in
ctx
.
result_generator
:
if
"-chunk-"
in
result
.
request_id
:
# Extract prompt_idx from chunked request_id
parts
=
result
.
request_id
.
split
(
"-"
)
try
:
prompt_idx
=
int
(
parts
[
parts
.
index
(
"prompt"
)
+
1
])
except
(
ValueError
,
IndexError
):
# Fallback: extract from result_idx if parsing fails
prompt_idx
=
result_idx
# Initialize aggregator for this prompt if needed
if
prompt_idx
not
in
prompt_aggregators
:
prompt_aggregators
[
prompt_idx
]
=
{
"weighted_sum"
:
None
,
"total_weight"
:
0
,
"chunk_count"
:
0
,
"request_id"
:
result
.
request_id
.
split
(
"-chunk-"
)[
0
],
}
if
result
.
prompt_token_ids
is
None
:
return
self
.
create_error_response
(
"prompt_token_ids cannot be None for chunked processing"
)
weight
=
len
(
result
.
prompt_token_ids
)
aggregator
=
prompt_aggregators
[
prompt_idx
]
# MEAN pooling with online weighted averaging
# Ensure result is PoolingRequestOutput
# for embedding processing
if
not
isinstance
(
result
,
PoolingRequestOutput
):
return
self
.
create_error_response
(
f
"Expected PoolingRequestOutput for "
f
"chunked embedding, got "
f
"
{
type
(
result
).
__name__
}
"
)
weighted_embedding
=
embedding_data
.
to
(
dtype
=
torch
.
float32
)
*
weight
# Handle both PoolingOutput and
# EmbeddingOutput types
if
hasattr
(
result
.
outputs
,
"data"
):
# PoolingOutput case
embedding_data
=
result
.
outputs
.
data
elif
hasattr
(
result
.
outputs
,
"embedding"
):
# EmbeddingOutput case -
# convert embedding list to tensor
embedding_data
=
result
.
outputs
.
embedding
else
:
return
self
.
create_error_response
(
f
"Unsupported output type:
{
type
(
result
.
outputs
).
__name__
}
"
)
if
aggregator
[
"weighted_sum"
]
is
None
:
# First chunk
aggregator
[
"weighted_sum"
]
=
weighted_embedding
else
:
# Accumulate
aggregator
[
"weighted_sum"
]
+=
weighted_embedding
if
not
isinstance
(
embedding_data
,
torch
.
Tensor
):
embedding_data
=
torch
.
tensor
(
embedding_data
,
dtype
=
torch
.
float32
)
aggregator
[
"total_weight"
]
+=
weight
aggregator
[
"chunk_count"
]
+=
1
if
result
.
prompt_token_ids
is
None
:
return
self
.
create_error_response
(
"prompt_token_ids cannot be None for chunked processing"
)
weight
=
len
(
result
.
prompt_token_ids
)
weighted_embedding
=
embedding_data
.
to
(
dtype
=
torch
.
float32
)
*
weight
if
aggregator
[
"weighted_sum"
]
is
None
:
# First chunk
aggregator
[
"weighted_sum"
]
=
weighted_embedding
else
:
# Non-chunked result - extract prompt_idx from request_id
parts
=
result
.
request_id
.
split
(
"-"
)
try
:
# Last part should be prompt index
prompt_idx
=
int
(
parts
[
-
1
])
except
(
ValueError
,
IndexError
):
prompt_idx
=
result_idx
# Fallback to result_idx
short_prompts_results
[
prompt_idx
]
=
result
# Finalize aggregated results
final_res_batch
:
list
[
PoolingRequestOutput
]
=
[]
num_prompts
=
len
(
ctx
.
engine_prompts
)
for
prompt_idx
in
range
(
num_prompts
):
if
prompt_idx
in
prompt_aggregators
:
# Finalize MEAN aggregation for this chunked prompt
aggregator
=
prompt_aggregators
[
prompt_idx
]
weighted_sum
=
aggregator
[
"weighted_sum"
]
total_weight
=
aggregator
[
"total_weight"
]
if
(
weighted_sum
is
not
None
and
isinstance
(
weighted_sum
,
torch
.
Tensor
)
and
isinstance
(
total_weight
,
(
int
,
float
))
and
total_weight
>
0
):
# Compute final mean embedding
final_embedding
=
weighted_sum
/
total_weight
# Create a PoolingRequestOutput
# for the aggregated result
pooling_output_data
=
PoolingOutput
(
data
=
final_embedding
)
# Get original prompt token IDs for this prompt
original_prompt
=
ctx
.
engine_prompts
[
prompt_idx
]
if
"prompt_token_ids"
not
in
original_prompt
:
return
self
.
create_error_response
(
f
"Chunked prompt
{
prompt_idx
}
does not contain "
"token IDs"
)
original_token_ids
=
original_prompt
[
"prompt_token_ids"
]
# type: ignore[typeddict-item]
pooling_request_output
=
PoolingRequestOutput
(
request_id
=
aggregator
[
"request_id"
],
prompt_token_ids
=
original_token_ids
,
outputs
=
pooling_output_data
,
num_cached_tokens
=
0
,
finished
=
True
,
)
# Accumulate
aggregator
[
"weighted_sum"
]
+=
weighted_embedding
final_res_batch
.
append
(
pooling_request_output
)
else
:
aggregator
[
"total_weight"
]
+=
weight
aggregator
[
"chunk_count"
]
+=
1
else
:
# Non-chunked result - extract prompt_idx from request_id
parts
=
result
.
request_id
.
split
(
"-"
)
try
:
# Last part should be prompt index
prompt_idx
=
int
(
parts
[
-
1
])
except
(
ValueError
,
IndexError
):
prompt_idx
=
result_idx
# Fallback to result_idx
short_prompts_results
[
prompt_idx
]
=
result
# Finalize aggregated results
final_res_batch
:
list
[
PoolingRequestOutput
]
=
[]
num_prompts
=
len
(
ctx
.
engine_prompts
)
for
prompt_idx
in
range
(
num_prompts
):
if
prompt_idx
in
prompt_aggregators
:
# Finalize MEAN aggregation for this chunked prompt
aggregator
=
prompt_aggregators
[
prompt_idx
]
weighted_sum
=
aggregator
[
"weighted_sum"
]
total_weight
=
aggregator
[
"total_weight"
]
if
(
weighted_sum
is
not
None
and
isinstance
(
weighted_sum
,
torch
.
Tensor
)
and
isinstance
(
total_weight
,
(
int
,
float
))
and
total_weight
>
0
):
# Compute final mean embedding
final_embedding
=
weighted_sum
/
total_weight
# Create a PoolingRequestOutput
# for the aggregated result
pooling_output_data
=
PoolingOutput
(
data
=
final_embedding
)
# Get original prompt token IDs for this prompt
original_prompt
=
ctx
.
engine_prompts
[
prompt_idx
]
if
"prompt_token_ids"
not
in
original_prompt
:
return
self
.
create_error_response
(
f
"
Failed to aggregate chunks for prompt
{
prompt_idx
}
"
f
"
Chunked prompt
{
prompt_idx
}
does not contain token IDs
"
)
elif
prompt_idx
in
short_prompts_results
:
final_res_batch
.
append
(
short_prompts_results
[
prompt_idx
])
original_token_ids
=
original_prompt
[
"prompt_token_ids"
]
# type: ignore[typeddict-item]
pooling_request_output
=
PoolingRequestOutput
(
request_id
=
aggregator
[
"request_id"
],
prompt_token_ids
=
original_token_ids
,
outputs
=
pooling_output_data
,
num_cached_tokens
=
0
,
finished
=
True
,
)
final_res_batch
.
append
(
pooling_request_output
)
else
:
return
self
.
create_error_response
(
f
"
Result not found
for prompt
{
prompt_idx
}
"
f
"
Failed to aggregate chunks
for prompt
{
prompt_idx
}
"
)
elif
prompt_idx
in
short_prompts_results
:
final_res_batch
.
append
(
short_prompts_results
[
prompt_idx
])
else
:
return
self
.
create_error_response
(
f
"Result not found for prompt
{
prompt_idx
}
"
)
ctx
.
final_res_batch
=
final_res_batch
return
None
ctx
.
final_res_batch
=
final_res_batch
except
Exception
as
e
:
return
self
.
create_error_response
(
e
)
return
None
async
def
create_embedding
(
self
,
...
...
vllm/entrypoints/pooling/pooling/api_router.py
View file @
176c799f
...
...
@@ -41,10 +41,8 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
return
base_server
.
create_error_response
(
message
=
"The model does not support Pooling API"
)
try
:
generator
=
await
handler
.
create_pooling
(
request
,
raw_request
)
except
Exception
as
e
:
generator
=
handler
.
create_error_response
(
e
)
generator
=
await
handler
.
create_pooling
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
...
...
vllm/entrypoints/pooling/pooling/serving.py
View file @
176c799f
...
...
@@ -8,7 +8,6 @@ from collections.abc import AsyncGenerator, Callable, Sequence
from
functools
import
partial
from
typing
import
Final
,
Literal
,
cast
import
jinja2
from
fastapi
import
Request
from
typing_extensions
import
assert_never
...
...
@@ -53,13 +52,11 @@ class OpenAIServingPooling(OpenAIServing):
chat_template
:
str
|
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
...
...
@@ -84,101 +81,92 @@ class OpenAIServingPooling(OpenAIServing):
request_id
=
f
"pool-
{
self
.
_base_request_id
(
raw_request
)
}
"
created_time
=
int
(
time
.
time
())
try
:
lora_request
=
self
.
_maybe_get_adapters
(
request
)
lora_request
=
self
.
_maybe_get_adapters
(
request
)
if
getattr
(
request
,
"dimensions"
,
None
)
is
not
None
:
return
self
.
create_error_response
(
"dimensions is currently not supported"
)
if
getattr
(
request
,
"dimensions"
,
None
)
is
not
None
:
return
self
.
create_error_response
(
"dimensions is currently not supported"
)
engine_prompts
:
Sequence
[
ProcessorInputs
]
if
use_io_processor
:
=
isinstance
(
request
,
IOProcessorRequest
):
if
self
.
io_processor
is
None
:
raise
ValueError
(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details."
)
engine_prompts
:
Sequence
[
ProcessorInputs
]
if
use_io_processor
:
=
isinstance
(
request
,
IOProcessorRequest
):
if
self
.
io_processor
is
None
:
raise
ValueError
(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details."
)
validated_prompt
=
self
.
io_processor
.
parse_data
(
request
.
data
)
validated_prompt
=
self
.
io_processor
.
parse_data
(
request
.
data
)
raw_prompts
=
await
self
.
io_processor
.
pre_process_async
(
prompt
=
validated_prompt
,
request_id
=
request_id
)
engine_prompts
=
await
self
.
_preprocess_cmpl
(
request
,
prompt_to_seq
(
raw_prompts
),
)
elif
isinstance
(
request
,
PoolingChatRequest
):
error_check_ret
=
self
.
_validate_chat_template
(
request_chat_template
=
request
.
chat_template
,
chat_template_kwargs
=
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
)
if
error_check_ret
is
not
None
:
return
error_check_ret
_
,
engine_prompts
=
await
self
.
_preprocess_chat
(
request
,
request
.
messages
,
default_template
=
self
.
chat_template
,
default_template_content_format
=
self
.
chat_template_content_format
,
default_template_kwargs
=
None
,
)
elif
isinstance
(
request
,
PoolingCompletionRequest
):
engine_prompts
=
await
self
.
_preprocess_completion
(
request
,
prompt_input
=
request
.
input
,
prompt_embeds
=
None
,
)
else
:
raise
ValueError
(
f
"Unsupported request of type
{
type
(
request
)
}
"
)
except
(
ValueError
,
TypeError
,
jinja2
.
TemplateError
)
as
e
:
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
return
self
.
create_error_response
(
str
(
e
))
raw_prompts
=
await
self
.
io_processor
.
pre_process_async
(
prompt
=
validated_prompt
,
request_id
=
request_id
)
engine_prompts
=
await
self
.
_preprocess_cmpl
(
request
,
prompt_to_seq
(
raw_prompts
),
)
elif
isinstance
(
request
,
PoolingChatRequest
):
error_check_ret
=
self
.
_validate_chat_template
(
request_chat_template
=
request
.
chat_template
,
chat_template_kwargs
=
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
)
if
error_check_ret
is
not
None
:
return
error_check_ret
_
,
engine_prompts
=
await
self
.
_preprocess_chat
(
request
,
request
.
messages
,
default_template
=
self
.
chat_template
,
default_template_content_format
=
self
.
chat_template_content_format
,
default_template_kwargs
=
None
,
)
elif
isinstance
(
request
,
PoolingCompletionRequest
):
engine_prompts
=
await
self
.
_preprocess_completion
(
request
,
prompt_input
=
request
.
input
,
prompt_embeds
=
None
,
)
else
:
raise
ValueError
(
f
"Unsupported request of type
{
type
(
request
)
}
"
)
# Schedule the request and get the result generator.
generators
:
list
[
AsyncGenerator
[
PoolingRequestOutput
,
None
]]
=
[]
try
:
if
use_io_processor
:
assert
self
.
io_processor
is
not
None
pooling_params
=
self
.
io_processor
.
merge_pooling_params
()
if
pooling_params
.
task
is
None
:
pooling_params
.
task
=
"plugin"
else
:
pooling_params
=
request
.
to_pooling_params
()
# type: ignore
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
self
.
_log_inputs
(
request_id_item
,
engine_prompt
,
params
=
pooling_params
,
lora_request
=
lora_request
,
)
if
use_io_processor
:
assert
self
.
io_processor
is
not
None
trace_headers
=
(
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
)
)
pooling_params
=
self
.
io_processor
.
merge_pooling_params
()
if
pooling_params
.
task
is
None
:
pooling_params
.
task
=
"plugin"
else
:
pooling_params
=
request
.
to_pooling_params
()
# type: ignore
generator
=
self
.
engine_client
.
encode
(
engine_prompt
,
pooling_params
,
request_id_item
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
)
for
i
,
engine_prompt
in
enumerate
(
engine_prompts
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
self
.
_log_inputs
(
request_id_item
,
engine_prompt
,
params
=
pooling_params
,
lora_request
=
lora_request
,
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
)
)
generator
=
self
.
engine_client
.
encode
(
engine_prompt
,
pooling_params
,
request_id_item
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
)
generators
.
append
(
generator
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
e
)
generators
.
append
(
generator
)
result_generator
=
merge_async_iterators
(
*
generators
)
...
...
@@ -233,8 +221,6 @@ class OpenAIServingPooling(OpenAIServing):
)
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
e
)
return
response
...
...
vllm/entrypoints/pooling/score/api_router.py
View file @
176c799f
...
...
@@ -49,10 +49,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
message
=
"The model does not support Score API"
)
try
:
generator
=
await
handler
.
create_score
(
request
,
raw_request
)
except
Exception
as
e
:
generator
=
handler
.
create_error_response
(
e
)
generator
=
await
handler
.
create_score
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
...
...
@@ -100,10 +97,8 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
return
base_server
.
create_error_response
(
message
=
"The model does not support Rerank (Score) API"
)
try
:
generator
=
await
handler
.
do_rerank
(
request
,
raw_request
)
except
Exception
as
e
:
generator
=
handler
.
create_error_response
(
e
)
generator
=
await
handler
.
do_rerank
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
...
...
vllm/entrypoints/pooling/score/serving.py
View file @
176c799f
...
...
@@ -62,7 +62,6 @@ class ServingScores(OpenAIServing):
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
score_template
=
score_template
self
.
use_gpu_for_pooling_score
=
use_gpu_for_pooling_score
...
...
@@ -518,8 +517,6 @@ class ServingScores(OpenAIServing):
)
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
e
)
async
def
do_rerank
(
self
,
request
:
RerankRequest
,
raw_request
:
Request
|
None
=
None
...
...
@@ -562,8 +559,6 @@ class ServingScores(OpenAIServing):
)
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
e
)
def
request_output_to_score_response
(
self
,
...
...
vllm/entrypoints/serve/disagg/api_router.py
View file @
176c799f
...
...
@@ -64,10 +64,8 @@ async def generate(request: GenerateRequest, raw_request: Request):
return
tokenization
(
raw_request
).
create_error_response
(
message
=
"The model does not support generate tokens API"
)
try
:
generator
=
await
handler
.
serve_tokens
(
request
,
raw_request
)
except
Exception
as
e
:
generator
=
handler
.
create_error_response
(
e
)
generator
=
await
handler
.
serve_tokens
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
...
...
vllm/entrypoints/serve/disagg/serving.py
View file @
176c799f
...
...
@@ -49,7 +49,6 @@ class ServingTokens(OpenAIServing):
request_logger
:
RequestLogger
|
None
,
force_no_detokenize
:
bool
=
False
,
return_tokens_as_token_ids
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
enable_prompt_tokens_details
:
bool
=
False
,
enable_log_outputs
:
bool
=
False
,
):
...
...
@@ -58,7 +57,6 @@ class ServingTokens(OpenAIServing):
models
=
models
,
request_logger
=
request_logger
,
return_tokens_as_token_ids
=
return_tokens_as_token_ids
,
log_error_stack
=
log_error_stack
,
)
self
.
enable_prompt_tokens_details
=
enable_prompt_tokens_details
self
.
enable_log_outputs
=
enable_log_outputs
...
...
@@ -108,45 +106,38 @@ class ServingTokens(OpenAIServing):
# Schedule the request and get the result generator.
result_generator
:
AsyncGenerator
[
RequestOutput
,
None
]
|
None
=
None
try
:
sampling_params
=
request
.
sampling_params
if
self
.
force_no_detokenize
:
sampling_params
.
detokenize
=
False
self
.
_log_inputs
(
request_id
,
engine_prompt
,
params
=
sampling_params
,
lora_request
=
lora_request
,
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
)
)
sampling_params
=
request
.
sampling_params
if
self
.
force_no_detokenize
:
sampling_params
.
detokenize
=
False
self
.
_log_inputs
(
request_id
,
engine_prompt
,
params
=
sampling_params
,
lora_request
=
lora_request
,
)
result_generator
=
self
.
engine_client
.
generate
(
engine_prompt
,
sampling_params
,
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
)
trace_headers
=
(
None
if
raw_request
is
None
else
await
self
.
_get_trace_headers
(
raw_request
.
headers
)
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
result_generator
=
self
.
engine_client
.
generate
(
engine_prompt
,
sampling_params
,
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
)
# TODO(NickLucche): Implement streaming response
try
:
assert
result_generator
is
not
None
return
await
self
.
serve_tokens_full_generator
(
request
,
result_generator
,
request_id
,
model_name
,
request_metadata
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
assert
result_generator
is
not
None
return
await
self
.
serve_tokens_full_generator
(
request
,
result_generator
,
request_id
,
model_name
,
request_metadata
)
async
def
serve_tokens_full_generator
(
self
,
...
...
@@ -165,8 +156,6 @@ class ServingTokens(OpenAIServing):
final_res
=
res
except
asyncio
.
CancelledError
:
return
self
.
create_error_response
(
"Client disconnected"
)
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
assert
final_res
is
not
None
...
...
vllm/entrypoints/serve/tokenize/api_router.py
View file @
176c799f
...
...
@@ -49,10 +49,7 @@ router = APIRouter()
async
def
tokenize
(
request
:
TokenizeRequest
,
raw_request
:
Request
):
handler
=
tokenization
(
raw_request
)
try
:
generator
=
await
handler
.
create_tokenize
(
request
,
raw_request
)
except
Exception
as
e
:
generator
=
handler
.
create_error_response
(
e
)
generator
=
await
handler
.
create_tokenize
(
request
,
raw_request
)
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
...
...
vllm/entrypoints/serve/tokenize/serving.py
View file @
176c799f
...
...
@@ -3,7 +3,6 @@
from
dataclasses
import
dataclass
from
typing
import
Any
,
Final
import
jinja2
from
fastapi
import
Request
from
vllm.engine.protocol
import
EngineClient
...
...
@@ -37,13 +36,11 @@ class OpenAIServingTokenization(OpenAIServing):
chat_template
:
str
|
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
,
trust_request_chat_template
:
bool
=
False
,
log_error_stack
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
engine_client
=
engine_client
,
models
=
models
,
request_logger
=
request_logger
,
log_error_stack
=
log_error_stack
,
)
self
.
chat_template
=
chat_template
...
...
@@ -61,40 +58,36 @@ class OpenAIServingTokenization(OpenAIServing):
request_id
=
f
"tokenize-
{
self
.
_base_request_id
(
raw_request
)
}
"
try
:
lora_request
=
self
.
_maybe_get_adapters
(
request
)
if
isinstance
(
request
,
TokenizeChatRequest
):
tool_dicts
=
(
None
if
request
.
tools
is
None
else
[
tool
.
model_dump
()
for
tool
in
request
.
tools
]
)
error_check_ret
=
self
.
_validate_chat_template
(
request_chat_template
=
request
.
chat_template
,
chat_template_kwargs
=
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
)
if
error_check_ret
is
not
None
:
return
error_check_ret
_
,
engine_prompts
=
await
self
.
_preprocess_chat
(
request
,
request
.
messages
,
default_template
=
self
.
chat_template
,
default_template_content_format
=
self
.
chat_template_content_format
,
default_template_kwargs
=
None
,
tool_dicts
=
tool_dicts
,
)
else
:
engine_prompts
=
await
self
.
_preprocess_completion
(
request
,
prompt_input
=
request
.
prompt
,
prompt_embeds
=
None
,
)
except
(
ValueError
,
TypeError
,
jinja2
.
TemplateError
)
as
e
:
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
return
self
.
create_error_response
(
f
"
{
e
}
{
e
.
__cause__
}
"
)
lora_request
=
self
.
_maybe_get_adapters
(
request
)
if
isinstance
(
request
,
TokenizeChatRequest
):
tool_dicts
=
(
None
if
request
.
tools
is
None
else
[
tool
.
model_dump
()
for
tool
in
request
.
tools
]
)
error_check_ret
=
self
.
_validate_chat_template
(
request_chat_template
=
request
.
chat_template
,
chat_template_kwargs
=
request
.
chat_template_kwargs
,
trust_request_chat_template
=
self
.
trust_request_chat_template
,
)
if
error_check_ret
is
not
None
:
return
error_check_ret
_
,
engine_prompts
=
await
self
.
_preprocess_chat
(
request
,
request
.
messages
,
default_template
=
self
.
chat_template
,
default_template_content_format
=
self
.
chat_template_content_format
,
default_template_kwargs
=
None
,
tool_dicts
=
tool_dicts
,
)
else
:
engine_prompts
=
await
self
.
_preprocess_completion
(
request
,
prompt_input
=
request
.
prompt
,
prompt_embeds
=
None
,
)
input_ids
:
list
[
int
]
=
[]
for
engine_prompt
in
engine_prompts
:
...
...
@@ -152,12 +145,9 @@ class OpenAIServingTokenization(OpenAIServing):
self
,
)
->
TokenizerInfoResponse
|
ErrorResponse
:
"""Get comprehensive tokenizer information."""
try
:
tokenizer
=
self
.
renderer
.
get_tokenizer
()
info
=
TokenizerInfo
(
tokenizer
,
self
.
chat_template
).
to_dict
()
return
TokenizerInfoResponse
(
**
info
)
except
Exception
as
e
:
return
self
.
create_error_response
(
f
"Failed to get tokenizer info:
{
str
(
e
)
}
"
)
tokenizer
=
self
.
renderer
.
get_tokenizer
()
info
=
TokenizerInfo
(
tokenizer
,
self
.
chat_template
).
to_dict
()
return
TokenizerInfoResponse
(
**
info
)
@
dataclass
...
...
vllm/entrypoints/utils.py
View file @
176c799f
...
...
@@ -5,13 +5,10 @@ import asyncio
import
dataclasses
import
functools
import
os
import
sys
import
traceback
from
argparse
import
Namespace
from
http
import
HTTPStatus
from
logging
import
Logger
from
string
import
Template
from
typing
import
TYPE_CHECKING
import
regex
as
re
from
fastapi
import
Request
...
...
@@ -20,24 +17,17 @@ from starlette.background import BackgroundTask, BackgroundTasks
from
vllm
import
envs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.exceptions
import
VLLMValidationError
from
vllm.entrypoints.openai.engine.protocol
import
(
ErrorInfo
,
ErrorResponse
,
GenerationError
,
StreamOptions
,
)
from
vllm.entrypoints.openai.models.protocol
import
LoRAModulePath
from
vllm.logger
import
current_formatter_type
,
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
if
TYPE_CHECKING
:
from
vllm.entrypoints.openai.engine.protocol
import
(
ErrorInfo
,
ErrorResponse
,
StreamOptions
,
)
from
vllm.entrypoints.openai.models.protocol
import
LoRAModulePath
else
:
ErrorResponse
=
object
ErrorInfo
=
object
LoRAModulePath
=
object
StreamOptions
=
object
logger
=
init_logger
(
__name__
)
VLLM_SUBCMD_PARSER_EPILOG
=
(
...
...
@@ -307,20 +297,19 @@ def create_error_response(
err_type
:
str
=
"BadRequestError"
,
status_code
:
HTTPStatus
=
HTTPStatus
.
BAD_REQUEST
,
param
:
str
|
None
=
None
,
log_error_stack
:
bool
=
False
,
)
->
"ErrorResponse"
:
)
->
ErrorResponse
:
exc
:
Exception
|
None
=
None
from
vllm.entrypoints.openai.engine.protocol
import
ErrorInfo
,
ErrorResponse
if
isinstance
(
message
,
Exception
):
exc
=
message
from
vllm.exceptions
import
VLLMValidationError
if
isinstance
(
exc
,
VLLMValidationError
):
err_type
=
"BadRequestError"
status_code
=
HTTPStatus
.
BAD_REQUEST
param
=
exc
.
parameter
elif
isinstance
(
exc
,
(
ValueError
,
TypeError
,
RuntimeError
,
OverflowError
)):
elif
isinstance
(
exc
,
(
ValueError
,
TypeError
,
OverflowError
)):
# Common validation errors from user input
err_type
=
"BadRequestError"
status_code
=
HTTPStatus
.
BAD_REQUEST
...
...
@@ -329,6 +318,10 @@ def create_error_response(
err_type
=
"NotImplementedError"
status_code
=
HTTPStatus
.
NOT_IMPLEMENTED
param
=
None
elif
isinstance
(
exc
,
GenerationError
):
err_type
=
"InternalServerError"
status_code
=
exc
.
status_code
param
=
None
elif
exc
.
__class__
.
__name__
==
"TemplateError"
:
# jinja2.TemplateError (avoid importing jinja2)
err_type
=
"BadRequestError"
...
...
@@ -341,13 +334,6 @@ def create_error_response(
message
=
str
(
exc
)
if
log_error_stack
:
exc_type
,
_
,
_
=
sys
.
exc_info
()
if
exc_type
is
not
None
:
traceback
.
print_exc
()
else
:
traceback
.
print_stack
()
return
ErrorResponse
(
error
=
ErrorInfo
(
message
=
sanitize_message
(
message
),
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment