Unverified Commit 176c799f authored by Ning Xie's avatar Ning Xie Committed by GitHub
Browse files

[openai api] log exception in exception handler (1/N) (#31164)


Signed-off-by: default avatarAndy Xie <andy.xning@gmail.com>
parent 612e7729
......@@ -6,7 +6,6 @@ from unittest.mock import MagicMock
import pytest
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import GenerationError, OpenAIServing
......@@ -38,32 +37,6 @@ async def test_raise_if_error_raises_generation_error():
serving._raise_if_error(None, "test-request-id") # should not raise
@pytest.mark.asyncio
async def test_convert_generation_error_to_response():
"""test _convert_generation_error_to_response creates proper ErrorResponse"""
mock_engine = MagicMock()
mock_engine.model_config = MagicMock()
mock_engine.model_config.max_model_len = 100
mock_models = MagicMock()
serving = OpenAIServing(
engine_client=mock_engine,
models=mock_models,
request_logger=None,
)
# create a GenerationError
gen_error = GenerationError("Internal server error")
# convert to ErrorResponse
error_response = serving._convert_generation_error_to_response(gen_error)
assert isinstance(error_response, ErrorResponse)
assert error_response.error.type == "InternalServerError"
assert error_response.error.message == "Internal server error"
assert error_response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
@pytest.mark.asyncio
async def test_convert_generation_error_to_streaming_response():
"""test _convert_generation_error_to_streaming_response output"""
......
......@@ -13,7 +13,7 @@ from typing import Any
import pytest
import pytest_asyncio
import requests
from openai import BadRequestError, NotFoundError, OpenAI
from openai import InternalServerError, NotFoundError, OpenAI
from openai_harmony import Message
from ....utils import RemoteOpenAIServer
......@@ -698,7 +698,7 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str):
async def test_function_calling_required(client: OpenAI, model_name: str):
tools = [GET_WEATHER_SCHEMA]
with pytest.raises(BadRequestError):
with pytest.raises(InternalServerError):
await client.responses.create(
model=model_name,
input="What's the weather like in Paris today?",
......
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
......@@ -11,7 +10,7 @@ import pytest
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.protocol import GenerationError
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput
......@@ -145,12 +144,8 @@ async def test_chat_error_non_stream():
stream=False,
)
response = await serving_chat.create_chat_completion(request)
assert isinstance(response, ErrorResponse)
assert response.error.type == "InternalServerError"
assert response.error.message == "Internal server error"
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
with pytest.raises(GenerationError):
await serving_chat.create_chat_completion(request)
@pytest.mark.asyncio
......
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any
from unittest.mock import MagicMock
......@@ -11,7 +10,7 @@ import pytest
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.protocol import GenerationError
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput
......@@ -131,12 +130,8 @@ async def test_completion_error_non_stream():
stream=False,
)
response = await serving_completion.create_completion(request)
assert isinstance(response, ErrorResponse)
assert response.error.type == "InternalServerError"
assert response.error.message == "Internal server error"
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
with pytest.raises(GenerationError):
await serving_completion.create_completion(request)
@pytest.mark.asyncio
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from http import HTTPStatus
from typing import Final
import pytest
import schemathesis
from httpx import URL
from hypothesis import settings
from schemathesis import GenerationConfig
from schemathesis.checks import not_a_server_error
from schemathesis.internal.checks import CheckContext
from schemathesis.models import Case
from schemathesis.transports.responses import GenericResponse
from ...utils import RemoteOpenAIServer
......@@ -127,10 +133,25 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy):
return strategy.filter(no_invalid_types)
def customized_not_a_server_error(
ctx: CheckContext, response: GenericResponse, case: Case
) -> bool | None:
try:
return not_a_server_error(ctx, response, case)
except Exception:
if (
URL(response.request.url).path
in ["/v1/chat/completions/render", "/v1/chat/completions"]
and response.status_code == HTTPStatus.NOT_IMPLEMENTED.value
):
return True
raise
@schema.parametrize()
@schema.override(headers={"Content-Type": "application/json"})
@settings(deadline=LONG_TIMEOUT_SECONDS * 1000, max_examples=50)
def test_openapi_stateless(case: schemathesis.Case):
def test_openapi_stateless(case: Case):
key = (
case.operation.method.upper(),
case.operation.path,
......@@ -155,4 +176,9 @@ def test_openapi_stateless(case: schemathesis.Case):
}.get(key, DEFAULT_TIMEOUT_SECONDS)
# No need to verify SSL certificate for localhost
case.call_and_validate(verify=False, timeout=timeout)
case.call_and_validate(
verify=False,
timeout=timeout,
additional_checks=(customized_not_a_server_error,),
excluded_checks=(not_a_server_error,),
)
......@@ -23,6 +23,7 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.entrypoints.openai.models.serving import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.parser.harmony_utils import get_encoding
from vllm.exceptions import VLLMValidationError
from vllm.inputs import TokensPrompt
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.renderers.hf import HfRenderer
......@@ -818,9 +819,8 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated():
max_tokens=10,
)
resp = await serving_chat.create_chat_completion(req)
assert isinstance(resp, ErrorResponse)
assert "context length is only" in resp.error.message
with pytest.raises(VLLMValidationError):
await serving_chat.create_chat_completion(req)
@pytest.mark.asyncio
......@@ -860,9 +860,8 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
max_tokens=1,
)
resp = await serving_chat.create_chat_completion(req)
assert isinstance(resp, ErrorResponse)
assert "context length is only" in resp.error.message
with pytest.raises(VLLMValidationError):
await serving_chat.create_chat_completion(req)
@pytest.mark.asyncio
......
......@@ -17,9 +17,6 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionResponse,
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
)
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.inputs import PromptType
......@@ -542,11 +539,9 @@ async def test_header_dp_rank_argument():
# Test 2: Out-of-range DP rank (1)
mock_raw_request.headers = {"X-data-parallel-rank": "1"}
# should return ErrorResponse for out-of-range rank
response2 = await serving_chat.create_chat_completion(req, mock_raw_request)
assert isinstance(response2, ErrorResponse), (
"Expected an ErrorResponse for out-of-range DP rank"
)
# should raise ValueError for out-of-range rank
with pytest.raises(ValueError):
await serving_chat.create_chat_completion(req, mock_raw_request)
@pytest.mark.asyncio
......
......@@ -4,11 +4,10 @@
import asyncio
import signal
import socket
from http import HTTPStatus
from typing import Any
import uvicorn
from fastapi import FastAPI, Request, Response
from fastapi import FastAPI
from vllm import envs
from vllm.engine.protocol import EngineClient
......@@ -19,7 +18,6 @@ from vllm.entrypoints.constants import (
from vllm.entrypoints.ssl import SSLCertRefresher
from vllm.logger import init_logger
from vllm.utils.network_utils import find_process_using_port
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
logger = init_logger(__name__)
......@@ -75,7 +73,7 @@ async def serve_http(
config.h11_max_header_count = h11_max_header_count
config.load()
server = uvicorn.Server(config)
_add_shutdown_handlers(app, server)
app.state.server = server
loop = asyncio.get_running_loop()
......@@ -148,40 +146,3 @@ def terminate_if_errored(server: uvicorn.Server, engine: EngineClient):
engine_errored = engine.errored and not engine.is_running
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored:
server.should_exit = True
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""
VLLM V1 AsyncLLM catches exceptions and returns
only two types: EngineGenerateError and EngineDeadError.
EngineGenerateError is raised by the per request generate()
method. This error could be request specific (and therefore
recoverable - e.g. if there is an error in input processing).
EngineDeadError is raised by the background output_handler
method. This error is global and therefore not recoverable.
We register these @app.exception_handlers to return nice
responses to the end user if they occur and shut down if needed.
See https://fastapi.tiangolo.com/tutorial/handling-errors/
for more details on how exception handlers work.
If an exception is encountered in a StreamingResponse
generator, the exception is not raised, since we already sent
a 200 status. Rather, we send an error message as the next chunk.
Since the exception is not raised, this means that the server
will not automatically shut down. Instead, we use the watchdog
background task for check for errored state.
"""
@app.exception_handler(RuntimeError)
@app.exception_handler(EngineDeadError)
@app.exception_handler(EngineGenerateError)
async def runtime_exception_handler(request: Request, __):
terminate_if_errored(
server=server,
engine=request.app.state.engine_client,
)
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
......@@ -31,6 +31,8 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_se
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.server_utils import (
engine_error_handler,
exception_handler,
get_uvicorn_log_config,
http_exception_handler,
lifespan,
......@@ -57,6 +59,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.network_utils import is_valid_ipv6_address
from vllm.utils.system_utils import decorate_logs, set_ulimit
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.version import __version__ as VLLM_VERSION
prometheus_multiproc_dir: tempfile.TemporaryDirectory
......@@ -250,6 +253,9 @@ def build_app(
app.exception_handler(HTTPException)(http_exception_handler)
app.exception_handler(RequestValidationError)(validation_exception_handler)
app.exception_handler(EngineGenerateError)(engine_error_handler)
app.exception_handler(EngineDeadError)(engine_error_handler)
app.exception_handler(Exception)(exception_handler)
# Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
......@@ -355,7 +361,6 @@ async def init_app_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
if any(task in supported_tasks for task in ("generate", "render")):
......
......@@ -39,6 +39,7 @@ def chat(request: Request) -> OpenAIServingChat | None:
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
},
)
@with_cancellation
......@@ -54,10 +55,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
message="The model does not support Chat Completions API"
)
try:
generator = await handler.create_chat_completion(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
generator = await handler.create_chat_completion(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(
......@@ -81,6 +79,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
},
)
async def render_chat_completion(request: ChatCompletionRequest, raw_request: Request):
......@@ -93,10 +92,7 @@ async def render_chat_completion(request: ChatCompletionRequest, raw_request: Re
message="The model does not support Chat Completions API"
)
try:
result = await handler.render_chat_request(request)
except Exception as e:
result = handler.create_error_response(e)
result = await handler.render_chat_request(request)
if isinstance(result, ErrorResponse):
return JSONResponse(content=result.model_dump(), status_code=result.error.code)
......
......@@ -8,7 +8,6 @@ from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import Any, Final
import jinja2
import partial_json_parser
import regex as re
from fastapi import Request
......@@ -105,7 +104,6 @@ class OpenAIServingChat(OpenAIServing):
enable_force_include_usage: bool = False,
enable_log_outputs: bool = False,
enable_log_deltas: bool = True,
log_error_stack: bool = False,
default_chat_template_kwargs: dict[str, Any] | None = None,
) -> None:
super().__init__(
......@@ -113,7 +111,6 @@ class OpenAIServingChat(OpenAIServing):
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
)
self.response_role = response_role
......@@ -235,81 +232,76 @@ class OpenAIServingChat(OpenAIServing):
if self.engine_client.errored:
raise self.engine_client.dead_error
try:
tokenizer = self.renderer.tokenizer
tool_parser = self.tool_parser
if is_mistral_tokenizer(tokenizer):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
_mt.maybe_serialize_tool_calls(request) # type: ignore[arg-type]
_mt.truncate_tool_call_ids(request) # type: ignore[arg-type]
_mt.validate_request_params(request)
# Check if tool parsing is unavailable (common condition)
tool_parsing_unavailable = (
tool_parser is None
and not is_mistral_tokenizer(tokenizer)
and not self.use_harmony
)
tokenizer = self.renderer.tokenizer
# Validate tool_choice when tool parsing is required but unavailable
if tool_parsing_unavailable and request.tool_choice not in (
None,
"none",
):
if request.tool_choice == "auto" and not self.enable_auto_tools:
# for hf tokenizers, "auto" tools requires
# --enable-auto-tool-choice and --tool-call-parser
return self.create_error_response(
'"auto" tool choice requires '
"--enable-auto-tool-choice and --tool-call-parser to be set"
)
elif request.tool_choice != "auto":
# "required" or named tool requires tool parser
return self.create_error_response(
f'tool_choice="{request.tool_choice}" requires '
"--tool-call-parser to be set"
)
tool_parser = self.tool_parser
if is_mistral_tokenizer(tokenizer):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
_mt.maybe_serialize_tool_calls(request) # type: ignore[arg-type]
_mt.truncate_tool_call_ids(request) # type: ignore[arg-type]
_mt.validate_request_params(request)
# Check if tool parsing is unavailable (common condition)
tool_parsing_unavailable = (
tool_parser is None
and not is_mistral_tokenizer(tokenizer)
and not self.use_harmony
)
if request.tools is None or (
request.tool_choice == "none"
and self.exclude_tools_when_tool_choice_none
):
tool_dicts = None
else:
tool_dicts = [tool.model_dump() for tool in request.tools]
if not self.use_harmony:
# Common case.
error_check_ret = self._validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
# Validate tool_choice when tool parsing is required but unavailable
if tool_parsing_unavailable and request.tool_choice not in (
None,
"none",
):
if request.tool_choice == "auto" and not self.enable_auto_tools:
# for hf tokenizers, "auto" tools requires
# --enable-auto-tool-choice and --tool-call-parser
return self.create_error_response(
'"auto" tool choice requires '
"--enable-auto-tool-choice and --tool-call-parser to be set"
)
if error_check_ret is not None:
return error_check_ret
conversation, engine_prompts = await self._preprocess_chat(
request,
request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=self.default_chat_template_kwargs,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
elif request.tool_choice != "auto":
# "required" or named tool requires tool parser
return self.create_error_response(
f'tool_choice="{request.tool_choice}" requires '
"--tool-call-parser to be set"
)
else:
# For GPT-OSS.
should_include_tools = tool_dicts is not None
conversation, engine_prompts = self._make_request_with_harmony(
request, should_include_tools
)
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(e)
if request.tools is None or (
request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none
):
tool_dicts = None
else:
tool_dicts = [tool.model_dump() for tool in request.tools]
if not self.use_harmony:
# Common case.
error_check_ret = self._validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
)
if error_check_ret is not None:
return error_check_ret
conversation, engine_prompts = await self._preprocess_chat(
request,
request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=self.default_chat_template_kwargs,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
)
else:
# For GPT-OSS.
should_include_tools = tool_dicts is not None
conversation, engine_prompts = self._make_request_with_harmony(
request, should_include_tools
)
return conversation, engine_prompts
......@@ -329,20 +321,16 @@ class OpenAIServingChat(OpenAIServing):
tokenizer = self.renderer.tokenizer
assert tokenizer is not None
reasoning_parser: ReasoningParser | None = None
try:
if self.reasoning_parser_cls:
# Pass the same chat template kwargs as used in tokenization
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
request.chat_template_kwargs,
self.default_chat_template_kwargs,
)
reasoning_parser = self.reasoning_parser_cls(
tokenizer,
chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg]
)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
return self.create_error_response(str(e))
if self.reasoning_parser_cls:
# Pass the same chat template kwargs as used in tokenization
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
request.chat_template_kwargs,
self.default_chat_template_kwargs,
)
reasoning_parser = self.reasoning_parser_cls(
tokenizer,
chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg]
)
result = await self.render_chat_request(request)
if isinstance(result, ErrorResponse):
return result
......@@ -357,15 +345,9 @@ class OpenAIServingChat(OpenAIServing):
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
lora_request = self._maybe_get_adapters(
request, supports_default_mm_loras=True
)
lora_request = self._maybe_get_adapters(request, supports_default_mm_loras=True)
model_name = self.models.model_name(lora_request)
except (ValueError, TypeError, RuntimeError) as e:
logger.exception("Error preparing request components")
return self.create_error_response(e)
model_name = self.models.model_name(lora_request)
# Extract data_parallel_rank from header (router can inject it)
data_parallel_rank = self._get_data_parallel_rank(raw_request)
......@@ -373,81 +355,76 @@ class OpenAIServingChat(OpenAIServing):
# Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_token_ids = self._extract_prompt_components(
engine_prompt
).token_ids
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
sub_request_id = (
request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
)
for i, engine_prompt in enumerate(engine_prompts):
prompt_token_ids = self._extract_prompt_components(engine_prompt).token_ids
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
sub_request_id = (
request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
)
max_tokens = get_max_tokens(
max_model_len,
request.max_completion_tokens
if request.max_completion_tokens is not None
else request.max_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
self.override_max_tokens,
)
max_tokens = get_max_tokens(
max_model_len,
request.max_completion_tokens
if request.max_completion_tokens is not None
else request.max_tokens,
self._extract_prompt_len(engine_prompt),
sampling_params: SamplingParams | BeamSearchParams
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params
)
else:
sampling_params = request.to_sampling_params(
max_tokens,
self.default_sampling_params,
self.override_max_tokens,
)
sampling_params: SamplingParams | BeamSearchParams
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params
)
else:
sampling_params = request.to_sampling_params(
max_tokens,
self.default_sampling_params,
)
self._log_inputs(
sub_request_id,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
self._log_inputs(
sub_request_id,
engine_prompt,
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
request_id=sub_request_id,
params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
else:
reasoning_ended = (
reasoning_parser.is_reasoning_end(prompt_token_ids or [])
if reasoning_parser
else None
)
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
request_id=sub_request_id,
params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
)
else:
reasoning_ended = (
reasoning_parser.is_reasoning_end(prompt_token_ids or [])
if reasoning_parser
else None
)
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
reasoning_ended=reasoning_ended,
)
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
reasoning_ended=reasoning_ended,
)
generators.append(generator)
except ValueError as e:
return self.create_error_response(e)
generators.append(generator)
assert len(generators) == 1
(result_generator,) = generators
......@@ -464,21 +441,16 @@ class OpenAIServingChat(OpenAIServing):
reasoning_parser,
)
try:
return await self.chat_completion_full_generator(
request,
result_generator,
request_id,
model_name,
conversation,
tokenizer,
request_metadata,
reasoning_parser,
)
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e:
return self.create_error_response(e)
return await self.chat_completion_full_generator(
request,
result_generator,
request_id,
model_name,
conversation,
tokenizer,
request_metadata,
reasoning_parser,
)
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
......@@ -1414,8 +1386,6 @@ class OpenAIServingChat(OpenAIServing):
final_res = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
return self.create_error_response(e)
assert final_res is not None
......
......@@ -54,10 +54,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
message="The model does not support Completions API"
)
try:
generator = await handler.create_completion(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
generator = await handler.create_completion(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(
......@@ -91,10 +88,7 @@ async def render_completion(request: CompletionRequest, raw_request: Request):
message="The model does not support Completions API"
)
try:
result = await handler.render_completion_request(request)
except Exception as e:
result = handler.create_error_response(e)
result = await handler.render_completion_request(request)
if isinstance(result, ErrorResponse):
return JSONResponse(content=result.model_dump(), status_code=result.error.code)
......
......@@ -7,7 +7,6 @@ from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import cast
import jinja2
from fastapi import Request
from vllm.engine.protocol import EngineClient
......@@ -56,14 +55,12 @@ class OpenAIServingCompletion(OpenAIServing):
return_tokens_as_token_ids: bool = False,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
log_error_stack: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
)
self.enable_prompt_tokens_details = enable_prompt_tokens_details
......@@ -110,15 +107,11 @@ class OpenAIServingCompletion(OpenAIServing):
"prompt_logprobs is not compatible with prompt embeds."
)
try:
engine_prompts = await self._preprocess_completion(
request,
prompt_input=request.prompt,
prompt_embeds=request.prompt_embeds,
)
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(e)
engine_prompts = await self._preprocess_completion(
request,
prompt_input=request.prompt,
prompt_embeds=request.prompt_embeds,
)
return engine_prompts
......@@ -149,11 +142,7 @@ class OpenAIServingCompletion(OpenAIServing):
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
lora_request = self._maybe_get_adapters(request)
except (ValueError, TypeError, RuntimeError) as e:
logger.exception("Error preparing request components")
return self.create_error_response(e)
lora_request = self._maybe_get_adapters(request)
# Extract data_parallel_rank from header (router can inject it)
data_parallel_rank = self._get_data_parallel_rank(raw_request)
......@@ -161,64 +150,61 @@ class OpenAIServingCompletion(OpenAIServing):
# Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
max_tokens = get_max_tokens(
max_model_len,
request.max_tokens,
self._extract_prompt_len(engine_prompt),
for i, engine_prompt in enumerate(engine_prompts):
max_tokens = get_max_tokens(
max_model_len,
request.max_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
self.override_max_tokens,
)
sampling_params: SamplingParams | BeamSearchParams
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params
)
else:
sampling_params = request.to_sampling_params(
max_tokens,
self.default_sampling_params,
self.override_max_tokens,
)
sampling_params: SamplingParams | BeamSearchParams
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
max_tokens, self.default_sampling_params
)
else:
sampling_params = request.to_sampling_params(
max_tokens,
self.default_sampling_params,
)
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
request_id_item = f"{request_id}-{i}"
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
self._log_inputs(
request_id_item,
engine_prompt,
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
else:
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
request_id=request_id,
params=sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
)
else:
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)
generators.append(generator)
except ValueError as e:
return self.create_error_response(e)
generators.append(generator)
result_generator = merge_async_iterators(*generators)
......@@ -273,10 +259,6 @@ class OpenAIServingCompletion(OpenAIServing):
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e:
return self.create_error_response(e)
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
......
......@@ -4,6 +4,7 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from http import HTTPStatus
from typing import Any, ClassVar, Literal, TypeAlias
import regex as re
......@@ -262,6 +263,14 @@ class DeltaMessage(OpenAIBaseModel):
tool_calls: list[DeltaToolCall] = Field(default_factory=list)
class GenerationError(Exception):
"""raised when finish_reason indicates internal server error (500)"""
def __init__(self, message: str = "Internal server error"):
super().__init__(message)
self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
request_id: str = Field(
......
......@@ -2,9 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import json
import sys
import time
import traceback
from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
from dataclasses import dataclass, field
from http import HTTPStatus
......@@ -38,10 +36,10 @@ from vllm.entrypoints.openai.completion.protocol import (
CompletionResponse,
)
from vllm.entrypoints.openai.engine.protocol import (
ErrorInfo,
ErrorResponse,
FunctionCall,
FunctionDefinition,
GenerationError,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.responses.context import (
......@@ -89,7 +87,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest,
TokenizeResponse,
)
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
from vllm.entrypoints.utils import create_error_response, get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import (
ProcessorInputs,
......@@ -125,15 +123,6 @@ from vllm.utils.async_utils import (
)
from vllm.utils.mistral import is_mistral_tokenizer
class GenerationError(Exception):
"""raised when finish_reason indicates internal server error (500)"""
def __init__(self, message: str = "Internal server error"):
super().__init__(message)
self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR
logger = init_logger(__name__)
......@@ -225,7 +214,6 @@ class OpenAIServing:
*,
request_logger: RequestLogger | None,
return_tokens_as_token_ids: bool = False,
log_error_stack: bool = False,
):
super().__init__()
......@@ -236,8 +224,6 @@ class OpenAIServing:
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self.log_error_stack = log_error_stack
self.model_config = engine_client.model_config
self.renderer = engine_client.renderer
self.io_processor = engine_client.io_processor
......@@ -526,133 +512,79 @@ class OpenAIServing:
"""Schedule the request and get the result generator."""
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
trace_headers = (
None
if ctx.raw_request is None
else await self._get_trace_headers(ctx.raw_request.headers)
)
trace_headers = (
None
if ctx.raw_request is None
else await self._get_trace_headers(ctx.raw_request.headers)
)
pooling_params = self._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
pooling_params = self._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available")
if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available")
for i, engine_prompt in enumerate(ctx.engine_prompts):
request_id_item = f"{ctx.request_id}-{i}"
for i, engine_prompt in enumerate(ctx.engine_prompts):
request_id_item = f"{ctx.request_id}-{i}"
self._log_inputs(
request_id_item,
engine_prompt,
params=pooling_params,
lora_request=ctx.lora_request,
)
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
self._log_inputs(
request_id_item,
engine_prompt,
params=pooling_params,
lora_request=ctx.lora_request,
)
generators.append(generator)
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
ctx.result_generator = merge_async_iterators(*generators)
generators.append(generator)
return None
ctx.result_generator = merge_async_iterators(*generators)
except Exception as e:
return self.create_error_response(e)
return None
async def _collect_batch(
self,
ctx: ServeContext,
) -> ErrorResponse | None:
"""Collect batch results from the result generator."""
try:
if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available")
if ctx.engine_prompts is None:
return self.create_error_response("Engine prompts not available")
num_prompts = len(ctx.engine_prompts)
final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts
num_prompts = len(ctx.engine_prompts)
final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts
if ctx.result_generator is None:
return self.create_error_response("Result generator not available")
if ctx.result_generator is None:
return self.create_error_response("Result generator not available")
async for i, res in ctx.result_generator:
final_res_batch[i] = res
if None in final_res_batch:
return self.create_error_response(
"Failed to generate results for all prompts"
)
async for i, res in ctx.result_generator:
final_res_batch[i] = res
ctx.final_res_batch = [res for res in final_res_batch if res is not None]
if None in final_res_batch:
return self.create_error_response(
"Failed to generate results for all prompts"
)
return None
ctx.final_res_batch = [res for res in final_res_batch if res is not None]
except Exception as e:
return self.create_error_response(e)
return None
@staticmethod
def create_error_response(
self,
message: str | Exception,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
param: str | None = None,
) -> ErrorResponse:
exc: Exception | None = None
if isinstance(message, Exception):
exc = message
from vllm.exceptions import VLLMValidationError
if isinstance(exc, VLLMValidationError):
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = exc.parameter
elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
# Common validation errors from user input
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
elif isinstance(exc, NotImplementedError):
err_type = "NotImplementedError"
status_code = HTTPStatus.NOT_IMPLEMENTED
param = None
elif exc.__class__.__name__ == "TemplateError":
# jinja2.TemplateError (avoid importing jinja2)
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
else:
err_type = "InternalServerError"
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
param = None
message = str(exc)
if self.log_error_stack:
exc_type, _, _ = sys.exc_info()
if exc_type is not None:
traceback.print_exc()
else:
traceback.print_stack()
return ErrorResponse(
error=ErrorInfo(
message=sanitize_message(message),
type=err_type,
code=status_code.value,
param=param,
)
)
return create_error_response(message, err_type, status_code, param)
def create_streaming_error_response(
self,
......@@ -680,16 +612,6 @@ class OpenAIServing:
)
raise GenerationError("Internal server error")
def _convert_generation_error_to_response(
self, e: GenerationError
) -> ErrorResponse:
"""Convert GenerationError to ErrorResponse."""
return self.create_error_response(
str(e),
err_type="InternalServerError",
status_code=e.status_code,
)
def _convert_generation_error_to_streaming_response(
self, e: GenerationError
) -> str:
......
......@@ -87,7 +87,6 @@ async def init_generate_state(
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
log_error_stack=args.log_error_stack,
)
if "generate" in supported_tasks
else None
......@@ -111,7 +110,6 @@ async def init_generate_state(
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
enable_log_deltas=args.enable_log_deltas,
log_error_stack=args.log_error_stack,
)
if any(task in supported_tasks for task in ("generate", "render"))
else None
......@@ -127,7 +125,6 @@ async def init_generate_state(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
log_error_stack=args.log_error_stack,
)
if any(task in supported_tasks for task in ("generate", "render"))
else None
......@@ -156,7 +153,6 @@ async def init_generate_state(
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
log_error_stack=args.log_error_stack,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_log_outputs=args.enable_log_outputs,
force_no_detokenize=args.tokens_only,
......
......@@ -68,7 +68,6 @@ def init_realtime_state(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
log_error_stack=args.log_error_stack,
)
if "realtime" in supported_tasks
else None
......
......@@ -33,13 +33,11 @@ class OpenAIServingRealtime(OpenAIServing):
models: OpenAIServingModels,
*,
request_logger: RequestLogger | None,
log_error_stack: bool = False,
):
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack,
)
self.task_type: Literal["realtime"] = "realtime"
......
......@@ -63,10 +63,8 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
return base_server.create_error_response(
message="The model does not support Responses API"
)
try:
generator = await handler.create_responses(request, raw_request)
except Exception as e:
generator = handler.create_error_response(e)
generator = await handler.create_responses(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(
......@@ -95,14 +93,11 @@ async def retrieve_responses(
message="The model does not support Responses API"
)
try:
response = await handler.retrieve_responses(
response_id,
starting_after=starting_after,
stream=stream,
)
except Exception as e:
response = handler.create_error_response(e)
response = await handler.retrieve_responses(
response_id,
starting_after=starting_after,
stream=stream,
)
if isinstance(response, ErrorResponse):
return JSONResponse(
......@@ -125,10 +120,7 @@ async def cancel_responses(response_id: str, raw_request: Request):
message="The model does not support Responses API"
)
try:
response = await handler.cancel_responses(response_id)
except Exception as e:
response = handler.create_error_response(e)
response = await handler.cancel_responses(response_id)
if isinstance(response, ErrorResponse):
return JSONResponse(
......
......@@ -11,7 +11,6 @@ from copy import copy
from http import HTTPStatus
from typing import Final
import jinja2
from fastapi import Request
from openai.types.responses import (
ResponseContentPartAddedEvent,
......@@ -174,14 +173,12 @@ class OpenAIServingResponses(OpenAIServing):
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
enable_log_outputs: bool = False,
log_error_stack: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
log_error_stack=log_error_stack,
)
self.chat_template = chat_template
......@@ -365,28 +362,15 @@ class OpenAIServingResponses(OpenAIServing):
else:
prev_response = None
try:
lora_request = self._maybe_get_adapters(request)
model_name = self.models.model_name(lora_request)
if self.use_harmony:
messages, engine_prompts = self._make_request_with_harmony(
request, prev_response
)
else:
messages, engine_prompts = await self._make_request(
request, prev_response
)
lora_request = self._maybe_get_adapters(request)
model_name = self.models.model_name(lora_request)
except (
ValueError,
TypeError,
RuntimeError,
jinja2.TemplateError,
NotImplementedError,
) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(e)
if self.use_harmony:
messages, engine_prompts = self._make_request_with_harmony(
request, prev_response
)
else:
messages, engine_prompts = await self._make_request(request, prev_response)
request_metadata = RequestResponseMetadata(request_id=request.request_id)
if raw_request:
......@@ -424,86 +408,83 @@ class OpenAIServingResponses(OpenAIServing):
else:
assert len(builtin_tool_list) == 0
available_tools = []
try:
tokenizer = self.renderer.get_tokenizer()
for engine_prompt in engine_prompts:
maybe_error = self._validate_generator_input(engine_prompt)
if maybe_error is not None:
return maybe_error
default_max_tokens = get_max_tokens(
max_model_len,
request.max_output_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
self.override_max_tokens,
)
tokenizer = self.renderer.get_tokenizer()
for engine_prompt in engine_prompts:
maybe_error = self._validate_generator_input(engine_prompt)
if maybe_error is not None:
return maybe_error
default_max_tokens = get_max_tokens(
max_model_len,
request.max_output_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
self.override_max_tokens,
)
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params
)
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
trace_headers = (
None
if raw_request is None
else await self._get_trace_headers(raw_request.headers)
)
context: ConversationContext
if self.use_harmony:
if request.stream:
context = StreamingHarmonyContext(messages, available_tools)
else:
context = HarmonyContext(messages, available_tools)
context: ConversationContext
if self.use_harmony:
if request.stream:
context = StreamingHarmonyContext(messages, available_tools)
else:
if envs.VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT:
# This is a feature in development for parsing
# tokens during generation instead of at the end
context = ParsableContext(
response_messages=messages,
tokenizer=tokenizer,
reasoning_parser_cls=self.parser.reasoning_parser_cls
if self.parser
else None,
request=request,
tool_parser_cls=self.parser.tool_parser_cls
if self.parser
else None,
available_tools=available_tools,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
)
else:
context = SimpleContext()
if self.parser and self.parser.reasoning_parser_cls is not None:
reasoning_parser = self.parser.reasoning_parser_cls(tokenizer)
if (
isinstance(
struct_out := sampling_params.structured_outputs,
StructuredOutputsParams,
)
and struct_out.all_non_structural_tag_constraints_none()
):
sampling_params.structured_outputs = replace(
struct_out,
structural_tag=reasoning_parser.prepare_structured_tag(
struct_out.structural_tag, self.tool_server
),
)
generator = self._generate_with_builtin_tools(
request_id=request.request_id,
engine_prompt=engine_prompt,
sampling_params=sampling_params,
context=context,
lora_request=lora_request,
priority=request.priority,
trace_headers=trace_headers,
)
generators.append(generator)
except ValueError as e:
return self.create_error_response(e)
context = HarmonyContext(messages, available_tools)
else:
if envs.VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT:
# This is a feature in development for parsing
# tokens during generation instead of at the end
context = ParsableContext(
response_messages=messages,
tokenizer=tokenizer,
reasoning_parser_cls=self.parser.reasoning_parser_cls
if self.parser
else None,
request=request,
tool_parser_cls=self.parser.tool_parser_cls
if self.parser
else None,
available_tools=available_tools,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
)
else:
context = SimpleContext()
if self.parser and self.parser.reasoning_parser_cls is not None:
reasoning_parser = self.parser.reasoning_parser_cls(tokenizer)
if (
isinstance(
struct_out := sampling_params.structured_outputs,
StructuredOutputsParams,
)
and struct_out.all_non_structural_tag_constraints_none()
):
sampling_params.structured_outputs = replace(
struct_out,
structural_tag=reasoning_parser.prepare_structured_tag(
struct_out.structural_tag, self.tool_server
),
)
generator = self._generate_with_builtin_tools(
request_id=request.request_id,
engine_prompt=engine_prompt,
sampling_params=sampling_params,
context=context,
lora_request=lora_request,
priority=request.priority,
trace_headers=trace_headers,
)
generators.append(generator)
assert len(generators) == 1
(result_generator,) = generators
......@@ -578,20 +559,15 @@ class OpenAIServingResponses(OpenAIServing):
request_metadata,
)
try:
return await self.responses_full_generator(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
)
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except Exception as e:
return self.create_error_response(e)
return await self.responses_full_generator(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
)
async def _make_request(
self,
......@@ -675,8 +651,6 @@ class OpenAIServingResponses(OpenAIServing):
pass
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
return self.create_error_response(e)
# NOTE: Implementation of status is still WIP, but for now
# we guarantee that if the status is not "completed", it is accurate.
......@@ -1129,16 +1103,11 @@ class OpenAIServingResponses(OpenAIServing):
new_event_signal = asyncio.Event()
self.event_store[request.request_id] = (event_deque, new_event_signal)
response = None
generator = self.responses_stream_generator(request, *args, **kwargs)
try:
generator = self.responses_stream_generator(request, *args, **kwargs)
async for event in generator:
event_deque.append(event)
new_event_signal.set() # Signal new event available
except GenerationError as e:
response = self._convert_generation_error_to_response(e)
except Exception as e:
logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(e)
finally:
new_event_signal.set()
......@@ -1157,13 +1126,7 @@ class OpenAIServingResponses(OpenAIServing):
*args,
**kwargs,
):
try:
response = await self.responses_full_generator(request, *args, **kwargs)
except GenerationError as e:
response = self._convert_generation_error_to_response(e)
except Exception as e:
logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(e)
response = await self.responses_full_generator(request, *args, **kwargs)
if isinstance(response, ErrorResponse):
# If the request has failed, update the status to "failed".
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment