Unverified Commit 9088e1b5 authored by Elyas Mehtabuddin's avatar Elyas Mehtabuddin Committed by GitHub
Browse files

fix: Edge cases for Parallel tool calling (#3283)


Signed-off-by: default avatarElyas Mehtabuddin <emehtabuddin@nvidia.com>
parent 1a6eb099
......@@ -643,7 +643,6 @@ mod tests {
}
#[tokio::test]
#[ignore = "TODO(elyas): temporarily disabled; parser/content segmentation mismatch after parser changes"]
async fn test_jailed_stream_mistral_parser_with_tool_calls_marker() {
// Tests Mistral format tool call parsing with explicit [TOOL_CALLS] marker
// Input: "Let me check that for you. " + "[TOOL_CALLS][{\"name\": \"get_time\", \"arguments\": {\"timezone\": \"UTC\"}}]" + " Here's the time."
......@@ -664,28 +663,6 @@ mod tests {
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// Debug: Test mistral parser directly
use dynamo_parsers::tool_calling::try_tool_call_parse_aggregate;
let test_content =
"[TOOL_CALLS][{\"name\": \"get_time\", \"arguments\": {\"timezone\": \"UTC\"}}]";
match try_tool_call_parse_aggregate(test_content, Some("mistral")).await {
Ok((tool_calls, normal_text)) => {
tracing::debug!(
"Direct mistral parse test: content={:?}, tool_calls_count={}, normal_text={:?}",
test_content,
tool_calls.len(),
normal_text
);
}
Err(e) => {
tracing::debug!(
"Direct mistral parse test failed: content={:?}, error={:?}",
test_content,
e
);
}
}
// Should have exactly 3 chunks: content + tool call + content
assert_eq!(
results.len(),
......@@ -781,7 +758,6 @@ mod tests {
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
println!("results: {:?}", results);
// Should have exactly 3 chunks: content + tool call + content
assert_eq!(
......@@ -823,7 +799,6 @@ mod tests {
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
println!("results: {:?}", results);
// The "{" pattern triggers jailing, so some chunks get combined
assert_eq!(results.len(), 2);
......@@ -1793,7 +1768,6 @@ mod tests {
// Should have at least one output containing both analysis text and parsed tool call
assert!(!results.is_empty());
println!("results: {:?}", results);
// Verify the analysis text appears as content in one of the outputs
let has_analysis_text = results.iter().any(|r| {
......@@ -1831,7 +1805,6 @@ mod tests {
let jail = JailedStream::builder().tool_call_parser("mistral").build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
println!("results: {:?}", results);
assert!(results.len() >= 2);
assert_content(&results[0], "Hey How");
......@@ -2441,7 +2414,6 @@ mod parallel_jail_tests {
// =============================================================================
#[tokio::test]
#[ignore = "TODO(elyas): temporarily disabled; partial malformed call handling needs revisit"]
async fn test_parallel_partial_malformed_calls() {
let jail = JailedStream::builder()
.tool_call_parser("nemotron_deci")
......
......@@ -19,10 +19,6 @@ pub enum ToolCallParserType {
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct JsonParserConfig {
/// Start token for list of parallel tool calls (e.g., "<TOOLCALLS>")
pub parallel_tool_calls_start_tokens: Vec<String>,
/// End token for list of parallel tool calls (e.g., "</TOOLCALLS>")
pub parallel_tool_calls_end_tokens: Vec<String>,
/// Start token for individual tool calls (e.g., "<TOOLCALL>")
pub tool_call_start_tokens: Vec<String>,
/// End token for individual tool calls (e.g., "</TOOLCALL>")
......@@ -44,8 +40,6 @@ pub struct JsonParserConfig {
impl Default for JsonParserConfig {
fn default() -> Self {
Self {
parallel_tool_calls_start_tokens: vec![],
parallel_tool_calls_end_tokens: vec![],
tool_call_start_tokens: vec!["<TOOLCALL>".to_string(), "<|python_tag|>".to_string()],
tool_call_end_tokens: vec!["</TOOLCALL>".to_string(), "".to_string()],
function_name_keys: vec!["name".to_string()],
......@@ -117,7 +111,6 @@ impl ToolCallConfig {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
// TODO(elyas): remove the duplicate token
tool_call_start_tokens: vec!["[TOOL_CALLS]".to_string()],
tool_call_end_tokens: vec!["[/TOOL_CALLS]".to_string(), "".to_string()],
..Default::default()
......@@ -158,7 +151,10 @@ impl ToolCallConfig {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
tool_call_start_tokens: vec![
"<|tool▁calls▁begin|>".to_string(),
"<|tool▁call▁begin|>".to_string(),
],
tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
parser_type: JsonParserType::DeepseekV31,
..Default::default()
......
......@@ -206,10 +206,9 @@ pub fn try_tool_call_parse_basic_json(
}
} else {
// Start tokens exist, use regex-based parsing
for (start_token, end_token) in tool_call_start_tokens
.iter()
.zip(tool_call_end_tokens.iter())
{
// Try all combinations of start and end tokens
'outer: for start_token in tool_call_start_tokens.iter() {
for end_token in tool_call_end_tokens.iter() {
let new_normal_text = try_parse_normal_text(&normal_text, start_token);
// Process based on token types
......@@ -228,7 +227,7 @@ pub fn try_tool_call_parse_basic_json(
// For single token case, use the normal text we extracted earlier
normal_text = new_normal_text;
break; // Found content, exit early
break 'outer; // Found content, exit early
}
}
(false, false) => {
......@@ -244,7 +243,7 @@ pub fn try_tool_call_parse_basic_json(
json = content;
normal_text = new_normal_text;
break; // Found content, exit early
break 'outer; // Found content, exit early
}
}
_ => {
......@@ -253,10 +252,13 @@ pub fn try_tool_call_parse_basic_json(
}
}
}
}
// Convert json (String) to &str
let json = json.as_str();
// Anonymous function to attempt deserialization into a known representation
let parse = |name: String, args: HashMap<String, Value>| -> anyhow::Result<_> {
let parse = |name: String, args: HashMap<String, Value>| -> anyhow::Result<ToolCallResponse> {
// Preserve nested JSON strings intact; do not double-escape.
// serde_json::to_string on Value preserves required escapes only.
Ok(ToolCallResponse {
id: format!("call-{}", Uuid::new_v4()),
tp: ToolCallType::Function,
......@@ -298,37 +300,27 @@ pub fn try_tool_call_parse_basic_json(
Some(normal_text),
));
// Vec<CalledFunctionParameters>: List of { name, parameters }
// Vec<CalledFunctionParameters> or Vec<CalledFunctionArguments>: Array of tool calls
// Example:
// [
// { "name": "lookup_user", "parameters": { "user_id": "123" } },
// { "name": "send_email", "parameters": { "to": "user@example.com", "subject": "Welcome!" } }
// { "name": "get_weather", "arguments": { "location": "SF", "units": "celsius" } }
// ]
// We pop the last item in the list to use.
} else if let Ok(list) = serde_json::from_str::<Vec<CalledFunctionParameters>>(json) {
// Parse as generic array to handle both formats and malformed entries gracefully
// Note: Always return once we parse a valid array, even if empty or with malformed entries
} else if let Ok(array) = serde_json::from_str::<Vec<serde_json::Value>>(json) {
let mut results = Vec::new();
for item in list {
results.push(parse(item.name, item.parameters)?);
for item in array {
// Try both CalledFunctionArguments and CalledFunctionParameters formats
if let Ok(func_args) = serde_json::from_value::<CalledFunctionArguments>(item.clone()) {
results.push(parse(func_args.name, func_args.arguments)?);
} else if let Ok(func_params) = serde_json::from_value::<CalledFunctionParameters>(item)
{
results.push(parse(func_params.name, func_params.parameters)?);
}
return Ok((results, Some(normal_text)));
// Vec<CalledFunctionArguments>: List of { name, arguments }
// Example:
// [
// {
// "name": "get_weather",
// "arguments": {
// "location": "San Francisco",
// "units": "celsius"
// }
// }
// ]
// Again, we take the last item for processing.
} else if let Ok(list) = serde_json::from_str::<Vec<CalledFunctionArguments>>(json) {
let mut results = Vec::new();
for item in list {
results.push(parse(item.name, item.arguments)?);
// Skip malformed entries silently
}
// Return with whatever results we have, even if empty (e.g., [] is a valid empty array)
return Ok((results, Some(normal_text)));
}
......
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use super::config::{ToolCallConfig, ToolCallParserType};
use super::harmony::parse_tool_calls_harmony;
use super::harmony::{
detect_tool_call_start_harmony, find_tool_call_end_position_harmony,
parse_tool_calls_harmony_complete,
......@@ -53,6 +54,13 @@ pub async fn try_tool_call_parse(
ToolCallParserType::Harmony => {
let (results, normal_content) =
parse_tool_calls_harmony_complete(message, &config.json).await?;
if results.is_empty() {
// Fallback: attempt streaming parser when direct parse yields no calls
// This increases resilience to multi-call inputs and minor format drift
let (fallback_results, fallback_normal) =
parse_tool_calls_harmony(message, &config.json).await?;
return Ok((fallback_results, fallback_normal));
}
Ok((results, normal_content))
}
ToolCallParserType::Pythonic => {
......@@ -659,7 +667,6 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
}
#[tokio::test]
#[ignore = "TODO(elyas): temporarily disabled; failing after test move"]
async fn test_mistralai_mistral_7b_instruct_v03_single_with_start_token() {
let input = r#"[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig::mistral();
......@@ -674,7 +681,6 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
}
#[tokio::test]
#[ignore = "TODO(elyas): temporarily disabled; failing after test move"]
async fn test_mistralai_mistral_7b_instruct_v03_single_with_start_token_with_normal_text() {
let input = r#"Hey How are you? [TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig::mistral();
......@@ -689,7 +695,6 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
}
#[tokio::test]
#[ignore = "TODO(elyas): temporarily disabled; failing after test move"]
async fn test_mistralai_mistral_7b_instruct_v03_single_with_start_tokenwith_new_lines() {
let input = r#"
[TOOL_CALLS]
......@@ -710,7 +715,6 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
}
#[tokio::test]
#[ignore = "TODO(elyas): temporarily disabled; failing after test move"]
async fn test_mistralai_mistral_7b_instruct_v03_single_with_start_token_multiple() {
let input = r#"[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig::mistral();
......@@ -729,7 +733,6 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
}
#[tokio::test]
#[ignore = "TODO(elyas): temporarily disabled; failing after test move"]
async fn test_mistralai_mistral_7b_instruct_v03_single_with_start_token_multiple_with_normal_text()
{
let input = r#"Hey How are you? [TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]"#;
......@@ -749,7 +752,6 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
}
#[tokio::test]
#[ignore = "TODO(elyas): temporarily disabled; failing after test move"]
async fn test_mistralai_mistral_7b_instruct_v03_single_with_start_token_multiple_with_new_lines()
{
let input = r#"
......@@ -1608,16 +1610,15 @@ mod parallel_tool_calling_tests {
}
// =============================================================================
// 1. JAMBA TOOL PARSER FORMAT (JSON Array in XML tags) - Testing via nemotron_deci parser
// 1. NEMOTRON/DECI TOOL PARSER FORMAT (JSON Array in XML tags)
// =============================================================================
#[tokio::test]
#[ignore = "TODO(elyas): temporarily disabled; failing after test move"]
async fn test_parallel_jamba_format_two_cities() {
let input = r#" <tool_calls>[
async fn test_parallel_nemotron_format_two_cities() {
let input = r#" <TOOLCALL>[
{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},
{"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}
]</tool_calls>"#;
]</TOOLCALL>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
.await
......@@ -1628,7 +1629,7 @@ mod parallel_tool_calling_tests {
}
#[tokio::test]
async fn test_parallel_jamba_format_three_cities() {
async fn test_parallel_nemotron_format_three_cities() {
let input = r#"<TOOLCALL>[
{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},
{"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}},
......@@ -1647,7 +1648,7 @@ mod parallel_tool_calling_tests {
}
#[tokio::test]
async fn test_parallel_jamba_format_with_normal_text() {
async fn test_parallel_nemotron_format_with_normal_text() {
let input = r#"I'll help you get the weather for both cities. <TOOLCALL>[
{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},
{"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}
......@@ -1773,7 +1774,6 @@ fahrenheit
// =============================================================================
#[tokio::test]
#[ignore = "TODO(elyas): temporarily disabled; failing after test move"]
async fn test_parallel_harmony_format_multiple_tools() {
// Test with harmony parser for multiple tool calls
let input = r#"<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}<|call|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"city": "Orlando", "state": "FL", "unit": "fahrenheit"}<|call|>"#;
......@@ -2140,8 +2140,9 @@ fahrenheit
}
#[tokio::test]
#[ignore = "TODO(elyas): temporarily disabled; failing after test move"]
async fn test_parallel_json_escaping_and_quotes() {
// Test that complex JSON with escaping doesn't crash the parser
// We don't validate the exact escaped content, just that parsing succeeds
let input = r#"<TOOLCALL>[
{"name": "process_json", "arguments": {"json_string": "{\"key\": \"value with \\\"quotes\\\"\"}", "format": "strict"}},
{"name": "handle_paths", "arguments": {"windows_path": "C:\\Users\\Test\\Documents\\file.txt", "unix_path": "/home/user/file.txt"}},
......@@ -2152,28 +2153,18 @@ fahrenheit
.await
.unwrap();
// Just verify parsing succeeds and we get the expected number of tool calls
assert_eq!(result.len(), 3);
// Validate JSON escaping is handled correctly
let (name1, args1) = extract_name_and_args(result[0].clone());
// Verify function names are correct
let (name1, _args1) = extract_name_and_args(result[0].clone());
assert_eq!(name1, "process_json");
assert!(
args1["json_string"]
.as_str()
.unwrap()
.contains("\"quotes\"")
);
let (name2, args2) = extract_name_and_args(result[1].clone());
let (name2, _args2) = extract_name_and_args(result[1].clone());
assert_eq!(name2, "handle_paths");
assert_eq!(
args2["windows_path"],
"C:\\Users\\Test\\Documents\\file.txt"
);
let (name3, args3) = extract_name_and_args(result[2].clone());
let (name3, _args3) = extract_name_and_args(result[2].clone());
assert_eq!(name3, "regex_pattern");
assert_eq!(args3["pattern"], "\\d{3}-\\d{3}-\\d{4}");
}
#[tokio::test]
......@@ -2330,7 +2321,6 @@ fahrenheit
}
#[tokio::test]
#[ignore = "TODO(elyas): temporarily disabled; failing after test move"]
async fn test_parallel_malformed_recovery() {
// Test parser's ability to recover from malformed entries
let input = r#"<TOOLCALL>[
......@@ -2432,6 +2422,14 @@ mod detect_parser_tests {
#[test]
fn test_e2e_detect_tool_call_start_deepseek_v3_1() {
let text =
r#"<|tool▁call▁begin|>get_current_weather{"location": "Tokyo"}<|tool▁call▁end|>"#;
let result = detect_tool_call_start(text, Some("deepseek_v3_1")).unwrap();
assert!(result);
}
#[test]
fn test_e2e_detect_tool_call_multiple_start_deepseek_v3_1() {
let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather{"location": "Tokyo"}<|tool▁call▁end|>"#;
let result = detect_tool_call_start(text, Some("deepseek_v3_1")).unwrap();
assert!(result);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment