Unverified Commit ed0c3035 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

feat(Tool Calling): Support `required` and specific function mode (#6550)

parent e6f11356
......@@ -54,10 +54,12 @@
"source": [
"Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n",
"\n",
"- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n",
"- llama3: Llama 3.1 / 3.2 / 3.3 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct, meta-llama/Llama-3.3-70B-Instruct).\n",
"- llama4: Llama 4 (e.g. meta-llama/Llama-4-Scout-17B-16E-Instruct).\n",
"- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n",
"Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n",
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html)."
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html).\n",
"- deepseekv3: DeepSeek-v3 (e.g., deepseek-ai/DeepSeek-V3-0324).\n"
]
},
{
......@@ -360,6 +362,164 @@
"print(final_response.choices[0].message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool Choice Mode\n",
"\n",
"SGLang supports OpenAI's `tool_choice` parameter to control when and which tools the model should call. This feature is implemented using EBNF (Extended Backus-Naur Form) grammar to ensure reliable tool calling behavior.\n",
"\n",
"### Supported Tool Choice Options\n",
"\n",
"- **`tool_choice=\"required\"`**: Forces the model to call at least one tool\n",
"- **`tool_choice={\"type\": \"function\", \"function\": {\"name\": \"specific_function\"}}`**: Forces the model to call a specific function\n",
"\n",
"### Backend Compatibility\n",
"\n",
"Tool choice is fully supported with the **Xgrammar backend**, which is the default grammar backend (`--grammar-backend xgrammar`). However, it may not be fully supported with other backends such as `outlines`.\n",
"\n",
"### Example: Required Tool Choice"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Response with tool_choice='required':\n",
"Content: None\n",
"Tool calls: [ChatCompletionMessageToolCall(id='call_NFO3TSZuRRO8Eu3Cv79uiQ', function=Function(arguments='{\"city\": \"Paris\", \"unit\": \"celsius\"}', name='get_current_weather'), type='function', index=0)]\n"
]
}
],
"source": [
"from openai import OpenAI\n",
"import json\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"from sglang.test.test_utils import is_in_ci\n",
"\n",
"if is_in_ci():\n",
" from patch import launch_server_cmd\n",
"else:\n",
" from sglang.utils import launch_server_cmd\n",
" import nest_asyncio\n",
"\n",
" nest_asyncio.apply()\n",
"\n",
"# Start a new server session for tool choice examples\n",
"server_process_tool_choice, port_tool_choice = launch_server_cmd(\n",
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0\"\n",
")\n",
"wait_for_server(f\"http://localhost:{port_tool_choice}\")\n",
"\n",
"# Initialize client for tool choice examples\n",
"client_tool_choice = OpenAI(\n",
" api_key=\"None\", base_url=f\"http://0.0.0.0:{port_tool_choice}/v1\"\n",
")\n",
"model_name_tool_choice = client_tool_choice.models.list().data[0].id\n",
"\n",
"# Example with tool_choice=\"required\" - forces the model to call a tool\n",
"messages_required = [\n",
" {\"role\": \"user\", \"content\": \"Hello, what is the capital of France?\"}\n",
"]\n",
"\n",
"# Define tools\n",
"tools = [\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"Get the current weather in a given location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"city\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n",
" },\n",
" \"unit\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The unit to fetch the temperature in\",\n",
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
" },\n",
" },\n",
" \"required\": [\"city\", \"unit\"],\n",
" },\n",
" },\n",
" }\n",
"]\n",
"\n",
"response_required = client_tool_choice.chat.completions.create(\n",
" model=model_name_tool_choice,\n",
" messages=messages_required,\n",
" temperature=0,\n",
" max_tokens=1024,\n",
" tools=tools,\n",
" tool_choice=\"required\", # Force the model to call a tool\n",
")\n",
"\n",
"print_highlight(\"Response with tool_choice='required':\")\n",
"print(\"Content:\", response_required.choices[0].message.content)\n",
"print(\"Tool calls:\", response_required.choices[0].message.tool_calls)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example: Specific Function Choice\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Response with specific function choice:\n",
"Content: None\n",
"Tool calls: [ChatCompletionMessageToolCall(id='call_fGL_1qsPQFqntNBPkSynJw', function=Function(arguments='{\"city\": \"Sophia Antipolis\", \"unit\": \"celsius\"}', name='get_current_weather'), type='function', index=0)]\n",
"Called function: get_current_weather\n",
"Arguments: {\"city\": \"Sophia Antipolis\", \"unit\": \"celsius\"}\n"
]
}
],
"source": [
"# Example with specific function choice - forces the model to call a specific function\n",
"messages_specific = [\n",
" {\"role\": \"user\", \"content\": \"What are the most attactive places in France?\"}\n",
"]\n",
"\n",
"response_specific = client_tool_choice.chat.completions.create(\n",
" model=model_name_tool_choice,\n",
" messages=messages_specific,\n",
" temperature=0,\n",
" max_tokens=1024,\n",
" tools=tools,\n",
" tool_choice={\n",
" \"type\": \"function\",\n",
" \"function\": {\"name\": \"get_current_weather\"},\n",
" }, # Force the model to call the specific get_current_weather function\n",
")\n",
"\n",
"print_highlight(\"Response with specific function choice:\")\n",
"print(\"Content:\", response_specific.choices[0].message.content)\n",
"print(\"Tool calls:\", response_specific.choices[0].message.tool_calls)\n",
"\n",
"if response_specific.choices[0].message.tool_calls:\n",
" tool_call = response_specific.choices[0].message.tool_calls[0]\n",
" print(f\"Called function: {tool_call.function.name}\")\n",
" print(f\"Arguments: {tool_call.function.arguments}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
......@@ -444,7 +604,7 @@
"outputs": [],
"source": [
"import sglang as sgl\n",
"from sglang.srt.function_call_parser import FunctionCallParser\n",
"from sglang.srt.function_call.function_call_parser import FunctionCallParser\n",
"from sglang.srt.managers.io_struct import Tool, Function\n",
"\n",
"llm = sgl.Engine(model_path=\"Qwen/Qwen2.5-7B-Instruct\")\n",
......
......@@ -47,7 +47,7 @@ from sglang.srt.disaggregation.utils import (
register_disaggregation_server,
)
from sglang.srt.entrypoints.engine import _launch_subprocesses
from sglang.srt.function_call_parser import FunctionCallParser
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import (
AbortReq,
CloseSessionReqInput,
......
import json
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from partial_json_parser.core.exceptions import MalformedJSON
from partial_json_parser.core.options import Allow
from sglang.srt.function_call.core_types import (
StreamingParseResult,
ToolCallItem,
_GetInfoFunc,
)
from sglang.srt.function_call.utils import (
_find_common_prefix,
_is_complete_json,
_partial_json_loads,
)
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
class BaseFormatDetector(ABC):
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
def __init__(self):
# initialize properties used for state when parsing tool calls in
self._buffer = ""
# streaming mode
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: List[str] = (
[]
) # map what has been streamed for each tool so far to a list
self.bot_token = ""
self.eot_token = ""
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
tool_indices = {
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
}
if not isinstance(action, list):
action = [action]
results = []
for act in action:
name = act.get("name")
if name and name in tool_indices:
results.append(
ToolCallItem(
tool_index=tool_indices[name],
name=name,
parameters=json.dumps(
act.get("parameters") or act.get("arguments", {}),
ensure_ascii=False,
),
)
)
else:
logger.warning(f"Model attempted to call undefined function: {name}")
return results
@abstractmethod
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
Parses the text in one go. Returns success=True if the format matches, otherwise False.
Note that leftover_text here represents "content that this parser will not consume further".
"""
action = json.loads(text)
return StreamingParseResult(calls=self.parse_base_json(action, tools))
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing with tool validation.
"""
# Append new text to buffer
self._buffer += new_text
current_text = self._buffer
if not (self.bot_token in current_text or current_text.startswith("{")):
self._buffer = ""
if self.eot_token in new_text:
new_text = new_text.replace(self.eot_token, "")
return StreamingParseResult(normal_text=new_text)
# Build tool indices if not already built
if not hasattr(self, "_tool_indices"):
self._tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function and tool.function.name
}
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
tool_call_arr = []
is_complete = []
try:
start_idx = (
len(self.bot_token)
if current_text.startswith(self.bot_token)
else 0
)
while start_idx < len(current_text):
(obj, end_idx) = _partial_json_loads(
current_text[start_idx:], flags
)
is_complete.append(
_is_complete_json(current_text[start_idx : start_idx + end_idx])
)
start_idx += end_idx + len("; ")
# Validate tool name if present
if "name" in obj and obj["name"] not in self._tool_indices:
# Invalid tool name - reset state
self._buffer = ""
self.current_tool_id = -1
self.current_tool_name_sent = False
if self.streamed_args_for_tool:
self.streamed_args_for_tool.pop()
return StreamingParseResult()
# Handle parameters/arguments consistency
if "parameters" in obj:
assert (
"arguments" not in obj
), "model generated both parameters and arguments"
obj["arguments"] = obj["parameters"]
tool_call_arr.append(obj)
except MalformedJSON:
return StreamingParseResult()
if len(tool_call_arr) == 0:
return StreamingParseResult()
current_tool_call: Dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
)
# Handle new tool in array
if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments)
sent = len(self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self.current_tool_id,
name="",
parameters=argument_diff,
)
],
)
self.streamed_args_for_tool[
self.current_tool_id
] += argument_diff
else:
res = StreamingParseResult()
else:
res = StreamingParseResult()
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
return res
# Handle tool name
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name and function_name in self._tool_indices:
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self._tool_indices[function_name],
name=function_name,
parameters="",
)
],
)
self.current_tool_name_sent = True
else:
res = StreamingParseResult()
# Handle streaming arguments
else:
cur_arguments = current_tool_call.get("arguments")
res = StreamingParseResult()
if cur_arguments:
sent = len(self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
argument_diff = None
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
self._buffer = ""
self.prev_tool_call_arr[self.current_tool_id].clear()
self.current_tool_name_sent = False
self.streamed_args_for_tool[self.current_tool_id] = ""
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json:
prefix = _find_common_prefix(prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
if argument_diff is not None:
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self.current_tool_id,
parameters=argument_diff,
)
],
)
if not is_complete[self.current_tool_id]:
self.streamed_args_for_tool[
self.current_tool_id
] += argument_diff
self.prev_tool_call_arr = tool_call_arr
return res
except Exception as e:
logger.error(f"Error in parse_streaming_increment: {e}")
return StreamingParseResult()
@abstractmethod
def has_tool_call(self, text: str) -> bool:
raise NotImplementedError()
@abstractmethod
def structure_info(self) -> _GetInfoFunc:
raise NotImplementedError()
@abstractmethod
def build_ebnf(self, tools: List[Tool]) -> str:
raise NotImplementedError()
from dataclasses import dataclass
from typing import Callable, List, Optional
from pydantic import BaseModel
class ToolCallItem(BaseModel):
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
tool_index: int
name: Optional[str] = None
parameters: str # JSON string
class StreamingParseResult(BaseModel):
"""Result of streaming incremental parsing."""
normal_text: str = ""
calls: List[ToolCallItem] = []
@dataclass
class StructureInfo:
begin: str
end: str
trigger: str
"""
Helper alias of function
Usually it is a function that takes a name string and returns a StructureInfo object,
which can be used to construct a structural_tag object
"""
_GetInfoFunc = Callable[[str], StructureInfo]
import json
import logging
import re
from typing import List
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
ToolCallItem,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.function_call.utils import _is_complete_json
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
class DeepSeekV3Detector(BaseFormatDetector):
"""
Detector for DeepSeek models.
Assumes function call format:
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
"""
def __init__(self):
super().__init__()
self.bot_token = "<|tool▁calls▁begin|>"
self.eot_token = "<|tool▁calls▁end|>"
self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
self._last_arguments = ""
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a deepseek format tool call."""
return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text
if self.bot_token not in text:
return StreamingParseResult(normal_text=normal_text, calls=[])
match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
calls = []
try:
for match_result in match_result_list:
# Get function name
func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
func_name = func_detail.group(2)
func_args = func_detail.group(3)
func_args = json.loads(func_args)
# construct match_result for parse_base_json
match_result = {"name": func_name, "parameters": func_args}
calls.extend(self.parse_base_json(match_result, tools))
return StreamingParseResult(normal_text=normal_text, calls=calls)
except Exception as e:
logger.error(f"Error in detect_and_parse: {e}")
# return the normal text if parsing fails
return StreamingParseResult(normal_text=text)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing tool calls for DeepSeekV3 format.
"""
self._buffer += new_text
current_text = self._buffer
if self.bot_token not in current_text:
self._buffer = ""
for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]:
if e_token in new_text:
new_text = new_text.replace(e_token, "")
return StreamingParseResult(normal_text=new_text)
if not hasattr(self, "_tool_indices"):
self._tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function and tool.function.name
}
calls: list[ToolCallItem] = []
try:
partial_match = re.search(
pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)",
string=current_text,
flags=re.DOTALL,
)
if partial_match:
func_name = partial_match.group(2).strip()
func_args_raw = partial_match.group(3).strip()
if not self.current_tool_name_sent:
calls.append(
ToolCallItem(
tool_index=self._tool_indices.get(func_name, 0),
name=func_name,
parameters="",
)
)
self.current_tool_name_sent = True
else:
argument_diff = (
func_args_raw[len(self._last_arguments) :]
if func_args_raw.startswith(self._last_arguments)
else func_args_raw
)
if argument_diff:
calls.append(
ToolCallItem(
tool_index=self._tool_indices.get(func_name, 0),
name=None,
parameters=argument_diff,
)
)
self._last_arguments += argument_diff
if _is_complete_json(func_args_raw):
result = StreamingParseResult(normal_text="", calls=calls)
self._buffer = ""
self._last_arguments = ""
self.current_tool_name_sent = False
return result
return StreamingParseResult(normal_text="", calls=calls)
except Exception as e:
logger.error(f"Error in parse_streaming_increment: {e}")
return StreamingParseResult(normal_text=current_text)
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin=">" + name + "\n```json\n",
end="\n```<",
trigger=">" + name + "\n```json\n",
)
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
bot_token=self.bot_token,
eot_token=self.eot_token,
tool_call_separator="",
call_rule_fmt='"<|tool▁call▁begin|>function<|tool▁sep|>{name}\\n```json\\n" {arguments_rule} "\\n```<|tool▁call▁end|>"',
function_format="json",
)
from typing import Literal, Optional
class EBNFComposer:
# Adapted from https://xgrammar.mlc.ai/docs/how_to/ebnf_guided_generation.html#try-out-via-hf-transformers
json_grammar_ebnf_str = r"""
json ::= basic_array | basic_object
basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object
basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"?
basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
basic_string ::= (([\"] basic_string_1 [\"]))
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
basic_boolean ::= "true" | "false"
basic_null ::= "null"
basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
ws ::= [ \n\t]*
"""
pythonic_grammar_ebnf_str = r"""
pythonic ::= basic_number | basic_string | basic_array | "True" | "False" | "None"
basic_any ::= basic_number | basic_string | basic_array | basic_object
basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
basic_string ::= (([\"] basic_string_1 [\"]))
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
ws ::= [ \n\t]*
"""
TOOL_CALLS_MAP = {
"pythonic": '"[" function_call ("," function_call)* "]"',
"json": "function_call",
}
CALL_RULE_MAP = {
"pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"',
"json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"',
}
ARGUMENTS_RULE_MAP = {
"pythonic": "{arg_rules}",
"json": '"{{" {arg_rules} "}}"',
}
KEY_VALUE_RULE_MAP = {
"pythonic": '"{key}" "=" {valrule}',
"json": '"\\"{key}\\"" ":" {valrule}',
}
JSON_TYPE_MAPPING = {
"string": "basic_string",
"number": "basic_number",
"integer": "basic_number",
"boolean": "basic_boolean",
"null": "basic_null",
"array": "basic_array",
"object": "basic_object",
}
PYTHONIC_TYPE_MAPPING = {
"string": "basic_string",
"number": "basic_number",
"integer": "basic_number",
"boolean": '"True" | "False"',
"null": '"None"',
"array": "basic_array",
"object": "basic_object",
}
@staticmethod
def get_value_rule(
prop: dict, function_format: Literal["pythonic", "json"] = "json"
) -> str:
if "enum" in prop:
return EBNFComposer._handle_enum(prop, function_format)
if "type" in prop:
return EBNFComposer._handle_type(prop, function_format)
return function_format
@staticmethod
def _handle_enum(prop: dict, function_format: str) -> str:
"""Handle enum properties by formatting each value according to type and format."""
enum_values = prop["enum"]
prop_type = prop.get("type", "string")
# Define formatters for different type/format combinations
formatters = {
("string", "json"): lambda v: f'"\\"{v}\\""',
("string", "pythonic"): lambda v: f'"\\"{v}\\""',
("number", "json"): str,
("number", "pythonic"): str,
("integer", "json"): str,
("integer", "pythonic"): str,
("boolean", "json"): lambda v: "true" if v else "false",
("boolean", "pythonic"): lambda v: "True" if v else "False",
}
# Get the formatter or default to string handling
formatter = formatters.get(
(prop_type, function_format),
formatters[("string", function_format)], # Default to string handling
)
formatted_values = [formatter(value) for value in enum_values]
enum_rule = " | ".join(formatted_values)
# Wrap in parentheses if there are multiple values to ensure correct EBNF precedence
if len(formatted_values) > 1:
enum_rule = f"({enum_rule})"
return enum_rule
@staticmethod
def _handle_type(prop: dict, function_format: str) -> str:
"""Handle type properties using the appropriate type mapping."""
prop_type = prop["type"]
type_mapping = (
EBNFComposer.PYTHONIC_TYPE_MAPPING
if function_format == "pythonic"
else EBNFComposer.JSON_TYPE_MAPPING
)
if isinstance(prop_type, list):
type_rules = [
type_mapping[single_type]
for single_type in prop_type
if single_type in type_mapping
]
return " | ".join(type_rules) if type_rules else function_format
return type_mapping.get(prop_type, function_format)
@staticmethod
def build_ebnf(
tools,
*,
call_rule_fmt: Optional[str] = None,
function_format: Literal["pythonic", "json"] = "json",
bot_token: Optional[str] = None,
eot_token: Optional[str] = None,
tool_call_separator: Optional[str] = None,
):
"""
Generalized EBNF builder for all detectors.
Args:
tools: List of Tool objects to generate EBNF grammar for
call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with
the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default
format based on function_format will be used.
function_format: The format of function calls, either "pythonic" or "json"
bot_token: The token that indicates the start of a tool call section
eot_token: The token that indicates the end of a tool call section
tool_call_separator: The separator between multiple tool calls
"""
# =================================================================
# Step 1: Determine the root tool calls rule
# =================================================================
if bot_token and eot_token:
if tool_call_separator:
root_rule = f'"{bot_token}" function_call ( "{tool_call_separator}" function_call )* "{eot_token}"'
else:
root_rule = f'"{bot_token}" function_call "{eot_token}"'
else:
root_rule = EBNFComposer.TOOL_CALLS_MAP[function_format]
# =================================================================
# Step 2: Build the header rules
# =================================================================
ebnf_lines = [
f"root ::= {root_rule}",
"function_call ::= "
+ " | ".join([f"call_{tool.function.name}" for tool in tools]),
]
# =================================================================
# Step 3: Set up formatting templates
# =================================================================
call_template = (
f"call_{{name}} ::= {call_rule_fmt}"
if call_rule_fmt
else EBNFComposer.CALL_RULE_MAP[function_format]
)
args_template = EBNFComposer.ARGUMENTS_RULE_MAP[function_format]
key_value_template = EBNFComposer.KEY_VALUE_RULE_MAP[function_format]
# =================================================================
# Step 4: Build rules for each tool
# =================================================================
for tool in tools:
tool_name = tool.function.name
params = tool.function.parameters or {}
properties = params.get("properties", {})
required_props = set(params.get("required", []))
# Build argument rules for this tool
arg_rules = []
for prop_name, prop_schema in properties.items():
value_rule = EBNFComposer.get_value_rule(prop_schema, function_format)
# Create key=value pair
pair = key_value_template.format(key=prop_name, valrule=value_rule)
if prop_name not in required_props:
pair = f"[ {pair} ]"
arg_rules.append(pair)
# Combine all argument rules
combined_args = ' "," '.join(arg_rules) if arg_rules else ""
arguments_rule = args_template.format(arg_rules=combined_args)
# Add the function call rule and its arguments rule
ebnf_lines.append(
call_template.format(
name=tool_name, arguments_rule=f"arguments_{tool_name}"
)
)
ebnf_lines.append(f"arguments_{tool_name} ::= {arguments_rule}")
# =================================================================
# Step 5: Add base grammar rules
# =================================================================
base_grammar = (
EBNFComposer.pythonic_grammar_ebnf_str
if function_format == "pythonic"
else EBNFComposer.json_grammar_ebnf_str
)
ebnf_lines.append(base_grammar)
return "\n".join(ebnf_lines)
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.openai_api.protocol import (
StructuralTagResponseFormat,
StructuresResponseFormat,
Tool,
ToolChoice,
)
class FunctionCallParser:
"""
Parser for function/tool calls in model outputs.
This class handles both streaming and non-streaming parsing of function calls using a detector.
In streaming scenarios, each time new_text is received, it calls detector.parse_streaming_increment
and returns the resulting normal_text and calls to the upper layer (or SSE).
"""
ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {
"llama3": Llama32Detector,
"qwen25": Qwen25Detector,
"mistral": MistralDetector,
"deepseekv3": DeepSeekV3Detector,
"pythonic": PythonicDetector,
}
def __init__(self, tools: List[Tool], tool_call_parser: str):
detector: Type[BaseFormatDetector] = None
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
if detector_class:
detector = detector_class()
else:
raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
self.detector = detector
self.tools = tools
def has_tool_call(self, text: str) -> bool:
"""
Check if the given text contains a tool call in the format supported by this parser.
This delegates to the detector's implementation.
Args:
text: The text to check for tool calls
Returns:
True if the text contains a tool call, False otherwise
"""
return self.detector.has_tool_call(text)
def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]:
"""
One-time parsing of the full text to extract tool calls.
Args:
full_text: The complete text to parse
Returns:
A tuple containing:
- The remaining text after parsing that was not consumed by the detector (can be treated as normal text)
- A list of tool calls parsed from the text
"""
parsed_result = self.detector.detect_and_parse(full_text, self.tools)
tool_call_list = parsed_result.calls
if tool_call_list:
return parsed_result.normal_text, tool_call_list
else:
return full_text, []
def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]:
"""
Streaming incremental parsing of chunks of text as they arrive.
Args:
chunk_text: The new chunk of text to parse
Returns:
A tuple containing:
- The normal text that should be displayed to the user
- A list of tool calls parsed from the chunk
"""
final_normal_text = ""
final_calls = []
sp_result = self.detector.parse_streaming_increment(chunk_text, self.tools)
if sp_result.normal_text:
final_normal_text = sp_result.normal_text
if sp_result.calls:
final_calls.extend(sp_result.calls)
final_normal_text = sp_result.normal_text
return final_normal_text, final_calls
def get_structure_tag(self) -> StructuralTagResponseFormat:
"""
Generate a structural tag response format for all available tools.
This creates the necessary structural tags that guide the model's output format.
"""
tool_structures: List[StructuresResponseFormat] = list()
tool_trigger_set: Set[str] = set()
get_structure_info = self.detector.structure_info()
for tool in self.tools:
function = tool.function
name = function.name
assert name is not None
info = get_structure_info(name)
# accept all if not strict, otherwise only accept the schema
schema = function.parameters if function.strict else {}
tool_structures.append(
StructuresResponseFormat(
begin=info.begin,
schema=schema, # type: ignore
end=info.end,
)
)
tool_trigger_set.add(info.trigger)
return StructuralTagResponseFormat(
type="structural_tag",
structures=tool_structures,
triggers=list(tool_trigger_set),
)
def get_structure_constraint(
self, tool_choice: Union[ToolChoice, Literal["auto", "required"]]
) -> Optional[Tuple[str, Any]]:
"""
Returns the appropriate structure constraint for tool calls based on the tool_choice.
The constraint is used to guide the model's output format.
Args:
tool_choice: The tool choice setting from the request
Returns:
A tuple of (constraint_type, constraint_value) to be added to sampling parameters,
or None if no constraint applies.
"""
# NOTE: structural_tag only supports JSON-compatible content between the begin and end.
# It cannot parse or validate Python syntax like function calls.
if (
not isinstance(self.detector, PythonicDetector)
and tool_choice == "auto"
and any(tool.function.strict for tool in self.tools)
):
strict_tag = self.get_structure_tag()
return ("structural_tag", strict_tag)
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
ebnf = self.get_ebnf(tool_choice)
return ("ebnf", ebnf) if ebnf is not None else None
def get_ebnf(
self, tool_choice: Union[ToolChoice, Literal["required"]]
) -> Optional[str]:
"""
Get the EBNF grammar for the specified tool choice.
"""
filtered_tools = []
if isinstance(tool_choice, ToolChoice):
fn_name = tool_choice.function.name
filtered_tools = [t for t in self.tools if t.function.name == fn_name]
else:
filtered_tools = self.tools
return self.detector.build_ebnf(filtered_tools)
import json
import logging
from typing import List
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
class Llama32Detector(BaseFormatDetector):
"""
Detector for Llama 3.2 models.
Assumes function call format:
<|python_tag|>{"name":"xxx", "arguments":{...}}
"""
def __init__(self):
super().__init__()
self.bot_token = "<|python_tag|>"
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Llama 3.2 format tool call."""
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
return "<|python_tag|>" in text or text.startswith("{")
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""Parse function calls from text, handling multiple JSON objects."""
if "<|python_tag|>" not in text and not text.startswith("{"):
return StreamingParseResult(normal_text=text, calls=[])
if "<|python_tag|>" in text:
normal_text, action_text = text.split("<|python_tag|>")
else:
normal_text, action_text = "", text
# Split by semicolon and process each part
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
all_actions = []
for part in json_parts:
try:
# Parse each individual JSON object
action = json.loads(part)
all_actions.append(action)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON part: {part}")
logger.warning(f"JSON parse error: {str(e)}")
continue
calls = []
# Only process if we found valid JSON objects
if all_actions:
calls = self.parse_base_json(all_actions, tools)
return StreamingParseResult(normal_text=normal_text, calls=calls)
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin='<|python_tag|>{"name":"' + name + '", "arguments":',
end="}",
trigger="<|python_tag|>",
)
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
function_format="json",
tool_call_separator=",",
)
import json
import re
from typing import List
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
class MistralDetector(BaseFormatDetector):
"""
Detector for Mistral models.
Assumes function call format:
[TOOL_CALLS] [{"name":"xxx", "arguments":{...}}]
"""
def __init__(self):
"""
Initializes the detector with necessary state variables.
"""
super().__init__()
self.bot_token = "[TOOL_CALLS] ["
self.eot_token = "]"
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Mistral format tool call."""
return self.bot_token in text
def _clean_text(self, text: str) -> str:
"""
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
for example,
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
The key pattern is [TOOL_CALLS] [...]
"""
# TODO: check if Mistral supports multiple tool calls, currently assume only support one tool call
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
if len(find_results) > 0:
return find_results[0]
else:
return ""
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text
text = self._clean_text(text)
tool_content = text.replace("[TOOL_CALLS]", "").strip()
raw_tool_calls = self.tool_call_regex.findall(tool_content)
calls = []
if len(raw_tool_calls) > 0:
raw_tool_call = raw_tool_calls[0]
function_call_arr = json.loads(raw_tool_call)
for match_result in function_call_arr:
calls.extend(self.parse_base_json(match_result, tools))
return StreamingParseResult(normal_text=normal_text, calls=calls)
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
end="}]",
trigger="[TOOL_CALLS]",
)
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
bot_token=self.bot_token,
eot_token=self.eot_token,
function_format="json",
)
import ast
import json
import logging
import re
from typing import List, Optional
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
ToolCallItem,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
class PythonicDetector(BaseFormatDetector):
"""
Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
Assumes function call format:
[tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
Arguments are Python literals (not JSON).
"""
def __init__(self):
super().__init__()
self.tool_call_regex = re.compile(
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
re.DOTALL,
)
def has_tool_call(self, text: str) -> bool:
return bool(self.tool_call_regex.match(text.strip()))
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
# Try parsing the text as a Python list of function calls
text = text.strip()
if not (text.startswith("[") and text.endswith("]")):
# Not a pythonic tool call format
return StreamingParseResult(normal_text=text, calls=[])
try:
module = ast.parse(text)
parsed = getattr(module.body[0], "value", None)
if not (
isinstance(parsed, ast.List)
and all(isinstance(e, ast.Call) for e in parsed.elts)
):
return StreamingParseResult(normal_text=text, calls=[])
calls = []
tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function.name
}
for call in parsed.elts:
if not isinstance(call.func, ast.Name):
continue
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = self._get_parameter_value(keyword.value)
calls.append(
ToolCallItem(
tool_index=tool_indices.get(function_name, -1),
name=function_name,
parameters=json.dumps(arguments, ensure_ascii=False),
)
)
return StreamingParseResult(normal_text="", calls=calls)
except Exception:
logger.exception("Error in pythonic tool call parsing.")
return StreamingParseResult(normal_text=text, calls=[])
def _find_matching_bracket(self, buffer: str, start: int) -> int:
"""
Find the matching closing bracket for the opening bracket at start position.
Properly handles nested brackets.
Args:
buffer: The text buffer to search in
start: Position of the opening bracket '['
Returns:
Position of the matching closing bracket ']', or -1 if not found
"""
bracket_count = 0
for i in range(start, len(buffer)):
if buffer[i] == "[":
bracket_count += 1
elif buffer[i] == "]":
bracket_count -= 1
if bracket_count == 0:
return i
return -1 # No matching bracket found
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing for pythonic tool calls.
Buffers input until a complete pythonic tool call (from [ to ]) is found,
then parses and emits any detected calls.
"""
self._buffer += new_text
start = self._buffer.find("[")
if start == -1:
normal_text = self._buffer
self._buffer = ""
return StreamingParseResult(normal_text=normal_text)
normal_text = self._buffer[:start] if start > 0 else ""
end = self._find_matching_bracket(self._buffer, start)
if end != -1:
call_text = self._buffer[start : end + 1]
result = self.detect_and_parse(call_text, tools)
self._buffer = self._buffer[end + 1 :]
# If we had normal text before the tool call, add it to the result
if normal_text:
result.normal_text = normal_text + (result.normal_text or "")
return result
# We have an opening bracket but no closing bracket yet
if normal_text:
self._buffer = self._buffer[start:]
return StreamingParseResult(normal_text=normal_text)
# Otherwise, we're still accumulating a potential tool call
return StreamingParseResult(normal_text="")
def _get_parameter_value(self, val):
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
return {
k.value: self._get_parameter_value(v)
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [self._get_parameter_value(v) for v in val.elts]
else:
raise ValueError("Tool call arguments must be literals")
def structure_info(self) -> _GetInfoFunc:
def info(name: str):
return StructureInfo(begin=f"[{name}(", end=")]", trigger=f"[{name}(")
return info
def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
return EBNFComposer.build_ebnf(
tools,
bot_token="[",
eot_token="]",
tool_call_separator=",",
function_format="pythonic",
)
import json
import re
from typing import List
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
class Qwen25Detector(BaseFormatDetector):
"""
Detector for Qwen 2.5 models.
Assumes function call format:
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
"""
def __init__(self):
"""
Initializes the detector with necessary state variables.
"""
super().__init__()
self.bot_token = "<tool_call>"
self.eot_token = "</tool_call>"
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Qwen 2.5 format tool call."""
return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text
if self.bot_token not in text:
return StreamingParseResult(normal_text=normal_text, calls=[])
pattern = rf"{self.bot_token}(.*?){self.eot_token}"
match_result_list = re.findall(pattern, text, re.DOTALL)
calls = []
for match_result in match_result_list:
match_result = json.loads(match_result)
calls.extend(self.parse_base_json(match_result, tools))
return StreamingParseResult(normal_text=normal_text, calls=calls)
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin='<tool_call>{"name":"' + name + '", "arguments":',
end="}</tool_call>",
trigger="<tool_call>",
)
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
bot_token=self.bot_token,
eot_token=self.eot_token,
function_format="json",
)
import json
from json import JSONDecodeError, JSONDecoder
from typing import Any, Tuple
import partial_json_parser
from partial_json_parser.core.options import Allow
def _find_common_prefix(s1: str, s2: str) -> str:
prefix = ""
min_length = min(len(s1), len(s2))
for i in range(0, min_length):
if s1[i] == s2[i]:
prefix += s1[i]
else:
break
return prefix
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e:
if "Extra data" in e.msg:
dec = JSONDecoder()
return dec.raw_decode(input_str)
raise
def _is_complete_json(input_str: str) -> bool:
try:
json.loads(input_str)
return True
except JSONDecodeError:
return False
import ast
import json
import logging
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from json import JSONDecodeError, JSONDecoder
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
import partial_json_parser
from partial_json_parser.core.exceptions import MalformedJSON
from partial_json_parser.core.options import Allow
from pydantic import BaseModel
from sglang.srt.openai_api.protocol import (
StructuralTagResponseFormat,
StructuresResponseFormat,
Tool,
)
logger = logging.getLogger(__name__)
TOOLS_TAG_LIST = [
"<|plugin|>",
"<function=",
"<tool_call>",
"<|python_tag|>",
"[TOOL_CALLS]",
"<|tool▁calls▁begin|>",
]
class ToolCallItem(BaseModel):
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
tool_index: int
name: Optional[str] = None
parameters: str # JSON string
def _find_common_prefix(s1: str, s2: str) -> str:
prefix = ""
min_length = min(len(s1), len(s2))
for i in range(0, min_length):
if s1[i] == s2[i]:
prefix += s1[i]
else:
break
return prefix
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e:
if "Extra data" in e.msg:
dec = JSONDecoder()
return dec.raw_decode(input_str)
raise
def _is_complete_json(input_str: str) -> bool:
try:
json.loads(input_str)
return True
except JSONDecodeError:
return False
class StreamingParseResult:
"""Result of streaming incremental parsing."""
def __init__(
self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None
):
self.normal_text = normal_text
self.calls = calls or []
@dataclass
class StructureInfo:
begin: str
end: str
trigger: str
_GetInfoFunc = Callable[[str], StructureInfo]
"""
Helper alias of function
Usually it is a function that takes a name string and returns a StructureInfo object,
which can be used to construct a structural_tag object
"""
class BaseFormatDetector(ABC):
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
def __init__(self):
# initialize properties used for state when parsing tool calls in
self._buffer = ""
# streaming mode
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: List[str] = (
[]
) # map what has been streamed for each tool so far to a list
self.bot_token = ""
self.eot_token = ""
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
tool_indices = {
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
}
if not isinstance(action, list):
action = [action]
results = []
for act in action:
name = act.get("name")
if name and name in tool_indices:
results.append(
ToolCallItem(
tool_index=tool_indices[name],
name=name,
parameters=json.dumps(
act.get("parameters") or act.get("arguments", {}),
ensure_ascii=False,
),
)
)
else:
logger.warning(f"Model attempted to call undefined function: {name}")
return results
@abstractmethod
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
Parses the text in one go. Returns success=True if the format matches, otherwise False.
Note that leftover_text here represents "content that this parser will not consume further".
"""
action = json.loads(text)
return StreamingParseResult(calls=self.parse_base_json(action, tools))
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing with tool validation.
"""
# Append new text to buffer
self._buffer += new_text
current_text = self._buffer
if not (self.bot_token in current_text or current_text.startswith("{")):
self._buffer = ""
if self.eot_token in new_text:
new_text = new_text.replace(self.eot_token, "")
return StreamingParseResult(normal_text=new_text)
# Build tool indices if not already built
if not hasattr(self, "_tool_indices"):
self._tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function and tool.function.name
}
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
tool_call_arr = []
is_complete = []
try:
start_idx = (
len(self.bot_token)
if current_text.startswith(self.bot_token)
else 0
)
while start_idx < len(current_text):
(obj, end_idx) = _partial_json_loads(
current_text[start_idx:], flags
)
is_complete.append(
_is_complete_json(current_text[start_idx : start_idx + end_idx])
)
start_idx += end_idx + len("; ")
# Validate tool name if present
if "name" in obj and obj["name"] not in self._tool_indices:
# Invalid tool name - reset state
self._buffer = ""
self.current_tool_id = -1
self.current_tool_name_sent = False
if self.streamed_args_for_tool:
self.streamed_args_for_tool.pop()
return StreamingParseResult()
# Handle parameters/arguments consistency
if "parameters" in obj:
assert (
"arguments" not in obj
), "model generated both parameters and arguments"
obj["arguments"] = obj["parameters"]
tool_call_arr.append(obj)
except MalformedJSON:
return StreamingParseResult()
if len(tool_call_arr) == 0:
return StreamingParseResult()
current_tool_call: Dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
)
# Handle new tool in array
if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments)
sent = len(self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self.current_tool_id,
name="",
parameters=argument_diff,
)
],
)
self.streamed_args_for_tool[
self.current_tool_id
] += argument_diff
else:
res = StreamingParseResult()
else:
res = StreamingParseResult()
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
return res
# Handle tool name
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name and function_name in self._tool_indices:
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self._tool_indices[function_name],
name=function_name,
parameters="",
)
],
)
self.current_tool_name_sent = True
else:
res = StreamingParseResult()
# Handle streaming arguments
else:
cur_arguments = current_tool_call.get("arguments")
res = StreamingParseResult()
if cur_arguments:
sent = len(self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
argument_diff = None
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
self._buffer = ""
self.prev_tool_call_arr[self.current_tool_id].clear()
self.current_tool_name_sent = False
self.streamed_args_for_tool[self.current_tool_id] = ""
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json:
prefix = _find_common_prefix(prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
if argument_diff is not None:
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self.current_tool_id,
parameters=argument_diff,
)
],
)
if not is_complete[self.current_tool_id]:
self.streamed_args_for_tool[
self.current_tool_id
] += argument_diff
self.prev_tool_call_arr = tool_call_arr
return res
except Exception as e:
logger.error(f"Error in parse_streaming_increment: {e}")
return StreamingParseResult()
@abstractmethod
def has_tool_call(self, text: str) -> bool:
raise NotImplementedError()
@abstractmethod
def structure_info(self) -> _GetInfoFunc:
raise NotImplementedError()
class Qwen25Detector(BaseFormatDetector):
"""
Detector for Qwen 2.5 models.
Assumes function call format:
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
"""
def __init__(self):
"""
Initializes the detector with necessary state variables.
"""
super().__init__()
self.bot_token = "<tool_call>"
self.eot_token = "</tool_call>"
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Qwen 2.5 format tool call."""
return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text
if self.bot_token not in text:
return StreamingParseResult(normal_text=normal_text, calls=[])
pattern = rf"{self.bot_token}(.*?){self.eot_token}"
match_result_list = re.findall(pattern, text, re.DOTALL)
calls = []
for match_result in match_result_list:
match_result = json.loads(match_result)
calls.extend(self.parse_base_json(match_result, tools))
return StreamingParseResult(normal_text=normal_text, calls=calls)
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin='<tool_call>{"name":"' + name + '", "arguments":',
end="}</tool_call>",
trigger="<tool_call>",
)
class MistralDetector(BaseFormatDetector):
"""
Detector for Mistral models.
Assumes function call format:
<|action_start|><|plugin|>{"name":"xxx", "arguments":{...}}<|action_end|>
"""
def __init__(self):
"""
Initializes the detector with necessary state variables.
"""
super().__init__()
self.bot_token = "[TOOL_CALLS] ["
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Mistral format tool call."""
return self.bot_token in text
def _clean_text(self, text: str) -> str:
"""
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
for example,
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
The key pattern is [TOOL_CALLS] [...]
"""
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
if len(find_results) > 0:
return find_results[0]
else:
return ""
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text
text = self._clean_text(text)
tool_content = text.replace("[TOOL_CALLS]", "").strip()
raw_tool_calls = self.tool_call_regex.findall(tool_content)
calls = []
if len(raw_tool_calls) > 0:
raw_tool_call = raw_tool_calls[0]
function_call_arr = json.loads(raw_tool_call)
for match_result in function_call_arr:
calls.extend(self.parse_base_json(match_result, tools))
return StreamingParseResult(normal_text=normal_text, calls=calls)
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
end="}]",
trigger="[TOOL_CALLS]",
)
class Llama32Detector(BaseFormatDetector):
"""
Detector for Llama 3.2 models.
Assumes function call format:
<|python_tag|>{"name":"xxx", "arguments":{...}}
"""
def __init__(self):
super().__init__()
self.bot_token = "<|python_tag|>"
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Llama 3.2 format tool call."""
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
return "<|python_tag|>" in text or text.startswith("{")
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""Parse function calls from text, handling multiple JSON objects."""
if "<|python_tag|>" not in text and not text.startswith("{"):
return StreamingParseResult(normal_text=text, calls=[])
if "<|python_tag|>" in text:
normal_text, action_text = text.split("<|python_tag|>")
else:
normal_text, action_text = "", text
# Split by semicolon and process each part
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
all_actions = []
for part in json_parts:
try:
# Parse each individual JSON object
action = json.loads(part)
all_actions.append(action)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON part: {part}")
logger.warning(f"JSON parse error: {str(e)}")
continue
calls = []
# Only process if we found valid JSON objects
if all_actions:
calls = self.parse_base_json(all_actions, tools)
return StreamingParseResult(normal_text=normal_text, calls=calls)
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin='<|python_tag|>{"name":"' + name + '", "arguments":',
end="}",
trigger="<|python_tag|>",
)
class DeepSeekV3Detector(BaseFormatDetector):
"""
Detector for DeepSeek models.
Assumes function call format:
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
"""
def __init__(self):
super().__init__()
self.bot_token = "<|tool▁calls▁begin|>"
self.eot_token = "<|tool▁calls▁end|>"
self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
self._last_arguments = ""
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a deepseek format tool call."""
return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text
if self.bot_token not in text:
return StreamingParseResult(normal_text=normal_text, calls=[])
match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
calls = []
try:
for match_result in match_result_list:
# Get function name
func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
func_name = func_detail.group(2)
func_args = func_detail.group(3)
func_args = json.loads(func_args)
# construct match_result for parse_base_json
match_result = {"name": func_name, "parameters": func_args}
calls.extend(self.parse_base_json(match_result, tools))
return StreamingParseResult(normal_text=normal_text, calls=calls)
except Exception as e:
logger.error(f"Error in detect_and_parse: {e}")
# return the normal text if parsing fails
return StreamingParseResult(normal_text=text)
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin=">" + name + "\n```json\n",
end="\n```<",
trigger=">" + name + "\n```json\n",
)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing tool calls for DeepSeekV3 format.
"""
self._buffer += new_text
current_text = self._buffer
if self.bot_token not in current_text:
self._buffer = ""
for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]:
if e_token in new_text:
new_text = new_text.replace(e_token, "")
return StreamingParseResult(normal_text=new_text)
if not hasattr(self, "_tool_indices"):
self._tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function and tool.function.name
}
calls: list[ToolCallItem] = []
try:
partial_match = re.search(
pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)",
string=current_text,
flags=re.DOTALL,
)
if partial_match:
func_name = partial_match.group(2).strip()
func_args_raw = partial_match.group(3).strip()
if not self.current_tool_name_sent:
calls.append(
ToolCallItem(
tool_index=self._tool_indices.get(func_name, 0),
name=func_name,
parameters="",
)
)
self.current_tool_name_sent = True
else:
argument_diff = (
func_args_raw[len(self._last_arguments) :]
if func_args_raw.startswith(self._last_arguments)
else func_args_raw
)
if argument_diff:
calls.append(
ToolCallItem(
tool_index=self._tool_indices.get(func_name, 0),
name=None,
parameters=argument_diff,
)
)
self._last_arguments += argument_diff
if _is_complete_json(func_args_raw):
result = StreamingParseResult(normal_text="", calls=calls)
self._buffer = ""
self._last_arguments = ""
self.current_tool_name_sent = False
return result
return StreamingParseResult(normal_text="", calls=calls)
except Exception as e:
logger.error(f"Error in parse_streaming_increment: {e}")
return StreamingParseResult(normal_text=current_text)
class MultiFormatParser:
def __init__(self, detectors: List[BaseFormatDetector]):
"""
:param detectors: A series of available Detector instances passed in
"""
self.detectors = detectors
def parse_once(
self, text: str, tools: List[Tool]
) -> Tuple[str, list[ToolCallItem]]:
"""
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
Return: (final_text, all_calls)
- final_text: The remaining text after parsing that was not consumed by any Detector (can be treated as normal text)
- all_calls: All calls parsed by the Detectors
"""
final_calls = []
final_normal_text = text
for detector in self.detectors:
parsed_result = detector.detect_and_parse(text, tools)
tool_call_list = parsed_result.calls
if len(tool_call_list) > 0: # parsed successfully
final_calls = tool_call_list
final_normal_text = parsed_result.normal_text
break
# leftover_text is the normal text not consumed by any Detector
return final_normal_text, final_calls
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> Tuple[str, list[ToolCallItem]]:
"""
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
and merge their produced normal_text/calls to return.
(The logic here can be "priority-based" or "parallel parsing" based on your needs)
"""
final_normal_text = ""
final_calls = []
for detector in self.detectors:
sp_result = detector.parse_streaming_increment(new_text, tools)
# Merge normal_text and calls
# If one sp_result contains result call, this should be a successful parse
# If one sp_result only contains normal_text, this can either be a successful
# parse or it is not using the desired parsing tool.
if sp_result.normal_text:
final_normal_text = sp_result.normal_text
if sp_result.calls:
final_calls.extend(sp_result.calls)
final_normal_text = sp_result.normal_text
break
return final_normal_text, final_calls
class PythonicDetector(BaseFormatDetector):
"""
Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
Assumes function call format:
[tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
Arguments are Python literals (not JSON).
"""
def __init__(self):
super().__init__()
self.tool_call_regex = re.compile(
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
re.DOTALL,
)
def has_tool_call(self, text: str) -> bool:
return bool(self.tool_call_regex.match(text.strip()))
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
# Try parsing the text as a Python list of function calls
text = text.strip()
if not (text.startswith("[") and text.endswith("]")):
# Not a pythonic tool call format
return StreamingParseResult(normal_text=text, calls=[])
try:
module = ast.parse(text)
parsed = getattr(module.body[0], "value", None)
if not (
isinstance(parsed, ast.List)
and all(isinstance(e, ast.Call) for e in parsed.elts)
):
return StreamingParseResult(normal_text=text, calls=[])
calls = []
tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function.name
}
for call in parsed.elts:
if not isinstance(call.func, ast.Name):
continue
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = self._get_parameter_value(keyword.value)
calls.append(
ToolCallItem(
tool_index=tool_indices.get(function_name, -1),
name=function_name,
parameters=json.dumps(arguments, ensure_ascii=False),
)
)
return StreamingParseResult(normal_text="", calls=calls)
except Exception:
logger.exception("Error in pythonic tool call parsing.")
return StreamingParseResult(normal_text=text, calls=[])
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing for pythonic tool calls.
Buffers input until a complete pythonic tool call (from [ to ]) is found,
then parses and emits any detected calls.
"""
self._buffer += new_text
start = self._buffer.find("[")
end = self._buffer.find("]", start)
if start != -1 and end != -1:
call_text = self._buffer[start : end + 1]
result = self.detect_and_parse(call_text, tools)
self._buffer = self._buffer[end + 1 :]
return result
return StreamingParseResult(normal_text="")
def _get_parameter_value(self, val):
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
return {
k.value: self._get_parameter_value(v)
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [self._get_parameter_value(v) for v in val.elts]
else:
raise ValueError("Tool call arguments must be literals")
def structure_info(self) -> _GetInfoFunc:
def info(name: str):
return StructureInfo(begin="[", end="]", trigger="")
return info
class FunctionCallParser:
"""
In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
and returns the resulting normal_text and calls to the upper layer (or SSE).
"""
ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {
"llama3": Llama32Detector,
"qwen25": Qwen25Detector,
"mistral": MistralDetector,
"deepseekv3": DeepSeekV3Detector,
"pythonic": PythonicDetector,
}
def __init__(self, tools: List[Tool], tool_call_parser: str):
detectors = []
if tool_call_parser:
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
if detector_class:
detectors.append(detector_class())
else:
raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
else:
raise ValueError("Tool Call Parser Not Given!")
self.multi_format_parser = MultiFormatParser(detectors)
self.tools = tools
def has_tool_call(self, text: str) -> bool:
"""
Check if the given text contains a tool call in the format supported by this parser.
This delegates to the detector's implementation.
:param text: The text to check for tool calls
:return: True if the text contains a tool call, False otherwise
"""
# Check all detectors in the multi_format_parser
for detector in self.multi_format_parser.detectors:
if detector.has_tool_call(text):
return True
return False
def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]:
"""
Non-streaming call: one-time parsing
"""
full_normal_text, calls = self.multi_format_parser.parse_once(
full_text, self.tools
)
return full_normal_text, calls
def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]:
"""
Streaming call: incremental parsing
"""
normal_text, calls = self.multi_format_parser.parse_streaming_increment(
chunk_text, self.tools
)
return normal_text, calls
def structure_infos(self) -> List[_GetInfoFunc]:
"""
Returns a list of structure_info functions for each detector
"""
return [
detector.structure_info() for detector in self.multi_format_parser.detectors
]
def get_structure_tag(self) -> StructuralTagResponseFormat:
tool_structures: List[StructuresResponseFormat] = list()
tool_trigger_set: Set[str] = set()
for wrapper in self.structure_infos():
for tool in self.tools:
function = tool.function
name = function.name
assert name is not None
info = wrapper(name)
# accept all if not strict, otherwise only accept the schema
schema = function.parameters if function.strict else {}
tool_structures.append(
StructuresResponseFormat(
begin=info.begin,
schema=schema, # type: ignore
end=info.end,
)
)
tool_trigger_set.add(info.trigger)
return StructuralTagResponseFormat(
type="structural_tag",
structures=tool_structures,
triggers=list(tool_trigger_set),
)
......@@ -40,7 +40,7 @@ from sglang.srt.conversation import (
get_conv_template_by_model_path,
register_conv_template,
)
from sglang.srt.function_call_parser import FunctionCallParser
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.openai_api.protocol import (
BatchRequest,
......@@ -970,7 +970,7 @@ def v1_chat_generate_request(
# - image_data: None or a list of image strings (URLs or base64 strings).
# - audio_data: None or a list of audio strings (URLs).
# None skips any image processing in GenerateReqInput.
strict_tag = None
tool_call_constraint = None
prompt = ""
prompt_ids = []
if not isinstance(request.messages, str):
......@@ -989,7 +989,9 @@ def v1_chat_generate_request(
tool_call_parser = tokenizer_manager.server_args.tool_call_parser
parser = FunctionCallParser(request.tools, tool_call_parser)
strict_tag = parser.get_structure_tag()
tool_call_constraint = parser.get_structure_constraint(
request.tool_choice
)
if chat_template_name is None:
openai_compatible_messages = []
......@@ -1156,20 +1158,24 @@ def v1_chat_generate_request(
request.response_format.model_dump(by_alias=True)
)
if strict_tag is not None:
if (
sampling_params.get("regex")
or sampling_params.get("ebnf")
or sampling_params.get("structural_tag")
or sampling_params.get("json_schema")
):
logger.warning(
"Constrained decoding is not compatible with tool calls."
# Check if there are already existing output constraints
has_existing_constraints = (
sampling_params.get("regex")
or sampling_params.get("ebnf")
or sampling_params.get("structural_tag")
or sampling_params.get("json_schema")
)
if tool_call_constraint and has_existing_constraints:
logger.warning("Constrained decoding is not compatible with tool calls.")
elif tool_call_constraint:
constraint_type, constraint_value = tool_call_constraint
if constraint_type == "structural_tag":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value.model_dump(by_alias=True)
)
else:
sampling_params["structural_tag"] = convert_json_schema_to_str(
strict_tag.model_dump(by_alias=True)
)
sampling_params[constraint_type] = constraint_value
sampling_params_list.append(sampling_params)
......
......@@ -36,7 +36,7 @@ suites = {
TestFile("test_fa3.py", 376),
TestFile("test_fim_completion.py", 40),
TestFile("test_fp8_kernel.py", 8),
TestFile("test_function_calling.py", 60),
TestFile("test_function_call_parser.py", 10),
TestFile("test_fused_moe.py", 30),
TestFile("test_hicache.py", 116),
TestFile("test_hicache_mla.py", 254),
......@@ -54,6 +54,7 @@ suites = {
TestFile("test_flashmla.py", 300),
TestFile("test_no_chunked_prefill.py", 108),
TestFile("test_no_overlap_scheduler.py", 216),
TestFile("test_openai_function_calling.py", 60),
TestFile("test_openai_server.py", 149),
TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60),
......
import json
import unittest
from xgrammar import GrammarCompiler, TokenizerInfo
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.openai_api.protocol import Function, Tool
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
class TestPythonicDetector(unittest.TestCase):
def setUp(self):
# Create sample tools for testing
self.tools = [
Tool(
type="function",
function=Function(
name="get_weather",
description="Get weather information",
parameters={
"properties": {
"location": {
"type": "string",
"description": "Location to get weather for",
},
"unit": {
"type": "string",
"description": "Temperature unit",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
),
),
Tool(
type="function",
function=Function(
name="search",
description="Search for information",
parameters={
"properties": {
"query": {
"type": "string",
"description": "Search query",
},
},
"required": ["query"],
},
),
),
]
self.detector = PythonicDetector()
def test_parse_streaming_no_brackets(self):
"""Test parsing text with no brackets (no tool calls)."""
text = "This is just normal text without any tool calls."
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, text)
self.assertEqual(result.calls, [])
self.assertEqual(self.detector._buffer, "") # Buffer should be cleared
def test_parse_streaming_complete_tool_call(self):
"""Test parsing a complete tool call."""
text = "Here's a tool call: [get_weather(location='New York', unit='celsius')]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "Here's a tool call: ")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(
self.detector._buffer, ""
) # Buffer should be cleared after processing
# Check the parameters
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["location"], "New York")
self.assertEqual(params["unit"], "celsius")
def test_parse_streaming_text_before_tool_call(self):
"""Test parsing text that appears before a tool call."""
text = "This is some text before [get_weather(location='London')]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "This is some text before ")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
# Check the parameters
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["location"], "London")
def test_parse_streaming_partial_tool_call(self):
"""Test parsing a partial tool call that spans multiple chunks."""
# First chunk with opening bracket but no closing bracket
text1 = "Let me check the weather: [get_weather(location="
result1 = self.detector.parse_streaming_increment(text1, self.tools)
self.assertEqual(result1.normal_text, "Let me check the weather: ")
self.assertEqual(result1.calls, [])
self.assertEqual(
self.detector._buffer, "[get_weather(location="
) # Partial tool call remains in buffer
# Second chunk completing the tool call
text2 = "'Paris')]"
result2 = self.detector.parse_streaming_increment(text2, self.tools)
self.assertEqual(result2.normal_text, "")
self.assertEqual(len(result2.calls), 1)
self.assertEqual(result2.calls[0].name, "get_weather")
# Check the parameters
params = json.loads(result2.calls[0].parameters)
self.assertEqual(params["location"], "Paris")
self.assertEqual(
self.detector._buffer, ""
) # Buffer should be cleared after processing
def test_parse_streaming_bracket_without_text_before(self):
"""Test parsing a tool call that starts at the beginning of the text."""
text = "[search(query='python programming')]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "search")
# Check the parameters
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["query"], "python programming")
def test_parse_streaming_text_after_tool_call(self):
"""Test parsing text that appears after a tool call."""
# First chunk with complete tool call and some text after
text = "[get_weather(location='Tokyo')] Here's the forecast:"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(
self.detector._buffer, " Here's the forecast:"
) # Text after tool call remains in buffer
# Process the remaining text in buffer
result2 = self.detector.parse_streaming_increment("", self.tools)
self.assertEqual(result2.normal_text, " Here's the forecast:")
self.assertEqual(result2.calls, [])
self.assertEqual(self.detector._buffer, "") # Buffer should be cleared
def test_parse_streaming_multiple_tool_calls(self):
"""Test parsing multiple tool calls in sequence."""
text = "[get_weather(location='Berlin')] and [search(query='restaurants')]"
# First tool call
result1 = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(len(result1.calls), 1)
self.assertEqual(result1.calls[0].name, "get_weather")
self.assertEqual(self.detector._buffer, " and [search(query='restaurants')]")
# Second tool call
result2 = self.detector.parse_streaming_increment("", self.tools)
self.assertEqual(result2.normal_text, " and ")
self.assertEqual(len(result2.calls), 1)
self.assertEqual(result2.calls[0].name, "search")
self.assertEqual(self.detector._buffer, "")
def test_parse_streaming_opening_bracket_only(self):
"""Test parsing text with only an opening bracket but no closing bracket."""
text = "Let's try this: ["
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "Let's try this: ")
self.assertEqual(result.calls, [])
self.assertEqual(
self.detector._buffer, "["
) # Opening bracket remains in buffer
def test_parse_streaming_nested_brackets(self):
"""Test parsing tool calls with nested brackets in arguments."""
# Test with list argument containing nested brackets
text = "[get_weather(location='New York', unit='celsius', data=[1, 2, 3])]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(self.detector._buffer, "")
# Check the parameters
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["location"], "New York")
self.assertEqual(params["unit"], "celsius")
self.assertEqual(params["data"], [1, 2, 3])
def test_parse_streaming_nested_brackets_dict(self):
"""Test parsing tool calls with nested dictionaries and lists."""
# Test with nested dict and list arguments
text = "[search(query='test', config={'options': [1, 2], 'nested': {'key': 'value'}})]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "search")
self.assertEqual(self.detector._buffer, "")
# Check the parameters
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["query"], "test")
self.assertEqual(params["config"]["options"], [1, 2])
self.assertEqual(params["config"]["nested"]["key"], "value")
def test_parse_streaming_multiple_tools_with_nested_brackets(self):
"""Test parsing multiple tool calls with nested brackets."""
text = "[get_weather(location='Paris', data=[10, 20]), search(query='test', filters=['a', 'b'])]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "")
self.assertEqual(len(result.calls), 2)
self.assertEqual(self.detector._buffer, "")
# Check first tool call
params1 = json.loads(result.calls[0].parameters)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(params1["location"], "Paris")
self.assertEqual(params1["data"], [10, 20])
# Check second tool call
params2 = json.loads(result.calls[1].parameters)
self.assertEqual(result.calls[1].name, "search")
self.assertEqual(params2["query"], "test")
self.assertEqual(params2["filters"], ["a", "b"])
def test_parse_streaming_partial_nested_brackets(self):
"""Test parsing partial tool calls with nested brackets across chunks."""
# First chunk with nested brackets but incomplete
text1 = "Here's a call: [get_weather(location='Tokyo', data=[1, 2"
result1 = self.detector.parse_streaming_increment(text1, self.tools)
self.assertEqual(result1.normal_text, "Here's a call: ")
self.assertEqual(result1.calls, [])
self.assertEqual(
self.detector._buffer, "[get_weather(location='Tokyo', data=[1, 2"
)
# Second chunk completing the nested brackets
text2 = ", 3])]"
result2 = self.detector.parse_streaming_increment(text2, self.tools)
self.assertEqual(result2.normal_text, "")
self.assertEqual(len(result2.calls), 1)
self.assertEqual(result2.calls[0].name, "get_weather")
self.assertEqual(self.detector._buffer, "")
# Check the parameters
params = json.loads(result2.calls[0].parameters)
self.assertEqual(params["location"], "Tokyo")
self.assertEqual(params["data"], [1, 2, 3])
class TestEBNFGeneration(unittest.TestCase):
def setUp(self):
# Create sample tools for testing
self.tools = [
Tool(
type="function",
function=Function(
name="get_weather",
description="Get weather information",
parameters={
"properties": {
"location": {
"type": "string",
"description": "Location to get weather for",
},
"unit": {
"type": "string",
"description": "Temperature unit",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
),
),
Tool(
type="function",
function=Function(
name="search",
description="Search for information",
parameters={
"properties": {
"query": {
"type": "string",
"description": "Search query",
},
},
"required": ["query"],
},
),
),
]
self.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
tokenizer_info = TokenizerInfo.from_huggingface(self.tokenizer)
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
# Initialize all detectors
self.pythonic_detector = PythonicDetector()
self.deepseekv3_detector = DeepSeekV3Detector()
self.llama32_detector = Llama32Detector()
self.mistral_detector = MistralDetector()
self.qwen25_detector = Qwen25Detector()
def test_pythonic_detector_ebnf(self):
"""Test that the PythonicDetector generates valid EBNF."""
ebnf = self.pythonic_detector.build_ebnf(self.tools)
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns
self.assertIn('call_get_weather ::= "get_weather" "(" ', ebnf)
self.assertIn('"location" "=" basic_string', ebnf)
self.assertIn('[ "unit" "=" ("\\"celsius\\"" | "\\"fahrenheit\\"") ]', ebnf)
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
def test_deepseekv3_detector_ebnf(self):
"""Test that the DeepSeekV3Detector generates valid EBNF."""
ebnf = self.deepseekv3_detector.build_ebnf(self.tools)
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns
self.assertIn("<|tool▁calls▁begin|>", ebnf)
self.assertIn("<|tool▁call▁begin|>function<|tool▁sep|>get_weather", ebnf)
self.assertIn('\\"location\\"" ":" basic_string ', ebnf)
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
def test_llama32_detector_ebnf(self):
"""Test that the Llama32Detector generates valid EBNF."""
ebnf = self.llama32_detector.build_ebnf(self.tools)
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns
self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf)
self.assertIn('"\\"arguments\\"" ":"', ebnf)
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
def test_mistral_detector_ebnf(self):
"""Test that the MistralDetector generates valid EBNF."""
ebnf = self.mistral_detector.build_ebnf(self.tools)
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns
self.assertIn('"[TOOL_CALLS] ["', ebnf)
self.assertIn("call_get_weather | call_search", ebnf)
self.assertIn('"\\"arguments\\"" ":"', ebnf)
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
def test_qwen25_detector_ebnf(self):
"""Test that the Qwen25Detector generates valid EBNF."""
ebnf = self.qwen25_detector.build_ebnf(self.tools)
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns
self.assertIn("<tool_call>", ebnf)
self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf)
self.assertIn('"\\"arguments\\"" ":"', ebnf)
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
if __name__ == "__main__":
unittest.main()
......@@ -290,6 +290,151 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
self.assertEqual(str(args_obj["int_a"]), "5", "Parameter int_a should be 5")
self.assertEqual(str(args_obj["int_b"]), "7", "Parameter int_b should be 7")
def test_function_call_required(self):
"""
Test: Whether tool_choice: "required" works as expected
- When tool_choice == "required", the model should return one or more tool_calls.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "sub",
"description": "Compute the difference of two integers",
"parameters": {
"type": "object",
"properties": {
"int_a": {
"type": "integer",
"description": "First integer",
},
"int_b": {
"type": "integer",
"description": "Second integer",
},
},
"required": ["int_a", "int_b"],
},
"strict": True,
},
},
{
"type": "function",
"function": {
"name": "get_weather",
"description": "use this to get latest weather information for a city given its name",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "name of the city to get weather for",
}
},
"required": ["city"],
},
},
},
]
messages = [{"role": "user", "content": "What is the capital of France?"}]
response = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=False,
tools=tools,
tool_choice="required",
)
tool_calls = response.choices[0].message.tool_calls
self.assertIsNotNone(tool_calls, "No tool_calls in the response")
function_name = tool_calls[0].function.name
arguments = tool_calls[0].function.arguments
args_obj = json.loads(arguments)
self.assertEqual(
function_name, "get_weather", "Function name should be 'get_weather'"
)
self.assertIn("city", args_obj, "Function arguments should have 'city'")
self.assertIn(
"Paris", args_obj["city"], "Parameter city should contain 'Paris'"
) # might be flaky
def test_function_call_specific(self):
"""
Test: Whether tool_choice: ToolChoice works as expected
- When tool_choice is a specific ToolChoice, the model should return one or more tool_calls.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "sub",
"description": "Compute the difference of two integers",
"parameters": {
"type": "object",
"properties": {
"int_a": {
"type": "integer",
"description": "First integer",
},
"int_b": {
"type": "integer",
"description": "Second integer",
},
},
"required": ["int_a", "int_b"],
},
"strict": True,
},
},
{
"type": "function",
"function": {
"name": "get_weather",
"description": "use this to get latest weather information for a city given its name",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "name of the city to get weather for",
}
},
"required": ["city"],
},
},
},
]
messages = [{"role": "user", "content": "What is the capital of France?"}]
response = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=False,
tools=tools,
tool_choice={"type": "function", "function": {"name": "get_weather"}},
)
tool_calls = response.choices[0].message.tool_calls
self.assertIsNotNone(tool_calls, "No tool_calls in the response")
function_name = tool_calls[0].function.name
arguments = tool_calls[0].function.arguments
args_obj = json.loads(arguments)
self.assertEqual(
function_name, "get_weather", "Function name should be 'get_weather'"
)
self.assertIn("city", args_obj, "Function arguments should have 'city'")
class TestOpenAIPythonicFunctionCalling(CustomTestCase):
PYTHONIC_TOOLS = [
......@@ -385,11 +530,13 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
stream=False,
)
tool_calls = response.choices[0].message.tool_calls
self.assertIsInstance(tool_calls, list)
self.assertIsInstance(tool_calls, list, "No tool_calls found")
self.assertGreaterEqual(len(tool_calls), 1)
names = [tc.function.name for tc in tool_calls]
self.assertIn("get_weather", names)
self.assertIn("get_tourist_attractions", names)
self.assertTrue(
"get_weather" in names or "get_tourist_attractions" in names,
f"Function name '{names}' should container either 'get_weather' or 'get_tourist_attractions'",
)
def test_pythonic_tool_call_streaming(self):
"""
......@@ -419,8 +566,10 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
self.assertTrue(found_tool_calls, "No tool_calls found in streaming response")
self.assertTrue(found_index, "No index field found in any streamed tool_call")
self.assertIn("get_weather", found_names)
self.assertIn("get_tourist_attractions", found_names)
self.assertTrue(
"get_weather" in found_names or "get_tourist_attractions" in found_names,
f"Function name '{found_names}' should container either 'get_weather' or 'get_tourist_attractions'",
)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment