Unverified Commit 1ab055ef authored by R3hankhan's avatar R3hankhan Committed by GitHub
Browse files

[OpenAI] Extend VLLMValidationError to additional validation parameters (#31870)


Signed-off-by: default avatarRehan Khan <Rehan.Khan7@ibm.com>
parent b665bbc2
...@@ -911,7 +911,7 @@ def build_app(args: Namespace) -> FastAPI: ...@@ -911,7 +911,7 @@ def build_app(args: Namespace) -> FastAPI:
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(_: Request, exc: RequestValidationError): async def validation_exception_handler(_: Request, exc: RequestValidationError):
from vllm.entrypoints.openai.protocol import VLLMValidationError from vllm.exceptions import VLLMValidationError
param = None param = None
for error in exc.errors(): for error in exc.errors():
......
...@@ -72,6 +72,7 @@ from pydantic import ( ...@@ -72,6 +72,7 @@ from pydantic import (
) )
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.sampling_params import ( from vllm.sampling_params import (
...@@ -131,36 +132,6 @@ class ErrorResponse(OpenAIBaseModel): ...@@ -131,36 +132,6 @@ class ErrorResponse(OpenAIBaseModel):
error: ErrorInfo error: ErrorInfo
class VLLMValidationError(ValueError):
"""vLLM-specific validation error for request validation failures.
Args:
message: The error message describing the validation failure.
parameter: Optional parameter name that failed validation.
value: Optional value that was rejected during validation.
"""
def __init__(
self,
message: str,
*,
parameter: str | None = None,
value: Any = None,
) -> None:
super().__init__(message)
self.parameter = parameter
self.value = value
def __str__(self):
base = super().__str__()
extras = []
if self.parameter is not None:
extras.append(f"parameter={self.parameter}")
if self.value is not None:
extras.append(f"value={self.value}")
return f"{base} ({', '.join(extras)})" if extras else base
class ModelPermission(OpenAIBaseModel): class ModelPermission(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission" object: str = "model_permission"
......
...@@ -140,16 +140,16 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -140,16 +140,16 @@ class OpenAIServingCompletion(OpenAIServing):
) )
except ValueError as e: except ValueError as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(e)
except TypeError as e: except TypeError as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(e)
except RuntimeError as e: except RuntimeError as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(e)
except jinja2.TemplateError as e: except jinja2.TemplateError as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(e)
# Extract data_parallel_rank from header (router can inject it) # Extract data_parallel_rank from header (router can inject it)
data_parallel_rank = self._get_data_parallel_rank(raw_request) data_parallel_rank = self._get_data_parallel_rank(raw_request)
......
...@@ -754,7 +754,7 @@ class OpenAIServing: ...@@ -754,7 +754,7 @@ class OpenAIServing:
if isinstance(message, Exception): if isinstance(message, Exception):
exc = message exc = message
from vllm.entrypoints.openai.protocol import VLLMValidationError from vllm.exceptions import VLLMValidationError
if isinstance(exc, VLLMValidationError): if isinstance(exc, VLLMValidationError):
err_type = "BadRequestError" err_type = "BadRequestError"
......
...@@ -373,7 +373,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -373,7 +373,7 @@ class OpenAIServingResponses(OpenAIServing):
NotImplementedError, NotImplementedError,
) as e: ) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(f"{e} {e.__cause__}") return self.create_error_response(e)
request_metadata = RequestResponseMetadata(request_id=request.request_id) request_metadata = RequestResponseMetadata(request_id=request.request_id)
if raw_request: if raw_request:
......
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
from pydantic import Field from pydantic import Field
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Custom exceptions for vLLM."""
from typing import Any
class VLLMValidationError(ValueError):
"""vLLM-specific validation error for request validation failures.
Args:
message: The error message describing the validation failure.
parameter: Optional parameter name that failed validation.
value: Optional value that was rejected during validation.
"""
def __init__(
self,
message: str,
*,
parameter: str | None = None,
value: Any = None,
) -> None:
super().__init__(message)
self.parameter = parameter
self.value = value
def __str__(self):
base = super().__str__()
extras = []
if self.parameter is not None:
extras.append(f"parameter={self.parameter}")
if self.value is not None:
extras.append(f"value={self.value}")
return f"{base} ({', '.join(extras)})" if extras else base
...@@ -11,6 +11,7 @@ from typing import Annotated, Any ...@@ -11,6 +11,7 @@ from typing import Annotated, Any
import msgspec import msgspec
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor from vllm.logits_process import LogitsProcessor
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -393,11 +394,17 @@ class SamplingParams( ...@@ -393,11 +394,17 @@ class SamplingParams(
f"{self.repetition_penalty}." f"{self.repetition_penalty}."
) )
if self.temperature < 0.0: if self.temperature < 0.0:
raise ValueError( raise VLLMValidationError(
f"temperature must be non-negative, got {self.temperature}." f"temperature must be non-negative, got {self.temperature}.",
parameter="temperature",
value=self.temperature,
) )
if not 0.0 < self.top_p <= 1.0: if not 0.0 < self.top_p <= 1.0:
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") raise VLLMValidationError(
f"top_p must be in (0, 1], got {self.top_p}.",
parameter="top_p",
value=self.top_p,
)
# quietly accept -1 as disabled, but prefer 0 # quietly accept -1 as disabled, but prefer 0
if self.top_k < -1: if self.top_k < -1:
raise ValueError( raise ValueError(
...@@ -410,7 +417,11 @@ class SamplingParams( ...@@ -410,7 +417,11 @@ class SamplingParams(
if not 0.0 <= self.min_p <= 1.0: if not 0.0 <= self.min_p <= 1.0:
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.") raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
if self.max_tokens is not None and self.max_tokens < 1: if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.") raise VLLMValidationError(
f"max_tokens must be at least 1, got {self.max_tokens}.",
parameter="max_tokens",
value=self.max_tokens,
)
if self.min_tokens < 0: if self.min_tokens < 0:
raise ValueError( raise ValueError(
f"min_tokens must be greater than or equal to 0, got {self.min_tokens}." f"min_tokens must be greater than or equal to 0, got {self.min_tokens}."
...@@ -421,24 +432,30 @@ class SamplingParams( ...@@ -421,24 +432,30 @@ class SamplingParams(
f"max_tokens={self.max_tokens}, got {self.min_tokens}." f"max_tokens={self.max_tokens}, got {self.min_tokens}."
) )
if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0: if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0:
raise ValueError( raise VLLMValidationError(
f"logprobs must be non-negative or -1, got {self.logprobs}." f"logprobs must be non-negative or -1, got {self.logprobs}.",
parameter="logprobs",
value=self.logprobs,
) )
if ( if (
self.prompt_logprobs is not None self.prompt_logprobs is not None
and self.prompt_logprobs != -1 and self.prompt_logprobs != -1
and self.prompt_logprobs < 0 and self.prompt_logprobs < 0
): ):
raise ValueError( raise VLLMValidationError(
f"prompt_logprobs must be non-negative or -1, got " f"prompt_logprobs must be non-negative or -1, got "
f"{self.prompt_logprobs}." f"{self.prompt_logprobs}.",
parameter="prompt_logprobs",
value=self.prompt_logprobs,
) )
if self.truncate_prompt_tokens is not None and ( if self.truncate_prompt_tokens is not None and (
self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1 self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1
): ):
raise ValueError( raise VLLMValidationError(
f"truncate_prompt_tokens must be an integer >= 1 or -1, " f"truncate_prompt_tokens must be an integer >= 1 or -1, "
f"got {self.truncate_prompt_tokens}" f"got {self.truncate_prompt_tokens}",
parameter="truncate_prompt_tokens",
value=self.truncate_prompt_tokens,
) )
assert isinstance(self.stop_token_ids, list) assert isinstance(self.stop_token_ids, list)
if not all(isinstance(st_id, int) for st_id in self.stop_token_ids): if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
...@@ -516,12 +533,14 @@ class SamplingParams( ...@@ -516,12 +533,14 @@ class SamplingParams(
if token_id < 0 or token_id > tokenizer.max_token_id if token_id < 0 or token_id > tokenizer.max_token_id
] ]
if len(invalid_token_ids) > 0: if len(invalid_token_ids) > 0:
raise ValueError( raise VLLMValidationError(
f"The model vocabulary size is {tokenizer.max_token_id + 1}," f"The model vocabulary size is {tokenizer.max_token_id + 1},"
f" but the following tokens" f" but the following tokens"
f" were specified as bad: {invalid_token_ids}." f" were specified as bad: {invalid_token_ids}."
f" All token id values should be integers satisfying:" f" All token id values should be integers satisfying:"
f" 0 <= token_id <= {tokenizer.max_token_id}." f" 0 <= token_id <= {tokenizer.max_token_id}.",
parameter="bad_words",
value=self.bad_words,
) )
@cached_property @cached_property
......
...@@ -6,6 +6,7 @@ from collections.abc import Mapping ...@@ -6,6 +6,7 @@ from collections.abc import Mapping
from typing import Any, Literal, cast from typing import Any, Literal, cast
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.exceptions import VLLMValidationError
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
...@@ -83,9 +84,11 @@ class InputProcessor: ...@@ -83,9 +84,11 @@ class InputProcessor:
if num_logprobs == -1: if num_logprobs == -1:
num_logprobs = self.model_config.get_vocab_size() num_logprobs = self.model_config.get_vocab_size()
if num_logprobs > max_logprobs: if num_logprobs > max_logprobs:
raise ValueError( raise VLLMValidationError(
f"Requested sample logprobs of {num_logprobs}, " f"Requested sample logprobs of {num_logprobs}, "
f"which is greater than max allowed: {max_logprobs}" f"which is greater than max allowed: {max_logprobs}",
parameter="logprobs",
value=num_logprobs,
) )
# Validate prompt logprobs. # Validate prompt logprobs.
...@@ -94,9 +97,11 @@ class InputProcessor: ...@@ -94,9 +97,11 @@ class InputProcessor:
if num_prompt_logprobs == -1: if num_prompt_logprobs == -1:
num_prompt_logprobs = self.model_config.get_vocab_size() num_prompt_logprobs = self.model_config.get_vocab_size()
if num_prompt_logprobs > max_logprobs: if num_prompt_logprobs > max_logprobs:
raise ValueError( raise VLLMValidationError(
f"Requested prompt logprobs of {num_prompt_logprobs}, " f"Requested prompt logprobs of {num_prompt_logprobs}, "
f"which is greater than max allowed: {max_logprobs}" f"which is greater than max allowed: {max_logprobs}",
parameter="prompt_logprobs",
value=num_prompt_logprobs,
) )
def _validate_sampling_params( def _validate_sampling_params(
...@@ -134,9 +139,11 @@ class InputProcessor: ...@@ -134,9 +139,11 @@ class InputProcessor:
invalid_token_ids.append(token_id) invalid_token_ids.append(token_id)
if invalid_token_ids: if invalid_token_ids:
raise ValueError( raise VLLMValidationError(
f"token_id(s) {invalid_token_ids} in logit_bias contain " f"token_id(s) {invalid_token_ids} in logit_bias contain "
f"out-of-vocab token ids. Vocabulary size: {vocab_size}" f"out-of-vocab token ids. Vocabulary size: {vocab_size}",
parameter="logit_bias",
value=invalid_token_ids,
) )
def _validate_supported_sampling_params( def _validate_supported_sampling_params(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment