Unverified Commit 9a7c8842 authored by gongwei-130's avatar gongwei-130 Committed by GitHub
Browse files

accomendate json schema in the "schema" field, not in "json_schema" field of...

accomendate json schema in the "schema" field, not in "json_schema" field of response_format (#9786)
parent 7a16db9b
......@@ -460,6 +460,38 @@ class ChatCompletionRequest(BaseModel):
values["tool_choice"] = "auto"
return values
@model_validator(mode="before")
@classmethod
def set_json_schema(cls, values):
response_format = values.get("response_format")
if not response_format:
return values
if response_format.get("type") != "json_schema":
return values
schema = response_format.pop("schema", None)
json_schema = response_format.get("json_schema")
if json_schema:
return values
if schema:
name_ = schema.get("title", "Schema")
strict_ = False
if "properties" in schema and "strict" in schema["properties"]:
item = schema["properties"].pop("strict", None)
if item and item.get("default", False):
strict_ = True
response_format["json_schema"] = {
"name": name_,
"schema": schema,
"strict": strict_,
}
return values
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1
min_p: float = 0.0
......
......@@ -18,7 +18,7 @@ import time
import unittest
from typing import Dict, List, Optional
from pydantic import ValidationError
from pydantic import BaseModel, Field, ValidationError
from sglang.srt.entrypoints.openai.protocol import (
BatchRequest,
......@@ -192,6 +192,81 @@ class TestChatCompletionRequest(unittest.TestCase):
self.assertFalse(request.stream_reasoning)
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
def test_chat_completion_json_format(self):
"""Test chat completion json format"""
transcript = "Good morning! It's 7:00 AM, and I'm just waking up. Today is going to be a busy day, "
"so let's get started. First, I need to make a quick breakfast. I think I'll have some "
"scrambled eggs and toast with a cup of coffee. While I'm cooking, I'll also check my "
"emails to see if there's anything urgent."
messages = [
{
"role": "system",
"content": "The following is a voice message transcript. Only answer in JSON.",
},
{
"role": "user",
"content": transcript,
},
]
class VoiceNote(BaseModel):
title: str = Field(description="A title for the voice note")
summary: str = Field(
description="A short one sentence summary of the voice note."
)
strict: Optional[bool] = True
actionItems: List[str] = Field(
description="A list of action items from the voice note"
)
request = ChatCompletionRequest(
model="test-model",
messages=messages,
top_k=40,
min_p=0.05,
separate_reasoning=False,
stream_reasoning=False,
chat_template_kwargs={"custom_param": "value"},
response_format={
"type": "json_schema",
"schema": VoiceNote.model_json_schema(),
},
)
res_format = request.response_format
json_format = res_format.json_schema
name = json_format.name
schema = json_format.schema_
strict = json_format.strict
self.assertEqual(name, "VoiceNote")
self.assertEqual(strict, True)
self.assertNotIn("strict", schema["properties"])
request = ChatCompletionRequest(
model="test-model",
messages=messages,
top_k=40,
min_p=0.05,
separate_reasoning=False,
stream_reasoning=False,
chat_template_kwargs={"custom_param": "value"},
response_format={
"type": "json_schema",
"json_schema": {
"name": "VoiceNote",
"schema": VoiceNote.model_json_schema(),
"strict": True,
},
},
)
res_format = request.response_format
json_format = res_format.json_schema
name = json_format.name
schema = json_format.schema_
strict = json_format.strict
self.assertEqual(name, "VoiceNote")
self.assertEqual(strict, True)
class TestModelSerialization(unittest.TestCase):
"""Test model serialization with hidden states"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment