Unverified Commit ef001d98 authored by Maximilien de Bayser's avatar Maximilien de Bayser Committed by GitHub
Browse files

Fix the pydantic logging validator (#12420)


Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
parent 5f671cb4
...@@ -6,7 +6,8 @@ from argparse import Namespace ...@@ -6,7 +6,8 @@ from argparse import Namespace
from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
import torch import torch
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
ValidationInfo, field_validator, model_validator)
from typing_extensions import Annotated from typing_extensions import Annotated
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
...@@ -45,14 +46,14 @@ class OpenAIBaseModel(BaseModel): ...@@ -45,14 +46,14 @@ class OpenAIBaseModel(BaseModel):
# Cache class field names # Cache class field names
field_names: ClassVar[Optional[Set[str]]] = None field_names: ClassVar[Optional[Set[str]]] = None
@model_validator(mode="before") @model_validator(mode="wrap")
@classmethod @classmethod
def __log_extra_fields__(cls, data): def __log_extra_fields__(cls, data, handler):
result = handler(data)
if not isinstance(data, dict):
return result
field_names = cls.field_names field_names = cls.field_names
if field_names is None: if field_names is None:
if not isinstance(data, dict):
return data
# Get all class field names and their potential aliases # Get all class field names and their potential aliases
field_names = set() field_names = set()
for field_name, field in cls.model_fields.items(): for field_name, field in cls.model_fields.items():
...@@ -67,7 +68,7 @@ class OpenAIBaseModel(BaseModel): ...@@ -67,7 +68,7 @@ class OpenAIBaseModel(BaseModel):
"The following fields were present in the request " "The following fields were present in the request "
"but ignored: %s", "but ignored: %s",
data.keys() - field_names) data.keys() - field_names)
return data return result
class ErrorResponse(OpenAIBaseModel): class ErrorResponse(OpenAIBaseModel):
...@@ -1287,6 +1288,20 @@ class BatchRequestInput(OpenAIBaseModel): ...@@ -1287,6 +1288,20 @@ class BatchRequestInput(OpenAIBaseModel):
# The parameters of the request. # The parameters of the request.
body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest] body: Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest]
@field_validator('body', mode='plain')
@classmethod
def check_type_for_url(cls, value: Any, info: ValidationInfo):
# Use url to disambiguate models
url = info.data['url']
if url == "/v1/chat/completions":
return ChatCompletionRequest.model_validate(value)
if url == "/v1/embeddings":
return TypeAdapter(EmbeddingRequest).validate_python(value)
if url == "/v1/score":
return ScoreRequest.model_validate(value)
return TypeAdapter(Union[ChatCompletionRequest, EmbeddingRequest,
ScoreRequest]).validate_python(value)
class BatchResponseData(OpenAIBaseModel): class BatchResponseData(OpenAIBaseModel):
# HTTP status code of the response. # HTTP status code of the response.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment