"vllm/vscode:/vscode.git/clone" did not exist on "2c4f59afc3d50fda805c4ad94c9d9be168cded0b"
Unverified Commit 176c799f authored by Ning Xie's avatar Ning Xie Committed by GitHub
Browse files

[openai api] log exception in exception handler (1/N) (#31164)


Signed-off-by: default avatarAndy Xie <andy.xning@gmail.com>
parent 612e7729
......@@ -6,7 +6,6 @@ from unittest.mock import MagicMock
import pytest
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import GenerationError, OpenAIServing
......@@ -38,32 +37,6 @@ async def test_raise_if_error_raises_generation_error():
serving._raise_if_error(None, "test-request-id") # should not raise
@pytest.mark.asyncio
async def test_convert_generation_error_to_response():
"""test _convert_generation_error_to_response creates proper ErrorResponse"""
mock_engine = MagicMock()
mock_engine.model_config = MagicMock()
mock_engine.model_config.max_model_len = 100
mock_models = MagicMock()
serving = OpenAIServing(
engine_client=mock_engine,
models=mock_models,
request_logger=None,
)
# create a GenerationError
gen_error = GenerationError("Internal server error")
# convert to ErrorResponse
error_response = serving._convert_generation_error_to_response(gen_error)
assert isinstance(error_response, ErrorResponse)
assert error_response.error.type == "InternalServerError"
assert error_response.error.message == "Internal server error"
assert error_response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
@pytest.mark.asyncio
async def test_convert_generation_error_to_streaming_response():
"""test _convert_generation_error_to_streaming_response output"""
......
......@@ -13,7 +13,7 @@ from typing import Any
import pytest
import pytest_asyncio
import requests
from openai import BadRequestError, NotFoundError, OpenAI
from openai import InternalServerError, NotFoundError, OpenAI
from openai_harmony import Message
from ....utils import RemoteOpenAIServer
......@@ -698,7 +698,7 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str):
async def test_function_calling_required(client: OpenAI, model_name: str):
tools = [GET_WEATHER_SCHEMA]
with pytest.raises(BadRequestError):
with pytest.raises(InternalServerError):
await client.responses.create(
model=model_name,
input="What's the weather like in Paris today?",
......
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
......@@ -11,7 +10,7 @@ import pytest
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.protocol import GenerationError
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput
......@@ -145,12 +144,8 @@ async def test_chat_error_non_stream():
stream=False,
)
response = await serving_chat.create_chat_completion(request)
assert isinstance(response, ErrorResponse)
assert response.error.type == "InternalServerError"
assert response.error.message == "Internal server error"
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
with pytest.raises(GenerationError):
await serving_chat.create_chat_completion(request)
@pytest.mark.asyncio
......
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any
from unittest.mock import MagicMock
......@@ -11,7 +10,7 @@ import pytest
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.protocol import GenerationError
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput
......@@ -131,12 +130,8 @@ async def test_completion_error_non_stream():
stream=False,
)
response = await serving_completion.create_completion(request)
assert isinstance(response, ErrorResponse)
assert response.error.type == "InternalServerError"
assert response.error.message == "Internal server error"
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
with pytest.raises(GenerationError):
await serving_completion.create_completion(request)
@pytest.mark.asyncio
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from http import HTTPStatus
from typing import Final
import pytest
import schemathesis
from httpx import URL
from hypothesis import settings
from schemathesis import GenerationConfig
from schemathesis.checks import not_a_server_error
from schemathesis.internal.checks import CheckContext
from schemathesis.models import Case
from schemathesis.transports.responses import GenericResponse
from ...utils import RemoteOpenAIServer
......@@ -127,10 +133,25 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy):
return strategy.filter(no_invalid_types)
def customized_not_a_server_error(
ctx: CheckContext, response: GenericResponse, case: Case
) -> bool | None:
try:
return not_a_server_error(ctx, response, case)
except Exception:
if (
URL(response.request.url).path
in ["/v1/chat/completions/render", "/v1/chat/completions"]
and response.status_code == HTTPStatus.NOT_IMPLEMENTED.value
):
return True
raise
@schema.parametrize()
@schema.override(headers={"Content-Type": "application/json"})
@settings(deadline=LONG_TIMEOUT_SECONDS * 1000, max_examples=50)
def test_openapi_stateless(case: schemathesis.Case):
def test_openapi_stateless(case: Case):
key = (
case.operation.method.upper(),
case.operation.path,
......@@ -155,4 +176,9 @@ def test_openapi_stateless(case: schemathesis.Case):
}.get(key, DEFAULT_TIMEOUT_SECONDS)
# No need to verify SSL certificate for localhost
case.call_and_validate(verify=False, timeout=timeout)
case.call_and_validate(
verify=False,
timeout=timeout,
additional_checks=(customized_not_a_server_error,),
excluded_checks=(not_a_server_error,),
)
......@@ -23,6 +23,7 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.entrypoints.openai.models.serving import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.parser.harmony_utils import get_encoding
from vllm.exceptions import VLLMValidationError
from vllm.inputs import TokensPrompt
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.renderers.hf import HfRenderer
......@@ -818,9 +819,8 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated():
max_tokens=10,
)
resp = await serving_chat.create_chat_completion(req)
assert isinstance(resp, ErrorResponse)
assert "context length is only" in resp.error.message
with pytest.raises(VLLMValidationError):
await serving_chat.create_chat_completion(req)
@pytest.mark.asyncio
......@@ -860,9 +860,8 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
max_tokens=1,
)
resp = await serving_chat.create_chat_completion(req)
assert isinstance(resp, ErrorResponse)
assert "context length is only" in resp.error.message
with pytest.raises(VLLMValidationError):
await serving_chat.create_chat_completion(req)
@pytest.mark.asyncio
......
......@@ -17,9 +17,6 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionResponse,
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
)
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.inputs import PromptType
......@@ -542,11 +539,9 @@ async def test_header_dp_rank_argument():
# Test 2: Out-of-range DP rank (1)
mock_raw_request.headers = {"X-data-parallel-rank": "1"}
# should return ErrorResponse for out-of-range rank
response2 = await serving_chat.create_chat_completion(req, mock_raw_request)
assert isinstance(response2, ErrorResponse), (
"Expected an ErrorResponse for out-of-range DP rank"
)
# should raise ValueError for out-of-range rank
with pytest.raises(ValueError):
await serving_chat.create_chat_completion(req, mock_raw_request)
@pytest.mark.asyncio
......
......@@ -4,11 +4,10 @@
import asyncio
import signal
import socket
from http import HTTPStatus
from typing import Any
import uvicorn
from fastapi import FastAPI, Request, Response
from fastapi import FastAPI
from vllm import envs
from vllm.engine.protocol import EngineClient
......@@ -19,7 +18,6 @@ from vllm.entrypoints.constants import (
from vllm.entrypoints.ssl import SSLCertRefresher
from vllm.logger import init_logger
from vllm.utils.network_utils import find_process_using_port
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
logger = init_logger(__name__)
......@@ -75,7 +73,7 @@ async def serve_http(
config.h11_max_header_count = h11_max_header_count
config.load()
server = uvicorn.Server(config)
_add_shutdown_handlers(app, server)
app.state.server = server
loop = asyncio.get_running_loop()
......@@ -148,40 +146,3 @@ def terminate_if_errored(server: uvicorn.Server, engine: EngineClient):
engine_errored = engine.errored and not engine.is_running
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored:
server.should_exit = True
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""
VLLM V1 AsyncLLM catches exceptions and returns
only two types: EngineGenerateError and EngineDeadError.
EngineGenerateError is raised by the per request generate()
method. This error could be request specific (and therefore
recoverable - e.g. if there is an error in input processing).
EngineDeadError is raised by the background output_handler
method. This error is global and therefore not recoverable.
We register these @app.exception_handlers to return nice
responses to the end user if they occur and shut down if needed.
See https://fastapi.tiangolo.com/tutorial/handling-errors/
for more details on how exception handlers work.
If an exception is encountered in a StreamingResponse
generator, the exception is not raised, since we already sent
a 200 status. Rather, we send an error message as the next chunk.
Since the exception is not raised, this means that the server
will not automatically shut down. Instead, we use the watchdog
background task for check for errored state.
"""
@app.exception_handler(RuntimeError)
@app.exception_handler(EngineDeadError)
@app.exception_handler(EngineGenerateError)
async def runtime_exception_handler(request: Request, __):
terminate_if_errored(
server=server,
engine=request.app.state.engine_client,
)
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
......@@ -31,6 +31,8 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_se
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.server_utils import (
engine_error_handler,
exception_handler,
get_uvicorn_log_config,
http_exception_handler,
lifespan,
......@@ -57,6 +59,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.network_utils import is_valid_ipv6_address
from vllm.utils.system_utils import decorate_logs, set_ulimit
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.version import __version__ as VLLM_VERSION
prometheus_multiproc_dir: tempfile.TemporaryDirectory
......@@ -250,6 +253,9 @@ def build_app(
app.exception_handler(HTTPException)(http_exception_handler)
app.exception_handler(RequestValidationError)(validation_exception_handler)
app.exception_handler(EngineGenerateError)(engine_error_handler)
app.exception_handler(EngineDeadError)(engine_error_handler)
app.exception_handler(Exception)(exception_handler)
# Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
......@@ -355,7 +361,6 @@ async def init_app_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
if any(task in supported_tasks for task in ("generate", "render")):
......
......@@ -39,6 +39,7 @@ def chat(request: Request) -> OpenAIServingChat | None:
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
},
)
@with_cancellation
......@@ -54,10 +55,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
message="The model does not support Chat Completions API"
)
try:
generator = await handler.create_chat_completion(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
......@@ -81,6 +79,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
},
)
async def render_chat_completion(request: ChatCompletionRequest, raw_request: Request):
......@@ -93,10 +92,7 @@ async def render_chat_completion(request: ChatCompletionRequest, raw_request: Re
message="The model does not support Chat Completions API"
)
try:
result = await handler.render_chat_request(request)
except Exception as e:
result = handler.create_error_response(e)
if isinstance(result, ErrorResponse):
return JSONResponse(content=result.model_dump(), status_code=result.error.code)
......
......@@ -8,7 +8,6 @@ from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import Any, Final
import jinja2
import partial_json_parser
import regex as re
from fastapi import Request
......@@ -105,7 +104,6 @@ class OpenAIServingChat(OpenAIServing):
enable_force_include_usage: bool = False,
enable_log_outputs: bool = False,
enable_log_deltas: bool = True,
log_error_stack: bool = False,
default_chat_template_kwargs: dict[str, Any] | None = None,
) -> None:
super().__init__(
......@@ -113,7 +111,6 @@ class OpenAIServingChat(OpenAIServing):
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
)
self.response_role = response_role
......@@ -235,7 +232,6 @@ class OpenAIServingChat(OpenAIServing):
if self.engine_client.errored:
raise self.engine_client.dead_error
try:
tokenizer = self.renderer.tokenizer
tool_parser = self.tool_parser
......@@ -275,8 +271,7 @@ class OpenAIServingChat(OpenAIServing):
)
if request.tools is None or (
request.tool_choice == "none"
and self.exclude_tools_when_tool_choice_none
request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none
):
tool_dicts = None
else:
......@@ -307,9 +302,6 @@ class OpenAIServingChat(OpenAIServing):
conversation, engine_prompts = self._make_request_with_harmony(
request, should_include_tools
)
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(e)
return conversation, engine_prompts
......@@ -329,7 +321,6 @@ class OpenAIServingChat(OpenAIServing):
tokenizer = self.renderer.tokenizer
assert tokenizer is not None
reasoning_parser: ReasoningParser | None = None
try:
if self.reasoning_parser_cls:
# Pass the same chat template kwargs as used in tokenization
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
......@@ -340,9 +331,6 @@ class OpenAIServingChat(OpenAIServing):
tokenizer,
chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg]
)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
return self.create_error_response(str(e))
result = await self.render_chat_request(request)
if isinstance(result, ErrorResponse):
return result
......@@ -357,15 +345,9 @@ class OpenAIServingChat(OpenAIServing):
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
lora_request = self._maybe_get_adapters(
request, supports_default_mm_loras=True
)
lora_request = self._maybe_get_adapters(request, supports_default_mm_loras=True)
model_name = self.models.model_name(lora_request)
except (ValueError, TypeError, RuntimeError) as e:
logger.exception("Error preparing request components")
return self.create_error_response(e)
# Extract data_parallel_rank from header (router can inject it)
data_parallel_rank = self._get_data_parallel_rank(raw_request)
......@@ -373,11 +355,8 @@ class OpenAIServingChat(OpenAIServing):
# Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_token_ids = self._extract_prompt_components(
engine_prompt
).token_ids
prompt_token_ids = self._extract_prompt_components(engine_prompt).token_ids
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
......@@ -446,8 +425,6 @@ class OpenAIServingChat(OpenAIServing):
)
generators.append(generator)
except ValueError as e:
return self.create_error_response(e)
assert len(generators) == 1
(result_generator,) = generators
......@@ -464,7 +441,6 @@ class OpenAIServingChat(OpenAIServing):
reasoning_parser,
)
try:
return await self.chat_completion_full_generator(
request,
result_generator,
......@@ -475,10 +451,6 @@ class OpenAIServingChat(OpenAIServing):
request_metadata,
reasoning_parser,
)
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e:
return self.create_error_response(e)
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
......@@ -1414,8 +1386,6 @@ class OpenAIServingChat(OpenAIServing):
final_res = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
return self.create_error_response(e)
assert final_res is not None
......
......@@ -54,10 +54,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
message="The model does not support Completions API"
)
try:
generator = await handler.create_completion(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
......@@ -91,10 +88,7 @@ async def render_completion(request: CompletionRequest, raw_request: Request):
message="The model does not support Completions API"
)
try:
result = await handler.render_completion_request(request)
except Exception as e:
result = handler.create_error_response(e)
if isinstance(result, ErrorResponse):
return JSONResponse(content=result.model_dump(), status_code=result.error.code)
......
......@@ -7,7 +7,6 @@ from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import cast
import jinja2
from fastapi import Request
from vllm.engine.protocol import EngineClient
......@@ -56,14 +55,12 @@ class OpenAIServingCompletion(OpenAIServing):
return_tokens_as_token_ids: bool = False,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
log_error_stack: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
)
self.enable_prompt_tokens_details = enable_prompt_tokens_details
......@@ -110,15 +107,11 @@ class OpenAIServingCompletion(OpenAIServing):
"prompt_logprobs is not compatible with prompt embeds."
)
try:
engine_prompts = await self._preprocess_completion(
request,
prompt_input=request.prompt,
prompt_embeds=request.prompt_embeds,
)
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(e)
return engine_prompts
......@@ -149,11 +142,7 @@ class OpenAIServingCompletion(OpenAIServing):
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
lora_request = self._maybe_get_adapters(request)
except (ValueError, TypeError, RuntimeError) as e:
logger.exception("Error preparing request components")
return self.create_error_response(e)
# Extract data_parallel_rank from header (router can inject it)
data_parallel_rank = self._get_data_parallel_rank(raw_request)
......@@ -161,7 +150,6 @@ class OpenAIServingCompletion(OpenAIServing):
# Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
max_tokens = get_max_tokens(
max_model_len,
......@@ -217,8 +205,6 @@ class OpenAIServingCompletion(OpenAIServing):
)
generators.append(generator)
except ValueError as e:
return self.create_error_response(e)
result_generator = merge_async_iterators(*generators)
......@@ -273,10 +259,6 @@ class OpenAIServingCompletion(OpenAIServing):
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e:
return self.create_error_response(e)
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
......
......@@ -4,6 +4,7 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from http import HTTPStatus
from typing import Any, ClassVar, Literal, TypeAlias
import regex as re
......@@ -262,6 +263,14 @@ class DeltaMessage(OpenAIBaseModel):
tool_calls: list[DeltaToolCall] = Field(default_factory=list)
class GenerationError(Exception):
"""raised when finish_reason indicates internal server error (500)"""
def __init__(self, message: str = "Internal server error"):
super().__init__(message)
self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
request_id: str = Field(
......
......@@ -2,9 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import json
import sys
import time
import traceback
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
from dataclasses import dataclass, field
from http import HTTPStatus
......@@ -38,10 +36,10 @@ from vllm.entrypoints.openai.completion.protocol import (
CompletionResponse,
)
from vllm.entrypoints.openai.engine.protocol import (
ErrorInfo,
ErrorResponse,
FunctionCall,
FunctionDefinition,
GenerationError,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.responses.context import (
......@@ -89,7 +87,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest,
TokenizeResponse,
)
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
from vllm.entrypoints.utils import create_error_response, get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import (
ProcessorInputs,
......@@ -125,15 +123,6 @@ from vllm.utils.async_utils import (
)
from vllm.utils.mistral import is_mistral_tokenizer
class GenerationError(Exception):
"""raised when finish_reason indicates internal server error (500)"""
def __init__(self, message: str = "Internal server error"):
super().__init__(message)
self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR
logger = init_logger(__name__)
......@@ -225,7 +214,6 @@ class OpenAIServing:
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
):
super().__init__()
......@@ -236,8 +224,6 @@ class OpenAIServing:
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self.log_error_stack = log_error_stack
self.model_config = engine_client.model_config
self.renderer = engine_client.renderer
self.io_processor = engine_client.io_processor
......@@ -526,7 +512,6 @@ class OpenAIServing:
"""Schedule the request and get the result generator."""
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
trace_headers = (
None
if ctx.raw_request is None
......@@ -565,15 +550,11 @@ class OpenAIServing:
return None
except Exception as e:
return self.create_error_response(e)
async def _collect_batch(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
"""Collect batch results from the result generator."""
try:
if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available")
......@@ -596,63 +577,14 @@ class OpenAIServing:
return None
except Exception as e:
return self.create_error_response(e)
@staticmethod
def create_error_response(
self,
message: str | Exception,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
param: str | None = None,
) -> ErrorResponse:
exc: Exception | None = None
if isinstance(message, Exception):
exc = message
from vllm.exceptions import VLLMValidationError
if isinstance(exc, VLLMValidationError):
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = exc.parameter
elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
# Common validation errors from user input
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
elif isinstance(exc, NotImplementedError):
err_type = "NotImplementedError"
status_code = HTTPStatus.NOT_IMPLEMENTED
param = None
elif exc.__class__.__name__ == "TemplateError":
# jinja2.TemplateError (avoid importing jinja2)
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
else:
err_type = "InternalServerError"
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
param = None
message = str(exc)
if self.log_error_stack:
exc_type, _, _ = sys.exc_info()
if exc_type is not None:
traceback.print_exc()
else:
traceback.print_stack()
return ErrorResponse(
error=ErrorInfo(
message=sanitize_message(message),
type=err_type,
code=status_code.value,
param=param,
)
)
return create_error_response(message, err_type, status_code, param)
def create_streaming_error_response(
self,
......@@ -680,16 +612,6 @@ class OpenAIServing:
)
raise GenerationError("Internal server error")
def _convert_generation_error_to_response(
self, e: GenerationError
) -> ErrorResponse:
"""Convert GenerationError to ErrorResponse."""
return self.create_error_response(
str(e),
err_type="InternalServerError",
status_code=e.status_code,
)
def _convert_generation_error_to_streaming_response(
self, e: GenerationError
) -> str:
......
......@@ -87,7 +87,6 @@ async def init_generate_state(
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
......@@ -111,7 +110,6 @@ async def init_generate_state(
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
enable_log_deltas=args.enable_log_deltas,
log_error_stack=args.log_error_stack,
)
if any(task in supported_tasks for task in ("generate", "render"))
else None
......@@ -127,7 +125,6 @@ async def init_generate_state(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
log_error_stack=args.log_error_stack,
)
if any(task in supported_tasks for task in ("generate", "render"))
else None
......@@ -156,7 +153,6 @@ async def init_generate_state(
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
log_error_stack=args.log_error_stack,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_log_outputs=args.enable_log_outputs,
force_no_detokenize=args.tokens_only,
......
......@@ -68,7 +68,6 @@ def init_realtime_state(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
log_error_stack=args.log_error_stack,
)
if "realtime" in supported_tasks
else None
......
......@@ -33,13 +33,11 @@ class OpenAIServingRealtime(OpenAIServing):
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
log_error_stack: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.task_type: Literal["realtime"] = "realtime"
......
......@@ -63,10 +63,8 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
return base_server.create_error_response(
message="The model does not support Responses API"
)
try:
generator = await handler.create_responses(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
......@@ -95,14 +93,11 @@ async def retrieve_responses(
message="The model does not support Responses API"
)
try:
response = await handler.retrieve_responses(
response_id,
starting_after=starting_after,
stream=stream,
)
except Exception as e:
response = handler.create_error_response(e)
if isinstance(response, ErrorResponse):
return JSONResponse(
......@@ -125,10 +120,7 @@ async def cancel_responses(response_id: str, raw_request: Request):
message="The model does not support Responses API"
)
try:
response = await handler.cancel_responses(response_id)
except Exception as e:
response = handler.create_error_response(e)
if isinstance(response, ErrorResponse):
return JSONResponse(
......
......@@ -11,7 +11,6 @@ from copy import copy
from http import HTTPStatus
from typing import Final
import jinja2
from fastapi import Request
from openai.types.responses import (
ResponseContentPartAddedEvent,
......@@ -174,14 +173,12 @@ class OpenAIServingResponses(OpenAIServing):
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
enable_log_outputs: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
......@@ -365,7 +362,6 @@ class OpenAIServingResponses(OpenAIServing):
else:
prev_response = None
try:
lora_request = self._maybe_get_adapters(request)
model_name = self.models.model_name(lora_request)
......@@ -374,19 +370,7 @@ class OpenAIServingResponses(OpenAIServing):
request, prev_response
)
else:
messages, engine_prompts = await self._make_request(
request, prev_response
)
except (
ValueError,
TypeError,
RuntimeError,
jinja2.TemplateError,
NotImplementedError,
) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(e)
messages, engine_prompts = await self._make_request(request, prev_response)
request_metadata = RequestResponseMetadata(request_id=request.request_id)
if raw_request:
......@@ -424,7 +408,6 @@ class OpenAIServingResponses(OpenAIServing):
else:
assert len(builtin_tool_list) == 0
available_tools = []
try:
tokenizer = self.renderer.get_tokenizer()
for engine_prompt in engine_prompts:
......@@ -502,8 +485,6 @@ class OpenAIServingResponses(OpenAIServing):
trace_headers=trace_headers,
)
generators.append(generator)
except ValueError as e:
return self.create_error_response(e)
assert len(generators) == 1
(result_generator,) = generators
......@@ -578,7 +559,6 @@ class OpenAIServingResponses(OpenAIServing):
request_metadata,
)
try:
return await self.responses_full_generator(
request,
sampling_params,
......@@ -588,10 +568,6 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer,
request_metadata,
)
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except Exception as e:
return self.create_error_response(e)
async def _make_request(
self,
......@@ -675,8 +651,6 @@ class OpenAIServingResponses(OpenAIServing):
pass
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
return self.create_error_response(e)
# NOTE: Implementation of status is still WIP, but for now
# we guarantee that if the status is not "completed", it is accurate.
......@@ -1129,16 +1103,11 @@ class OpenAIServingResponses(OpenAIServing):
new_event_signal = asyncio.Event()
self.event_store[request.request_id] = (event_deque, new_event_signal)
response = None
try:
generator = self.responses_stream_generator(request, *args, **kwargs)
try:
async for event in generator:
event_deque.append(event)
new_event_signal.set() # Signal new event available
except GenerationError as e:
response = self._convert_generation_error_to_response(e)
except Exception as e:
logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(e)
finally:
new_event_signal.set()
......@@ -1157,13 +1126,7 @@ class OpenAIServingResponses(OpenAIServing):
*args,
**kwargs,
):
try:
response = await self.responses_full_generator(request, *args, **kwargs)
except GenerationError as e:
response = self._convert_generation_error_to_response(e)
except Exception as e:
logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(e)
if isinstance(response, ErrorResponse):
# If the request has failed, update the status to "failed".
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment