Unverified Commit 769f27e7 authored by R3hankhan's avatar R3hankhan Committed by GitHub
Browse files

[OpenAI] Add parameter metadata to validation errors (#30134)


Signed-off-by: default avatarRehan Khan <Rehan.Khan7@ibm.com>
parent 23daef54
...@@ -909,6 +909,16 @@ def build_app(args: Namespace) -> FastAPI: ...@@ -909,6 +909,16 @@ def build_app(args: Namespace) -> FastAPI:
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(_: Request, exc: RequestValidationError): async def validation_exception_handler(_: Request, exc: RequestValidationError):
from vllm.entrypoints.openai.protocol import VLLMValidationError
param = None
for error in exc.errors():
if "ctx" in error and "error" in error["ctx"]:
ctx_error = error["ctx"]["error"]
if isinstance(ctx_error, VLLMValidationError):
param = ctx_error.parameter
break
exc_str = str(exc) exc_str = str(exc)
errors_str = str(exc.errors()) errors_str = str(exc.errors())
...@@ -922,6 +932,7 @@ def build_app(args: Namespace) -> FastAPI: ...@@ -922,6 +932,7 @@ def build_app(args: Namespace) -> FastAPI:
message=message, message=message,
type=HTTPStatus.BAD_REQUEST.phrase, type=HTTPStatus.BAD_REQUEST.phrase,
code=HTTPStatus.BAD_REQUEST, code=HTTPStatus.BAD_REQUEST,
param=param,
) )
) )
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
......
...@@ -131,6 +131,36 @@ class ErrorResponse(OpenAIBaseModel): ...@@ -131,6 +131,36 @@ class ErrorResponse(OpenAIBaseModel):
error: ErrorInfo error: ErrorInfo
class VLLMValidationError(ValueError):
"""vLLM-specific validation error for request validation failures.
Args:
message: The error message describing the validation failure.
parameter: Optional parameter name that failed validation.
value: Optional value that was rejected during validation.
"""
def __init__(
self,
message: str,
*,
parameter: str | None = None,
value: Any = None,
) -> None:
super().__init__(message)
self.parameter = parameter
self.value = value
def __str__(self):
base = super().__str__()
extras = []
if self.parameter is not None:
extras.append(f"parameter={self.parameter}")
if self.value is not None:
extras.append(f"value={self.value}")
return f"{base} ({', '.join(extras)})" if extras else base
class ModelPermission(OpenAIBaseModel): class ModelPermission(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission" object: str = "model_permission"
...@@ -466,7 +496,9 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -466,7 +496,9 @@ class ResponsesRequest(OpenAIBaseModel):
@model_validator(mode="before") @model_validator(mode="before")
def validate_prompt(cls, data): def validate_prompt(cls, data):
if data.get("prompt") is not None: if data.get("prompt") is not None:
raise ValueError("prompt template is not supported") raise VLLMValidationError(
"prompt template is not supported", parameter="prompt"
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")
...@@ -850,7 +882,10 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -850,7 +882,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
@classmethod @classmethod
def validate_stream_options(cls, data): def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"): if data.get("stream_options") and not data.get("stream"):
raise ValueError("Stream options can only be defined when `stream=True`.") raise VLLMValidationError(
"Stream options can only be defined when `stream=True`.",
parameter="stream_options",
)
return data return data
...@@ -859,19 +894,29 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -859,19 +894,29 @@ class ChatCompletionRequest(OpenAIBaseModel):
def check_logprobs(cls, data): def check_logprobs(cls, data):
if (prompt_logprobs := data.get("prompt_logprobs")) is not None: if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1): if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
raise ValueError( raise VLLMValidationError(
"`prompt_logprobs` are not available when `stream=True`." "`prompt_logprobs` are not available when `stream=True`.",
parameter="prompt_logprobs",
) )
if prompt_logprobs < 0 and prompt_logprobs != -1: if prompt_logprobs < 0 and prompt_logprobs != -1:
raise ValueError("`prompt_logprobs` must be a positive value or -1.") raise VLLMValidationError(
"`prompt_logprobs` must be a positive value or -1.",
parameter="prompt_logprobs",
value=prompt_logprobs,
)
if (top_logprobs := data.get("top_logprobs")) is not None: if (top_logprobs := data.get("top_logprobs")) is not None:
if top_logprobs < 0 and top_logprobs != -1: if top_logprobs < 0 and top_logprobs != -1:
raise ValueError("`top_logprobs` must be a positive value or -1.") raise VLLMValidationError(
"`top_logprobs` must be a positive value or -1.",
parameter="top_logprobs",
value=top_logprobs,
)
if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"): if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"):
raise ValueError( raise VLLMValidationError(
"when using `top_logprobs`, `logprobs` must be set to true." "when using `top_logprobs`, `logprobs` must be set to true.",
parameter="top_logprobs",
) )
return data return data
...@@ -1285,9 +1330,10 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -1285,9 +1330,10 @@ class CompletionRequest(OpenAIBaseModel):
for k in ("json", "regex", "choice") for k in ("json", "regex", "choice")
) )
if count > 1: if count > 1:
raise ValueError( raise VLLMValidationError(
"You can only use one kind of constraints for structured " "You can only use one kind of constraints for structured "
"outputs ('json', 'regex' or 'choice')." "outputs ('json', 'regex' or 'choice').",
parameter="structured_outputs",
) )
return data return data
...@@ -1296,14 +1342,23 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -1296,14 +1342,23 @@ class CompletionRequest(OpenAIBaseModel):
def check_logprobs(cls, data): def check_logprobs(cls, data):
if (prompt_logprobs := data.get("prompt_logprobs")) is not None: if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1): if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1):
raise ValueError( raise VLLMValidationError(
"`prompt_logprobs` are not available when `stream=True`." "`prompt_logprobs` are not available when `stream=True`.",
parameter="prompt_logprobs",
) )
if prompt_logprobs < 0 and prompt_logprobs != -1: if prompt_logprobs < 0 and prompt_logprobs != -1:
raise ValueError("`prompt_logprobs` must be a positive value or -1.") raise VLLMValidationError(
"`prompt_logprobs` must be a positive value or -1.",
parameter="prompt_logprobs",
value=prompt_logprobs,
)
if (logprobs := data.get("logprobs")) is not None and logprobs < 0: if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
raise ValueError("`logprobs` must be a positive value.") raise VLLMValidationError(
"`logprobs` must be a positive value.",
parameter="logprobs",
value=logprobs,
)
return data return data
...@@ -1311,7 +1366,10 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -1311,7 +1366,10 @@ class CompletionRequest(OpenAIBaseModel):
@classmethod @classmethod
def validate_stream_options(cls, data): def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"): if data.get("stream_options") and not data.get("stream"):
raise ValueError("Stream options can only be defined when `stream=True`.") raise VLLMValidationError(
"Stream options can only be defined when `stream=True`.",
parameter="stream_options",
)
return data return data
...@@ -2138,7 +2196,15 @@ class TranscriptionRequest(OpenAIBaseModel): ...@@ -2138,7 +2196,15 @@ class TranscriptionRequest(OpenAIBaseModel):
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
stream = data.get("stream", False) stream = data.get("stream", False)
if any(bool(data.get(so, False)) for so in stream_opts) and not stream: if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
raise ValueError("Stream options can only be defined when `stream=True`.") # Find which specific stream option was set
invalid_param = next(
(so for so in stream_opts if data.get(so, False)),
"stream_include_usage",
)
raise VLLMValidationError(
"Stream options can only be defined when `stream=True`.",
parameter=invalid_param,
)
return data return data
...@@ -2351,7 +2417,15 @@ class TranslationRequest(OpenAIBaseModel): ...@@ -2351,7 +2417,15 @@ class TranslationRequest(OpenAIBaseModel):
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
stream = data.get("stream", False) stream = data.get("stream", False)
if any(bool(data.get(so, False)) for so in stream_opts) and not stream: if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
raise ValueError("Stream options can only be defined when `stream=True`.") # Find which specific stream option was set
invalid_param = next(
(so for so in stream_opts if data.get(so, False)),
"stream_include_usage",
)
raise VLLMValidationError(
"Stream options can only be defined when `stream=True`.",
parameter=invalid_param,
)
return data return data
......
...@@ -417,8 +417,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -417,8 +417,7 @@ class OpenAIServingChat(OpenAIServing):
generators.append(generator) generators.append(generator)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error return self.create_error_response(e)
return self.create_error_response(str(e))
assert len(generators) == 1 assert len(generators) == 1
(result_generator,) = generators (result_generator,) = generators
...@@ -448,8 +447,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -448,8 +447,7 @@ class OpenAIServingChat(OpenAIServing):
except GenerationError as e: except GenerationError as e:
return self._convert_generation_error_to_response(e) return self._convert_generation_error_to_response(e)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error return self.create_error_response(e)
return self.create_error_response(str(e))
def get_chat_request_role(self, request: ChatCompletionRequest) -> str: def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt: if request.add_generation_prompt:
...@@ -682,7 +680,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -682,7 +680,7 @@ class OpenAIServingChat(OpenAIServing):
tool_parsers = [None] * num_choices tool_parsers = [None] * num_choices
except Exception as e: except Exception as e:
logger.exception("Error in tool parser creation.") logger.exception("Error in tool parser creation.")
data = self.create_streaming_error_response(str(e)) data = self.create_streaming_error_response(e)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return return
...@@ -1328,9 +1326,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1328,9 +1326,8 @@ class OpenAIServingChat(OpenAIServing):
except GenerationError as e: except GenerationError as e:
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n" yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
except Exception as e: except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in chat completion stream generator.") logger.exception("Error in chat completion stream generator.")
data = self.create_streaming_error_response(str(e)) data = self.create_streaming_error_response(e)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished # Send the final done message after all response.n are finished
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
...@@ -1354,8 +1351,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1354,8 +1351,7 @@ class OpenAIServingChat(OpenAIServing):
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error return self.create_error_response(e)
return self.create_error_response(str(e))
assert final_res is not None assert final_res is not None
......
...@@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import (
PromptTokenUsageInfo, PromptTokenUsageInfo,
RequestResponseMetadata, RequestResponseMetadata,
UsageInfo, UsageInfo,
VLLMValidationError,
) )
from vllm.entrypoints.openai.serving_engine import ( from vllm.entrypoints.openai.serving_engine import (
GenerationError, GenerationError,
...@@ -247,8 +248,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -247,8 +248,7 @@ class OpenAIServingCompletion(OpenAIServing):
generators.append(generator) generators.append(generator)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error return self.create_error_response(e)
return self.create_error_response(str(e))
result_generator = merge_async_iterators(*generators) result_generator = merge_async_iterators(*generators)
...@@ -308,8 +308,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -308,8 +308,7 @@ class OpenAIServingCompletion(OpenAIServing):
except GenerationError as e: except GenerationError as e:
return self._convert_generation_error_to_response(e) return self._convert_generation_error_to_response(e)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error return self.create_error_response(e)
return self.create_error_response(str(e))
# When user requests streaming but we don't stream, we still need to # When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event. # return a streaming response with a single event.
...@@ -510,9 +509,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -510,9 +509,8 @@ class OpenAIServingCompletion(OpenAIServing):
except GenerationError as e: except GenerationError as e:
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n" yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
except Exception as e: except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in completion stream generator.") logger.exception("Error in completion stream generator.")
data = self.create_streaming_error_response(str(e)) data = self.create_streaming_error_response(e)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
...@@ -660,8 +658,11 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -660,8 +658,11 @@ class OpenAIServingCompletion(OpenAIServing):
token = f"token_id:{token_id}" token = f"token_id:{token_id}"
else: else:
if tokenizer is None: if tokenizer is None:
raise ValueError( raise VLLMValidationError(
"Unable to get tokenizer because `skip_tokenizer_init=True`" "Unable to get tokenizer because "
"`skip_tokenizer_init=True`",
parameter="skip_tokenizer_init",
value=True,
) )
token = tokenizer.decode(token_id) token = tokenizer.decode(token_id)
...@@ -720,6 +721,15 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -720,6 +721,15 @@ class OpenAIServingCompletion(OpenAIServing):
request: CompletionRequest, request: CompletionRequest,
max_input_length: int | None = None, max_input_length: int | None = None,
) -> RenderConfig: ) -> RenderConfig:
# Validate max_tokens before using it
if request.max_tokens is not None and request.max_tokens > self.max_model_len:
raise VLLMValidationError(
f"'max_tokens' ({request.max_tokens}) cannot be greater than "
f"the model's maximum context length ({self.max_model_len}).",
parameter="max_tokens",
value=request.max_tokens,
)
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0) max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
return RenderConfig( return RenderConfig(
max_length=max_input_tokens_len, max_length=max_input_tokens_len,
......
...@@ -57,6 +57,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -57,6 +57,7 @@ from vllm.entrypoints.openai.protocol import (
TranscriptionRequest, TranscriptionRequest,
TranscriptionResponse, TranscriptionResponse,
TranslationRequest, TranslationRequest,
VLLMValidationError,
) )
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.pooling.classify.protocol import ( from vllm.entrypoints.pooling.classify.protocol import (
...@@ -322,8 +323,10 @@ class OpenAIServing: ...@@ -322,8 +323,10 @@ class OpenAIServing:
input_processor = self.input_processor input_processor = self.input_processor
tokenizer = input_processor.tokenizer tokenizer = input_processor.tokenizer
if tokenizer is None: if tokenizer is None:
raise ValueError( raise VLLMValidationError(
"You cannot use beam search when `skip_tokenizer_init=True`" "You cannot use beam search when `skip_tokenizer_init=True`",
parameter="skip_tokenizer_init",
value=True,
) )
eos_token_id: int = tokenizer.eos_token_id # type: ignore eos_token_id: int = tokenizer.eos_token_id # type: ignore
...@@ -706,8 +709,7 @@ class OpenAIServing: ...@@ -706,8 +709,7 @@ class OpenAIServing:
return None return None
except Exception as e: except Exception as e:
# TODO: Use a vllm-specific Validation Error return self.create_error_response(e)
return self.create_error_response(str(e))
async def _collect_batch( async def _collect_batch(
self, self,
...@@ -738,14 +740,43 @@ class OpenAIServing: ...@@ -738,14 +740,43 @@ class OpenAIServing:
return None return None
except Exception as e: except Exception as e:
return self.create_error_response(str(e)) return self.create_error_response(e)
def create_error_response( def create_error_response(
self, self,
message: str, message: str | Exception,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
param: str | None = None,
) -> ErrorResponse: ) -> ErrorResponse:
exc: Exception | None = None
if isinstance(message, Exception):
exc = message
from vllm.entrypoints.openai.protocol import VLLMValidationError
if isinstance(exc, VLLMValidationError):
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = exc.parameter
elif isinstance(exc, (ValueError, TypeError, RuntimeError)):
# Common validation errors from user input
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
elif exc.__class__.__name__ == "TemplateError":
# jinja2.TemplateError (avoid importing jinja2)
err_type = "BadRequestError"
status_code = HTTPStatus.BAD_REQUEST
param = None
else:
err_type = "InternalServerError"
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
param = None
message = str(exc)
if self.log_error_stack: if self.log_error_stack:
exc_type, _, _ = sys.exc_info() exc_type, _, _ = sys.exc_info()
if exc_type is not None: if exc_type is not None:
...@@ -753,18 +784,27 @@ class OpenAIServing: ...@@ -753,18 +784,27 @@ class OpenAIServing:
else: else:
traceback.print_stack() traceback.print_stack()
return ErrorResponse( return ErrorResponse(
error=ErrorInfo(message=message, type=err_type, code=status_code.value) error=ErrorInfo(
message=message,
type=err_type,
code=status_code.value,
param=param,
)
) )
def create_streaming_error_response( def create_streaming_error_response(
self, self,
message: str, message: str | Exception,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
param: str | None = None,
) -> str: ) -> str:
json_str = json.dumps( json_str = json.dumps(
self.create_error_response( self.create_error_response(
message=message, err_type=err_type, status_code=status_code message=message,
err_type=err_type,
status_code=status_code,
param=param,
).model_dump() ).model_dump()
) )
return json_str return json_str
...@@ -825,6 +865,7 @@ class OpenAIServing: ...@@ -825,6 +865,7 @@ class OpenAIServing:
message=f"The model `{request.model}` does not exist.", message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError", err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND, status_code=HTTPStatus.NOT_FOUND,
param="model",
) )
def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None: def _get_active_default_mm_loras(self, request: AnyRequest) -> LoRARequest | None:
...@@ -991,11 +1032,13 @@ class OpenAIServing: ...@@ -991,11 +1032,13 @@ class OpenAIServing:
ClassificationChatRequest: "classification", ClassificationChatRequest: "classification",
} }
operation = operations.get(type(request), "embedding generation") operation = operations.get(type(request), "embedding generation")
raise ValueError( raise VLLMValidationError(
f"This model's maximum context length is " f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested " f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for {operation}. " f"{token_num} tokens in the input for {operation}. "
f"Please reduce the length of the input." f"Please reduce the length of the input.",
parameter="input_tokens",
value=token_num,
) )
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
...@@ -1017,20 +1060,24 @@ class OpenAIServing: ...@@ -1017,20 +1060,24 @@ class OpenAIServing:
# Note: input length can be up to model context length - 1 for # Note: input length can be up to model context length - 1 for
# completion-like requests. # completion-like requests.
if token_num >= self.max_model_len: if token_num >= self.max_model_len:
raise ValueError( raise VLLMValidationError(
f"This model's maximum context length is " f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, your request has " f"{self.max_model_len} tokens. However, your request has "
f"{token_num} input tokens. Please reduce the length of " f"{token_num} input tokens. Please reduce the length of "
"the input messages." "the input messages.",
parameter="input_tokens",
value=token_num,
) )
if max_tokens is not None and token_num + max_tokens > self.max_model_len: if max_tokens is not None and token_num + max_tokens > self.max_model_len:
raise ValueError( raise VLLMValidationError(
"'max_tokens' or 'max_completion_tokens' is too large: " "'max_tokens' or 'max_completion_tokens' is too large: "
f"{max_tokens}. This model's maximum context length is " f"{max_tokens}. This model's maximum context length is "
f"{self.max_model_len} tokens and your request has " f"{self.max_model_len} tokens and your request has "
f"{token_num} input tokens ({max_tokens} > {self.max_model_len}" f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
f" - {token_num})." f" - {token_num}).",
parameter="max_tokens",
value=max_tokens,
) )
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids) return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
......
...@@ -94,6 +94,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -94,6 +94,7 @@ from vllm.entrypoints.openai.protocol import (
ResponsesResponse, ResponsesResponse,
ResponseUsage, ResponseUsage,
StreamingResponsesResponse, StreamingResponsesResponse,
VLLMValidationError,
) )
from vllm.entrypoints.openai.serving_engine import ( from vllm.entrypoints.openai.serving_engine import (
GenerationError, GenerationError,
...@@ -271,6 +272,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -271,6 +272,7 @@ class OpenAIServingResponses(OpenAIServing):
err_type="invalid_request_error", err_type="invalid_request_error",
message=error_message, message=error_message,
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
param="input",
) )
return None return None
...@@ -282,6 +284,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -282,6 +284,7 @@ class OpenAIServingResponses(OpenAIServing):
err_type="invalid_request_error", err_type="invalid_request_error",
message="logprobs are not supported with gpt-oss models", message="logprobs are not supported with gpt-oss models",
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
param="logprobs",
) )
if request.store and not self.enable_store and request.background: if request.store and not self.enable_store and request.background:
return self.create_error_response( return self.create_error_response(
...@@ -294,6 +297,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -294,6 +297,7 @@ class OpenAIServingResponses(OpenAIServing):
"the vLLM server." "the vLLM server."
), ),
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
param="background",
) )
if request.previous_input_messages and request.previous_response_id: if request.previous_input_messages and request.previous_response_id:
return self.create_error_response( return self.create_error_response(
...@@ -301,6 +305,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -301,6 +305,7 @@ class OpenAIServingResponses(OpenAIServing):
message="Only one of `previous_input_messages` and " message="Only one of `previous_input_messages` and "
"`previous_response_id` can be set.", "`previous_response_id` can be set.",
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
param="previous_response_id",
) )
return None return None
...@@ -457,8 +462,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -457,8 +462,7 @@ class OpenAIServingResponses(OpenAIServing):
) )
generators.append(generator) generators.append(generator)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error return self.create_error_response(e)
return self.create_error_response(str(e))
assert len(generators) == 1 assert len(generators) == 1
(result_generator,) = generators (result_generator,) = generators
...@@ -546,7 +550,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -546,7 +550,7 @@ class OpenAIServingResponses(OpenAIServing):
except GenerationError as e: except GenerationError as e:
return self._convert_generation_error_to_response(e) return self._convert_generation_error_to_response(e)
except Exception as e: except Exception as e:
return self.create_error_response(str(e)) return self.create_error_response(e)
async def _make_request( async def _make_request(
self, self,
...@@ -630,8 +634,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -630,8 +634,7 @@ class OpenAIServingResponses(OpenAIServing):
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error return self.create_error_response(e)
return self.create_error_response(str(e))
# NOTE: Implementation of stauts is still WIP, but for now # NOTE: Implementation of stauts is still WIP, but for now
# we guarantee that if the status is not "completed", it is accurate. # we guarantee that if the status is not "completed", it is accurate.
...@@ -1074,7 +1077,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1074,7 +1077,7 @@ class OpenAIServingResponses(OpenAIServing):
response = self._convert_generation_error_to_response(e) response = self._convert_generation_error_to_response(e)
except Exception as e: except Exception as e:
logger.exception("Background request failed for %s", request.request_id) logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(str(e)) response = self.create_error_response(e)
finally: finally:
new_event_signal.set() new_event_signal.set()
...@@ -1099,7 +1102,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1099,7 +1102,7 @@ class OpenAIServingResponses(OpenAIServing):
response = self._convert_generation_error_to_response(e) response = self._convert_generation_error_to_response(e)
except Exception as e: except Exception as e:
logger.exception("Background request failed for %s", request.request_id) logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(str(e)) response = self.create_error_response(e)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
# If the request has failed, update the status to "failed". # If the request has failed, update the status to "failed".
...@@ -1116,7 +1119,11 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1116,7 +1119,11 @@ class OpenAIServingResponses(OpenAIServing):
starting_after: int | None = None, starting_after: int | None = None,
) -> AsyncGenerator[StreamingResponsesResponse, None]: ) -> AsyncGenerator[StreamingResponsesResponse, None]:
if response_id not in self.event_store: if response_id not in self.event_store:
raise ValueError(f"Unknown response_id: {response_id}") raise VLLMValidationError(
f"Unknown response_id: {response_id}",
parameter="response_id",
value=response_id,
)
event_deque, new_event_signal = self.event_store[response_id] event_deque, new_event_signal = self.event_store[response_id]
start_index = 0 if starting_after is None else starting_after + 1 start_index = 0 if starting_after is None else starting_after + 1
...@@ -1172,6 +1179,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1172,6 +1179,7 @@ class OpenAIServingResponses(OpenAIServing):
return self.create_error_response( return self.create_error_response(
err_type="invalid_request_error", err_type="invalid_request_error",
message="Cannot cancel a synchronous response.", message="Cannot cancel a synchronous response.",
param="response_id",
) )
# Update the status to "cancelled". # Update the status to "cancelled".
...@@ -1191,6 +1199,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1191,6 +1199,7 @@ class OpenAIServingResponses(OpenAIServing):
err_type="invalid_request_error", err_type="invalid_request_error",
message=f"Response with id '{response_id}' not found.", message=f"Response with id '{response_id}' not found.",
status_code=HTTPStatus.NOT_FOUND, status_code=HTTPStatus.NOT_FOUND,
param="response_id",
) )
def _make_store_not_supported_error(self) -> ErrorResponse: def _make_store_not_supported_error(self) -> ErrorResponse:
...@@ -1203,6 +1212,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1203,6 +1212,7 @@ class OpenAIServingResponses(OpenAIServing):
"starting the vLLM server." "starting the vLLM server."
), ),
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
param="store",
) )
async def _process_simple_streaming_events( async def _process_simple_streaming_events(
......
...@@ -30,6 +30,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -30,6 +30,7 @@ from vllm.entrypoints.openai.protocol import (
TranslationSegment, TranslationSegment,
TranslationStreamResponse, TranslationStreamResponse,
UsageInfo, UsageInfo,
VLLMValidationError,
) )
from vllm.entrypoints.openai.serving_engine import OpenAIServing, SpeechToTextRequest from vllm.entrypoints.openai.serving_engine import OpenAIServing, SpeechToTextRequest
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
...@@ -259,7 +260,11 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -259,7 +260,11 @@ class OpenAISpeechToText(OpenAIServing):
) )
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb: if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
raise ValueError("Maximum file size exceeded.") raise VLLMValidationError(
"Maximum file size exceeded",
parameter="audio_filesize_mb",
value=len(audio_data) / 1024**2,
)
with io.BytesIO(audio_data) as bytes_: with io.BytesIO(audio_data) as bytes_:
# NOTE resample to model SR here for efficiency. This is also a # NOTE resample to model SR here for efficiency. This is also a
...@@ -287,12 +292,18 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -287,12 +292,18 @@ class OpenAISpeechToText(OpenAIServing):
) )
if request.response_format == "verbose_json": if request.response_format == "verbose_json":
if not isinstance(prompt, dict): if not isinstance(prompt, dict):
raise ValueError(f"Expected prompt to be a dict,got {type(prompt)}") raise VLLMValidationError(
"Expected prompt to be a dict",
parameter="prompt",
value=type(prompt).__name__,
)
prompt_dict = cast(dict, prompt) prompt_dict = cast(dict, prompt)
decoder_prompt = prompt.get("decoder_prompt") decoder_prompt = prompt.get("decoder_prompt")
if not isinstance(decoder_prompt, str): if not isinstance(decoder_prompt, str):
raise ValueError( raise VLLMValidationError(
f"Expected decoder_prompt to bestr, got {type(decoder_prompt)}" "Expected decoder_prompt to be str",
parameter="decoder_prompt",
value=type(decoder_prompt).__name__,
) )
prompt_dict["decoder_prompt"] = decoder_prompt.replace( prompt_dict["decoder_prompt"] = decoder_prompt.replace(
"<|notimestamps|>", "<|0.00|>" "<|notimestamps|>", "<|0.00|>"
...@@ -412,7 +423,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -412,7 +423,7 @@ class OpenAISpeechToText(OpenAIServing):
except ValueError as e: except ValueError as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e)) return self.create_error_response(e)
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
try: try:
...@@ -448,8 +459,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -448,8 +459,7 @@ class OpenAISpeechToText(OpenAIServing):
for i, prompt in enumerate(prompts) for i, prompt in enumerate(prompts)
] ]
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error return self.create_error_response(e)
return self.create_error_response(str(e))
if request.stream: if request.stream:
return stream_generator_method( return stream_generator_method(
...@@ -523,8 +533,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -523,8 +533,7 @@ class OpenAISpeechToText(OpenAIServing):
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error return self.create_error_response(e)
return self.create_error_response(str(e))
async def _speech_to_text_stream_generator( async def _speech_to_text_stream_generator(
self, self,
...@@ -634,9 +643,8 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -634,9 +643,8 @@ class OpenAISpeechToText(OpenAIServing):
) )
except Exception as e: except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in %s stream generator.", self.task_type) logger.exception("Error in %s stream generator.", self.task_type)
data = self.create_streaming_error_response(str(e)) data = self.create_streaming_error_response(e)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished # Send the final done message after all response.n are finished
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
......
...@@ -12,6 +12,7 @@ import torch ...@@ -12,6 +12,7 @@ import torch
from pydantic import Field from pydantic import Field
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -162,8 +163,9 @@ class BaseRenderer(ABC): ...@@ -162,8 +163,9 @@ class BaseRenderer(ABC):
) -> list[EmbedsPrompt]: ) -> list[EmbedsPrompt]:
"""Load and validate base64-encoded embeddings into prompt objects.""" """Load and validate base64-encoded embeddings into prompt objects."""
if not self.model_config.enable_prompt_embeds: if not self.model_config.enable_prompt_embeds:
raise ValueError( raise VLLMValidationError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`." "You must set `--enable-prompt-embeds` to input `prompt_embeds`.",
parameter="prompt_embeds",
) )
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt: def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
...@@ -396,10 +398,12 @@ class CompletionRenderer(BaseRenderer): ...@@ -396,10 +398,12 @@ class CompletionRenderer(BaseRenderer):
) -> TokensPrompt: ) -> TokensPrompt:
"""Create validated TokensPrompt.""" """Create validated TokensPrompt."""
if max_length is not None and len(token_ids) > max_length: if max_length is not None and len(token_ids) > max_length:
raise ValueError( raise VLLMValidationError(
f"This model's maximum context length is {max_length} tokens. " f"This model's maximum context length is {max_length} tokens. "
f"However, your request has {len(token_ids)} input tokens. " f"However, your request has {len(token_ids)} input tokens. "
"Please reduce the length of the input messages." "Please reduce the length of the input messages.",
parameter="input_tokens",
value=len(token_ids),
) )
tokens_prompt = TokensPrompt(prompt_token_ids=token_ids) tokens_prompt = TokensPrompt(prompt_token_ids=token_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment