Unverified Commit 41f095cf authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: enable tool call array parsing (#2466)

parent d177cdff
...@@ -95,7 +95,7 @@ fn extract_tool_call_content<'a>( ...@@ -95,7 +95,7 @@ fn extract_tool_call_content<'a>(
pub fn try_tool_call_parse_json( pub fn try_tool_call_parse_json(
message: &str, message: &str,
config: &JsonParserConfig, config: &JsonParserConfig,
) -> anyhow::Result<Option<ToolCallResponse>> { ) -> anyhow::Result<Vec<ToolCallResponse>> {
// Log the config we are using // Log the config we are using
tracing::debug!("Using JSON parser config: {:?}", config); tracing::debug!("Using JSON parser config: {:?}", config);
let trimmed = message.trim(); let trimmed = message.trim();
...@@ -147,19 +147,20 @@ pub fn try_tool_call_parse_json( ...@@ -147,19 +147,20 @@ pub fn try_tool_call_parse_json(
// } // }
// } // }
if let Ok(single) = serde_json::from_str::<CalledFunctionParameters>(json) { if let Ok(single) = serde_json::from_str::<CalledFunctionParameters>(json) {
return parse(single.name, single.parameters).map(Some); return Ok(vec![parse(single.name, single.parameters)?]);
//parse(single.name, single.parameters).map(Some);
// CalledFunctionArguments: Single { name, arguments }
// Example: // CalledFunctionArguments: Single { name, arguments }
// { // Example:
// "name": "summarize", // {
// "arguments": { // "name": "summarize",
// "text": "Rust is a systems programming language.", // "arguments": {
// "length": "short" // "text": "Rust is a systems programming language.",
// } // "length": "short"
// } // }
// }
} else if let Ok(single) = serde_json::from_str::<CalledFunctionArguments>(json) { } else if let Ok(single) = serde_json::from_str::<CalledFunctionArguments>(json) {
return parse(single.name, single.arguments).map(Some); return Ok(vec![parse(single.name, single.arguments)?]);
// Vec<CalledFunctionParameters>: List of { name, parameters } // Vec<CalledFunctionParameters>: List of { name, parameters }
// Example: // Example:
...@@ -168,10 +169,12 @@ pub fn try_tool_call_parse_json( ...@@ -168,10 +169,12 @@ pub fn try_tool_call_parse_json(
// { "name": "send_email", "parameters": { "to": "user@example.com", "subject": "Welcome!" } } // { "name": "send_email", "parameters": { "to": "user@example.com", "subject": "Welcome!" } }
// ] // ]
// We pop the last item in the list to use. // We pop the last item in the list to use.
} else if let Ok(mut list) = serde_json::from_str::<Vec<CalledFunctionParameters>>(json) { } else if let Ok(list) = serde_json::from_str::<Vec<CalledFunctionParameters>>(json) {
if let Some(item) = list.pop() { let mut results = Vec::new();
return parse(item.name, item.parameters).map(Some); for item in list {
results.push(parse(item.name, item.parameters)?);
} }
return Ok(results);
// Vec<CalledFunctionArguments>: List of { name, arguments } // Vec<CalledFunctionArguments>: List of { name, arguments }
// Example: // Example:
...@@ -185,11 +188,13 @@ pub fn try_tool_call_parse_json( ...@@ -185,11 +188,13 @@ pub fn try_tool_call_parse_json(
// } // }
// ] // ]
// Again, we take the last item for processing. // Again, we take the last item for processing.
} else if let Ok(mut list) = serde_json::from_str::<Vec<CalledFunctionArguments>>(json) { } else if let Ok(list) = serde_json::from_str::<Vec<CalledFunctionArguments>>(json) {
if let Some(item) = list.pop() { let mut results = Vec::new();
return parse(item.name, item.arguments).map(Some); for item in list {
results.push(parse(item.name, item.arguments)?);
} }
return Ok(results);
} }
Ok(None) Ok(vec![])
} }
...@@ -113,7 +113,7 @@ pub struct ToolCallConfig { ...@@ -113,7 +113,7 @@ pub struct ToolCallConfig {
pub fn try_tool_call_parse( pub fn try_tool_call_parse(
message: &str, message: &str,
config: &ToolCallConfig, config: &ToolCallConfig,
) -> anyhow::Result<Option<ToolCallResponse>> { ) -> anyhow::Result<Vec<ToolCallResponse>> {
// Use match statement (Rust's switch statement) to call the appropriate parser // Use match statement (Rust's switch statement) to call the appropriate parser
match config.format { match config.format {
ToolCallParserType::Json => try_tool_call_parse_json(message, &config.json), ToolCallParserType::Json => try_tool_call_parse_json(message, &config.json),
...@@ -136,7 +136,7 @@ pub fn try_tool_call_parse( ...@@ -136,7 +136,7 @@ pub fn try_tool_call_parse(
pub fn detect_and_parse_tool_call( pub fn detect_and_parse_tool_call(
message: &str, message: &str,
parser_str: Option<&str>, parser_str: Option<&str>,
) -> anyhow::Result<Option<ToolCallResponse>> { ) -> anyhow::Result<Vec<ToolCallResponse>> {
let mut parser_map: std::collections::HashMap<&str, ToolCallConfig> = let mut parser_map: std::collections::HashMap<&str, ToolCallConfig> =
std::collections::HashMap::new(); std::collections::HashMap::new();
parser_map.insert("hermes", ToolCallConfig::hermes()); parser_map.insert("hermes", ToolCallConfig::hermes());
...@@ -170,10 +170,10 @@ mod tests { ...@@ -170,10 +170,10 @@ mod tests {
#[test] #[test]
fn parses_single_parameters_object() { fn parses_single_parameters_object() {
let input = r#"{ "name": "hello", "parameters": { "x": 1, "y": 2 } }"#; let input = r#"{ "name": "hello", "parameters": { "x": 1, "y": 2 } }"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default()) let result = try_tool_call_parse(input, &ToolCallConfig::default()).unwrap();
.unwrap() assert!(!result.is_empty());
.unwrap(); assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result); let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "hello"); assert_eq!(name, "hello");
assert_eq!(args["x"], 1); assert_eq!(args["x"], 1);
assert_eq!(args["y"], 2); assert_eq!(args["y"], 2);
...@@ -182,33 +182,39 @@ mod tests { ...@@ -182,33 +182,39 @@ mod tests {
#[test] #[test]
fn parses_single_arguments_object() { fn parses_single_arguments_object() {
let input = r#"{ "name": "world", "arguments": { "a": "abc", "b": 42 } }"#; let input = r#"{ "name": "world", "arguments": { "a": "abc", "b": 42 } }"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default()) let result = try_tool_call_parse(input, &ToolCallConfig::default()).unwrap();
.unwrap() assert!(!result.is_empty());
.unwrap(); assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result); let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "world"); assert_eq!(name, "world");
assert_eq!(args["a"], "abc"); assert_eq!(args["a"], "abc");
assert_eq!(args["b"], 42); assert_eq!(args["b"], 42);
} }
#[test] #[test]
fn parses_vec_of_parameters_and_takes_last() { fn parses_vec_of_parameters() {
let input = r#"[{ "name": "first", "parameters": { "a": 1 } }, { "name": "second", "parameters": { "b": 2 } }]"#; let input = r#"[{ "name": "first", "parameters": { "a": 1 } }, { "name": "second", "parameters": { "b": 2 } }]"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default()) let result = try_tool_call_parse(input, &ToolCallConfig::default()).unwrap();
.unwrap() assert!(!result.is_empty());
.unwrap(); assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result); let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "first");
assert_eq!(args["a"], 1);
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "second"); assert_eq!(name, "second");
assert_eq!(args["b"], 2); assert_eq!(args["b"], 2);
} }
#[test] #[test]
fn parses_vec_of_arguments_and_takes_last() { fn parses_vec_of_arguments() {
let input = r#"[{ "name": "alpha", "arguments": { "a": "x" } }, { "name": "omega", "arguments": { "z": "y" } }]"#; let input = r#"[{ "name": "alpha", "arguments": { "a": "x" } }, { "name": "omega", "arguments": { "z": "y" } }]"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default()) let result = try_tool_call_parse(input, &ToolCallConfig::default()).unwrap();
.unwrap() assert!(!result.is_empty());
.unwrap(); assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result); let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "alpha");
assert_eq!(args["a"], "x");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "omega"); assert_eq!(name, "omega");
assert_eq!(args["z"], "y"); assert_eq!(args["z"], "y");
} }
...@@ -217,10 +223,10 @@ mod tests { ...@@ -217,10 +223,10 @@ mod tests {
fn parses_toolcall_wrapped_payload() { fn parses_toolcall_wrapped_payload() {
let input = let input =
r#"<TOOLCALL>[{ "name": "wrapped", "parameters": { "foo": "bar" } }]</TOOLCALL>"#; r#"<TOOLCALL>[{ "name": "wrapped", "parameters": { "foo": "bar" } }]</TOOLCALL>"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default()) let result = try_tool_call_parse(input, &ToolCallConfig::default()).unwrap();
.unwrap() assert!(!result.is_empty());
.unwrap(); assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result); let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "wrapped"); assert_eq!(name, "wrapped");
assert_eq!(args["foo"], "bar"); assert_eq!(args["foo"], "bar");
} }
...@@ -239,9 +245,10 @@ mod tests { ...@@ -239,9 +245,10 @@ mod tests {
}, },
}, },
) )
.unwrap()
.unwrap(); .unwrap();
let (name, args) = extract_name_and_args(result); assert!(!result.is_empty());
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "pyfunc"); assert_eq!(name, "pyfunc");
assert_eq!(args["k"], "v"); assert_eq!(args["k"], "v");
} }
...@@ -250,14 +257,14 @@ mod tests { ...@@ -250,14 +257,14 @@ mod tests {
fn returns_none_on_invalid_input() { fn returns_none_on_invalid_input() {
let input = r#"not even json"#; let input = r#"not even json"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default()).unwrap(); let result = try_tool_call_parse(input, &ToolCallConfig::default()).unwrap();
assert!(result.is_none()); assert!(result.is_empty());
} }
#[test] #[test]
fn returns_none_on_valid_json_wrong_shape() { fn returns_none_on_valid_json_wrong_shape() {
let input = r#"{ "foo": "bar" }"#; let input = r#"{ "foo": "bar" }"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default()).unwrap(); let result = try_tool_call_parse(input, &ToolCallConfig::default()).unwrap();
assert!(result.is_none()); assert!(result.is_empty());
} }
// Tests for real model outputs - disabled by default // Tests for real model outputs - disabled by default
...@@ -268,18 +275,16 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -268,18 +275,16 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
</think> </think>
<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>"#; <TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let result = detect_and_parse_tool_call(input, Some("nemotron_deci")) let result = detect_and_parse_tool_call(input, Some("nemotron_deci")).unwrap();
.unwrap() assert!(!result.is_empty());
.unwrap(); assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result); let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
} }
#[test] #[test]
#[ignore]
// TODO : Implement extracting function arrays
fn test_nvidia_llama3_nemotron_super_49b_with_function_array() { fn test_nvidia_llama3_nemotron_super_49b_with_function_array() {
let input = r#"<think> let input = r#"<think>
Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me check the tools available. Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me check the tools available.
...@@ -287,8 +292,17 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -287,8 +292,17 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]</TOOLCALL>"#; <TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let config = ToolCallConfig::nemotron_deci(); let config = ToolCallConfig::nemotron_deci();
let result = try_tool_call_parse(input, &config).unwrap().unwrap(); let result = try_tool_call_parse(input, &config).unwrap();
println!("{:?}", result); assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "New York, NY");
assert_eq!(args["unit"], "fahrenheit");
} }
#[test] #[test]
...@@ -296,10 +310,10 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -296,10 +310,10 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
let input = r#"<tool_call> let input = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}} {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
</tool_call>"#; </tool_call>"#;
let result = detect_and_parse_tool_call(input, Some("hermes")) let result = detect_and_parse_tool_call(input, Some("hermes")).unwrap();
.unwrap() assert!(!result.is_empty());
.unwrap(); assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result); let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
...@@ -310,10 +324,10 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -310,10 +324,10 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
let input = r#"<tool_call> let input = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}} {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
</tool_call>"#; </tool_call>"#;
let result = detect_and_parse_tool_call(input, Some("hermes")) let result = detect_and_parse_tool_call(input, Some("hermes")).unwrap();
.unwrap() assert!(!result.is_empty());
.unwrap(); assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result); let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
...@@ -331,8 +345,17 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -331,8 +345,17 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
</tool_call> </tool_call>
"#; "#;
let config = ToolCallConfig::hermes(); let config = ToolCallConfig::hermes();
let result = try_tool_call_parse(input, &config).unwrap().unwrap(); let result = try_tool_call_parse(input, &config).unwrap();
println!("{:?}", result); assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
let (name, args) = extract_name_and_args(result[1].clone());
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "New York, NY");
assert_eq!(args["unit"], "fahrenheit");
} }
#[test] #[test]
...@@ -348,8 +371,10 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -348,8 +371,10 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
..Default::default() ..Default::default()
}, },
}; };
let result = try_tool_call_parse(input, &config).unwrap().unwrap(); let result = try_tool_call_parse(input, &config).unwrap();
let (name, args) = extract_name_and_args(result); assert!(!result.is_empty());
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
...@@ -368,8 +393,10 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -368,8 +393,10 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
..Default::default() ..Default::default()
}, },
}; };
let result = try_tool_call_parse(input, &config).unwrap().unwrap(); let result = try_tool_call_parse(input, &config).unwrap();
let (name, args) = extract_name_and_args(result); assert!(!result.is_empty());
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
...@@ -378,10 +405,10 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -378,10 +405,10 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
#[test] #[test]
fn test_meta_llama_llama31_8b_instruct_simple() { fn test_meta_llama_llama31_8b_instruct_simple() {
let input = r#"{"name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit"}}"#; let input = r#"{"name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit"}}"#;
let result = detect_and_parse_tool_call(input, Some("llama3_json")) let result = detect_and_parse_tool_call(input, Some("llama3_json")).unwrap();
.unwrap() assert!(!result.is_empty());
.unwrap(); assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result); let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
...@@ -390,10 +417,10 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -390,10 +417,10 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
#[test] #[test]
fn test_meta_llama_llama31_8b_instruct_with_python_tag() { fn test_meta_llama_llama31_8b_instruct_with_python_tag() {
let input = r#"<|python_tag|>{ "name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#; let input = r#"<|python_tag|>{ "name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
let result = detect_and_parse_tool_call(input, Some("llama3_json")) let result = detect_and_parse_tool_call(input, Some("llama3_json")).unwrap();
.unwrap() assert!(!result.is_empty());
.unwrap(); assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result); let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
...@@ -416,13 +443,13 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -416,13 +443,13 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
let input = "not a json"; let input = "not a json";
let result = detect_and_parse_tool_call(input, Some("hermes")); let result = detect_and_parse_tool_call(input, Some("hermes"));
assert!(result.is_ok()); assert!(result.is_ok());
assert!(result.unwrap().is_none()); assert!(result.unwrap().is_empty());
// Known parser, but valid JSON with wrong shape should return Ok(None) // Known parser, but valid JSON with wrong shape should return Ok(None)
let input = r#"{"foo": "bar"}"#; let input = r#"{"foo": "bar"}"#;
let result = detect_and_parse_tool_call(input, Some("hermes")); let result = detect_and_parse_tool_call(input, Some("hermes"));
assert!(result.is_ok()); assert!(result.is_ok());
assert!(result.unwrap().is_none()); assert!(result.unwrap().is_empty());
} }
#[test] #[test]
...@@ -434,7 +461,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -434,7 +461,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
Remember, San Francisco weather can be quite unpredictable, particularly with its famous fog, which can significantly lower temperatures. Always check a local weather forecast for the most accurate and up-to-date information."#; Remember, San Francisco weather can be quite unpredictable, particularly with its famous fog, which can significantly lower temperatures. Always check a local weather forecast for the most accurate and up-to-date information."#;
let result = try_tool_call_parse(input, &ToolCallConfig::default()).unwrap(); let result = try_tool_call_parse(input, &ToolCallConfig::default()).unwrap();
assert!(result.is_none()); // This model doesn't produce tool calls assert!(result.is_empty()); // This model doesn't produce tool calls
} }
#[test] #[test]
...@@ -452,8 +479,10 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it ...@@ -452,8 +479,10 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
..Default::default() ..Default::default()
}, },
}; };
let result = try_tool_call_parse(input, &config).unwrap().unwrap(); let result = try_tool_call_parse(input, &config).unwrap();
let (name, args) = extract_name_and_args(result); assert!(!result.is_empty());
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
...@@ -472,8 +501,10 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it ...@@ -472,8 +501,10 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
..Default::default() ..Default::default()
}, },
}; };
let result = try_tool_call_parse(input, &config).unwrap().unwrap(); let result = try_tool_call_parse(input, &config).unwrap();
let (name, args) = extract_name_and_args(result); assert!(!result.is_empty());
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
...@@ -482,8 +513,10 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it ...@@ -482,8 +513,10 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
#[test] #[test]
fn test_detect_and_parse_tool_call_default_parser_nemotron_deci() { fn test_detect_and_parse_tool_call_default_parser_nemotron_deci() {
let input = r#"<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>"#; let input = r#"<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let result = detect_and_parse_tool_call(input, None).unwrap().unwrap(); let result = detect_and_parse_tool_call(input, None).unwrap();
let (name, args) = extract_name_and_args(result); assert!(!result.is_empty());
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
...@@ -492,8 +525,10 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it ...@@ -492,8 +525,10 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
#[test] #[test]
fn test_detect_and_parse_tool_call_default_parser_llama3_json_with_python_tag() { fn test_detect_and_parse_tool_call_default_parser_llama3_json_with_python_tag() {
let input = r#"<|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#; let input = r#"<|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
let result = detect_and_parse_tool_call(input, None).unwrap().unwrap(); let result = detect_and_parse_tool_call(input, None).unwrap();
let (name, args) = extract_name_and_args(result); assert!(!result.is_empty());
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
...@@ -502,8 +537,10 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it ...@@ -502,8 +537,10 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
#[test] #[test]
fn test_detect_and_parse_tool_call_default_parser_llama3_json_without_python_tag() { fn test_detect_and_parse_tool_call_default_parser_llama3_json_without_python_tag() {
let input = r#"{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#; let input = r#"{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
let result = detect_and_parse_tool_call(input, None).unwrap().unwrap(); let result = detect_and_parse_tool_call(input, None).unwrap();
let (name, args) = extract_name_and_args(result); assert!(!result.is_empty());
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
......
...@@ -14,20 +14,24 @@ pub use super::parsers::{detect_and_parse_tool_call, ToolCallConfig}; ...@@ -14,20 +14,24 @@ pub use super::parsers::{detect_and_parse_tool_call, ToolCallConfig};
pub fn try_tool_call_parse_aggregate( pub fn try_tool_call_parse_aggregate(
message: &str, message: &str,
parser_str: Option<&str>, parser_str: Option<&str>,
) -> anyhow::Result<Option<async_openai::types::ChatCompletionMessageToolCall>> { ) -> anyhow::Result<Vec<async_openai::types::ChatCompletionMessageToolCall>> {
let parsed = detect_and_parse_tool_call(message, parser_str)?; let parsed = detect_and_parse_tool_call(message, parser_str)?;
if let Some(parsed) = parsed { if parsed.is_empty() {
Ok(Some(async_openai::types::ChatCompletionMessageToolCall { return Ok(vec![]);
id: parsed.id,
r#type: async_openai::types::ChatCompletionToolType::Function,
function: async_openai::types::FunctionCall {
name: parsed.function.name,
arguments: parsed.function.arguments,
},
}))
} else {
Ok(None)
} }
Ok(parsed
.into_iter()
.map(
|parsed| async_openai::types::ChatCompletionMessageToolCall {
id: parsed.id,
r#type: async_openai::types::ChatCompletionToolType::Function,
function: async_openai::types::FunctionCall {
name: parsed.function.name,
arguments: parsed.function.arguments,
},
},
)
.collect())
} }
/// Try parsing a string as a structured tool call, for streaming (delta) usage. /// Try parsing a string as a structured tool call, for streaming (delta) usage.
...@@ -36,21 +40,25 @@ pub fn try_tool_call_parse_aggregate( ...@@ -36,21 +40,25 @@ pub fn try_tool_call_parse_aggregate(
pub fn try_tool_call_parse_stream( pub fn try_tool_call_parse_stream(
message: &str, message: &str,
parser_str: Option<&str>, parser_str: Option<&str>,
) -> anyhow::Result<Option<async_openai::types::ChatCompletionMessageToolCallChunk>> { ) -> anyhow::Result<Vec<async_openai::types::ChatCompletionMessageToolCallChunk>> {
let parsed = detect_and_parse_tool_call(message, parser_str)?; let parsed = detect_and_parse_tool_call(message, parser_str)?;
if let Some(parsed) = parsed { if parsed.is_empty() {
Ok(Some( return Ok(vec![]);
async_openai::types::ChatCompletionMessageToolCallChunk { }
index: 0, Ok(parsed
.into_iter()
.enumerate()
.map(
|(idx, parsed)| async_openai::types::ChatCompletionMessageToolCallChunk {
index: idx as u32,
id: Some(parsed.id), id: Some(parsed.id),
r#type: Some(async_openai::types::ChatCompletionToolType::Function), r#type: Some(async_openai::types::ChatCompletionToolType::Function),
function: Some(async_openai::types::FunctionCallStream { function: Some(async_openai::types::FunctionCallStream {
name: Some(parsed.function.name), name: Some(parsed.function.name),
arguments: Some(parsed.function.arguments), arguments: Some(parsed.function.arguments),
}), }),
// Add other fields as needed if required by the struct definition
}, },
)) )
} else { .collect())
Ok(None)
}
} }
...@@ -163,20 +163,24 @@ impl DeltaAggregator { ...@@ -163,20 +163,24 @@ impl DeltaAggregator {
// After aggregation, inspect each choice's text for tool call syntax // After aggregation, inspect each choice's text for tool call syntax
for choice in aggregator.choices.values_mut() { for choice in aggregator.choices.values_mut() {
if choice.tool_calls.is_none() { if choice.tool_calls.is_none() {
if let Ok(Some(tool_call)) = if let Ok(tool_calls) =
crate::postprocessor::tool_calling::tools::try_tool_call_parse_aggregate( crate::postprocessor::tool_calling::tools::try_tool_call_parse_aggregate(
&choice.text, &choice.text,
None, None,
) )
{ {
tracing::debug!( if tool_calls.is_empty() {
tool_call_id = %tool_call.id, continue;
function_name = %tool_call.function.name, }
arguments = %tool_call.function.arguments, for tool_call in &tool_calls {
"Parsed structured tool call from aggregated content" tracing::debug!(
); tool_call_id = %tool_call.id,
function_name = %tool_call.function.name,
choice.tool_calls = Some(vec![tool_call]); arguments = %tool_call.function.arguments,
"Parsed structured tool call from aggregated content"
);
}
choice.tool_calls = Some(tool_calls);
choice.text.clear(); choice.text.clear();
choice.finish_reason = Some(async_openai::types::FinishReason::ToolCalls); choice.finish_reason = Some(async_openai::types::FinishReason::ToolCalls);
} }
...@@ -488,4 +492,63 @@ mod tests { ...@@ -488,4 +492,63 @@ mod tests {
); );
assert_eq!(choice1.message.role, async_openai::types::Role::Assistant); assert_eq!(choice1.message.role, async_openai::types::Role::Assistant);
} }
#[tokio::test]
async fn test_tool_calling_output() {
// Simulate a delta with a tool call in the content
let tool_call_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}"#;
// Use create_test_delta to generate the annotated delta, then extract the inner delta for the test
let annotated_delta = create_test_delta(
0,
tool_call_json,
Some(async_openai::types::Role::Assistant),
Some(async_openai::types::FinishReason::ToolCalls),
);
let delta = annotated_delta.data.unwrap().inner;
let data = NvCreateChatCompletionStreamResponse { inner: delta };
// Wrap it in Annotated and create a stream
let annotated_delta = Annotated {
data: Some(data),
id: Some("test_id".to_string()),
event: None,
comment: None,
};
let stream = Box::pin(stream::iter(vec![annotated_delta]));
// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
// Check the result
assert!(result.is_ok());
let response = result.unwrap();
// There should be one choice
assert_eq!(response.inner.choices.len(), 1);
let choice = &response.inner.choices[0];
// The tool_calls field should be present and parsed
assert!(choice.message.tool_calls.is_some());
let tool_calls = choice.message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
let tool_call = &tool_calls[0];
assert_eq!(tool_call.function.name, "get_weather");
// The arguments should be a JSON string containing the expected keys
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments).unwrap();
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
// The content should be cleared (None) after tool call parsing
assert!(choice.message.content.is_none());
// The finish_reason should be ToolCalls
assert_eq!(
choice.finish_reason,
Some(async_openai::types::FinishReason::ToolCalls)
);
assert_eq!(choice.message.role, async_openai::types::Role::Assistant);
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment