"vllm/vscode:/vscode.git/clone" did not exist on "99caa4910651754f3f68de518ca42349c8c424d1"
Commit 48b67ba7 authored by Cyrus Leung's avatar Cyrus Leung Committed by DarkLight1337
Browse files

[Frontend] Standardize use of `create_error_response` (#32319)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 09f4264a
...@@ -540,14 +540,8 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ...@@ -540,14 +540,8 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
try: try:
generator = await handler.create_completion(request, raw_request) generator = await handler.create_completion(request, raw_request)
except OverflowError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e)
) from e
except Exception as e: except Exception as e:
raise HTTPException( return handler.create_error_response(e)
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
......
...@@ -86,7 +86,7 @@ from vllm.entrypoints.responses_utils import ( ...@@ -86,7 +86,7 @@ from vllm.entrypoints.responses_utils import (
construct_input_messages, construct_input_messages,
) )
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size, sanitize_message
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import ( from vllm.inputs.parse import (
PromptComponents, PromptComponents,
...@@ -760,11 +760,15 @@ class OpenAIServing: ...@@ -760,11 +760,15 @@ class OpenAIServing:
err_type = "BadRequestError" err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST status_code = HTTPStatus.BAD_REQUEST
param = exc.parameter param = exc.parameter
elif isinstance(exc, (ValueError, TypeError, RuntimeError)): elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
# Common validation errors from user input # Common validation errors from user input
err_type = "BadRequestError" err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST status_code = HTTPStatus.BAD_REQUEST
param = None param = None
elif isinstance(exc, NotImplementedError):
err_type = "NotImplementedError"
status_code = HTTPStatus.NOT_IMPLEMENTED
param = None
elif exc.__class__.__name__ == "TemplateError": elif exc.__class__.__name__ == "TemplateError":
# jinja2.TemplateError (avoid importing jinja2) # jinja2.TemplateError (avoid importing jinja2)
err_type = "BadRequestError" err_type = "BadRequestError"
...@@ -783,9 +787,10 @@ class OpenAIServing: ...@@ -783,9 +787,10 @@ class OpenAIServing:
traceback.print_exc() traceback.print_exc()
else: else:
traceback.print_stack() traceback.print_stack()
return ErrorResponse( return ErrorResponse(
error=ErrorInfo( error=ErrorInfo(
message=message, message=sanitize_message(message),
type=err_type, type=err_type,
code=status_code.value, code=status_code.value,
param=param, param=param,
......
...@@ -16,6 +16,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -16,6 +16,7 @@ from vllm.entrypoints.openai.protocol import (
ModelPermission, ModelPermission,
UnloadLoRAAdapterRequest, UnloadLoRAAdapterRequest,
) )
from vllm.entrypoints.utils import sanitize_message
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
...@@ -300,5 +301,9 @@ def create_error_response( ...@@ -300,5 +301,9 @@ def create_error_response(
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> ErrorResponse: ) -> ErrorResponse:
return ErrorResponse( return ErrorResponse(
error=ErrorInfo(message=message, type=err_type, code=status_code.value) error=ErrorInfo(
message=sanitize_message(message),
type=err_type,
code=status_code.value,
)
) )
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, Request
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from typing_extensions import assert_never from typing_extensions import assert_never
...@@ -36,9 +35,8 @@ async def create_classify(request: ClassificationRequest, raw_request: Request): ...@@ -36,9 +35,8 @@ async def create_classify(request: ClassificationRequest, raw_request: Request):
try: try:
generator = await handler.create_classify(request, raw_request) generator = await handler.create_classify(request, raw_request)
except Exception as e: except Exception as e:
raise HTTPException( return handler.create_error_response(e)
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code content=generator.model_dump(), status_code=generator.error.code
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from typing_extensions import assert_never from typing_extensions import assert_never
...@@ -47,9 +47,7 @@ async def create_embedding( ...@@ -47,9 +47,7 @@ async def create_embedding(
try: try:
generator = await handler.create_embedding(request, raw_request) generator = await handler.create_embedding(request, raw_request)
except Exception as e: except Exception as e:
raise HTTPException( return handler.create_error_response(e)
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from typing_extensions import assert_never from typing_extensions import assert_never
...@@ -44,9 +44,8 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): ...@@ -44,9 +44,8 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
try: try:
generator = await handler.create_pooling(request, raw_request) generator = await handler.create_pooling(request, raw_request)
except Exception as e: except Exception as e:
raise HTTPException( return handler.create_error_response(e)
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code content=generator.model_dump(), status_code=generator.error.code
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from typing_extensions import assert_never from typing_extensions import assert_never
...@@ -52,9 +52,8 @@ async def create_score(request: ScoreRequest, raw_request: Request): ...@@ -52,9 +52,8 @@ async def create_score(request: ScoreRequest, raw_request: Request):
try: try:
generator = await handler.create_score(request, raw_request) generator = await handler.create_score(request, raw_request)
except Exception as e: except Exception as e:
raise HTTPException( return handler.create_error_response(e)
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code content=generator.model_dump(), status_code=generator.error.code
...@@ -104,9 +103,8 @@ async def do_rerank(request: RerankRequest, raw_request: Request): ...@@ -104,9 +103,8 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
try: try:
generator = await handler.do_rerank(request, raw_request) generator = await handler.do_rerank(request, raw_request)
except Exception as e: except Exception as e:
raise HTTPException( return handler.create_error_response(e)
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code content=generator.model_dump(), status_code=generator.error.code
......
...@@ -67,9 +67,8 @@ async def generate(request: GenerateRequest, raw_request: Request): ...@@ -67,9 +67,8 @@ async def generate(request: GenerateRequest, raw_request: Request):
try: try:
generator = await handler.serve_tokens(request, raw_request) generator = await handler.serve_tokens(request, raw_request)
except Exception as e: except Exception as e:
raise HTTPException( return handler.create_error_response(e)
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code content=generator.model_dump(), status_code=generator.error.code
......
...@@ -49,14 +49,8 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): ...@@ -49,14 +49,8 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
try: try:
generator = await handler.create_tokenize(request, raw_request) generator = await handler.create_tokenize(request, raw_request)
except NotImplementedError as e:
raise HTTPException(
status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e)
) from e
except Exception as e: except Exception as e:
raise HTTPException( return handler.create_error_response(e)
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse( return JSONResponse(
......
...@@ -7,7 +7,7 @@ import functools ...@@ -7,7 +7,7 @@ import functools
import os import os
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
from typing import Any from typing import TYPE_CHECKING, Any
import regex as re import regex as re
from fastapi import Request from fastapi import Request
...@@ -22,18 +22,25 @@ from vllm.entrypoints.chat_utils import ( ...@@ -22,18 +22,25 @@ from vllm.entrypoints.chat_utils import (
resolve_hf_chat_template, resolve_hf_chat_template,
resolve_mistral_chat_template, resolve_mistral_chat_template,
) )
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
StreamOptions,
)
from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
StreamOptions,
)
from vllm.entrypoints.openai.serving_models import LoRAModulePath
else:
ChatCompletionRequest = object
CompletionRequest = object
StreamOptions = object
LoRAModulePath = object
logger = init_logger(__name__) logger = init_logger(__name__)
VLLM_SUBCMD_PARSER_EPILOG = ( VLLM_SUBCMD_PARSER_EPILOG = (
...@@ -206,7 +213,7 @@ def _validate_truncation_size( ...@@ -206,7 +213,7 @@ def _validate_truncation_size(
def get_max_tokens( def get_max_tokens(
max_model_len: int, max_model_len: int,
request: ChatCompletionRequest | CompletionRequest, request: "ChatCompletionRequest | CompletionRequest",
input_length: int, input_length: int,
default_sampling_params: dict, default_sampling_params: dict,
) -> int: ) -> int:
...@@ -227,6 +234,8 @@ def get_max_tokens( ...@@ -227,6 +234,8 @@ def get_max_tokens(
def log_non_default_args(args: Namespace | EngineArgs): def log_non_default_args(args: Namespace | EngineArgs):
from vllm.entrypoints.openai.cli_args import make_arg_parser
non_default_args = {} non_default_args = {}
# Handle Namespace # Handle Namespace
...@@ -255,7 +264,7 @@ def log_non_default_args(args: Namespace | EngineArgs): ...@@ -255,7 +264,7 @@ def log_non_default_args(args: Namespace | EngineArgs):
def should_include_usage( def should_include_usage(
stream_options: StreamOptions | None, enable_force_include_usage: bool stream_options: "StreamOptions | None", enable_force_include_usage: bool
) -> tuple[bool, bool]: ) -> tuple[bool, bool]:
if stream_options: if stream_options:
include_usage = stream_options.include_usage or enable_force_include_usage include_usage = stream_options.include_usage or enable_force_include_usage
...@@ -270,6 +279,8 @@ def should_include_usage( ...@@ -270,6 +279,8 @@ def should_include_usage(
def process_lora_modules( def process_lora_modules(
args_lora_modules: list[LoRAModulePath], default_mm_loras: dict[str, str] | None args_lora_modules: list[LoRAModulePath], default_mm_loras: dict[str, str] | None
) -> list[LoRAModulePath]: ) -> list[LoRAModulePath]:
from vllm.entrypoints.openai.serving_models import LoRAModulePath
lora_modules = args_lora_modules lora_modules = args_lora_modules
if default_mm_loras: if default_mm_loras:
default_mm_lora_paths = [ default_mm_lora_paths = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment