Unverified Commit b929e134 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: added utility to detect possible tool call start for a chunk (#2923)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 5ea6b8d7
...@@ -29,11 +29,7 @@ pub fn parse_tool_calls_harmony( ...@@ -29,11 +29,7 @@ pub fn parse_tool_calls_harmony(
// Check if tool call start tokens are present, if not return everything as normal text // Check if tool call start tokens are present, if not return everything as normal text
// Start Token: "<|start|>assistant<|channel|>commentary" should be present in the text if tool calls are present // Start Token: "<|start|>assistant<|channel|>commentary" should be present in the text if tool calls are present
// End Token: "<|call|>" // End Token: "<|call|>"
if !config if !detect_tool_call_start_harmony(text, config) {
.tool_call_start_tokens
.iter()
.any(|token| trimmed.contains(token))
{
return Ok((vec![], Some(trimmed))); return Ok((vec![], Some(trimmed)));
} }
...@@ -158,6 +154,17 @@ pub fn parse_tool_calls_harmony( ...@@ -158,6 +154,17 @@ pub fn parse_tool_calls_harmony(
Ok((res, Some(normal_text.to_string()))) Ok((res, Some(normal_text.to_string())))
} }
pub fn detect_tool_call_start_harmony(chunk: &str, config: &JsonParserConfig) -> bool {
let trimmed = chunk.trim();
if trimmed.is_empty() {
return false;
}
config
.tool_call_start_tokens
.iter()
.any(|token| trimmed.contains(token))
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
...@@ -270,3 +277,32 @@ mod tests { ...@@ -270,3 +277,32 @@ mod tests {
assert_eq!(args["unit"], "celsius"); assert_eq!(args["unit"], "celsius");
} }
} }
#[cfg(test)]
mod detect_parser_tests {
use super::*;
#[test]
fn test_detect_tool_call_start_harmony_chunk_with_tool_call_start_token() {
let text = r#"<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default()
};
let result = detect_tool_call_start_harmony(text, &config);
assert!(result);
}
#[test]
fn test_detect_tool_call_start_harmony_chunk_without_tool_call_start_token() {
let text = r#"<|channel|>commentary to=functions.get_current_weather"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default()
};
let result = detect_tool_call_start_harmony(text, &config);
assert!(!result);
}
}
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
pub mod harmony_parser; pub mod harmony_parser;
pub use super::{config, response}; pub use super::{config, response};
pub use harmony_parser::parse_tool_calls_harmony; pub use harmony_parser::{detect_tool_call_start_harmony, parse_tool_calls_harmony};
...@@ -306,3 +306,133 @@ pub fn try_tool_call_parse_basic_json( ...@@ -306,3 +306,133 @@ pub fn try_tool_call_parse_basic_json(
Ok((vec![], Some(trimmed.to_string()))) Ok((vec![], Some(trimmed.to_string())))
} }
pub fn detect_tool_call_start_basic_json(chunk: &str, config: &JsonParserConfig) -> bool {
let trimmed = chunk.trim();
if trimmed.is_empty() {
return false;
}
config
.tool_call_start_tokens
.iter()
.any(|token| trimmed.contains(token))
|| trimmed.contains('{')
|| trimmed.contains('[')
}
#[cfg(test)]
mod detect_parser_tests {
use super::*;
#[test]
fn detect_tool_call_start_basic_json_chunk_with_tool_call_start_token_hermes() {
let text =
r#"<tool_call>{"name": "search", "parameters": { "query": "rust" } }</tool_call>"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["</tool_call>".to_string()],
..Default::default()
};
let result = detect_tool_call_start_basic_json(text, &config);
assert!(result);
}
#[test]
fn detect_tool_call_start_basic_json_chunk_without_tool_call_start_token() {
let text = r#"{"name": "search", "parameters": { "query": "rust" } }"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["</tool_call>".to_string()],
..Default::default()
};
let result = detect_tool_call_start_basic_json(text, &config);
assert!(result);
}
#[test]
fn detect_tool_call_start_basic_json_chunk_without_tool_call_start_token_with_normal_text() {
let text = r#"Here it is {"name": "#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["</tool_call>".to_string()],
..Default::default()
};
let result = detect_tool_call_start_basic_json(text, &config);
assert!(result);
}
#[test]
fn detect_tool_call_start_basic_json_chunk_with_square_brackets() {
// These kind of false positives are expected when calling this function for stream=True
let text = r#"Here it is [{"name": "search","#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["</tool_call>".to_string()],
..Default::default()
};
let result = detect_tool_call_start_basic_json(text, &config);
assert!(result);
}
#[test]
fn detect_tool_call_start_basic_json_chunk_false_positive() {
// These kind of false positives are expected when calling this function for stream=True
let text = r#"Here it is { Whats up"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<tool_call>".to_string()],
tool_call_end_tokens: vec!["</tool_call>".to_string()],
..Default::default()
};
let result = detect_tool_call_start_basic_json(text, &config);
assert!(result);
}
#[test]
fn detect_tool_call_start_basic_json_chunk_with_tool_call_start_token_nemotron_deci() {
let text =
r#"<TOOLCALL>[{"name": "search", "parameters": { "query": "rust" } }]</TOOLCALL>"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<TOOLCALL>".to_string()],
tool_call_end_tokens: vec!["</TOOLCALL>".to_string()],
..Default::default()
};
let result = detect_tool_call_start_basic_json(text, &config);
assert!(result);
}
#[test]
fn detect_tool_call_start_basic_json_chunk_with_lllama3_json_token() {
let text = r#"<|python_tag|>{ "name": }"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|python_tag|>".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
};
let result = detect_tool_call_start_basic_json(text, &config);
assert!(result);
}
#[test]
fn detect_tool_call_start_basic_json_chunk_mistral_token() {
let text = r#"Hello Yo ! [TOOL_CALLS]{"name": "search", "#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["[TOOL_CALLS]".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
};
let result = detect_tool_call_start_basic_json(text, &config);
assert!(result);
}
#[test]
fn detect_tool_call_start_basic_json_chunk_phi4_token() {
let text = r#"functools{"name": "search", "#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["functools".to_string()],
tool_call_end_tokens: vec!["".to_string()],
..Default::default()
};
let result = detect_tool_call_start_basic_json(text, &config);
assert!(result);
}
}
...@@ -43,12 +43,7 @@ pub fn parse_tool_calls_deepseek_v3_1( ...@@ -43,12 +43,7 @@ pub fn parse_tool_calls_deepseek_v3_1(
} }
// If tool call start token is not present then, no tool calls are there, return empty tool calls and the original trimmed string // If tool call start token is not present then, no tool calls are there, return empty tool calls and the original trimmed string
if let Some(start_token) = tool_call_start_tokens.first() { if !detect_tool_call_start_deepseek_v3_1(trimmed, config) {
if !trimmed.contains(start_token) {
return Ok((vec![], Some(trimmed.to_string())));
}
} else {
// Invalid start token
return Ok((vec![], Some(trimmed.to_string()))); return Ok((vec![], Some(trimmed.to_string())));
} }
...@@ -106,6 +101,15 @@ pub fn parse_tool_calls_deepseek_v3_1( ...@@ -106,6 +101,15 @@ pub fn parse_tool_calls_deepseek_v3_1(
Ok((tool_calls, Some(normal_text))) Ok((tool_calls, Some(normal_text)))
} }
pub fn detect_tool_call_start_deepseek_v3_1(chunk: &str, config: &JsonParserConfig) -> bool {
let trimmed = chunk.trim();
!trimmed.is_empty()
&& config
.tool_call_start_tokens
.iter()
.any(|token| trimmed.contains(token))
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
...@@ -220,3 +224,43 @@ mod tests { ...@@ -220,3 +224,43 @@ mod tests {
assert_eq!(result.len(), 0); assert_eq!(result.len(), 0);
} }
} }
#[cfg(test)]
mod detect_parser_tests {
use super::*;
#[test]
fn test_detect_tool_call_start_deepseek_v3_1_chunk_with_tool_call_start_token() {
let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather宽带}"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
..Default::default()
};
let result = detect_tool_call_start_deepseek_v3_1(text, &config);
assert!(result);
}
#[test]
fn test_detect_tool_call_start_deepseek_v3_1_chunk_without_tool_call_start_token() {
let text = r#"<|tool▁call▁begin|>get_current_weather宽带}"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
..Default::default()
};
let result = detect_tool_call_start_deepseek_v3_1(text, &config);
assert!(!result);
}
#[test]
fn test_detect_tool_call_start_deepseek_v3_1_chunk_with_tool_call_start_token_in_middle() {
let text = r#"The following tool calls retrieve weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather宽带}"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
..Default::default()
};
let result = detect_tool_call_start_deepseek_v3_1(text, &config);
assert!(result);
}
}
...@@ -5,8 +5,8 @@ pub mod base_json_parser; ...@@ -5,8 +5,8 @@ pub mod base_json_parser;
pub mod deepseek_parser; pub mod deepseek_parser;
pub use super::{config, response}; pub use super::{config, response};
pub use base_json_parser::try_tool_call_parse_basic_json; pub use base_json_parser::{detect_tool_call_start_basic_json, try_tool_call_parse_basic_json};
pub use deepseek_parser::parse_tool_calls_deepseek_v3_1; pub use deepseek_parser::{detect_tool_call_start_deepseek_v3_1, parse_tool_calls_deepseek_v3_1};
pub use super::config::JsonParserConfig; pub use super::config::JsonParserConfig;
pub use super::response::ToolCallResponse; pub use super::response::ToolCallResponse;
...@@ -34,3 +34,10 @@ pub fn try_tool_call_parse_json( ...@@ -34,3 +34,10 @@ pub fn try_tool_call_parse_json(
JsonParserType::DeepseekV31 => parse_tool_calls_deepseek_v3_1(message, config), JsonParserType::DeepseekV31 => parse_tool_calls_deepseek_v3_1(message, config),
} }
} }
pub fn detect_tool_call_start_json(chunk: &str, config: &JsonParserConfig) -> bool {
match config.parser_type {
JsonParserType::Basic => detect_tool_call_start_basic_json(chunk, config),
JsonParserType::DeepseekV31 => detect_tool_call_start_deepseek_v3_1(chunk, config),
}
}
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::config::{ToolCallConfig, ToolCallParserType}; use super::config::{ToolCallConfig, ToolCallParserType};
use super::harmony::parse_tool_calls_harmony; use super::harmony::{detect_tool_call_start_harmony, parse_tool_calls_harmony};
use super::json::try_tool_call_parse_json; use super::json::{detect_tool_call_start_json, try_tool_call_parse_json};
use super::pythonic::try_tool_call_parse_pythonic; use super::pythonic::{detect_tool_call_start_pythonic, try_tool_call_parse_pythonic};
use super::response::ToolCallResponse; use super::response::ToolCallResponse;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::OnceLock; use std::sync::OnceLock;
...@@ -86,6 +86,33 @@ pub fn detect_and_parse_tool_call( ...@@ -86,6 +86,33 @@ pub fn detect_and_parse_tool_call(
} }
} }
pub fn detect_tool_call_start(chunk: &str, parser_str: Option<&str>) -> anyhow::Result<bool> {
let parser_map = get_tool_parser_map();
let parser_key = match parser_str {
Some(s) if !s.is_empty() => s,
_ => "default", // None or empty string
};
match parser_map.get(parser_key) {
Some(config) => match config.format {
ToolCallParserType::Json => Ok(detect_tool_call_start_json(chunk, &config.json)),
ToolCallParserType::Harmony => Ok(detect_tool_call_start_harmony(chunk, &config.json)),
ToolCallParserType::Pythonic => Ok(detect_tool_call_start_pythonic(chunk)),
ToolCallParserType::Typescript => {
anyhow::bail!("Typescript parser not implemented");
}
ToolCallParserType::Xml => {
anyhow::bail!("Xml parser not implemented");
}
},
None => anyhow::bail!(
"Parser '{}' is not implemented. Available parsers: {:?}",
parser_key,
get_available_tool_parsers()
),
}
}
// Tests // Tests
// cargo test postprocessor::tool_calling::parsers // cargo test postprocessor::tool_calling::parsers
#[cfg(test)] #[cfg(test)]
...@@ -1200,3 +1227,67 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it ...@@ -1200,3 +1227,67 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
assert_eq!(args["unit"], "celsius"); assert_eq!(args["unit"], "celsius");
} }
} }
#[cfg(test)]
// Just e2e tests to test the flow. Detailed tests are covered in the individual parsers
mod detect_parser_tests {
use super::*;
#[test]
fn test_e2e_detect_tool_call_start_harmony() {
let text = r#"<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json"#;
let result = detect_tool_call_start(text, Some("harmony")).unwrap();
assert!(result);
}
#[test]
fn test_e2e_detect_tool_call_start_hermes() {
let text = r#"{"name": "get_current_weather", "parameters": {"location": "Tokyo"}}"#;
let result = detect_tool_call_start(text, Some("hermes")).unwrap();
assert!(result);
}
#[test]
fn test_e2e_detect_tool_call_start_pythonic() {
let text = r#"foo(a=1, b=2), bar(x=3)]"#;
let result = detect_tool_call_start(text, Some("pythonic")).unwrap();
assert!(!result);
}
#[test]
fn test_e2e_detect_tool_call_start_nemotron_deci() {
let text = r#"<TOOLCALL>[{"name": "get_current_weather", "parameters": {"location": "Tokyo"}}]</TOOLCALL>"#;
let result = detect_tool_call_start(text, Some("nemotron_deci")).unwrap();
assert!(result);
}
#[test]
fn test_e2e_detect_tool_call_start_phi4() {
let text =
r#"functools{"name": "get_current_weather", "parameters": {"location": "Tokyo"}}"#;
let result = detect_tool_call_start(text, Some("phi4")).unwrap();
assert!(result);
}
#[test]
fn test_e2e_detect_tool_call_start_llama3_json() {
let text = r#"<|python_tag|>{ "name": "get_current_weather", "parameters": {"location": "Tokyo"}}"#;
let result = detect_tool_call_start(text, Some("llama3_json")).unwrap();
assert!(result);
}
#[test]
fn test_e2e_detect_tool_call_start_mistral() {
let text =
r#"[TOOL_CALLS]{"name": "get_current_weather", "parameters": {"location": "Tokyo"}}"#;
let result = detect_tool_call_start(text, Some("mistral")).unwrap();
assert!(result);
}
#[test]
fn test_e2e_detect_tool_call_start_deepseek_v3_1() {
let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather{"location": "Tokyo"}<|tool▁call▁end|>"#;
let result = detect_tool_call_start(text, Some("deepseek_v3_1")).unwrap();
assert!(result);
}
}
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
pub mod pythonic_parser; pub mod pythonic_parser;
pub use super::{config, response}; pub use super::{config, response};
pub use pythonic_parser::try_tool_call_parse_pythonic; pub use pythonic_parser::{detect_tool_call_start_pythonic, try_tool_call_parse_pythonic};
...@@ -187,6 +187,16 @@ pub fn try_tool_call_parse_pythonic( ...@@ -187,6 +187,16 @@ pub fn try_tool_call_parse_pythonic(
Ok((tool_response?, Some(normal_text))) Ok((tool_response?, Some(normal_text)))
} }
pub fn detect_tool_call_start_pythonic(chunk: &str) -> bool {
let trimmed = chunk.trim();
// Early return for empty input
if trimmed.is_empty() {
return false;
}
// Heuristic: Pythonic tool calls always start with a '[' somewhere in the chunk
trimmed.contains('[')
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
...@@ -353,3 +363,37 @@ mod tests { ...@@ -353,3 +363,37 @@ mod tests {
assert_eq!(args["x"], json!({"x": 3, "y": {"e": "f"}})); assert_eq!(args["x"], json!({"x": 3, "y": {"e": "f"}}));
} }
} }
#[cfg(test)]
mod detect_parser_tests {
use super::*;
#[test]
fn test_detect_tool_call_start_pythonic_chunk_with_tool_call_start_token() {
let text = r#"[foo(a=1, b=2), bar(x=3)]"#;
let result = detect_tool_call_start_pythonic(text);
assert!(result);
}
#[test]
fn test_detect_tool_call_start_pythonic_chunk_without_tool_call_start_token() {
let text = r#"foo(a=1, b=2)"#;
let result = detect_tool_call_start_pythonic(text);
assert!(!result);
}
#[test]
fn test_detect_tool_call_start_pythonic_chunk_with_tool_call_start_token_in_middle() {
let text = r#"information: [foo(a=1, b=2), bar(x=3)]"#;
let result = detect_tool_call_start_pythonic(text);
assert!(result);
}
#[test]
fn test_detect_tool_call_start_pythonic_false_positive() {
// Since we detect just "[" as tool call start token, this will be a false positive
let text = r#"Hey [ There is one tool call here . foo(a=1, b=2)"#;
let result = detect_tool_call_start_pythonic(text);
assert!(result);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment