Unverified Commit 67e1f6ee authored by Elyas Mehtabuddin's avatar Elyas Mehtabuddin Committed by GitHub
Browse files

feat: enable parallel tool calling and add testing (#3188)


Signed-off-by: default avatarElyas Mehtabuddin <emehtabuddin@nvidia.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 9d73be12
......@@ -2229,6 +2229,7 @@ version = "0.5.1"
dependencies = [
"anyhow",
"dynamo-async-openai",
"dynamo-llm",
"lazy_static",
"num-traits",
"openai-harmony",
......
......@@ -886,7 +886,7 @@ version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e"
dependencies = [
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
......@@ -1630,12 +1630,14 @@ dependencies = [
"once_cell",
"prometheus",
"rand 0.9.2",
"rayon",
"regex",
"serde",
"serde_json",
"socket2 0.5.10",
"thiserror 2.0.16",
"tokio",
"tokio-rayon",
"tokio-stream",
"tokio-util",
"tower-http",
......@@ -3026,7 +3028,7 @@ dependencies = [
"portable-atomic",
"portable-atomic-util",
"serde",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
......@@ -6101,6 +6103,16 @@ dependencies = [
"syn 2.0.106",
]
[[package]]
name = "tokio-rayon"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7cf33a76e0b1dd03b778f83244137bd59887abf25c0e87bc3e7071105f457693"
dependencies = [
"rayon",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.26.2"
......
......@@ -981,11 +981,6 @@ pub fn validate_response_unsupported_fields(
"`metadata` is not supported.",
));
}
if inner.parallel_tool_calls == Some(true) {
return Some(ErrorMessage::not_implemented_error(
"`parallel_tool_calls: true` is not supported.",
));
}
if inner.previous_response_id.is_some() {
return Some(ErrorMessage::not_implemented_error(
"`previous_response_id` is not supported.",
......@@ -1338,6 +1333,14 @@ mod tests {
assert!(result.is_none());
}
#[test]
fn test_validate_unsupported_fields_accepts_parallel_tool_calls() {
let mut request = make_base_request();
request.inner.parallel_tool_calls = Some(true);
let result = validate_response_unsupported_fields(&request);
assert!(result.is_none(), "parallel_tool_calls should be supported");
}
#[test]
fn test_validate_unsupported_fields_detects_flags() {
#[allow(clippy::type_complexity)]
......@@ -1353,10 +1356,6 @@ mod tests {
),
("max_tool_calls", Box::new(|r| r.max_tool_calls = Some(3))),
("metadata", Box::new(|r| r.metadata = Some(HashMap::new()))),
(
"parallel_tool_calls",
Box::new(|r| r.parallel_tool_calls = Some(true)),
),
(
"previous_response_id",
Box::new(|r| r.previous_response_id = Some("prev-id".into())),
......
......@@ -133,10 +133,11 @@ impl ChoiceJailState {
if !self.is_jailed {
// Use the marker matcher to detect complete/partial markers
match jail_stream
let match_result = jail_stream
.marker_matcher
.process_chunk(content, &self.partial_match_buffer)
{
.process_chunk(content, &self.partial_match_buffer);
match match_result {
MatchResult::Complete {
prefix,
marker,
......@@ -632,6 +633,14 @@ impl JailedStream {
let tool_call_match = self.tool_call_parser.is_some()
&& detect_tool_call_start(content, self.tool_call_parser.as_deref()).unwrap_or(false);
tracing::debug!(
"should_start_jail: content={:?}, sequence_match={}, tool_call_match={}, sequences={:?}",
content,
sequence_match,
tool_call_match,
self.jail_start_sequences
);
sequence_match || tool_call_match
}
......@@ -726,10 +735,12 @@ impl JailedStream {
async fn should_exit_jail_early(&self, accumulated: &str) -> bool {
if let Some(ref parser) = self.tool_call_parser {
// Try to parse - if successful and we have complete tool calls, exit early
if let Ok((tool_calls, _)) =
try_tool_call_parse_aggregate(accumulated, Some(parser)).await
{
return !tool_calls.is_empty();
match try_tool_call_parse_aggregate(accumulated, Some(parser)).await {
Ok((tool_calls, _normal_text)) => {
let result = !tool_calls.is_empty();
return result;
}
Err(_e) => {}
}
}
false
......@@ -878,6 +889,7 @@ impl JailedStreamBuilder {
MarkerMatcher::new(vec!["__NEVER_MATCH__".to_string()])
.expect("Failed to create dummy MarkerMatcher")
} else {
tracing::debug!("Creating MarkerMatcher with patterns: {:?}", all_patterns);
MarkerMatcher::new(all_patterns)
.expect("Failed to create MarkerMatcher with configured patterns")
};
......
This diff is collapsed.
......@@ -38,3 +38,6 @@ openai-harmony = "0.0.3"
lazy_static = "1.5.0"
rustpython-parser = "0.4.0"
num-traits = "0.2"
[dev-dependencies]
dynamo-llm = { workspace = true }
......@@ -117,8 +117,9 @@ impl ToolCallConfig {
Self {
format: ToolCallParserType::Json,
json: JsonParserConfig {
// TODO(elyas): remove the duplicate token
tool_call_start_tokens: vec!["[TOOL_CALLS]".to_string()],
tool_call_end_tokens: vec!["".to_string()],
tool_call_end_tokens: vec!["[/TOOL_CALLS]".to_string(), "".to_string()],
..Default::default()
},
}
......
......@@ -7,6 +7,8 @@ pub mod json;
pub mod parsers;
pub mod pythonic;
pub mod response;
#[cfg(test)]
pub mod tests;
pub mod tools;
// Re-export main types and functions for convenience
......
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Internal tests module for `tool_calling`.
//!
//! Unit tests for submodules live alongside their implementations.
//! This placeholder exists to satisfy the conditional `pub mod tests` declaration.
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Tests for tool calling functionality
#[cfg(test)]
mod parallel_tool_call_integration;
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration test for parallel tool calling functionality
//!
//! This test simulates a complete chat completion request with parallel tool calls,
//! mocking the response and testing the tool call parsing functionality.
//!
//! The test covers:
//! - Creating a mock NvCreateChatCompletionRequest based on a curl request
//! - Mocking a chat completion response with parallel tool calls in <tool_call> format
//! - Parsing the tool calls using the hermes parser
//! - Validating OpenAI API compatibility
//! - Testing error handling with malformed content
//! - Ensuring tool call IDs are unique and properly formatted
use dynamo_llm::protocols::openai::{
chat_completions::NvCreateChatCompletionRequest, common_ext::CommonExt,
};
use dynamo_parsers::{ToolCallResponse, ToolCallType, detect_and_parse_tool_call};
use serde_json::json;
/// Creates a mock NvCreateChatCompletionRequest based on the curl request
fn create_mock_chat_completion_request() -> NvCreateChatCompletionRequest {
let messages = vec![
dynamo_async_openai::types::ChatCompletionRequestMessage::System(
dynamo_async_openai::types::ChatCompletionRequestSystemMessage {
content: dynamo_async_openai::types::ChatCompletionRequestSystemMessageContent::Text(
"You MUST use two tools in parallel to resolve the user request: call get_current_weather for each city AND call is_holiday_today to check if today is a holiday. Do not answer without using both tools.".to_string()
),
name: None,
}
),
dynamo_async_openai::types::ChatCompletionRequestMessage::User(
dynamo_async_openai::types::ChatCompletionRequestUserMessage {
content: dynamo_async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"What is the weather in Dallas, Texas? Is today a holiday?".to_string()
),
name: None,
}
),
];
let tools = vec![
dynamo_async_openai::types::ChatCompletionTool {
r#type: dynamo_async_openai::types::ChatCompletionToolType::Function,
function: dynamo_async_openai::types::FunctionObject {
name: "get_current_weather".to_string(),
description: Some("Get weather for a city/state in specified units".to_string()),
parameters: Some(json!({
"type": "object",
"properties": {
"city": { "type": "string", "description": "City name, e.g., Dallas" },
"state": { "type": "string", "description": "Two-letter state code, e.g., TX" },
"unit": { "type": "string", "enum": ["fahrenheit", "celsius"] }
},
"required": ["city", "state", "unit"],
"additionalProperties": false
})),
strict: None,
},
},
dynamo_async_openai::types::ChatCompletionTool {
r#type: dynamo_async_openai::types::ChatCompletionToolType::Function,
function: dynamo_async_openai::types::FunctionObject {
name: "is_holiday_today".to_string(),
description: Some("Return whether today is a public holiday".to_string()),
parameters: Some(json!({
"type": "object",
"properties": {},
"additionalProperties": false
})),
strict: None,
},
},
];
let inner = dynamo_async_openai::types::CreateChatCompletionRequestArgs::default()
.model("Qwen/Qwen3-0.6B")
.temperature(0.0)
.max_tokens(3000u32)
.stream(false)
.messages(messages)
.tools(tools)
.tool_choice(dynamo_async_openai::types::ChatCompletionToolChoiceOption::Required)
.build()
.expect("Failed to build chat completion request");
NvCreateChatCompletionRequest {
inner,
common: CommonExt::default(),
nvext: None,
chat_template_args: None,
}
}
/// Mock response content that contains parallel tool calls
fn get_mock_response_content() -> String {
r#"<think>Okay, the user is asking two things: the weather in Dallas, Texas, and whether today is a holiday. I need to use both tools here. First, I'll check the weather using get_current_weather with city Dallas and state Texas. Then, I'll use is_holiday_today to see if today is a public holiday. I have to make sure to call both functions in parallel. Let me structure the tool calls properly.</think>
<tool_call>
{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}
</tool_call>
<tool_call>
{"name": "is_holiday_today", "arguments": {}}
</tool_call>"#.to_string()
}
/// Validates that a tool call response matches expected values
fn validate_weather_tool_call(tool_call: &ToolCallResponse) {
assert_eq!(tool_call.function.name, "get_current_weather");
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)
.expect("Arguments should be valid JSON");
let args_obj = args.as_object().expect("Arguments should be an object");
assert_eq!(args_obj.get("city").unwrap().as_str().unwrap(), "Dallas");
assert_eq!(args_obj.get("state").unwrap().as_str().unwrap(), "TX");
assert_eq!(
args_obj.get("unit").unwrap().as_str().unwrap(),
"fahrenheit"
);
// Validate OpenAI compatibility
assert!(!tool_call.id.is_empty(), "Tool call should have an ID");
assert_eq!(tool_call.tp, ToolCallType::Function);
}
/// Validates that a holiday tool call response matches expected values
fn validate_holiday_tool_call(tool_call: &ToolCallResponse) {
assert_eq!(tool_call.function.name, "is_holiday_today");
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)
.expect("Arguments should be valid JSON");
let args_obj = args.as_object().expect("Arguments should be an object");
assert!(
args_obj.is_empty(),
"Holiday tool should have empty arguments"
);
// Validate OpenAI compatibility
assert!(!tool_call.id.is_empty(), "Tool call should have an ID");
assert_eq!(tool_call.tp, ToolCallType::Function);
}
/// Validates that tool call IDs are unique
fn validate_unique_tool_call_ids(tool_calls: &[ToolCallResponse]) {
let mut ids = std::collections::HashSet::new();
for tool_call in tool_calls {
assert!(
ids.insert(tool_call.id.clone()),
"Tool call IDs should be unique: {}",
tool_call.id
);
}
}
#[tokio::test]
async fn test_parallel_tool_call_integration() {
// Create the mock request
let request = create_mock_chat_completion_request();
// Validate request structure
assert_eq!(request.inner.model, "Qwen/Qwen3-0.6B");
assert_eq!(request.inner.temperature, Some(0.0));
#[allow(deprecated)]
{
assert_eq!(request.inner.max_tokens, Some(3000));
}
assert_eq!(request.inner.stream, Some(false));
assert_eq!(request.inner.messages.len(), 2);
assert_eq!(request.inner.tools.as_ref().unwrap().len(), 2);
// Verify tool choice is required
match request.inner.tool_choice.as_ref().unwrap() {
dynamo_async_openai::types::ChatCompletionToolChoiceOption::Required => {
// This is expected
}
_ => panic!("Tool choice should be Required"),
}
// Get the mock response content
let response_content = get_mock_response_content();
// Verify the response contains both tool calls
assert!(response_content.contains("get_current_weather"));
assert!(response_content.contains("is_holiday_today"));
assert!(response_content.contains("Dallas"));
assert!(response_content.contains("Texas"));
assert!(response_content.contains("fahrenheit"));
}
#[tokio::test]
async fn test_parallel_tool_call_parsing() {
let response_content = get_mock_response_content();
// Parse the tool calls using the hermes parser (works well with <tool_call> format)
let (tool_calls, remaining_content) =
detect_and_parse_tool_call(&response_content, Some("hermes"))
.await
.expect("Should successfully parse tool calls");
// Validate we got exactly 2 tool calls
assert_eq!(
tool_calls.len(),
2,
"Should parse exactly 2 parallel tool calls"
);
// Validate remaining content (should be the thinking part)
assert!(remaining_content.is_some());
let remaining = remaining_content.unwrap();
assert!(remaining.contains("<think>"));
assert!(remaining.contains("</think>"));
// Sort tool calls by name for consistent testing
let mut sorted_calls = tool_calls;
sorted_calls.sort_by(|a, b| a.function.name.cmp(&b.function.name));
// Validate the weather tool call (first alphabetically)
validate_weather_tool_call(&sorted_calls[0]);
// Validate the holiday tool call (second alphabetically)
validate_holiday_tool_call(&sorted_calls[1]);
// Validate tool call IDs are unique
validate_unique_tool_call_ids(&sorted_calls);
}
#[tokio::test]
async fn test_parallel_tool_call_with_explicit_parser() {
let response_content = get_mock_response_content();
// Test with explicit parser selection
let parsers_to_test = vec![
"hermes", // Should work well with this format
];
for parser in parsers_to_test {
let (tool_calls, remaining_content) =
detect_and_parse_tool_call(&response_content, Some(parser))
.await
.unwrap_or_else(|e| panic!("Should successfully parse with {parser} parser: {e}"));
// Should get 2 tool calls regardless of parser
assert_eq!(
tool_calls.len(),
2,
"Parser {parser} should find 2 tool calls"
);
// Validate remaining content exists
assert!(remaining_content.is_some());
// Sort and validate calls
let mut sorted_calls = tool_calls;
sorted_calls.sort_by(|a, b| a.function.name.cmp(&b.function.name));
validate_weather_tool_call(&sorted_calls[0]);
validate_holiday_tool_call(&sorted_calls[1]);
validate_unique_tool_call_ids(&sorted_calls);
}
}
#[tokio::test]
async fn test_tool_call_json_structure() {
let response_content = get_mock_response_content();
let (tool_calls, _) = detect_and_parse_tool_call(&response_content, Some("hermes"))
.await
.expect("Should parse tool calls");
// Test JSON serialization
for tool_call in &tool_calls {
let json_str =
serde_json::to_string(tool_call).expect("Tool call should serialize to JSON");
// Verify the JSON contains expected fields
assert!(json_str.contains("\"id\""));
assert!(json_str.contains("\"type\""));
assert!(json_str.contains("\"function\""));
assert!(json_str.contains(&tool_call.function.name));
}
}
#[tokio::test]
async fn test_openai_compatibility_structure() {
let response_content = get_mock_response_content();
let (tool_calls, _) = detect_and_parse_tool_call(&response_content, Some("hermes"))
.await
.expect("Should parse tool calls");
// Validate OpenAI API compatibility
for tool_call in &tool_calls {
// Should have all required OpenAI fields
assert!(!tool_call.id.is_empty(), "Missing required 'id' field");
assert_eq!(
tool_call.tp,
ToolCallType::Function,
"Type should be 'function'"
);
assert!(
!tool_call.function.name.is_empty(),
"Function name should not be empty"
);
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)
.expect("Arguments should be valid JSON");
assert!(args.is_object(), "Arguments should be an object");
// ID should follow expected format (call-XXXXXXXX or call_XXXXXXXX)
assert!(
tool_call.id.starts_with("call-") || tool_call.id.starts_with("call_"),
"ID should start with 'call-' or 'call_': {}",
tool_call.id
);
assert!(
tool_call.id.len() > 5,
"ID should be longer than just 'call': {}",
tool_call.id
);
}
}
#[tokio::test]
async fn test_parallel_tool_call_error_handling() {
// Test with malformed content
let malformed_content = r#"<tool_call>
{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}
</tool_call>
<tool_call>
{"invalid_json": }
</tool_call>"#;
let result = detect_and_parse_tool_call(malformed_content, Some("hermes")).await;
// Should handle partial parsing gracefully
match result {
Ok((tool_calls, _)) => {
// May parse valid tool calls and ignore malformed ones, or return empty
println!(
"Parsed {} tool calls from malformed content",
tool_calls.len()
);
if !tool_calls.is_empty() {
// If any were parsed, verify they're valid
for call in &tool_calls {
assert!(
!call.function.name.is_empty(),
"Parsed tool call should have valid name"
);
}
}
}
Err(e) => {
// Error handling is also acceptable for malformed input
println!("Expected error for malformed input: {}", e);
}
}
}
#[tokio::test]
async fn test_empty_tool_calls() {
let content_without_tools = "This is just a regular response without any tool calls.";
let (tool_calls, remaining_content) =
detect_and_parse_tool_call(content_without_tools, Some("hermes"))
.await
.expect("Should handle content without tool calls");
assert!(
tool_calls.is_empty(),
"Should return empty tool calls array"
);
assert!(
remaining_content.is_some(),
"Should return the original content"
);
assert_eq!(remaining_content.unwrap(), content_without_tools);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment