Unverified Commit 125be8ce authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: add support for multi-tool within nested tags (#2501)

parent 56e99232
......@@ -24,11 +24,9 @@ pub struct CalledFunctionArguments {
pub arguments: HashMap<String, Value>,
}
fn extract_tool_call_content<'a>(
input: &'a str,
start_token: &str,
end_token: &str,
) -> Option<&'a str> {
// Extract the contents between start and end tokens using regex parsing.
// Returns a JSON array string if there are multiple matches, otherwise returns the last match directly.
fn extract_tool_call_content(input: &str, start_token: &str, end_token: &str) -> Option<String> {
let escaped_start = regex::escape(start_token);
let escaped_end = regex::escape(end_token);
let pattern = format!(r"{}(.*?){}", escaped_start, escaped_end);
......@@ -38,19 +36,62 @@ fn extract_tool_call_content<'a>(
.build()
{
Ok(regex) => {
// Get all matches and take the last one for now. TODO : Handle multiple tool calls
// Get all matches and take the last one for now. TODO: Handle multiple tool calls
let matches: Vec<_> = regex
.captures_iter(input)
.filter_map(|captures| captures.get(1))
.map(|m| m.as_str().trim())
.map(|m| m.as_str().trim().to_string())
.collect();
matches.last().copied()
if !matches.is_empty() {
// If only one match, return it directly, otherwise return as a JSON array string
if matches.len() == 1 {
// Return the last match directly
return Some(matches.last().unwrap().clone());
} else {
// Join the matches into a JSON array string
return Some(format!("[{}]", matches.join(",")));
}
}
None
}
Err(_) => None,
}
}
// Special case for <|python_tag|> . Regex pattern does not work well with it as it has no end token
// Handles single tool and multiple tool call cases for single start_token like <|python_tag|>
fn handle_single_token_tool_calls(input: &str, start_token: &str) -> String {
// Return the input if it doesn't contain the start token
if !input.contains(start_token) {
return input.to_string();
}
// Split on the start token and keep only JSON-looking segments
let mut items: Vec<String> = Vec::new();
for seg in input.split(start_token) {
let s = seg.trim();
if s.is_empty() {
continue;
}
// Only consider segments that start like JSON
if s.starts_with('{') || s.starts_with('[') {
// Trim trailing non-JSON by cutting at the last closing brace/bracket
if let Some(pos) = s.rfind(['}', ']']) {
let candidate = &s[..=pos];
// Keep only valid JSON candidates
if serde_json::from_str::<serde_json::Value>(candidate).is_ok() {
items.push(candidate.to_string());
}
}
}
}
if items.is_empty() {
return input.to_string();
}
format!("[{}]", items.join(","))
}
/// Attempts to parse a tool call from a raw LLM message string into a unified [`ToolCallResponse`] format.
///
/// This is a flexible helper that handles a variety of potential formats emitted by LLMs for function/tool calls,
......@@ -110,21 +151,25 @@ pub fn try_tool_call_parse_json(
);
// Iterate over all start and end tokens and try to extract the content between them
let mut json = trimmed;
// Assumption : One message will not contain different tags for tool calls. Iteration over tags is to support different tags by default for multiple models
let mut json = trimmed.to_string();
for (start_token, end_token) in tool_call_start_tokens
.iter()
.zip(tool_call_end_tokens.iter())
{
// Special case for <|python_tag|> . Regex pattern does not work well with it as it has no end token
json = if !start_token.is_empty() && end_token.is_empty() {
json.strip_prefix(start_token).unwrap_or(json)
} else if let Some(content) = extract_tool_call_content(json, start_token, end_token) {
handle_single_token_tool_calls(&json, start_token)
} else if let Some(content) = extract_tool_call_content(&json, start_token, end_token) {
content
} else {
json
};
}
// Convert json to &str if it's a String, otherwise keep as &str
let json = json.as_str();
// Anonymous function to attempt deserialization into a known representation
let parse = |name: String, args: HashMap<String, Value>| -> anyhow::Result<_> {
Ok(ToolCallResponse {
......
......@@ -334,8 +334,6 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
}
#[test]
#[ignore]
// TODO : Implement this
fn test_qwen_qwq_32b_multiple_tool_calls() {
let input = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
......@@ -426,6 +424,22 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_meta_llama_llama31_8b_instruct_with_python_tag_multiple() {
let input = r#"<|python_tag|>{ "name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit" } }<|python_tag|>{ "name": "get_weather", "parameters": {"location": "New York, NY", "unit": "fahrenheit" } }"#;
let result = detect_and_parse_tool_call(input, Some("llama3_json")).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "New York, NY");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_detect_and_parse_tool_call_error_handling() {
// Unknown parser string should return an error
......@@ -522,6 +536,22 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_detect_and_parse_tool_call_default_parser_nemotron_deci_multiple() {
let input = r#"<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let result = detect_and_parse_tool_call(input, None).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "New York, NY");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_detect_and_parse_tool_call_default_parser_llama3_json_with_python_tag() {
let input = r#"<|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment