Unverified Commit 276e7b3e authored by DarkSharpness's avatar DarkSharpness Committed by GitHub
Browse files

[Feature] New structural tag support (#10691)

parent 296f6892
...@@ -349,6 +349,50 @@ ...@@ -349,6 +349,50 @@
"print_highlight(response.choices[0].message.content)" "print_highlight(response.choices[0].message.content)"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Support for XGrammar latest structural tag format\n",
"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
" messages=messages,\n",
" response_format={\n",
" \"type\": \"structural_tag\",\n",
" \"format\": {\n",
" \"type\": \"triggered_tags\",\n",
" \"triggers\": [\"<function=\"],\n",
" \"tags\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"content\": {\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": schema_get_current_weather,\n",
" },\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"content\": {\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": schema_get_current_date,\n",
" },\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"at_least_one\": False,\n",
" \"stop_after_first\": False,\n",
" },\n",
" },\n",
")\n",
"\n",
"print_highlight(response.choices[0].message.content)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
...@@ -594,6 +638,56 @@ ...@@ -594,6 +638,56 @@
"print_highlight(response.json())" "print_highlight(response.json())"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Support for XGrammar latest structural tag format\n",
"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n",
"\n",
"payload = {\n",
" \"text\": text,\n",
" \"sampling_params\": {\n",
" \"structural_tag\": json.dumps(\n",
" {\n",
" \"type\": \"structural_tag\",\n",
" \"format\": {\n",
" \"type\": \"triggered_tags\",\n",
" \"triggers\": [\"<function=\"],\n",
" \"tags\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"content\": {\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": schema_get_current_weather,\n",
" },\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"content\": {\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": schema_get_current_date,\n",
" },\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"at_least_one\": False,\n",
" \"stop_after_first\": False,\n",
" },\n",
" }\n",
" )\n",
" },\n",
"}\n",
"\n",
"\n",
"# Send POST request to the API endpoint\n",
"response = requests.post(f\"http://localhost:{port}/generate\", json=payload)\n",
"print_highlight(response.json())"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
...@@ -825,6 +919,57 @@ ...@@ -825,6 +919,57 @@
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")" " print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Support for XGrammar latest structural tag format\n",
"# https://xgrammar.mlc.ai/docs/tutorials/structural_tag.html\n",
"\n",
"sampling_params = {\n",
" \"temperature\": 0.8,\n",
" \"top_p\": 0.95,\n",
" \"structural_tag\": json.dumps(\n",
" {\n",
" \"type\": \"structural_tag\",\n",
" \"format\": {\n",
" \"type\": \"triggered_tags\",\n",
" \"triggers\": [\"<function=\"],\n",
" \"tags\": [\n",
" {\n",
" \"begin\": \"<function=get_current_weather>\",\n",
" \"content\": {\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": schema_get_current_weather,\n",
" },\n",
" \"end\": \"</function>\",\n",
" },\n",
" {\n",
" \"begin\": \"<function=get_current_date>\",\n",
" \"content\": {\n",
" \"type\": \"json_schema\",\n",
" \"json_schema\": schema_get_current_date,\n",
" },\n",
" \"end\": \"</function>\",\n",
" },\n",
" ],\n",
" \"at_least_one\": False,\n",
" \"stop_after_first\": False,\n",
" },\n",
" }\n",
" ),\n",
"}\n",
"\n",
"\n",
"# Send POST request to the API endpoint\n",
"outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print_highlight(\"===============================\")\n",
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
......
...@@ -32,6 +32,7 @@ from sglang.srt.constrained.base_grammar_backend import ( ...@@ -32,6 +32,7 @@ from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarBackend, BaseGrammarBackend,
BaseGrammarObject, BaseGrammarObject,
) )
from sglang.srt.constrained.utils import is_legacy_structural_tag
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -160,6 +161,7 @@ class GuidanceBackend(BaseGrammarBackend): ...@@ -160,6 +161,7 @@ class GuidanceBackend(BaseGrammarBackend):
def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]: def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
try: try:
structural_tag = json.loads(key_string) structural_tag = json.loads(key_string)
assert is_legacy_structural_tag(structural_tag)
tags = [ tags = [
StructTag( StructTag(
begin=structure["begin"], begin=structure["begin"],
......
from typing import Dict
def is_legacy_structural_tag(obj: Dict) -> bool:
# test whether an object is a legacy structural tag
# see `StructuralTagResponseFormat` at `sglang.srt.entrypoints.openai.protocol`
if obj.get("structures", None) is not None:
assert obj.get("triggers", None) is not None
return True
else:
assert obj.get("format", None) is not None
return False
...@@ -34,6 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import ( ...@@ -34,6 +34,7 @@ from sglang.srt.constrained.base_grammar_backend import (
BaseGrammarObject, BaseGrammarObject,
GrammarStats, GrammarStats,
) )
from sglang.srt.constrained.utils import is_legacy_structural_tag
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
...@@ -241,7 +242,9 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -241,7 +242,9 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]: def dispatch_structural_tag(self, key_string: str) -> Optional[XGrammarGrammar]:
try: try:
# TODO(dark): it's REALLY stupid to construct object from string and decode it again
structural_tag = json.loads(key_string) structural_tag = json.loads(key_string)
if is_legacy_structural_tag(structural_tag):
tags = [ tags = [
StructuralTagItem( StructuralTagItem(
begin=structure["begin"], begin=structure["begin"],
...@@ -253,6 +256,8 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -253,6 +256,8 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
ctx = self.grammar_compiler.compile_structural_tag( ctx = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"] tags, structural_tag["triggers"]
) )
else:
ctx = self.grammar_compiler.compile_structural_tag(key_string)
except (RuntimeError, json.decoder.JSONDecodeError) as e: except (RuntimeError, json.decoder.JSONDecodeError) as e:
logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}") logging.error(f"Hit invalid structural_tag: {key_string=}, {e=}")
return INVALID_GRAMMAR_OBJ return INVALID_GRAMMAR_OBJ
......
...@@ -17,7 +17,7 @@ import logging ...@@ -17,7 +17,7 @@ import logging
import time import time
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, TypeAlias, Union from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TypeAlias, Union
from openai.types.responses import ( from openai.types.responses import (
ResponseFunctionToolCall, ResponseFunctionToolCall,
...@@ -37,6 +37,7 @@ from pydantic import ( ...@@ -37,6 +37,7 @@ from pydantic import (
model_validator, model_validator,
) )
from typing_extensions import Literal from typing_extensions import Literal
from xgrammar import StructuralTag
from sglang.utils import convert_json_schema_to_str from sglang.utils import convert_json_schema_to_str
...@@ -128,12 +129,23 @@ class StructuresResponseFormat(BaseModel): ...@@ -128,12 +129,23 @@ class StructuresResponseFormat(BaseModel):
end: str end: str
class StructuralTagResponseFormat(BaseModel): # NOTE(dark): keep this for backward compatibility
class LegacyStructuralTagResponseFormat(BaseModel):
type: Literal["structural_tag"] type: Literal["structural_tag"]
structures: List[StructuresResponseFormat] structures: List[StructuresResponseFormat]
triggers: List[str] triggers: List[str]
StructuralTagResponseFormat: TypeAlias = Union[
LegacyStructuralTagResponseFormat, StructuralTag
]
ToolCallConstraint: TypeAlias = Union[
Tuple[Literal["structural_tag"], StructuralTagResponseFormat],
Tuple[Literal["json_schema"], Any], # json_schema can be dict/str/None
]
class FileRequest(BaseModel): class FileRequest(BaseModel):
# https://platform.openai.com/docs/api-reference/files/create # https://platform.openai.com/docs/api-reference/files/create
file: bytes # The File object (not file name) to be uploaded file: bytes # The File object (not file name) to be uploaded
...@@ -583,7 +595,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -583,7 +595,7 @@ class ChatCompletionRequest(BaseModel):
self, self,
stop: List[str], stop: List[str],
model_generation_config: Dict[str, Any], model_generation_config: Dict[str, Any],
tool_call_constraint: Optional[Any] = None, tool_call_constraint: Optional[ToolCallConstraint] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Convert request to sampling parameters. Convert request to sampling parameters.
...@@ -649,7 +661,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -649,7 +661,7 @@ class ChatCompletionRequest(BaseModel):
) )
elif constraint_type == "json_schema": elif constraint_type == "json_schema":
sampling_params[constraint_type] = convert_json_schema_to_str( sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value constraint_value # type: ignore
) )
else: else:
sampling_params[constraint_type] = constraint_value sampling_params[constraint_type] = constraint_value
...@@ -1177,7 +1189,7 @@ class MessageProcessingResult: ...@@ -1177,7 +1189,7 @@ class MessageProcessingResult:
video_data: Optional[Any] video_data: Optional[Any]
modalities: List[str] modalities: List[str]
stop: List[str] stop: List[str]
tool_call_constraint: Optional[Any] = None tool_call_constraint: Optional[ToolCallConstraint] = None
class ToolCallProcessingResult(NamedTuple): class ToolCallProcessingResult(NamedTuple):
......
...@@ -2,9 +2,10 @@ import logging ...@@ -2,9 +2,10 @@ import logging
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
from sglang.srt.entrypoints.openai.protocol import ( from sglang.srt.entrypoints.openai.protocol import (
StructuralTagResponseFormat, LegacyStructuralTagResponseFormat,
StructuresResponseFormat, StructuresResponseFormat,
Tool, Tool,
ToolCallConstraint,
ToolChoice, ToolChoice,
) )
from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.base_format_detector import BaseFormatDetector
...@@ -51,7 +52,6 @@ class FunctionCallParser: ...@@ -51,7 +52,6 @@ class FunctionCallParser:
} }
def __init__(self, tools: List[Tool], tool_call_parser: str): def __init__(self, tools: List[Tool], tool_call_parser: str):
detector: Type[BaseFormatDetector] = None
detector_class = self.ToolCallParserEnum.get(tool_call_parser) detector_class = self.ToolCallParserEnum.get(tool_call_parser)
if detector_class: if detector_class:
detector = detector_class() detector = detector_class()
...@@ -123,7 +123,7 @@ class FunctionCallParser: ...@@ -123,7 +123,7 @@ class FunctionCallParser:
return final_normal_text, final_calls return final_normal_text, final_calls
def get_structure_tag(self) -> StructuralTagResponseFormat: def get_structure_tag(self) -> LegacyStructuralTagResponseFormat:
""" """
Generate a structural tag response format for all available tools. Generate a structural tag response format for all available tools.
...@@ -151,7 +151,9 @@ class FunctionCallParser: ...@@ -151,7 +151,9 @@ class FunctionCallParser:
) )
tool_trigger_set.add(info.trigger) tool_trigger_set.add(info.trigger)
return StructuralTagResponseFormat( # TODO(dark): move this into new structural tag format
# This requires all grammar backend support the new format
return LegacyStructuralTagResponseFormat(
type="structural_tag", type="structural_tag",
structures=tool_structures, structures=tool_structures,
triggers=list(tool_trigger_set), triggers=list(tool_trigger_set),
...@@ -159,7 +161,7 @@ class FunctionCallParser: ...@@ -159,7 +161,7 @@ class FunctionCallParser:
def get_structure_constraint( def get_structure_constraint(
self, tool_choice: Union[ToolChoice, Literal["auto", "required"]] self, tool_choice: Union[ToolChoice, Literal["auto", "required"]]
) -> Optional[Tuple[str, Any]]: ) -> Optional[ToolCallConstraint]:
""" """
Returns the appropriate structure constraint for tool calls based on the tool_choice. Returns the appropriate structure constraint for tool calls based on the tool_choice.
The constraint is used to guide the model's output format. The constraint is used to guide the model's output format.
...@@ -178,8 +180,8 @@ class FunctionCallParser: ...@@ -178,8 +180,8 @@ class FunctionCallParser:
and tool_choice == "auto" and tool_choice == "auto"
and any(tool.function.strict for tool in self.tools) and any(tool.function.strict for tool in self.tools)
): ):
strict_tag = self.get_structure_tag() tag = self.get_structure_tag()
return ("structural_tag", strict_tag) return ("structural_tag", tag)
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice): elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
json_schema = get_json_schema_constraint(self.tools, tool_choice) json_schema = get_json_schema_constraint(self.tools, tool_choice)
return ("json_schema", json_schema) return ("json_schema", json_schema)
......
"""
python3 -m unittest test.srt.openai_server.features.test_structural_tag
"""
import json
import unittest
from typing import Any
import openai
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
def setup_class(cls, backend: str):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--max-running-requests",
"10",
"--grammar-backend",
backend,
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
class TestStructuralTagXGrammarBackend(CustomTestCase):
model: str
base_url: str
process: Any
@classmethod
def setUpClass(cls):
setup_class(cls, backend="xgrammar")
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_stag_constant_str_openai(self):
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
# even when the answer is ridiculous, the model should follow the instruction
answer = "The capital of France is Berlin."
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": "Introduce the capital of France. Return in a JSON format.",
},
],
temperature=0,
max_tokens=128,
response_format={
"type": "structural_tag",
"format": {
"type": "const_string",
"value": answer,
},
},
)
text = response.choices[0].message.content
self.assertEqual(text, answer)
def test_stag_json_schema_openai(self):
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
json_schema = {
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
"additionalProperties": False,
}
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{
"role": "user",
"content": "Introduce the capital of France. Return in a JSON format.",
},
],
temperature=0,
max_tokens=128,
response_format={
"type": "structural_tag",
"format": {
"type": "json_schema",
"json_schema": json_schema,
},
},
)
text = response.choices[0].message.content
try:
js_obj = json.loads(text)
except (TypeError, json.decoder.JSONDecodeError):
print("JSONDecodeError", text)
raise
self.assertIsInstance(js_obj["name"], str)
self.assertIsInstance(js_obj["population"], int)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment