Unverified Commit ff7aea54 authored by William Zhang's avatar William Zhang Committed by GitHub
Browse files

feat: Pass tool definitions to parsers (#4948)


Signed-off-by: default avatarWilliam Zhang <133824995+2ez4bz@users.noreply.github.com>
parent 7043707e
...@@ -801,6 +801,7 @@ impl OpenAIPreprocessor { ...@@ -801,6 +801,7 @@ impl OpenAIPreprocessor {
pub fn apply_tool_calling_jail<S>( pub fn apply_tool_calling_jail<S>(
tool_call_parser: Option<String>, tool_call_parser: Option<String>,
tool_choice: Option<dynamo_async_openai::types::ChatCompletionToolChoiceOption>, tool_choice: Option<dynamo_async_openai::types::ChatCompletionToolChoiceOption>,
tool_definitions: Option<Vec<dynamo_parsers::tool_calling::ToolDefinition>>,
stream: S, stream: S,
) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send ) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
where where
...@@ -810,6 +811,13 @@ impl OpenAIPreprocessor { ...@@ -810,6 +811,13 @@ impl OpenAIPreprocessor {
let mut builder = JailedStream::builder(); let mut builder = JailedStream::builder();
// Set tool definitions if provided
if let Some(tool_definitions) = tool_definitions
&& !tool_definitions.is_empty()
{
builder = builder.tool_definitions(tool_definitions);
}
// Configure jail based on tool_choice // Configure jail based on tool_choice
match tool_choice { match tool_choice {
Some(ChatCompletionToolChoiceOption::Named(named)) => { Some(ChatCompletionToolChoiceOption::Named(named)) => {
...@@ -991,11 +999,23 @@ impl ...@@ -991,11 +999,23 @@ impl
has_tools, has_tools,
)?; )?;
// Convert OpenAI tools to parser ToolDefinition format before applying jail
let tool_definitions = request.inner.tools.as_ref().map(|tools| {
tools
.iter()
.map(|tool| dynamo_parsers::tool_calling::ToolDefinition {
name: tool.function.name.clone(),
parameters: tool.function.parameters.clone(),
})
.collect()
});
// Apply jail conditionally // Apply jail conditionally
let transformed_stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_jail { let transformed_stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_jail {
Box::pin(Self::apply_tool_calling_jail( Box::pin(Self::apply_tool_calling_jail(
self.tool_call_parser.clone(), self.tool_call_parser.clone(),
request.inner.tool_choice.clone(), request.inner.tool_choice.clone(),
tool_definitions,
stream, stream,
)) ))
} else { } else {
......
...@@ -470,6 +470,7 @@ pub struct JailedStream { ...@@ -470,6 +470,7 @@ pub struct JailedStream {
jail_start_sequences: Vec<String>, jail_start_sequences: Vec<String>,
jail_end_sequences: Vec<String>, jail_end_sequences: Vec<String>,
tool_call_parser: Option<String>, tool_call_parser: Option<String>,
tool_definitions: Option<Vec<dynamo_parsers::tool_calling::ToolDefinition>>,
emission_mode: EmissionMode, emission_mode: EmissionMode,
marker_matcher: MarkerMatcher, marker_matcher: MarkerMatcher,
jail_mode: JailMode, jail_mode: JailMode,
...@@ -758,8 +759,13 @@ impl JailedStream { ...@@ -758,8 +759,13 @@ impl JailedStream {
} else if early_exit { } else if early_exit {
// For early exit, find where the complete tool call ends // For early exit, find where the complete tool call ends
if let Some(parser) = &self.tool_call_parser { if let Some(parser) = &self.tool_call_parser {
if let Ok((_, _)) = let tools_slice = self.tool_definitions.as_deref();
try_tool_call_parse_aggregate(accumulated_content, Some(parser)).await if let Ok((_, _)) = try_tool_call_parse_aggregate(
accumulated_content,
Some(parser),
tools_slice,
)
.await
{ {
let split_pos = let split_pos =
find_tool_call_end_position(accumulated_content, Some(parser)); find_tool_call_end_position(accumulated_content, Some(parser));
...@@ -814,9 +820,11 @@ impl JailedStream { ...@@ -814,9 +820,11 @@ impl JailedStream {
match &self.jail_mode { match &self.jail_mode {
JailMode::MarkerBased => { JailMode::MarkerBased => {
// Traditional marker-based tool call parsing // Traditional marker-based tool call parsing
let tools_slice = self.tool_definitions.as_deref();
if let Ok((tool_calls, normal_text)) = try_tool_call_parse_aggregate( if let Ok((tool_calls, normal_text)) = try_tool_call_parse_aggregate(
accumulated_content, accumulated_content,
self.tool_call_parser.as_deref(), self.tool_call_parser.as_deref(),
tools_slice,
) )
.await .await
&& !tool_calls.is_empty() && !tool_calls.is_empty()
...@@ -952,7 +960,8 @@ impl JailedStream { ...@@ -952,7 +960,8 @@ impl JailedStream {
async fn should_exit_jail_early(&self, accumulated: &str) -> bool { async fn should_exit_jail_early(&self, accumulated: &str) -> bool {
if let Some(ref parser) = self.tool_call_parser { if let Some(ref parser) = self.tool_call_parser {
// Try to parse - if successful and we have complete tool calls, exit early // Try to parse - if successful and we have complete tool calls, exit early
match try_tool_call_parse_aggregate(accumulated, Some(parser)).await { let tools_slice = self.tool_definitions.as_deref();
match try_tool_call_parse_aggregate(accumulated, Some(parser), tools_slice).await {
Ok((tool_calls, _normal_text)) => { Ok((tool_calls, _normal_text)) => {
let result = !tool_calls.is_empty(); let result = !tool_calls.is_empty();
return result; return result;
...@@ -1034,6 +1043,7 @@ pub struct JailedStreamBuilder { ...@@ -1034,6 +1043,7 @@ pub struct JailedStreamBuilder {
jail_start_sequences: Vec<String>, jail_start_sequences: Vec<String>,
jail_end_sequences: Vec<String>, jail_end_sequences: Vec<String>,
tool_call_parser: Option<String>, tool_call_parser: Option<String>,
tool_definitions: Option<Vec<dynamo_parsers::tool_calling::ToolDefinition>>,
emission_mode: EmissionMode, emission_mode: EmissionMode,
jail_mode: JailMode, jail_mode: JailMode,
} }
...@@ -1045,6 +1055,7 @@ impl JailedStreamBuilder { ...@@ -1045,6 +1055,7 @@ impl JailedStreamBuilder {
jail_start_sequences: Vec::new(), jail_start_sequences: Vec::new(),
jail_end_sequences: Vec::new(), jail_end_sequences: Vec::new(),
tool_call_parser: None, tool_call_parser: None,
tool_definitions: None,
emission_mode: EmissionMode::default(), emission_mode: EmissionMode::default(),
jail_mode: JailMode::MarkerBased, jail_mode: JailMode::MarkerBased,
} }
...@@ -1088,6 +1099,15 @@ impl JailedStreamBuilder { ...@@ -1088,6 +1099,15 @@ impl JailedStreamBuilder {
self self
} }
/// Set the tool definitions for runtime validation and parsing
pub fn tool_definitions(
mut self,
tools: Vec<dynamo_parsers::tool_calling::ToolDefinition>,
) -> Self {
self.tool_definitions = Some(tools);
self
}
/// Set the emission mode for handling multiple choices /// Set the emission mode for handling multiple choices
pub fn emission_mode(mut self, mode: EmissionMode) -> Self { pub fn emission_mode(mut self, mode: EmissionMode) -> Self {
self.emission_mode = mode; self.emission_mode = mode;
...@@ -1198,6 +1218,7 @@ impl JailedStreamBuilder { ...@@ -1198,6 +1218,7 @@ impl JailedStreamBuilder {
jail_start_sequences: self.jail_start_sequences, jail_start_sequences: self.jail_start_sequences,
jail_end_sequences: self.jail_end_sequences, jail_end_sequences: self.jail_end_sequences,
tool_call_parser: self.tool_call_parser, tool_call_parser: self.tool_call_parser,
tool_definitions: self.tool_definitions,
emission_mode: self.emission_mode, emission_mode: self.emission_mode,
marker_matcher, marker_matcher,
jail_mode: self.jail_mode, jail_mode: self.jail_mode,
......
...@@ -197,7 +197,7 @@ async fn test_parallel_tool_call_parsing() { ...@@ -197,7 +197,7 @@ async fn test_parallel_tool_call_parsing() {
// Parse the tool calls using the hermes parser (works well with <tool_call> format) // Parse the tool calls using the hermes parser (works well with <tool_call> format)
let (tool_calls, remaining_content) = let (tool_calls, remaining_content) =
detect_and_parse_tool_call(&response_content, Some("hermes")) detect_and_parse_tool_call(&response_content, Some("hermes"), None)
.await .await
.expect("Should successfully parse tool calls"); .expect("Should successfully parse tool calls");
...@@ -239,7 +239,7 @@ async fn test_parallel_tool_call_with_explicit_parser() { ...@@ -239,7 +239,7 @@ async fn test_parallel_tool_call_with_explicit_parser() {
for parser in parsers_to_test { for parser in parsers_to_test {
let (tool_calls, remaining_content) = let (tool_calls, remaining_content) =
detect_and_parse_tool_call(&response_content, Some(parser)) detect_and_parse_tool_call(&response_content, Some(parser), None)
.await .await
.unwrap_or_else(|e| panic!("Should successfully parse with {parser} parser: {e}")); .unwrap_or_else(|e| panic!("Should successfully parse with {parser} parser: {e}"));
...@@ -267,7 +267,7 @@ async fn test_parallel_tool_call_with_explicit_parser() { ...@@ -267,7 +267,7 @@ async fn test_parallel_tool_call_with_explicit_parser() {
async fn test_tool_call_json_structure() { async fn test_tool_call_json_structure() {
let response_content = get_mock_response_content(); let response_content = get_mock_response_content();
let (tool_calls, _) = detect_and_parse_tool_call(&response_content, Some("hermes")) let (tool_calls, _) = detect_and_parse_tool_call(&response_content, Some("hermes"), None)
.await .await
.expect("Should parse tool calls"); .expect("Should parse tool calls");
...@@ -288,7 +288,7 @@ async fn test_tool_call_json_structure() { ...@@ -288,7 +288,7 @@ async fn test_tool_call_json_structure() {
async fn test_openai_compatibility_structure() { async fn test_openai_compatibility_structure() {
let response_content = get_mock_response_content(); let response_content = get_mock_response_content();
let (tool_calls, _) = detect_and_parse_tool_call(&response_content, Some("hermes")) let (tool_calls, _) = detect_and_parse_tool_call(&response_content, Some("hermes"), None)
.await .await
.expect("Should parse tool calls"); .expect("Should parse tool calls");
...@@ -335,7 +335,7 @@ async fn test_parallel_tool_call_error_handling() { ...@@ -335,7 +335,7 @@ async fn test_parallel_tool_call_error_handling() {
{"invalid_json": } {"invalid_json": }
</tool_call>"#; </tool_call>"#;
let result = detect_and_parse_tool_call(malformed_content, Some("hermes")).await; let result = detect_and_parse_tool_call(malformed_content, Some("hermes"), None).await;
// Should handle partial parsing gracefully // Should handle partial parsing gracefully
match result { match result {
...@@ -368,7 +368,7 @@ async fn test_empty_tool_calls() { ...@@ -368,7 +368,7 @@ async fn test_empty_tool_calls() {
let content_without_tools = "This is just a regular response without any tool calls."; let content_without_tools = "This is just a regular response without any tool calls.";
let (tool_calls, remaining_content) = let (tool_calls, remaining_content) =
detect_and_parse_tool_call(content_without_tools, Some("hermes")) detect_and_parse_tool_call(content_without_tools, Some("hermes"), None)
.await .await
.expect("Should handle content without tool calls"); .expect("Should handle content without tool calls");
...@@ -412,7 +412,7 @@ async fn test_deepseek_v3_1_tool_call_parsing() { ...@@ -412,7 +412,7 @@ async fn test_deepseek_v3_1_tool_call_parsing() {
// Parse the tool calls using the deepseek_v3_1 parser // Parse the tool calls using the deepseek_v3_1 parser
let (tool_calls, remaining_content) = let (tool_calls, remaining_content) =
detect_and_parse_tool_call(response_content, Some("deepseek_v3_1")) detect_and_parse_tool_call(response_content, Some("deepseek_v3_1"), None)
.await .await
.expect("Should successfully parse deepseek_v3_1 tool calls"); .expect("Should successfully parse deepseek_v3_1 tool calls");
......
...@@ -2045,6 +2045,8 @@ mod tests { ...@@ -2045,6 +2045,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_jailed_stream_qwen3_coder_multiple_params() { async fn test_jailed_stream_qwen3_coder_multiple_params() {
use dynamo_parsers::tool_calling::ToolDefinition;
let chunks = vec![ let chunks = vec![
create_mock_response_chunk("Let me search for that. ".to_string(), 0), create_mock_response_chunk("Let me search for that. ".to_string(), 0),
create_mock_response_chunk( create_mock_response_chunk(
...@@ -2054,9 +2056,23 @@ mod tests { ...@@ -2054,9 +2056,23 @@ mod tests {
create_mock_response_chunk(" Searching now.".to_string(), 0), create_mock_response_chunk(" Searching now.".to_string(), 0),
]; ];
// Define the web_search tool with its parameters
let tool_defs = vec![ToolDefinition {
name: "web_search".to_string(),
parameters: Some(serde_json::json!({
"type": "object",
"properties": {
"query": {"type": "string"},
"max_results": {"type": "integer"},
"filter": {"type": "string"},
},
})),
}];
let input_stream = stream::iter(chunks); let input_stream = stream::iter(chunks);
let jail = JailedStream::builder() let jail = JailedStream::builder()
.tool_call_parser("qwen3_coder") .tool_call_parser("qwen3_coder")
.tool_definitions(tool_defs)
.build(); .build();
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await; let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
......
...@@ -487,6 +487,7 @@ mod tests { ...@@ -487,6 +487,7 @@ mod tests {
let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail( let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail(
Some("nemotron_deci".to_string()), Some("nemotron_deci".to_string()),
None, // No tool_choice in this test None, // No tool_choice in this test
None, // No tool_definitions in this test
reasoning_parsed_stream, reasoning_parsed_stream,
); );
...@@ -600,6 +601,7 @@ mod tests { ...@@ -600,6 +601,7 @@ mod tests {
let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail( let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail(
Some("harmony".to_string()), Some("harmony".to_string()),
None, // No tool_choice in this test None, // No tool_choice in this test
None, // No tool_definitions in this test
reasoning_parsed_stream, reasoning_parsed_stream,
); );
......
...@@ -160,6 +160,7 @@ async fn parse_response_stream( ...@@ -160,6 +160,7 @@ async fn parse_response_stream(
Box::pin(OpenAIPreprocessor::apply_tool_calling_jail( Box::pin(OpenAIPreprocessor::apply_tool_calling_jail(
Some(tool_parser), Some(tool_parser),
None, // No tool_choice in this test None, // No tool_choice in this test
None, // No tool_definitions in this test
stream, stream,
)) ))
} else { } else {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::super::ToolDefinition;
use super::config::JsonParserConfig; use super::config::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType}; use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
use openai_harmony::chat::{Content::Text, Role}; use openai_harmony::chat::{Content::Text, Role};
...@@ -46,6 +47,7 @@ pub async fn get_harmony_encoding() -> &'static Result<HarmonyEncoding, anyhow:: ...@@ -46,6 +47,7 @@ pub async fn get_harmony_encoding() -> &'static Result<HarmonyEncoding, anyhow::
pub async fn parse_tool_calls_harmony_complete( pub async fn parse_tool_calls_harmony_complete(
text: &str, text: &str,
_config: &JsonParserConfig, _config: &JsonParserConfig,
_tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> { ) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
let enc = match get_harmony_encoding().await.as_ref() { let enc = match get_harmony_encoding().await.as_ref() {
Ok(e) => e, Ok(e) => e,
...@@ -212,7 +214,7 @@ mod tests { ...@@ -212,7 +214,7 @@ mod tests {
async fn test_parse_tool_calls_harmony_complete_basic() { async fn test_parse_tool_calls_harmony_complete_basic() {
let text = r#"<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"format":"celsius","location":"San Francisco"}"#; let text = r#"<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"format":"celsius","location":"San Francisco"}"#;
let (tool_calls, normal_content) = let (tool_calls, normal_content) =
parse_tool_calls_harmony_complete(text, &Default::default()) parse_tool_calls_harmony_complete(text, &Default::default(), None)
.await .await
.unwrap(); .unwrap();
assert_eq!(normal_content, Some("".to_string())); assert_eq!(normal_content, Some("".to_string()));
...@@ -226,7 +228,7 @@ mod tests { ...@@ -226,7 +228,7 @@ mod tests {
async fn test_parse_tools_harmony_without_start_token() { async fn test_parse_tools_harmony_without_start_token() {
let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|message|>{"location":"San Francisco"}<|call|>"#; let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|message|>{"location":"San Francisco"}<|call|>"#;
let (tool_calls, normal_content) = let (tool_calls, normal_content) =
parse_tool_calls_harmony_complete(text, &Default::default()) parse_tool_calls_harmony_complete(text, &Default::default(), None)
.await .await
.unwrap(); .unwrap();
assert_eq!(normal_content, Some(text.trim().to_string())); assert_eq!(normal_content, Some(text.trim().to_string()));
...@@ -237,7 +239,7 @@ mod tests { ...@@ -237,7 +239,7 @@ mod tests {
async fn test_parse_tool_calls_harmony_with_multi_args() { async fn test_parse_tool_calls_harmony_with_multi_args() {
let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco", "unit":"fahrenheit"}<|call|>"#; let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco", "unit":"fahrenheit"}<|call|>"#;
let (tool_calls, normal_content) = let (tool_calls, normal_content) =
parse_tool_calls_harmony_complete(text, &Default::default()) parse_tool_calls_harmony_complete(text, &Default::default(), None)
.await .await
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
...@@ -255,7 +257,7 @@ mod tests { ...@@ -255,7 +257,7 @@ mod tests {
async fn test_parse_tool_calls_harmony_with_normal_text() { async fn test_parse_tool_calls_harmony_with_normal_text() {
let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco"}<|call|>"#; let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco"}<|call|>"#;
let (tool_calls, normal_content) = let (tool_calls, normal_content) =
parse_tool_calls_harmony_complete(text, &Default::default()) parse_tool_calls_harmony_complete(text, &Default::default(), None)
.await .await
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
...@@ -272,7 +274,7 @@ mod tests { ...@@ -272,7 +274,7 @@ mod tests {
async fn test_parse_tool_calls_harmony_without_call_token() { async fn test_parse_tool_calls_harmony_without_call_token() {
let text = r#"<|channel|>analysis<|message|>We need to call get_weather function. The user asks "What's the weather like in San Francisco in Celsius?" So location: "San Francisco, CA" unit: "celsius". Let's call function.<|end|><|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{"location":"San Francisco, CA","unit":"celsius"}"#; let text = r#"<|channel|>analysis<|message|>We need to call get_weather function. The user asks "What's the weather like in San Francisco in Celsius?" So location: "San Francisco, CA" unit: "celsius". Let's call function.<|end|><|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{"location":"San Francisco, CA","unit":"celsius"}"#;
let (tool_calls, normal_content) = let (tool_calls, normal_content) =
parse_tool_calls_harmony_complete(text, &Default::default()) parse_tool_calls_harmony_complete(text, &Default::default(), None)
.await .await
.unwrap(); .unwrap();
assert_eq!(normal_content, Some("We need to call get_weather function. The user asks \"What's the weather like in San Francisco in Celsius?\" So location: \"San Francisco, CA\" unit: \"celsius\". Let's call function.".to_string())); assert_eq!(normal_content, Some("We need to call get_weather function. The user asks \"What's the weather like in San Francisco in Celsius?\" So location: \"San Francisco, CA\" unit: \"celsius\". Let's call function.".to_string()));
......
...@@ -7,6 +7,7 @@ use regex::RegexBuilder; ...@@ -7,6 +7,7 @@ use regex::RegexBuilder;
use serde_json::Value; use serde_json::Value;
use uuid::Uuid; use uuid::Uuid;
use super::super::ToolDefinition;
use super::config::JsonParserConfig; use super::config::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType}; use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
...@@ -165,6 +166,7 @@ fn try_parse_normal_text(input: &str, start_token: &str) -> String { ...@@ -165,6 +166,7 @@ fn try_parse_normal_text(input: &str, start_token: &str) -> String {
pub fn try_tool_call_parse_basic_json( pub fn try_tool_call_parse_basic_json(
message: &str, message: &str,
config: &JsonParserConfig, config: &JsonParserConfig,
_tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> { ) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
// Log the config we are using // Log the config we are using
tracing::debug!("Using JSON parser config: {:?}", config); tracing::debug!("Using JSON parser config: {:?}", config);
......
...@@ -5,6 +5,7 @@ use regex::RegexBuilder; ...@@ -5,6 +5,7 @@ use regex::RegexBuilder;
use serde_json::Value; use serde_json::Value;
use uuid::Uuid; use uuid::Uuid;
use super::super::ToolDefinition;
use super::config::JsonParserConfig; use super::config::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType}; use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
...@@ -119,6 +120,7 @@ fn parse_single_tool_call_v3_1( ...@@ -119,6 +120,7 @@ fn parse_single_tool_call_v3_1(
pub fn parse_tool_calls_deepseek_v3_1( pub fn parse_tool_calls_deepseek_v3_1(
message: &str, message: &str,
config: &JsonParserConfig, config: &JsonParserConfig,
_tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> { ) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
// Format Structure: // Format Structure:
// <|tool▁calls▁begin|><|tool▁call▁begin|>{function_name}<|tool▁sep|>{json_arguments}<|tool▁call▁end|><|tool▁calls▁end|> // <|tool▁calls▁begin|><|tool▁call▁begin|>{function_name}<|tool▁sep|>{json_arguments}<|tool▁call▁end|><|tool▁calls▁end|>
...@@ -275,7 +277,7 @@ mod tests { ...@@ -275,7 +277,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg, super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"), _ => panic!("Expected JSON parser config"),
}; };
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(content, Some("".to_string())); assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone()); let (name, args) = extract_name_and_args(result[0].clone());
...@@ -293,7 +295,7 @@ mod tests { ...@@ -293,7 +295,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg, super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"), _ => panic!("Expected JSON parser config"),
}; };
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!( assert_eq!(
content, content,
Some("The following tool call retrieves weather information: ".to_string()) Some("The following tool call retrieves weather information: ".to_string())
...@@ -311,7 +313,7 @@ mod tests { ...@@ -311,7 +313,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg, super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"), _ => panic!("Expected JSON parser config"),
}; };
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(content, Some(text.to_string())); assert_eq!(content, Some(text.to_string()));
assert_eq!(result.len(), 0); assert_eq!(result.len(), 0);
} }
...@@ -323,7 +325,7 @@ mod tests { ...@@ -323,7 +325,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg, super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"), _ => panic!("Expected JSON parser config"),
}; };
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(content, Some("".to_string())); assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 3); assert_eq!(result.len(), 3);
let (name, args) = extract_name_and_args(result[0].clone()); let (name, args) = extract_name_and_args(result[0].clone());
...@@ -349,7 +351,7 @@ mod tests { ...@@ -349,7 +351,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg, super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"), _ => panic!("Expected JSON parser config"),
}; };
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(content, Some(text.trim().to_string())); assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0); assert_eq!(result.len(), 0);
} }
...@@ -362,7 +364,7 @@ mod tests { ...@@ -362,7 +364,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg, super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"), _ => panic!("Expected JSON parser config"),
}; };
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(content, Some(text.trim().to_string())); assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0); assert_eq!(result.len(), 0);
} }
...@@ -388,7 +390,7 @@ mod tests { ...@@ -388,7 +390,7 @@ mod tests {
}; };
let (tool_call_results, normal_content) = let (tool_call_results, normal_content) =
parse_tool_calls_deepseek_v3_1(text, &config).unwrap(); parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(tool_call_results.len(), 1); assert_eq!(tool_call_results.len(), 1);
......
...@@ -5,6 +5,7 @@ use regex::RegexBuilder; ...@@ -5,6 +5,7 @@ use regex::RegexBuilder;
use serde_json::Value; use serde_json::Value;
use uuid::Uuid; use uuid::Uuid;
use super::super::ToolDefinition;
use super::config::JsonParserConfig; use super::config::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType}; use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
...@@ -129,6 +130,7 @@ fn parse_single_tool_call_v3(block: &str, separator_tokens: &[String]) -> Option ...@@ -129,6 +130,7 @@ fn parse_single_tool_call_v3(block: &str, separator_tokens: &[String]) -> Option
pub fn parse_tool_calls_deepseek_v3( pub fn parse_tool_calls_deepseek_v3(
message: &str, message: &str,
config: &JsonParserConfig, config: &JsonParserConfig,
_tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> { ) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
// Format Structure: // Format Structure:
// <|tool▁calls▁begin|><|tool▁call▁begin|>{type}<|tool▁sep|>{function_name}\n```json\n{json_arguments}\n```<|tool▁call▁end|><|tool▁calls▁end|> // <|tool▁calls▁begin|><|tool▁call▁begin|>{type}<|tool▁sep|>{function_name}\n```json\n{json_arguments}\n```<|tool▁call▁end|><|tool▁calls▁end|>
...@@ -285,7 +287,7 @@ mod tests { ...@@ -285,7 +287,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg, super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"), _ => panic!("Expected JSON parser config"),
}; };
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(content, Some("".to_string())); assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone()); let (name, args) = extract_name_and_args(result[0].clone());
...@@ -306,7 +308,7 @@ mod tests { ...@@ -306,7 +308,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg, super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"), _ => panic!("Expected JSON parser config"),
}; };
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!( assert_eq!(
content, content,
Some("The following tool call retrieves weather information: ".to_string()) Some("The following tool call retrieves weather information: ".to_string())
...@@ -327,7 +329,7 @@ mod tests { ...@@ -327,7 +329,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg, super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"), _ => panic!("Expected JSON parser config"),
}; };
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(content, Some(text.to_string())); assert_eq!(content, Some(text.to_string()));
assert_eq!(result.len(), 0); assert_eq!(result.len(), 0);
} }
...@@ -348,7 +350,7 @@ mod tests { ...@@ -348,7 +350,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg, super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"), _ => panic!("Expected JSON parser config"),
}; };
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(content, Some("".to_string())); assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 3); assert_eq!(result.len(), 3);
let (name, args) = extract_name_and_args(result[0].clone()); let (name, args) = extract_name_and_args(result[0].clone());
...@@ -377,7 +379,7 @@ mod tests { ...@@ -377,7 +379,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg, super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"), _ => panic!("Expected JSON parser config"),
}; };
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(content, Some(text.trim().to_string())); assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0); assert_eq!(result.len(), 0);
} }
...@@ -399,7 +401,7 @@ mod tests { ...@@ -399,7 +401,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg, super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"), _ => panic!("Expected JSON parser config"),
}; };
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap(); let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(content, Some(text.trim().to_string())); assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0); assert_eq!(result.len(), 0);
} }
...@@ -428,7 +430,7 @@ mod tests { ...@@ -428,7 +430,7 @@ mod tests {
}; };
let (tool_call_results, normal_content) = let (tool_call_results, normal_content) =
parse_tool_calls_deepseek_v3(text, &config).unwrap(); parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(tool_call_results.len(), 1); assert_eq!(tool_call_results.len(), 1);
......
...@@ -33,11 +33,12 @@ impl Default for JsonParserType { ...@@ -33,11 +33,12 @@ impl Default for JsonParserType {
pub fn try_tool_call_parse_json( pub fn try_tool_call_parse_json(
message: &str, message: &str,
config: &JsonParserConfig, config: &JsonParserConfig,
tools: Option<&[super::ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> { ) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
match config.parser_type { match config.parser_type {
JsonParserType::Basic => try_tool_call_parse_basic_json(message, config), JsonParserType::Basic => try_tool_call_parse_basic_json(message, config, tools),
JsonParserType::DeepseekV3 => parse_tool_calls_deepseek_v3(message, config), JsonParserType::DeepseekV3 => parse_tool_calls_deepseek_v3(message, config, tools),
JsonParserType::DeepseekV31 => parse_tool_calls_deepseek_v3_1(message, config), JsonParserType::DeepseekV31 => parse_tool_calls_deepseek_v3_1(message, config, tools),
} }
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use serde_json::Value;
pub mod config; pub mod config;
pub mod dsml; pub mod dsml;
pub mod harmony; pub mod harmony;
...@@ -13,6 +15,13 @@ pub mod tests; ...@@ -13,6 +15,13 @@ pub mod tests;
pub mod tools; pub mod tools;
pub mod xml; pub mod xml;
/// Represents a tool definition with function schema.
#[derive(Debug, Clone)]
pub struct ToolDefinition {
pub name: String,
pub parameters: Option<Value>,
}
// Re-export main types and functions for convenience // Re-export main types and functions for convenience
pub use config::{JsonParserConfig, ParserConfig, ToolCallConfig, XmlParserConfig}; pub use config::{JsonParserConfig, ParserConfig, ToolCallConfig, XmlParserConfig};
pub use dsml::try_tool_call_parse_dsml; pub use dsml::try_tool_call_parse_dsml;
......
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::super::ToolDefinition;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType}; use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
use regex::Regex; use regex::Regex;
use rustpython_parser::{ use rustpython_parser::{
...@@ -161,6 +162,7 @@ fn const_expr(e: &Expr) -> Result<Value, Box<dyn std::error::Error>> { ...@@ -161,6 +162,7 @@ fn const_expr(e: &Expr) -> Result<Value, Box<dyn std::error::Error>> {
pub fn try_tool_call_parse_pythonic( pub fn try_tool_call_parse_pythonic(
message: &str, message: &str,
_tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> { ) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
let stripped = strip_text(message).trim().to_string(); let stripped = strip_text(message).trim().to_string();
...@@ -263,7 +265,7 @@ mod tests { ...@@ -263,7 +265,7 @@ mod tests {
#[test] #[test]
fn test_parse_tool_call_parse_pythonic_basic() { fn test_parse_tool_call_parse_pythonic_basic() {
let message = "[foo(a=1, b=2), bar(x=3)]"; let message = "[foo(a=1, b=2), bar(x=3)]";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap(); let (result, content) = try_tool_call_parse_pythonic(message, None).unwrap();
assert_eq!(content, Some("".to_string())); assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty()); assert!(!result.is_empty());
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
...@@ -279,7 +281,7 @@ mod tests { ...@@ -279,7 +281,7 @@ mod tests {
#[test] #[test]
fn test_parse_tool_call_parse_pythonic_with_text() { fn test_parse_tool_call_parse_pythonic_with_text() {
let message = "Hey yo ! [foo(a=1, b=2), bar(x=3)] Hey yo"; let message = "Hey yo ! [foo(a=1, b=2), bar(x=3)] Hey yo";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap(); let (result, content) = try_tool_call_parse_pythonic(message, None).unwrap();
assert_eq!(content, Some("Hey yo !".to_string())); assert_eq!(content, Some("Hey yo !".to_string()));
assert!(!result.is_empty()); assert!(!result.is_empty());
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
...@@ -295,7 +297,7 @@ mod tests { ...@@ -295,7 +297,7 @@ mod tests {
#[test] #[test]
fn test_parse_tool_call_parse_pythonic_with_text_and_new_line() { fn test_parse_tool_call_parse_pythonic_with_text_and_new_line() {
let message = "Hey \n yo ! [foo(a=1, b=2), bar(x=3)] Hey yo"; let message = "Hey \n yo ! [foo(a=1, b=2), bar(x=3)] Hey yo";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap(); let (result, content) = try_tool_call_parse_pythonic(message, None).unwrap();
assert_eq!(content, Some("Hey \n yo !".to_string())); assert_eq!(content, Some("Hey \n yo !".to_string()));
assert!(!result.is_empty()); assert!(!result.is_empty());
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
...@@ -311,7 +313,7 @@ mod tests { ...@@ -311,7 +313,7 @@ mod tests {
#[test] #[test]
fn test_parse_tool_call_parse_pythonic_with_no_calls() { fn test_parse_tool_call_parse_pythonic_with_no_calls() {
let message = "Hey \n yo !"; let message = "Hey \n yo !";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap(); let (result, content) = try_tool_call_parse_pythonic(message, None).unwrap();
assert_eq!(content, Some("Hey \n yo !".to_string())); assert_eq!(content, Some("Hey \n yo !".to_string()));
assert!(result.is_empty()); assert!(result.is_empty());
assert_eq!(result.len(), 0) assert_eq!(result.len(), 0)
...@@ -320,7 +322,7 @@ mod tests { ...@@ -320,7 +322,7 @@ mod tests {
#[test] #[test]
fn test_parse_tool_call_parse_pythonic_with_python_tags() { fn test_parse_tool_call_parse_pythonic_with_python_tags() {
let message = "<|python_start|>[foo(a=1, b=2), bar(x=3)]<|python_end|>"; let message = "<|python_start|>[foo(a=1, b=2), bar(x=3)]<|python_end|>";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap(); let (result, content) = try_tool_call_parse_pythonic(message, None).unwrap();
assert_eq!(content, Some("".to_string())); assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty()); assert!(!result.is_empty());
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
...@@ -336,7 +338,7 @@ mod tests { ...@@ -336,7 +338,7 @@ mod tests {
#[test] #[test]
fn test_parse_tool_call_parse_pythonic_with_list_arg_values() { fn test_parse_tool_call_parse_pythonic_with_list_arg_values() {
let message = "[foo(a=[1, 2, 3], b=2), bar(x=[3, 4, 5])]"; let message = "[foo(a=[1, 2, 3], b=2), bar(x=[3, 4, 5])]";
let (result, _) = try_tool_call_parse_pythonic(message).unwrap(); let (result, _) = try_tool_call_parse_pythonic(message, None).unwrap();
assert!(!result.is_empty()); assert!(!result.is_empty());
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone()); let (name, args) = extract_name_and_args(result[0].clone());
...@@ -351,7 +353,7 @@ mod tests { ...@@ -351,7 +353,7 @@ mod tests {
#[test] #[test]
fn test_parse_tool_call_parse_pythonic_with_dict_arg_values() { fn test_parse_tool_call_parse_pythonic_with_dict_arg_values() {
let message = "[foo(a={'a': 1, 'b': 2}, b=2), bar(x={'x': 3, 'y': {'e': 'f'}})]"; let message = "[foo(a={'a': 1, 'b': 2}, b=2), bar(x={'x': 3, 'y': {'e': 'f'}})]";
let (result, _) = try_tool_call_parse_pythonic(message).unwrap(); let (result, _) = try_tool_call_parse_pythonic(message, None).unwrap();
assert!(!result.is_empty()); assert!(!result.is_empty());
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone()); let (name, args) = extract_name_and_args(result[0].clone());
......
...@@ -10,6 +10,7 @@ pub use super::parsers::detect_and_parse_tool_call; ...@@ -10,6 +10,7 @@ pub use super::parsers::detect_and_parse_tool_call;
pub async fn try_tool_call_parse_aggregate( pub async fn try_tool_call_parse_aggregate(
message: &str, message: &str,
parser_str: Option<&str>, parser_str: Option<&str>,
tools: Option<&[super::ToolDefinition]>,
) -> anyhow::Result<( ) -> anyhow::Result<(
Vec<dynamo_async_openai::types::ChatCompletionMessageToolCall>, Vec<dynamo_async_openai::types::ChatCompletionMessageToolCall>,
Option<String>, Option<String>,
...@@ -19,7 +20,7 @@ pub async fn try_tool_call_parse_aggregate( ...@@ -19,7 +20,7 @@ pub async fn try_tool_call_parse_aggregate(
} else { } else {
tracing::info!("Using tool parser: {:?}", parser_str); tracing::info!("Using tool parser: {:?}", parser_str);
} }
let (parsed, content) = detect_and_parse_tool_call(message, parser_str).await?; let (parsed, content) = detect_and_parse_tool_call(message, parser_str, tools).await?;
if parsed.is_empty() { if parsed.is_empty() {
return Ok((vec![], content)); return Ok((vec![], content));
} }
...@@ -47,11 +48,12 @@ pub async fn try_tool_call_parse_aggregate( ...@@ -47,11 +48,12 @@ pub async fn try_tool_call_parse_aggregate(
pub async fn try_tool_call_parse_stream( pub async fn try_tool_call_parse_stream(
message: &str, message: &str,
parser_str: Option<&str>, parser_str: Option<&str>,
tools: Option<&[super::ToolDefinition]>,
) -> anyhow::Result<( ) -> anyhow::Result<(
Vec<dynamo_async_openai::types::ChatCompletionMessageToolCallChunk>, Vec<dynamo_async_openai::types::ChatCompletionMessageToolCallChunk>,
Option<String>, Option<String>,
)> { )> {
let (parsed, content) = detect_and_parse_tool_call(message, parser_str).await?; let (parsed, content) = detect_and_parse_tool_call(message, parser_str, tools).await?;
if parsed.is_empty() { if parsed.is_empty() {
return Ok((vec![], content)); return Ok((vec![], content));
} }
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment