Unverified Commit 9144ed10 authored by zifeitong's avatar zifeitong Committed by GitHub
Browse files

Support OpenAI API json_schema response format (#1363)

parent 69b3bb9a
...@@ -23,7 +23,6 @@ from collections import defaultdict ...@@ -23,7 +23,6 @@ from collections import defaultdict
import interegular import interegular
import outlines.caching import outlines.caching
from outlines.fsm.json_schema import build_regex_from_schema
from sglang.srt.constrained import ( from sglang.srt.constrained import (
FSMInfo, FSMInfo,
......
...@@ -28,6 +28,13 @@ from fastapi import HTTPException, Request, UploadFile ...@@ -28,6 +28,13 @@ from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import ValidationError from pydantic import ValidationError
try:
from outlines.fsm.json_schema import convert_json_schema_to_str
except ImportError:
# Before outlines 0.0.47, convert_json_schema_to_str is under
# outlines.integrations.utils
from outlines.integrations.utils import convert_json_schema_to_str
from sglang.srt.conversation import ( from sglang.srt.conversation import (
Conversation, Conversation,
SeparatorStyle, SeparatorStyle,
...@@ -888,22 +895,26 @@ def v1_chat_generate_request( ...@@ -888,22 +895,26 @@ def v1_chat_generate_request(
return_logprobs.append(request.logprobs) return_logprobs.append(request.logprobs)
logprob_start_lens.append(-1) logprob_start_lens.append(-1)
top_logprobs_nums.append(request.top_logprobs) top_logprobs_nums.append(request.top_logprobs)
sampling_params_list.append(
{ sampling_params = {
"temperature": request.temperature, "temperature": request.temperature,
"max_new_tokens": request.max_tokens, "max_new_tokens": request.max_tokens,
"min_new_tokens": request.min_tokens, "min_new_tokens": request.min_tokens,
"stop": stop, "stop": stop,
"stop_token_ids": request.stop_token_ids, "stop_token_ids": request.stop_token_ids,
"top_p": request.top_p, "top_p": request.top_p,
"presence_penalty": request.presence_penalty, "presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty, "frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty, "repetition_penalty": request.repetition_penalty,
"regex": request.regex, "regex": request.regex,
"json_schema": request.json_schema, "n": request.n,
"n": request.n, }
} if request.response_format and request.response_format.type == "json_schema":
) sampling_params["json_schema"] = convert_json_schema_to_str(
request.response_format.json_schema.schema_
)
sampling_params_list.append(sampling_params)
image_data_list.append(image_data) image_data_list.append(image_data)
modalities_list.extend(modalities) modalities_list.extend(modalities)
if len(all_requests) == 1: if len(all_requests) == 1:
......
...@@ -82,6 +82,14 @@ class StreamOptions(BaseModel): ...@@ -82,6 +82,14 @@ class StreamOptions(BaseModel):
include_usage: Optional[bool] = False include_usage: Optional[bool] = False
class JsonSchemaResponseFormat(BaseModel):
name: str
description: Optional[str] = None
# use alias to workaround pydantic conflict
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
strict: Optional[bool] = False
class FileRequest(BaseModel): class FileRequest(BaseModel):
# https://platform.openai.com/docs/api-reference/files/create # https://platform.openai.com/docs/api-reference/files/create
file: bytes # The File object (not file name) to be uploaded file: bytes # The File object (not file name) to be uploaded
...@@ -237,8 +245,8 @@ ChatCompletionMessageParam = Union[ ...@@ -237,8 +245,8 @@ ChatCompletionMessageParam = Union[
class ResponseFormat(BaseModel): class ResponseFormat(BaseModel):
# type must be "json_object" or "text" type: Literal["text", "json_object", "json_schema"]
type: Literal["text", "json_object"] json_schema: Optional[JsonSchemaResponseFormat] = None
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
...@@ -264,7 +272,6 @@ class ChatCompletionRequest(BaseModel): ...@@ -264,7 +272,6 @@ class ChatCompletionRequest(BaseModel):
# Extra parameters for SRT backend only and will be ignored by OpenAI models. # Extra parameters for SRT backend only and will be ignored by OpenAI models.
regex: Optional[str] = None regex: Optional[str] = None
json_schema: Optional[str] = None
min_tokens: Optional[int] = 0 min_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
......
...@@ -79,7 +79,10 @@ class TestJSONConstrained(unittest.TestCase): ...@@ -79,7 +79,10 @@ class TestJSONConstrained(unittest.TestCase):
], ],
temperature=0, temperature=0,
max_tokens=128, max_tokens=128,
extra_body={"json_schema": self.json_schema}, response_format={
"type": "json_schema",
"json_schema": {"name": "foo", "schema": json.loads(self.json_schema)},
},
) )
text = response.choices[0].message.content text = response.choices[0].message.content
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment