"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "9851a69f6d294f5d672d973d8a1dbeebdd2aa04e"
Unverified Commit bdb962d7 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

fix(tool call): Fix tool_index in PythonicDetector and issues with mixed...

fix(tool call): Fix tool_index in PythonicDetector and issues with mixed output in non-streaming (#6678)
parent 0b9557fc
...@@ -72,7 +72,7 @@ class BaseFormatDetector(ABC): ...@@ -72,7 +72,7 @@ class BaseFormatDetector(ABC):
action = json.loads(text) action = json.loads(text)
return StreamingParseResult(calls=self.parse_base_json(action, tools)) return StreamingParseResult(calls=self.parse_base_json(action, tools))
def ends_with_partial_token(self, buffer: str, bot_token: str) -> int: def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
""" """
Check if buffer ends with a partial bot_token. Check if buffer ends with a partial bot_token.
Return the length of the partial bot_token. Return the length of the partial bot_token.
...@@ -108,7 +108,7 @@ class BaseFormatDetector(ABC): ...@@ -108,7 +108,7 @@ class BaseFormatDetector(ABC):
current_text = self._buffer current_text = self._buffer
if not (self.bot_token in current_text or current_text.startswith("{")): if not (self.bot_token in current_text or current_text.startswith("{")):
# Only clear buffer if we're sure no tool call is starting # Only clear buffer if we're sure no tool call is starting
if not self.ends_with_partial_token(self._buffer, self.bot_token): if not self._ends_with_partial_token(self._buffer, self.bot_token):
normal_text = self._buffer normal_text = self._buffer
self._buffer = "" self._buffer = ""
if self.eot_token in normal_text: if self.eot_token in normal_text:
......
...@@ -33,46 +33,67 @@ class PythonicDetector(BaseFormatDetector): ...@@ -33,46 +33,67 @@ class PythonicDetector(BaseFormatDetector):
) )
def has_tool_call(self, text: str) -> bool: def has_tool_call(self, text: str) -> bool:
return bool(self.tool_call_regex.match(text.strip())) return bool(self.tool_call_regex.search(text.strip()))
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
# Try parsing the text as a Python list of function calls # Try parsing the text as a Python list of function calls
text = text.strip() text = text.strip()
if not (text.startswith("[") and text.endswith("]")):
# Not a pythonic tool call format match = self.tool_call_regex.search(text)
if match is None:
return StreamingParseResult(normal_text=text, calls=[]) return StreamingParseResult(normal_text=text, calls=[])
# Extract the tool call part and any text before/after it
tool_call_start = match.start()
tool_call_end = match.end()
normal_text_before = text[:tool_call_start] if tool_call_start > 0 else ""
tool_call_text = text[tool_call_start:tool_call_end]
normal_text_after = text[tool_call_end:] if tool_call_end < len(text) else ""
# Combine normal text
normal_text = normal_text_before + normal_text_after
try: try:
module = ast.parse(text) module = ast.parse(tool_call_text)
parsed = getattr(module.body[0], "value", None) parsed = getattr(module.body[0], "value", None)
if not ( if not (
isinstance(parsed, ast.List) isinstance(parsed, ast.List)
and all(isinstance(e, ast.Call) for e in parsed.elts) and all(isinstance(e, ast.Call) for e in parsed.elts)
): ):
return StreamingParseResult(normal_text=text, calls=[]) return StreamingParseResult(normal_text=normal_text, calls=[])
calls = [] calls = []
tool_indices = { tool_indices = {
tool.function.name: i tool.function.name: i
for i, tool in enumerate(tools) for i, tool in enumerate(tools)
if tool.function.name if tool.function.name
} }
for call in parsed.elts: for call_index, call in enumerate(parsed.elts):
if not isinstance(call.func, ast.Name): if not isinstance(call.func, ast.Name):
continue continue
function_name = call.func.id function_name = call.func.id
# Validate that the function exists in the tools
if function_name not in tool_indices:
logger.warning(
f"Model attempted to call undefined function: {function_name}"
)
continue
arguments = {} arguments = {}
for keyword in call.keywords: for keyword in call.keywords:
arguments[keyword.arg] = self._get_parameter_value(keyword.value) arguments[keyword.arg] = self._get_parameter_value(keyword.value)
calls.append( calls.append(
ToolCallItem( ToolCallItem(
tool_index=tool_indices.get(function_name, -1), tool_index=call_index, # Use the call index in the response, not tool position
name=function_name, name=function_name,
parameters=json.dumps(arguments, ensure_ascii=False), parameters=json.dumps(arguments, ensure_ascii=False),
) )
) )
return StreamingParseResult(normal_text="", calls=calls)
return StreamingParseResult(normal_text=normal_text, calls=calls)
except Exception: except Exception:
logger.exception("Error in pythonic tool call parsing.") logger.exception("Error in pythonic tool call parsing.")
return StreamingParseResult(normal_text=text, calls=[]) return StreamingParseResult(normal_text=normal_text, calls=[])
def _find_matching_bracket(self, buffer: str, start: int) -> int: def _find_matching_bracket(self, buffer: str, start: int) -> int:
""" """
......
...@@ -86,7 +86,7 @@ class Qwen25Detector(BaseFormatDetector): ...@@ -86,7 +86,7 @@ class Qwen25Detector(BaseFormatDetector):
result.normal_text = cleaned_text result.normal_text = cleaned_text
else: else:
# Check if buffer might contain partial end token at the end # Check if buffer might contain partial end token at the end
partial_match_len = self.ends_with_partial_token( partial_match_len = self._ends_with_partial_token(
self._normal_text_buffer, end_token_without_newline self._normal_text_buffer, end_token_without_newline
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment