Unverified Commit ff7aea54 authored by William Zhang's avatar William Zhang Committed by GitHub
Browse files

feat: Pass tool definitions to parsers (#4948)


Signed-off-by: default avatarWilliam Zhang <133824995+2ez4bz@users.noreply.github.com>
parent 7043707e
......@@ -801,6 +801,7 @@ impl OpenAIPreprocessor {
pub fn apply_tool_calling_jail<S>(
tool_call_parser: Option<String>,
tool_choice: Option<dynamo_async_openai::types::ChatCompletionToolChoiceOption>,
tool_definitions: Option<Vec<dynamo_parsers::tool_calling::ToolDefinition>>,
stream: S,
) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
where
......@@ -810,6 +811,13 @@ impl OpenAIPreprocessor {
let mut builder = JailedStream::builder();
// Set tool definitions if provided
if let Some(tool_definitions) = tool_definitions
&& !tool_definitions.is_empty()
{
builder = builder.tool_definitions(tool_definitions);
}
// Configure jail based on tool_choice
match tool_choice {
Some(ChatCompletionToolChoiceOption::Named(named)) => {
......@@ -991,11 +999,23 @@ impl
has_tools,
)?;
// Convert OpenAI tools to parser ToolDefinition format before applying jail
let tool_definitions = request.inner.tools.as_ref().map(|tools| {
tools
.iter()
.map(|tool| dynamo_parsers::tool_calling::ToolDefinition {
name: tool.function.name.clone(),
parameters: tool.function.parameters.clone(),
})
.collect()
});
// Apply jail conditionally
let transformed_stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_jail {
Box::pin(Self::apply_tool_calling_jail(
self.tool_call_parser.clone(),
request.inner.tool_choice.clone(),
tool_definitions,
stream,
))
} else {
......
......@@ -470,6 +470,7 @@ pub struct JailedStream {
jail_start_sequences: Vec<String>,
jail_end_sequences: Vec<String>,
tool_call_parser: Option<String>,
tool_definitions: Option<Vec<dynamo_parsers::tool_calling::ToolDefinition>>,
emission_mode: EmissionMode,
marker_matcher: MarkerMatcher,
jail_mode: JailMode,
......@@ -758,8 +759,13 @@ impl JailedStream {
} else if early_exit {
// For early exit, find where the complete tool call ends
if let Some(parser) = &self.tool_call_parser {
if let Ok((_, _)) =
try_tool_call_parse_aggregate(accumulated_content, Some(parser)).await
let tools_slice = self.tool_definitions.as_deref();
if let Ok((_, _)) = try_tool_call_parse_aggregate(
accumulated_content,
Some(parser),
tools_slice,
)
.await
{
let split_pos =
find_tool_call_end_position(accumulated_content, Some(parser));
......@@ -814,9 +820,11 @@ impl JailedStream {
match &self.jail_mode {
JailMode::MarkerBased => {
// Traditional marker-based tool call parsing
let tools_slice = self.tool_definitions.as_deref();
if let Ok((tool_calls, normal_text)) = try_tool_call_parse_aggregate(
accumulated_content,
self.tool_call_parser.as_deref(),
tools_slice,
)
.await
&& !tool_calls.is_empty()
......@@ -952,7 +960,8 @@ impl JailedStream {
async fn should_exit_jail_early(&self, accumulated: &str) -> bool {
if let Some(ref parser) = self.tool_call_parser {
// Try to parse - if successful and we have complete tool calls, exit early
match try_tool_call_parse_aggregate(accumulated, Some(parser)).await {
let tools_slice = self.tool_definitions.as_deref();
match try_tool_call_parse_aggregate(accumulated, Some(parser), tools_slice).await {
Ok((tool_calls, _normal_text)) => {
let result = !tool_calls.is_empty();
return result;
......@@ -1034,6 +1043,7 @@ pub struct JailedStreamBuilder {
jail_start_sequences: Vec<String>,
jail_end_sequences: Vec<String>,
tool_call_parser: Option<String>,
tool_definitions: Option<Vec<dynamo_parsers::tool_calling::ToolDefinition>>,
emission_mode: EmissionMode,
jail_mode: JailMode,
}
......@@ -1045,6 +1055,7 @@ impl JailedStreamBuilder {
jail_start_sequences: Vec::new(),
jail_end_sequences: Vec::new(),
tool_call_parser: None,
tool_definitions: None,
emission_mode: EmissionMode::default(),
jail_mode: JailMode::MarkerBased,
}
......@@ -1088,6 +1099,15 @@ impl JailedStreamBuilder {
self
}
/// Set the tool definitions for runtime validation and parsing
pub fn tool_definitions(
mut self,
tools: Vec<dynamo_parsers::tool_calling::ToolDefinition>,
) -> Self {
self.tool_definitions = Some(tools);
self
}
/// Set the emission mode for handling multiple choices
pub fn emission_mode(mut self, mode: EmissionMode) -> Self {
self.emission_mode = mode;
......@@ -1198,6 +1218,7 @@ impl JailedStreamBuilder {
jail_start_sequences: self.jail_start_sequences,
jail_end_sequences: self.jail_end_sequences,
tool_call_parser: self.tool_call_parser,
tool_definitions: self.tool_definitions,
emission_mode: self.emission_mode,
marker_matcher,
jail_mode: self.jail_mode,
......
......@@ -197,7 +197,7 @@ async fn test_parallel_tool_call_parsing() {
// Parse the tool calls using the hermes parser (works well with <tool_call> format)
let (tool_calls, remaining_content) =
detect_and_parse_tool_call(&response_content, Some("hermes"))
detect_and_parse_tool_call(&response_content, Some("hermes"), None)
.await
.expect("Should successfully parse tool calls");
......@@ -239,7 +239,7 @@ async fn test_parallel_tool_call_with_explicit_parser() {
for parser in parsers_to_test {
let (tool_calls, remaining_content) =
detect_and_parse_tool_call(&response_content, Some(parser))
detect_and_parse_tool_call(&response_content, Some(parser), None)
.await
.unwrap_or_else(|e| panic!("Should successfully parse with {parser} parser: {e}"));
......@@ -267,7 +267,7 @@ async fn test_parallel_tool_call_with_explicit_parser() {
async fn test_tool_call_json_structure() {
let response_content = get_mock_response_content();
let (tool_calls, _) = detect_and_parse_tool_call(&response_content, Some("hermes"))
let (tool_calls, _) = detect_and_parse_tool_call(&response_content, Some("hermes"), None)
.await
.expect("Should parse tool calls");
......@@ -288,7 +288,7 @@ async fn test_tool_call_json_structure() {
async fn test_openai_compatibility_structure() {
let response_content = get_mock_response_content();
let (tool_calls, _) = detect_and_parse_tool_call(&response_content, Some("hermes"))
let (tool_calls, _) = detect_and_parse_tool_call(&response_content, Some("hermes"), None)
.await
.expect("Should parse tool calls");
......@@ -335,7 +335,7 @@ async fn test_parallel_tool_call_error_handling() {
{"invalid_json": }
</tool_call>"#;
let result = detect_and_parse_tool_call(malformed_content, Some("hermes")).await;
let result = detect_and_parse_tool_call(malformed_content, Some("hermes"), None).await;
// Should handle partial parsing gracefully
match result {
......@@ -368,7 +368,7 @@ async fn test_empty_tool_calls() {
let content_without_tools = "This is just a regular response without any tool calls.";
let (tool_calls, remaining_content) =
detect_and_parse_tool_call(content_without_tools, Some("hermes"))
detect_and_parse_tool_call(content_without_tools, Some("hermes"), None)
.await
.expect("Should handle content without tool calls");
......@@ -412,7 +412,7 @@ async fn test_deepseek_v3_1_tool_call_parsing() {
// Parse the tool calls using the deepseek_v3_1 parser
let (tool_calls, remaining_content) =
detect_and_parse_tool_call(response_content, Some("deepseek_v3_1"))
detect_and_parse_tool_call(response_content, Some("deepseek_v3_1"), None)
.await
.expect("Should successfully parse deepseek_v3_1 tool calls");
......
......@@ -2045,6 +2045,8 @@ mod tests {
#[tokio::test]
async fn test_jailed_stream_qwen3_coder_multiple_params() {
use dynamo_parsers::tool_calling::ToolDefinition;
let chunks = vec![
create_mock_response_chunk("Let me search for that. ".to_string(), 0),
create_mock_response_chunk(
......@@ -2054,9 +2056,23 @@ mod tests {
create_mock_response_chunk(" Searching now.".to_string(), 0),
];
// Define the web_search tool with its parameters
let tool_defs = vec![ToolDefinition {
name: "web_search".to_string(),
parameters: Some(serde_json::json!({
"type": "object",
"properties": {
"query": {"type": "string"},
"max_results": {"type": "integer"},
"filter": {"type": "string"},
},
})),
}];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder()
.tool_call_parser("qwen3_coder")
.tool_definitions(tool_defs)
.build();
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
......
......@@ -487,6 +487,7 @@ mod tests {
let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail(
Some("nemotron_deci".to_string()),
None, // No tool_choice in this test
None, // No tool_definitions in this test
reasoning_parsed_stream,
);
......@@ -600,6 +601,7 @@ mod tests {
let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail(
Some("harmony".to_string()),
None, // No tool_choice in this test
None, // No tool_definitions in this test
reasoning_parsed_stream,
);
......
......@@ -160,6 +160,7 @@ async fn parse_response_stream(
Box::pin(OpenAIPreprocessor::apply_tool_calling_jail(
Some(tool_parser),
None, // No tool_choice in this test
None, // No tool_definitions in this test
stream,
))
} else {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::super::ToolDefinition;
use super::config::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
use openai_harmony::chat::{Content::Text, Role};
......@@ -46,6 +47,7 @@ pub async fn get_harmony_encoding() -> &'static Result<HarmonyEncoding, anyhow::
pub async fn parse_tool_calls_harmony_complete(
text: &str,
_config: &JsonParserConfig,
_tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
let enc = match get_harmony_encoding().await.as_ref() {
Ok(e) => e,
......@@ -212,7 +214,7 @@ mod tests {
async fn test_parse_tool_calls_harmony_complete_basic() {
let text = r#"<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"format":"celsius","location":"San Francisco"}"#;
let (tool_calls, normal_content) =
parse_tool_calls_harmony_complete(text, &Default::default())
parse_tool_calls_harmony_complete(text, &Default::default(), None)
.await
.unwrap();
assert_eq!(normal_content, Some("".to_string()));
......@@ -226,7 +228,7 @@ mod tests {
async fn test_parse_tools_harmony_without_start_token() {
let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|message|>{"location":"San Francisco"}<|call|>"#;
let (tool_calls, normal_content) =
parse_tool_calls_harmony_complete(text, &Default::default())
parse_tool_calls_harmony_complete(text, &Default::default(), None)
.await
.unwrap();
assert_eq!(normal_content, Some(text.trim().to_string()));
......@@ -237,7 +239,7 @@ mod tests {
async fn test_parse_tool_calls_harmony_with_multi_args() {
let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco", "unit":"fahrenheit"}<|call|>"#;
let (tool_calls, normal_content) =
parse_tool_calls_harmony_complete(text, &Default::default())
parse_tool_calls_harmony_complete(text, &Default::default(), None)
.await
.unwrap();
assert_eq!(
......@@ -255,7 +257,7 @@ mod tests {
async fn test_parse_tool_calls_harmony_with_normal_text() {
let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco"}<|call|>"#;
let (tool_calls, normal_content) =
parse_tool_calls_harmony_complete(text, &Default::default())
parse_tool_calls_harmony_complete(text, &Default::default(), None)
.await
.unwrap();
assert_eq!(
......@@ -272,7 +274,7 @@ mod tests {
async fn test_parse_tool_calls_harmony_without_call_token() {
let text = r#"<|channel|>analysis<|message|>We need to call get_weather function. The user asks "What's the weather like in San Francisco in Celsius?" So location: "San Francisco, CA" unit: "celsius". Let's call function.<|end|><|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{"location":"San Francisco, CA","unit":"celsius"}"#;
let (tool_calls, normal_content) =
parse_tool_calls_harmony_complete(text, &Default::default())
parse_tool_calls_harmony_complete(text, &Default::default(), None)
.await
.unwrap();
assert_eq!(normal_content, Some("We need to call get_weather function. The user asks \"What's the weather like in San Francisco in Celsius?\" So location: \"San Francisco, CA\" unit: \"celsius\". Let's call function.".to_string()));
......
......@@ -7,6 +7,7 @@ use regex::RegexBuilder;
use serde_json::Value;
use uuid::Uuid;
use super::super::ToolDefinition;
use super::config::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
......@@ -165,6 +166,7 @@ fn try_parse_normal_text(input: &str, start_token: &str) -> String {
pub fn try_tool_call_parse_basic_json(
message: &str,
config: &JsonParserConfig,
_tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
// Log the config we are using
tracing::debug!("Using JSON parser config: {:?}", config);
......
......@@ -5,6 +5,7 @@ use regex::RegexBuilder;
use serde_json::Value;
use uuid::Uuid;
use super::super::ToolDefinition;
use super::config::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
......@@ -119,6 +120,7 @@ fn parse_single_tool_call_v3_1(
pub fn parse_tool_calls_deepseek_v3_1(
message: &str,
config: &JsonParserConfig,
_tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
// Format Structure:
// <|tool▁calls▁begin|><|tool▁call▁begin|>{function_name}<|tool▁sep|>{json_arguments}<|tool▁call▁end|><|tool▁calls▁end|>
......@@ -275,7 +277,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
......@@ -293,7 +295,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(
content,
Some("The following tool call retrieves weather information: ".to_string())
......@@ -311,7 +313,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(content, Some(text.to_string()));
assert_eq!(result.len(), 0);
}
......@@ -323,7 +325,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 3);
let (name, args) = extract_name_and_args(result[0].clone());
......@@ -349,7 +351,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0);
}
......@@ -362,7 +364,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0);
}
......@@ -388,7 +390,7 @@ mod tests {
};
let (tool_call_results, normal_content) =
parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(tool_call_results.len(), 1);
......
......@@ -5,6 +5,7 @@ use regex::RegexBuilder;
use serde_json::Value;
use uuid::Uuid;
use super::super::ToolDefinition;
use super::config::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
......@@ -129,6 +130,7 @@ fn parse_single_tool_call_v3(block: &str, separator_tokens: &[String]) -> Option
pub fn parse_tool_calls_deepseek_v3(
message: &str,
config: &JsonParserConfig,
_tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
// Format Structure:
// <|tool▁calls▁begin|><|tool▁call▁begin|>{type}<|tool▁sep|>{function_name}\n```json\n{json_arguments}\n```<|tool▁call▁end|><|tool▁calls▁end|>
......@@ -285,7 +287,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
......@@ -306,7 +308,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(
content,
Some("The following tool call retrieves weather information: ".to_string())
......@@ -327,7 +329,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(content, Some(text.to_string()));
assert_eq!(result.len(), 0);
}
......@@ -348,7 +350,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 3);
let (name, args) = extract_name_and_args(result[0].clone());
......@@ -377,7 +379,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0);
}
......@@ -399,7 +401,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0);
}
......@@ -428,7 +430,7 @@ mod tests {
};
let (tool_call_results, normal_content) =
parse_tool_calls_deepseek_v3(text, &config).unwrap();
parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(tool_call_results.len(), 1);
......
......@@ -33,11 +33,12 @@ impl Default for JsonParserType {
pub fn try_tool_call_parse_json(
message: &str,
config: &JsonParserConfig,
tools: Option<&[super::ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
match config.parser_type {
JsonParserType::Basic => try_tool_call_parse_basic_json(message, config),
JsonParserType::DeepseekV3 => parse_tool_calls_deepseek_v3(message, config),
JsonParserType::DeepseekV31 => parse_tool_calls_deepseek_v3_1(message, config),
JsonParserType::Basic => try_tool_call_parse_basic_json(message, config, tools),
JsonParserType::DeepseekV3 => parse_tool_calls_deepseek_v3(message, config, tools),
JsonParserType::DeepseekV31 => parse_tool_calls_deepseek_v3_1(message, config, tools),
}
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use serde_json::Value;
pub mod config;
pub mod dsml;
pub mod harmony;
......@@ -13,6 +15,13 @@ pub mod tests;
pub mod tools;
pub mod xml;
/// Represents a tool definition with function schema.
#[derive(Debug, Clone)]
pub struct ToolDefinition {
pub name: String,
pub parameters: Option<Value>,
}
// Re-export main types and functions for convenience
pub use config::{JsonParserConfig, ParserConfig, ToolCallConfig, XmlParserConfig};
pub use dsml::try_tool_call_parse_dsml;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::ToolDefinition;
use super::config::{ParserConfig, ToolCallConfig};
use super::dsml::{
detect_tool_call_start_dsml, find_tool_call_end_position_dsml, try_tool_call_parse_dsml,
......@@ -53,27 +54,28 @@ pub fn get_available_tool_parsers() -> Vec<&'static str> {
pub async fn try_tool_call_parse(
message: &str,
config: &ToolCallConfig,
tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
// Use match statement (Rust's switch statement) to call the appropriate parser
match &config.parser_config {
ParserConfig::Json(json_config) => {
let (results, normal_content) = try_tool_call_parse_json(message, json_config)?;
let (results, normal_content) = try_tool_call_parse_json(message, json_config, tools)?;
Ok((results, normal_content))
}
ParserConfig::Harmony(json_config) => {
let (results, normal_content) =
parse_tool_calls_harmony_complete(message, json_config).await?;
parse_tool_calls_harmony_complete(message, json_config, tools).await?;
Ok((results, normal_content))
}
ParserConfig::Pythonic => {
let (results, normal_content) = try_tool_call_parse_pythonic(message)?;
let (results, normal_content) = try_tool_call_parse_pythonic(message, tools)?;
Ok((results, normal_content))
}
ParserConfig::Typescript => {
anyhow::bail!("Typescript parser not implemented");
}
ParserConfig::Xml(xml_config) => {
let (results, normal_content) = try_tool_call_parse_xml(message, xml_config)?;
let (results, normal_content) = try_tool_call_parse_xml(message, xml_config, tools)?;
Ok((results, normal_content))
}
ParserConfig::Dsml(dsml_config) => {
......@@ -87,6 +89,7 @@ pub async fn try_tool_call_parse(
pub async fn detect_and_parse_tool_call(
message: &str,
parser_str: Option<&str>,
tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
// Get the tool parser map
let parser_map = get_tool_parser_map();
......@@ -99,7 +102,7 @@ pub async fn detect_and_parse_tool_call(
match parser_map.get(parser_key) {
Some(config) => {
let (results, normal_content) = try_tool_call_parse(message, config).await?;
let (results, normal_content) = try_tool_call_parse(message, config, tools).await?;
Ok((results, normal_content))
}
None => anyhow::bail!(
......@@ -213,7 +216,7 @@ mod tests {
#[tokio::test]
async fn parses_single_parameters_object() {
let input = r#"{ "name": "hello", "parameters": { "x": 1, "y": 2 } }"#;
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::default())
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::default(), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -228,7 +231,7 @@ mod tests {
#[tokio::test]
async fn parses_single_arguments_object() {
let input = r#"{ "name": "world", "arguments": { "a": "abc", "b": 42 } }"#;
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::default())
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::default(), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -243,7 +246,7 @@ mod tests {
#[tokio::test]
async fn parses_vec_of_parameters() {
let input = r#"[{ "name": "first", "parameters": { "a": 1 } }, { "name": "second", "parameters": { "b": 2 } }]"#;
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::default())
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::default(), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -260,7 +263,7 @@ mod tests {
#[tokio::test]
async fn parses_vec_of_arguments() {
let input = r#"[{ "name": "alpha", "arguments": { "a": "x" } }, { "name": "omega", "arguments": { "z": "y" } }]"#;
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::default())
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::default(), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -278,7 +281,7 @@ mod tests {
async fn parses_toolcall_wrapped_payload() {
let input =
r#"<TOOLCALL>[{ "name": "wrapped", "parameters": { "foo": "bar" } }]</TOOLCALL>"#;
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::default())
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::default(), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -301,6 +304,7 @@ mod tests {
..Default::default()
}),
},
None,
)
.await
.unwrap();
......@@ -315,7 +319,7 @@ mod tests {
#[tokio::test]
async fn returns_none_on_invalid_input() {
let input = r#"not even json"#;
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::default())
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::default(), None)
.await
.unwrap();
assert_eq!(content, Some("not even json".to_string()));
......@@ -325,7 +329,7 @@ mod tests {
#[tokio::test]
async fn returns_none_on_valid_json_wrong_shape() {
let input = r#"{ "foo": "bar" }"#;
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::default())
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::default(), None)
.await
.unwrap();
assert_eq!(content, Some("{ \"foo\": \"bar\" }".to_string()));
......@@ -340,7 +344,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
</think>
<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"), None)
.await
.unwrap();
assert!(!result.is_empty());
......@@ -355,7 +359,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
#[tokio::test]
async fn test_nvidia_llama3_nemotron_super_49b_simple_with_no_think() {
let input = r#"<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"), None)
.await
.unwrap();
assert!(!result.is_empty());
......@@ -375,7 +379,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let config = ToolCallConfig::nemotron_deci();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("<think>\nOkay, the user is asking for the weather in San Francisco in Fahrenheit. Let me check the tools available.\n</think>".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -406,7 +410,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
</TOOLCALL>
"#;
let config = ToolCallConfig::nemotron_deci();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("<think>\nOkay, the user is asking for the weather in San Francisco in Fahrenheit. Let me check the tools available.\n</think>".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -425,7 +429,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
let input = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
</tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("hermes"))
let (result, content) = detect_and_parse_tool_call(input, Some("hermes"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -442,7 +446,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
let input = r#"Hey How are you? <tool_call>
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
</tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("hermes"))
let (result, content) = detect_and_parse_tool_call(input, Some("hermes"), None)
.await
.unwrap();
assert_eq!(content, Some("Hey How are you?".to_string()));
......@@ -455,7 +459,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
let input = r#"<tool_call>
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
</tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("hermes"))
let (result, content) = detect_and_parse_tool_call(input, Some("hermes"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -477,7 +481,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
</tool_call>
"#;
let config = ToolCallConfig::hermes();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -501,7 +505,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
</tool_call>
"#;
let config = ToolCallConfig::hermes();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("Hey How are you?".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -529,7 +533,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
</tool_call>
"#;
let config = ToolCallConfig::hermes();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -555,7 +559,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
..Default::default()
}),
};
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
......@@ -569,7 +573,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
async fn test_mistralai_mistral_7b_instruct_v03_simple() {
let input = r#" [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig::mistral();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
......@@ -583,7 +587,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
async fn test_mistralai_mistral_7b_instruct_v03_simple_with_normal_text() {
let input = r#"Hey How are you? [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig::mistral();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("Hey How are you?".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
......@@ -602,7 +606,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
"unit": "fahrenheit"}}]
"#;
let config = ToolCallConfig::mistral();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
......@@ -616,7 +620,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
async fn test_mistralai_mistral_7b_instruct_v03_multiple() {
let input = r#" [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig::mistral();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -634,7 +638,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
async fn test_mistralai_mistral_7b_instruct_v03_multiple_with_normal_text() {
let input = r#"Hey How are you? [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig::mistral();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("Hey How are you?".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -660,7 +664,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
"fahrenheit"}}]
"#;
let config = ToolCallConfig::mistral();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -678,7 +682,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
async fn test_mistralai_mistral_7b_instruct_v03_single_with_start_token() {
let input = r#"[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig::mistral();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
......@@ -692,7 +696,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
async fn test_mistralai_mistral_7b_instruct_v03_single_with_start_token_with_normal_text() {
let input = r#"Hey How are you? [TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig::mistral();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("Hey How are you?".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
......@@ -712,7 +716,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
"unit": "fahrenheit"}}]
"#;
let config = ToolCallConfig::mistral();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
......@@ -726,7 +730,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
async fn test_mistralai_mistral_7b_instruct_v03_single_with_start_token_multiple() {
let input = r#"[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig::mistral();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -745,7 +749,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
{
let input = r#"Hey How are you? [TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]"#;
let config = ToolCallConfig::mistral();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("Hey How are you?".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -773,7 +777,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
"fahrenheit"}}]
"#;
let config = ToolCallConfig::mistral();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -790,7 +794,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
#[tokio::test]
async fn test_meta_llama_llama31_8b_instruct_simple() {
let input = r#"{"name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit"}}"#;
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::mistral())
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::mistral(), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -805,7 +809,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
#[tokio::test]
async fn test_meta_llama_llama31_8b_instruct_simple_with_normal_text() {
let input = r#"Hey How are you? {"name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit"}}"#;
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::mistral())
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::mistral(), None)
.await
.unwrap();
assert_eq!(content, Some("Hey How are you?".to_string()));
......@@ -823,7 +827,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
{"name": "get_weather",
"parameters": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
"#;
let (result, content) = detect_and_parse_tool_call(input, Some("llama3_json"))
let (result, content) = detect_and_parse_tool_call(input, Some("llama3_json"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -838,7 +842,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
#[tokio::test]
async fn test_meta_llama_llama31_8b_instruct_with_python_tag() {
let input = r#"<|python_tag|>{ "name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
let (result, content) = detect_and_parse_tool_call(input, Some("llama3_json"))
let (result, content) = detect_and_parse_tool_call(input, Some("llama3_json"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -853,7 +857,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
#[tokio::test]
async fn test_meta_llama_llama31_8b_instruct_with_python_tag_with_normal_text() {
let input = r#"Hey How are you? <|python_tag|>{ "name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
let (result, content) = detect_and_parse_tool_call(input, Some("llama3_json"))
let (result, content) = detect_and_parse_tool_call(input, Some("llama3_json"), None)
.await
.unwrap();
assert_eq!(content, Some("Hey How are you?".to_string()));
......@@ -871,7 +875,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
<|python_tag|>
{"name": "get_weather", "parameters": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
"#;
let (result, content) = detect_and_parse_tool_call(input, Some("llama3_json"))
let (result, content) = detect_and_parse_tool_call(input, Some("llama3_json"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -891,7 +895,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
<|python_tag|>
{"name": "get_weather", "parameters": {"location": "New York, NY", "unit": "fahrenheit" }}
"#;
let (result, content) = detect_and_parse_tool_call(input, Some("llama3_json"))
let (result, content) = detect_and_parse_tool_call(input, Some("llama3_json"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -911,7 +915,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
async fn test_detect_and_parse_tool_call_error_handling() {
// Unknown parser string should return an error
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco, CA"}}"#;
let result = detect_and_parse_tool_call(input, Some("unknown_parser")).await;
let result = detect_and_parse_tool_call(input, Some("unknown_parser"), None).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
......@@ -922,7 +926,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
// Known parser, but invalid input (not JSON) should return Ok(None)
let input = "not a json";
let (result, content) = detect_and_parse_tool_call(input, Some("hermes"))
let (result, content) = detect_and_parse_tool_call(input, Some("hermes"), None)
.await
.unwrap();
assert_eq!(content, Some("not a json".to_string()));
......@@ -930,7 +934,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
// Known parser, but valid JSON with wrong shape should return Ok(None)
let input = r#"{"foo": "bar"}"#;
let (result, content) = detect_and_parse_tool_call(input, Some("hermes"))
let (result, content) = detect_and_parse_tool_call(input, Some("hermes"), None)
.await
.unwrap();
assert_eq!(content, Some(r#"{"foo": "bar"}"#.to_string()));
......@@ -945,7 +949,7 @@ Okay, the user is asking for the weather in San Francisco in Fahrenheit. Let me
- **Summer (June to August)**: Average highs range from the mid-60s to low 70s Fahrenheit, with cooler mornings and evenings. Coastal areas may be cooler than inland spots.
Remember, San Francisco weather can be quite unpredictable, particularly with its famous fog, which can significantly lower temperatures. Always check a local weather forecast for the most accurate and up-to-date information."#;
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::default())
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::default(), None)
.await
.unwrap();
assert_eq!(content, Some(input.to_string()));
......@@ -958,7 +962,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}
]</tool_calls>"#;
let config = ToolCallConfig::jamba();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
......@@ -975,7 +979,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
{"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "celsius"}}
]</tool_calls>"#;
let config = ToolCallConfig::jamba();
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -1003,7 +1007,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
..Default::default()
}),
};
let (result, content) = try_tool_call_parse(input, &config).await.unwrap();
let (result, content) = try_tool_call_parse(input, &config, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
......@@ -1016,7 +1020,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
#[tokio::test]
async fn test_detect_and_parse_tool_call_default_parser_nemotron_deci() {
let input = r#"<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let (result, content) = detect_and_parse_tool_call(input, None).await.unwrap();
let (result, content) = detect_and_parse_tool_call(input, None, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
......@@ -1029,7 +1033,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
#[tokio::test]
async fn test_detect_and_parse_tool_call_default_parser_nemotron_deci_multiple() {
let input = r#"<TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let (result, content) = detect_and_parse_tool_call(input, None).await.unwrap();
let (result, content) = detect_and_parse_tool_call(input, None, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -1047,7 +1051,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
async fn test_detect_and_parse_tool_call_default_parser_nemotron_deci_multiple_with_normal_text()
{
let input = r#"Hey How are you? <TOOLCALL>[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "fahrenheit"}}]</TOOLCALL>"#;
let (result, content) = detect_and_parse_tool_call(input, None).await.unwrap();
let (result, content) = detect_and_parse_tool_call(input, None, None).await.unwrap();
assert_eq!(content, Some("Hey How are you?".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -1064,7 +1068,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
#[tokio::test]
async fn test_detect_and_parse_tool_call_default_parser_llama3_json_with_python_tag() {
let input = r#"<|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
let (result, content) = detect_and_parse_tool_call(input, None).await.unwrap();
let (result, content) = detect_and_parse_tool_call(input, None, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
......@@ -1078,7 +1082,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
async fn test_detect_and_parse_tool_call_default_parser_llama3_json_with_python_tag_with_normal_text()
{
let input = r#"Hey How are you? <|python_tag|>{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
let (result, content) = detect_and_parse_tool_call(input, None).await.unwrap();
let (result, content) = detect_and_parse_tool_call(input, None, None).await.unwrap();
assert_eq!(content, Some("Hey How are you?".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
......@@ -1099,7 +1103,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
{"location": "San Francisco, CA",
"unit": "fahrenheit" }}
"#;
let (result, content) = detect_and_parse_tool_call(input, None).await.unwrap();
let (result, content) = detect_and_parse_tool_call(input, None, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
......@@ -1117,7 +1121,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
{"location": "San Francisco, CA",
"unit": "fahrenheit" }}
"#;
let (result, content) = detect_and_parse_tool_call(input, None).await.unwrap();
let (result, content) = detect_and_parse_tool_call(input, None, None).await.unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
......@@ -1130,7 +1134,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
#[tokio::test]
async fn test_detect_and_parse_tool_call_default_parser_llama3_json_without_python_tag() {
let input = r#"{ "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::mistral())
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::mistral(), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -1146,7 +1150,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
async fn test_detect_and_parse_tool_call_default_parser_llama3_json_without_python_tag_with_normal_text()
{
let input = r#"Hey How are you? { "name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit" } }"#;
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::mistral())
let (result, content) = try_tool_call_parse(input, &ToolCallConfig::mistral(), None)
.await
.unwrap();
assert_eq!(content, Some("Hey How are you?".to_string()));
......@@ -1162,7 +1166,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
async fn test_phi4_single_function_call() {
let input =
r#"functools[{"name": "get_country_capital", "arguments": {"country": "Poland"}}]"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -1175,7 +1179,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
#[tokio::test]
async fn test_phi4_single_function_call_with_normal_text() {
let input = r#"Hey How are you? functools[{"name": "get_country_capital", "arguments": {"country": "Poland"}}]"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"), None)
.await
.unwrap();
assert_eq!(content, Some("Hey How are you?".to_string()));
......@@ -1191,7 +1195,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
{"name": "get_country_capital", "arguments": {"country": "Poland"}},
{"name": "get_population", "arguments": {"city": "Warsaw"}}
]"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -1212,7 +1216,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
{"name": "get_country_capital", "arguments": {"country": "Poland"}},
{"name": "get_population", "arguments": {"city": "Warsaw"}}
]"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"), None)
.await
.unwrap();
assert_eq!(content, Some("Hey How are you?".to_string()));
......@@ -1232,7 +1236,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
let input = r#"functools[{"name": "get_weather_forecast", "arguments":
{"location": {"city": "San Francisco",
"state": "CA"}, "date": "2023-10-05"}}]"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -1249,7 +1253,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
let input = r#"Hey How are you? functools[{"name": "get_weather_forecast", "arguments":
{"location": {"city": "San Francisco",
"state": "CA"}, "date": "2023-10-05"}}]"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"), None)
.await
.unwrap();
assert_eq!(content, Some("Hey How are you?".to_string()));
......@@ -1265,7 +1269,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
async fn test_phi4_function_call_with_parameters_instead_of_arguments() {
let input = r#"functools[{"name": "calculate_distance",
"parameters": {"from": "New York", "to": "Los Angeles"}}]"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -1280,7 +1284,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
async fn test_phi4_function_call_with_parameters_instead_of_arguments_with_normal_text() {
let input = r#"Hey How are you? functools[{"name": "calculate_distance",
"parameters": {"from": "New York", "to": "Los Angeles"}}]"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"), None)
.await
.unwrap();
assert_eq!(content, Some("Hey How are you?".to_string()));
......@@ -1296,7 +1300,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
// Reproduce the issue where "functools" appears in content field
// This might happen when there's malformed JSON or parsing issues
let input = r#"functools{"name": "get_weather","arguments":{"location":"San Francisco"}}"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"), None)
.await
.unwrap();
// Content should be empty, not contain "functools"
......@@ -1312,7 +1316,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
// Test the case where only the token appears without JSON
// This case is less critical but shouldn't leak the full token
let input = r#"functools"#;
let (result, _content) = detect_and_parse_tool_call(input, Some("phi4"))
let (result, _content) = detect_and_parse_tool_call(input, Some("phi4"), None)
.await
.unwrap();
// Content may contain the token if no valid JSON follows, but shouldn't crash
......@@ -1325,7 +1329,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
async fn test_phi4_token_with_invalid_json() {
// Test the case where token is followed by invalid JSON
let input = r#"functools{invalid json}"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"), None)
.await
.unwrap();
// Content should be empty, not contain "functools" or leak the token
......@@ -1381,7 +1385,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
// are correctly treated as normal content, not tool calls
let input = r#"funk music is great"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"), None)
.await
.unwrap();
// Should be treated as normal content, not tool call
......@@ -1402,7 +1406,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
// Test words that start with "func" but are not "functools"
let input = r#"The function works well"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"), None)
.await
.unwrap();
assert_eq!(
......@@ -1413,7 +1417,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
assert_eq!(content, Some("The function works well".to_string()));
let input = r#"functional programming"#;
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"))
let (result, content) = detect_and_parse_tool_call(input, Some("phi4"), None)
.await
.unwrap();
assert_eq!(
......@@ -1438,7 +1442,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
];
for test_input in test_cases {
let (result, content) = detect_and_parse_tool_call(test_input, Some("phi4"))
let (result, content) = detect_and_parse_tool_call(test_input, Some("phi4"), None)
.await
.unwrap();
assert_eq!(
......@@ -1468,7 +1472,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
];
for test_input in test_cases {
let (result, content) = detect_and_parse_tool_call(test_input, Some("phi4"))
let (result, content) = detect_and_parse_tool_call(test_input, Some("phi4"), None)
.await
.unwrap();
assert_eq!(
......@@ -1489,7 +1493,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
#[tokio::test]
async fn test_pythonic_parser_basic_with_constants() {
let input = r#"[get_weather(location="San Francisco", unit="fahrenheit"), get_weather(location="New York", unit="fahrenheit")]"#;
let (result, content) = detect_and_parse_tool_call(input, Some("pythonic"))
let (result, content) = detect_and_parse_tool_call(input, Some("pythonic"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -1508,7 +1512,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
#[ignore]
async fn test_pythonic_parser_with_constants_and_normal_text() {
let input = r#"Hey How are you? [get_weather(location="San Francisco", unit="fahrenheit"), get_weather(location="New York", unit="fahrenheit")]"#;
let (result, content) = detect_and_parse_tool_call(input, Some("pythonic"))
let (result, content) = detect_and_parse_tool_call(input, Some("pythonic"), None)
.await
.unwrap();
assert_eq!(content, Some("Hey How are you?".to_string()));
......@@ -1528,7 +1532,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
async fn test_harmony_parser_basic() {
let input = r#"
<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco", "unit":"fahrenheit"}"#;
let (result, content) = detect_and_parse_tool_call(input, Some("harmony"))
let (result, content) = detect_and_parse_tool_call(input, Some("harmony"), None)
.await
.unwrap();
assert_eq!(
......@@ -1551,7 +1555,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
```json
{"location": "Paris"}
```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("deepseek_v3"))
let (result, content) = detect_and_parse_tool_call(input, Some("deepseek_v3"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -1567,7 +1571,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
#[tokio::test]
async fn test_deepseek_v3_1_parser_basic() {
let input = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Tokyo"}<|tool▁call▁end|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>{"location": "Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("deepseek_v3_1"))
let (result, content) = detect_and_parse_tool_call(input, Some("deepseek_v3_1"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -1588,7 +1592,8 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
</|DSML|invoke>
</|DSML|function_calls>"#;
let (tool_calls, normal_text) = detect_and_parse_tool_call(input, Some("deepseek_v3_2"))
let (tool_calls, normal_text) =
detect_and_parse_tool_call(input, Some("deepseek_v3_2"), None)
.await
.expect("Failed to parse");
......@@ -1614,7 +1619,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
</|DSML|invoke>
</|DSML|function_calls>"#;
let (tool_calls, _) = detect_and_parse_tool_call(input, Some("deepseek_v3_2"))
let (tool_calls, _) = detect_and_parse_tool_call(input, Some("deepseek_v3_2"), None)
.await
.expect("Failed to parse");
......@@ -1642,7 +1647,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
</|DSML|invoke>
</|DSML|function_calls>"#;
let (tool_calls, _) = detect_and_parse_tool_call(input, Some("deepseek_v3_2"))
let (tool_calls, _) = detect_and_parse_tool_call(input, Some("deepseek_v3_2"), None)
.await
.expect("Failed to parse");
......@@ -1660,7 +1665,7 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it
async fn test_hermes_parser_without_new_line() {
let input = r#"<tool_call>{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "celsius"}}</tool_call>"
"#;
let (result, content) = detect_and_parse_tool_call(input, Some("hermes"))
let (result, content) = detect_and_parse_tool_call(input, Some("hermes"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -1743,7 +1748,7 @@ mod parallel_tool_calling_tests {
{"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}
]</TOOLCALL>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -1759,7 +1764,7 @@ mod parallel_tool_calling_tests {
{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "fahrenheit"}}
]</TOOLCALL>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -1777,7 +1782,7 @@ mod parallel_tool_calling_tests {
{"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}
]</TOOLCALL>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -1821,7 +1826,7 @@ fahrenheit
</function>
</tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"), None)
.await
.unwrap();
......@@ -1837,7 +1842,7 @@ fahrenheit
async fn test_parallel_xlam_format_pure_json() {
let input = r#"[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]"#;
let (result, content) = detect_and_parse_tool_call(input, Some("mistral"))
let (result, content) = detect_and_parse_tool_call(input, Some("mistral"), None)
.await
.unwrap();
......@@ -1852,7 +1857,7 @@ fahrenheit
{"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}
]"#;
let (result, content) = detect_and_parse_tool_call(input, Some("mistral"))
let (result, content) = detect_and_parse_tool_call(input, Some("mistral"), None)
.await
.unwrap();
......@@ -1879,7 +1884,7 @@ fahrenheit
]</TOOLCALL>"#;
let (result, content) =
detect_and_parse_tool_call(input_nemotron_format, Some("nemotron_deci"))
detect_and_parse_tool_call(input_nemotron_format, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -1896,7 +1901,7 @@ fahrenheit
// Test with harmony parser for multiple tool calls
let input = r#"<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}<|call|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"city": "Orlando", "state": "FL", "unit": "fahrenheit"}<|call|>"#;
let (result, _content) = detect_and_parse_tool_call(input, Some("harmony"))
let (result, _content) = detect_and_parse_tool_call(input, Some("harmony"), None)
.await
.unwrap();
......@@ -1920,7 +1925,7 @@ fahrenheit
{"name": "web_search", "arguments": {"query": "Orlando Florida attractions", "max_results": 5}}
]</TOOLCALL>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -1952,7 +1957,7 @@ fahrenheit
{"name": "get_current_weather", "arguments": {"city": "Orlando", "invalid_field": 123}}
]</TOOLCALL>"#;
let (result, _content) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
let (result, _content) = detect_and_parse_tool_call(input, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -1971,7 +1976,7 @@ fahrenheit
async fn test_parallel_empty_array() {
let input = r#"<TOOLCALL>[]</TOOLCALL>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -1989,7 +1994,7 @@ fahrenheit
{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}
]</TOOLCALL>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -2012,7 +2017,7 @@ fahrenheit
{"name": "get_current_weather", "arguments": {"city": "Miami", "state": "FL", "unit": "fahrenheit"}}
]</TOOLCALL>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -2056,7 +2061,7 @@ fahrenheit
}
]</TOOLCALL>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
let (result, content) = detect_and_parse_tool_call(input, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -2134,7 +2139,7 @@ fahrenheit
{"name": "web_search", "arguments": {"query": "weather forecast", "max_results": 3}}
]</TOOLCALL>"#;
let (result, _) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
let (result, _) = detect_and_parse_tool_call(input, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -2151,7 +2156,7 @@ fahrenheit
{"name": "function_three", "arguments": {"param5": {"nested": "object"}}}
][/TOOL_CALLS]"#;
let (result, _) = detect_and_parse_tool_call(input, Some("mistral"))
let (result, _) = detect_and_parse_tool_call(input, Some("mistral"), None)
.await
.unwrap();
......@@ -2181,7 +2186,7 @@ fahrenheit
let input = format!("<TOOLCALL>[{}]</TOOLCALL>", tool_calls.join(","));
let start = std::time::Instant::now();
let (result, _) = detect_and_parse_tool_call(&input, Some("nemotron_deci"))
let (result, _) = detect_and_parse_tool_call(&input, Some("nemotron_deci"), None)
.await
.unwrap();
let duration = start.elapsed();
......@@ -2208,7 +2213,7 @@ fahrenheit
large_data, large_data
);
let (result, _) = detect_and_parse_tool_call(&input, Some("nemotron_deci"))
let (result, _) = detect_and_parse_tool_call(&input, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -2237,7 +2242,7 @@ fahrenheit
{"name": "process_unicode", "arguments": {"data": "café naïve résumé", "encoding": "utf-8"}}
]</TOOLCALL>"#;
let (result, _) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
let (result, _) = detect_and_parse_tool_call(input, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -2267,7 +2272,7 @@ fahrenheit
{"name": "regex_pattern", "arguments": {"pattern": "\\d{3}-\\d{3}-\\d{4}", "test_string": "Phone: 123-456-7890"}}
]</TOOLCALL>"#;
let (result, _) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
let (result, _) = detect_and_parse_tool_call(input, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -2293,7 +2298,7 @@ fahrenheit
{"name": "object_test", "arguments": {"empty_object": {}, "nested": {"level1": {"level2": {"value": "deep"}}}}}
]</TOOLCALL>"#;
let (result, _) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
let (result, _) = detect_and_parse_tool_call(input, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -2341,7 +2346,7 @@ fahrenheit
}
]</TOOLCALL>"#;
let (result, _) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
let (result, _) = detect_and_parse_tool_call(input, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -2377,7 +2382,7 @@ fahrenheit
];
for (input, parser) in test_cases {
let (result, _) = detect_and_parse_tool_call(&input, Some(parser))
let (result, _) = detect_and_parse_tool_call(&input, Some(parser), None)
.await
.unwrap_or_else(|e| panic!("Failed to parse with {}: {}", parser, e));
assert_eq!(
......@@ -2404,7 +2409,7 @@ fahrenheit
{"name": "single_call", "arguments": {"test": true}}
]</TOOLCALL>"#;
let (result, _) = detect_and_parse_tool_call(input_single, Some("nemotron_deci"))
let (result, _) = detect_and_parse_tool_call(input_single, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -2422,7 +2427,7 @@ fahrenheit
let input_many = format!("<TOOLCALL>[{}]</TOOLCALL>", many_calls.join(","));
let (result, _) = detect_and_parse_tool_call(&input_many, Some("nemotron_deci"))
let (result, _) = detect_and_parse_tool_call(&input_many, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -2451,7 +2456,7 @@ fahrenheit
{"name": "good_call_4", "arguments": {"param": "value4"}}
]</TOOLCALL>"#;
let (result, _) = detect_and_parse_tool_call(input, Some("nemotron_deci"))
let (result, _) = detect_and_parse_tool_call(input, Some("nemotron_deci"), None)
.await
.unwrap();
......@@ -2608,7 +2613,7 @@ pwd && ls
</parameter>
</function>
</tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -2633,7 +2638,7 @@ fahrenheit
</parameter>
</function>
</tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -2657,7 +2662,7 @@ fahrenheit
</parameter>
</function>
</tool_call> Let me get that information for you."#;
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"), None)
.await
.unwrap();
assert_eq!(
......@@ -2702,7 +2707,7 @@ fahrenheit
</parameter>
</function>
</tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -2730,7 +2735,18 @@ fahrenheit
</parameter>
</function>
</tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
let tools = vec![ToolDefinition {
name: "process_data".to_string(),
parameters: Some(serde_json::json!({
"properties": {
"config": {
"type": "array"
}
}
})),
}];
let (result, content) =
detect_and_parse_tool_call(input, Some("qwen3_coder"), Some(&tools))
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -2757,7 +2773,17 @@ true
</parameter>
</function>
</tool_call>"#;
let (result, _) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
let tools = vec![ToolDefinition {
name: "calculate".to_string(),
parameters: Some(serde_json::json!({
"properties": {
"x": {"type": "int"},
"y": {"type": "float"},
"enabled": {"type": "bool"},
}
})),
}];
let (result, _) = detect_and_parse_tool_call(input, Some("qwen3_coder"), Some(&tools))
.await
.unwrap();
assert_eq!(result.len(), 1);
......@@ -2771,7 +2797,7 @@ true
#[tokio::test]
async fn test_qwen3_coder_no_tool_calls() {
let input = "This is just normal text without any tool calls.";
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"), None)
.await
.unwrap();
assert_eq!(result.len(), 0);
......@@ -2781,7 +2807,7 @@ true
#[tokio::test]
async fn test_qwen3_coder_compact_format() {
let input = r#"<tool_call><function=search><parameter=query>rust programming</parameter><parameter=limit>10</parameter></function></tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -2789,7 +2815,7 @@ true
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "search");
assert_eq!(args["query"], "rust programming");
assert_eq!(args["limit"], 10);
assert_eq!(args["limit"], "10");
}
#[tokio::test]
......@@ -2801,7 +2827,7 @@ true
</parameter>
</function>
</tool_call>"#;
let (result, _) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
let (result, _) = detect_and_parse_tool_call(input, Some("qwen3_coder"), None)
.await
.unwrap();
assert_eq!(result.len(), 1);
......@@ -2833,7 +2859,7 @@ Seattle
</parameter>
</function>
</tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"), None)
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -2869,7 +2895,18 @@ weather forecasting
</parameter>
</function>
</tool_call>"#;
let (result, content) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
let tools = vec![ToolDefinition {
name: "web_search".to_string(),
parameters: Some(serde_json::json!({
"properties": {
"max_results": {
"type": "uint"
}
}
})),
}];
let (result, content) =
detect_and_parse_tool_call(input, Some("qwen3_coder"), Some(&tools))
.await
.unwrap();
assert_eq!(content, Some("".to_string()));
......@@ -2887,7 +2924,26 @@ weather forecasting
}
#[tokio::test]
async fn test_qwen3_coder_array_parameter_value() {
async fn test_qwen3_coder_array_parameter_value_without_tool_definition() {
let input = r#"<tool_call>
<function=process_list>
<parameter=items>
[1, 2, 3, 4, 5]
</parameter>
</function>
</tool_call>"#;
let (result, _) = detect_and_parse_tool_call(input, Some("qwen3_coder"), None)
.await
.unwrap();
assert_eq!(result.len(), 1);
let (name, args) = extract_name_and_args(result[0].clone());
assert_eq!(name, "process_list");
// The default is to return it as a string.
assert_eq!(args["items"], serde_json::json!("[1, 2, 3, 4, 5]"));
}
#[tokio::test]
async fn test_qwen3_coder_array_parameter_value_with_tool_definition() {
let input = r#"<tool_call>
<function=process_list>
<parameter=items>
......@@ -2895,7 +2951,17 @@ weather forecasting
</parameter>
</function>
</tool_call>"#;
let (result, _) = detect_and_parse_tool_call(input, Some("qwen3_coder"))
let tools = vec![ToolDefinition {
name: "process_list".to_string(),
parameters: Some(serde_json::json!({
"properties": {
"items": {
"type": "array"
}
}
})),
}];
let (result, _) = detect_and_parse_tool_call(input, Some("qwen3_coder"), Some(&tools))
.await
.unwrap();
assert_eq!(result.len(), 1);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::super::ToolDefinition;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
use regex::Regex;
use rustpython_parser::{
......@@ -161,6 +162,7 @@ fn const_expr(e: &Expr) -> Result<Value, Box<dyn std::error::Error>> {
pub fn try_tool_call_parse_pythonic(
message: &str,
_tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
let stripped = strip_text(message).trim().to_string();
......@@ -263,7 +265,7 @@ mod tests {
#[test]
fn test_parse_tool_call_parse_pythonic_basic() {
let message = "[foo(a=1, b=2), bar(x=3)]";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
let (result, content) = try_tool_call_parse_pythonic(message, None).unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -279,7 +281,7 @@ mod tests {
#[test]
fn test_parse_tool_call_parse_pythonic_with_text() {
let message = "Hey yo ! [foo(a=1, b=2), bar(x=3)] Hey yo";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
let (result, content) = try_tool_call_parse_pythonic(message, None).unwrap();
assert_eq!(content, Some("Hey yo !".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -295,7 +297,7 @@ mod tests {
#[test]
fn test_parse_tool_call_parse_pythonic_with_text_and_new_line() {
let message = "Hey \n yo ! [foo(a=1, b=2), bar(x=3)] Hey yo";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
let (result, content) = try_tool_call_parse_pythonic(message, None).unwrap();
assert_eq!(content, Some("Hey \n yo !".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -311,7 +313,7 @@ mod tests {
#[test]
fn test_parse_tool_call_parse_pythonic_with_no_calls() {
let message = "Hey \n yo !";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
let (result, content) = try_tool_call_parse_pythonic(message, None).unwrap();
assert_eq!(content, Some("Hey \n yo !".to_string()));
assert!(result.is_empty());
assert_eq!(result.len(), 0)
......@@ -320,7 +322,7 @@ mod tests {
#[test]
fn test_parse_tool_call_parse_pythonic_with_python_tags() {
let message = "<|python_start|>[foo(a=1, b=2), bar(x=3)]<|python_end|>";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
let (result, content) = try_tool_call_parse_pythonic(message, None).unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -336,7 +338,7 @@ mod tests {
#[test]
fn test_parse_tool_call_parse_pythonic_with_list_arg_values() {
let message = "[foo(a=[1, 2, 3], b=2), bar(x=[3, 4, 5])]";
let (result, _) = try_tool_call_parse_pythonic(message).unwrap();
let (result, _) = try_tool_call_parse_pythonic(message, None).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
......@@ -351,7 +353,7 @@ mod tests {
#[test]
fn test_parse_tool_call_parse_pythonic_with_dict_arg_values() {
let message = "[foo(a={'a': 1, 'b': 2}, b=2), bar(x={'x': 3, 'y': {'e': 'f'}})]";
let (result, _) = try_tool_call_parse_pythonic(message).unwrap();
let (result, _) = try_tool_call_parse_pythonic(message, None).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
......
......@@ -10,6 +10,7 @@ pub use super::parsers::detect_and_parse_tool_call;
pub async fn try_tool_call_parse_aggregate(
message: &str,
parser_str: Option<&str>,
tools: Option<&[super::ToolDefinition]>,
) -> anyhow::Result<(
Vec<dynamo_async_openai::types::ChatCompletionMessageToolCall>,
Option<String>,
......@@ -19,7 +20,7 @@ pub async fn try_tool_call_parse_aggregate(
} else {
tracing::info!("Using tool parser: {:?}", parser_str);
}
let (parsed, content) = detect_and_parse_tool_call(message, parser_str).await?;
let (parsed, content) = detect_and_parse_tool_call(message, parser_str, tools).await?;
if parsed.is_empty() {
return Ok((vec![], content));
}
......@@ -47,11 +48,12 @@ pub async fn try_tool_call_parse_aggregate(
pub async fn try_tool_call_parse_stream(
message: &str,
parser_str: Option<&str>,
tools: Option<&[super::ToolDefinition]>,
) -> anyhow::Result<(
Vec<dynamo_async_openai::types::ChatCompletionMessageToolCallChunk>,
Option<String>,
)> {
let (parsed, content) = detect_and_parse_tool_call(message, parser_str).await?;
let (parsed, content) = detect_and_parse_tool_call(message, parser_str, tools).await?;
if parsed.is_empty() {
return Ok((vec![], content));
}
......
......@@ -7,8 +7,10 @@
use std::collections::HashMap;
use regex::Regex;
use serde_json::Value;
use uuid::Uuid;
use super::super::ToolDefinition;
use super::super::config::XmlParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
......@@ -51,8 +53,9 @@ pub fn find_tool_call_end_position_xml(chunk: &str, config: &XmlParserConfig) ->
pub fn try_tool_call_parse_xml(
message: &str,
config: &XmlParserConfig,
tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
let (normal_text, tool_calls) = extract_tool_calls(message, config)?;
let (normal_text, tool_calls) = extract_tool_calls(message, config, tools)?;
let normal_content = if normal_text.is_empty() {
Some("".to_string())
......@@ -67,6 +70,7 @@ pub fn try_tool_call_parse_xml(
fn extract_tool_calls(
text: &str,
config: &XmlParserConfig,
tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(String, Vec<ToolCallResponse>)> {
let mut normal_parts = Vec::new();
let mut calls = Vec::new();
......@@ -89,7 +93,7 @@ fn extract_tool_calls(
let block = &text[abs_start..abs_end];
// Parse this tool call block.
if let Ok(mut parsed_calls) = parse_tool_call_block(block, config) {
if let Ok(mut parsed_calls) = parse_tool_call_block(block, config, tools) {
calls.append(&mut parsed_calls);
}
......@@ -115,6 +119,7 @@ fn extract_tool_calls(
fn parse_tool_call_block(
block: &str,
config: &XmlParserConfig,
tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<Vec<ToolCallResponse>> {
// Build regex patterns based on config
let function_start = regex::escape(&config.function_start_token);
......@@ -142,6 +147,9 @@ fn parse_tool_call_block(
continue;
}
// Get parameter config for this function
let param_config = get_arguments_config(function_name, tools);
// Parse parameters from the function body.
let mut parameters: HashMap<String, serde_json::Value> = HashMap::new();
......@@ -150,7 +158,8 @@ fn parse_tool_call_block(
let param_value = param_cap.get(2).map(|m| m.as_str()).unwrap_or("");
if !param_name.is_empty() {
let parsed_value = safe_parse_value(param_value);
let parsed_value =
convert_param_value(param_value, param_name, &param_config, function_name);
parameters.insert(param_name.to_string(), parsed_value);
}
}
......@@ -173,8 +182,316 @@ fn parse_tool_call_block(
Ok(results)
}
/// Extract argument configuration for a function from the tool definitions.
/// Returns a HashMap of parameter names to their schema definitions.
fn get_arguments_config(
func_name: &str,
tools: Option<&[ToolDefinition]>,
) -> HashMap<String, Value> {
let Some(tools) = tools else {
return HashMap::new();
};
for tool in tools {
if tool.name == func_name {
if let Some(params) = &tool.parameters {
// Try to extract "properties" from the parameters schema
if let Some(properties) = params.get("properties") {
if let Some(props_obj) = properties.as_object() {
return props_obj
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
}
} else if let Some(params_obj) = params.as_object() {
// If no "properties" field, treat the whole thing as the config
return params_obj
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
}
}
return HashMap::new();
}
}
tracing::warn!("Tool '{}' is not defined in the tools list.", func_name);
HashMap::new()
}
/// Convert parameter value based on its type in the schema.
/// This matches the behavior of the Python implementation.
/// Converts a string parameter value from XML into a typed JSON Value.
///
/// # Examples
///
/// **String types:**
/// ```text
/// Input: param_value="hello world", param_type="string"
/// Output: Value::String("hello world")
/// ```
///
/// ```text
/// Input: param_value="42", param_type="string"
/// Output: Value::String("42")
/// ```
///
/// **Integer types:**
/// ```text
/// Input: param_value="42", param_type="integer"
/// Output: Value::Number(42)
///
/// Input: param_value="not_a_number", param_type="int"
/// Output: Value::String("not_a_number") // Falls back to string with warning
/// ```
///
/// **Float/Number types:**
/// ```text
/// Input: param_value="3.14", param_type="number"
/// Output: Value::Number(3.14)
///
/// Input: param_value="42.0", param_type="float"
/// Output: Value::Number(42) // Whole numbers stored as integers
/// ```
///
/// **Boolean types:**
/// ```text
/// Input: param_value="true", param_type="boolean"
/// Output: Value::Bool(true)
///
/// Input: param_value="yes", param_type="bool"
/// Output: Value::Bool(false) // Falls back to false with warning
/// ```
///
/// **Complex types (objects/arrays):**
/// ```text
/// Input: param_value='{"key": "value"}', param_type="object"
/// Output: Value::Object({"key": "value"})
///
/// Input: param_value="[1, 2, 3]", param_type="array"
/// Output: Value::Array([1, 2, 3])
///
/// Input: param_value="{'key': 'value'}", param_type="dict"
/// Output: Value::Object({"key": "value"}) // Uses ast.literal_eval-style parsing
/// ```
///
/// **Special cases:**
/// ```text
/// Input: param_value="null", param_type=<any>
/// Output: Value::Null // Handled before type checking
///
/// Input: param_value="&lt;tag&gt;", param_type="string"
/// Output: Value::String("<tag>") // HTML entities are unescaped
///
/// Input: param_value="123", param_type=<undefined/not in schema>
/// Output: Value::String("123") // Unknown params returned as strings
/// ```
///
/// # Arguments
///
/// * `param_value` - The raw string value from XML parameter
/// * `param_name` - The parameter name (used for schema lookup and error messages)
/// * `param_config` - Schema defining expected types for each parameter
/// * `func_name` - The function/tool name (used for error messages)
///
/// # Type Aliases
///
/// The function recognizes various type name aliases:
/// - Strings: "string", "str", "text", "varchar", "char", "enum"
/// - Integers: "int", "integer", "int32", "int64", "uint", "long", "short", "unsigned"
/// - Numbers: "number", "num", "float", "float32", "float64", "double"
/// - Booleans: "boolean", "bool", "binary"
/// - Objects: "object", "dict", "dictionary"
/// - Arrays: "array", "arr", "list"
fn convert_param_value(
param_value: &str,
param_name: &str,
param_config: &HashMap<String, Value>,
func_name: &str,
) -> Value {
// HTML unescape and trim
let param_value = html_unescape(param_value.trim());
// Handle null
if param_value.to_lowercase() == "null" {
return Value::Null;
}
// Check if parameter is in config
if !param_config.contains_key(param_name) {
tracing::debug!(
"Parsed parameter '{}' is not defined in the tool parameters for tool '{}', directly returning the string value.",
param_name,
func_name
);
return Value::String(param_value);
}
// Get the type from schema
let param_type = param_config
.get(param_name)
.and_then(|v| v.get("type"))
.and_then(|t| t.as_str())
.unwrap_or("string")
.to_lowercase();
// The follow `match` block follows this rough pattern for each block:
// 1. Match `param_type` against predefined string representations of each type,
// 2. Parse the string value and convert it to the appropriate Rust JSON Value type.
// Each branch handles a category of type aliases (e.g., "int"/"integer"/"int32" all map to i64).
// If parsing fails, we log a warning and fall back to returning the value as a string.
match param_type.as_str() {
// String types: Return value as-is (already HTML-unescaped above)
"string" | "str" | "text" | "varchar" | "char" | "enum" => Value::String(param_value),
// Integer types: Parse as i64, fall back to string on error.
// Matches: "int", "integer", "int32", "uint", "unsigned", "long", "short", etc.
t if t.starts_with("int")
|| t.starts_with("uint")
|| t.starts_with("long")
|| t.starts_with("short")
|| t.starts_with("unsigned") =>
{
match param_value.parse::<i64>() {
Ok(int_val) => Value::Number(int_val.into()),
Err(_) => {
tracing::warn!(
"Parsed value '{}' of parameter '{}' is not an integer in tool '{}', degenerating to string.",
param_value,
param_name,
func_name
);
Value::String(param_value)
}
}
}
// Float/Number types: Parse as f64.
// Matches: "number", "num", "float", "float32", "float64", "double", etc.
// Note: Whole numbers (e.g., 42.0) are stored as integers for better JSON compatibility.
t if t.starts_with("num") || t.starts_with("float") => {
match param_value.parse::<f64>() {
Ok(float_val) => {
// Return int if it's a whole number, otherwise float.
if float_val.fract() == 0.0 && float_val.is_finite() {
Value::Number((float_val as i64).into())
} else if let Some(num) = serde_json::Number::from_f64(float_val) {
Value::Number(num)
} else {
tracing::warn!(
"Parsed value '{}' of parameter '{}' is not a valid float in tool '{}', degenerating to string.",
param_value,
param_name,
func_name
);
Value::String(param_value)
}
}
Err(_) => {
tracing::warn!(
"Parsed value '{}' of parameter '{}' is not a float in tool '{}', degenerating to string.",
param_value,
param_name,
func_name
);
Value::String(param_value)
}
}
}
// Boolean types: Only "true" or "false" (case-insensitive) are valid.
// Any other value defaults to false with a warning.
"boolean" | "bool" | "binary" => {
let lower_val = param_value.to_lowercase();
if lower_val != "true" && lower_val != "false" {
tracing::warn!(
"Parsed value '{}' of parameter '{}' is not a boolean (`true` or `false`) in tool '{}', degenerating to false.",
param_value,
param_name,
func_name
);
}
Value::Bool(lower_val == "true")
}
// Complex types (objects/arrays): Try JSON parsing, then fall back to Python-style
// `ast.literal_eval` (or our own barebones version of it for the purposes of this
// parser).
// Matches: "object", "array", "arr", "dict", "dictionary", "list", etc.
// This handles both JSON syntax ({"a": 1}) and Python syntax ({'a': 1}).
t if t == "object"
|| t == "array"
|| t == "arr"
|| t.starts_with("dict")
|| t.starts_with("list") =>
{
// Try JSON parsing first (standard JSON with double quotes).
if let Ok(json_val) = serde_json::from_str::<Value>(&param_value) {
return json_val;
}
tracing::warn!(
"Parsed value '{}' of parameter '{}' cannot be parsed with json.loads in tool '{}', will try other methods to parse it.",
param_value,
param_name,
func_name
);
// Try `ast.literal_eval` equivalent (handles Python-style single quotes, etc.).
if let Ok(json_val) = try_literal_eval(&param_value) {
return json_val;
}
tracing::warn!(
"Parsed value '{}' of parameter '{}' cannot be converted via Python `ast.literal_eval()` in tool '{}', degenerating to string.",
param_value,
param_name,
func_name
);
Value::String(param_value)
}
// Unknown/custom types: Attempt best-effort parsing via `literal_eval`.
// This allows for flexible type names while still trying to parse structured data
_ => {
// Unknown type, try `literal_eval`.
if let Ok(json_val) = try_literal_eval(&param_value) {
return json_val;
}
tracing::warn!(
"Parsed value '{}' of parameter '{}' cannot be converted via Python `ast.literal_eval()` in tool '{}', degenerating to string.",
param_value,
param_name,
func_name
);
Value::String(param_value)
}
}
}
/// Try to parse a value similar to Python's ast.literal_eval.
/// This is a simplified version that handles common cases.
fn try_literal_eval(s: &str) -> Result<Value, ()> {
// First try standard JSON
if let Ok(val) = serde_json::from_str::<Value>(s) {
return Ok(val);
}
// Try to handle Python-style literals (single quotes, True/False/None)
let normalized = s
.replace('\'', "\"") // Replace single quotes with double quotes
.replace("True", "true")
.replace("False", "false")
.replace("None", "null");
serde_json::from_str::<Value>(&normalized).map_err(|_| ())
}
/// Safely parse a value - tries JSON, then falls back to string.
/// Mimics SGLang's `_safe_val` function in spirit.
/// NOTE: This function is deprecated and kept for reference. Use convert_param_value instead.
#[allow(dead_code)]
fn safe_parse_value(raw: &str) -> serde_json::Value {
// HTML unescape
let unescaped = html_unescape(raw.trim());
......@@ -279,7 +596,8 @@ pwd && ls
</function>
</tool_call>"#;
let (calls, normal) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
let (calls, normal) =
try_tool_call_parse_xml(input, &XmlParserConfig::default(), None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "execute_bash");
assert_eq!(normal, Some("".to_string()));
......@@ -304,7 +622,7 @@ fahrenheit
</function>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default(), None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather");
......@@ -324,7 +642,8 @@ Dallas
</function>
</tool_call> Let me check that for you."#;
let (calls, normal) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
let (calls, normal) =
try_tool_call_parse_xml(input, &XmlParserConfig::default(), None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather");
assert_eq!(
......@@ -350,7 +669,7 @@ Orlando
</function>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default(), None).unwrap();
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].function.name, "get_weather");
assert_eq!(calls[1].function.name, "get_weather");
......@@ -363,6 +682,17 @@ Orlando
#[test]
fn test_parse_json_parameter_value() {
// With schema-aware parsing, we need to provide a schema to parse JSON objects
let tools = vec![ToolDefinition {
name: "process_data".to_string(),
parameters: Some(serde_json::json!({
"type": "object",
"properties": {
"config": {"type": "object"}
}
})),
}];
let input = r#"<tool_call>
<function=process_data>
<parameter=config>
......@@ -371,7 +701,8 @@ Orlando
</function>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
let (calls, _) =
try_tool_call_parse_xml(input, &XmlParserConfig::default(), Some(&tools)).unwrap();
assert_eq!(calls.len(), 1);
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
......@@ -383,7 +714,8 @@ Orlando
#[test]
fn test_parse_no_tool_calls() {
let input = "This is just normal text without any tool calls.";
let (calls, normal) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
let (calls, normal) =
try_tool_call_parse_xml(input, &XmlParserConfig::default(), None).unwrap();
assert_eq!(calls.len(), 0);
assert_eq!(normal, Some(input.to_string()));
}
......@@ -397,7 +729,7 @@ value
</tool_call>"#;
// Should handle gracefully - might parse or return empty
let result = try_tool_call_parse_xml(input, &XmlParserConfig::default());
let result = try_tool_call_parse_xml(input, &XmlParserConfig::default(), None);
assert!(result.is_ok());
}
......@@ -410,7 +742,7 @@ ls -la
</function>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default(), None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "execute_bash");
......@@ -427,7 +759,7 @@ Boston
</parameter>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default(), None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "get_weather");
......@@ -443,7 +775,7 @@ Boston
SELECT * FROM users
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default(), None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "run_query");
......@@ -463,7 +795,7 @@ rust programming
</function>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default()).unwrap();
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default(), None).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "search");
......@@ -471,4 +803,132 @@ rust programming
// This matches the original SGLang python implementation.
assert_eq!(args["query"], "rust programming\n<parameter=limit>\n10");
}
#[test]
fn test_schema_aware_type_conversion() {
// This test matches the Python test_parse_streaming_increment_multiple_parameters
// from the diff, showing schema-aware type conversion
let tools = vec![ToolDefinition {
name: "multi_param_func".to_string(),
parameters: Some(serde_json::json!({
"type": "object",
"properties": {
"param1": {"type": "string"},
"param2": {"type": "float"},
"param3": {"type": "integer"},
"param4": {"type": "boolean"},
"param5": {"type": "object"},
"param6": {"type": "array"},
"param7": {"type": "null"},
"param8": {"type": "other_type"}
},
"required": ["param1", "param2", "param3", "param4", "param5", "param6", "param7", "param8"]
})),
}];
let input = r#"<tool_call>
<function=multi_param_func>
<parameter=param1>42</parameter>
<parameter=param2>41.9</parameter>
<parameter=param3>42</parameter>
<parameter=param4>true</parameter>
<parameter=param5>{"key": "value"}</parameter>
<parameter=param6>[1, 2, 3]</parameter>
<parameter=param7>null</parameter>
<parameter=param8>{'arg1': 3, 'arg2': [1, 2]}</parameter>
</function>
</tool_call>"#;
let (calls, _) =
try_tool_call_parse_xml(input, &XmlParserConfig::default(), Some(&tools)).unwrap();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].function.name, "multi_param_func");
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
// param1 is type "string" so "42" stays as string
assert_eq!(args["param1"], "42");
// param2 is type "float" so 41.9 is parsed as float
assert_eq!(args["param2"], 41.9);
// param3 is type "integer" so 42 is parsed as integer
assert_eq!(args["param3"], 42);
// param4 is type "boolean" so "true" is parsed as bool
assert_eq!(args["param4"], true);
// param5 is type "object" so JSON is parsed
assert_eq!(args["param5"], serde_json::json!({"key": "value"}));
// param6 is type "array" so JSON array is parsed
assert_eq!(args["param6"], serde_json::json!([1, 2, 3]));
// param7 is type "null" so "null" is parsed as null
assert_eq!(args["param7"], serde_json::Value::Null);
// param8 is other_type, uses literal_eval which converts Python-style dict
assert_eq!(
args["param8"],
serde_json::json!({"arg1": 3, "arg2": [1, 2]})
);
}
#[test]
fn test_schema_aware_type_conversion_fallback() {
// Test that invalid values fall back to strings with warnings
let tools = vec![ToolDefinition {
name: "test_func".to_string(),
parameters: Some(serde_json::json!({
"type": "object",
"properties": {
"int_param": {"type": "integer"},
"float_param": {"type": "float"},
"bool_param": {"type": "boolean"}
}
})),
}];
let input = r#"<tool_call>
<function=test_func>
<parameter=int_param>not_an_int</parameter>
<parameter=float_param>not_a_float</parameter>
<parameter=bool_param>not_a_bool</parameter>
</function>
</tool_call>"#;
let (calls, _) =
try_tool_call_parse_xml(input, &XmlParserConfig::default(), Some(&tools)).unwrap();
assert_eq!(calls.len(), 1);
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
// All should fall back to strings
assert_eq!(args["int_param"], "not_an_int");
assert_eq!(args["float_param"], "not_a_float");
// bool_param with invalid value defaults to false
assert_eq!(args["bool_param"], false);
}
#[test]
fn test_no_schema_fallback_behavior() {
// Without schema, behavior should match old safe_parse_value logic
let input = r#"<tool_call>
<function=unknown_func>
<parameter=param1>42</parameter>
<parameter=param2>true</parameter>
<parameter=param3>hello</parameter>
</function>
</tool_call>"#;
let (calls, _) = try_tool_call_parse_xml(input, &XmlParserConfig::default(), None).unwrap();
assert_eq!(calls.len(), 1);
let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap();
// Without schema, all values are returned as strings (no type inference)
assert_eq!(args["param1"], "42");
assert_eq!(args["param2"], "true");
assert_eq!(args["param3"], "hello");
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment