Unverified Commit 3f28174c authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Standardize use of `create_error_response` (#32319)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 769d0629
......@@ -368,14 +368,8 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
try:
generator = await handler.create_completion(request, raw_request)
except OverflowError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e)
) from e
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
......
......@@ -4,7 +4,7 @@
from http import HTTPStatus
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi import APIRouter, Depends, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.openai.chat_completion.protocol import (
......@@ -53,12 +53,12 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
return base_server.create_error_response(
message="The model does not support Chat Completions API"
)
try:
generator = await handler.create_chat_completion(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
......
......@@ -94,7 +94,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest,
TokenizeResponse,
)
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.entrypoints.utils import _validate_truncation_size, sanitize_message
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import (
PromptComponents,
......@@ -768,11 +768,15 @@ class OpenAIServing:
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = exc.parameter
elif isinstance(exc, (ValueError, TypeError, RuntimeError)):
elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
# Common validation errors from user input
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
elif isinstance(exc, NotImplementedError):
err_type = "NotImplementedError"
status_code = HTTPStatus.NOT_IMPLEMENTED
param = None
elif exc.__class__.__name__ == "TemplateError":
# jinja2.TemplateError (avoid importing jinja2)
err_type = "BadRequestError"
......@@ -791,9 +795,10 @@ class OpenAIServing:
traceback.print_exc()
else:
traceback.print_stack()
return ErrorResponse(
error=ErrorInfo(
message=message,
message=sanitize_message(message),
type=err_type,
code=status_code.value,
param=param,
......
......@@ -5,7 +5,7 @@
from collections.abc import AsyncGenerator
from http import HTTPStatus
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi import APIRouter, Depends, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
......@@ -64,9 +64,7 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
try:
generator = await handler.create_responses(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
......@@ -101,9 +99,7 @@ async def retrieve_responses(
stream=stream,
)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(response, ErrorResponse):
return JSONResponse(
......@@ -128,9 +124,7 @@ async def cancel_responses(response_id: str, raw_request: Request):
try:
response = await handler.cancel_responses(response_id)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(response, ErrorResponse):
return JSONResponse(
......
......@@ -18,6 +18,7 @@ from vllm.entrypoints.serve.lora.protocol import (
LoadLoRAAdapterRequest,
UnloadLoRAAdapterRequest,
)
from vllm.entrypoints.utils import sanitize_message
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
......@@ -302,5 +303,9 @@ def create_error_response(
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> ErrorResponse:
return ErrorResponse(
error=ErrorInfo(message=message, type=err_type, code=status_code.value)
error=ErrorInfo(
message=sanitize_message(message),
type=err_type,
code=status_code.value,
)
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, Request
from starlette.responses import JSONResponse
from typing_extensions import assert_never
......@@ -36,9 +35,8 @@ async def create_classify(request: ClassificationRequest, raw_request: Request):
try:
generator = await handler.create_classify(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse
from typing_extensions import assert_never
......@@ -47,9 +47,7 @@ async def create_embedding(
try:
generator = await handler.create_embedding(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse
from typing_extensions import assert_never
......@@ -44,9 +44,8 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
try:
generator = await handler.create_pooling(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse
from typing_extensions import assert_never
......@@ -52,9 +52,8 @@ async def create_score(request: ScoreRequest, raw_request: Request):
try:
generator = await handler.create_score(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
......@@ -104,9 +103,8 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
try:
generator = await handler.do_rerank(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
......
......@@ -67,9 +67,8 @@ async def generate(request: GenerateRequest, raw_request: Request):
try:
generator = await handler.serve_tokens(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
......
......@@ -51,14 +51,8 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
try:
generator = await handler.create_tokenize(request, raw_request)
except NotImplementedError as e:
raise HTTPException(
status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e)
) from e
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
......
......@@ -7,7 +7,7 @@ import functools
import os
from argparse import Namespace
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any
import regex as re
from fastapi import Request
......@@ -22,20 +22,27 @@ from vllm.entrypoints.chat_utils import (
resolve_hf_chat_template,
resolve_mistral_chat_template,
)
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.engine.protocol import (
CompletionRequest,
StreamOptions,
)
from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
CompletionRequest,
StreamOptions,
)
from vllm.entrypoints.openai.serving_models import LoRAModulePath
else:
ChatCompletionRequest = object
CompletionRequest = object
StreamOptions = object
LoRAModulePath = object
logger = init_logger(__name__)
VLLM_SUBCMD_PARSER_EPILOG = (
......@@ -208,7 +215,7 @@ def _validate_truncation_size(
def get_max_tokens(
max_model_len: int,
request: ChatCompletionRequest | CompletionRequest,
request: "ChatCompletionRequest | CompletionRequest",
input_length: int,
default_sampling_params: dict,
) -> int:
......@@ -229,6 +236,8 @@ def get_max_tokens(
def log_non_default_args(args: Namespace | EngineArgs):
from vllm.entrypoints.openai.cli_args import make_arg_parser
non_default_args = {}
# Handle Namespace
......@@ -257,7 +266,7 @@ def log_non_default_args(args: Namespace | EngineArgs):
def should_include_usage(
stream_options: StreamOptions | None, enable_force_include_usage: bool
stream_options: "StreamOptions | None", enable_force_include_usage: bool
) -> tuple[bool, bool]:
if stream_options:
include_usage = stream_options.include_usage or enable_force_include_usage
......@@ -272,6 +281,8 @@ def should_include_usage(
def process_lora_modules(
args_lora_modules: list[LoRAModulePath], default_mm_loras: dict[str, str] | None
) -> list[LoRAModulePath]:
from vllm.entrypoints.openai.serving_models import LoRAModulePath
lora_modules = args_lora_modules
if default_mm_loras:
default_mm_lora_paths = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment