Unverified Commit 591e751e authored by Shi Shuai's avatar Shi Shuai Committed by GitHub
Browse files

Fix: Runtime error for function calling (#3300)

parent 40022d07
......@@ -20,7 +20,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
- [Serving with two H200*8 nodes and docker](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-two-h2008-nodes-and-docker).
## Optimisations
## Optimizations
### Multi-head Latent Attention (MLA) Throughput Optimizations
......
import json
import logging
import re
from abc import ABC, abstractmethod
from json import JSONDecodeError, JSONDecoder
......@@ -8,6 +9,8 @@ import partial_json_parser
from partial_json_parser.core.options import Allow
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
TOOLS_TAG_LIST = [
"<|plugin|>",
"<function=",
......@@ -88,17 +91,43 @@ class BaseFormatDetector:
self.bot_token = ""
self.eot_token = ""
def parse_base_json(self, action: Dict, tools: List[Function]):
name, parameters = action["name"], json.dumps(
action.get("parameters", action.get("arguments", {})),
ensure_ascii=False,
)
tool_index = [tool.function.name for tool in tools].index(name)
tool_call_item = ToolCallItem(
tool_index=tool_index, name=name, parameters=parameters
)
calls = [tool_call_item]
return calls
def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallItem]:
tool_indices = {
tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
}
if not isinstance(action, list):
name = action.get("name")
if not name or name not in tool_indices:
logger.warning(f"Model attempted to call undefined function: {name}")
return []
return [
ToolCallItem(
tool_index=tool_indices[name],
name=name,
parameters=json.dumps(
action.get("parameters") or action.get("arguments", {}),
ensure_ascii=False,
),
)
]
results = []
for act in action:
name = act.get("name")
if name and name in tool_indices:
results.append(
ToolCallItem(
tool_index=tool_indices[name],
name=name,
parameters=json.dumps(
act.get("parameters") or act.get("arguments", {}),
ensure_ascii=False,
),
)
)
return results
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
"""
......@@ -112,9 +141,7 @@ class BaseFormatDetector:
self, new_text: str, tools: List[Function]
) -> StreamingParseResult:
"""
Streaming incremental parsing, referencing the logic of Llama32Detector.
We partially parse JSON within <tool_call>...</tool_call>, and handle
incremental argument output.
Streaming incremental parsing with tool validation.
"""
# Append new text to buffer
self._buffer += new_text
......@@ -125,17 +152,19 @@ class BaseFormatDetector:
new_text = new_text.replace(self.eot_token, "")
return StreamingParseResult(normal_text=new_text)
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
# Build tool indices if not already built
if not hasattr(self, "_tool_indices"):
self._tool_indices = {
tool.function.name: i
for i, tool in enumerate(tools)
if tool.function and tool.function.name
}
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
tool_call_arr = []
is_complete = []
try:
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
start_idx = (
len(self.bot_token)
if current_text.startswith(self.bot_token)
......@@ -149,8 +178,18 @@ class BaseFormatDetector:
_is_complete_json(current_text[start_idx : start_idx + end_idx])
)
start_idx += end_idx + len("; ")
# depending on the prompt Llama can use
# either arguments or parameters
# Validate tool name if present
if "name" in obj and obj["name"] not in self._tool_indices:
# Invalid tool name - reset state
self._buffer = ""
self.current_tool_id = -1
self.current_tool_name_sent = False
if self.streamed_args_for_tool:
self.streamed_args_for_tool.pop()
return StreamingParseResult()
# Handle parameters/arguments consistency
if "parameters" in obj:
assert (
"arguments" not in obj
......@@ -159,29 +198,17 @@ class BaseFormatDetector:
tool_call_arr.append(obj)
except partial_json_parser.core.exceptions.MalformedJSON:
# not enough tokens to parse into JSON yet
return StreamingParseResult()
# select as the current tool call the one we're on the state at
current_tool_call: Dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
)
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return StreamingParseResult()
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
):
current_tool_call: Dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
)
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
# Handle new tool in array
if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments")
if cur_arguments:
......@@ -190,7 +217,6 @@ class BaseFormatDetector:
argument_diff = cur_args_json[sent:]
res = StreamingParseResult(
normal_text=None,
calls=[
ToolCallItem(
tool_index=self.current_tool_id,
......@@ -206,23 +232,20 @@ class BaseFormatDetector:
res = StreamingParseResult()
else:
res = StreamingParseResult()
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
print("starting on new tool %d", self.current_tool_id)
return res
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
# Handle tool name
elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
if function_name and function_name in self._tool_indices:
res = StreamingParseResult(
normal_text=None,
calls=[
ToolCallItem(
tool_index=self.current_tool_id,
tool_index=self._tool_indices[function_name],
name=function_name,
parameters="",
)
......@@ -232,8 +255,7 @@ class BaseFormatDetector:
else:
res = StreamingParseResult()
# now we know we're on the same tool call and we're streaming
# arguments
# Handle streaming arguments
else:
cur_arguments = current_tool_call.get("arguments")
res = StreamingParseResult()
......@@ -250,13 +272,12 @@ class BaseFormatDetector:
argument_diff = cur_args_json[sent:]
self._buffer = ""
self.prev_tool_call_arr[self.current_tool_id].clear()
self.current_tool_name_sent: bool = False
self.current_tool_name_sent = False
self.streamed_args_for_tool[self.current_tool_id] = ""
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json:
prefix = _find_common_prefix(prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
......@@ -279,8 +300,7 @@ class BaseFormatDetector:
return res
except Exception as e:
print(e)
# Skipping chunk as a result of tool streaming extraction error
logger.error(f"Error in parse_streaming_increment: {e}")
return StreamingParseResult()
......@@ -372,31 +392,38 @@ class Llama32Detector(BaseFormatDetector):
Detector for Llama 3.2 models.
Assumes function call format:
<|python_tag|>{"name":"xxx", "arguments":{...}}
Does not require a closing tag "</python_tag|>",
relies on json.loads(...) success to determine if JSON is complete.
"""
def __init__(self):
"""
Initializes the detector with necessary state variables.
"""
super().__init__()
self.bot_token = "<|python_tag|>"
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
"""
One-time parsing: Detects and parses tool calls in the provided text.
:param text: The complete text to parse.
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
"""Parse function calls from text, handling multiple JSON objects."""
if "<|python_tag|>" not in text:
return []
_, action = text.split("<|python_tag|>")
action = json.loads(action)
return self.parse_base_json(action, tools)
_, action_text = text.split("<|python_tag|>")
# Split by semicolon and process each part
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
all_actions = []
for part in json_parts:
try:
# Parse each individual JSON object
action = json.loads(part)
all_actions.append(action)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON part: {part}")
logger.warning(f"JSON parse error: {str(e)}")
continue
# Only process if we found valid JSON objects
if all_actions:
return self.parse_base_json(all_actions, tools)
return []
class MultiFormatParser:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment