Unverified Commit bac414ab authored by mlmz's avatar mlmz Committed by GitHub
Browse files

[Feature] integrate Structural Tag in xgrammar backend for function calling (#3566)


Co-authored-by: default avatarshuaills <shishuaiuoe@gmail.com>
parent eec3f6d1
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# ============================================================================== # ==============================================================================
"""Constrained decoding with xgrammar backend.""" """Constrained decoding with xgrammar backend."""
import json
import logging import logging
from typing import List, Tuple from typing import List, Tuple
...@@ -21,6 +22,7 @@ from xgrammar import ( ...@@ -21,6 +22,7 @@ from xgrammar import (
CompiledGrammar, CompiledGrammar,
GrammarCompiler, GrammarCompiler,
GrammarMatcher, GrammarMatcher,
StructuralTagItem,
TokenizerInfo, TokenizerInfo,
allocate_token_bitmask, allocate_token_bitmask,
apply_token_bitmask_inplace, apply_token_bitmask_inplace,
...@@ -138,6 +140,23 @@ class XGrammarGrammarBackend(BaseGrammarBackend): ...@@ -138,6 +140,23 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
except RuntimeError as e: except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}") logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
return None return None
elif key_type == "structural_tag":
try:
structural_tag = json.loads(key_string)
tags = [
StructuralTagItem(
begin=structure["begin"],
schema=json.dumps(structure["schema"]),
end=structure["end"],
)
for structure in structural_tag["structures"]
]
ctx = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"]
)
except RuntimeError as e:
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
return None
else: else:
raise ValueError(f"Invalid key_type: {key_type}") raise ValueError(f"Invalid key_type: {key_type}")
......
...@@ -710,6 +710,7 @@ class Scheduler: ...@@ -710,6 +710,7 @@ class Scheduler:
req.sampling_params.json_schema is not None req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None or req.sampling_params.regex is not None
or req.sampling_params.ebnf is not None or req.sampling_params.ebnf is not None
or req.sampling_params.structural_tag is not None
): ):
assert self.grammar_backend is not None assert self.grammar_backend is not None
if req.sampling_params.json_schema is not None: if req.sampling_params.json_schema is not None:
...@@ -718,6 +719,8 @@ class Scheduler: ...@@ -718,6 +719,8 @@ class Scheduler:
key = ("regex", req.sampling_params.regex) key = ("regex", req.sampling_params.regex)
elif req.sampling_params.ebnf is not None: elif req.sampling_params.ebnf is not None:
key = ("ebnf", req.sampling_params.ebnf) key = ("ebnf", req.sampling_params.ebnf)
elif req.sampling_params.structural_tag:
key = ("structural_tag", req.sampling_params.structural_tag)
req.grammar = self.grammar_backend.get_cached_value(key) req.grammar = self.grammar_backend.get_cached_value(key)
if not req.grammar: if not req.grammar:
......
...@@ -994,10 +994,17 @@ def v1_chat_generate_request( ...@@ -994,10 +994,17 @@ def v1_chat_generate_request(
"ignore_eos": request.ignore_eos, "ignore_eos": request.ignore_eos,
"skip_special_tokens": request.skip_special_tokens, "skip_special_tokens": request.skip_special_tokens,
} }
if request.response_format and request.response_format.type == "json_schema": if request.response_format and request.response_format.type == "json_schema":
sampling_params["json_schema"] = convert_json_schema_to_str( sampling_params["json_schema"] = convert_json_schema_to_str(
request.response_format.json_schema.schema_ request.response_format.json_schema.schema_
) )
elif (
request.response_format and request.response_format.type == "structural_tag"
):
sampling_params["structural_tag"] = convert_json_schema_to_str(
request.response_format.model_dump(by_alias=True)
)
sampling_params_list.append(sampling_params) sampling_params_list.append(sampling_params)
image_data_list.append(image_data) image_data_list.append(image_data)
......
...@@ -258,6 +258,18 @@ class ResponseFormat(BaseModel): ...@@ -258,6 +258,18 @@ class ResponseFormat(BaseModel):
json_schema: Optional[JsonSchemaResponseFormat] = None json_schema: Optional[JsonSchemaResponseFormat] = None
class StructuresResponseFormat(BaseModel):
begin: str
schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None)
end: str
class StructuralTagResponseFormat(BaseModel):
type: Literal["structural_tag"]
structures: List[StructuresResponseFormat]
triggers: List[str]
class Function(BaseModel): class Function(BaseModel):
"""Function descriptions.""" """Function descriptions."""
...@@ -298,7 +310,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -298,7 +310,7 @@ class ChatCompletionRequest(BaseModel):
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
n: int = 1 n: int = 1
presence_penalty: float = 0.0 presence_penalty: float = 0.0
response_format: Optional[ResponseFormat] = None response_format: Union[ResponseFormat, StructuralTagResponseFormat] = None
seed: Optional[int] = None seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, List[str]]] = None
stream: bool = False stream: bool = False
......
...@@ -45,6 +45,7 @@ class SamplingParams: ...@@ -45,6 +45,7 @@ class SamplingParams:
json_schema: Optional[str] = None, json_schema: Optional[str] = None,
regex: Optional[str] = None, regex: Optional[str] = None,
ebnf: Optional[str] = None, ebnf: Optional[str] = None,
structural_tag: Optional[str] = None,
no_stop_trim: bool = False, no_stop_trim: bool = False,
ignore_eos: bool = False, ignore_eos: bool = False,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
...@@ -72,6 +73,7 @@ class SamplingParams: ...@@ -72,6 +73,7 @@ class SamplingParams:
self.n = n self.n = n
self.json_schema = json_schema self.json_schema = json_schema
self.ebnf = ebnf self.ebnf = ebnf
self.structural_tag = structural_tag
self.no_stop_trim = no_stop_trim self.no_stop_trim = no_stop_trim
self.return_hidden_states = return_hidden_states self.return_hidden_states = return_hidden_states
self.custom_params = custom_params self.custom_params = custom_params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment