Unverified Commit 1193f131 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

fix: KimiK2Detector Improve tool call ID parsing with regex (#10972)

parent 84a9f5d6
...@@ -50,6 +50,11 @@ class KimiK2Detector(BaseFormatDetector): ...@@ -50,6 +50,11 @@ class KimiK2Detector(BaseFormatDetector):
self._last_arguments = "" self._last_arguments = ""
# Robust parser for ids like "functions.search:0" or fallback "search:0"
self.tool_call_id_regex = re.compile(
r"^(?:functions\.)?(?P<name>[\w\.]+):(?P<index>\d+)$"
)
def has_tool_call(self, text: str) -> bool: def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a KimiK2 format tool call.""" """Check if the text contains a KimiK2 format tool call."""
return self.bot_token in text return self.bot_token in text
...@@ -76,14 +81,18 @@ class KimiK2Detector(BaseFormatDetector): ...@@ -76,14 +81,18 @@ class KimiK2Detector(BaseFormatDetector):
tool_calls = [] tool_calls = []
for match in function_call_tuples: for match in function_call_tuples:
function_id, function_args = match function_id, function_args = match
function_name = function_id.split(".")[1].split(":")[0] m = self.tool_call_id_regex.match(function_id)
function_idx = int(function_id.split(".")[1].split(":")[1]) if not m:
logger.warning("Unexpected tool_call_id format: %s", function_id)
continue
function_name = m.group("name")
function_idx = int(m.group("index"))
logger.info(f"function_name {function_name}") logger.info(f"function_name {function_name}")
tool_calls.append( tool_calls.append(
ToolCallItem( ToolCallItem(
tool_index=function_idx, # Use the call index in the response, not tool position tool_index=function_idx,
name=function_name, name=function_name,
parameters=function_args, parameters=function_args,
) )
...@@ -128,7 +137,11 @@ class KimiK2Detector(BaseFormatDetector): ...@@ -128,7 +137,11 @@ class KimiK2Detector(BaseFormatDetector):
function_id = match.group("tool_call_id") function_id = match.group("tool_call_id")
function_args = match.group("function_arguments") function_args = match.group("function_arguments")
function_name = function_id.split(".")[1].split(":")[0] m = self.tool_call_id_regex.match(function_id)
if not m:
logger.warning("Unexpected tool_call_id format: %s", function_id)
return StreamingParseResult(normal_text="", calls=calls)
function_name = m.group("name")
# Initialize state if this is the first tool call # Initialize state if this is the first tool call
if self.current_tool_id == -1: if self.current_tool_id == -1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment