Unverified Commit 9a7c8842 authored by gongwei-130's avatar gongwei-130 Committed by GitHub
Browse files

accomendate json schema in the "schema" field, not in "json_schema" field of...

accomendate json schema in the "schema" field, not in "json_schema" field of response_format (#9786)
parent 7a16db9b
...@@ -460,6 +460,38 @@ class ChatCompletionRequest(BaseModel): ...@@ -460,6 +460,38 @@ class ChatCompletionRequest(BaseModel):
values["tool_choice"] = "auto" values["tool_choice"] = "auto"
return values return values
@model_validator(mode="before")
@classmethod
def set_json_schema(cls, values):
response_format = values.get("response_format")
if not response_format:
return values
if response_format.get("type") != "json_schema":
return values
schema = response_format.pop("schema", None)
json_schema = response_format.get("json_schema")
if json_schema:
return values
if schema:
name_ = schema.get("title", "Schema")
strict_ = False
if "properties" in schema and "strict" in schema["properties"]:
item = schema["properties"].pop("strict", None)
if item and item.get("default", False):
strict_ = True
response_format["json_schema"] = {
"name": name_,
"schema": schema,
"strict": strict_,
}
return values
# Extra parameters for SRT backend only and will be ignored by OpenAI models. # Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1 top_k: int = -1
min_p: float = 0.0 min_p: float = 0.0
......
...@@ -18,7 +18,7 @@ import time ...@@ -18,7 +18,7 @@ import time
import unittest import unittest
from typing import Dict, List, Optional from typing import Dict, List, Optional
from pydantic import ValidationError from pydantic import BaseModel, Field, ValidationError
from sglang.srt.entrypoints.openai.protocol import ( from sglang.srt.entrypoints.openai.protocol import (
BatchRequest, BatchRequest,
...@@ -192,6 +192,81 @@ class TestChatCompletionRequest(unittest.TestCase): ...@@ -192,6 +192,81 @@ class TestChatCompletionRequest(unittest.TestCase):
self.assertFalse(request.stream_reasoning) self.assertFalse(request.stream_reasoning)
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"}) self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
def test_chat_completion_json_format(self):
"""Test chat completion json format"""
transcript = "Good morning! It's 7:00 AM, and I'm just waking up. Today is going to be a busy day, "
"so let's get started. First, I need to make a quick breakfast. I think I'll have some "
"scrambled eggs and toast with a cup of coffee. While I'm cooking, I'll also check my "
"emails to see if there's anything urgent."
messages = [
{
"role": "system",
"content": "The following is a voice message transcript. Only answer in JSON.",
},
{
"role": "user",
"content": transcript,
},
]
class VoiceNote(BaseModel):
title: str = Field(description="A title for the voice note")
summary: str = Field(
description="A short one sentence summary of the voice note."
)
strict: Optional[bool] = True
actionItems: List[str] = Field(
description="A list of action items from the voice note"
)
request = ChatCompletionRequest(
model="test-model",
messages=messages,
top_k=40,
min_p=0.05,
separate_reasoning=False,
stream_reasoning=False,
chat_template_kwargs={"custom_param": "value"},
response_format={
"type": "json_schema",
"schema": VoiceNote.model_json_schema(),
},
)
res_format = request.response_format
json_format = res_format.json_schema
name = json_format.name
schema = json_format.schema_
strict = json_format.strict
self.assertEqual(name, "VoiceNote")
self.assertEqual(strict, True)
self.assertNotIn("strict", schema["properties"])
request = ChatCompletionRequest(
model="test-model",
messages=messages,
top_k=40,
min_p=0.05,
separate_reasoning=False,
stream_reasoning=False,
chat_template_kwargs={"custom_param": "value"},
response_format={
"type": "json_schema",
"json_schema": {
"name": "VoiceNote",
"schema": VoiceNote.model_json_schema(),
"strict": True,
},
},
)
res_format = request.response_format
json_format = res_format.json_schema
name = json_format.name
schema = json_format.schema_
strict = json_format.strict
self.assertEqual(name, "VoiceNote")
self.assertEqual(strict, True)
class TestModelSerialization(unittest.TestCase): class TestModelSerialization(unittest.TestCase):
"""Test model serialization with hidden states""" """Test model serialization with hidden states"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment