Commit 48b67ba7 authored by Cyrus Leung's avatar Cyrus Leung Committed by DarkLight1337
Browse files

[Frontend] Standardize use of `create_error_response` (#32319)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 09f4264a
......@@ -540,14 +540,8 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
try:
generator = await handler.create_completion(request, raw_request)
except OverflowError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e)
) from e
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
......
......@@ -86,7 +86,7 @@ from vllm.entrypoints.responses_utils import (
construct_input_messages,
)
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.entrypoints.utils import _validate_truncation_size, sanitize_message
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import (
PromptComponents,
......@@ -760,11 +760,15 @@ class OpenAIServing:
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = exc.parameter
elif isinstance(exc, (ValueError, TypeError, RuntimeError)):
elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
# Common validation errors from user input
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
elif isinstance(exc, NotImplementedError):
err_type = "NotImplementedError"
status_code = HTTPStatus.NOT_IMPLEMENTED
param = None
elif exc.__class__.__name__ == "TemplateError":
# jinja2.TemplateError (avoid importing jinja2)
err_type = "BadRequestError"
......@@ -783,9 +787,10 @@ class OpenAIServing:
traceback.print_exc()
else:
traceback.print_stack()
return ErrorResponse(
error=ErrorInfo(
message=message,
message=sanitize_message(message),
type=err_type,
code=status_code.value,
param=param,
......
......@@ -16,6 +16,7 @@ from vllm.entrypoints.openai.protocol import (
ModelPermission,
UnloadLoRAAdapterRequest,
)
from vllm.entrypoints.utils import sanitize_message
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
......@@ -300,5 +301,9 @@ def create_error_response(
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> ErrorResponse:
return ErrorResponse(
error=ErrorInfo(message=message, type=err_type, code=status_code.value)
error=ErrorInfo(
message=sanitize_message(message),
type=err_type,
code=status_code.value,
)
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, Request
from starlette.responses import JSONResponse
from typing_extensions import assert_never
......@@ -36,9 +35,8 @@ async def create_classify(request: ClassificationRequest, raw_request: Request):
try:
generator = await handler.create_classify(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse
from typing_extensions import assert_never
......@@ -47,9 +47,7 @@ async def create_embedding(
try:
generator = await handler.create_embedding(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse
from typing_extensions import assert_never
......@@ -44,9 +44,8 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
try:
generator = await handler.create_pooling(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse
from typing_extensions import assert_never
......@@ -52,9 +52,8 @@ async def create_score(request: ScoreRequest, raw_request: Request):
try:
generator = await handler.create_score(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
......@@ -104,9 +103,8 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
try:
generator = await handler.do_rerank(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
......
......@@ -67,9 +67,8 @@ async def generate(request: GenerateRequest, raw_request: Request):
try:
generator = await handler.serve_tokens(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
......
......@@ -49,14 +49,8 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
try:
generator = await handler.create_tokenize(request, raw_request)
except NotImplementedError as e:
raise HTTPException(
status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e)
) from e
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
return handler.create_error_response(e)
if isinstance(generator, ErrorResponse):
return JSONResponse(
......
......@@ -7,7 +7,7 @@ import functools
import os
from argparse import Namespace
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any
import regex as re
from fastapi import Request
......@@ -22,18 +22,25 @@ from vllm.entrypoints.chat_utils import (
resolve_hf_chat_template,
resolve_mistral_chat_template,
)
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
StreamOptions,
)
from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
StreamOptions,
)
from vllm.entrypoints.openai.serving_models import LoRAModulePath
else:
ChatCompletionRequest = object
CompletionRequest = object
StreamOptions = object
LoRAModulePath = object
logger = init_logger(__name__)
VLLM_SUBCMD_PARSER_EPILOG = (
......@@ -206,7 +213,7 @@ def _validate_truncation_size(
def get_max_tokens(
max_model_len: int,
request: ChatCompletionRequest | CompletionRequest,
request: "ChatCompletionRequest | CompletionRequest",
input_length: int,
default_sampling_params: dict,
) -> int:
......@@ -227,6 +234,8 @@ def get_max_tokens(
def log_non_default_args(args: Namespace | EngineArgs):
from vllm.entrypoints.openai.cli_args import make_arg_parser
non_default_args = {}
# Handle Namespace
......@@ -255,7 +264,7 @@ def log_non_default_args(args: Namespace | EngineArgs):
def should_include_usage(
stream_options: StreamOptions | None, enable_force_include_usage: bool
stream_options: "StreamOptions | None", enable_force_include_usage: bool
) -> tuple[bool, bool]:
if stream_options:
include_usage = stream_options.include_usage or enable_force_include_usage
......@@ -270,6 +279,8 @@ def should_include_usage(
def process_lora_modules(
args_lora_modules: list[LoRAModulePath], default_mm_loras: dict[str, str] | None
) -> list[LoRAModulePath]:
from vllm.entrypoints.openai.serving_models import LoRAModulePath
lora_modules = args_lora_modules
if default_mm_loras:
default_mm_lora_paths = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment