Unverified Commit ed0c3035 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

feat(Tool Calling): Support `required` and specific function mode (#6550)

parent e6f11356
...@@ -54,10 +54,12 @@ ...@@ -54,10 +54,12 @@
"source": [ "source": [
"Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n", "Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n",
"\n", "\n",
"- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n", "- llama3: Llama 3.1 / 3.2 / 3.3 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct, meta-llama/Llama-3.3-70B-Instruct).\n",
"- llama4: Llama 4 (e.g. meta-llama/Llama-4-Scout-17B-16E-Instruct).\n",
"- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n", "- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n",
"Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n", "Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n",
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html)." "- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct) and QwQ (i.e. Qwen/QwQ-32B). Especially, for QwQ, we can enable the reasoning parser together with tool call parser, details about reasoning parser can be found in [reasoning parser](https://docs.sglang.ai/backend/separate_reasoning.html).\n",
"- deepseekv3: DeepSeek-v3 (e.g., deepseek-ai/DeepSeek-V3-0324).\n"
] ]
}, },
{ {
...@@ -360,6 +362,164 @@ ...@@ -360,6 +362,164 @@
"print(final_response.choices[0].message.content)" "print(final_response.choices[0].message.content)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool Choice Mode\n",
"\n",
"SGLang supports OpenAI's `tool_choice` parameter to control when and which tools the model should call. This feature is implemented using EBNF (Extended Backus-Naur Form) grammar to ensure reliable tool calling behavior.\n",
"\n",
"### Supported Tool Choice Options\n",
"\n",
"- **`tool_choice=\"required\"`**: Forces the model to call at least one tool\n",
"- **`tool_choice={\"type\": \"function\", \"function\": {\"name\": \"specific_function\"}}`**: Forces the model to call a specific function\n",
"\n",
"### Backend Compatibility\n",
"\n",
"Tool choice is fully supported with the **Xgrammar backend**, which is the default grammar backend (`--grammar-backend xgrammar`). However, it may not be fully supported with other backends such as `outlines`.\n",
"\n",
"### Example: Required Tool Choice"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Response with tool_choice='required':\n",
"Content: None\n",
"Tool calls: [ChatCompletionMessageToolCall(id='call_NFO3TSZuRRO8Eu3Cv79uiQ', function=Function(arguments='{\"city\": \"Paris\", \"unit\": \"celsius\"}', name='get_current_weather'), type='function', index=0)]\n"
]
}
],
"source": [
"from openai import OpenAI\n",
"import json\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"from sglang.test.test_utils import is_in_ci\n",
"\n",
"if is_in_ci():\n",
" from patch import launch_server_cmd\n",
"else:\n",
" from sglang.utils import launch_server_cmd\n",
" import nest_asyncio\n",
"\n",
" nest_asyncio.apply()\n",
"\n",
"# Start a new server session for tool choice examples\n",
"server_process_tool_choice, port_tool_choice = launch_server_cmd(\n",
" \"python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 0.0.0.0\"\n",
")\n",
"wait_for_server(f\"http://localhost:{port_tool_choice}\")\n",
"\n",
"# Initialize client for tool choice examples\n",
"client_tool_choice = OpenAI(\n",
" api_key=\"None\", base_url=f\"http://0.0.0.0:{port_tool_choice}/v1\"\n",
")\n",
"model_name_tool_choice = client_tool_choice.models.list().data[0].id\n",
"\n",
"# Example with tool_choice=\"required\" - forces the model to call a tool\n",
"messages_required = [\n",
" {\"role\": \"user\", \"content\": \"Hello, what is the capital of France?\"}\n",
"]\n",
"\n",
"# Define tools\n",
"tools = [\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"Get the current weather in a given location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"city\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n",
" },\n",
" \"unit\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The unit to fetch the temperature in\",\n",
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
" },\n",
" },\n",
" \"required\": [\"city\", \"unit\"],\n",
" },\n",
" },\n",
" }\n",
"]\n",
"\n",
"response_required = client_tool_choice.chat.completions.create(\n",
" model=model_name_tool_choice,\n",
" messages=messages_required,\n",
" temperature=0,\n",
" max_tokens=1024,\n",
" tools=tools,\n",
" tool_choice=\"required\", # Force the model to call a tool\n",
")\n",
"\n",
"print_highlight(\"Response with tool_choice='required':\")\n",
"print(\"Content:\", response_required.choices[0].message.content)\n",
"print(\"Tool calls:\", response_required.choices[0].message.tool_calls)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example: Specific Function Choice\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Response with specific function choice:\n",
"Content: None\n",
"Tool calls: [ChatCompletionMessageToolCall(id='call_fGL_1qsPQFqntNBPkSynJw', function=Function(arguments='{\"city\": \"Sophia Antipolis\", \"unit\": \"celsius\"}', name='get_current_weather'), type='function', index=0)]\n",
"Called function: get_current_weather\n",
"Arguments: {\"city\": \"Sophia Antipolis\", \"unit\": \"celsius\"}\n"
]
}
],
"source": [
"# Example with specific function choice - forces the model to call a specific function\n",
"messages_specific = [\n",
" {\"role\": \"user\", \"content\": \"What are the most attactive places in France?\"}\n",
"]\n",
"\n",
"response_specific = client_tool_choice.chat.completions.create(\n",
" model=model_name_tool_choice,\n",
" messages=messages_specific,\n",
" temperature=0,\n",
" max_tokens=1024,\n",
" tools=tools,\n",
" tool_choice={\n",
" \"type\": \"function\",\n",
" \"function\": {\"name\": \"get_current_weather\"},\n",
" }, # Force the model to call the specific get_current_weather function\n",
")\n",
"\n",
"print_highlight(\"Response with specific function choice:\")\n",
"print(\"Content:\", response_specific.choices[0].message.content)\n",
"print(\"Tool calls:\", response_specific.choices[0].message.tool_calls)\n",
"\n",
"if response_specific.choices[0].message.tool_calls:\n",
" tool_call = response_specific.choices[0].message.tool_calls[0]\n",
" print(f\"Called function: {tool_call.function.name}\")\n",
" print(f\"Arguments: {tool_call.function.arguments}\")"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
...@@ -444,7 +604,7 @@ ...@@ -444,7 +604,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import sglang as sgl\n", "import sglang as sgl\n",
"from sglang.srt.function_call_parser import FunctionCallParser\n", "from sglang.srt.function_call.function_call_parser import FunctionCallParser\n",
"from sglang.srt.managers.io_struct import Tool, Function\n", "from sglang.srt.managers.io_struct import Tool, Function\n",
"\n", "\n",
"llm = sgl.Engine(model_path=\"Qwen/Qwen2.5-7B-Instruct\")\n", "llm = sgl.Engine(model_path=\"Qwen/Qwen2.5-7B-Instruct\")\n",
......
...@@ -47,7 +47,7 @@ from sglang.srt.disaggregation.utils import ( ...@@ -47,7 +47,7 @@ from sglang.srt.disaggregation.utils import (
register_disaggregation_server, register_disaggregation_server,
) )
from sglang.srt.entrypoints.engine import _launch_subprocesses from sglang.srt.entrypoints.engine import _launch_subprocesses
from sglang.srt.function_call_parser import FunctionCallParser from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
CloseSessionReqInput, CloseSessionReqInput,
......
import json
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from partial_json_parser.core.exceptions import MalformedJSON
from partial_json_parser.core.options import Allow
from sglang.srt.function_call.core_types import (
StreamingParseResult,
ToolCallItem,
_GetInfoFunc,
)
from sglang.srt.function_call.utils import (
_find_common_prefix,
_is_complete_json,
_partial_json_loads,
)
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
class BaseFormatDetector(ABC):
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
def __init__(self):
# initialize properties used for state when parsing tool calls in
self._buffer = ""
# streaming mode
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: List[str] = (
[]
) # map what has been streamed for each tool so far to a list
self.bot_token = ""
self.eot_token = ""
def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
tool_indices = {
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
}
if not isinstance(action, list):
action = [action]
results = []
for act in action:
name = act.get("name")
if name and name in tool_indices:
results.append(
ToolCallItem(
tool_index=tool_indices[name],
name=name,
parameters=json.dumps(
act.get("parameters") or act.get("arguments", {}),
ensure_ascii=False,
),
)
)
else:
logger.warning(f"Model attempted to call undefined function: {name}")
return results
@abstractmethod
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
Parses the text in one go. Returns success=True if the format matches, otherwise False.
Note that leftover_text here represents "content that this parser will not consume further".
"""
action = json.loads(text)
return StreamingParseResult(calls=self.parse_base_json(action, tools))
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing with tool validation.
"""
# Append new text to buffer
self._buffer += new_text
current_text = self._buffer
if not (self.bot_token in current_text or current_text.startswith("{")):
self._buffer = ""
if self.eot_token in new_text:
new_text = new_text.replace(self.eot_token, "")
return StreamingParseResult(normal_text=new_text)
# Build tool indices if not already built
if not hasattr(self, "_tool_indices"):
self._tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function and tool.function.name
}
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
tool_call_arr = []
is_complete = []
try:
start_idx = (
len(self.bot_token)
if current_text.startswith(self.bot_token)
else 0
)
while start_idx < len(current_text):
(obj, end_idx) = _partial_json_loads(
current_text[start_idx:], flags
)
is_complete.append(
_is_complete_json(current_text[start_idx : start_idx + end_idx])
)
start_idx += end_idx + len("; ")
# Validate tool name if present
if "name" in obj and obj["name"] not in self._tool_indices:
# Invalid tool name - reset state
self._buffer = ""
self.current_tool_id = -1
self.current_tool_name_sent = False
if self.streamed_args_for_tool:
self.streamed_args_for_tool.pop()
return StreamingParseResult()
# Handle parameters/arguments consistency
if "parameters" in obj:
assert (
"arguments" not in obj
), "model generated both parameters and arguments"
obj["arguments"] = obj["parameters"]
tool_call_arr.append(obj)
except MalformedJSON:
return StreamingParseResult()
if len(tool_call_arr) == 0:
return StreamingParseResult()
current_tool_call: Dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
)
# Handle new tool in array
if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments)
sent = len(self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self.current_tool_id,
name="",
parameters=argument_diff,
)
],
)
self.streamed_args_for_tool[
self.current_tool_id
] += argument_diff
else:
res = StreamingParseResult()
else:
res = StreamingParseResult()
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
return res
# Handle tool name
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name and function_name in self._tool_indices:
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self._tool_indices[function_name],
name=function_name,
parameters="",
)
],
)
self.current_tool_name_sent = True
else:
res = StreamingParseResult()
# Handle streaming arguments
else:
cur_arguments = current_tool_call.get("arguments")
res = StreamingParseResult()
if cur_arguments:
sent = len(self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
argument_diff = None
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
self._buffer = ""
self.prev_tool_call_arr[self.current_tool_id].clear()
self.current_tool_name_sent = False
self.streamed_args_for_tool[self.current_tool_id] = ""
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json:
prefix = _find_common_prefix(prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
if argument_diff is not None:
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self.current_tool_id,
parameters=argument_diff,
)
],
)
if not is_complete[self.current_tool_id]:
self.streamed_args_for_tool[
self.current_tool_id
] += argument_diff
self.prev_tool_call_arr = tool_call_arr
return res
except Exception as e:
logger.error(f"Error in parse_streaming_increment: {e}")
return StreamingParseResult()
@abstractmethod
def has_tool_call(self, text: str) -> bool:
raise NotImplementedError()
@abstractmethod
def structure_info(self) -> _GetInfoFunc:
raise NotImplementedError()
@abstractmethod
def build_ebnf(self, tools: List[Tool]) -> str:
raise NotImplementedError()
from dataclasses import dataclass
from typing import Callable, List, Optional
from pydantic import BaseModel
class ToolCallItem(BaseModel):
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
tool_index: int
name: Optional[str] = None
parameters: str # JSON string
class StreamingParseResult(BaseModel):
"""Result of streaming incremental parsing."""
normal_text: str = ""
calls: List[ToolCallItem] = []
@dataclass
class StructureInfo:
begin: str
end: str
trigger: str
"""
Helper alias of function
Usually it is a function that takes a name string and returns a StructureInfo object,
which can be used to construct a structural_tag object
"""
_GetInfoFunc = Callable[[str], StructureInfo]
import json
import logging
import re
from typing import List
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
ToolCallItem,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.function_call.utils import _is_complete_json
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
class DeepSeekV3Detector(BaseFormatDetector):
"""
Detector for DeepSeek models.
Assumes function call format:
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
"""
def __init__(self):
super().__init__()
self.bot_token = "<|tool▁calls▁begin|>"
self.eot_token = "<|tool▁calls▁end|>"
self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
self._last_arguments = ""
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a deepseek format tool call."""
return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text
if self.bot_token not in text:
return StreamingParseResult(normal_text=normal_text, calls=[])
match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
calls = []
try:
for match_result in match_result_list:
# Get function name
func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
func_name = func_detail.group(2)
func_args = func_detail.group(3)
func_args = json.loads(func_args)
# construct match_result for parse_base_json
match_result = {"name": func_name, "parameters": func_args}
calls.extend(self.parse_base_json(match_result, tools))
return StreamingParseResult(normal_text=normal_text, calls=calls)
except Exception as e:
logger.error(f"Error in detect_and_parse: {e}")
# return the normal text if parsing fails
return StreamingParseResult(normal_text=text)
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing tool calls for DeepSeekV3 format.
"""
self._buffer += new_text
current_text = self._buffer
if self.bot_token not in current_text:
self._buffer = ""
for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]:
if e_token in new_text:
new_text = new_text.replace(e_token, "")
return StreamingParseResult(normal_text=new_text)
if not hasattr(self, "_tool_indices"):
self._tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function and tool.function.name
}
calls: list[ToolCallItem] = []
try:
partial_match = re.search(
pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)",
string=current_text,
flags=re.DOTALL,
)
if partial_match:
func_name = partial_match.group(2).strip()
func_args_raw = partial_match.group(3).strip()
if not self.current_tool_name_sent:
calls.append(
ToolCallItem(
tool_index=self._tool_indices.get(func_name, 0),
name=func_name,
parameters="",
)
)
self.current_tool_name_sent = True
else:
argument_diff = (
func_args_raw[len(self._last_arguments) :]
if func_args_raw.startswith(self._last_arguments)
else func_args_raw
)
if argument_diff:
calls.append(
ToolCallItem(
tool_index=self._tool_indices.get(func_name, 0),
name=None,
parameters=argument_diff,
)
)
self._last_arguments += argument_diff
if _is_complete_json(func_args_raw):
result = StreamingParseResult(normal_text="", calls=calls)
self._buffer = ""
self._last_arguments = ""
self.current_tool_name_sent = False
return result
return StreamingParseResult(normal_text="", calls=calls)
except Exception as e:
logger.error(f"Error in parse_streaming_increment: {e}")
return StreamingParseResult(normal_text=current_text)
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin=">" + name + "\n```json\n",
end="\n```<",
trigger=">" + name + "\n```json\n",
)
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
bot_token=self.bot_token,
eot_token=self.eot_token,
tool_call_separator="",
call_rule_fmt='"<|tool▁call▁begin|>function<|tool▁sep|>{name}\\n```json\\n" {arguments_rule} "\\n```<|tool▁call▁end|>"',
function_format="json",
)
from typing import Literal, Optional
class EBNFComposer:
# Adapted from https://xgrammar.mlc.ai/docs/how_to/ebnf_guided_generation.html#try-out-via-hf-transformers
json_grammar_ebnf_str = r"""
json ::= basic_array | basic_object
basic_any ::= basic_number | basic_string | basic_boolean | basic_null | basic_array | basic_object
basic_integer ::= ("0" | "-"? [1-9] [0-9]*) ".0"?
basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
basic_string ::= (([\"] basic_string_1 [\"]))
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
basic_boolean ::= "true" | "false"
basic_null ::= "null"
basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
ws ::= [ \n\t]*
"""
pythonic_grammar_ebnf_str = r"""
pythonic ::= basic_number | basic_string | basic_array | "True" | "False" | "None"
basic_any ::= basic_number | basic_string | basic_array | basic_object
basic_number ::= ("0" | "-"? [1-9] [0-9]*) ("." [0-9]+)? ([eE] [+-]? [0-9]+)?
basic_string ::= (([\"] basic_string_1 [\"]))
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
basic_array ::= "[" ("" | ws basic_any (ws "," ws basic_any)*) ws "]"
basic_object ::= "{" ("" | ws basic_string ws ":" ws basic_any ( ws "," ws basic_string ws ":" ws basic_any)*) ws "}"
ws ::= [ \n\t]*
"""
TOOL_CALLS_MAP = {
"pythonic": '"[" function_call ("," function_call)* "]"',
"json": "function_call",
}
CALL_RULE_MAP = {
"pythonic": 'call_{name} ::= "{name}" "(" {arguments_rule} ")"',
"json": 'call_{name} ::= "{{" "\\"name\\"" ":" "\\"{name}\\"" ", " "\\"arguments\\"" ":" {arguments_rule} "}}"',
}
ARGUMENTS_RULE_MAP = {
"pythonic": "{arg_rules}",
"json": '"{{" {arg_rules} "}}"',
}
KEY_VALUE_RULE_MAP = {
"pythonic": '"{key}" "=" {valrule}',
"json": '"\\"{key}\\"" ":" {valrule}',
}
JSON_TYPE_MAPPING = {
"string": "basic_string",
"number": "basic_number",
"integer": "basic_number",
"boolean": "basic_boolean",
"null": "basic_null",
"array": "basic_array",
"object": "basic_object",
}
PYTHONIC_TYPE_MAPPING = {
"string": "basic_string",
"number": "basic_number",
"integer": "basic_number",
"boolean": '"True" | "False"',
"null": '"None"',
"array": "basic_array",
"object": "basic_object",
}
@staticmethod
def get_value_rule(
prop: dict, function_format: Literal["pythonic", "json"] = "json"
) -> str:
if "enum" in prop:
return EBNFComposer._handle_enum(prop, function_format)
if "type" in prop:
return EBNFComposer._handle_type(prop, function_format)
return function_format
@staticmethod
def _handle_enum(prop: dict, function_format: str) -> str:
"""Handle enum properties by formatting each value according to type and format."""
enum_values = prop["enum"]
prop_type = prop.get("type", "string")
# Define formatters for different type/format combinations
formatters = {
("string", "json"): lambda v: f'"\\"{v}\\""',
("string", "pythonic"): lambda v: f'"\\"{v}\\""',
("number", "json"): str,
("number", "pythonic"): str,
("integer", "json"): str,
("integer", "pythonic"): str,
("boolean", "json"): lambda v: "true" if v else "false",
("boolean", "pythonic"): lambda v: "True" if v else "False",
}
# Get the formatter or default to string handling
formatter = formatters.get(
(prop_type, function_format),
formatters[("string", function_format)], # Default to string handling
)
formatted_values = [formatter(value) for value in enum_values]
enum_rule = " | ".join(formatted_values)
# Wrap in parentheses if there are multiple values to ensure correct EBNF precedence
if len(formatted_values) > 1:
enum_rule = f"({enum_rule})"
return enum_rule
@staticmethod
def _handle_type(prop: dict, function_format: str) -> str:
"""Handle type properties using the appropriate type mapping."""
prop_type = prop["type"]
type_mapping = (
EBNFComposer.PYTHONIC_TYPE_MAPPING
if function_format == "pythonic"
else EBNFComposer.JSON_TYPE_MAPPING
)
if isinstance(prop_type, list):
type_rules = [
type_mapping[single_type]
for single_type in prop_type
if single_type in type_mapping
]
return " | ".join(type_rules) if type_rules else function_format
return type_mapping.get(prop_type, function_format)
@staticmethod
def build_ebnf(
tools,
*,
call_rule_fmt: Optional[str] = None,
function_format: Literal["pythonic", "json"] = "json",
bot_token: Optional[str] = None,
eot_token: Optional[str] = None,
tool_call_separator: Optional[str] = None,
):
"""
Generalized EBNF builder for all detectors.
Args:
tools: List of Tool objects to generate EBNF grammar for
call_rule_fmt: Optional custom format string for call_{name} rule. It should define each function call's format, with
the placeholders {name} for the function name and {arguments_rule} for the arguments rule. If None, a default
format based on function_format will be used.
function_format: The format of function calls, either "pythonic" or "json"
bot_token: The token that indicates the start of a tool call section
eot_token: The token that indicates the end of a tool call section
tool_call_separator: The separator between multiple tool calls
"""
# =================================================================
# Step 1: Determine the root tool calls rule
# =================================================================
if bot_token and eot_token:
if tool_call_separator:
root_rule = f'"{bot_token}" function_call ( "{tool_call_separator}" function_call )* "{eot_token}"'
else:
root_rule = f'"{bot_token}" function_call "{eot_token}"'
else:
root_rule = EBNFComposer.TOOL_CALLS_MAP[function_format]
# =================================================================
# Step 2: Build the header rules
# =================================================================
ebnf_lines = [
f"root ::= {root_rule}",
"function_call ::= "
+ " | ".join([f"call_{tool.function.name}" for tool in tools]),
]
# =================================================================
# Step 3: Set up formatting templates
# =================================================================
call_template = (
f"call_{{name}} ::= {call_rule_fmt}"
if call_rule_fmt
else EBNFComposer.CALL_RULE_MAP[function_format]
)
args_template = EBNFComposer.ARGUMENTS_RULE_MAP[function_format]
key_value_template = EBNFComposer.KEY_VALUE_RULE_MAP[function_format]
# =================================================================
# Step 4: Build rules for each tool
# =================================================================
for tool in tools:
tool_name = tool.function.name
params = tool.function.parameters or {}
properties = params.get("properties", {})
required_props = set(params.get("required", []))
# Build argument rules for this tool
arg_rules = []
for prop_name, prop_schema in properties.items():
value_rule = EBNFComposer.get_value_rule(prop_schema, function_format)
# Create key=value pair
pair = key_value_template.format(key=prop_name, valrule=value_rule)
if prop_name not in required_props:
pair = f"[ {pair} ]"
arg_rules.append(pair)
# Combine all argument rules
combined_args = ' "," '.join(arg_rules) if arg_rules else ""
arguments_rule = args_template.format(arg_rules=combined_args)
# Add the function call rule and its arguments rule
ebnf_lines.append(
call_template.format(
name=tool_name, arguments_rule=f"arguments_{tool_name}"
)
)
ebnf_lines.append(f"arguments_{tool_name} ::= {arguments_rule}")
# =================================================================
# Step 5: Add base grammar rules
# =================================================================
base_grammar = (
EBNFComposer.pythonic_grammar_ebnf_str
if function_format == "pythonic"
else EBNFComposer.json_grammar_ebnf_str
)
ebnf_lines.append(base_grammar)
return "\n".join(ebnf_lines)
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Type, Union
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.openai_api.protocol import (
StructuralTagResponseFormat,
StructuresResponseFormat,
Tool,
ToolChoice,
)
class FunctionCallParser:
"""
Parser for function/tool calls in model outputs.
This class handles both streaming and non-streaming parsing of function calls using a detector.
In streaming scenarios, each time new_text is received, it calls detector.parse_streaming_increment
and returns the resulting normal_text and calls to the upper layer (or SSE).
"""
ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {
"llama3": Llama32Detector,
"qwen25": Qwen25Detector,
"mistral": MistralDetector,
"deepseekv3": DeepSeekV3Detector,
"pythonic": PythonicDetector,
}
def __init__(self, tools: List[Tool], tool_call_parser: str):
detector: Type[BaseFormatDetector] = None
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
if detector_class:
detector = detector_class()
else:
raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
self.detector = detector
self.tools = tools
def has_tool_call(self, text: str) -> bool:
"""
Check if the given text contains a tool call in the format supported by this parser.
This delegates to the detector's implementation.
Args:
text: The text to check for tool calls
Returns:
True if the text contains a tool call, False otherwise
"""
return self.detector.has_tool_call(text)
def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]:
"""
One-time parsing of the full text to extract tool calls.
Args:
full_text: The complete text to parse
Returns:
A tuple containing:
- The remaining text after parsing that was not consumed by the detector (can be treated as normal text)
- A list of tool calls parsed from the text
"""
parsed_result = self.detector.detect_and_parse(full_text, self.tools)
tool_call_list = parsed_result.calls
if tool_call_list:
return parsed_result.normal_text, tool_call_list
else:
return full_text, []
def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]:
"""
Streaming incremental parsing of chunks of text as they arrive.
Args:
chunk_text: The new chunk of text to parse
Returns:
A tuple containing:
- The normal text that should be displayed to the user
- A list of tool calls parsed from the chunk
"""
final_normal_text = ""
final_calls = []
sp_result = self.detector.parse_streaming_increment(chunk_text, self.tools)
if sp_result.normal_text:
final_normal_text = sp_result.normal_text
if sp_result.calls:
final_calls.extend(sp_result.calls)
final_normal_text = sp_result.normal_text
return final_normal_text, final_calls
def get_structure_tag(self) -> StructuralTagResponseFormat:
"""
Generate a structural tag response format for all available tools.
This creates the necessary structural tags that guide the model's output format.
"""
tool_structures: List[StructuresResponseFormat] = list()
tool_trigger_set: Set[str] = set()
get_structure_info = self.detector.structure_info()
for tool in self.tools:
function = tool.function
name = function.name
assert name is not None
info = get_structure_info(name)
# accept all if not strict, otherwise only accept the schema
schema = function.parameters if function.strict else {}
tool_structures.append(
StructuresResponseFormat(
begin=info.begin,
schema=schema, # type: ignore
end=info.end,
)
)
tool_trigger_set.add(info.trigger)
return StructuralTagResponseFormat(
type="structural_tag",
structures=tool_structures,
triggers=list(tool_trigger_set),
)
def get_structure_constraint(
self, tool_choice: Union[ToolChoice, Literal["auto", "required"]]
) -> Optional[Tuple[str, Any]]:
"""
Returns the appropriate structure constraint for tool calls based on the tool_choice.
The constraint is used to guide the model's output format.
Args:
tool_choice: The tool choice setting from the request
Returns:
A tuple of (constraint_type, constraint_value) to be added to sampling parameters,
or None if no constraint applies.
"""
# NOTE: structural_tag only supports JSON-compatible content between the begin and end.
# It cannot parse or validate Python syntax like function calls.
if (
not isinstance(self.detector, PythonicDetector)
and tool_choice == "auto"
and any(tool.function.strict for tool in self.tools)
):
strict_tag = self.get_structure_tag()
return ("structural_tag", strict_tag)
elif tool_choice == "required" or isinstance(tool_choice, ToolChoice):
ebnf = self.get_ebnf(tool_choice)
return ("ebnf", ebnf) if ebnf is not None else None
def get_ebnf(
self, tool_choice: Union[ToolChoice, Literal["required"]]
) -> Optional[str]:
"""
Get the EBNF grammar for the specified tool choice.
"""
filtered_tools = []
if isinstance(tool_choice, ToolChoice):
fn_name = tool_choice.function.name
filtered_tools = [t for t in self.tools if t.function.name == fn_name]
else:
filtered_tools = self.tools
return self.detector.build_ebnf(filtered_tools)
import json
import logging
from typing import List
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
class Llama32Detector(BaseFormatDetector):
"""
Detector for Llama 3.2 models.
Assumes function call format:
<|python_tag|>{"name":"xxx", "arguments":{...}}
"""
def __init__(self):
super().__init__()
self.bot_token = "<|python_tag|>"
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Llama 3.2 format tool call."""
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
return "<|python_tag|>" in text or text.startswith("{")
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""Parse function calls from text, handling multiple JSON objects."""
if "<|python_tag|>" not in text and not text.startswith("{"):
return StreamingParseResult(normal_text=text, calls=[])
if "<|python_tag|>" in text:
normal_text, action_text = text.split("<|python_tag|>")
else:
normal_text, action_text = "", text
# Split by semicolon and process each part
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
all_actions = []
for part in json_parts:
try:
# Parse each individual JSON object
action = json.loads(part)
all_actions.append(action)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON part: {part}")
logger.warning(f"JSON parse error: {str(e)}")
continue
calls = []
# Only process if we found valid JSON objects
if all_actions:
calls = self.parse_base_json(all_actions, tools)
return StreamingParseResult(normal_text=normal_text, calls=calls)
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin='<|python_tag|>{"name":"' + name + '", "arguments":',
end="}",
trigger="<|python_tag|>",
)
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
function_format="json",
tool_call_separator=",",
)
import json
import re
from typing import List
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
class MistralDetector(BaseFormatDetector):
"""
Detector for Mistral models.
Assumes function call format:
[TOOL_CALLS] [{"name":"xxx", "arguments":{...}}]
"""
def __init__(self):
"""
Initializes the detector with necessary state variables.
"""
super().__init__()
self.bot_token = "[TOOL_CALLS] ["
self.eot_token = "]"
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Mistral format tool call."""
return self.bot_token in text
def _clean_text(self, text: str) -> str:
"""
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
for example,
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
The key pattern is [TOOL_CALLS] [...]
"""
# TODO: check if Mistral supports multiple tool calls, currently assume only support one tool call
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
if len(find_results) > 0:
return find_results[0]
else:
return ""
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text
text = self._clean_text(text)
tool_content = text.replace("[TOOL_CALLS]", "").strip()
raw_tool_calls = self.tool_call_regex.findall(tool_content)
calls = []
if len(raw_tool_calls) > 0:
raw_tool_call = raw_tool_calls[0]
function_call_arr = json.loads(raw_tool_call)
for match_result in function_call_arr:
calls.extend(self.parse_base_json(match_result, tools))
return StreamingParseResult(normal_text=normal_text, calls=calls)
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
end="}]",
trigger="[TOOL_CALLS]",
)
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
bot_token=self.bot_token,
eot_token=self.eot_token,
function_format="json",
)
import ast
import json
import logging
import re
from typing import List, Optional
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
ToolCallItem,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
class PythonicDetector(BaseFormatDetector):
"""
Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
Assumes function call format:
[tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
Arguments are Python literals (not JSON).
"""
def __init__(self):
super().__init__()
self.tool_call_regex = re.compile(
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
re.DOTALL,
)
def has_tool_call(self, text: str) -> bool:
return bool(self.tool_call_regex.match(text.strip()))
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
# Try parsing the text as a Python list of function calls
text = text.strip()
if not (text.startswith("[") and text.endswith("]")):
# Not a pythonic tool call format
return StreamingParseResult(normal_text=text, calls=[])
try:
module = ast.parse(text)
parsed = getattr(module.body[0], "value", None)
if not (
isinstance(parsed, ast.List)
and all(isinstance(e, ast.Call) for e in parsed.elts)
):
return StreamingParseResult(normal_text=text, calls=[])
calls = []
tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function.name
}
for call in parsed.elts:
if not isinstance(call.func, ast.Name):
continue
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = self._get_parameter_value(keyword.value)
calls.append(
ToolCallItem(
tool_index=tool_indices.get(function_name, -1),
name=function_name,
parameters=json.dumps(arguments, ensure_ascii=False),
)
)
return StreamingParseResult(normal_text="", calls=calls)
except Exception:
logger.exception("Error in pythonic tool call parsing.")
return StreamingParseResult(normal_text=text, calls=[])
def _find_matching_bracket(self, buffer: str, start: int) -> int:
"""
Find the matching closing bracket for the opening bracket at start position.
Properly handles nested brackets.
Args:
buffer: The text buffer to search in
start: Position of the opening bracket '['
Returns:
Position of the matching closing bracket ']', or -1 if not found
"""
bracket_count = 0
for i in range(start, len(buffer)):
if buffer[i] == "[":
bracket_count += 1
elif buffer[i] == "]":
bracket_count -= 1
if bracket_count == 0:
return i
return -1 # No matching bracket found
def parse_streaming_increment(
self, new_text: str, tools: List[Tool]
) -> StreamingParseResult:
"""
Streaming incremental parsing for pythonic tool calls.
Buffers input until a complete pythonic tool call (from [ to ]) is found,
then parses and emits any detected calls.
"""
self._buffer += new_text
start = self._buffer.find("[")
if start == -1:
normal_text = self._buffer
self._buffer = ""
return StreamingParseResult(normal_text=normal_text)
normal_text = self._buffer[:start] if start > 0 else ""
end = self._find_matching_bracket(self._buffer, start)
if end != -1:
call_text = self._buffer[start : end + 1]
result = self.detect_and_parse(call_text, tools)
self._buffer = self._buffer[end + 1 :]
# If we had normal text before the tool call, add it to the result
if normal_text:
result.normal_text = normal_text + (result.normal_text or "")
return result
# We have an opening bracket but no closing bracket yet
if normal_text:
self._buffer = self._buffer[start:]
return StreamingParseResult(normal_text=normal_text)
# Otherwise, we're still accumulating a potential tool call
return StreamingParseResult(normal_text="")
def _get_parameter_value(self, val):
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
return {
k.value: self._get_parameter_value(v)
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [self._get_parameter_value(v) for v in val.elts]
else:
raise ValueError("Tool call arguments must be literals")
def structure_info(self) -> _GetInfoFunc:
def info(name: str):
return StructureInfo(begin=f"[{name}(", end=")]", trigger=f"[{name}(")
return info
def build_ebnf(self, tools: List[Tool]) -> Optional[str]:
return EBNFComposer.build_ebnf(
tools,
bot_token="[",
eot_token="]",
tool_call_separator=",",
function_format="pythonic",
)
import json
import re
from typing import List
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
class Qwen25Detector(BaseFormatDetector):
"""
Detector for Qwen 2.5 models.
Assumes function call format:
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
"""
def __init__(self):
"""
Initializes the detector with necessary state variables.
"""
super().__init__()
self.bot_token = "<tool_call>"
self.eot_token = "</tool_call>"
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Qwen 2.5 format tool call."""
return self.bot_token in text
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
idx = text.find(self.bot_token)
normal_text = text[:idx].strip() if idx != -1 else text
if self.bot_token not in text:
return StreamingParseResult(normal_text=normal_text, calls=[])
pattern = rf"{self.bot_token}(.*?){self.eot_token}"
match_result_list = re.findall(pattern, text, re.DOTALL)
calls = []
for match_result in match_result_list:
match_result = json.loads(match_result)
calls.extend(self.parse_base_json(match_result, tools))
return StreamingParseResult(normal_text=normal_text, calls=calls)
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin='<tool_call>{"name":"' + name + '", "arguments":',
end="}</tool_call>",
trigger="<tool_call>",
)
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
bot_token=self.bot_token,
eot_token=self.eot_token,
function_format="json",
)
import json
from json import JSONDecodeError, JSONDecoder
from typing import Any, Tuple
import partial_json_parser
from partial_json_parser.core.options import Allow
def _find_common_prefix(s1: str, s2: str) -> str:
prefix = ""
min_length = min(len(s1), len(s2))
for i in range(0, min_length):
if s1[i] == s2[i]:
prefix += s1[i]
else:
break
return prefix
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e:
if "Extra data" in e.msg:
dec = JSONDecoder()
return dec.raw_decode(input_str)
raise
def _is_complete_json(input_str: str) -> bool:
try:
json.loads(input_str)
return True
except JSONDecodeError:
return False
This diff is collapsed.
...@@ -40,7 +40,7 @@ from sglang.srt.conversation import ( ...@@ -40,7 +40,7 @@ from sglang.srt.conversation import (
get_conv_template_by_model_path, get_conv_template_by_model_path,
register_conv_template, register_conv_template,
) )
from sglang.srt.function_call_parser import FunctionCallParser from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.openai_api.protocol import ( from sglang.srt.openai_api.protocol import (
BatchRequest, BatchRequest,
...@@ -970,7 +970,7 @@ def v1_chat_generate_request( ...@@ -970,7 +970,7 @@ def v1_chat_generate_request(
# - image_data: None or a list of image strings (URLs or base64 strings). # - image_data: None or a list of image strings (URLs or base64 strings).
# - audio_data: None or a list of audio strings (URLs). # - audio_data: None or a list of audio strings (URLs).
# None skips any image processing in GenerateReqInput. # None skips any image processing in GenerateReqInput.
strict_tag = None tool_call_constraint = None
prompt = "" prompt = ""
prompt_ids = [] prompt_ids = []
if not isinstance(request.messages, str): if not isinstance(request.messages, str):
...@@ -989,7 +989,9 @@ def v1_chat_generate_request( ...@@ -989,7 +989,9 @@ def v1_chat_generate_request(
tool_call_parser = tokenizer_manager.server_args.tool_call_parser tool_call_parser = tokenizer_manager.server_args.tool_call_parser
parser = FunctionCallParser(request.tools, tool_call_parser) parser = FunctionCallParser(request.tools, tool_call_parser)
strict_tag = parser.get_structure_tag() tool_call_constraint = parser.get_structure_constraint(
request.tool_choice
)
if chat_template_name is None: if chat_template_name is None:
openai_compatible_messages = [] openai_compatible_messages = []
...@@ -1156,20 +1158,24 @@ def v1_chat_generate_request( ...@@ -1156,20 +1158,24 @@ def v1_chat_generate_request(
request.response_format.model_dump(by_alias=True) request.response_format.model_dump(by_alias=True)
) )
if strict_tag is not None: # Check if there are already existing output constraints
if ( has_existing_constraints = (
sampling_params.get("regex") sampling_params.get("regex")
or sampling_params.get("ebnf") or sampling_params.get("ebnf")
or sampling_params.get("structural_tag") or sampling_params.get("structural_tag")
or sampling_params.get("json_schema") or sampling_params.get("json_schema")
):
logger.warning(
"Constrained decoding is not compatible with tool calls."
) )
else:
sampling_params["structural_tag"] = convert_json_schema_to_str( if tool_call_constraint and has_existing_constraints:
strict_tag.model_dump(by_alias=True) logger.warning("Constrained decoding is not compatible with tool calls.")
elif tool_call_constraint:
constraint_type, constraint_value = tool_call_constraint
if constraint_type == "structural_tag":
sampling_params[constraint_type] = convert_json_schema_to_str(
constraint_value.model_dump(by_alias=True)
) )
else:
sampling_params[constraint_type] = constraint_value
sampling_params_list.append(sampling_params) sampling_params_list.append(sampling_params)
......
...@@ -36,7 +36,7 @@ suites = { ...@@ -36,7 +36,7 @@ suites = {
TestFile("test_fa3.py", 376), TestFile("test_fa3.py", 376),
TestFile("test_fim_completion.py", 40), TestFile("test_fim_completion.py", 40),
TestFile("test_fp8_kernel.py", 8), TestFile("test_fp8_kernel.py", 8),
TestFile("test_function_calling.py", 60), TestFile("test_function_call_parser.py", 10),
TestFile("test_fused_moe.py", 30), TestFile("test_fused_moe.py", 30),
TestFile("test_hicache.py", 116), TestFile("test_hicache.py", 116),
TestFile("test_hicache_mla.py", 254), TestFile("test_hicache_mla.py", 254),
...@@ -54,6 +54,7 @@ suites = { ...@@ -54,6 +54,7 @@ suites = {
TestFile("test_flashmla.py", 300), TestFile("test_flashmla.py", 300),
TestFile("test_no_chunked_prefill.py", 108), TestFile("test_no_chunked_prefill.py", 108),
TestFile("test_no_overlap_scheduler.py", 216), TestFile("test_no_overlap_scheduler.py", 216),
TestFile("test_openai_function_calling.py", 60),
TestFile("test_openai_server.py", 149), TestFile("test_openai_server.py", 149),
TestFile("test_penalty.py", 41), TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60), TestFile("test_page_size.py", 60),
......
import json
import unittest
from xgrammar import GrammarCompiler, TokenizerInfo
from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector
from sglang.srt.function_call.llama32_detector import Llama32Detector
from sglang.srt.function_call.mistral_detector import MistralDetector
from sglang.srt.function_call.pythonic_detector import PythonicDetector
from sglang.srt.function_call.qwen25_detector import Qwen25Detector
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.openai_api.protocol import Function, Tool
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
class TestPythonicDetector(unittest.TestCase):
def setUp(self):
# Create sample tools for testing
self.tools = [
Tool(
type="function",
function=Function(
name="get_weather",
description="Get weather information",
parameters={
"properties": {
"location": {
"type": "string",
"description": "Location to get weather for",
},
"unit": {
"type": "string",
"description": "Temperature unit",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
),
),
Tool(
type="function",
function=Function(
name="search",
description="Search for information",
parameters={
"properties": {
"query": {
"type": "string",
"description": "Search query",
},
},
"required": ["query"],
},
),
),
]
self.detector = PythonicDetector()
def test_parse_streaming_no_brackets(self):
"""Test parsing text with no brackets (no tool calls)."""
text = "This is just normal text without any tool calls."
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, text)
self.assertEqual(result.calls, [])
self.assertEqual(self.detector._buffer, "") # Buffer should be cleared
def test_parse_streaming_complete_tool_call(self):
"""Test parsing a complete tool call."""
text = "Here's a tool call: [get_weather(location='New York', unit='celsius')]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "Here's a tool call: ")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(
self.detector._buffer, ""
) # Buffer should be cleared after processing
# Check the parameters
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["location"], "New York")
self.assertEqual(params["unit"], "celsius")
def test_parse_streaming_text_before_tool_call(self):
"""Test parsing text that appears before a tool call."""
text = "This is some text before [get_weather(location='London')]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "This is some text before ")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
# Check the parameters
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["location"], "London")
def test_parse_streaming_partial_tool_call(self):
"""Test parsing a partial tool call that spans multiple chunks."""
# First chunk with opening bracket but no closing bracket
text1 = "Let me check the weather: [get_weather(location="
result1 = self.detector.parse_streaming_increment(text1, self.tools)
self.assertEqual(result1.normal_text, "Let me check the weather: ")
self.assertEqual(result1.calls, [])
self.assertEqual(
self.detector._buffer, "[get_weather(location="
) # Partial tool call remains in buffer
# Second chunk completing the tool call
text2 = "'Paris')]"
result2 = self.detector.parse_streaming_increment(text2, self.tools)
self.assertEqual(result2.normal_text, "")
self.assertEqual(len(result2.calls), 1)
self.assertEqual(result2.calls[0].name, "get_weather")
# Check the parameters
params = json.loads(result2.calls[0].parameters)
self.assertEqual(params["location"], "Paris")
self.assertEqual(
self.detector._buffer, ""
) # Buffer should be cleared after processing
def test_parse_streaming_bracket_without_text_before(self):
"""Test parsing a tool call that starts at the beginning of the text."""
text = "[search(query='python programming')]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "search")
# Check the parameters
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["query"], "python programming")
def test_parse_streaming_text_after_tool_call(self):
"""Test parsing text that appears after a tool call."""
# First chunk with complete tool call and some text after
text = "[get_weather(location='Tokyo')] Here's the forecast:"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(
self.detector._buffer, " Here's the forecast:"
) # Text after tool call remains in buffer
# Process the remaining text in buffer
result2 = self.detector.parse_streaming_increment("", self.tools)
self.assertEqual(result2.normal_text, " Here's the forecast:")
self.assertEqual(result2.calls, [])
self.assertEqual(self.detector._buffer, "") # Buffer should be cleared
def test_parse_streaming_multiple_tool_calls(self):
"""Test parsing multiple tool calls in sequence."""
text = "[get_weather(location='Berlin')] and [search(query='restaurants')]"
# First tool call
result1 = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(len(result1.calls), 1)
self.assertEqual(result1.calls[0].name, "get_weather")
self.assertEqual(self.detector._buffer, " and [search(query='restaurants')]")
# Second tool call
result2 = self.detector.parse_streaming_increment("", self.tools)
self.assertEqual(result2.normal_text, " and ")
self.assertEqual(len(result2.calls), 1)
self.assertEqual(result2.calls[0].name, "search")
self.assertEqual(self.detector._buffer, "")
def test_parse_streaming_opening_bracket_only(self):
"""Test parsing text with only an opening bracket but no closing bracket."""
text = "Let's try this: ["
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "Let's try this: ")
self.assertEqual(result.calls, [])
self.assertEqual(
self.detector._buffer, "["
) # Opening bracket remains in buffer
def test_parse_streaming_nested_brackets(self):
"""Test parsing tool calls with nested brackets in arguments."""
# Test with list argument containing nested brackets
text = "[get_weather(location='New York', unit='celsius', data=[1, 2, 3])]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(self.detector._buffer, "")
# Check the parameters
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["location"], "New York")
self.assertEqual(params["unit"], "celsius")
self.assertEqual(params["data"], [1, 2, 3])
def test_parse_streaming_nested_brackets_dict(self):
"""Test parsing tool calls with nested dictionaries and lists."""
# Test with nested dict and list arguments
text = "[search(query='test', config={'options': [1, 2], 'nested': {'key': 'value'}})]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "")
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "search")
self.assertEqual(self.detector._buffer, "")
# Check the parameters
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["query"], "test")
self.assertEqual(params["config"]["options"], [1, 2])
self.assertEqual(params["config"]["nested"]["key"], "value")
def test_parse_streaming_multiple_tools_with_nested_brackets(self):
"""Test parsing multiple tool calls with nested brackets."""
text = "[get_weather(location='Paris', data=[10, 20]), search(query='test', filters=['a', 'b'])]"
result = self.detector.parse_streaming_increment(text, self.tools)
self.assertEqual(result.normal_text, "")
self.assertEqual(len(result.calls), 2)
self.assertEqual(self.detector._buffer, "")
# Check first tool call
params1 = json.loads(result.calls[0].parameters)
self.assertEqual(result.calls[0].name, "get_weather")
self.assertEqual(params1["location"], "Paris")
self.assertEqual(params1["data"], [10, 20])
# Check second tool call
params2 = json.loads(result.calls[1].parameters)
self.assertEqual(result.calls[1].name, "search")
self.assertEqual(params2["query"], "test")
self.assertEqual(params2["filters"], ["a", "b"])
def test_parse_streaming_partial_nested_brackets(self):
"""Test parsing partial tool calls with nested brackets across chunks."""
# First chunk with nested brackets but incomplete
text1 = "Here's a call: [get_weather(location='Tokyo', data=[1, 2"
result1 = self.detector.parse_streaming_increment(text1, self.tools)
self.assertEqual(result1.normal_text, "Here's a call: ")
self.assertEqual(result1.calls, [])
self.assertEqual(
self.detector._buffer, "[get_weather(location='Tokyo', data=[1, 2"
)
# Second chunk completing the nested brackets
text2 = ", 3])]"
result2 = self.detector.parse_streaming_increment(text2, self.tools)
self.assertEqual(result2.normal_text, "")
self.assertEqual(len(result2.calls), 1)
self.assertEqual(result2.calls[0].name, "get_weather")
self.assertEqual(self.detector._buffer, "")
# Check the parameters
params = json.loads(result2.calls[0].parameters)
self.assertEqual(params["location"], "Tokyo")
self.assertEqual(params["data"], [1, 2, 3])
class TestEBNFGeneration(unittest.TestCase):
def setUp(self):
# Create sample tools for testing
self.tools = [
Tool(
type="function",
function=Function(
name="get_weather",
description="Get weather information",
parameters={
"properties": {
"location": {
"type": "string",
"description": "Location to get weather for",
},
"unit": {
"type": "string",
"description": "Temperature unit",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
),
),
Tool(
type="function",
function=Function(
name="search",
description="Search for information",
parameters={
"properties": {
"query": {
"type": "string",
"description": "Search query",
},
},
"required": ["query"],
},
),
),
]
self.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
tokenizer_info = TokenizerInfo.from_huggingface(self.tokenizer)
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
# Initialize all detectors
self.pythonic_detector = PythonicDetector()
self.deepseekv3_detector = DeepSeekV3Detector()
self.llama32_detector = Llama32Detector()
self.mistral_detector = MistralDetector()
self.qwen25_detector = Qwen25Detector()
def test_pythonic_detector_ebnf(self):
"""Test that the PythonicDetector generates valid EBNF."""
ebnf = self.pythonic_detector.build_ebnf(self.tools)
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns
self.assertIn('call_get_weather ::= "get_weather" "(" ', ebnf)
self.assertIn('"location" "=" basic_string', ebnf)
self.assertIn('[ "unit" "=" ("\\"celsius\\"" | "\\"fahrenheit\\"") ]', ebnf)
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
def test_deepseekv3_detector_ebnf(self):
"""Test that the DeepSeekV3Detector generates valid EBNF."""
ebnf = self.deepseekv3_detector.build_ebnf(self.tools)
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns
self.assertIn("<|tool▁calls▁begin|>", ebnf)
self.assertIn("<|tool▁call▁begin|>function<|tool▁sep|>get_weather", ebnf)
self.assertIn('\\"location\\"" ":" basic_string ', ebnf)
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
def test_llama32_detector_ebnf(self):
"""Test that the Llama32Detector generates valid EBNF."""
ebnf = self.llama32_detector.build_ebnf(self.tools)
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns
self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf)
self.assertIn('"\\"arguments\\"" ":"', ebnf)
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
def test_mistral_detector_ebnf(self):
"""Test that the MistralDetector generates valid EBNF."""
ebnf = self.mistral_detector.build_ebnf(self.tools)
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns
self.assertIn('"[TOOL_CALLS] ["', ebnf)
self.assertIn("call_get_weather | call_search", ebnf)
self.assertIn('"\\"arguments\\"" ":"', ebnf)
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
def test_qwen25_detector_ebnf(self):
"""Test that the Qwen25Detector generates valid EBNF."""
ebnf = self.qwen25_detector.build_ebnf(self.tools)
self.assertIsNotNone(ebnf)
# Check that the EBNF contains expected patterns
self.assertIn("<tool_call>", ebnf)
self.assertIn('\\"name\\"" ":" "\\"get_weather\\"', ebnf)
self.assertIn('"\\"arguments\\"" ":"', ebnf)
# Validate that the EBNF can be compiled by GrammarCompiler
try:
ctx = self.grammar_compiler.compile_grammar(ebnf)
self.assertIsNotNone(ctx, "EBNF should be valid and compile successfully")
except RuntimeError as e:
self.fail(f"Failed to compile EBNF: {e}")
if __name__ == "__main__":
unittest.main()
...@@ -290,6 +290,151 @@ class TestOpenAIServerFunctionCalling(CustomTestCase): ...@@ -290,6 +290,151 @@ class TestOpenAIServerFunctionCalling(CustomTestCase):
self.assertEqual(str(args_obj["int_a"]), "5", "Parameter int_a should be 5") self.assertEqual(str(args_obj["int_a"]), "5", "Parameter int_a should be 5")
self.assertEqual(str(args_obj["int_b"]), "7", "Parameter int_b should be 7") self.assertEqual(str(args_obj["int_b"]), "7", "Parameter int_b should be 7")
def test_function_call_required(self):
"""
Test: Whether tool_choice: "required" works as expected
- When tool_choice == "required", the model should return one or more tool_calls.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "sub",
"description": "Compute the difference of two integers",
"parameters": {
"type": "object",
"properties": {
"int_a": {
"type": "integer",
"description": "First integer",
},
"int_b": {
"type": "integer",
"description": "Second integer",
},
},
"required": ["int_a", "int_b"],
},
"strict": True,
},
},
{
"type": "function",
"function": {
"name": "get_weather",
"description": "use this to get latest weather information for a city given its name",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "name of the city to get weather for",
}
},
"required": ["city"],
},
},
},
]
messages = [{"role": "user", "content": "What is the capital of France?"}]
response = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=False,
tools=tools,
tool_choice="required",
)
tool_calls = response.choices[0].message.tool_calls
self.assertIsNotNone(tool_calls, "No tool_calls in the response")
function_name = tool_calls[0].function.name
arguments = tool_calls[0].function.arguments
args_obj = json.loads(arguments)
self.assertEqual(
function_name, "get_weather", "Function name should be 'get_weather'"
)
self.assertIn("city", args_obj, "Function arguments should have 'city'")
self.assertIn(
"Paris", args_obj["city"], "Parameter city should contain 'Paris'"
) # might be flaky
def test_function_call_specific(self):
"""
Test: Whether tool_choice: ToolChoice works as expected
- When tool_choice is a specific ToolChoice, the model should return one or more tool_calls.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "sub",
"description": "Compute the difference of two integers",
"parameters": {
"type": "object",
"properties": {
"int_a": {
"type": "integer",
"description": "First integer",
},
"int_b": {
"type": "integer",
"description": "Second integer",
},
},
"required": ["int_a", "int_b"],
},
"strict": True,
},
},
{
"type": "function",
"function": {
"name": "get_weather",
"description": "use this to get latest weather information for a city given its name",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "name of the city to get weather for",
}
},
"required": ["city"],
},
},
},
]
messages = [{"role": "user", "content": "What is the capital of France?"}]
response = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=False,
tools=tools,
tool_choice={"type": "function", "function": {"name": "get_weather"}},
)
tool_calls = response.choices[0].message.tool_calls
self.assertIsNotNone(tool_calls, "No tool_calls in the response")
function_name = tool_calls[0].function.name
arguments = tool_calls[0].function.arguments
args_obj = json.loads(arguments)
self.assertEqual(
function_name, "get_weather", "Function name should be 'get_weather'"
)
self.assertIn("city", args_obj, "Function arguments should have 'city'")
class TestOpenAIPythonicFunctionCalling(CustomTestCase): class TestOpenAIPythonicFunctionCalling(CustomTestCase):
PYTHONIC_TOOLS = [ PYTHONIC_TOOLS = [
...@@ -385,11 +530,13 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase): ...@@ -385,11 +530,13 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
stream=False, stream=False,
) )
tool_calls = response.choices[0].message.tool_calls tool_calls = response.choices[0].message.tool_calls
self.assertIsInstance(tool_calls, list) self.assertIsInstance(tool_calls, list, "No tool_calls found")
self.assertGreaterEqual(len(tool_calls), 1) self.assertGreaterEqual(len(tool_calls), 1)
names = [tc.function.name for tc in tool_calls] names = [tc.function.name for tc in tool_calls]
self.assertIn("get_weather", names) self.assertTrue(
self.assertIn("get_tourist_attractions", names) "get_weather" in names or "get_tourist_attractions" in names,
f"Function name '{names}' should container either 'get_weather' or 'get_tourist_attractions'",
)
def test_pythonic_tool_call_streaming(self): def test_pythonic_tool_call_streaming(self):
""" """
...@@ -419,8 +566,10 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase): ...@@ -419,8 +566,10 @@ class TestOpenAIPythonicFunctionCalling(CustomTestCase):
self.assertTrue(found_tool_calls, "No tool_calls found in streaming response") self.assertTrue(found_tool_calls, "No tool_calls found in streaming response")
self.assertTrue(found_index, "No index field found in any streamed tool_call") self.assertTrue(found_index, "No index field found in any streamed tool_call")
self.assertIn("get_weather", found_names) self.assertTrue(
self.assertIn("get_tourist_attractions", found_names) "get_weather" in found_names or "get_tourist_attractions" in found_names,
f"Function name '{found_names}' should container either 'get_weather' or 'get_tourist_attractions'",
)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment