Unverified Commit 16f69b1f authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

feat: Improve Mistral and Qwen25 function call parsing (#6597)

parent 65f09131
...@@ -72,20 +72,51 @@ class BaseFormatDetector(ABC): ...@@ -72,20 +72,51 @@ class BaseFormatDetector(ABC):
action = json.loads(text) action = json.loads(text)
return StreamingParseResult(calls=self.parse_base_json(action, tools)) return StreamingParseResult(calls=self.parse_base_json(action, tools))
def ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
"""
Check if buffer ends with a partial bot_token.
Return the length of the partial bot_token.
For some format, the bot_token is not a token in model's vocabulary, such as
`[TOOL_CALLS] [` in Mistral.
"""
for i in range(1, min(len(buffer) + 1, len(bot_token))):
if bot_token.startswith(buffer[-i:]):
return i
return 0
def parse_streaming_increment( def parse_streaming_increment(
self, new_text: str, tools: List[Tool] self, new_text: str, tools: List[Tool]
) -> StreamingParseResult: ) -> StreamingParseResult:
""" """
Streaming incremental parsing with tool validation. Streaming incremental parsing with tool validation.
This base implementation works best with formats where:
1. bot_token is followed immediately by JSON (e.g., bot_token + JSON_array)
2. JSON can be parsed incrementally using partial_json_loads
3. Multiple tool calls are separated by "; " or ", "
Examples of incompatible formats (need custom implementation, may reuse some logic from this class):
- Each tool call is wrapped in a separate block: See Qwen25Detector
- Multiple separate blocks: [TOOL_CALLS] [...] \n [TOOL_CALLS] [...]
- Tool call is Pythonic style
For incompatible formats, detectors should override this method with custom logic.
""" """
# Append new text to buffer # Append new text to buffer
self._buffer += new_text self._buffer += new_text
current_text = self._buffer current_text = self._buffer
if not (self.bot_token in current_text or current_text.startswith("{")): if not (self.bot_token in current_text or current_text.startswith("{")):
self._buffer = "" # Only clear buffer if we're sure no tool call is starting
if self.eot_token in new_text: if not self.ends_with_partial_token(self._buffer, self.bot_token):
new_text = new_text.replace(self.eot_token, "") normal_text = self._buffer
return StreamingParseResult(normal_text=new_text) self._buffer = ""
if self.eot_token in normal_text:
normal_text = normal_text.replace(self.eot_token, "")
return StreamingParseResult(normal_text=normal_text)
else:
# Might be partial bot_token, keep buffering
return StreamingParseResult()
# Build tool indices if not already built # Build tool indices if not already built
if not hasattr(self, "_tool_indices"): if not hasattr(self, "_tool_indices"):
......
...@@ -149,8 +149,8 @@ class DeepSeekV3Detector(BaseFormatDetector): ...@@ -149,8 +149,8 @@ class DeepSeekV3Detector(BaseFormatDetector):
def build_ebnf(self, tools: List[Tool]): def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf( return EBNFComposer.build_ebnf(
tools, tools,
bot_token=self.bot_token, sequence_start_token=self.bot_token,
eot_token=self.eot_token, sequence_end_token=self.eot_token,
tool_call_separator="", tool_call_separator="",
call_rule_fmt='"<|tool▁call▁begin|>function<|tool▁sep|>{name}\\n```json\\n" {arguments_rule} "\\n```<|tool▁call▁end|>"', call_rule_fmt='"<|tool▁call▁begin|>function<|tool▁sep|>{name}\\n```json\\n" {arguments_rule} "\\n```<|tool▁call▁end|>"',
function_format="json", function_format="json",
......
...@@ -30,11 +30,6 @@ class EBNFComposer: ...@@ -30,11 +30,6 @@ class EBNFComposer:
ws ::= [ \n\t]* ws ::= [ \n\t]*
""" """
TOOL_CALLS_MAP = {
"pythonic": '"[" function_call ("," function_call)* "]"',
"json": "function_call",
}
CALL_RULE_MAP = { CALL_RULE_MAP = {
"pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"', "pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"',
"json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"', "json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"',
...@@ -138,35 +133,54 @@ class EBNFComposer: ...@@ -138,35 +133,54 @@ class EBNFComposer:
@staticmethod @staticmethod
def build_ebnf( def build_ebnf(
tools, tools,
*,
call_rule_fmt: Optional[str] = None,
function_format: Literal["pythonic", "json"] = "json", function_format: Literal["pythonic", "json"] = "json",
bot_token: Optional[str] = None, # Parameters for wrapping the entire sequence of tool calls
eot_token: Optional[str] = None, sequence_start_token: Optional[str] = None,
sequence_end_token: Optional[str] = None,
# Parameters for wrapping individual tool calls
individual_call_start_token: Optional[str] = None,
individual_call_end_token: Optional[str] = None,
# Parameter for separating multiple tool calls
tool_call_separator: Optional[str] = None, tool_call_separator: Optional[str] = None,
call_rule_fmt: Optional[str] = None,
): ):
""" """
Generalized EBNF builder for all detectors. Generalized EBNF builder for all detectors.
Args: Args:
tools: List of Tool objects to generate EBNF grammar for tools: List of Tool objects to generate EBNF grammar for
function_format: The format of function calls, either "pythonic" or "json"
sequence_start_token: Token that wraps the entire sequence of tool calls (start)
sequence_end_token: Token that wraps the entire sequence of tool calls (end)
individual_call_start_token: Token that wraps each individual tool call (start)
individual_call_end_token: Token that wraps each individual tool call (end)
tool_call_separator: The separator between multiple tool calls
call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with
the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default
format based on function_format will be used. format based on function_format will be used.
function_format: The format of function calls, either "pythonic" or "json"
bot_token: The token that indicates the start of a tool call section
eot_token: The token that indicates the end of a tool call section
tool_call_separator: The separator between multiple tool calls
""" """
# ================================================================= # =================================================================
# Step 1: Determine the root tool calls rule # Step 1: Determine the root tool calls rule
# ================================================================= # =================================================================
if bot_token and eot_token: # Handle a single function call
if tool_call_separator: if individual_call_start_token and individual_call_end_token:
root_rule = f'"{bot_token}" function_call ( "{tool_call_separator}" function_call )* "{eot_token}"' function_call_unit = f'"{individual_call_start_token}" function_call "{individual_call_end_token}"'
else: else:
root_rule = f'"{bot_token}" function_call "{eot_token}"' function_call_unit = "function_call"
# Handle multiple function calls with separators
if tool_call_separator is not None:
base_pattern = f'{function_call_unit} ( "{tool_call_separator}" {function_call_unit} )*'
else:
# Assume only support single function call
base_pattern = function_call_unit
# Apply sequence-level wrapping if needed
if sequence_start_token and sequence_end_token:
root_rule = (
f'"{sequence_start_token}" {base_pattern} "{sequence_end_token}"'
)
else: else:
root_rule = EBNFComposer.TOOL_CALLS_MAP[function_format] root_rule = base_pattern
# ================================================================= # =================================================================
# Step 2: Build the header rules # Step 2: Build the header rules
......
import json import json
import logging
import re import re
from typing import List from typing import List
...@@ -11,12 +12,14 @@ from sglang.srt.function_call.core_types import ( ...@@ -11,12 +12,14 @@ from sglang.srt.function_call.core_types import (
from sglang.srt.function_call.ebnf_composer import EBNFComposer from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
class MistralDetector(BaseFormatDetector): class MistralDetector(BaseFormatDetector):
""" """
Detector for Mistral models. Detector for Mistral models.
Assumes function call format: Assumes function call format:
[TOOL_CALLS] [{"name":"xxx", "arguments":{...}}] [TOOL_CALLS] [{"name":"func1", "arguments":{...}}, {"name":"func2", "arguments":{...}}]
""" """
def __init__(self): def __init__(self):
...@@ -32,21 +35,6 @@ class MistralDetector(BaseFormatDetector): ...@@ -32,21 +35,6 @@ class MistralDetector(BaseFormatDetector):
"""Check if the text contains a Mistral format tool call.""" """Check if the text contains a Mistral format tool call."""
return self.bot_token in text return self.bot_token in text
def _clean_text(self, text: str) -> str:
"""
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
for example,
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
The key pattern is [TOOL_CALLS] [...]
"""
# TODO: check if Mistral supports multiple tool calls, currently assume only support one tool call
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
if len(find_results) > 0:
return find_results[0]
else:
return ""
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
""" """
One-time parsing: Detects and parses tool calls in the provided text. One-time parsing: Detects and parses tool calls in the provided text.
...@@ -57,17 +45,74 @@ class MistralDetector(BaseFormatDetector): ...@@ -57,17 +45,74 @@ class MistralDetector(BaseFormatDetector):
""" """
idx = text.find(self.bot_token) idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text normal_text = text[:idx].strip() if idx != -1 else text
text = self._clean_text(text)
tool_content = text.replace("[TOOL_CALLS]", "").strip() if self.bot_token not in text:
raw_tool_calls = self.tool_call_regex.findall(tool_content) return StreamingParseResult(normal_text=normal_text, calls=[])
# Extract the JSON array part from [TOOL_CALLS] [...]
# Use bracket counting to properly handle nested brackets in JSON content
json_array_str = self._extract_json_array(text)
if not json_array_str:
return StreamingParseResult(normal_text=normal_text, calls=[])
calls = [] calls = []
if len(raw_tool_calls) > 0: try:
raw_tool_call = raw_tool_calls[0] function_call_arr = json.loads(json_array_str)
function_call_arr = json.loads(raw_tool_call) # Handle both single object and array of objects
for match_result in function_call_arr: if not isinstance(function_call_arr, list):
calls.extend(self.parse_base_json(match_result, tools)) function_call_arr = [function_call_arr]
calls = self.parse_base_json(function_call_arr, tools)
except json.JSONDecodeError as e:
logger.warning(
f"Failed to parse JSON part: {json_array_str}, JSON parse error: {str(e)}"
)
return StreamingParseResult(normal_text=normal_text, calls=calls) return StreamingParseResult(normal_text=normal_text, calls=calls)
def _extract_json_array(self, text: str) -> str:
"""
Extract the JSON array part using bracket counting to handle nested brackets.
:param text: The complete text containing [TOOL_CALLS] [...]
:return: The JSON array string or None if not found
"""
start_idx = text.find(self.bot_token)
if start_idx == -1:
return None
# Start from the opening bracket after [TOOL_CALLS]
json_start = (
start_idx + len(self.bot_token) - 1
) # -1 to include the opening bracket
bracket_count = 0
in_string = False
escape_next = False
for i in range(json_start, len(text)):
char = text[i]
if escape_next:
escape_next = False
continue
if char == "\\":
escape_next = True
continue
if char == '"' and not escape_next:
in_string = not in_string
continue
if not in_string:
if char == "[":
bracket_count += 1
elif char == "]":
bracket_count -= 1
if bracket_count == 0:
return text[json_start : i + 1]
return None
def structure_info(self) -> _GetInfoFunc: def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo( return lambda name: StructureInfo(
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":', begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
...@@ -78,7 +123,8 @@ class MistralDetector(BaseFormatDetector): ...@@ -78,7 +123,8 @@ class MistralDetector(BaseFormatDetector):
def build_ebnf(self, tools: List[Tool]): def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf( return EBNFComposer.build_ebnf(
tools, tools,
bot_token=self.bot_token, sequence_start_token=self.bot_token,
eot_token=self.eot_token, sequence_end_token=self.eot_token,
function_format="json", function_format="json",
tool_call_separator=", ",
) )
...@@ -156,8 +156,8 @@ class PythonicDetector(BaseFormatDetector): ...@@ -156,8 +156,8 @@ class PythonicDetector(BaseFormatDetector):
def build_ebnf(self, tools: List[Tool]) -> Optional[str]: def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
return EBNFComposer.build_ebnf( return EBNFComposer.build_ebnf(
tools, tools,
bot_token="[", sequence_start_token="[",
eot_token="]", sequence_end_token="]",
tool_call_separator=",", tool_call_separator=",",
function_format="pythonic", function_format="pythonic",
) )
import json import json
import logging
import re import re
from typing import List from typing import List
...@@ -11,12 +12,14 @@ from sglang.srt.function_call.core_types import ( ...@@ -11,12 +12,14 @@ from sglang.srt.function_call.core_types import (
from sglang.srt.function_call.ebnf_composer import EBNFComposer from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
class Qwen25Detector(BaseFormatDetector): class Qwen25Detector(BaseFormatDetector):
""" """
Detector for Qwen 2.5 models. Detector for Qwen 2.5 models.
Assumes function call format: Assumes function call format:
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call> <tool_call>\n{"name":"func1", "arguments":{...}}\n</tool_call>\n<tool_call>\n{"name":"func2", "arguments":{...}}\n</tool_call>
""" """
def __init__(self): def __init__(self):
...@@ -24,8 +27,9 @@ class Qwen25Detector(BaseFormatDetector): ...@@ -24,8 +27,9 @@ class Qwen25Detector(BaseFormatDetector):
Initializes the detector with necessary state variables. Initializes the detector with necessary state variables.
""" """
super().__init__() super().__init__()
self.bot_token = "<tool_call>" self.bot_token = "<tool_call>\n"
self.eot_token = "</tool_call>" self.eot_token = "\n</tool_call>"
self._normal_text_buffer = "" # Buffer for handling partial end tokens
def has_tool_call(self, text: str) -> bool: def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Qwen 2.5 format tool call.""" """Check if the text contains a Qwen 2.5 format tool call."""
...@@ -43,15 +47,64 @@ class Qwen25Detector(BaseFormatDetector): ...@@ -43,15 +47,64 @@ class Qwen25Detector(BaseFormatDetector):
normal_text = text[:idx].strip() if idx != -1 else text normal_text = text[:idx].strip() if idx != -1 else text
if self.bot_token not in text: if self.bot_token not in text:
return StreamingParseResult(normal_text=normal_text, calls=[]) return StreamingParseResult(normal_text=normal_text, calls=[])
pattern = rf"{self.bot_token}(.*?){self.eot_token}"
# Find all <tool_call>\n...\n</tool_call> blocks
pattern = rf"{re.escape(self.bot_token)}(.*?){re.escape(self.eot_token)}"
match_result_list = re.findall(pattern, text, re.DOTALL) match_result_list = re.findall(pattern, text, re.DOTALL)
calls = [] calls = []
for match_result in match_result_list: for match_result in match_result_list:
match_result = json.loads(match_result) try:
calls.extend(self.parse_base_json(match_result, tools)) parsed_call = json.loads(match_result.strip())
calls.extend(self.parse_base_json(parsed_call, tools))
except json.JSONDecodeError as e:
logger.warning(
f"Failed to parse JSON part: {match_result}, JSON parse error: {str(e)}"
)
continue
return StreamingParseResult(normal_text=normal_text, calls=calls) return StreamingParseResult(normal_text=normal_text, calls=calls)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing for Qwen 2.5 tool calls.
Uses base class implementation with buffering to handle partial end tokens.
"""
result = super().parse_streaming_increment(new_text, tools)
# Handle partial end tokens that are streamed character by character
if result.normal_text:
self._normal_text_buffer += result.normal_text
# Check if buffer contains complete end token (without leading newline)
end_token_without_newline = self.eot_token[1:] # "</tool_call>"
if end_token_without_newline in self._normal_text_buffer:
cleaned_text = self._normal_text_buffer.replace(
end_token_without_newline, ""
)
self._normal_text_buffer = ""
result.normal_text = cleaned_text
else:
# Check if buffer might contain partial end token at the end
partial_match_len = self.ends_with_partial_token(
self._normal_text_buffer, end_token_without_newline
)
if partial_match_len:
# Keep potential partial match in buffer, return the rest
result.normal_text = self._normal_text_buffer[:-partial_match_len]
self._normal_text_buffer = self._normal_text_buffer[
-partial_match_len:
]
else:
# No partial match, return all buffered text
result.normal_text = self._normal_text_buffer
self._normal_text_buffer = ""
return result
def structure_info(self) -> _GetInfoFunc: def structure_info(self) -> _GetInfoFunc:
# TODO: Update the begin and end tokens with '\n' if necessary
return lambda name: StructureInfo( return lambda name: StructureInfo(
begin='<tool_call>{"name":"' + name + '", "arguments":', begin='<tool_call>{"name":"' + name + '", "arguments":',
end="}</tool_call>", end="}</tool_call>",
...@@ -61,7 +114,8 @@ class Qwen25Detector(BaseFormatDetector): ...@@ -61,7 +114,8 @@ class Qwen25Detector(BaseFormatDetector):
def build_ebnf(self, tools: List[Tool]): def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf( return EBNFComposer.build_ebnf(
tools, tools,
bot_token=self.bot_token, individual_call_start_token=self.bot_token.replace("\n", "\\n"),
eot_token=self.eot_token, individual_call_end_token=self.eot_token.replace("\n", "\\n"),
tool_call_separator="\\n",
function_format="json", function_format="json",
) )
...@@ -265,6 +265,118 @@ class TestPythonicDetector(unittest.TestCase): ...@@ -265,6 +265,118 @@ class TestPythonicDetector(unittest.TestCase):
self.assertEqual(params["data"], [1, 2, 3]) self.assertEqual(params["data"], [1, 2, 3])
class TestMistralDetector(unittest.TestCase):
def setUp(self):
"""Set up test tools and detector for Mistral format testing."""
self.tools = [
Tool(
type="function",
function=Function(
name="make_next_step_decision",
description="Test function for decision making",
parameters={
"type": "object",
"properties": {
"decision": {
"type": "string",
"description": "The next step to take",
},
"content": {
"type": "string",
"description": "The content of the next step",
},
},
"required": ["decision", "content"],
},
),
),
]
self.detector = MistralDetector()
def test_detect_and_parse_with_nested_brackets_in_content(self):
"""Test parsing Mistral format with nested brackets in JSON content.
This test case specifically addresses the issue where the regex pattern
was incorrectly truncating JSON when it contained nested brackets like [City Name].
"""
# This is the exact problematic text from the original test failure
test_text = '[TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"","content":"```\\nTOOL: Access a weather API or service\\nOBSERVATION: Retrieve the current weather data for the top 5 populated cities in the US\\nANSWER: The weather in the top 5 populated cities in the US is as follows: [City Name] - [Weather Conditions] - [Temperature]\\n```"}}]'
result = self.detector.detect_and_parse(test_text, self.tools)
# Verify that the parsing was successful
self.assertEqual(len(result.calls), 1, "Should detect exactly one tool call")
call = result.calls[0]
self.assertEqual(
call.name,
"make_next_step_decision",
"Should detect the correct function name",
)
# Verify that the parameters are valid JSON and contain the expected content
params = json.loads(call.parameters)
self.assertEqual(
params["decision"], "", "Decision parameter should be empty string"
)
# The content should contain the full text including the nested brackets [City Name]
expected_content = "```\nTOOL: Access a weather API or service\nOBSERVATION: Retrieve the current weather data for the top 5 populated cities in the US\nANSWER: The weather in the top 5 populated cities in the US is as follows: [City Name] - [Weather Conditions] - [Temperature]\n```"
self.assertEqual(
params["content"],
expected_content,
"Content should include nested brackets without truncation",
)
# Verify that normal text is empty (since the entire input is a tool call)
self.assertEqual(
result.normal_text, "", "Normal text should be empty for pure tool call"
)
def test_detect_and_parse_simple_case(self):
"""Test parsing a simple Mistral format tool call without nested brackets."""
test_text = '[TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"TOOL", "content":"Use weather API"}}]'
result = self.detector.detect_and_parse(test_text, self.tools)
self.assertEqual(len(result.calls), 1)
call = result.calls[0]
self.assertEqual(call.name, "make_next_step_decision")
params = json.loads(call.parameters)
self.assertEqual(params["decision"], "TOOL")
self.assertEqual(params["content"], "Use weather API")
def test_detect_and_parse_no_tool_calls(self):
"""Test parsing text without any tool calls."""
test_text = "This is just normal text without any tool calls."
result = self.detector.detect_and_parse(test_text, self.tools)
self.assertEqual(len(result.calls), 0, "Should detect no tool calls")
self.assertEqual(
result.normal_text,
test_text,
"Should return the original text as normal text",
)
def test_detect_and_parse_with_text_before_tool_call(self):
"""Test parsing text that has content before the tool call."""
test_text = 'Here is some text before the tool call: [TOOL_CALLS] [{"name":"make_next_step_decision", "arguments":{"decision":"ANSWER", "content":"The answer is 42"}}]'
result = self.detector.detect_and_parse(test_text, self.tools)
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.normal_text, "Here is some text before the tool call:")
call = result.calls[0]
self.assertEqual(call.name, "make_next_step_decision")
params = json.loads(call.parameters)
self.assertEqual(params["decision"], "ANSWER")
self.assertEqual(params["content"], "The answer is 42")
class TestEBNFGeneration(unittest.TestCase): class TestEBNFGeneration(unittest.TestCase):
def setUp(self): def setUp(self):
# Create sample tools for testing # Create sample tools for testing
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment