Unverified Commit b6c14ec0 authored by cicirori's avatar cicirori Committed by GitHub
Browse files

add `response_format` support for `completion` API (#9665)

parent 43de1d73
...@@ -108,6 +108,23 @@ class JsonSchemaResponseFormat(BaseModel): ...@@ -108,6 +108,23 @@ class JsonSchemaResponseFormat(BaseModel):
strict: Optional[bool] = False strict: Optional[bool] = False
class ResponseFormat(BaseModel):
type: Literal["text", "json_object", "json_schema"]
json_schema: Optional[JsonSchemaResponseFormat] = None
class StructuresResponseFormat(BaseModel):
begin: str
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
end: str
class StructuralTagResponseFormat(BaseModel):
type: Literal["structural_tag"]
structures: List[StructuresResponseFormat]
triggers: List[str]
class FileRequest(BaseModel): class FileRequest(BaseModel):
# https://platform.openai.com/docs/api-reference/files/create # https://platform.openai.com/docs/api-reference/files/create
file: bytes # The File object (not file name) to be uploaded file: bytes # The File object (not file name) to be uploaded
...@@ -200,6 +217,7 @@ class CompletionRequest(BaseModel): ...@@ -200,6 +217,7 @@ class CompletionRequest(BaseModel):
skip_special_tokens: bool = True skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None session_params: Optional[Dict] = None
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
# For PD disaggregation # For PD disaggregation
bootstrap_host: Optional[Union[List[str], str]] = None bootstrap_host: Optional[Union[List[str], str]] = None
...@@ -359,23 +377,6 @@ ChatCompletionMessageParam = Union[ ...@@ -359,23 +377,6 @@ ChatCompletionMessageParam = Union[
] ]
class ResponseFormat(BaseModel):
type: Literal["text", "json_object", "json_schema"]
json_schema: Optional[JsonSchemaResponseFormat] = None
class StructuresResponseFormat(BaseModel):
begin: str
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
end: str
class StructuralTagResponseFormat(BaseModel):
type: Literal["structural_tag"]
structures: List[StructuresResponseFormat]
triggers: List[str]
class Function(BaseModel): class Function(BaseModel):
"""Function descriptions.""" """Function descriptions."""
......
...@@ -23,6 +23,7 @@ from sglang.srt.entrypoints.openai.utils import ( ...@@ -23,6 +23,7 @@ from sglang.srt.entrypoints.openai.utils import (
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.template_manager import TemplateManager from sglang.srt.managers.template_manager import TemplateManager
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.utils import convert_json_schema_to_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -125,6 +126,20 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -125,6 +126,20 @@ class OpenAIServingCompletion(OpenAIServingBase):
"logit_bias": request.logit_bias, "logit_bias": request.logit_bias,
} }
# Handle response_format constraints
if request.response_format and request.response_format.type == "json_schema":
sampling_params["json_schema"] = convert_json_schema_to_str(
request.response_format.json_schema.schema_
)
elif request.response_format and request.response_format.type == "json_object":
sampling_params["json_schema"] = '{"type": "object"}'
elif (
request.response_format and request.response_format.type == "structural_tag"
):
sampling_params["structural_tag"] = convert_json_schema_to_str(
request.response_format.model_dump(by_alias=True)
)
return sampling_params return sampling_params
async def _handle_streaming_request( async def _handle_streaming_request(
......
...@@ -95,6 +95,63 @@ class ServingCompletionTestCase(unittest.TestCase): ...@@ -95,6 +95,63 @@ class ServingCompletionTestCase(unittest.TestCase):
self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded" self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded"
self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"]) self.assertEqual(self.sc._prepare_echo_prompts(req), ["decoded"])
# ---------- response_format handling ----------
def test_response_format_json_object(self):
"""Test that response_format json_object is correctly processed in sampling params."""
req = CompletionRequest(
model="x",
prompt="Generate a JSON object:",
max_tokens=100,
response_format={"type": "json_object"},
)
sampling_params = self.sc._build_sampling_params(req)
self.assertEqual(sampling_params["json_schema"], '{"type": "object"}')
def test_response_format_json_schema(self):
"""Test that response_format json_schema is correctly processed in sampling params."""
schema = {
"type": "object",
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
}
req = CompletionRequest(
model="x",
prompt="Generate a JSON object:",
max_tokens=100,
response_format={
"type": "json_schema",
"json_schema": {"name": "person", "schema": schema},
},
)
sampling_params = self.sc._build_sampling_params(req)
# The schema should be converted to string by convert_json_schema_to_str
self.assertIn("json_schema", sampling_params)
self.assertIsInstance(sampling_params["json_schema"], str)
def test_response_format_structural_tag(self):
"""Test that response_format structural_tag is correctly processed in sampling params."""
req = CompletionRequest(
model="x",
prompt="Generate structured output:",
max_tokens=100,
response_format={
"type": "structural_tag",
"structures": [{"begin": "<data>", "end": "</data>"}],
"triggers": ["<data>"],
},
)
sampling_params = self.sc._build_sampling_params(req)
# The structural_tag should be processed
self.assertIn("structural_tag", sampling_params)
self.assertIsInstance(sampling_params["structural_tag"], str)
def test_response_format_none(self):
"""Test that no response_format doesn't add extra constraints."""
req = CompletionRequest(model="x", prompt="Generate text:", max_tokens=100)
sampling_params = self.sc._build_sampling_params(req)
# Should not have json_schema or structural_tag from response_format
# (but might have json_schema from the legacy json_schema field)
self.assertIsNone(sampling_params.get("structural_tag"))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main(verbosity=2) unittest.main(verbosity=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment