Unverified Commit a7184bec authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: Tool call parsers incremental improvements + Model Specific Parsers (#2457)

parent ffae72b7
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use regex::RegexBuilder;
use serde_json::Value; use serde_json::Value;
use uuid::Uuid; use uuid::Uuid;
...@@ -23,6 +24,33 @@ pub struct CalledFunctionArguments { ...@@ -23,6 +24,33 @@ pub struct CalledFunctionArguments {
pub arguments: HashMap<String, Value>, pub arguments: HashMap<String, Value>,
} }
fn extract_tool_call_content<'a>(
input: &'a str,
start_token: &str,
end_token: &str,
) -> Option<&'a str> {
let escaped_start = regex::escape(start_token);
let escaped_end = regex::escape(end_token);
let pattern = format!(r"{}(.*?){}", escaped_start, escaped_end);
match RegexBuilder::new(&pattern)
.dot_matches_new_line(true)
.build()
{
Ok(regex) => {
// Get all matches and take the last one for now. TODO : Handle multiple tool calls
let matches: Vec<_> = regex
.captures_iter(input)
.filter_map(|captures| captures.get(1))
.map(|m| m.as_str().trim())
.collect();
matches.last().copied()
}
Err(_) => None,
}
}
/// Attempts to parse a tool call from a raw LLM message string into a unified [`ToolCallResponse`] format. /// Attempts to parse a tool call from a raw LLM message string into a unified [`ToolCallResponse`] format.
/// ///
/// This is a flexible helper that handles a variety of potential formats emitted by LLMs for function/tool calls, /// This is a flexible helper that handles a variety of potential formats emitted by LLMs for function/tool calls,
...@@ -72,24 +100,30 @@ pub fn try_tool_call_parse_json( ...@@ -72,24 +100,30 @@ pub fn try_tool_call_parse_json(
tracing::debug!("Using JSON parser config: {:?}", config); tracing::debug!("Using JSON parser config: {:?}", config);
let trimmed = message.trim(); let trimmed = message.trim();
// Support <TOOLCALL>[ ... ] or <tool_call>[ ... ] // Use config to get tool call start and end token vectors, then use the first element for now
let json = if let Some(stripped) = trimmed.strip_prefix("<TOOLCALL>[") { let tool_call_start_tokens = &config.tool_call_start_tokens;
if let Some(stripped) = stripped.strip_suffix("]</TOOLCALL>") { let tool_call_end_tokens = &config.tool_call_end_tokens;
tracing::debug!("Stripping <TOOLCALL> wrapper from tool call payload");
stripped
} else {
trimmed
}
// Support custom/LLM-formatted `<|python_tag|>` preamble assert!(
} else if let Some(stripped) = trimmed.strip_prefix("<|python_tag|>") { tool_call_start_tokens.len() == tool_call_end_tokens.len(),
tracing::debug!("Stripping <|python_tag|> prefix from tool call payload"); "Tool call start and end tokens must have the same length"
stripped );
// Otherwise, assume input is clean JSON // Iterate over all start and end tokens and try to extract the content between them
} else { let mut json = trimmed;
trimmed for (start_token, end_token) in tool_call_start_tokens
}; .iter()
.zip(tool_call_end_tokens.iter())
{
// Special case for <|python_tag|> . Regex pattern does not work well with it as it has no end token
json = if !start_token.is_empty() && end_token.is_empty() {
json.strip_prefix(start_token).unwrap_or(json)
} else if let Some(content) = extract_tool_call_content(json, start_token, end_token) {
content
} else {
json
};
}
// Anonymous function to attempt deserialization into a known representation // Anonymous function to attempt deserialization into a known representation
let parse = |name: String, args: HashMap<String, Value>| -> anyhow::Result<_> { let parse = |name: String, args: HashMap<String, Value>| -> anyhow::Result<_> {
......
...@@ -44,22 +44,13 @@ impl Default for JsonParserConfig { ...@@ -44,22 +44,13 @@ impl Default for JsonParserConfig {
parallel_tool_calls_start_tokens: vec![], parallel_tool_calls_start_tokens: vec![],
parallel_tool_calls_end_tokens: vec![], parallel_tool_calls_end_tokens: vec![],
tool_call_start_tokens: vec!["<TOOLCALL>".to_string(), "<|python_tag|>".to_string()], tool_call_start_tokens: vec!["<TOOLCALL>".to_string(), "<|python_tag|>".to_string()],
tool_call_end_tokens: vec!["</TOOLCALL>".to_string()], tool_call_end_tokens: vec!["</TOOLCALL>".to_string(), "".to_string()],
function_name_keys: vec!["name".to_string()], function_name_keys: vec!["name".to_string()],
arguments_keys: vec!["arguments".to_string(), "parameters".to_string()], arguments_keys: vec!["arguments".to_string(), "parameters".to_string()],
} }
} }
} }
/// Configuration for parsing tool calls with different formats
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ToolCallConfig {
/// The format type for tool calls
pub format: ToolCallParserType,
/// The config for the JSON parser
pub json: JsonParserConfig,
}
impl Default for ToolCallConfig { impl Default for ToolCallConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
...@@ -69,6 +60,56 @@ impl Default for ToolCallConfig { ...@@ -69,6 +60,56 @@ impl Default for ToolCallConfig {
} }
} }
impl ToolCallConfig {
/// Default configuration for hermes tool calls
/// <tool_call>{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}\n</tool_call>
pub fn hermes() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["\n</tool_call>".to_string()],
..Default::default()
},
}
}
/// Default configuration for nemotron tool calls
/// <TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>
pub fn nemotron_deci() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<TOOLCALL>".to_string()],
tool_call_end_tokens: vec!["</TOOLCALL>".to_string()],
..Default::default()
},
}
}
pub fn llama3_json() -> Self {
// <|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} }
// or { "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} }
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<|python_tag|>".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
},
}
}
}
/// Configuration for parsing tool calls with different formats
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ToolCallConfig {
/// The format type for tool calls
pub format: ToolCallParserType,
/// The config for the JSON parser
pub json: JsonParserConfig,
}
pub fn try_tool_call_parse( pub fn try_tool_call_parse(
message: &str, message: &str,
config: &ToolCallConfig, config: &ToolCallConfig,
...@@ -91,6 +132,30 @@ pub fn try_tool_call_parse( ...@@ -91,6 +132,30 @@ pub fn try_tool_call_parse(
} }
} }
// Base Detector to call for all tool parsing
pub fn detect_and_parse_tool_call(
message: &str,
parser_str: Option<&str>,
) -> anyhow::Result<Option<ToolCallResponse>> {
let mut parser_map: std::collections::HashMap<&str, ToolCallConfig> =
std::collections::HashMap::new();
parser_map.insert("hermes", ToolCallConfig::hermes());
parser_map.insert("nemotron_deci", ToolCallConfig::nemotron_deci());
parser_map.insert("llama3_json", ToolCallConfig::llama3_json());
parser_map.insert("default", ToolCallConfig::default()); // Add default key
// Handle None or empty string by defaulting to "default"
let parser_key = match parser_str {
Some(s) if !s.is_empty() => s,
_ => "default", // None or empty string
};
match parser_map.get(parser_key) {
Some(config) => try_tool_call_parse(message, config),
None => anyhow::bail!("Parser for the given config is not implemented"), // Original message
}
}
// Tests // Tests
// cargo test postprocessor::tool_calling::parsers // cargo test postprocessor::tool_calling::parsers
#[cfg(test)] #[cfg(test)]
...@@ -163,9 +228,19 @@ mod tests { ...@@ -163,9 +228,19 @@ mod tests {
#[test] #[test]
fn parses_python_tag_prefixed_payload() { fn parses_python_tag_prefixed_payload() {
let input = r#"<|python_tag|>{ "name": "pyfunc", "arguments": { "k": "v" } }"#; let input = r#"<|python_tag|>{ "name": "pyfunc", "arguments": { "k": "v" } }"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default()) let result = try_tool_call_parse(
.unwrap() input,
.unwrap(); &ToolCallConfig {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<|python_tag|>".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
},
},
)
.unwrap()
.unwrap();
let (name, args) = extract_name_and_args(result); let (name, args) = extract_name_and_args(result);
assert_eq!(name, "pyfunc"); assert_eq!(name, "pyfunc");
assert_eq!(args["k"], "v"); assert_eq!(args["k"], "v");
...@@ -187,14 +262,13 @@ mod tests { ...@@ -187,14 +262,13 @@ mod tests {
// Tests for real model outputs - disabled by default // Tests for real model outputs - disabled by default
#[test] #[test]
#[ignore]
fn test_nvidia_llama3_nemotron_super_49b_simple() { fn test_nvidia_llama3_nemotron_super_49b_simple() {
let input = r#"<think> let input = r#"<think>
Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me check the tools available. Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me check the tools available.
</think> </think>
<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>"#; <TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default()) let result = detect_and_parse_tool_call(input, Some("nemotron_deci"))
.unwrap() .unwrap()
.unwrap(); .unwrap();
let (name, args) = extract_name_and_args(result); let (name, args) = extract_name_and_args(result);
...@@ -205,20 +279,26 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -205,20 +279,26 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
#[test] #[test]
#[ignore] #[ignore]
// TODO : Implement extracting function arrays
fn test_nvidia_llama3_nemotron_super_49b_with_function_array() {
let input = r#"<think>
Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me check the tools available.
</think>
<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let config = ToolCallConfig::nemotron_deci();
let result = try_tool_call_parse(input, &config).unwrap().unwrap();
println!("{:?}", result);
}
#[test]
fn test_qwen_qwq_32b_simple() { fn test_qwen_qwq_32b_simple() {
let input = r#"<tool_call> let input = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}} {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
</tool_call>"#; </tool_call>"#;
let config = ToolCallConfig { let result = detect_and_parse_tool_call(input, Some("hermes"))
format: ToolCallParserType::Json, .unwrap()
json: JsonParserConfig { .unwrap();
tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["</tool_call>".to_string()],
arguments_keys: vec!["arguments".to_string()],
..Default::default()
},
};
let result = try_tool_call_parse(input, &config).unwrap().unwrap();
let (name, args) = extract_name_and_args(result); let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
...@@ -226,27 +306,35 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -226,27 +306,35 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
} }
#[test] #[test]
#[ignore]
fn test_nousresearch_hermes3_llama31_8b_simple() { fn test_nousresearch_hermes3_llama31_8b_simple() {
let input = r#"<tool_call> let input = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}} {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
</tool_call>"#; </tool_call>"#;
let config = ToolCallConfig { let result = detect_and_parse_tool_call(input, Some("hermes"))
format: ToolCallParserType::Json, .unwrap()
json: JsonParserConfig { .unwrap();
tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["</tool_call>".to_string()],
arguments_keys: vec!["arguments".to_string()],
..Default::default()
},
};
let result = try_tool_call_parse(input, &config).unwrap().unwrap();
let (name, args) = extract_name_and_args(result); let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
} }
#[test]
#[ignore]
// TODO : Implement this
fn test_qwen_qwq_32b_multiple_tool_calls() {
let input = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
</tool_call>
<tool_call>
{"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}
</tool_call>
"#;
let config = ToolCallConfig::hermes();
let result = try_tool_call_parse(input, &config).unwrap().unwrap();
println!("{:?}", result);
}
#[test] #[test]
#[ignore] #[ignore]
fn test_ibm_granite_40_tiny_preview_simple() { fn test_ibm_granite_40_tiny_preview_simple() {
...@@ -288,25 +376,55 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me ...@@ -288,25 +376,55 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
} }
#[test] #[test]
#[ignore]
fn test_meta_llama_llama31_8b_instruct_simple() { fn test_meta_llama_llama31_8b_instruct_simple() {
let input = r#"{"name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit"}}"#; let input = r#"{"name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit"}}"#;
let config = ToolCallConfig { let result = detect_and_parse_tool_call(input, Some("llama3_json"))
format: ToolCallParserType::Json, .unwrap()
json: JsonParserConfig { .unwrap();
tool_call_start_tokens: vec![],
tool_call_end_tokens: vec![],
arguments_keys: vec!["parameters".to_string()],
..Default::default()
},
};
let result = try_tool_call_parse(input, &config).unwrap().unwrap();
let (name, args) = extract_name_and_args(result); let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather"); assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
} }
#[test]
fn test_meta_llama_llama31_8b_instruct_with_python_tag() {
let input = r#"<|python_tag|>{ "name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
let result = detect_and_parse_tool_call(input, Some("llama3_json"))
.unwrap()
.unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_detect_and_parse_tool_call_error_handling() {
// Unknown parser string should return an error
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco, CA"}}"#;
let result = detect_and_parse_tool_call(input, Some("unknown_parser"));
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("is not implemented"),
"Unexpected error message: {}",
err
);
// Known parser, but invalid input (not JSON) should return Ok(None)
let input = "not a json";
let result = detect_and_parse_tool_call(input, Some("hermes"));
assert!(result.is_ok());
assert!(result.unwrap().is_none());
// Known parser, but valid JSON with wrong shape should return Ok(None)
let input = r#"{"foo": "bar"}"#;
let result = detect_and_parse_tool_call(input, Some("hermes"));
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test] #[test]
#[ignore] #[ignore]
fn test_internlm_internlm2_5_7b_chat_simple() { fn test_internlm_internlm2_5_7b_chat_simple() {
...@@ -360,4 +478,34 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it ...@@ -360,4 +478,34 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
assert_eq!(args["location"], "San Francisco, CA"); assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
} }
#[test]
fn test_detect_and_parse_tool_call_default_parser_nemotron_deci() {
let input = r#"<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let result = detect_and_parse_tool_call(input, None).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_detect_and_parse_tool_call_default_parser_llama3_json_with_python_tag() {
let input = r#"<|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
let result = detect_and_parse_tool_call(input, None).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_detect_and_parse_tool_call_default_parser_llama3_json_without_python_tag() {
let input = r#"{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
let result = detect_and_parse_tool_call(input, None).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
} }
...@@ -6,16 +6,16 @@ pub use crate::preprocessor::tools::request::*; ...@@ -6,16 +6,16 @@ pub use crate::preprocessor::tools::request::*;
// Import json_parser from postprocessor module // Import json_parser from postprocessor module
pub use super::json_parser::*; pub use super::json_parser::*;
pub use super::parsers::{try_tool_call_parse, ToolCallConfig}; pub use super::parsers::{detect_and_parse_tool_call, ToolCallConfig};
/// Try parsing a string as a structured tool call, for aggregation usage. /// Try parsing a string as a structured tool call, for aggregation usage.
/// ///
/// If successful, returns a `ChatCompletionMessageToolCall`. /// If successful, returns a `ChatCompletionMessageToolCall`.
pub fn try_tool_call_parse_aggregate( pub fn try_tool_call_parse_aggregate(
message: &str, message: &str,
parser_str: Option<&str>,
) -> anyhow::Result<Option<async_openai::types::ChatCompletionMessageToolCall>> { ) -> anyhow::Result<Option<async_openai::types::ChatCompletionMessageToolCall>> {
let config = ToolCallConfig::default(); let parsed = detect_and_parse_tool_call(message, parser_str)?;
let parsed = try_tool_call_parse(message, &config)?;
if let Some(parsed) = parsed { if let Some(parsed) = parsed {
Ok(Some(async_openai::types::ChatCompletionMessageToolCall { Ok(Some(async_openai::types::ChatCompletionMessageToolCall {
id: parsed.id, id: parsed.id,
...@@ -35,9 +35,9 @@ pub fn try_tool_call_parse_aggregate( ...@@ -35,9 +35,9 @@ pub fn try_tool_call_parse_aggregate(
/// If successful, returns a `ChatCompletionMessageToolCallChunk`. /// If successful, returns a `ChatCompletionMessageToolCallChunk`.
pub fn try_tool_call_parse_stream( pub fn try_tool_call_parse_stream(
message: &str, message: &str,
parser_str: Option<&str>,
) -> anyhow::Result<Option<async_openai::types::ChatCompletionMessageToolCallChunk>> { ) -> anyhow::Result<Option<async_openai::types::ChatCompletionMessageToolCallChunk>> {
let config = ToolCallConfig::default(); let parsed = detect_and_parse_tool_call(message, parser_str)?;
let parsed = try_tool_call_parse(message, &config)?;
if let Some(parsed) = parsed { if let Some(parsed) = parsed {
Ok(Some( Ok(Some(
async_openai::types::ChatCompletionMessageToolCallChunk { async_openai::types::ChatCompletionMessageToolCallChunk {
......
...@@ -166,6 +166,7 @@ impl DeltaAggregator { ...@@ -166,6 +166,7 @@ impl DeltaAggregator {
if let Ok(Some(tool_call)) = if let Ok(Some(tool_call)) =
crate::postprocessor::tool_calling::tools::try_tool_call_parse_aggregate( crate::postprocessor::tool_calling::tools::try_tool_call_parse_aggregate(
&choice.text, &choice.text,
None,
) )
{ {
tracing::debug!( tracing::debug!(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment