"vscode:/vscode.git/clone" did not exist on "18e85452979d2f974f2c193d159816a893fbc253"
Unverified Commit 176c799f authored by Ning Xie's avatar Ning Xie Committed by GitHub
Browse files

[openai api] log exception in exception handler (1/N) (#31164)


Signed-off-by: default avatarAndy Xie <andy.xning@gmail.com>
parent 612e7729
...@@ -11,7 +11,7 @@ from contextlib import asynccontextmanager ...@@ -11,7 +11,7 @@ from contextlib import asynccontextmanager
from http import HTTPStatus from http import HTTPStatus
import pydantic import pydantic
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from starlette.concurrency import iterate_in_threadpool from starlette.concurrency import iterate_in_threadpool
...@@ -20,11 +20,13 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send ...@@ -20,11 +20,13 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send
from vllm import envs from vllm import envs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.launcher import terminate_if_errored
from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse
from vllm.entrypoints.utils import sanitize_message from vllm.entrypoints.utils import create_error_response, sanitize_message
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.gc_utils import freeze_gc_heap from vllm.utils.gc_utils import freeze_gc_heap
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
logger = init_logger("vllm.entrypoints.openai.server_utils") logger = init_logger("vllm.entrypoints.openai.server_utils")
...@@ -309,7 +311,69 @@ async def log_response(request: Request, call_next): ...@@ -309,7 +311,69 @@ async def log_response(request: Request, call_next):
return response return response
async def http_exception_handler(_: Request, exc: HTTPException): async def engine_error_handler(
req: Request, exc: EngineDeadError | EngineGenerateError
):
"""
VLLM V1 AsyncLLM catches exceptions and returns
only two types: EngineGenerateError and EngineDeadError.
EngineGenerateError is raised by the per request generate()
method. This error could be request specific (and therefore
recoverable - e.g. if there is an error in input processing).
EngineDeadError is raised by the background output_handler
method. This error is global and therefore not recoverable.
We register these @app.exception_handlers to return nice
responses to the end user if they occur and shut down if needed.
See https://fastapi.tiangolo.com/tutorial/handling-errors/
for more details on how exception handlers work.
If an exception is encountered in a StreamingResponse
generator, the exception is not raised, since we already sent
a 200 status. Rather, we send an error message as the next chunk.
Since the exception is not raised, this means that the server
will not automatically shut down. Instead, we use the watchdog
background task for check for errored state.
"""
if req.app.state.args.log_error_stack:
logger.exception(
"Engine Exception caught. Request id: %s",
req.state.request_metadata.request_id
if hasattr(req.state, "request_metadata")
else None,
)
terminate_if_errored(
server=req.app.state.server,
engine=req.app.state.engine_client,
)
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
async def exception_handler(req: Request, exc: Exception):
if req.app.state.args.log_error_stack:
logger.exception(
"Exception caught. Request id: %s",
req.state.request_metadata.request_id
if hasattr(req.state, "request_metadata")
else None,
)
err = create_error_response(exc)
return JSONResponse(err.model_dump(), status_code=err.error.code)
async def http_exception_handler(req: Request, exc: HTTPException):
if req.app.state.args.log_error_stack:
logger.exception(
"HTTPException caught. Request id: %s",
req.state.request_metadata.request_id
if hasattr(req.state, "request_metadata")
else None,
)
err = ErrorResponse( err = ErrorResponse(
error=ErrorInfo( error=ErrorInfo(
message=sanitize_message(exc.detail), message=sanitize_message(exc.detail),
...@@ -320,7 +384,15 @@ async def http_exception_handler(_: Request, exc: HTTPException): ...@@ -320,7 +384,15 @@ async def http_exception_handler(_: Request, exc: HTTPException):
return JSONResponse(err.model_dump(), status_code=exc.status_code) return JSONResponse(err.model_dump(), status_code=exc.status_code)
async def validation_exception_handler(_: Request, exc: RequestValidationError): async def validation_exception_handler(req: Request, exc: RequestValidationError):
if req.app.state.args.log_error_stack:
logger.exception(
"RequestValidationError caught. Request id: %s",
req.state.request_metadata.request_id
if hasattr(req.state, "request_metadata")
else None,
)
param = None param = None
errors = exc.errors() errors = exc.errors()
for error in errors: for error in errors:
......
...@@ -71,10 +71,9 @@ async def create_transcriptions( ...@@ -71,10 +71,9 @@ async def create_transcriptions(
) )
audio_data = await request.file.read() audio_data = await request.file.read()
try:
generator = await handler.create_transcription(audio_data, request, raw_request) generator = await handler.create_transcription(audio_data, request, raw_request)
except Exception as e:
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code content=generator.model_dump(), status_code=generator.error.code
...@@ -108,10 +107,8 @@ async def create_translations( ...@@ -108,10 +107,8 @@ async def create_translations(
) )
audio_data = await request.file.read() audio_data = await request.file.read()
try:
generator = await handler.create_translation(audio_data, request, raw_request) generator = await handler.create_translation(audio_data, request, raw_request)
except Exception as e:
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
...@@ -140,7 +137,6 @@ def init_transcription_state( ...@@ -140,7 +137,6 @@ def init_transcription_state(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=args.log_error_stack,
enable_force_include_usage=args.enable_force_include_usage, enable_force_include_usage=args.enable_force_include_usage,
) )
if "transcription" in supported_tasks if "transcription" in supported_tasks
...@@ -151,7 +147,6 @@ def init_transcription_state( ...@@ -151,7 +147,6 @@ def init_transcription_state(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=args.log_error_stack,
enable_force_include_usage=args.enable_force_include_usage, enable_force_include_usage=args.enable_force_include_usage,
) )
if "transcription" in supported_tasks if "transcription" in supported_tasks
......
...@@ -40,7 +40,6 @@ class OpenAIServingTranscription(OpenAISpeechToText): ...@@ -40,7 +40,6 @@ class OpenAIServingTranscription(OpenAISpeechToText):
*, *,
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
enable_force_include_usage: bool = False, enable_force_include_usage: bool = False,
): ):
super().__init__( super().__init__(
...@@ -49,7 +48,6 @@ class OpenAIServingTranscription(OpenAISpeechToText): ...@@ -49,7 +48,6 @@ class OpenAIServingTranscription(OpenAISpeechToText):
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="transcribe", task_type="transcribe",
log_error_stack=log_error_stack,
enable_force_include_usage=enable_force_include_usage, enable_force_include_usage=enable_force_include_usage,
) )
...@@ -113,7 +111,6 @@ class OpenAIServingTranslation(OpenAISpeechToText): ...@@ -113,7 +111,6 @@ class OpenAIServingTranslation(OpenAISpeechToText):
*, *,
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
enable_force_include_usage: bool = False, enable_force_include_usage: bool = False,
): ):
super().__init__( super().__init__(
...@@ -122,7 +119,6 @@ class OpenAIServingTranslation(OpenAISpeechToText): ...@@ -122,7 +119,6 @@ class OpenAIServingTranslation(OpenAISpeechToText):
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="translate", task_type="translate",
log_error_stack=log_error_stack,
enable_force_include_usage=enable_force_include_usage, enable_force_include_usage=enable_force_include_usage,
) )
......
...@@ -97,7 +97,6 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -97,7 +97,6 @@ class OpenAISpeechToText(OpenAIServing):
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
task_type: Literal["transcribe", "translate"] = "transcribe", task_type: Literal["transcribe", "translate"] = "transcribe",
log_error_stack: bool = False,
enable_force_include_usage: bool = False, enable_force_include_usage: bool = False,
): ):
super().__init__( super().__init__(
...@@ -105,7 +104,6 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -105,7 +104,6 @@ class OpenAISpeechToText(OpenAIServing):
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
) )
self.default_sampling_params = self.model_config.get_diff_sampling_param() self.default_sampling_params = self.model_config.get_diff_sampling_param()
...@@ -517,69 +515,61 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -517,69 +515,61 @@ class OpenAISpeechToText(OpenAIServing):
if raw_request: if raw_request:
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
try: lora_request = self._maybe_get_adapters(request)
lora_request = self._maybe_get_adapters(request)
engine_prompts, duration_s = await self._preprocess_speech_to_text(
request=request,
audio_data=audio_data,
request_id=request_id,
)
except ValueError as e: engine_prompts, duration_s = await self._preprocess_speech_to_text(
logger.exception("Error in preprocessing prompt inputs") request=request,
return self.create_error_response(e) audio_data=audio_data,
request_id=request_id,
)
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len max_model_len = self.model_config.max_model_len
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
try: # Unlike most decoder-only models, whisper generation length is not
# Unlike most decoder-only models, whisper generation length is not # constrained by the size of the input audio, which is mapped to a
# constrained by the size of the input audio, which is mapped to a # fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be # generated by respecting the extra completion tokens arg.
# generated by respecting the extra completion tokens arg. max_tokens = get_max_tokens(
max_tokens = get_max_tokens( max_model_len,
max_model_len, request.max_completion_tokens,
request.max_completion_tokens, 0,
0, self.default_sampling_params,
self.default_sampling_params, )
)
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
max_tokens, max_tokens,
self.default_sampling_params, self.default_sampling_params,
)
if request.response_format == "verbose_json":
sampling_params.logprobs = 1
list_result_generator = []
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}_{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
) )
if request.response_format == "verbose_json":
sampling_params.logprobs = 1
list_result_generator = []
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}_{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = ( trace_headers = (
None None
if raw_request is None if raw_request is None
else await self._get_trace_headers(raw_request.headers) else await self._get_trace_headers(raw_request.headers)
) )
generator = self.engine_client.generate( generator = self.engine_client.generate(
engine_prompt, engine_prompt,
sampling_params, sampling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
) )
list_result_generator.append(generator) list_result_generator.append(generator)
except ValueError as e:
return self.create_error_response(e)
if request.stream: if request.stream:
return stream_generator_method( return stream_generator_method(
...@@ -663,8 +653,6 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -663,8 +653,6 @@ class OpenAISpeechToText(OpenAIServing):
return final_response return final_response
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e:
return self.create_error_response(e)
async def _speech_to_text_stream_generator( async def _speech_to_text_stream_generator(
self, self,
......
...@@ -72,7 +72,6 @@ def init_pooling_state( ...@@ -72,7 +72,6 @@ def init_pooling_state(
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template, trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
) )
) )
if any(t in supported_tasks for t in POOLING_TASKS) if any(t in supported_tasks for t in POOLING_TASKS)
...@@ -86,7 +85,6 @@ def init_pooling_state( ...@@ -86,7 +85,6 @@ def init_pooling_state(
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template, trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
) )
if "embed" in supported_tasks if "embed" in supported_tasks
else None else None
...@@ -99,7 +97,6 @@ def init_pooling_state( ...@@ -99,7 +97,6 @@ def init_pooling_state(
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template, trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
) )
if "classify" in supported_tasks if "classify" in supported_tasks
else None else None
...@@ -114,7 +111,6 @@ def init_pooling_state( ...@@ -114,7 +111,6 @@ def init_pooling_state(
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
score_template=resolved_chat_template, score_template=resolved_chat_template,
log_error_stack=args.log_error_stack,
use_gpu_for_pooling_score=getattr(args, "use_gpu_for_pooling_score", False), use_gpu_for_pooling_score=getattr(args, "use_gpu_for_pooling_score", False),
) )
if any(t in supported_tasks for t in ("embed", "score", "token_embed")) if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
......
...@@ -41,7 +41,6 @@ from vllm.tracing import ( ...@@ -41,7 +41,6 @@ from vllm.tracing import (
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.async_utils import merge_async_iterators from vllm.utils.async_utils import merge_async_iterators
from ...utils import create_error_response
from .io_processor import PoolingIOProcessor from .io_processor import PoolingIOProcessor
PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest) PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)
...@@ -112,34 +111,25 @@ class PoolingServing: ...@@ -112,34 +111,25 @@ class PoolingServing:
request: AnyPoolingRequest, request: AnyPoolingRequest,
raw_request: Request, raw_request: Request,
) -> JSONResponse: ) -> JSONResponse:
try: model_name = self.models.model_name()
model_name = self.models.model_name() request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
request_id = (
f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
)
await self._check_model(request) await self._check_model(request)
ctx = PoolingServeContext( ctx = PoolingServeContext(
request=request, request=request,
raw_request=raw_request, raw_request=raw_request,
model_name=model_name, model_name=model_name,
request_id=request_id, request_id=request_id,
) )
self._validate_request(ctx) self._validate_request(ctx)
self._maybe_get_adapters(ctx) self._maybe_get_adapters(ctx)
await self._preprocess(ctx) await self._preprocess(ctx)
await self._prepare_generators(ctx) await self._prepare_generators(ctx)
await self._collect_batch(ctx) await self._collect_batch(ctx)
response = await self._build_response(ctx) response = await self._build_response(ctx)
return JSONResponse(content=response.model_dump()) return JSONResponse(content=response.model_dump())
except Exception as e:
error_response = create_error_response(e)
return JSONResponse(
content=error_response.model_dump(),
status_code=error_response.error.code,
)
async def _preprocess( async def _preprocess(
self, self,
......
...@@ -61,10 +61,7 @@ async def create_embedding( ...@@ -61,10 +61,7 @@ async def create_embedding(
message="The model does not support Embeddings API" message="The model does not support Embeddings API"
) )
try: generator = await handler.create_embedding(request, raw_request)
generator = await handler.create_embedding(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
......
This diff is collapsed.
...@@ -41,10 +41,8 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): ...@@ -41,10 +41,8 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
return base_server.create_error_response( return base_server.create_error_response(
message="The model does not support Pooling API" message="The model does not support Pooling API"
) )
try:
generator = await handler.create_pooling(request, raw_request) generator = await handler.create_pooling(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
......
...@@ -8,7 +8,6 @@ from collections.abc import AsyncGenerator, Callable, Sequence ...@@ -8,7 +8,6 @@ from collections.abc import AsyncGenerator, Callable, Sequence
from functools import partial from functools import partial
from typing import Final, Literal, cast from typing import Final, Literal, cast
import jinja2
from fastapi import Request from fastapi import Request
from typing_extensions import assert_never from typing_extensions import assert_never
...@@ -53,13 +52,11 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -53,13 +52,11 @@ class OpenAIServingPooling(OpenAIServing):
chat_template: str | None, chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption, chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False, trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None: ) -> None:
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=log_error_stack,
) )
self.chat_template = chat_template self.chat_template = chat_template
...@@ -84,101 +81,92 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -84,101 +81,92 @@ class OpenAIServingPooling(OpenAIServing):
request_id = f"pool-{self._base_request_id(raw_request)}" request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time()) created_time = int(time.time())
try: lora_request = self._maybe_get_adapters(request)
lora_request = self._maybe_get_adapters(request)
if getattr(request, "dimensions", None) is not None: if getattr(request, "dimensions", None) is not None:
return self.create_error_response( return self.create_error_response("dimensions is currently not supported")
"dimensions is currently not supported"
)
engine_prompts: Sequence[ProcessorInputs] engine_prompts: Sequence[ProcessorInputs]
if use_io_processor := isinstance(request, IOProcessorRequest): if use_io_processor := isinstance(request, IOProcessorRequest):
if self.io_processor is None: if self.io_processor is None:
raise ValueError( raise ValueError(
"No IOProcessor plugin installed. Please refer " "No IOProcessor plugin installed. Please refer "
"to the documentation and to the " "to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' " "'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details." "offline inference example for more details."
) )
validated_prompt = self.io_processor.parse_data(request.data) validated_prompt = self.io_processor.parse_data(request.data)
raw_prompts = await self.io_processor.pre_process_async( raw_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id prompt=validated_prompt, request_id=request_id
) )
engine_prompts = await self._preprocess_cmpl( engine_prompts = await self._preprocess_cmpl(
request, request,
prompt_to_seq(raw_prompts), prompt_to_seq(raw_prompts),
) )
elif isinstance(request, PoolingChatRequest): elif isinstance(request, PoolingChatRequest):
error_check_ret = self._validate_chat_template( error_check_ret = self._validate_chat_template(
request_chat_template=request.chat_template, request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs, chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template, trust_request_chat_template=self.trust_request_chat_template,
) )
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
_, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self._preprocess_chat(
request, request,
request.messages, request.messages,
default_template=self.chat_template, default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format, default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None, default_template_kwargs=None,
) )
elif isinstance(request, PoolingCompletionRequest): elif isinstance(request, PoolingCompletionRequest):
engine_prompts = await self._preprocess_completion( engine_prompts = await self._preprocess_completion(
request, request,
prompt_input=request.input, prompt_input=request.input,
prompt_embeds=None, prompt_embeds=None,
) )
else: else:
raise ValueError(f"Unsupported request of type {type(request)}") raise ValueError(f"Unsupported request of type {type(request)}")
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try: if use_io_processor:
if use_io_processor: assert self.io_processor is not None
assert self.io_processor is not None
pooling_params = self.io_processor.merge_pooling_params()
if pooling_params.task is None:
pooling_params.task = "plugin"
else:
pooling_params = request.to_pooling_params() # type: ignore
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=pooling_params,
lora_request=lora_request,
)
trace_headers = ( pooling_params = self.io_processor.merge_pooling_params()
None if pooling_params.task is None:
if raw_request is None pooling_params.task = "plugin"
else await self._get_trace_headers(raw_request.headers) else:
) pooling_params = request.to_pooling_params() # type: ignore
generator = self.engine_client.encode( for i, engine_prompt in enumerate(engine_prompts):
engine_prompt, request_id_item = f"{request_id}-{i}"
pooling_params,
request_id_item, self._log_inputs(
lora_request=lora_request, request_id_item,
trace_headers=trace_headers, engine_prompt,
priority=request.priority, params=pooling_params,
) lora_request=lora_request,
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator) generators.append(generator)
except ValueError as e:
return self.create_error_response(e)
result_generator = merge_async_iterators(*generators) result_generator = merge_async_iterators(*generators)
...@@ -233,8 +221,6 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -233,8 +221,6 @@ class OpenAIServingPooling(OpenAIServing):
) )
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e:
return self.create_error_response(e)
return response return response
......
...@@ -49,10 +49,7 @@ async def create_score(request: ScoreRequest, raw_request: Request): ...@@ -49,10 +49,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
message="The model does not support Score API" message="The model does not support Score API"
) )
try: generator = await handler.create_score(request, raw_request)
generator = await handler.create_score(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
...@@ -100,10 +97,8 @@ async def do_rerank(request: RerankRequest, raw_request: Request): ...@@ -100,10 +97,8 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
return base_server.create_error_response( return base_server.create_error_response(
message="The model does not support Rerank (Score) API" message="The model does not support Rerank (Score) API"
) )
try:
generator = await handler.do_rerank(request, raw_request) generator = await handler.do_rerank(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
......
...@@ -62,7 +62,6 @@ class ServingScores(OpenAIServing): ...@@ -62,7 +62,6 @@ class ServingScores(OpenAIServing):
engine_client=engine_client, engine_client=engine_client,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=log_error_stack,
) )
self.score_template = score_template self.score_template = score_template
self.use_gpu_for_pooling_score = use_gpu_for_pooling_score self.use_gpu_for_pooling_score = use_gpu_for_pooling_score
...@@ -518,8 +517,6 @@ class ServingScores(OpenAIServing): ...@@ -518,8 +517,6 @@ class ServingScores(OpenAIServing):
) )
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e:
return self.create_error_response(e)
async def do_rerank( async def do_rerank(
self, request: RerankRequest, raw_request: Request | None = None self, request: RerankRequest, raw_request: Request | None = None
...@@ -562,8 +559,6 @@ class ServingScores(OpenAIServing): ...@@ -562,8 +559,6 @@ class ServingScores(OpenAIServing):
) )
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e:
return self.create_error_response(e)
def request_output_to_score_response( def request_output_to_score_response(
self, self,
......
...@@ -64,10 +64,8 @@ async def generate(request: GenerateRequest, raw_request: Request): ...@@ -64,10 +64,8 @@ async def generate(request: GenerateRequest, raw_request: Request):
return tokenization(raw_request).create_error_response( return tokenization(raw_request).create_error_response(
message="The model does not support generate tokens API" message="The model does not support generate tokens API"
) )
try:
generator = await handler.serve_tokens(request, raw_request) generator = await handler.serve_tokens(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
......
...@@ -49,7 +49,6 @@ class ServingTokens(OpenAIServing): ...@@ -49,7 +49,6 @@ class ServingTokens(OpenAIServing):
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
force_no_detokenize: bool = False, force_no_detokenize: bool = False,
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
enable_prompt_tokens_details: bool = False, enable_prompt_tokens_details: bool = False,
enable_log_outputs: bool = False, enable_log_outputs: bool = False,
): ):
...@@ -58,7 +57,6 @@ class ServingTokens(OpenAIServing): ...@@ -58,7 +57,6 @@ class ServingTokens(OpenAIServing):
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
) )
self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_log_outputs = enable_log_outputs self.enable_log_outputs = enable_log_outputs
...@@ -108,45 +106,38 @@ class ServingTokens(OpenAIServing): ...@@ -108,45 +106,38 @@ class ServingTokens(OpenAIServing):
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
result_generator: AsyncGenerator[RequestOutput, None] | None = None result_generator: AsyncGenerator[RequestOutput, None] | None = None
try: sampling_params = request.sampling_params
sampling_params = request.sampling_params if self.force_no_detokenize:
if self.force_no_detokenize: sampling_params.detokenize = False
sampling_params.detokenize = False
self._log_inputs(
self._log_inputs( request_id,
request_id, engine_prompt,
engine_prompt, params=sampling_params,
params=sampling_params, lora_request=lora_request,
lora_request=lora_request, )
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
result_generator = self.engine_client.generate( trace_headers = (
engine_prompt, None
sampling_params, if raw_request is None
request_id, else await self._get_trace_headers(raw_request.headers)
lora_request=lora_request, )
trace_headers=trace_headers,
priority=request.priority,
)
except ValueError as e: result_generator = self.engine_client.generate(
return self.create_error_response(str(e)) engine_prompt,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
# TODO(NickLucche): Implement streaming response # TODO(NickLucche): Implement streaming response
try: assert result_generator is not None
assert result_generator is not None return await self.serve_tokens_full_generator(
return await self.serve_tokens_full_generator( request, result_generator, request_id, model_name, request_metadata
request, result_generator, request_id, model_name, request_metadata )
)
except ValueError as e:
return self.create_error_response(str(e))
async def serve_tokens_full_generator( async def serve_tokens_full_generator(
self, self,
...@@ -165,8 +156,6 @@ class ServingTokens(OpenAIServing): ...@@ -165,8 +156,6 @@ class ServingTokens(OpenAIServing):
final_res = res final_res = res
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e:
return self.create_error_response(str(e))
assert final_res is not None assert final_res is not None
......
...@@ -49,10 +49,7 @@ router = APIRouter() ...@@ -49,10 +49,7 @@ router = APIRouter()
async def tokenize(request: TokenizeRequest, raw_request: Request): async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request) handler = tokenization(raw_request)
try: generator = await handler.create_tokenize(request, raw_request)
generator = await handler.create_tokenize(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Final from typing import Any, Final
import jinja2
from fastapi import Request from fastapi import Request
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
...@@ -37,13 +36,11 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -37,13 +36,11 @@ class OpenAIServingTokenization(OpenAIServing):
chat_template: str | None, chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption, chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False, trust_request_chat_template: bool = False,
log_error_stack: bool = False,
) -> None: ) -> None:
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=log_error_stack,
) )
self.chat_template = chat_template self.chat_template = chat_template
...@@ -61,40 +58,36 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -61,40 +58,36 @@ class OpenAIServingTokenization(OpenAIServing):
request_id = f"tokenize-{self._base_request_id(raw_request)}" request_id = f"tokenize-{self._base_request_id(raw_request)}"
try: lora_request = self._maybe_get_adapters(request)
lora_request = self._maybe_get_adapters(request)
if isinstance(request, TokenizeChatRequest):
if isinstance(request, TokenizeChatRequest): tool_dicts = (
tool_dicts = ( None
None if request.tools is None
if request.tools is None else [tool.model_dump() for tool in request.tools]
else [tool.model_dump() for tool in request.tools] )
) error_check_ret = self._validate_chat_template(
error_check_ret = self._validate_chat_template( request_chat_template=request.chat_template,
request_chat_template=request.chat_template, chat_template_kwargs=request.chat_template_kwargs,
chat_template_kwargs=request.chat_template_kwargs, trust_request_chat_template=self.trust_request_chat_template,
trust_request_chat_template=self.trust_request_chat_template, )
) if error_check_ret is not None:
if error_check_ret is not None: return error_check_ret
return error_check_ret
_, engine_prompts = await self._preprocess_chat(
_, engine_prompts = await self._preprocess_chat( request,
request, request.messages,
request.messages, default_template=self.chat_template,
default_template=self.chat_template, default_template_content_format=self.chat_template_content_format,
default_template_content_format=self.chat_template_content_format, default_template_kwargs=None,
default_template_kwargs=None, tool_dicts=tool_dicts,
tool_dicts=tool_dicts, )
) else:
else: engine_prompts = await self._preprocess_completion(
engine_prompts = await self._preprocess_completion( request,
request, prompt_input=request.prompt,
prompt_input=request.prompt, prompt_embeds=None,
prompt_embeds=None, )
)
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(f"{e} {e.__cause__}")
input_ids: list[int] = [] input_ids: list[int] = []
for engine_prompt in engine_prompts: for engine_prompt in engine_prompts:
...@@ -152,12 +145,9 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -152,12 +145,9 @@ class OpenAIServingTokenization(OpenAIServing):
self, self,
) -> TokenizerInfoResponse | ErrorResponse: ) -> TokenizerInfoResponse | ErrorResponse:
"""Get comprehensive tokenizer information.""" """Get comprehensive tokenizer information."""
try: tokenizer = self.renderer.get_tokenizer()
tokenizer = self.renderer.get_tokenizer() info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
info = TokenizerInfo(tokenizer, self.chat_template).to_dict() return TokenizerInfoResponse(**info)
return TokenizerInfoResponse(**info)
except Exception as e:
return self.create_error_response(f"Failed to get tokenizer info: {str(e)}")
@dataclass @dataclass
......
...@@ -5,13 +5,10 @@ import asyncio ...@@ -5,13 +5,10 @@ import asyncio
import dataclasses import dataclasses
import functools import functools
import os import os
import sys
import traceback
from argparse import Namespace from argparse import Namespace
from http import HTTPStatus from http import HTTPStatus
from logging import Logger from logging import Logger
from string import Template from string import Template
from typing import TYPE_CHECKING
import regex as re import regex as re
from fastapi import Request from fastapi import Request
...@@ -20,24 +17,17 @@ from starlette.background import BackgroundTask, BackgroundTasks ...@@ -20,24 +17,17 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs from vllm import envs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.exceptions import VLLMValidationError from vllm.entrypoints.openai.engine.protocol import (
ErrorInfo,
ErrorResponse,
GenerationError,
StreamOptions,
)
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
from vllm.logger import current_formatter_type, init_logger from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.engine.protocol import (
ErrorInfo,
ErrorResponse,
StreamOptions,
)
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
else:
ErrorResponse = object
ErrorInfo = object
LoRAModulePath = object
StreamOptions = object
logger = init_logger(__name__) logger = init_logger(__name__)
VLLM_SUBCMD_PARSER_EPILOG = ( VLLM_SUBCMD_PARSER_EPILOG = (
...@@ -307,20 +297,19 @@ def create_error_response( ...@@ -307,20 +297,19 @@ def create_error_response(
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
param: str | None = None, param: str | None = None,
log_error_stack: bool = False, ) -> ErrorResponse:
) -> "ErrorResponse":
exc: Exception | None = None exc: Exception | None = None
from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse
if isinstance(message, Exception): if isinstance(message, Exception):
exc = message exc = message
from vllm.exceptions import VLLMValidationError
if isinstance(exc, VLLMValidationError): if isinstance(exc, VLLMValidationError):
err_type = "BadRequestError" err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST status_code = HTTPStatus.BAD_REQUEST
param = exc.parameter param = exc.parameter
elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)): elif isinstance(exc, (ValueError, TypeError, OverflowError)):
# Common validation errors from user input # Common validation errors from user input
err_type = "BadRequestError" err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST status_code = HTTPStatus.BAD_REQUEST
...@@ -329,6 +318,10 @@ def create_error_response( ...@@ -329,6 +318,10 @@ def create_error_response(
err_type = "NotImplementedError" err_type = "NotImplementedError"
status_code = HTTPStatus.NOT_IMPLEMENTED status_code = HTTPStatus.NOT_IMPLEMENTED
param = None param = None
elif isinstance(exc, GenerationError):
err_type = "InternalServerError"
status_code = exc.status_code
param = None
elif exc.__class__.__name__ == "TemplateError": elif exc.__class__.__name__ == "TemplateError":
# jinja2.TemplateError (avoid importing jinja2) # jinja2.TemplateError (avoid importing jinja2)
err_type = "BadRequestError" err_type = "BadRequestError"
...@@ -341,13 +334,6 @@ def create_error_response( ...@@ -341,13 +334,6 @@ def create_error_response(
message = str(exc) message = str(exc)
if log_error_stack:
exc_type, _, _ = sys.exc_info()
if exc_type is not None:
traceback.print_exc()
else:
traceback.print_stack()
return ErrorResponse( return ErrorResponse(
error=ErrorInfo( error=ErrorInfo(
message=sanitize_message(message), message=sanitize_message(message),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment