Unverified Commit ff7aea54 authored by William Zhang's avatar William Zhang Committed by GitHub
Browse files

feat: Pass tool definitions to parsers (#4948)


Signed-off-by: default avatarWilliam Zhang <133824995+2ez4bz@users.noreply.github.com>
parent 7043707e
......@@ -801,6 +801,7 @@ impl OpenAIPreprocessor {
pub fn apply_tool_calling_jail<S>(
tool_call_parser: Option<String>,
tool_choice: Option<dynamo_async_openai::types::ChatCompletionToolChoiceOption>,
tool_definitions: Option<Vec<dynamo_parsers::tool_calling::ToolDefinition>>,
stream: S,
) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
where
......@@ -810,6 +811,13 @@ impl OpenAIPreprocessor {
let mut builder = JailedStream::builder();
// Set tool definitions if provided
if let Some(tool_definitions) = tool_definitions
&& !tool_definitions.is_empty()
{
builder = builder.tool_definitions(tool_definitions);
}
// Configure jail based on tool_choice
match tool_choice {
Some(ChatCompletionToolChoiceOption::Named(named)) => {
......@@ -991,11 +999,23 @@ impl
has_tools,
)?;
// Convert OpenAI tools to parser ToolDefinition format before applying jail
let tool_definitions = request.inner.tools.as_ref().map(|tools| {
tools
.iter()
.map(|tool| dynamo_parsers::tool_calling::ToolDefinition {
name: tool.function.name.clone(),
parameters: tool.function.parameters.clone(),
})
.collect()
});
// Apply jail conditionally
let transformed_stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_jail {
Box::pin(Self::apply_tool_calling_jail(
self.tool_call_parser.clone(),
request.inner.tool_choice.clone(),
tool_definitions,
stream,
))
} else {
......
......@@ -470,6 +470,7 @@ pub struct JailedStream {
jail_start_sequences: Vec<String>,
jail_end_sequences: Vec<String>,
tool_call_parser: Option<String>,
tool_definitions: Option<Vec<dynamo_parsers::tool_calling::ToolDefinition>>,
emission_mode: EmissionMode,
marker_matcher: MarkerMatcher,
jail_mode: JailMode,
......@@ -758,8 +759,13 @@ impl JailedStream {
} else if early_exit {
// For early exit, find where the complete tool call ends
if let Some(parser) = &self.tool_call_parser {
if let Ok((_, _)) =
try_tool_call_parse_aggregate(accumulated_content, Some(parser)).await
let tools_slice = self.tool_definitions.as_deref();
if let Ok((_, _)) = try_tool_call_parse_aggregate(
accumulated_content,
Some(parser),
tools_slice,
)
.await
{
let split_pos =
find_tool_call_end_position(accumulated_content, Some(parser));
......@@ -814,9 +820,11 @@ impl JailedStream {
match &self.jail_mode {
JailMode::MarkerBased => {
// Traditional marker-based tool call parsing
let tools_slice = self.tool_definitions.as_deref();
if let Ok((tool_calls, normal_text)) = try_tool_call_parse_aggregate(
accumulated_content,
self.tool_call_parser.as_deref(),
tools_slice,
)
.await
&& !tool_calls.is_empty()
......@@ -952,7 +960,8 @@ impl JailedStream {
async fn should_exit_jail_early(&self, accumulated: &str) -> bool {
if let Some(ref parser) = self.tool_call_parser {
// Try to parse - if successful and we have complete tool calls, exit early
match try_tool_call_parse_aggregate(accumulated, Some(parser)).await {
let tools_slice = self.tool_definitions.as_deref();
match try_tool_call_parse_aggregate(accumulated, Some(parser), tools_slice).await {
Ok((tool_calls, _normal_text)) => {
let result = !tool_calls.is_empty();
return result;
......@@ -1034,6 +1043,7 @@ pub struct JailedStreamBuilder {
jail_start_sequences: Vec<String>,
jail_end_sequences: Vec<String>,
tool_call_parser: Option<String>,
tool_definitions: Option<Vec<dynamo_parsers::tool_calling::ToolDefinition>>,
emission_mode: EmissionMode,
jail_mode: JailMode,
}
......@@ -1045,6 +1055,7 @@ impl JailedStreamBuilder {
jail_start_sequences: Vec::new(),
jail_end_sequences: Vec::new(),
tool_call_parser: None,
tool_definitions: None,
emission_mode: EmissionMode::default(),
jail_mode: JailMode::MarkerBased,
}
......@@ -1088,6 +1099,15 @@ impl JailedStreamBuilder {
self
}
/// Set the tool definitions for runtime validation and parsing
pub fn tool_definitions(
mut self,
tools: Vec<dynamo_parsers::tool_calling::ToolDefinition>,
) -> Self {
self.tool_definitions = Some(tools);
self
}
/// Set the emission mode for handling multiple choices
pub fn emission_mode(mut self, mode: EmissionMode) -> Self {
self.emission_mode = mode;
......@@ -1198,6 +1218,7 @@ impl JailedStreamBuilder {
jail_start_sequences: self.jail_start_sequences,
jail_end_sequences: self.jail_end_sequences,
tool_call_parser: self.tool_call_parser,
tool_definitions: self.tool_definitions,
emission_mode: self.emission_mode,
marker_matcher,
jail_mode: self.jail_mode,
......
......@@ -197,7 +197,7 @@ async fn test_parallel_tool_call_parsing() {
// Parse the tool calls using the hermes parser (works well with <tool_call> format)
let (tool_calls, remaining_content) =
detect_and_parse_tool_call(&response_content, Some("hermes"))
detect_and_parse_tool_call(&response_content, Some("hermes"), None)
.await
.expect("Should successfully parse tool calls");
......@@ -239,7 +239,7 @@ async fn test_parallel_tool_call_with_explicit_parser() {
for parser in parsers_to_test {
let (tool_calls, remaining_content) =
detect_and_parse_tool_call(&response_content, Some(parser))
detect_and_parse_tool_call(&response_content, Some(parser), None)
.await
.unwrap_or_else(|e| panic!("Should successfully parse with {parser} parser: {e}"));
......@@ -267,7 +267,7 @@ async fn test_parallel_tool_call_with_explicit_parser() {
async fn test_tool_call_json_structure() {
let response_content = get_mock_response_content();
let (tool_calls, _) = detect_and_parse_tool_call(&response_content, Some("hermes"))
let (tool_calls, _) = detect_and_parse_tool_call(&response_content, Some("hermes"), None)
.await
.expect("Should parse tool calls");
......@@ -288,7 +288,7 @@ async fn test_tool_call_json_structure() {
async fn test_openai_compatibility_structure() {
let response_content = get_mock_response_content();
let (tool_calls, _) = detect_and_parse_tool_call(&response_content, Some("hermes"))
let (tool_calls, _) = detect_and_parse_tool_call(&response_content, Some("hermes"), None)
.await
.expect("Should parse tool calls");
......@@ -335,7 +335,7 @@ async fn test_parallel_tool_call_error_handling() {
{"invalid_json": }
</tool_call>"#;
let result = detect_and_parse_tool_call(malformed_content, Some("hermes")).await;
let result = detect_and_parse_tool_call(malformed_content, Some("hermes"), None).await;
// Should handle partial parsing gracefully
match result {
......@@ -368,7 +368,7 @@ async fn test_empty_tool_calls() {
let content_without_tools = "This is just a regular response without any tool calls.";
let (tool_calls, remaining_content) =
detect_and_parse_tool_call(content_without_tools, Some("hermes"))
detect_and_parse_tool_call(content_without_tools, Some("hermes"), None)
.await
.expect("Should handle content without tool calls");
......@@ -412,7 +412,7 @@ async fn test_deepseek_v3_1_tool_call_parsing() {
// Parse the tool calls using the deepseek_v3_1 parser
let (tool_calls, remaining_content) =
detect_and_parse_tool_call(response_content, Some("deepseek_v3_1"))
detect_and_parse_tool_call(response_content, Some("deepseek_v3_1"), None)
.await
.expect("Should successfully parse deepseek_v3_1 tool calls");
......
......@@ -2045,6 +2045,8 @@ mod tests {
#[tokio::test]
async fn test_jailed_stream_qwen3_coder_multiple_params() {
use dynamo_parsers::tool_calling::ToolDefinition;
let chunks = vec![
create_mock_response_chunk("Let me search for that. ".to_string(), 0),
create_mock_response_chunk(
......@@ -2054,9 +2056,23 @@ mod tests {
create_mock_response_chunk(" Searching now.".to_string(), 0),
];
// Define the web_search tool with its parameters
let tool_defs = vec![ToolDefinition {
name: "web_search".to_string(),
parameters: Some(serde_json::json!({
"type": "object",
"properties": {
"query": {"type": "string"},
"max_results": {"type": "integer"},
"filter": {"type": "string"},
},
})),
}];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder()
.tool_call_parser("qwen3_coder")
.tool_definitions(tool_defs)
.build();
let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
......
......@@ -487,6 +487,7 @@ mod tests {
let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail(
Some("nemotron_deci".to_string()),
None, // No tool_choice in this test
None, // No tool_definitions in this test
reasoning_parsed_stream,
);
......@@ -600,6 +601,7 @@ mod tests {
let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail(
Some("harmony".to_string()),
None, // No tool_choice in this test
None, // No tool_definitions in this test
reasoning_parsed_stream,
);
......
......@@ -160,6 +160,7 @@ async fn parse_response_stream(
Box::pin(OpenAIPreprocessor::apply_tool_calling_jail(
Some(tool_parser),
None, // No tool_choice in this test
None, // No tool_definitions in this test
stream,
))
} else {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::super::ToolDefinition;
use super::config::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
use openai_harmony::chat::{Content::Text, Role};
......@@ -46,6 +47,7 @@ pub async fn get_harmony_encoding() -> &'static Result<HarmonyEncoding, anyhow::
pub async fn parse_tool_calls_harmony_complete(
text: &str,
_config: &JsonParserConfig,
_tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
let enc = match get_harmony_encoding().await.as_ref() {
Ok(e) => e,
......@@ -212,7 +214,7 @@ mod tests {
async fn test_parse_tool_calls_harmony_complete_basic() {
let text = r#"<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"format":"celsius","location":"San Francisco"}"#;
let (tool_calls, normal_content) =
parse_tool_calls_harmony_complete(text, &Default::default())
parse_tool_calls_harmony_complete(text, &Default::default(), None)
.await
.unwrap();
assert_eq!(normal_content, Some("".to_string()));
......@@ -226,7 +228,7 @@ mod tests {
async fn test_parse_tools_harmony_without_start_token() {
let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|message|>{"location":"San Francisco"}<|call|>"#;
let (tool_calls, normal_content) =
parse_tool_calls_harmony_complete(text, &Default::default())
parse_tool_calls_harmony_complete(text, &Default::default(), None)
.await
.unwrap();
assert_eq!(normal_content, Some(text.trim().to_string()));
......@@ -237,7 +239,7 @@ mod tests {
async fn test_parse_tool_calls_harmony_with_multi_args() {
let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco", "unit":"fahrenheit"}<|call|>"#;
let (tool_calls, normal_content) =
parse_tool_calls_harmony_complete(text, &Default::default())
parse_tool_calls_harmony_complete(text, &Default::default(), None)
.await
.unwrap();
assert_eq!(
......@@ -255,7 +257,7 @@ mod tests {
async fn test_parse_tool_calls_harmony_with_normal_text() {
let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco"}<|call|>"#;
let (tool_calls, normal_content) =
parse_tool_calls_harmony_complete(text, &Default::default())
parse_tool_calls_harmony_complete(text, &Default::default(), None)
.await
.unwrap();
assert_eq!(
......@@ -272,7 +274,7 @@ mod tests {
async fn test_parse_tool_calls_harmony_without_call_token() {
let text = r#"<|channel|>analysis<|message|>We need to call get_weather function. The user asks "What's the weather like in San Francisco in Celsius?" So location: "San Francisco, CA" unit: "celsius". Let's call function.<|end|><|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{"location":"San Francisco, CA","unit":"celsius"}"#;
let (tool_calls, normal_content) =
parse_tool_calls_harmony_complete(text, &Default::default())
parse_tool_calls_harmony_complete(text, &Default::default(), None)
.await
.unwrap();
assert_eq!(normal_content, Some("We need to call get_weather function. The user asks \"What's the weather like in San Francisco in Celsius?\" So location: \"San Francisco, CA\" unit: \"celsius\". Let's call function.".to_string()));
......
......@@ -7,6 +7,7 @@ use regex::RegexBuilder;
use serde_json::Value;
use uuid::Uuid;
use super::super::ToolDefinition;
use super::config::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
......@@ -165,6 +166,7 @@ fn try_parse_normal_text(input: &str, start_token: &str) -> String {
pub fn try_tool_call_parse_basic_json(
message: &str,
config: &JsonParserConfig,
_tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
// Log the config we are using
tracing::debug!("Using JSON parser config: {:?}", config);
......
......@@ -5,6 +5,7 @@ use regex::RegexBuilder;
use serde_json::Value;
use uuid::Uuid;
use super::super::ToolDefinition;
use super::config::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
......@@ -119,6 +120,7 @@ fn parse_single_tool_call_v3_1(
pub fn parse_tool_calls_deepseek_v3_1(
message: &str,
config: &JsonParserConfig,
_tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
// Format Structure:
// <|tool▁calls▁begin|><|tool▁call▁begin|>{function_name}<|tool▁sep|>{json_arguments}<|tool▁call▁end|><|tool▁calls▁end|>
......@@ -275,7 +277,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
......@@ -293,7 +295,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(
content,
Some("The following tool call retrieves weather information: ".to_string())
......@@ -311,7 +313,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(content, Some(text.to_string()));
assert_eq!(result.len(), 0);
}
......@@ -323,7 +325,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 3);
let (name, args) = extract_name_and_args(result[0].clone());
......@@ -349,7 +351,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0);
}
......@@ -362,7 +364,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0);
}
......@@ -388,7 +390,7 @@ mod tests {
};
let (tool_call_results, normal_content) =
parse_tool_calls_deepseek_v3_1(text, &config).unwrap();
parse_tool_calls_deepseek_v3_1(text, &config, None).unwrap();
assert_eq!(tool_call_results.len(), 1);
......
......@@ -5,6 +5,7 @@ use regex::RegexBuilder;
use serde_json::Value;
use uuid::Uuid;
use super::super::ToolDefinition;
use super::config::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
......@@ -129,6 +130,7 @@ fn parse_single_tool_call_v3(block: &str, separator_tokens: &[String]) -> Option
pub fn parse_tool_calls_deepseek_v3(
message: &str,
config: &JsonParserConfig,
_tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
// Format Structure:
// <|tool▁calls▁begin|><|tool▁call▁begin|>{type}<|tool▁sep|>{function_name}\n```json\n{json_arguments}\n```<|tool▁call▁end|><|tool▁calls▁end|>
......@@ -285,7 +287,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
......@@ -306,7 +308,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(
content,
Some("The following tool call retrieves weather information: ".to_string())
......@@ -327,7 +329,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(content, Some(text.to_string()));
assert_eq!(result.len(), 0);
}
......@@ -348,7 +350,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(content, Some("".to_string()));
assert_eq!(result.len(), 3);
let (name, args) = extract_name_and_args(result[0].clone());
......@@ -377,7 +379,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0);
}
......@@ -399,7 +401,7 @@ mod tests {
super::super::config::ParserConfig::Json(cfg) => cfg,
_ => panic!("Expected JSON parser config"),
};
let (result, content) = parse_tool_calls_deepseek_v3(text, &config).unwrap();
let (result, content) = parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(content, Some(text.trim().to_string()));
assert_eq!(result.len(), 0);
}
......@@ -428,7 +430,7 @@ mod tests {
};
let (tool_call_results, normal_content) =
parse_tool_calls_deepseek_v3(text, &config).unwrap();
parse_tool_calls_deepseek_v3(text, &config, None).unwrap();
assert_eq!(tool_call_results.len(), 1);
......
......@@ -33,11 +33,12 @@ impl Default for JsonParserType {
pub fn try_tool_call_parse_json(
message: &str,
config: &JsonParserConfig,
tools: Option<&[super::ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
match config.parser_type {
JsonParserType::Basic => try_tool_call_parse_basic_json(message, config),
JsonParserType::DeepseekV3 => parse_tool_calls_deepseek_v3(message, config),
JsonParserType::DeepseekV31 => parse_tool_calls_deepseek_v3_1(message, config),
JsonParserType::Basic => try_tool_call_parse_basic_json(message, config, tools),
JsonParserType::DeepseekV3 => parse_tool_calls_deepseek_v3(message, config, tools),
JsonParserType::DeepseekV31 => parse_tool_calls_deepseek_v3_1(message, config, tools),
}
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use serde_json::Value;
pub mod config;
pub mod dsml;
pub mod harmony;
......@@ -13,6 +15,13 @@ pub mod tests;
pub mod tools;
pub mod xml;
/// Represents a tool definition with function schema.
#[derive(Debug, Clone)]
pub struct ToolDefinition {
pub name: String,
pub parameters: Option<Value>,
}
// Re-export main types and functions for convenience
pub use config::{JsonParserConfig, ParserConfig, ToolCallConfig, XmlParserConfig};
pub use dsml::try_tool_call_parse_dsml;
......
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::super::ToolDefinition;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
use regex::Regex;
use rustpython_parser::{
......@@ -161,6 +162,7 @@ fn const_expr(e: &Expr) -> Result<Value, Box<dyn std::error::Error>> {
pub fn try_tool_call_parse_pythonic(
message: &str,
_tools: Option<&[ToolDefinition]>,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
let stripped = strip_text(message).trim().to_string();
......@@ -263,7 +265,7 @@ mod tests {
#[test]
fn test_parse_tool_call_parse_pythonic_basic() {
let message = "[foo(a=1, b=2), bar(x=3)]";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
let (result, content) = try_tool_call_parse_pythonic(message, None).unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -279,7 +281,7 @@ mod tests {
#[test]
fn test_parse_tool_call_parse_pythonic_with_text() {
let message = "Hey yo ! [foo(a=1, b=2), bar(x=3)] Hey yo";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
let (result, content) = try_tool_call_parse_pythonic(message, None).unwrap();
assert_eq!(content, Some("Hey yo !".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -295,7 +297,7 @@ mod tests {
#[test]
fn test_parse_tool_call_parse_pythonic_with_text_and_new_line() {
let message = "Hey \n yo ! [foo(a=1, b=2), bar(x=3)] Hey yo";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
let (result, content) = try_tool_call_parse_pythonic(message, None).unwrap();
assert_eq!(content, Some("Hey \n yo !".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -311,7 +313,7 @@ mod tests {
#[test]
fn test_parse_tool_call_parse_pythonic_with_no_calls() {
let message = "Hey \n yo !";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
let (result, content) = try_tool_call_parse_pythonic(message, None).unwrap();
assert_eq!(content, Some("Hey \n yo !".to_string()));
assert!(result.is_empty());
assert_eq!(result.len(), 0)
......@@ -320,7 +322,7 @@ mod tests {
#[test]
fn test_parse_tool_call_parse_pythonic_with_python_tags() {
let message = "<|python_start|>[foo(a=1, b=2), bar(x=3)]<|python_end|>";
let (result, content) = try_tool_call_parse_pythonic(message).unwrap();
let (result, content) = try_tool_call_parse_pythonic(message, None).unwrap();
assert_eq!(content, Some("".to_string()));
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
......@@ -336,7 +338,7 @@ mod tests {
#[test]
fn test_parse_tool_call_parse_pythonic_with_list_arg_values() {
let message = "[foo(a=[1, 2, 3], b=2), bar(x=[3, 4, 5])]";
let (result, _) = try_tool_call_parse_pythonic(message).unwrap();
let (result, _) = try_tool_call_parse_pythonic(message, None).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
......@@ -351,7 +353,7 @@ mod tests {
#[test]
fn test_parse_tool_call_parse_pythonic_with_dict_arg_values() {
let message = "[foo(a={'a': 1, 'b': 2}, b=2), bar(x={'x': 3, 'y': {'e': 'f'}})]";
let (result, _) = try_tool_call_parse_pythonic(message).unwrap();
let (result, _) = try_tool_call_parse_pythonic(message, None).unwrap();
assert!(!result.is_empty());
assert_eq!(result.len(), 2);
let (name, args) = extract_name_and_args(result[0].clone());
......
......@@ -10,6 +10,7 @@ pub use super::parsers::detect_and_parse_tool_call;
pub async fn try_tool_call_parse_aggregate(
message: &str,
parser_str: Option<&str>,
tools: Option<&[super::ToolDefinition]>,
) -> anyhow::Result<(
Vec<dynamo_async_openai::types::ChatCompletionMessageToolCall>,
Option<String>,
......@@ -19,7 +20,7 @@ pub async fn try_tool_call_parse_aggregate(
} else {
tracing::info!("Using tool parser: {:?}", parser_str);
}
let (parsed, content) = detect_and_parse_tool_call(message, parser_str).await?;
let (parsed, content) = detect_and_parse_tool_call(message, parser_str, tools).await?;
if parsed.is_empty() {
return Ok((vec![], content));
}
......@@ -47,11 +48,12 @@ pub async fn try_tool_call_parse_aggregate(
pub async fn try_tool_call_parse_stream(
message: &str,
parser_str: Option<&str>,
tools: Option<&[super::ToolDefinition]>,
) -> anyhow::Result<(
Vec<dynamo_async_openai::types::ChatCompletionMessageToolCallChunk>,
Option<String>,
)> {
let (parsed, content) = detect_and_parse_tool_call(message, parser_str).await?;
let (parsed, content) = detect_and_parse_tool_call(message, parser_str, tools).await?;
if parsed.is_empty() {
return Ok((vec![], content));
}
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment