Unverified Commit b045841b authored by YAMY's avatar YAMY Committed by GitHub
Browse files

Feature/function calling update (#2700)


Co-authored-by: default avatarMingyuan Ma <mamingyuan2001@berkeley.edu>
Co-authored-by: default avatarChayenne <zhaochen20@outlook.com>
Co-authored-by: default avatarshuaills <shishuaiuoe@gmail.com>
parent f265d15b
...@@ -4,32 +4,23 @@ ...@@ -4,32 +4,23 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Function Calling\n", "# Tool and Function Calling\n",
"\n", "\n",
"This notebook provides a quick-start guide to use function tooling using SGLang chat completions API\n", "This guide demonstrates how to use SGLang’s **Tool Calling** functionality."
"\n",
"## Supported Models\n",
"\n",
"Currently, we added the support for tools calling in the following models:\n",
" - Llama 3.2 models\n",
" - Llama 3.1 models\n",
" - Qwen 2.5 models\n",
" - InternLM Models"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Usage\n", "## OpenAI Compatible API"
"\n", ]
"### Launch a server\n", },
"\n", {
"This code block is equivalent to executing\n", "cell_type": "markdown",
"\n", "metadata": {},
"`python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", "source": [
"--port 30000 --host 0.0.0.0`\n", "### Launching the Server"
"in your terminal and wait for the server to be ready. Once the server is running, you can send test requests using curl or requests. The server implements the OpenAI-compatible APIs."
] ]
}, },
{ {
...@@ -38,6 +29,8 @@ ...@@ -38,6 +29,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from openai import OpenAI\n",
"import json\n",
"from sglang.utils import (\n", "from sglang.utils import (\n",
" execute_shell_command,\n", " execute_shell_command,\n",
" wait_for_server,\n", " wait_for_server,\n",
...@@ -45,21 +38,30 @@ ...@@ -45,21 +38,30 @@
" print_highlight,\n", " print_highlight,\n",
")\n", ")\n",
"\n", "\n",
"\n",
"server_process = execute_shell_command(\n", "server_process = execute_shell_command(\n",
" \"\"\"\n", " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --tool-call-parser llama3 --port 30333 --host 0.0.0.0\" # llama3\n",
" python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0\n",
"\"\"\"\n",
")\n", ")\n",
"wait_for_server(\"http://localhost:30333\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n",
"\n", "\n",
"wait_for_server(\"http://localhost:30000\")" "- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n",
"- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n",
"Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n",
"- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct)."
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Single Round Invocation" "### Define Tools for Function Call\n",
"Below is a Python snippet that shows how to define a tool as a dictionary. The dictionary includes a tool name, a description, and property defined Parameters."
] ]
}, },
{ {
...@@ -68,8 +70,7 @@ ...@@ -68,8 +70,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from openai import OpenAI\n", "# Define tools\n",
"\n",
"tools = [\n", "tools = [\n",
" {\n", " {\n",
" \"type\": \"function\",\n", " \"type\": \"function\",\n",
...@@ -79,22 +80,264 @@ ...@@ -79,22 +80,264 @@
" \"parameters\": {\n", " \"parameters\": {\n",
" \"type\": \"object\",\n", " \"type\": \"object\",\n",
" \"properties\": {\n", " \"properties\": {\n",
" \"location\": {\n", " \"city\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n",
" },\n",
" \"state\": {\n",
" \"type\": \"string\",\n", " \"type\": \"string\",\n",
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n", " \"description\": \"the two-letter abbreviation for the state that the city is\"\n",
" \" in, e.g. 'CA' which would mean 'California'\",\n",
" },\n",
" \"unit\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The unit to fetch the temperature in\",\n",
" \"enum\": [\"celsius\", \"fahrenheit\"],\n",
" },\n", " },\n",
" \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
" },\n", " },\n",
" \"required\": [\"location\"],\n", " \"required\": [\"city\", \"state\", \"unit\"],\n",
" },\n", " },\n",
" },\n", " },\n",
" }\n", " }\n",
"]\n", "]"
"messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston today?\"}]\n", ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define Messages"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_messages():\n",
" return [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"What's the weather like in Boston today? Please respond with the format: Today's weather is :{function call result}\",\n",
" }\n",
" ]\n",
"\n",
"\n",
"messages = get_messages()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initialize the Client"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize OpenAI-like client\n",
"client = OpenAI(api_key=\"None\", base_url=\"http://0.0.0.0:30333/v1\")\n",
"model_name = client.models.list().data[0].id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Non-Streaming Request"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Non-streaming mode test\n",
"response_non_stream = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages,\n",
" temperature=0.8,\n",
" top_p=0.8,\n",
" stream=False, # Non-streaming\n",
" tools=tools,\n",
")\n",
"print_highlight(\"Non-stream response:\")\n",
"print(response_non_stream)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Streaming Request"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Streaming mode test\n",
"print_highlight(\"Streaming response:\")\n",
"response_stream = client.chat.completions.create(\n",
" model=model_name,\n",
" messages=messages,\n",
" temperature=0.8,\n",
" top_p=0.8,\n",
" stream=True, # Enable streaming\n",
" tools=tools,\n",
")\n",
"\n",
"chunks = []\n",
"for chunk in response_stream:\n",
" chunks.append(chunk)\n",
" if chunk.choices[0].delta.tool_calls:\n",
" print(chunk.choices[0].delta.tool_calls[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"### Handle Tool Calls\n",
"\n",
"When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Non-Streaming Request**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"name_non_stream = response_non_stream.choices[0].message.tool_calls[0].function.name\n",
"arguments_non_stream = (\n",
" response_non_stream.choices[0].message.tool_calls[0].function.arguments\n",
")\n",
"\n",
"print_highlight(f\"Final streamed function call name: {name_non_stream}\")\n",
"print_highlight(f\"Final streamed function call arguments: {arguments_non_stream}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Streaming Request**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Parse and combine function call arguments\n",
"arguments = []\n",
"for chunk in chunks:\n",
" choice = chunk.choices[0]\n",
" delta = choice.delta\n",
" if delta.tool_calls:\n",
" tool_call = delta.tool_calls[0]\n",
" if tool_call.function.name:\n",
" print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n",
"\n",
" if tool_call.function.arguments:\n",
" arguments.append(tool_call.function.arguments)\n",
" print(f\"Streamed function call arguments: {tool_call.function.arguments}\")\n",
"\n",
"# Combine all fragments into a single JSON string\n",
"full_arguments = \"\".join(arguments)\n",
"print_highlight(f\"Final streamed function call arguments: {full_arguments}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define a Tool Function"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This is a demonstration, define real function according to your usage.\n",
"def get_current_weather(city: str, state: str, unit: \"str\"):\n",
" return (\n",
" f\"The weather in {city}, {state} is 85 degrees {unit}. It is \"\n",
" \"partly cloudly, with highs in the 90's.\"\n",
" )\n",
"\n",
"\n",
"available_tools = {\"get_current_weather\": get_current_weather}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"## Execute the Tool"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"call_data = json.loads(full_arguments)\n",
"\n",
"messages.append(\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"\",\n",
" \"tool_calls\": {\"name\": \"get_current_weather\", \"arguments\": full_arguments},\n",
" }\n",
")\n",
"\n", "\n",
"client = OpenAI(api_key=\"YOUR_API_KEY\", base_url=\"http://0.0.0.0:30000/v1\")\n", "# Call the corresponding tool function\n",
"model_name = client.models.list().data[0].id\n", "tool_name = messages[-1][\"tool_calls\"][\"name\"]\n",
"response = client.chat.completions.create(\n", "tool_to_call = available_tools[tool_name]\n",
"result = tool_to_call(**call_data)\n",
"print_highlight(f\"Function call result: {result}\")\n",
"messages.append({\"role\": \"tool\", \"content\": result, \"name\": tool_name})\n",
"\n",
"print_highlight(f\"Updated message history: {messages}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Send Results Back to Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"final_response = client.chat.completions.create(\n",
" model=model_name,\n", " model=model_name,\n",
" messages=messages,\n", " messages=messages,\n",
" temperature=0.8,\n", " temperature=0.8,\n",
...@@ -102,17 +345,56 @@ ...@@ -102,17 +345,56 @@
" stream=False,\n", " stream=False,\n",
" tools=tools,\n", " tools=tools,\n",
")\n", ")\n",
"print_highlight(\"Non-stream response:\")\n",
"print(final_response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Native API and SGLang Runtime (SRT)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"import requests\n",
"\n",
"# generate an answer\n",
"tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
"\n",
"messages = get_messages()\n",
"\n",
"input = tokenizer.apply_chat_template(\n",
" messages,\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
" tools=tools,\n",
")\n",
"\n", "\n",
"print(response)\n", "gen_url = \"http://localhost:30333/generate\"\n",
"gen_data = {\"text\": input, \"sampling_params\": {\"skip_special_tokens\": False}}\n",
"gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n",
"print(gen_response)\n",
"\n", "\n",
"\"\"\"\n", "# parse the response\n",
"parse_url = \"http://localhost:30333/function_call\"\n",
"\n", "\n",
"ChatCompletion(id='d6f620e1767e490d85b5ce45c15151cf', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=None, refusal=None, \n", "function_call_input = {\n",
"role='assistant', audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='0', function=Function(arguments='{\"a\": \"3\", \"b\": \"5\"}', name='add'), type='function')]), \n", " \"text\": gen_response,\n",
"matched_stop=128008)], created=1735411703, model='meta-llama/Llama-3.2-1B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, \n", " \"tool_call_parser\": \"llama3\",\n",
"usage=CompletionUsage(completion_tokens=23, prompt_tokens=198, total_tokens=221, completion_tokens_details=None, prompt_tokens_details=None))\n", " \"tools\": tools,\n",
"}\n",
"\n", "\n",
"\"\"\"" "function_call_response = requests.post(parse_url, json=function_call_input)\n",
"function_call_response_json = function_call_response.json()\n",
"print(\"function name: \", function_call_response_json[\"calls\"][0][\"name\"])\n",
"print(\"function arguments: \", function_call_response_json[\"calls\"][0][\"parameters\"])"
] ]
}, },
{ {
...@@ -128,11 +410,98 @@ ...@@ -128,11 +410,98 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## How to support a new model?\n", "## Offline Engine API"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sglang as sgl\n",
"from sglang.srt.function_call_parser import FunctionCallParser\n",
"from sglang.srt.managers.io_struct import Tool, Function\n",
"\n",
"llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n",
"tokenizer = llm.tokenizer_manager.tokenizer\n",
"input_ids = tokenizer.apply_chat_template(\n",
" messages, tokenize=True, add_generation_prompt=True, tools=tools\n",
")\n",
"\n",
"sampling_params = {\n",
" \"max_new_tokens\": 128,\n",
" \"temperature\": 0.3,\n",
" \"top_p\": 0.95,\n",
" \"skip_special_tokens\": False,\n",
"}\n",
"\n",
"# 1) Offline generation\n",
"result = llm.generate(input_ids=input_ids, sampling_params=sampling_params)\n",
"generated_text = result[\"text\"] # Assume there is only one prompt\n",
"\n",
"print(\"=== Offline Engine Output Text ===\")\n",
"print(generated_text)\n",
"\n",
"\n",
"# 2) Parse using FunctionCallParser\n",
"def convert_dict_to_tool(tool_dict: dict) -> Tool:\n",
" function_dict = tool_dict.get(\"function\", {})\n",
" return Tool(\n",
" type=tool_dict.get(\"type\", \"function\"),\n",
" function=Function(\n",
" name=function_dict.get(\"name\"),\n",
" description=function_dict.get(\"description\"),\n",
" parameters=function_dict.get(\"parameters\"),\n",
" ),\n",
" )\n",
"\n",
"\n",
"tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\n",
"\n",
"parser = FunctionCallParser(tools=tools, tool_call_parser=\"llama3\")\n",
"normal_text, calls = parser.parse_non_stream(generated_text)\n",
"\n",
"print(\"\\n=== Parsing Result ===\")\n",
"print(\"Normal text portion:\", normal_text)\n",
"print(\"Function call portion:\")\n",
"for call in calls:\n",
" # call: ToolCallItem\n",
" print(f\" - tool name: {call.name}\")\n",
" print(f\" parameters: {call.parameters}\")\n",
"\n", "\n",
"For adding support of more different models:\n", "# 3) If needed, perform additional logic on the parsed functions, such as automatically calling the corresponding function to obtain a return value, etc."
" 1. Update the `TOOLS_TAG_LIST` in `sglang/srt/utils.py` with the tool tag used by the model.\n", ]
" 2. Add support in `parse_tool_response` function for converting into tool calls `sglang/srt/utils.py`\n" },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm.shutdown()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## How to support a new model?\n",
"1. Update the TOOLS_TAG_LIST in sglang/srt/function_call_parser.py with the model’s tool tags. Currently supported tags include:\n",
"```\n",
"\tTOOLS_TAG_LIST = [\n",
"\t “<|plugin|>“,\n",
"\t “<function=“,\n",
"\t “<tool_call>“,\n",
"\t “<|python_tag|>“,\n",
"\t “[TOOL_CALLS]”\n",
"\t]\n",
"```\n",
"2. Create a new detector class in sglang/srt/function_call_parser.py that inherits from BaseFormatDetector. The detector should handle the model’s specific function call format. For example:\n",
"```\n",
" class NewModelDetector(BaseFormatDetector):\n",
"```\n",
"3. Add the new detector to the MultiFormatParser class that manages all the format detectors."
] ]
} }
], ],
......
...@@ -39,10 +39,12 @@ from fastapi.middleware.cors import CORSMiddleware ...@@ -39,10 +39,12 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.srt.entrypoints.engine import _launch_subprocesses from sglang.srt.entrypoints.engine import _launch_subprocesses
from sglang.srt.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
CloseSessionReqInput, CloseSessionReqInput,
ConfigureLoggingReq, ConfigureLoggingReq,
EmbeddingReqInput, EmbeddingReqInput,
FunctionCallReqInput,
GenerateReqInput, GenerateReqInput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
...@@ -369,6 +371,28 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request): ...@@ -369,6 +371,28 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
return Response(status_code=200) return Response(status_code=200)
@app.post("/function_call")
async def function_call_request(obj: FunctionCallReqInput, request: Request):
"""
A native API endpoint to parse function calls from a text.
"""
# 1) Initialize the parser based on the request body
parser = FunctionCallParser(tools=obj.tools, tool_call_parser=obj.tool_call_parser)
# 2) Call the non-stream parsing method (non-stream)
normal_text, calls = parser.parse_non_stream(obj.text)
# 3) Organize the response content
response_data = {
"normal_text": normal_text,
"calls": [
call.model_dump() for call in calls
], # Convert pydantic objects to dictionaries
}
return ORJSONResponse(content=response_data, status_code=200)
##### OpenAI-compatible API endpoints ##### ##### OpenAI-compatible API endpoints #####
......
import json
import re
from abc import ABC, abstractmethod
from json import JSONDecodeError, JSONDecoder
from typing import Any, Dict, List, Optional, Tuple
import partial_json_parser
from partial_json_parser.core.options import Allow
from pydantic import BaseModel, Field
TOOLS_TAG_LIST = [
"<|plugin|>",
"<function=",
"<tool_call>",
"<|python_tag|>",
"[TOOL_CALLS]",
]
class Function(BaseModel):
"""Function Tool Template."""
description: Optional[str] = Field(default=None, examples=[None])
name: Optional[str] = None
parameters: Optional[object] = None
class ToolCallItem(BaseModel):
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
tool_index: int
name: Optional[str] = None
parameters: str # JSON string
def _find_common_prefix(s1: str, s2: str) -> str:
prefix = ""
min_length = min(len(s1), len(s2))
for i in range(0, min_length):
if s1[i] == s2[i]:
prefix += s1[i]
else:
break
return prefix
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e:
if "Extra data" in e.msg:
dec = JSONDecoder()
return dec.raw_decode(input_str)
raise
def _is_complete_json(input_str: str) -> bool:
try:
json.loads(input_str)
return True
except JSONDecodeError:
return False
class StreamingParseResult:
"""Result of streaming incremental parsing."""
def __init__(
self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None
):
self.normal_text = normal_text
self.calls = calls or []
class BaseFormatDetector:
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
def __init__(self):
# initialize properties used for state when parsing tool calls in
self._buffer = ""
# streaming mode
self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: List[str] = (
[]
) # map what has been streamed for each tool so far to a list
self.bot_token = ""
self.eot_token = ""
def parse_base_json(self, action: Dict, tools: List[Function]):
name, parameters = action["name"], json.dumps(
action.get("parameters", action.get("arguments", {})),
ensure_ascii=False,
)
tool_index = [tool.function.name for tool in tools].index(name)
tool_call_item = ToolCallItem(
tool_index=tool_index, name=name, parameters=parameters
)
calls = [tool_call_item]
return calls
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
"""
Parses the text in one go. Returns success=True if the format matches, otherwise False.
Note that leftover_text here represents "content that this parser will not consume further".
"""
action = json.loads(text)
return self.parse_base_json(action, tools)
def parse_streaming_increment(
self, new_text: str, tools: List[Function]
) -> StreamingParseResult:
"""
Streaming incremental parsing, referencing the logic of Llama32Detector.
We partially parse JSON within <tool_call>...</tool_call>, and handle
incremental argument output.
"""
# Append new text to buffer
self._buffer += new_text
current_text = self._buffer
if not (self.bot_token in current_text or current_text.startswith("{")):
self._buffer = ""
if self.eot_token in new_text:
new_text = new_text.replace(self.eot_token, "")
return StreamingParseResult(normal_text=new_text)
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
tool_call_arr = []
is_complete = []
try:
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
start_idx = (
len(self.bot_token)
if current_text.startswith(self.bot_token)
else 0
)
while start_idx < len(current_text):
(obj, end_idx) = _partial_json_loads(
current_text[start_idx:], flags
)
is_complete.append(
_is_complete_json(current_text[start_idx : start_idx + end_idx])
)
start_idx += end_idx + len("; ")
# depending on the prompt Llama can use
# either arguments or parameters
if "parameters" in obj:
assert (
"arguments" not in obj
), "model generated both parameters and arguments"
obj["arguments"] = obj["parameters"]
tool_call_arr.append(obj)
except partial_json_parser.core.exceptions.MalformedJSON:
# not enough tokens to parse into JSON yet
return StreamingParseResult()
# select as the current tool call the one we're on the state at
current_tool_call: Dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
)
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return StreamingParseResult()
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
cur_args_json = json.dumps(cur_arguments)
sent = len(self.streamed_args_for_tool[self.current_tool_id])
argument_diff = cur_args_json[sent:]
res = StreamingParseResult(
normal_text=None,
calls=[
ToolCallItem(
tool_index=self.current_tool_id,
name="",
parameters=argument_diff,
)
],
)
self.streamed_args_for_tool[
self.current_tool_id
] += argument_diff
else:
res = StreamingParseResult()
else:
res = StreamingParseResult()
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
print("starting on new tool %d", self.current_tool_id)
return res
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
res = StreamingParseResult(
normal_text=None,
calls=[
ToolCallItem(
tool_index=self.current_tool_id,
name=function_name,
parameters="",
)
],
)
self.current_tool_name_sent = True
else:
res = StreamingParseResult()
# now we know we're on the same tool call and we're streaming
# arguments
else:
cur_arguments = current_tool_call.get("arguments")
res = StreamingParseResult()
if cur_arguments:
sent = len(self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
argument_diff = None
if is_complete[self.current_tool_id]:
argument_diff = cur_args_json[sent:]
self._buffer = ""
self.prev_tool_call_arr[self.current_tool_id].clear()
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool[self.current_tool_id] = ""
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json:
prefix = _find_common_prefix(prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
if argument_diff is not None:
res = StreamingParseResult(
calls=[
ToolCallItem(
tool_index=self.current_tool_id,
name="",
parameters=argument_diff,
)
],
)
if not is_complete[self.current_tool_id]:
self.streamed_args_for_tool[
self.current_tool_id
] += argument_diff
self.prev_tool_call_arr = tool_call_arr
return res
except Exception as e:
print(e)
# Skipping chunk as a result of tool streaming extraction error
return StreamingParseResult()
class Qwen25Detector(BaseFormatDetector):
"""
Detector for Qwen 2.5 models.
Assumes function call format:
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
"""
def __init__(self):
"""
Initializes the detector with necessary state variables.
"""
super().__init__()
self.bot_token = "<tool_call>"
self.eot_token = "</tool_call>"
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
if "<tool_call>" not in text:
return []
pattern = r"<tool_call>(.*?)</tool_call>"
match_result_list = re.findall(pattern, text, re.DOTALL)
calls = []
for match_result in match_result_list:
match_result = json.loads(match_result)
calls.extend(self.parse_base_json(match_result, tools))
return calls
class MistralDetector(BaseFormatDetector):
"""
Detector for Mistral models.
Assumes function call format:
<|action_start|><|plugin|>{"name":"xxx", "arguments":{...}}<|action_end|>
"""
def __init__(self):
"""
Initializes the detector with necessary state variables.
"""
super().__init__()
self.bot_token = "[TOOL_CALLS] ["
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
def _clean_text(self, text: str) -> str:
"""
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
for example,
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
The key pattern is [TOOL_CALLS] [...]
"""
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
if len(find_results) > 0:
return find_results[0]
else:
return ""
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
text = self._clean_text(text)
tool_content = text.replace("[TOOL_CALLS]", "").strip()
raw_tool_calls = self.tool_call_regex.findall(tool_content)
calls = []
if len(raw_tool_calls) > 0:
raw_tool_call = raw_tool_calls[0]
function_call_arr = json.loads(raw_tool_call)
for match_result in function_call_arr:
calls.extend(self.parse_base_json(match_result, tools))
return calls
class Llama32Detector(BaseFormatDetector):
"""
Detector for Llama 3.2 models.
Assumes function call format:
<|python_tag|>{"name":"xxx", "arguments":{...}}
Does not require a closing tag "</python_tag|>",
relies on json.loads(...) success to determine if JSON is complete.
"""
def __init__(self):
"""
Initializes the detector with necessary state variables.
"""
super().__init__()
self.bot_token = "<|python_tag|>"
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
if "<|python_tag|>" not in text:
return []
_, action = text.split("<|python_tag|>")
action = json.loads(action)
return self.parse_base_json(action, tools)
class MultiFormatParser:
def __init__(self, detectors: List[BaseFormatDetector]):
"""
:param detectors: A series of available Detector instances passed in
"""
self.detectors = detectors
def parse_once(self, text: str, tools: List[Function]):
"""
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
Return: (final_text, all_calls)
- final_text: The remaining text after parsing that was not consumed by any Detector (can be treated as normal text)
- all_calls: All calls parsed by the Detectors
"""
final_calls = []
final_normal_text = text
for detector in self.detectors:
tool_call_list = detector.detect_and_parse(text, tools)
if len(tool_call_list) > 0: # parsed successfully
final_calls = tool_call_list
break
# leftover_text is the normal text not consumed by any Detector
return final_normal_text, final_calls
def parse_streaming_increment(self, new_text: str, tools: List[Function]):
"""
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
and merge their produced normal_text/calls to return.
(The logic here can be "priority-based" or "parallel parsing" based on your needs)
"""
final_normal_text = ""
final_calls = []
for detector in self.detectors:
sp_result = detector.parse_streaming_increment(new_text, tools)
# Merge normal_text and calls
# If one sp_result contains result call, this should be a successful parse
# If one sp_result only contains normal_text, this can either be a successful
# parse or it is not using the desired parsing tool.
if sp_result.normal_text:
final_normal_text = sp_result.normal_text
if sp_result.calls:
final_calls.extend(sp_result.calls)
final_normal_text = sp_result.normal_text
break
return final_normal_text, final_calls
class FunctionCallParser:
"""
In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
and returns the resulting normal_text and calls to the upper layer (or SSE).
"""
ToolCallParserEnum: Dict[str, BaseFormatDetector] = {
"llama3": Llama32Detector,
"qwen25": Qwen25Detector,
"mistral": MistralDetector,
}
def __init__(self, tools: List[Function], tool_call_parser: str = None):
detectors = []
if tool_call_parser:
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
if detector_class:
detectors.append(detector_class())
else:
raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
else:
raise ValueError("Tool Call Parser Not Given!")
self.multi_format_parser = MultiFormatParser(detectors)
self.tools = tools
def parse_non_stream(self, full_text: str):
"""
Non-streaming call: one-time parsing
"""
full_normal_text, calls = self.multi_format_parser.parse_once(
full_text, self.tools
)
return full_normal_text, calls
def parse_stream_chunk(self, chunk_text: str):
"""
Streaming call: incremental parsing
"""
normal_text, calls = self.multi_format_parser.parse_streaming_increment(
chunk_text, self.tools
)
return normal_text, calls
...@@ -17,7 +17,7 @@ processes (TokenizerManager, DetokenizerManager, Controller). ...@@ -17,7 +17,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
""" """
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
...@@ -540,3 +540,27 @@ class CloseSessionReqInput: ...@@ -540,3 +540,27 @@ class CloseSessionReqInput:
class OpenSessionReqOutput: class OpenSessionReqOutput:
session_id: Optional[str] session_id: Optional[str]
success: bool success: bool
@dataclass
class Function:
description: Optional[str] = None
name: Optional[str] = None
parameters: Optional[object] = None
@dataclass
class Tool:
function: Function
type: Optional[str] = "function"
@dataclass
class FunctionCallReqInput:
text: str # The text to parse.
tools: List[Tool] = field(
default_factory=list
) # A list of available function tools (name, parameters, etc.).
tool_call_parser: Optional[str] = (
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
)
...@@ -20,7 +20,7 @@ import os ...@@ -20,7 +20,7 @@ import os
import time import time
import uuid import uuid
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List from typing import Dict, List, Optional
from fastapi import HTTPException, Request, UploadFile from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import ORJSONResponse, StreamingResponse from fastapi.responses import ORJSONResponse, StreamingResponse
...@@ -40,6 +40,7 @@ from sglang.srt.conversation import ( ...@@ -40,6 +40,7 @@ from sglang.srt.conversation import (
generate_chat_conv, generate_chat_conv,
register_conv_template, register_conv_template,
) )
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.openai_api.protocol import ( from sglang.srt.openai_api.protocol import (
BatchRequest, BatchRequest,
...@@ -71,7 +72,6 @@ from sglang.srt.openai_api.protocol import ( ...@@ -71,7 +72,6 @@ from sglang.srt.openai_api.protocol import (
TopLogprob, TopLogprob,
UsageInfo, UsageInfo,
) )
from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -309,6 +309,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe ...@@ -309,6 +309,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
ret, ret,
to_file=True, to_file=True,
cache_report=tokenizer_manager.server_args.enable_cache_report, cache_report=tokenizer_manager.server_args.enable_cache_report,
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
) )
else: else:
responses = v1_generate_response( responses = v1_generate_response(
...@@ -877,9 +878,6 @@ def v1_chat_generate_request( ...@@ -877,9 +878,6 @@ def v1_chat_generate_request(
tools = None tools = None
if request.tools and request.tool_choice != "none": if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False request.skip_special_tokens = False
if request.stream:
logger.warning("Streaming is not supported with tools.")
request.stream = False
if not isinstance(request.tool_choice, str): if not isinstance(request.tool_choice, str):
tools = [ tools = [
item.function.model_dump() item.function.model_dump()
...@@ -908,12 +906,26 @@ def v1_chat_generate_request( ...@@ -908,12 +906,26 @@ def v1_chat_generate_request(
openai_compatible_messages = openai_compatible_messages[:-1] openai_compatible_messages = openai_compatible_messages[:-1]
else: else:
assistant_prefix = None assistant_prefix = None
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
openai_compatible_messages, try:
tokenize=True, prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
add_generation_prompt=True, openai_compatible_messages,
tools=tools, tokenize=True,
) add_generation_prompt=True,
tools=tools,
)
except:
# This except branch will be triggered when the chosen model
# has a different tools input format that is not compatiable
# with openAI's apply_chat_template tool_call format, like Mistral.
tools = [t if "function" in t else {"function": t} for t in tools]
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
openai_compatible_messages,
tokenize=True,
add_generation_prompt=True,
tools=tools,
)
if assistant_prefix: if assistant_prefix:
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix) prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
stop = request.stop stop = request.stop
...@@ -1005,7 +1017,9 @@ def v1_chat_generate_request( ...@@ -1005,7 +1017,9 @@ def v1_chat_generate_request(
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
def v1_chat_generate_response(request, ret, to_file=False, cache_report=False): def v1_chat_generate_response(
request, ret, to_file=False, cache_report=False, tool_call_parser=None
):
choices = [] choices = []
for idx, ret_item in enumerate(ret): for idx, ret_item in enumerate(ret):
...@@ -1066,12 +1080,13 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False): ...@@ -1066,12 +1080,13 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
if finish_reason == "stop": if finish_reason == "stop":
finish_reason = "tool_calls" finish_reason = "tool_calls"
try: try:
text, call_info_list = parse_tool_response(text, tools) # noqa parser = FunctionCallParser(tools, tool_call_parser)
full_normal_text, call_info_list = parser.parse_non_stream(text)
tool_calls = [ tool_calls = [
ToolCall( ToolCall(
id=str(call_info[0]), id=str(call_info.tool_index),
function=FunctionResponse( function=FunctionResponse(
name=call_info[1], arguments=call_info[2] name=call_info.name, arguments=call_info.parameters
), ),
) )
for call_info in call_info_list for call_info in call_info_list
...@@ -1172,6 +1187,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -1172,6 +1187,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager) adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
if adapted_request.stream: if adapted_request.stream:
parser_dict = {}
async def generate_stream_resp(): async def generate_stream_resp():
is_firsts = {} is_firsts = {}
...@@ -1184,6 +1200,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -1184,6 +1200,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
adapted_request, raw_request adapted_request, raw_request
): ):
index = content.get("index", 0) index = content.get("index", 0)
text = content["text"]
is_first = is_firsts.get(index, True) is_first = is_firsts.get(index, True)
stream_buffer = stream_buffers.get(index, "") stream_buffer = stream_buffers.get(index, "")
...@@ -1263,29 +1280,111 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -1263,29 +1280,111 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
text = content["text"] text = content["text"]
delta = text[len(stream_buffer) :] delta = text[len(stream_buffer) :]
stream_buffer = stream_buffer + delta new_stream_buffer = stream_buffer + delta
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=delta),
finish_reason=(finish_reason["type"] if finish_reason else ""),
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
is_firsts[index] = is_first if request.tool_choice != "none" and request.tools:
stream_buffers[index] = stream_buffer if index not in parser_dict:
n_prev_tokens[index] = n_prev_token parser_dict[index] = FunctionCallParser(
tools=request.tools,
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
)
parser = parser_dict[index]
# parse_increment => returns (normal_text, calls)
normal_text, calls = parser.parse_stream_chunk(delta)
# 1) if there's normal_text, output it as normal content
if normal_text:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=normal_text),
finish_reason=(
finish_reason["type"] if finish_reason else ""
),
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
# 2) if we found calls, we output them as separate chunk(s)
for call_item in calls:
# transform call_item -> FunctionResponse + ToolCall
if (
content["meta_info"]["finish_reason"]
and content["meta_info"]["finish_reason"]["type"]
== "stop"
):
latest_delta_len = 0
if isinstance(call_item.parameters, str):
latest_delta_len = len(call_item.parameters)
expected_call = json.dumps(
parser.multi_format_parser.detectors[0]
.prev_tool_call_arr[index]
.get("arguments", {}),
ensure_ascii=False,
)
actual_call = parser.multi_format_parser.detectors[
0
].streamed_args_for_tool[index]
if latest_delta_len > 0:
actual_call = actual_call[:-latest_delta_len]
remaining_call = expected_call.replace(
actual_call, "", 1
)
call_item.parameters = remaining_call
tool_call = ToolCall(
id=str(call_item.tool_index),
function=FunctionResponse(
name=call_item.name,
arguments=call_item.parameters,
),
)
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(
role="assistant", tool_calls=[tool_call]
),
finish_reason="tool_call",
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
yield f"data: {chunk.model_dump_json()}\n\n" stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first
else:
# No tool calls => just treat this as normal text
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=delta),
finish_reason=(
finish_reason["type"] if finish_reason else ""
),
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first
if request.stream_options and request.stream_options.include_usage: if request.stream_options and request.stream_options.include_usage:
total_prompt_tokens = sum( total_prompt_tokens = sum(
tokens tokens
...@@ -1333,7 +1432,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -1333,7 +1432,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
ret = [ret] ret = [ret]
response = v1_chat_generate_response( response = v1_chat_generate_response(
request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report request,
ret,
cache_report=tokenizer_manager.server_args.enable_cache_report,
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
) )
return response return response
......
...@@ -262,7 +262,7 @@ class Function(BaseModel): ...@@ -262,7 +262,7 @@ class Function(BaseModel):
"""Function descriptions.""" """Function descriptions."""
description: Optional[str] = Field(default=None, examples=[None]) description: Optional[str] = Field(default=None, examples=[None])
name: str name: Optional[str] = None
parameters: Optional[object] = None parameters: Optional[object] = None
...@@ -276,7 +276,7 @@ class Tool(BaseModel): ...@@ -276,7 +276,7 @@ class Tool(BaseModel):
class ToolChoiceFuncName(BaseModel): class ToolChoiceFuncName(BaseModel):
"""The name of tool choice function.""" """The name of tool choice function."""
name: str name: Optional[str] = None
class ToolChoice(BaseModel): class ToolChoice(BaseModel):
...@@ -329,8 +329,8 @@ class ChatCompletionRequest(BaseModel): ...@@ -329,8 +329,8 @@ class ChatCompletionRequest(BaseModel):
class FunctionResponse(BaseModel): class FunctionResponse(BaseModel):
"""Function response.""" """Function response."""
name: str name: Optional[str] = None
arguments: str arguments: Optional[str] = None
class ToolCall(BaseModel): class ToolCall(BaseModel):
...@@ -367,6 +367,7 @@ class ChatCompletionResponse(BaseModel): ...@@ -367,6 +367,7 @@ class ChatCompletionResponse(BaseModel):
class DeltaMessage(BaseModel): class DeltaMessage(BaseModel):
role: Optional[str] = None role: Optional[str] = None
content: Optional[str] = None content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel):
......
...@@ -161,6 +161,7 @@ class ServerArgs: ...@@ -161,6 +161,7 @@ class ServerArgs:
# Custom logit processor # Custom logit processor
enable_custom_logit_processor: bool = False enable_custom_logit_processor: bool = False
tool_call_parser: str = None
def __post_init__(self): def __post_init__(self):
# Set missing default values # Set missing default values
...@@ -877,6 +878,14 @@ class ServerArgs: ...@@ -877,6 +878,14 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable users to pass custom logit processors to the server (disabled by default for security)", help="Enable users to pass custom logit processors to the server (disabled by default for security)",
) )
# Function Calling
parser.add_argument(
"--tool-call-parser",
type=str,
choices=["qwen25", "mistral", "llama3"],
default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
...@@ -1243,68 +1243,6 @@ def dataclass_to_string_truncated(data, max_length=2048): ...@@ -1243,68 +1243,6 @@ def dataclass_to_string_truncated(data, max_length=2048):
return str(data) return str(data)
TOOLS_TAG_LIST = ["<|plugin|>", "<function=", "<tool_call>", "<|python_tag|>"]
def parse_tool_response(text, tools, **kwargs):
"""Parse model response containing tool information.
Args:
text(str): model response in string format
tools(List): tools from user request
"""
if "<|plugin|>" in text: # internlm2
text, action = text.split("<|action_start|><|plugin|>")
action = action.split("<|action_end|>".strip())[0]
action = action[action.find("{") :]
action = json.loads(action)
name, parameters = action["name"], json.dumps(
action.get("parameters", action.get("arguments", {})), ensure_ascii=False
)
call_info_list = [(name, parameters)]
elif "<function=" in text: # llama3.1
action, _ = text.split("</function>")
parameters = action[action.find("{") :]
name = action.split("<function=")[1].split(">{")[0]
call_info_list = [(name, parameters)]
elif "<tool_call>" in text and "</tool_call>" in text: # qwen2.5
# get tool_call in text
pattern = r"<tool_call>(.*?)</tool_call>"
match_result_list = re.findall(pattern, text, re.DOTALL)
call_info_list = []
for match_result in match_result_list:
action = json.loads(match_result)
call_info_list.append(
(action["name"], json.dumps(action["arguments"], ensure_ascii=False))
)
# get text outside of tags
if not text.startswith("<tool_call>"):
text = text[: text.find("<tool_call>")]
elif not text.endswith("</tool_call>"):
text = text[text.rfind("</tool_call>") + len("</tool_call>") :]
else:
text = ""
elif "<|python_tag|>" in text: # llama3.2
_, action = text.split("<|python_tag|>")
action = json.loads(action)
name, parameters = action["name"], json.dumps(
action.get("parameters", action.get("arguments", {})), ensure_ascii=False
)
call_info_list = [(name, parameters)]
else:
raise RuntimeError(f"Unexpected model response: {text}")
call_info_list = [
(
[tool.function.name for tool in tools].index(call_info[0]),
call_info[0],
call_info[1],
)
for call_info in call_info_list
]
return text, call_info_list
def permute_weight(x: torch.Tensor) -> torch.Tensor: def permute_weight(x: torch.Tensor) -> torch.Tensor:
b_ = x.shape[0] b_ = x.shape[0]
n_ = x.shape[1] n_ = x.shape[1]
......
import json
import time
import unittest
import openai
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestOpenAIServerFunctionCalling(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
# Start the local OpenAI Server. If necessary, you can add other parameters such as --enable-tools.
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
# If your server needs extra parameters to test function calling, please add them here.
"--tool-call-parser",
"llama3",
],
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_function_calling_format(self):
"""
Test: Whether the function call format returned by the AI is correct.
When returning a tool call, message.content should be None, and tool_calls should be a list.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "add",
"description": "Compute the sum of two numbers",
"parameters": {
"type": "object",
"properties": {
"a": {
"type": "int",
"description": "A number",
},
"b": {
"type": "int",
"description": "A number",
},
},
"required": ["a", "b"],
},
},
}
]
messages = [{"role": "user", "content": "Compute (3+5)"}]
response = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=False,
tools=tools,
)
content = response.choices[0].message.content
tool_calls = response.choices[0].message.tool_calls
assert content is None, (
"When function call is successful, message.content should be None, "
f"but got: {content}"
)
assert (
isinstance(tool_calls, list) and len(tool_calls) > 0
), "tool_calls should be a non-empty list"
function_name = tool_calls[0].function.name
assert function_name == "add", "Function name should be 'add'"
def test_function_calling_streaming_simple(self):
"""
Test: Whether the function name can be correctly recognized in streaming mode.
- Expect a function call to be found, and the function name to be correct.
- Verify that streaming mode returns at least multiple chunks.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for",
},
"unit": {
"type": "string",
"description": "Weather unit (celsius or fahrenheit)",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "unit"],
},
},
}
]
messages = [{"role": "user", "content": "What is the temperature in Paris?"}]
response_stream = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=True,
tools=tools,
)
chunks = list(response_stream)
self.assertTrue(len(chunks) > 0, "Streaming should return at least one chunk")
found_function_name = False
for chunk in chunks:
choice = chunk.choices[0]
# Check whether the current chunk contains tool_calls
if choice.delta.tool_calls:
tool_call = choice.delta.tool_calls[0]
if tool_call.function.name:
self.assertEqual(
tool_call.function.name,
"get_current_weather",
"Function name should be 'get_current_weather'",
)
found_function_name = True
break
self.assertTrue(
found_function_name,
"Target function name 'get_current_weather' was not found in the streaming chunks",
)
def test_function_calling_streaming_args_parsing(self):
"""
Test: Whether the function call arguments returned in streaming mode can be correctly concatenated into valid JSON.
- The user request requires multiple parameters.
- AI may return the arguments in chunks that need to be concatenated.
"""
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "add",
"description": "Compute the sum of two integers",
"parameters": {
"type": "object",
"properties": {
"a": {
"type": "int",
"description": "First integer",
},
"b": {
"type": "int",
"description": "Second integer",
},
},
"required": ["a", "b"],
},
},
}
]
messages = [
{"role": "user", "content": "Please sum 5 and 7, just call the function."}
]
response_stream = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.9,
top_p=0.9,
stream=True,
tools=tools,
)
argument_fragments = []
function_name = None
for chunk in response_stream:
choice = chunk.choices[0]
if choice.delta.tool_calls:
tool_call = choice.delta.tool_calls[0]
# Record the function name on first occurrence
function_name = tool_call.function.name or function_name
# In case of multiple chunks, JSON fragments may need to be concatenated
if tool_call.function.arguments:
argument_fragments.append(tool_call.function.arguments)
self.assertEqual(function_name, "add", "Function name should be 'add'")
joined_args = "".join(argument_fragments)
self.assertTrue(
len(joined_args) > 0,
"No parameter fragments were returned in the function call",
)
# Check whether the concatenated JSON is valid
try:
args_obj = json.loads(joined_args)
except json.JSONDecodeError:
self.fail(
"The concatenated tool call arguments are not valid JSON, parsing failed"
)
self.assertIn("a", args_obj, "Missing parameter 'a'")
self.assertIn("b", args_obj, "Missing parameter 'b'")
self.assertEqual(
args_obj["a"],
5,
"Parameter a should be 5",
)
self.assertEqual(args_obj["b"], 7, "Parameter b should be 7")
if __name__ == "__main__":
unittest.main()
...@@ -623,58 +623,6 @@ class TestOpenAIServerEBNF(unittest.TestCase): ...@@ -623,58 +623,6 @@ class TestOpenAIServerEBNF(unittest.TestCase):
text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape" text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape"
) )
def test_function_calling_format(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
tools = [
{
"type": "function",
"function": {
"name": "add",
"description": "Compute the sum of two numbers",
"parameters": {
"type": "object",
"properties": {
"a": {
"type": "int",
"description": "A number",
},
"b": {
"type": "int",
"description": "A number",
},
},
"required": ["a", "b"],
},
},
}
]
messages = [{"role": "user", "content": "Compute (3+5)"}]
response = client.chat.completions.create(
model=self.model,
messages=messages,
temperature=0.8,
top_p=0.8,
stream=False,
tools=tools,
)
content = response.choices[0].message.content
tool_calls = response.choices[0].message.tool_calls
assert (
content is None
), "When tools provided by the response, content should be None"
assert (
isinstance(tool_calls, list) and len(tool_calls) > 0
), "Format not matched, tool_calls should be a list"
function_name = tool_calls[0].function.name
assert (
function_name == "add"
), "Function name should be add for the above response"
class TestOpenAIEmbedding(unittest.TestCase): class TestOpenAIEmbedding(unittest.TestCase):
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment