Unverified Commit 0dfe2491 authored by yhyang201's avatar yhyang201 Committed by GitHub
Browse files

Preliminary Support for Qwen3XMLDetector (#8260)


Co-authored-by: default avatarChayenne <zhaochen20@outlook.com>
parent ff45ab7a
......@@ -14,6 +14,7 @@ from sglang.srt.function_call.kimik2_detector import KimiK2Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen3_detector import Qwen3XMLDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
logger = logging.getLogger(__name__)
......@@ -35,6 +36,7 @@ class FunctionCallParser:
"deepseekv3": DeepSeekV3Detector,
"pythonic": PythonicDetector,
"kimi_k2": KimiK2Detector,
"qwen3": Qwen3XMLDetector,
}
def __init__(self, tools: List[Tool], tool_call_parser: str):
......
import ast
import html
import json
import logging
import re
from typing import Any, Dict, List, Tuple
from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
ToolCallItem,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
logger = logging.getLogger(__name__)
def _safe_val(raw: str) -> Any:
raw = html.unescape(raw.strip())
try:
return json.loads(raw)
except Exception:
try:
return ast.literal_eval(raw)
except Exception:
return raw
class Qwen3XMLDetector(BaseFormatDetector):
"""
Detector for Qwen 3 models.
Assumes function call format:
<tool_call>
<function=execute_bash>
<parameter=command>
pwd && ls
</parameter>
</function>
</tool_call>
"""
def __init__(self):
super().__init__()
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_prefix: str = "<function="
self.tool_call_regex = re.compile(
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL
)
self.tool_call_function_regex = re.compile(
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
)
self.tool_call_parameter_regex = re.compile(
r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL
)
self._buf: str = ""
def has_tool_call(self, text: str) -> bool:
return self.tool_call_start_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
normal, calls = self._extract(text, tools)
return StreamingParseResult(normal_text=normal, calls=calls)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
self._buf += new_text
normal = ""
calls: List[ToolCallItem] = []
while True:
if self.tool_call_start_token not in self._buf:
normal += self._buf
self._buf = ""
break
s = self._buf.find(self.tool_call_start_token)
if s > 0:
normal += self._buf[:s]
self._buf = self._buf[s:]
e = self._buf.find(self.tool_call_end_token)
if e == -1:
break
block = self._buf[: e + len(self.tool_call_end_token)]
self._buf = self._buf[e + len(self.tool_call_end_token) :]
calls.extend(self._parse_block(block, tools))
return StreamingParseResult(normal_text=normal, calls=calls)
def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]:
normal_parts: List[str] = []
calls: List[ToolCallItem] = []
cursor = 0
while True:
s = text.find(self.tool_call_start_token, cursor)
if s == -1:
normal_parts.append(text[cursor:])
break
normal_parts.append(text[cursor:s])
e = text.find(self.tool_call_end_token, s)
if e == -1:
normal_parts.append(text[s:])
break
block = text[s : e + len(self.tool_call_end_token)]
cursor = e + len(self.tool_call_end_token)
calls.extend(self._parse_block(block, tools))
return "".join(normal_parts), calls
def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]:
res: List[ToolCallItem] = []
for m in self.tool_call_function_regex.findall(block):
txt = m[0] if m[0] else m[1]
if ">" not in txt:
continue
idx = txt.index(">")
fname = txt[:idx].strip()
body = txt[idx + 1 :]
params: Dict[str, Any] = {}
for pm in self.tool_call_parameter_regex.findall(body):
ptxt = pm[0] if pm[0] else pm[1]
if ">" not in ptxt:
continue
pidx = ptxt.index(">")
pname = ptxt[:pidx].strip()
pval = ptxt[pidx + 1 :].lstrip("\n").rstrip("\n")
params[pname] = _safe_val(pval)
raw = {"name": fname, "arguments": params}
try:
res.extend(self.parse_base_json(raw, tools))
except Exception:
logger.warning("invalid tool call for %s dropped", fname)
return res
def structure_info(self) -> _GetInfoFunc:
return lambda n: StructureInfo(
begin=f"{self.tool_call_start_token}\n<function={n}>",
end=f"</function>\n{self.tool_call_end_token}",
trigger=self.tool_call_start_token,
)
# TODO: fake ebnf for xml + outlines backend
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
individual_call_start_token=self.tool_call_start_token.replace("\n", "\\n"),
individual_call_end_token=self.tool_call_end_token.replace("\n", "\\n"),
tool_call_separator="\\n",
function_format="json",
)
......@@ -1099,6 +1099,7 @@ class ServerArgs:
"deepseekv3",
"pythonic",
"kimi_k2",
"qwen3",
],
default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment