Unverified Commit a7184bec authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: Tool call parsers incremental improvements + Model Specific Parsers (#2457)

parent ffae72b7
......@@ -3,6 +3,7 @@
use std::collections::HashMap;
use regex::RegexBuilder;
use serde_json::Value;
use uuid::Uuid;
......@@ -23,6 +24,33 @@ pub struct CalledFunctionArguments {
pub arguments: HashMap<String, Value>,
}
fn extract_tool_call_content<'a>(
input: &'a str,
start_token: &str,
end_token: &str,
) -> Option<&'a str> {
let escaped_start = regex::escape(start_token);
let escaped_end = regex::escape(end_token);
let pattern = format!(r"{}(.*?){}", escaped_start, escaped_end);
match RegexBuilder::new(&pattern)
.dot_matches_new_line(true)
.build()
{
Ok(regex) => {
// Get all matches and take the last one for now. TODO : Handle multiple tool calls
let matches: Vec<_> = regex
.captures_iter(input)
.filter_map(|captures| captures.get(1))
.map(|m| m.as_str().trim())
.collect();
matches.last().copied()
}
Err(_) => None,
}
}
/// Attempts to parse a tool call from a raw LLM message string into a unified [`ToolCallResponse`] format.
///
/// This is a flexible helper that handles a variety of potential formats emitted by LLMs for function/tool calls,
......@@ -72,24 +100,30 @@ pub fn try_tool_call_parse_json(
tracing::debug!("Using JSON parser config: {:?}", config);
let trimmed = message.trim();
// Support <TOOLCALL>[ ... ] or <tool_call>[ ... ]
let json = if let Some(stripped) = trimmed.strip_prefix("<TOOLCALL>[") {
if let Some(stripped) = stripped.strip_suffix("]</TOOLCALL>") {
tracing::debug!("Stripping <TOOLCALL> wrapper from tool call payload");
stripped
} else {
trimmed
}
// Use config to get tool call start and end token vectors, then use the first element for now
let tool_call_start_tokens = &config.tool_call_start_tokens;
let tool_call_end_tokens = &config.tool_call_end_tokens;
// Support custom/LLM-formatted `<|python_tag|>` preamble
} else if let Some(stripped) = trimmed.strip_prefix("<|python_tag|>") {
tracing::debug!("Stripping <|python_tag|> prefix from tool call payload");
stripped
assert!(
tool_call_start_tokens.len() == tool_call_end_tokens.len(),
"Tool call start and end tokens must have the same length"
);
// Otherwise, assume input is clean JSON
// Iterate over all start and end tokens and try to extract the content between them
let mut json = trimmed;
for (start_token, end_token) in tool_call_start_tokens
.iter()
.zip(tool_call_end_tokens.iter())
{
// Special case for <|python_tag|> . Regex pattern does not work well with it as it has no end token
json = if !start_token.is_empty() && end_token.is_empty() {
json.strip_prefix(start_token).unwrap_or(json)
} else if let Some(content) = extract_tool_call_content(json, start_token, end_token) {
content
} else {
trimmed
json
};
}
// Anonymous function to attempt deserialization into a known representation
let parse = |name: String, args: HashMap<String, Value>| -> anyhow::Result<_> {
......
......@@ -44,22 +44,13 @@ impl Default for JsonParserConfig {
parallel_tool_calls_start_tokens: vec![],
parallel_tool_calls_end_tokens: vec![],
tool_call_start_tokens: vec!["<TOOLCALL>".to_string(), "<|python_tag|>".to_string()],
tool_call_end_tokens: vec!["</TOOLCALL>".to_string()],
tool_call_end_tokens: vec!["</TOOLCALL>".to_string(), "".to_string()],
function_name_keys: vec!["name".to_string()],
arguments_keys: vec!["arguments".to_string(), "parameters".to_string()],
}
}
}
/// Configuration for parsing tool calls with different formats
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ToolCallConfig {
/// The format type for tool calls
pub format: ToolCallParserType,
/// The config for the JSON parser
pub json: JsonParserConfig,
}
impl Default for ToolCallConfig {
fn default() -> Self {
Self {
......@@ -69,6 +60,56 @@ impl Default for ToolCallConfig {
}
}
impl ToolCallConfig {
/// Default configuration for hermes tool calls
/// <tool_call>{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}\n</tool_call>
pub fn hermes() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["\n</tool_call>".to_string()],
..Default::default()
},
}
}
/// Default configuration for nemotron tool calls
/// <TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>
pub fn nemotron_deci() -> Self {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<TOOLCALL>".to_string()],
tool_call_end_tokens: vec!["</TOOLCALL>".to_string()],
..Default::default()
},
}
}
pub fn llama3_json() -> Self {
// <|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} }
// or { "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"} }
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<|python_tag|>".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
},
}
}
}
/// Configuration for parsing tool calls with different formats
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct ToolCallConfig {
/// The format type for tool calls
pub format: ToolCallParserType,
/// The config for the JSON parser
pub json: JsonParserConfig,
}
pub fn try_tool_call_parse(
message: &str,
config: &ToolCallConfig,
......@@ -91,6 +132,30 @@ pub fn try_tool_call_parse(
}
}
// Base Detector to call for all tool parsing
pub fn detect_and_parse_tool_call(
message: &str,
parser_str: Option<&str>,
) -> anyhow::Result<Option<ToolCallResponse>> {
let mut parser_map: std::collections::HashMap<&str, ToolCallConfig> =
std::collections::HashMap::new();
parser_map.insert("hermes", ToolCallConfig::hermes());
parser_map.insert("nemotron_deci", ToolCallConfig::nemotron_deci());
parser_map.insert("llama3_json", ToolCallConfig::llama3_json());
parser_map.insert("default", ToolCallConfig::default()); // Add default key
// Handle None or empty string by defaulting to "default"
let parser_key = match parser_str {
Some(s) if !s.is_empty() => s,
_ => "default", // None or empty string
};
match parser_map.get(parser_key) {
Some(config) => try_tool_call_parse(message, config),
None => anyhow::bail!("Parser for the given config is not implemented"), // Original message
}
}
// Tests
// cargo test postprocessor::tool_calling::parsers
#[cfg(test)]
......@@ -163,7 +228,17 @@ mod tests {
#[test]
fn parses_python_tag_prefixed_payload() {
let input = r#"<|python_tag|>{ "name": "pyfunc", "arguments": { "k": "v" } }"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default())
let result = try_tool_call_parse(
input,
&ToolCallConfig {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<|python_tag|>".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
},
},
)
.unwrap()
.unwrap();
let (name, args) = extract_name_and_args(result);
......@@ -187,14 +262,13 @@ mod tests {
// Tests for real model outputs - disabled by default
#[test]
#[ignore]
fn test_nvidia_llama3_nemotron_super_49b_simple() {
let input = r#"<think>
Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me check the tools available.
</think>
<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let result = try_tool_call_parse(input, &ToolCallConfig::default())
let result = detect_and_parse_tool_call(input, Some("nemotron_deci"))
.unwrap()
.unwrap();
let (name, args) = extract_name_and_args(result);
......@@ -205,20 +279,26 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
#[test]
#[ignore]
// TODO : Implement extracting function arrays
fn test_nvidia_llama3_nemotron_super_49b_with_function_array() {
let input = r#"<think>
Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me check the tools available.
</think>
<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let config = ToolCallConfig::nemotron_deci();
let result = try_tool_call_parse(input, &config).unwrap().unwrap();
println!("{:?}", result);
}
#[test]
fn test_qwen_qwq_32b_simple() {
let input = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
</tool_call>"#;
let config = ToolCallConfig {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["</tool_call>".to_string()],
arguments_keys: vec!["arguments".to_string()],
..Default::default()
},
};
let result = try_tool_call_parse(input, &config).unwrap().unwrap();
let result = detect_and_parse_tool_call(input, Some("hermes"))
.unwrap()
.unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
......@@ -226,27 +306,35 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
}
#[test]
#[ignore]
fn test_nousresearch_hermes3_llama31_8b_simple() {
let input = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
</tool_call>"#;
let config = ToolCallConfig {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["</tool_call>".to_string()],
arguments_keys: vec!["arguments".to_string()],
..Default::default()
},
};
let result = try_tool_call_parse(input, &config).unwrap().unwrap();
let result = detect_and_parse_tool_call(input, Some("hermes"))
.unwrap()
.unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
#[ignore]
// TODO : Implement this
fn test_qwen_qwq_32b_multiple_tool_calls() {
let input = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
</tool_call>
<tool_call>
{"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}
</tool_call>
"#;
let config = ToolCallConfig::hermes();
let result = try_tool_call_parse(input, &config).unwrap().unwrap();
println!("{:?}", result);
}
#[test]
#[ignore]
fn test_ibm_granite_40_tiny_preview_simple() {
......@@ -288,25 +376,55 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
}
#[test]
#[ignore]
fn test_meta_llama_llama31_8b_instruct_simple() {
let input = r#"{"name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit"}}"#;
let config = ToolCallConfig {
format: ToolCallParserType::Json,
json: JsonParserConfig {
tool_call_start_tokens: vec![],
tool_call_end_tokens: vec![],
arguments_keys: vec!["parameters".to_string()],
..Default::default()
},
};
let result = try_tool_call_parse(input, &config).unwrap().unwrap();
let result = detect_and_parse_tool_call(input, Some("llama3_json"))
.unwrap()
.unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_meta_llama_llama31_8b_instruct_with_python_tag() {
let input = r#"<|python_tag|>{ "name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
let result = detect_and_parse_tool_call(input, Some("llama3_json"))
.unwrap()
.unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_detect_and_parse_tool_call_error_handling() {
// Unknown parser string should return an error
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco, CA"}}"#;
let result = detect_and_parse_tool_call(input, Some("unknown_parser"));
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("is not implemented"),
"Unexpected error message: {}",
err
);
// Known parser, but invalid input (not JSON) should return Ok(None)
let input = "not a json";
let result = detect_and_parse_tool_call(input, Some("hermes"));
assert!(result.is_ok());
assert!(result.unwrap().is_none());
// Known parser, but valid JSON with wrong shape should return Ok(None)
let input = r#"{"foo": "bar"}"#;
let result = detect_and_parse_tool_call(input, Some("hermes"));
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
#[ignore]
fn test_internlm_internlm2_5_7b_chat_simple() {
......@@ -360,4 +478,34 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_detect_and_parse_tool_call_default_parser_nemotron_deci() {
let input = r#"<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let result = detect_and_parse_tool_call(input, None).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_detect_and_parse_tool_call_default_parser_llama3_json_with_python_tag() {
let input = r#"<|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
let result = detect_and_parse_tool_call(input, None).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_detect_and_parse_tool_call_default_parser_llama3_json_without_python_tag() {
let input = r#"{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
let result = detect_and_parse_tool_call(input, None).unwrap().unwrap();
let (name, args) = extract_name_and_args(result);
assert_eq!(name, "get_weather");
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
}
}
......@@ -6,16 +6,16 @@ pub use crate::preprocessor::tools::request::*;
// Import json_parser from postprocessor module
pub use super::json_parser::*;
pub use super::parsers::{try_tool_call_parse, ToolCallConfig};
pub use super::parsers::{detect_and_parse_tool_call, ToolCallConfig};
/// Try parsing a string as a structured tool call, for aggregation usage.
///
/// If successful, returns a `ChatCompletionMessageToolCall`.
pub fn try_tool_call_parse_aggregate(
message: &str,
parser_str: Option<&str>,
) -> anyhow::Result<Option<async_openai::types::ChatCompletionMessageToolCall>> {
let config = ToolCallConfig::default();
let parsed = try_tool_call_parse(message, &config)?;
let parsed = detect_and_parse_tool_call(message, parser_str)?;
if let Some(parsed) = parsed {
Ok(Some(async_openai::types::ChatCompletionMessageToolCall {
id: parsed.id,
......@@ -35,9 +35,9 @@ pub fn try_tool_call_parse_aggregate(
/// If successful, returns a `ChatCompletionMessageToolCallChunk`.
pub fn try_tool_call_parse_stream(
message: &str,
parser_str: Option<&str>,
) -> anyhow::Result<Option<async_openai::types::ChatCompletionMessageToolCallChunk>> {
let config = ToolCallConfig::default();
let parsed = try_tool_call_parse(message, &config)?;
let parsed = detect_and_parse_tool_call(message, parser_str)?;
if let Some(parsed) = parsed {
Ok(Some(
async_openai::types::ChatCompletionMessageToolCallChunk {
......
......@@ -166,6 +166,7 @@ impl DeltaAggregator {
if let Ok(Some(tool_call)) =
crate::postprocessor::tool_calling::tools::try_tool_call_parse_aggregate(
&choice.text,
None,
)
{
tracing::debug!(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment