Unverified Commit 176c799f authored by Ning Xie's avatar Ning Xie Committed by GitHub
Browse files

[openai api] log exception in exception handler (1/N) (#31164)


Signed-off-by: default avatarAndy Xie <andy.xning@gmail.com>
parent 612e7729
...@@ -6,7 +6,6 @@ from unittest.mock import MagicMock ...@@ -6,7 +6,6 @@ from unittest.mock import MagicMock
import pytest import pytest
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import GenerationError, OpenAIServing from vllm.entrypoints.openai.engine.serving import GenerationError, OpenAIServing
...@@ -38,32 +37,6 @@ async def test_raise_if_error_raises_generation_error(): ...@@ -38,32 +37,6 @@ async def test_raise_if_error_raises_generation_error():
serving._raise_if_error(None, "test-request-id") # should not raise serving._raise_if_error(None, "test-request-id") # should not raise
@pytest.mark.asyncio
async def test_convert_generation_error_to_response():
"""test _convert_generation_error_to_response creates proper ErrorResponse"""
mock_engine = MagicMock()
mock_engine.model_config = MagicMock()
mock_engine.model_config.max_model_len = 100
mock_models = MagicMock()
serving = OpenAIServing(
engine_client=mock_engine,
models=mock_models,
request_logger=None,
)
# create a GenerationError
gen_error = GenerationError("Internal server error")
# convert to ErrorResponse
error_response = serving._convert_generation_error_to_response(gen_error)
assert isinstance(error_response, ErrorResponse)
assert error_response.error.type == "InternalServerError"
assert error_response.error.message == "Internal server error"
assert error_response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_convert_generation_error_to_streaming_response(): async def test_convert_generation_error_to_streaming_response():
"""test _convert_generation_error_to_streaming_response output""" """test _convert_generation_error_to_streaming_response output"""
......
...@@ -13,7 +13,7 @@ from typing import Any ...@@ -13,7 +13,7 @@ from typing import Any
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import requests import requests
from openai import BadRequestError, NotFoundError, OpenAI from openai import InternalServerError, NotFoundError, OpenAI
from openai_harmony import Message from openai_harmony import Message
from ....utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
...@@ -698,7 +698,7 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str): ...@@ -698,7 +698,7 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str):
async def test_function_calling_required(client: OpenAI, model_name: str): async def test_function_calling_required(client: OpenAI, model_name: str):
tools = [GET_WEATHER_SCHEMA] tools = [GET_WEATHER_SCHEMA]
with pytest.raises(BadRequestError): with pytest.raises(InternalServerError):
await client.responses.create( await client.responses.create(
model=model_name, model=model_name,
input="What's the weather like in Paris today?", input="What's the weather like in Paris today?",
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
...@@ -11,7 +10,7 @@ import pytest ...@@ -11,7 +10,7 @@ import pytest
from vllm.config.multimodal import MultiModalConfig from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.protocol import GenerationError
from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
...@@ -145,12 +144,8 @@ async def test_chat_error_non_stream(): ...@@ -145,12 +144,8 @@ async def test_chat_error_non_stream():
stream=False, stream=False,
) )
response = await serving_chat.create_chat_completion(request) with pytest.raises(GenerationError):
await serving_chat.create_chat_completion(request)
assert isinstance(response, ErrorResponse)
assert response.error.type == "InternalServerError"
assert response.error.message == "Internal server error"
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any from typing import Any
from unittest.mock import MagicMock from unittest.mock import MagicMock
...@@ -11,7 +10,7 @@ import pytest ...@@ -11,7 +10,7 @@ import pytest
from vllm.config.multimodal import MultiModalConfig from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.completion.protocol import CompletionRequest from vllm.entrypoints.openai.completion.protocol import CompletionRequest
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.protocol import GenerationError
from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
...@@ -131,12 +130,8 @@ async def test_completion_error_non_stream(): ...@@ -131,12 +130,8 @@ async def test_completion_error_non_stream():
stream=False, stream=False,
) )
response = await serving_completion.create_completion(request) with pytest.raises(GenerationError):
await serving_completion.create_completion(request)
assert isinstance(response, ErrorResponse)
assert response.error.type == "InternalServerError"
assert response.error.message == "Internal server error"
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
@pytest.mark.asyncio @pytest.mark.asyncio
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json import json
from http import HTTPStatus
from typing import Final from typing import Final
import pytest import pytest
import schemathesis import schemathesis
from httpx import URL
from hypothesis import settings from hypothesis import settings
from schemathesis import GenerationConfig from schemathesis import GenerationConfig
from schemathesis.checks import not_a_server_error
from schemathesis.internal.checks import CheckContext
from schemathesis.models import Case
from schemathesis.transports.responses import GenericResponse
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
...@@ -127,10 +133,25 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy): ...@@ -127,10 +133,25 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy):
return strategy.filter(no_invalid_types) return strategy.filter(no_invalid_types)
def customized_not_a_server_error(
ctx: CheckContext, response: GenericResponse, case: Case
) -> bool | None:
try:
return not_a_server_error(ctx, response, case)
except Exception:
if (
URL(response.request.url).path
in ["/v1/chat/completions/render", "/v1/chat/completions"]
and response.status_code == HTTPStatus.NOT_IMPLEMENTED.value
):
return True
raise
@schema.parametrize() @schema.parametrize()
@schema.override(headers={"Content-Type": "application/json"}) @schema.override(headers={"Content-Type": "application/json"})
@settings(deadline=LONG_TIMEOUT_SECONDS * 1000, max_examples=50) @settings(deadline=LONG_TIMEOUT_SECONDS * 1000, max_examples=50)
def test_openapi_stateless(case: schemathesis.Case): def test_openapi_stateless(case: Case):
key = ( key = (
case.operation.method.upper(), case.operation.method.upper(),
case.operation.path, case.operation.path,
...@@ -155,4 +176,9 @@ def test_openapi_stateless(case: schemathesis.Case): ...@@ -155,4 +176,9 @@ def test_openapi_stateless(case: schemathesis.Case):
}.get(key, DEFAULT_TIMEOUT_SECONDS) }.get(key, DEFAULT_TIMEOUT_SECONDS)
# No need to verify SSL certificate for localhost # No need to verify SSL certificate for localhost
case.call_and_validate(verify=False, timeout=timeout) case.call_and_validate(
verify=False,
timeout=timeout,
additional_checks=(customized_not_a_server_error,),
excluded_checks=(not_a_server_error,),
)
...@@ -23,6 +23,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -23,6 +23,7 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.entrypoints.openai.models.serving import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.models.serving import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.parser.harmony_utils import get_encoding from vllm.entrypoints.openai.parser.harmony_utils import get_encoding
from vllm.exceptions import VLLMValidationError
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.renderers.hf import HfRenderer from vllm.renderers.hf import HfRenderer
...@@ -818,9 +819,8 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated(): ...@@ -818,9 +819,8 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated():
max_tokens=10, max_tokens=10,
) )
resp = await serving_chat.create_chat_completion(req) with pytest.raises(VLLMValidationError):
assert isinstance(resp, ErrorResponse) await serving_chat.create_chat_completion(req)
assert "context length is only" in resp.error.message
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -860,9 +860,8 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected(): ...@@ -860,9 +860,8 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
max_tokens=1, max_tokens=1,
) )
resp = await serving_chat.create_chat_completion(req) with pytest.raises(VLLMValidationError):
assert isinstance(resp, ErrorResponse) await serving_chat.create_chat_completion(req)
assert "context length is only" in resp.error.message
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -17,9 +17,6 @@ from vllm.entrypoints.openai.chat_completion.protocol import ( ...@@ -17,9 +17,6 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionResponse, ChatCompletionResponse,
) )
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
)
from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.inputs import PromptType from vllm.inputs import PromptType
...@@ -542,11 +539,9 @@ async def test_header_dp_rank_argument(): ...@@ -542,11 +539,9 @@ async def test_header_dp_rank_argument():
# Test 2: Out-of-range DP rank (1) # Test 2: Out-of-range DP rank (1)
mock_raw_request.headers = {"X-data-parallel-rank": "1"} mock_raw_request.headers = {"X-data-parallel-rank": "1"}
# should return ErrorResponse for out-of-range rank # should raise ValueError for out-of-range rank
response2 = await serving_chat.create_chat_completion(req, mock_raw_request) with pytest.raises(ValueError):
assert isinstance(response2, ErrorResponse), ( await serving_chat.create_chat_completion(req, mock_raw_request)
"Expected an ErrorResponse for out-of-range DP rank"
)
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -4,11 +4,10 @@ ...@@ -4,11 +4,10 @@
import asyncio import asyncio
import signal import signal
import socket import socket
from http import HTTPStatus
from typing import Any from typing import Any
import uvicorn import uvicorn
from fastapi import FastAPI, Request, Response from fastapi import FastAPI
from vllm import envs from vllm import envs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
...@@ -19,7 +18,6 @@ from vllm.entrypoints.constants import ( ...@@ -19,7 +18,6 @@ from vllm.entrypoints.constants import (
from vllm.entrypoints.ssl import SSLCertRefresher from vllm.entrypoints.ssl import SSLCertRefresher
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.network_utils import find_process_using_port from vllm.utils.network_utils import find_process_using_port
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -75,7 +73,7 @@ async def serve_http( ...@@ -75,7 +73,7 @@ async def serve_http(
config.h11_max_header_count = h11_max_header_count config.h11_max_header_count = h11_max_header_count
config.load() config.load()
server = uvicorn.Server(config) server = uvicorn.Server(config)
_add_shutdown_handlers(app, server) app.state.server = server
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
...@@ -148,40 +146,3 @@ def terminate_if_errored(server: uvicorn.Server, engine: EngineClient): ...@@ -148,40 +146,3 @@ def terminate_if_errored(server: uvicorn.Server, engine: EngineClient):
engine_errored = engine.errored and not engine.is_running engine_errored = engine.errored and not engine.is_running
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored: if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored:
server.should_exit = True server.should_exit = True
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""
VLLM V1 AsyncLLM catches exceptions and returns
only two types: EngineGenerateError and EngineDeadError.
EngineGenerateError is raised by the per request generate()
method. This error could be request specific (and therefore
recoverable - e.g. if there is an error in input processing).
EngineDeadError is raised by the background output_handler
method. This error is global and therefore not recoverable.
We register these @app.exception_handlers to return nice
responses to the end user if they occur and shut down if needed.
See https://fastapi.tiangolo.com/tutorial/handling-errors/
for more details on how exception handlers work.
If an exception is encountered in a StreamingResponse
generator, the exception is not raised, since we already sent
a 200 status. Rather, we send an error message as the next chunk.
Since the exception is not raised, this means that the server
will not automatically shut down. Instead, we use the watchdog
background task for check for errored state.
"""
@app.exception_handler(RuntimeError)
@app.exception_handler(EngineDeadError)
@app.exception_handler(EngineGenerateError)
async def runtime_exception_handler(request: Request, __):
terminate_if_errored(
server=server,
engine=request.app.state.engine_client,
)
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
...@@ -31,6 +31,8 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_se ...@@ -31,6 +31,8 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_se
from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.server_utils import ( from vllm.entrypoints.openai.server_utils import (
engine_error_handler,
exception_handler,
get_uvicorn_log_config, get_uvicorn_log_config,
http_exception_handler, http_exception_handler,
lifespan, lifespan,
...@@ -57,6 +59,7 @@ from vllm.usage.usage_lib import UsageContext ...@@ -57,6 +59,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.network_utils import is_valid_ipv6_address from vllm.utils.network_utils import is_valid_ipv6_address
from vllm.utils.system_utils import decorate_logs, set_ulimit from vllm.utils.system_utils import decorate_logs, set_ulimit
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
prometheus_multiproc_dir: tempfile.TemporaryDirectory prometheus_multiproc_dir: tempfile.TemporaryDirectory
...@@ -250,6 +253,9 @@ def build_app( ...@@ -250,6 +253,9 @@ def build_app(
app.exception_handler(HTTPException)(http_exception_handler) app.exception_handler(HTTPException)(http_exception_handler)
app.exception_handler(RequestValidationError)(validation_exception_handler) app.exception_handler(RequestValidationError)(validation_exception_handler)
app.exception_handler(EngineGenerateError)(engine_error_handler)
app.exception_handler(EngineDeadError)(engine_error_handler)
app.exception_handler(Exception)(exception_handler)
# Ensure --api-key option from CLI takes precedence over VLLM_API_KEY # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]: if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
...@@ -355,7 +361,6 @@ async def init_app_state( ...@@ -355,7 +361,6 @@ async def init_app_state(
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template, trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
) )
if any(task in supported_tasks for task in ("generate", "render")): if any(task in supported_tasks for task in ("generate", "render")):
......
...@@ -39,6 +39,7 @@ def chat(request: Request) -> OpenAIServingChat | None: ...@@ -39,6 +39,7 @@ def chat(request: Request) -> OpenAIServingChat | None:
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
}, },
) )
@with_cancellation @with_cancellation
...@@ -54,10 +55,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re ...@@ -54,10 +55,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
message="The model does not support Chat Completions API" message="The model does not support Chat Completions API"
) )
try: generator = await handler.create_chat_completion(request, raw_request)
generator = await handler.create_chat_completion(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
...@@ -81,6 +79,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re ...@@ -81,6 +79,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
}, },
) )
async def render_chat_completion(request: ChatCompletionRequest, raw_request: Request): async def render_chat_completion(request: ChatCompletionRequest, raw_request: Request):
...@@ -93,10 +92,7 @@ async def render_chat_completion(request: ChatCompletionRequest, raw_request: Re ...@@ -93,10 +92,7 @@ async def render_chat_completion(request: ChatCompletionRequest, raw_request: Re
message="The model does not support Chat Completions API" message="The model does not support Chat Completions API"
) )
try: result = await handler.render_chat_request(request)
result = await handler.render_chat_request(request)
except Exception as e:
result = handler.create_error_response(e)
if isinstance(result, ErrorResponse): if isinstance(result, ErrorResponse):
return JSONResponse(content=result.model_dump(), status_code=result.error.code) return JSONResponse(content=result.model_dump(), status_code=result.error.code)
......
...@@ -8,7 +8,6 @@ from collections.abc import AsyncGenerator, AsyncIterator ...@@ -8,7 +8,6 @@ from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence from collections.abc import Sequence as GenericSequence
from typing import Any, Final from typing import Any, Final
import jinja2
import partial_json_parser import partial_json_parser
import regex as re import regex as re
from fastapi import Request from fastapi import Request
...@@ -105,7 +104,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -105,7 +104,6 @@ class OpenAIServingChat(OpenAIServing):
enable_force_include_usage: bool = False, enable_force_include_usage: bool = False,
enable_log_outputs: bool = False, enable_log_outputs: bool = False,
enable_log_deltas: bool = True, enable_log_deltas: bool = True,
log_error_stack: bool = False,
default_chat_template_kwargs: dict[str, Any] | None = None, default_chat_template_kwargs: dict[str, Any] | None = None,
) -> None: ) -> None:
super().__init__( super().__init__(
...@@ -113,7 +111,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -113,7 +111,6 @@ class OpenAIServingChat(OpenAIServing):
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
) )
self.response_role = response_role self.response_role = response_role
...@@ -235,81 +232,76 @@ class OpenAIServingChat(OpenAIServing): ...@@ -235,81 +232,76 @@ class OpenAIServingChat(OpenAIServing):
if self.engine_client.errored: if self.engine_client.errored:
raise self.engine_client.dead_error raise self.engine_client.dead_error
try: tokenizer = self.renderer.tokenizer
tokenizer = self.renderer.tokenizer
tool_parser = self.tool_parser
if is_mistral_tokenizer(tokenizer):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
_mt.maybe_serialize_tool_calls(request) # type: ignore[arg-type]
_mt.truncate_tool_call_ids(request) # type: ignore[arg-type]
_mt.validate_request_params(request)
# Check if tool parsing is unavailable (common condition)
tool_parsing_unavailable = (
tool_parser is None
and not is_mistral_tokenizer(tokenizer)
and not self.use_harmony
)
# Validate tool_choice when tool parsing is required but unavailable tool_parser = self.tool_parser
if tool_parsing_unavailable and request.tool_choice not in (
None, if is_mistral_tokenizer(tokenizer):
"none", # because of issues with pydantic we need to potentially
): # re-serialize the tool_calls field of the request
if request.tool_choice == "auto" and not self.enable_auto_tools: # for more info: see comment in `maybe_serialize_tool_calls`
# for hf tokenizers, "auto" tools requires _mt.maybe_serialize_tool_calls(request) # type: ignore[arg-type]
# --enable-auto-tool-choice and --tool-call-parser _mt.truncate_tool_call_ids(request) # type: ignore[arg-type]
return self.create_error_response( _mt.validate_request_params(request)
'"auto" tool choice requires '
"--enable-auto-tool-choice and --tool-call-parser to be set" # Check if tool parsing is unavailable (common condition)
) tool_parsing_unavailable = (
elif request.tool_choice != "auto": tool_parser is None
# "required" or named tool requires tool parser and not is_mistral_tokenizer(tokenizer)
return self.create_error_response( and not self.use_harmony
f'tool_choice="{request.tool_choice}" requires ' )
"--tool-call-parser to be set"
)
if request.tools is None or ( # Validate tool_choice when tool parsing is required but unavailable
request.tool_choice == "none" if tool_parsing_unavailable and request.tool_choice not in (
and self.exclude_tools_when_tool_choice_none None,
): "none",
tool_dicts = None ):
else: if request.tool_choice == "auto" and not self.enable_auto_tools:
tool_dicts = [tool.model_dump() for tool in request.tools] # for hf tokenizers, "auto" tools requires
# --enable-auto-tool-choice and --tool-call-parser
if not self.use_harmony: return self.create_error_response(
# Common case. '"auto" tool choice requires '
error_check_ret = self._validate_chat_template( "--enable-auto-tool-choice and --tool-call-parser to be set"
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
) )
if error_check_ret is not None: elif request.tool_choice != "auto":
return error_check_ret # "required" or named tool requires tool parser
return self.create_error_response(
conversation, engine_prompts = await self._preprocess_chat( f'tool_choice="{request.tool_choice}" requires '
request, "--tool-call-parser to be set"
request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=self.default_chat_template_kwargs,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
) )
else:
# For GPT-OSS. if request.tools is None or (
should_include_tools = tool_dicts is not None request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none
conversation, engine_prompts = self._make_request_with_harmony( ):
request, should_include_tools tool_dicts = None
) else:
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: tool_dicts = [tool.model_dump() for tool in request.tools]
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(e) if not self.use_harmony:
# Common case.
error_check_ret = self._validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
conversation, engine_prompts = await self._preprocess_chat(
request,
request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=self.default_chat_template_kwargs,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
)
else:
# For GPT-OSS.
should_include_tools = tool_dicts is not None
conversation, engine_prompts = self._make_request_with_harmony(
request, should_include_tools
)
return conversation, engine_prompts return conversation, engine_prompts
...@@ -329,20 +321,16 @@ class OpenAIServingChat(OpenAIServing): ...@@ -329,20 +321,16 @@ class OpenAIServingChat(OpenAIServing):
tokenizer = self.renderer.tokenizer tokenizer = self.renderer.tokenizer
assert tokenizer is not None assert tokenizer is not None
reasoning_parser: ReasoningParser | None = None reasoning_parser: ReasoningParser | None = None
try: if self.reasoning_parser_cls:
if self.reasoning_parser_cls: # Pass the same chat template kwargs as used in tokenization
# Pass the same chat template kwargs as used in tokenization chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
chat_template_kwargs = self._prepare_extra_chat_template_kwargs( request.chat_template_kwargs,
request.chat_template_kwargs, self.default_chat_template_kwargs,
self.default_chat_template_kwargs, )
) reasoning_parser = self.reasoning_parser_cls(
reasoning_parser = self.reasoning_parser_cls( tokenizer,
tokenizer, chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg]
chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg] )
)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
return self.create_error_response(str(e))
result = await self.render_chat_request(request) result = await self.render_chat_request(request)
if isinstance(result, ErrorResponse): if isinstance(result, ErrorResponse):
return result return result
...@@ -357,15 +345,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -357,15 +345,9 @@ class OpenAIServingChat(OpenAIServing):
if raw_request: if raw_request:
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
try: lora_request = self._maybe_get_adapters(request, supports_default_mm_loras=True)
lora_request = self._maybe_get_adapters(
request, supports_default_mm_loras=True
)
model_name = self.models.model_name(lora_request) model_name = self.models.model_name(lora_request)
except (ValueError, TypeError, RuntimeError) as e:
logger.exception("Error preparing request components")
return self.create_error_response(e)
# Extract data_parallel_rank from header (router can inject it) # Extract data_parallel_rank from header (router can inject it)
data_parallel_rank = self._get_data_parallel_rank(raw_request) data_parallel_rank = self._get_data_parallel_rank(raw_request)
...@@ -373,81 +355,76 @@ class OpenAIServingChat(OpenAIServing): ...@@ -373,81 +355,76 @@ class OpenAIServingChat(OpenAIServing):
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len max_model_len = self.model_config.max_model_len
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: for i, engine_prompt in enumerate(engine_prompts):
for i, engine_prompt in enumerate(engine_prompts): prompt_token_ids = self._extract_prompt_components(engine_prompt).token_ids
prompt_token_ids = self._extract_prompt_components(
engine_prompt # If we are creating sub requests for multiple prompts, ensure that they
).token_ids # have unique request ids.
sub_request_id = (
# If we are creating sub requests for multiple prompts, ensure that they request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
# have unique request ids. )
sub_request_id = (
request_id if len(engine_prompts) == 1 else f"{request_id}_{i}" max_tokens = get_max_tokens(
) max_model_len,
request.max_completion_tokens
if request.max_completion_tokens is not None
else request.max_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
self.override_max_tokens,
)
max_tokens = get_max_tokens( sampling_params: SamplingParams | BeamSearchParams
max_model_len, if request.use_beam_search:
request.max_completion_tokens sampling_params = request.to_beam_search_params(
if request.max_completion_tokens is not None max_tokens, self.default_sampling_params
else request.max_tokens, )
self._extract_prompt_len(engine_prompt), else:
sampling_params = request.to_sampling_params(
max_tokens,
self.default_sampling_params, self.default_sampling_params,
self.override_max_tokens,
) )
sampling_params: SamplingParams | BeamSearchParams self._log_inputs(
if request.use_beam_search: sub_request_id,
sampling_params = request.to_beam_search_params( engine_prompt,
max_tokens, self.default_sampling_params params=sampling_params,
) lora_request=lora_request,
else: )
sampling_params = request.to_sampling_params(
max_tokens,
self.default_sampling_params,
)
self._log_inputs( trace_headers = (
sub_request_id, None
engine_prompt, if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
request_id=sub_request_id,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers,
) )
else:
trace_headers = ( reasoning_ended = (
None reasoning_parser.is_reasoning_end(prompt_token_ids or [])
if raw_request is None if reasoning_parser
else await self._get_trace_headers(raw_request.headers) else None
) )
if isinstance(sampling_params, BeamSearchParams): generator = self.engine_client.generate(
generator = self.beam_search( engine_prompt,
prompt=engine_prompt, sampling_params,
request_id=sub_request_id, sub_request_id,
params=sampling_params, lora_request=lora_request,
lora_request=lora_request, trace_headers=trace_headers,
trace_headers=trace_headers, priority=request.priority,
) data_parallel_rank=data_parallel_rank,
else: reasoning_ended=reasoning_ended,
reasoning_ended = ( )
reasoning_parser.is_reasoning_end(prompt_token_ids or [])
if reasoning_parser
else None
)
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
reasoning_ended=reasoning_ended,
)
generators.append(generator) generators.append(generator)
except ValueError as e:
return self.create_error_response(e)
assert len(generators) == 1 assert len(generators) == 1
(result_generator,) = generators (result_generator,) = generators
...@@ -464,21 +441,16 @@ class OpenAIServingChat(OpenAIServing): ...@@ -464,21 +441,16 @@ class OpenAIServingChat(OpenAIServing):
reasoning_parser, reasoning_parser,
) )
try: return await self.chat_completion_full_generator(
return await self.chat_completion_full_generator( request,
request, result_generator,
result_generator, request_id,
request_id, model_name,
model_name, conversation,
conversation, tokenizer,
tokenizer, request_metadata,
request_metadata, reasoning_parser,
reasoning_parser, )
)
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e:
return self.create_error_response(e)
def get_chat_request_role(self, request: ChatCompletionRequest) -> str: def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt: if request.add_generation_prompt:
...@@ -1414,8 +1386,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1414,8 +1386,6 @@ class OpenAIServingChat(OpenAIServing):
final_res = res final_res = res
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e:
return self.create_error_response(e)
assert final_res is not None assert final_res is not None
......
...@@ -54,10 +54,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ...@@ -54,10 +54,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
message="The model does not support Completions API" message="The model does not support Completions API"
) )
try: generator = await handler.create_completion(request, raw_request)
generator = await handler.create_completion(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
...@@ -91,10 +88,7 @@ async def render_completion(request: CompletionRequest, raw_request: Request): ...@@ -91,10 +88,7 @@ async def render_completion(request: CompletionRequest, raw_request: Request):
message="The model does not support Completions API" message="The model does not support Completions API"
) )
try: result = await handler.render_completion_request(request)
result = await handler.render_completion_request(request)
except Exception as e:
result = handler.create_error_response(e)
if isinstance(result, ErrorResponse): if isinstance(result, ErrorResponse):
return JSONResponse(content=result.model_dump(), status_code=result.error.code) return JSONResponse(content=result.model_dump(), status_code=result.error.code)
......
...@@ -7,7 +7,6 @@ from collections.abc import AsyncGenerator, AsyncIterator ...@@ -7,7 +7,6 @@ from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence from collections.abc import Sequence as GenericSequence
from typing import cast from typing import cast
import jinja2
from fastapi import Request from fastapi import Request
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
...@@ -56,14 +55,12 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -56,14 +55,12 @@ class OpenAIServingCompletion(OpenAIServing):
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
enable_prompt_tokens_details: bool = False, enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False, enable_force_include_usage: bool = False,
log_error_stack: bool = False,
): ):
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
) )
self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_prompt_tokens_details = enable_prompt_tokens_details
...@@ -110,15 +107,11 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -110,15 +107,11 @@ class OpenAIServingCompletion(OpenAIServing):
"prompt_logprobs is not compatible with prompt embeds." "prompt_logprobs is not compatible with prompt embeds."
) )
try: engine_prompts = await self._preprocess_completion(
engine_prompts = await self._preprocess_completion( request,
request, prompt_input=request.prompt,
prompt_input=request.prompt, prompt_embeds=request.prompt_embeds,
prompt_embeds=request.prompt_embeds, )
)
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(e)
return engine_prompts return engine_prompts
...@@ -149,11 +142,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -149,11 +142,7 @@ class OpenAIServingCompletion(OpenAIServing):
if raw_request: if raw_request:
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
try: lora_request = self._maybe_get_adapters(request)
lora_request = self._maybe_get_adapters(request)
except (ValueError, TypeError, RuntimeError) as e:
logger.exception("Error preparing request components")
return self.create_error_response(e)
# Extract data_parallel_rank from header (router can inject it) # Extract data_parallel_rank from header (router can inject it)
data_parallel_rank = self._get_data_parallel_rank(raw_request) data_parallel_rank = self._get_data_parallel_rank(raw_request)
...@@ -161,64 +150,61 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -161,64 +150,61 @@ class OpenAIServingCompletion(OpenAIServing):
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len max_model_len = self.model_config.max_model_len
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: for i, engine_prompt in enumerate(engine_prompts):
for i, engine_prompt in enumerate(engine_prompts): max_tokens = get_max_tokens(
max_tokens = get_max_tokens( max_model_len,
max_model_len, request.max_tokens,
request.max_tokens, self._extract_prompt_len(engine_prompt),
self._extract_prompt_len(engine_prompt), self.default_sampling_params,
self.override_max_tokens,
)
sampling_params: SamplingParams | BeamSearchParams
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params
)
else:
sampling_params = request.to_sampling_params(
max_tokens,
self.default_sampling_params, self.default_sampling_params,
self.override_max_tokens,
) )
sampling_params: SamplingParams | BeamSearchParams request_id_item = f"{request_id}-{i}"
if request.use_beam_search:
sampling_params = request.to_beam_search_params( self._log_inputs(
max_tokens, self.default_sampling_params request_id_item,
) engine_prompt,
else: params=sampling_params,
sampling_params = request.to_sampling_params( lora_request=lora_request,
max_tokens, )
self.default_sampling_params,
)
request_id_item = f"{request_id}-{i}" trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
self._log_inputs( if isinstance(sampling_params, BeamSearchParams):
request_id_item, generator = self.beam_search(
engine_prompt, prompt=engine_prompt,
request_id=request_id,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers,
) )
else:
trace_headers = ( generator = self.engine_client.generate(
None engine_prompt,
if raw_request is None sampling_params,
else await self._get_trace_headers(raw_request.headers) request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
) )
if isinstance(sampling_params, BeamSearchParams): generators.append(generator)
generator = self.beam_search(
prompt=engine_prompt,
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
)
else:
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)
generators.append(generator)
except ValueError as e:
return self.create_error_response(e)
result_generator = merge_async_iterators(*generators) result_generator = merge_async_iterators(*generators)
...@@ -273,10 +259,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -273,10 +259,6 @@ class OpenAIServingCompletion(OpenAIServing):
) )
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e:
return self.create_error_response(e)
# When user requests streaming but we don't stream, we still need to # When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event. # return a streaming response with a single event.
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Adapted from # Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time import time
from http import HTTPStatus
from typing import Any, ClassVar, Literal, TypeAlias from typing import Any, ClassVar, Literal, TypeAlias
import regex as re import regex as re
...@@ -262,6 +263,14 @@ class DeltaMessage(OpenAIBaseModel): ...@@ -262,6 +263,14 @@ class DeltaMessage(OpenAIBaseModel):
tool_calls: list[DeltaToolCall] = Field(default_factory=list) tool_calls: list[DeltaToolCall] = Field(default_factory=list)
class GenerationError(Exception):
"""raised when finish_reason indicates internal server error (500)"""
def __init__(self, message: str = "Internal server error"):
super().__init__(message)
self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR
####### Tokens IN <> Tokens OUT ####### ####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel): class GenerateRequest(BaseModel):
request_id: str = Field( request_id: str = Field(
......
...@@ -2,9 +2,7 @@ ...@@ -2,9 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import json import json
import sys
import time import time
import traceback
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from http import HTTPStatus from http import HTTPStatus
...@@ -38,10 +36,10 @@ from vllm.entrypoints.openai.completion.protocol import ( ...@@ -38,10 +36,10 @@ from vllm.entrypoints.openai.completion.protocol import (
CompletionResponse, CompletionResponse,
) )
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
ErrorInfo,
ErrorResponse, ErrorResponse,
FunctionCall, FunctionCall,
FunctionDefinition, FunctionDefinition,
GenerationError,
) )
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.responses.context import ( from vllm.entrypoints.openai.responses.context import (
...@@ -89,7 +87,7 @@ from vllm.entrypoints.serve.tokenize.protocol import ( ...@@ -89,7 +87,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest, TokenizeCompletionRequest,
TokenizeResponse, TokenizeResponse,
) )
from vllm.entrypoints.utils import get_max_tokens, sanitize_message from vllm.entrypoints.utils import create_error_response, get_max_tokens
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ( from vllm.inputs.data import (
ProcessorInputs, ProcessorInputs,
...@@ -125,15 +123,6 @@ from vllm.utils.async_utils import ( ...@@ -125,15 +123,6 @@ from vllm.utils.async_utils import (
) )
from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import is_mistral_tokenizer
class GenerationError(Exception):
"""raised when finish_reason indicates internal server error (500)"""
def __init__(self, message: str = "Internal server error"):
super().__init__(message)
self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -225,7 +214,6 @@ class OpenAIServing: ...@@ -225,7 +214,6 @@ class OpenAIServing:
*, *,
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
): ):
super().__init__() super().__init__()
...@@ -236,8 +224,6 @@ class OpenAIServing: ...@@ -236,8 +224,6 @@ class OpenAIServing:
self.request_logger = request_logger self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids self.return_tokens_as_token_ids = return_tokens_as_token_ids
self.log_error_stack = log_error_stack
self.model_config = engine_client.model_config self.model_config = engine_client.model_config
self.renderer = engine_client.renderer self.renderer = engine_client.renderer
self.io_processor = engine_client.io_processor self.io_processor = engine_client.io_processor
...@@ -526,133 +512,79 @@ class OpenAIServing: ...@@ -526,133 +512,79 @@ class OpenAIServing:
"""Schedule the request and get the result generator.""" """Schedule the request and get the result generator."""
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try: trace_headers = (
trace_headers = ( None
None if ctx.raw_request is None
if ctx.raw_request is None else await self._get_trace_headers(ctx.raw_request.headers)
else await self._get_trace_headers(ctx.raw_request.headers) )
)
pooling_params = self._create_pooling_params(ctx) pooling_params = self._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse): if isinstance(pooling_params, ErrorResponse):
return pooling_params return pooling_params
if ctx.engine_prompts is None: if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available") return self.create_error_response("Engine prompts not available")
for i, engine_prompt in enumerate(ctx.engine_prompts): for i, engine_prompt in enumerate(ctx.engine_prompts):
request_id_item = f"{ctx.request_id}-{i}" request_id_item = f"{ctx.request_id}-{i}"
self._log_inputs( self._log_inputs(
request_id_item, request_id_item,
engine_prompt, engine_prompt,
params=pooling_params, params=pooling_params,
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
) )
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
generators.append(generator) generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
ctx.result_generator = merge_async_iterators(*generators) generators.append(generator)
return None ctx.result_generator = merge_async_iterators(*generators)
except Exception as e: return None
return self.create_error_response(e)
async def _collect_batch( async def _collect_batch(
self, self,
ctx: ServeContext, ctx: ServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Collect batch results from the result generator.""" """Collect batch results from the result generator."""
try: if ctx.engine_prompts is None:
if ctx.engine_prompts is None: return self.create_error_response("Engine prompts not available")
return self.create_error_response("Engine prompts not available")
num_prompts = len(ctx.engine_prompts) num_prompts = len(ctx.engine_prompts)
final_res_batch: list[PoolingRequestOutput | None] final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts final_res_batch = [None] * num_prompts
if ctx.result_generator is None: if ctx.result_generator is None:
return self.create_error_response("Result generator not available") return self.create_error_response("Result generator not available")
async for i, res in ctx.result_generator: async for i, res in ctx.result_generator:
final_res_batch[i] = res final_res_batch[i] = res
if None in final_res_batch:
return self.create_error_response(
"Failed to generate results for all prompts"
)
ctx.final_res_batch = [res for res in final_res_batch if res is not None] if None in final_res_batch:
return self.create_error_response(
"Failed to generate results for all prompts"
)
return None ctx.final_res_batch = [res for res in final_res_batch if res is not None]
except Exception as e: return None
return self.create_error_response(e)
@staticmethod
def create_error_response( def create_error_response(
self,
message: str | Exception, message: str | Exception,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
param: str | None = None, param: str | None = None,
) -> ErrorResponse: ) -> ErrorResponse:
exc: Exception | None = None return create_error_response(message, err_type, status_code, param)
if isinstance(message, Exception):
exc = message
from vllm.exceptions import VLLMValidationError
if isinstance(exc, VLLMValidationError):
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = exc.parameter
elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
# Common validation errors from user input
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
elif isinstance(exc, NotImplementedError):
err_type = "NotImplementedError"
status_code = HTTPStatus.NOT_IMPLEMENTED
param = None
elif exc.__class__.__name__ == "TemplateError":
# jinja2.TemplateError (avoid importing jinja2)
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
else:
err_type = "InternalServerError"
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
param = None
message = str(exc)
if self.log_error_stack:
exc_type, _, _ = sys.exc_info()
if exc_type is not None:
traceback.print_exc()
else:
traceback.print_stack()
return ErrorResponse(
error=ErrorInfo(
message=sanitize_message(message),
type=err_type,
code=status_code.value,
param=param,
)
)
def create_streaming_error_response( def create_streaming_error_response(
self, self,
...@@ -680,16 +612,6 @@ class OpenAIServing: ...@@ -680,16 +612,6 @@ class OpenAIServing:
) )
raise GenerationError("Internal server error") raise GenerationError("Internal server error")
def _convert_generation_error_to_response(
self, e: GenerationError
) -> ErrorResponse:
"""Convert GenerationError to ErrorResponse."""
return self.create_error_response(
str(e),
err_type="InternalServerError",
status_code=e.status_code,
)
def _convert_generation_error_to_streaming_response( def _convert_generation_error_to_streaming_response(
self, e: GenerationError self, e: GenerationError
) -> str: ) -> str:
......
...@@ -87,7 +87,6 @@ async def init_generate_state( ...@@ -87,7 +87,6 @@ async def init_generate_state(
enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage, enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs, enable_log_outputs=args.enable_log_outputs,
log_error_stack=args.log_error_stack,
) )
if "generate" in supported_tasks if "generate" in supported_tasks
else None else None
...@@ -111,7 +110,6 @@ async def init_generate_state( ...@@ -111,7 +110,6 @@ async def init_generate_state(
enable_force_include_usage=args.enable_force_include_usage, enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs, enable_log_outputs=args.enable_log_outputs,
enable_log_deltas=args.enable_log_deltas, enable_log_deltas=args.enable_log_deltas,
log_error_stack=args.log_error_stack,
) )
if any(task in supported_tasks for task in ("generate", "render")) if any(task in supported_tasks for task in ("generate", "render"))
else None else None
...@@ -127,7 +125,6 @@ async def init_generate_state( ...@@ -127,7 +125,6 @@ async def init_generate_state(
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage, enable_force_include_usage=args.enable_force_include_usage,
log_error_stack=args.log_error_stack,
) )
if any(task in supported_tasks for task in ("generate", "render")) if any(task in supported_tasks for task in ("generate", "render"))
else None else None
...@@ -156,7 +153,6 @@ async def init_generate_state( ...@@ -156,7 +153,6 @@ async def init_generate_state(
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
log_error_stack=args.log_error_stack,
enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_log_outputs=args.enable_log_outputs, enable_log_outputs=args.enable_log_outputs,
force_no_detokenize=args.tokens_only, force_no_detokenize=args.tokens_only,
......
...@@ -68,7 +68,6 @@ def init_realtime_state( ...@@ -68,7 +68,6 @@ def init_realtime_state(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=args.log_error_stack,
) )
if "realtime" in supported_tasks if "realtime" in supported_tasks
else None else None
......
...@@ -33,13 +33,11 @@ class OpenAIServingRealtime(OpenAIServing): ...@@ -33,13 +33,11 @@ class OpenAIServingRealtime(OpenAIServing):
models: OpenAIServingModels, models: OpenAIServingModels,
*, *,
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
log_error_stack: bool = False,
): ):
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
log_error_stack=log_error_stack,
) )
self.task_type: Literal["realtime"] = "realtime" self.task_type: Literal["realtime"] = "realtime"
......
...@@ -63,10 +63,8 @@ async def create_responses(request: ResponsesRequest, raw_request: Request): ...@@ -63,10 +63,8 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
return base_server.create_error_response( return base_server.create_error_response(
message="The model does not support Responses API" message="The model does not support Responses API"
) )
try:
generator = await handler.create_responses(request, raw_request) generator = await handler.create_responses(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
...@@ -95,14 +93,11 @@ async def retrieve_responses( ...@@ -95,14 +93,11 @@ async def retrieve_responses(
message="The model does not support Responses API" message="The model does not support Responses API"
) )
try: response = await handler.retrieve_responses(
response = await handler.retrieve_responses( response_id,
response_id, starting_after=starting_after,
starting_after=starting_after, stream=stream,
stream=stream, )
)
except Exception as e:
response = handler.create_error_response(e)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse( return JSONResponse(
...@@ -125,10 +120,7 @@ async def cancel_responses(response_id: str, raw_request: Request): ...@@ -125,10 +120,7 @@ async def cancel_responses(response_id: str, raw_request: Request):
message="The model does not support Responses API" message="The model does not support Responses API"
) )
try: response = await handler.cancel_responses(response_id)
response = await handler.cancel_responses(response_id)
except Exception as e:
response = handler.create_error_response(e)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse( return JSONResponse(
......
...@@ -11,7 +11,6 @@ from copy import copy ...@@ -11,7 +11,6 @@ from copy import copy
from http import HTTPStatus from http import HTTPStatus
from typing import Final from typing import Final
import jinja2
from fastapi import Request from fastapi import Request
from openai.types.responses import ( from openai.types.responses import (
ResponseContentPartAddedEvent, ResponseContentPartAddedEvent,
...@@ -174,14 +173,12 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -174,14 +173,12 @@ class OpenAIServingResponses(OpenAIServing):
enable_prompt_tokens_details: bool = False, enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False, enable_force_include_usage: bool = False,
enable_log_outputs: bool = False, enable_log_outputs: bool = False,
log_error_stack: bool = False,
) -> None: ) -> None:
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
models=models, models=models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
) )
self.chat_template = chat_template self.chat_template = chat_template
...@@ -365,28 +362,15 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -365,28 +362,15 @@ class OpenAIServingResponses(OpenAIServing):
else: else:
prev_response = None prev_response = None
try: lora_request = self._maybe_get_adapters(request)
lora_request = self._maybe_get_adapters(request) model_name = self.models.model_name(lora_request)
model_name = self.models.model_name(lora_request)
if self.use_harmony:
messages, engine_prompts = self._make_request_with_harmony(
request, prev_response
)
else:
messages, engine_prompts = await self._make_request(
request, prev_response
)
except ( if self.use_harmony:
ValueError, messages, engine_prompts = self._make_request_with_harmony(
TypeError, request, prev_response
RuntimeError, )
jinja2.TemplateError, else:
NotImplementedError, messages, engine_prompts = await self._make_request(request, prev_response)
) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(e)
request_metadata = RequestResponseMetadata(request_id=request.request_id) request_metadata = RequestResponseMetadata(request_id=request.request_id)
if raw_request: if raw_request:
...@@ -424,86 +408,83 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -424,86 +408,83 @@ class OpenAIServingResponses(OpenAIServing):
else: else:
assert len(builtin_tool_list) == 0 assert len(builtin_tool_list) == 0
available_tools = [] available_tools = []
try: tokenizer = self.renderer.get_tokenizer()
tokenizer = self.renderer.get_tokenizer()
for engine_prompt in engine_prompts:
for engine_prompt in engine_prompts: maybe_error = self._validate_generator_input(engine_prompt)
maybe_error = self._validate_generator_input(engine_prompt) if maybe_error is not None:
if maybe_error is not None: return maybe_error
return maybe_error
default_max_tokens = get_max_tokens(
default_max_tokens = get_max_tokens( max_model_len,
max_model_len, request.max_output_tokens,
request.max_output_tokens, self._extract_prompt_len(engine_prompt),
self._extract_prompt_len(engine_prompt), self.default_sampling_params,
self.default_sampling_params, self.override_max_tokens,
self.override_max_tokens, )
)
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params default_max_tokens, self.default_sampling_params
) )
trace_headers = ( trace_headers = (
None None
if raw_request is None if raw_request is None
else await self._get_trace_headers(raw_request.headers) else await self._get_trace_headers(raw_request.headers)
) )
context: ConversationContext context: ConversationContext
if self.use_harmony: if self.use_harmony:
if request.stream: if request.stream:
context = StreamingHarmonyContext(messages, available_tools) context = StreamingHarmonyContext(messages, available_tools)
else:
context = HarmonyContext(messages, available_tools)
else: else:
if envs.VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT: context = HarmonyContext(messages, available_tools)
# This is a feature in development for parsing else:
# tokens during generation instead of at the end if envs.VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT:
context = ParsableContext( # This is a feature in development for parsing
response_messages=messages, # tokens during generation instead of at the end
tokenizer=tokenizer, context = ParsableContext(
reasoning_parser_cls=self.parser.reasoning_parser_cls response_messages=messages,
if self.parser tokenizer=tokenizer,
else None, reasoning_parser_cls=self.parser.reasoning_parser_cls
request=request, if self.parser
tool_parser_cls=self.parser.tool_parser_cls else None,
if self.parser request=request,
else None, tool_parser_cls=self.parser.tool_parser_cls
available_tools=available_tools, if self.parser
chat_template=self.chat_template, else None,
chat_template_content_format=self.chat_template_content_format, available_tools=available_tools,
) chat_template=self.chat_template,
else: chat_template_content_format=self.chat_template_content_format,
context = SimpleContext() )
else:
if self.parser and self.parser.reasoning_parser_cls is not None: context = SimpleContext()
reasoning_parser = self.parser.reasoning_parser_cls(tokenizer)
if ( if self.parser and self.parser.reasoning_parser_cls is not None:
isinstance( reasoning_parser = self.parser.reasoning_parser_cls(tokenizer)
struct_out := sampling_params.structured_outputs, if (
StructuredOutputsParams, isinstance(
) struct_out := sampling_params.structured_outputs,
and struct_out.all_non_structural_tag_constraints_none() StructuredOutputsParams,
): )
sampling_params.structured_outputs = replace( and struct_out.all_non_structural_tag_constraints_none()
struct_out, ):
structural_tag=reasoning_parser.prepare_structured_tag( sampling_params.structured_outputs = replace(
struct_out.structural_tag, self.tool_server struct_out,
), structural_tag=reasoning_parser.prepare_structured_tag(
) struct_out.structural_tag, self.tool_server
generator = self._generate_with_builtin_tools( ),
request_id=request.request_id, )
engine_prompt=engine_prompt, generator = self._generate_with_builtin_tools(
sampling_params=sampling_params, request_id=request.request_id,
context=context, engine_prompt=engine_prompt,
lora_request=lora_request, sampling_params=sampling_params,
priority=request.priority, context=context,
trace_headers=trace_headers, lora_request=lora_request,
) priority=request.priority,
generators.append(generator) trace_headers=trace_headers,
except ValueError as e: )
return self.create_error_response(e) generators.append(generator)
assert len(generators) == 1 assert len(generators) == 1
(result_generator,) = generators (result_generator,) = generators
...@@ -578,20 +559,15 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -578,20 +559,15 @@ class OpenAIServingResponses(OpenAIServing):
request_metadata, request_metadata,
) )
try: return await self.responses_full_generator(
return await self.responses_full_generator( request,
request, sampling_params,
sampling_params, result_generator,
result_generator, context,
context, model_name,
model_name, tokenizer,
tokenizer, request_metadata,
request_metadata, )
)
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except Exception as e:
return self.create_error_response(e)
async def _make_request( async def _make_request(
self, self,
...@@ -675,8 +651,6 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -675,8 +651,6 @@ class OpenAIServingResponses(OpenAIServing):
pass pass
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e:
return self.create_error_response(e)
# NOTE: Implementation of status is still WIP, but for now # NOTE: Implementation of status is still WIP, but for now
# we guarantee that if the status is not "completed", it is accurate. # we guarantee that if the status is not "completed", it is accurate.
...@@ -1129,16 +1103,11 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1129,16 +1103,11 @@ class OpenAIServingResponses(OpenAIServing):
new_event_signal = asyncio.Event() new_event_signal = asyncio.Event()
self.event_store[request.request_id] = (event_deque, new_event_signal) self.event_store[request.request_id] = (event_deque, new_event_signal)
response = None response = None
generator = self.responses_stream_generator(request, *args, **kwargs)
try: try:
generator = self.responses_stream_generator(request, *args, **kwargs)
async for event in generator: async for event in generator:
event_deque.append(event) event_deque.append(event)
new_event_signal.set() # Signal new event available new_event_signal.set() # Signal new event available
except GenerationError as e:
response = self._convert_generation_error_to_response(e)
except Exception as e:
logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(e)
finally: finally:
new_event_signal.set() new_event_signal.set()
...@@ -1157,13 +1126,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1157,13 +1126,7 @@ class OpenAIServingResponses(OpenAIServing):
*args, *args,
**kwargs, **kwargs,
): ):
try: response = await self.responses_full_generator(request, *args, **kwargs)
response = await self.responses_full_generator(request, *args, **kwargs)
except GenerationError as e:
response = self._convert_generation_error_to_response(e)
except Exception as e:
logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(e)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
# If the request has failed, update the status to "failed". # If the request has failed, update the status to "failed".
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment