Unverified Commit 591e751e authored by Shi Shuai's avatar Shi Shuai Committed by GitHub
Browse files

Fix: Runtime error for function calling (#3300)

parent 40022d07
...@@ -20,7 +20,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be ...@@ -20,7 +20,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
- [Serving with two H200*8 nodes and docker](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-two-h2008-nodes-and-docker). - [Serving with two H200*8 nodes and docker](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-two-h2008-nodes-and-docker).
## Optimisations ## Optimizations
### Multi-head Latent Attention (MLA) Throughput Optimizations ### Multi-head Latent Attention (MLA) Throughput Optimizations
......
import json import json
import logging
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from json import JSONDecodeError, JSONDecoder from json import JSONDecodeError, JSONDecoder
...@@ -8,6 +9,8 @@ import partial_json_parser ...@@ -8,6 +9,8 @@ import partial_json_parser
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
TOOLS_TAG_LIST = [ TOOLS_TAG_LIST = [
"<|plugin|>", "<|plugin|>",
"<function=", "<function=",
...@@ -88,17 +91,43 @@ class BaseFormatDetector: ...@@ -88,17 +91,43 @@ class BaseFormatDetector:
self.bot_token = "" self.bot_token = ""
self.eot_token = "" self.eot_token = ""
def parse_base_json(self, action: Dict, tools: List[Function]): def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallItem]:
name, parameters = action["name"], json.dumps( tool_indices = {
action.get("parameters", action.get("arguments", {})), tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
}
if not isinstance(action, list):
name = action.get("name")
if not name or name not in tool_indices:
logger.warning(f"Model attempted to call undefined function: {name}")
return []
return [
ToolCallItem(
tool_index=tool_indices[name],
name=name,
parameters=json.dumps(
action.get("parameters") or action.get("arguments", {}),
ensure_ascii=False,
),
)
]
results = []
for act in action:
name = act.get("name")
if name and name in tool_indices:
results.append(
ToolCallItem(
tool_index=tool_indices[name],
name=name,
parameters=json.dumps(
act.get("parameters") or act.get("arguments", {}),
ensure_ascii=False, ensure_ascii=False,
),
) )
tool_index = [tool.function.name for tool in tools].index(name)
tool_call_item = ToolCallItem(
tool_index=tool_index, name=name, parameters=parameters
) )
calls = [tool_call_item]
return calls return results
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
""" """
...@@ -112,9 +141,7 @@ class BaseFormatDetector: ...@@ -112,9 +141,7 @@ class BaseFormatDetector:
self, new_text: str, tools: List[Function] self, new_text: str, tools: List[Function]
) -> StreamingParseResult: ) -> StreamingParseResult:
""" """
Streaming incremental parsing, referencing the logic of Llama32Detector. Streaming incremental parsing with tool validation.
We partially parse JSON within <tool_call>...</tool_call>, and handle
incremental argument output.
""" """
# Append new text to buffer # Append new text to buffer
self._buffer += new_text self._buffer += new_text
...@@ -125,17 +152,19 @@ class BaseFormatDetector: ...@@ -125,17 +152,19 @@ class BaseFormatDetector:
new_text = new_text.replace(self.eot_token, "") new_text = new_text.replace(self.eot_token, "")
return StreamingParseResult(normal_text=new_text) return StreamingParseResult(normal_text=new_text)
# bit mask flags for partial JSON parsing. If the name hasn't been # Build tool indices if not already built
# sent yet, don't allow sending if not hasattr(self, "_tool_indices"):
# an incomplete string since OpenAI only ever (as far as I have self._tool_indices = {
# seen) allows sending the entire tool/ function name at once. tool.function.name: i
for i, tool in enumerate(tools)
if tool.function and tool.function.name
}
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try: try:
tool_call_arr = [] tool_call_arr = []
is_complete = [] is_complete = []
try: try:
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
start_idx = ( start_idx = (
len(self.bot_token) len(self.bot_token)
if current_text.startswith(self.bot_token) if current_text.startswith(self.bot_token)
...@@ -149,8 +178,18 @@ class BaseFormatDetector: ...@@ -149,8 +178,18 @@ class BaseFormatDetector:
_is_complete_json(current_text[start_idx : start_idx + end_idx]) _is_complete_json(current_text[start_idx : start_idx + end_idx])
) )
start_idx += end_idx + len("; ") start_idx += end_idx + len("; ")
# depending on the prompt Llama can use
# either arguments or parameters # Validate tool name if present
if "name" in obj and obj["name"] not in self._tool_indices:
# Invalid tool name - reset state
self._buffer = ""
self.current_tool_id = -1
self.current_tool_name_sent = False
if self.streamed_args_for_tool:
self.streamed_args_for_tool.pop()
return StreamingParseResult()
# Handle parameters/arguments consistency
if "parameters" in obj: if "parameters" in obj:
assert ( assert (
"arguments" not in obj "arguments" not in obj
...@@ -159,29 +198,17 @@ class BaseFormatDetector: ...@@ -159,29 +198,17 @@ class BaseFormatDetector:
tool_call_arr.append(obj) tool_call_arr.append(obj)
except partial_json_parser.core.exceptions.MalformedJSON: except partial_json_parser.core.exceptions.MalformedJSON:
# not enough tokens to parse into JSON yet
return StreamingParseResult() return StreamingParseResult()
# select as the current tool call the one we're on the state at
current_tool_call: Dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
)
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0: if len(tool_call_arr) == 0:
return StreamingParseResult() return StreamingParseResult()
# case: we are starting a new tool in the array current_tool_call: Dict = (
# -> array has > 0 length AND length has moved past cursor tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
elif ( )
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
):
# if we're moving on to a new call, first make sure we # Handle new tool in array
# haven't missed anything in the previous one that was if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0: if self.current_tool_id >= 0:
cur_arguments = current_tool_call.get("arguments") cur_arguments = current_tool_call.get("arguments")
if cur_arguments: if cur_arguments:
...@@ -190,7 +217,6 @@ class BaseFormatDetector: ...@@ -190,7 +217,6 @@ class BaseFormatDetector:
argument_diff = cur_args_json[sent:] argument_diff = cur_args_json[sent:]
res = StreamingParseResult( res = StreamingParseResult(
normal_text=None,
calls=[ calls=[
ToolCallItem( ToolCallItem(
tool_index=self.current_tool_id, tool_index=self.current_tool_id,
...@@ -206,23 +232,20 @@ class BaseFormatDetector: ...@@ -206,23 +232,20 @@ class BaseFormatDetector:
res = StreamingParseResult() res = StreamingParseResult()
else: else:
res = StreamingParseResult() res = StreamingParseResult()
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1 self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False self.current_tool_name_sent = False
self.streamed_args_for_tool.append("") self.streamed_args_for_tool.append("")
print("starting on new tool %d", self.current_tool_id)
return res return res
# if the current tool name hasn't been sent, send if available # Handle tool name
# - otherwise send nothing
elif not self.current_tool_name_sent: elif not self.current_tool_name_sent:
function_name = current_tool_call.get("name") function_name = current_tool_call.get("name")
if function_name: if function_name and function_name in self._tool_indices:
res = StreamingParseResult( res = StreamingParseResult(
normal_text=None,
calls=[ calls=[
ToolCallItem( ToolCallItem(
tool_index=self.current_tool_id, tool_index=self._tool_indices[function_name],
name=function_name, name=function_name,
parameters="", parameters="",
) )
...@@ -232,8 +255,7 @@ class BaseFormatDetector: ...@@ -232,8 +255,7 @@ class BaseFormatDetector:
else: else:
res = StreamingParseResult() res = StreamingParseResult()
# now we know we're on the same tool call and we're streaming # Handle streaming arguments
# arguments
else: else:
cur_arguments = current_tool_call.get("arguments") cur_arguments = current_tool_call.get("arguments")
res = StreamingParseResult() res = StreamingParseResult()
...@@ -250,13 +272,12 @@ class BaseFormatDetector: ...@@ -250,13 +272,12 @@ class BaseFormatDetector:
argument_diff = cur_args_json[sent:] argument_diff = cur_args_json[sent:]
self._buffer = "" self._buffer = ""
self.prev_tool_call_arr[self.current_tool_id].clear() self.prev_tool_call_arr[self.current_tool_id].clear()
self.current_tool_name_sent: bool = False self.current_tool_name_sent = False
self.streamed_args_for_tool[self.current_tool_id] = "" self.streamed_args_for_tool[self.current_tool_id] = ""
elif prev_arguments: elif prev_arguments:
prev_args_json = json.dumps(prev_arguments) prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json: if cur_args_json != prev_args_json:
prefix = _find_common_prefix(prev_args_json, cur_args_json) prefix = _find_common_prefix(prev_args_json, cur_args_json)
argument_diff = prefix[sent:] argument_diff = prefix[sent:]
...@@ -279,8 +300,7 @@ class BaseFormatDetector: ...@@ -279,8 +300,7 @@ class BaseFormatDetector:
return res return res
except Exception as e: except Exception as e:
print(e) logger.error(f"Error in parse_streaming_increment: {e}")
# Skipping chunk as a result of tool streaming extraction error
return StreamingParseResult() return StreamingParseResult()
...@@ -372,31 +392,38 @@ class Llama32Detector(BaseFormatDetector): ...@@ -372,31 +392,38 @@ class Llama32Detector(BaseFormatDetector):
Detector for Llama 3.2 models. Detector for Llama 3.2 models.
Assumes function call format: Assumes function call format:
<|python_tag|>{"name":"xxx", "arguments":{...}} <|python_tag|>{"name":"xxx", "arguments":{...}}
Does not require a closing tag "</python_tag|>",
relies on json.loads(...) success to determine if JSON is complete.
""" """
def __init__(self): def __init__(self):
"""
Initializes the detector with necessary state variables.
"""
super().__init__() super().__init__()
self.bot_token = "<|python_tag|>" self.bot_token = "<|python_tag|>"
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
""" """Parse function calls from text, handling multiple JSON objects."""
One-time parsing: Detects and parses tool calls in the provided text. if "<|python_tag|>" not in text:
return []
:param text: The complete text to parse. _, action_text = text.split("<|python_tag|>")
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. # Split by semicolon and process each part
""" json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
all_actions = []
for part in json_parts:
try:
# Parse each individual JSON object
action = json.loads(part)
all_actions.append(action)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON part: {part}")
logger.warning(f"JSON parse error: {str(e)}")
continue
# Only process if we found valid JSON objects
if all_actions:
return self.parse_base_json(all_actions, tools)
if "<|python_tag|>" not in text:
return [] return []
_, action = text.split("<|python_tag|>")
action = json.loads(action)
return self.parse_base_json(action, tools)
class MultiFormatParser: class MultiFormatParser:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment