Unverified Commit 125be8ce authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: add support for multi-tool within nested tags (#2501)

parent 56e99232
...@@ -24,11 +24,9 @@ pub struct CalledFunctionArguments { ...@@ -24,11 +24,9 @@ pub struct CalledFunctionArguments {
pub arguments: HashMap<String, Value>, pub arguments: HashMap<String, Value>,
} }
fn extract_tool_call_content<'a>( // Extract the contents between start and end tokens using regex parsing.
input: &'a str, // Returns a JSON array string if there are multiple matches, otherwise returns the last match directly.
start_token: &str, fn extract_tool_call_content(input: &str, start_token: &str, end_token: &str) -> Option<String> {
end_token: &str,
) -> Option<&'a str> {
let escaped_start = regex::escape(start_token); let escaped_start = regex::escape(start_token);
let escaped_end = regex::escape(end_token); let escaped_end = regex::escape(end_token);
let pattern = format!(r"{}(.*?){}", escaped_start, escaped_end); let pattern = format!(r"{}(.*?){}", escaped_start, escaped_end);
...@@ -38,19 +36,62 @@ fn extract_tool_call_content<'a>( ...@@ -38,19 +36,62 @@ fn extract_tool_call_content<'a>(
.build() .build()
{ {
Ok(regex) => { Ok(regex) => {
// Get all matches and take the last one for now. TODO : Handle multiple tool calls // Get all matches and take the last one for now. TODO: Handle multiple tool calls
let matches: Vec<_> = regex let matches: Vec<_> = regex
.captures_iter(input) .captures_iter(input)
.filter_map(|captures| captures.get(1)) .filter_map(|captures| captures.get(1))
.map(|m| m.as_str().trim()) .map(|m| m.as_str().trim().to_string())
.collect(); .collect();
if !matches.is_empty() {
matches.last().copied() // If only one match, return it directly, otherwise return as a JSON array string
if matches.len() == 1 {
// Return the last match directly
return Some(matches.last().unwrap().clone());
} else {
// Join the matches into a JSON array string
return Some(format!("[{}]", matches.join(",")));
}
}
None
} }
Err(_) => None, Err(_) => None,
} }
} }
// Special case for <|python_tag|> . Regex pattern does not work well with it as it has no end token
// Handles single tool and multiple tool call cases for single start_token like <|python_tag|>
fn handle_single_token_tool_calls(input: &str, start_token: &str) -> String {
// Return the input if it doesn't contain the start token
if !input.contains(start_token) {
return input.to_string();
}
// Split on the start token and keep only JSON-looking segments
let mut items: Vec<String> = Vec::new();
for seg in input.split(start_token) {
let s = seg.trim();
if s.is_empty() {
continue;
}
// Only consider segments that start like JSON
if s.starts_with('{') || s.starts_with('[') {
// Trim trailing non-JSON by cutting at the last closing brace/bracket
if let Some(pos) = s.rfind(['}', ']']) {
let candidate = &s[..=pos];
// Keep only valid JSON candidates
if serde_json::from_str::<serde_json::Value>(candidate).is_ok() {
items.push(candidate.to_string());
}
}
}
}
if items.is_empty() {
return input.to_string();
}
format!("[{}]", items.join(","))
}
/// Attempts to parse a tool call from a raw LLM message string into a unified [`ToolCallResponse`] format. /// Attempts to parse a tool call from a raw LLM message string into a unified [`ToolCallResponse`] format.
/// ///
/// This is a flexible helper that handles a variety of potential formats emitted by LLMs for function/tool calls, /// This is a flexible helper that handles a variety of potential formats emitted by LLMs for function/tool calls,
...@@ -110,21 +151,25 @@ pub fn try_tool_call_parse_json( ...@@ -110,21 +151,25 @@ pub fn try_tool_call_parse_json(
); );
// Iterate over all start and end tokens and try to extract the content between them // Iterate over all start and end tokens and try to extract the content between them
let mut json = trimmed; // Assumption : One message will not contain different tags for tool calls. Iteration over tags is to support different tags by default for multiple models
let mut json = trimmed.to_string();
for (start_token, end_token) in tool_call_start_tokens for (start_token, end_token) in tool_call_start_tokens
.iter() .iter()
.zip(tool_call_end_tokens.iter()) .zip(tool_call_end_tokens.iter())
{ {
// Special case for <|python_tag|> . Regex pattern does not work well with it as it has no end token // Special case for <|python_tag|> . Regex pattern does not work well with it as it has no end token
json = if !start_token.is_empty() && end_token.is_empty() { json = if !start_token.is_empty() && end_token.is_empty() {
json.strip_prefix(start_token).unwrap_or(json) handle_single_token_tool_calls(&json, start_token)
} else if let Some(content) = extract_tool_call_content(json, start_token, end_token) { } else if let Some(content) = extract_tool_call_content(&json, start_token, end_token) {
content content
} else { } else {
json json
}; };
} }
// Convert json to &str if it's a String, otherwise keep as &str
let json = json.as_str();
// Anonymous function to attempt deserialization into a known representation // Anonymous function to attempt deserialization into a known representation
let parse = |name: String, args: HashMap<String, Value>| -> anyhow::Result<_> { let parse = |name: String, args: HashMap<String, Value>| -> anyhow::Result<_> {
Ok(ToolCallResponse { Ok(ToolCallResponse {
......
...@@ -334,8 +334,6 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -334,8 +334,6 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
} }
#[test] #[test]
#[ignore]
// TODO : Implement this
fn test_qwen_qwq_32b_multiple_tool_calls() { fn test_qwen_qwq_32b_multiple_tool_calls() {
let input = r#"<tool_call> let input = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}} {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
...@@ -426,6 +424,22 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -426,6 +424,22 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
} }
#[test]
fn test_meta_llama_llama31_8b_instruct_with_python_tag_multiple() {
let input = r#"<|python_tag|>{ "name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit" } }<|python_tag|>{ "name": "get_weather", "parameters": {"location": "New York, NY", "unit": "fahrenheit" } }"#;
let result = detect_and_parse_tool_call(input, Some("llama3_json")).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "New York, NY");
assert_eq!(args["unit"], "fahrenheit");
}
#[test] #[test]
fn test_detect_and_parse_tool_call_error_handling() { fn test_detect_and_parse_tool_call_error_handling() {
// Unknown parser string should return an error // Unknown parser string should return an error
...@@ -522,6 +536,22 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it ...@@ -522,6 +536,22 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
} }
#[test]
fn test_detect_and_parse_tool_call_default_parser_nemotron_deci_multiple() {
let input = r#"<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let result = detect_and_parse_tool_call(input, None).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "New York, NY");
assert_eq!(args["unit"], "fahrenheit");
}
#[test] #[test]
fn test_detect_and_parse_tool_call_default_parser_llama3_json_with_python_tag() { fn test_detect_and_parse_tool_call_default_parser_llama3_json_with_python_tag() {
let input = r#"<|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#; let input = r#"<|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment