"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "9acaa8d1a2010fbbf5ad959f89dc2d6a32869dd5"
Unverified Commit 656b4c44 authored by zhongdaor-nv's avatar zhongdaor-nv Committed by GitHub
Browse files

fix: Migrate to new implementation that using parse_tool_calls_harmony_complete (#3685)


Signed-off-by: default avatarzhongdaor <zhongdaor@nvidia.com>
parent d68e4b8a
...@@ -4,9 +4,7 @@ ...@@ -4,9 +4,7 @@
use super::config::JsonParserConfig; use super::config::JsonParserConfig;
use super::response::{CalledFunction, ToolCallResponse, ToolCallType}; use super::response::{CalledFunction, ToolCallResponse, ToolCallType};
use openai_harmony::chat::{Content::Text, Role}; use openai_harmony::chat::{Content::Text, Role};
use openai_harmony::{ use openai_harmony::{HarmonyEncoding, HarmonyEncodingName, load_harmony_encoding};
HarmonyEncoding, HarmonyEncodingName, StreamableParser, load_harmony_encoding,
};
use serde_json::Value; use serde_json::Value;
static GLOBAL_HARMONY_GPTOSS_ENCODING: tokio::sync::OnceCell< static GLOBAL_HARMONY_GPTOSS_ENCODING: tokio::sync::OnceCell<
...@@ -26,142 +24,6 @@ pub async fn get_harmony_encoding() -> &'static Result<HarmonyEncoding, anyhow:: ...@@ -26,142 +24,6 @@ pub async fn get_harmony_encoding() -> &'static Result<HarmonyEncoding, anyhow::
.await .await
} }
/// Parse tool calls from Harmony Format text
/// <|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco"}<|call|>
pub async fn parse_tool_calls_harmony(
text: &str,
config: &JsonParserConfig,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
let mut trimmed = text.trim().to_string();
let original_text = trimmed.clone();
// Check if tool call start tokens are present, if not return everything as normal text
// Start Token: "<|start|>assistant<|channel|>commentary" should be present in the text if tool calls are present
// End Token: "<|call|>"
if !detect_tool_call_start_harmony(text, config, true) {
return Ok((vec![], Some(trimmed)));
}
// Workaround to add <|call|> token to the end of the text if it is not present. Otherwise, StreamableParser will not be able to parse the text.
let end_token = config
.tool_call_end_tokens
.first()
.map(String::as_str)
.unwrap_or("<|call|>");
if !trimmed.ends_with(end_token) {
trimmed.push_str(end_token);
}
let enc = match get_harmony_encoding().await.as_ref() {
Ok(e) => e,
Err(e) => {
tracing::debug!("Failed to load harmony encoding: {e}. Tool calls will not be parsed.");
return Ok((vec![], Some(original_text)));
}
};
// Encode the text into tokens using harmony encoding
let tokens = enc.tokenizer().encode_with_special_tokens(&trimmed);
// Create StreamableParser to process each token and create Harmony Format messages
// Set Role to Assistant because we are parsing tool calls from an assistant message
let mut parser = match StreamableParser::new(enc.clone(), Some(Role::Assistant)) {
Ok(p) => p,
Err(e) => {
tracing::debug!(
"Failed to create harmony streamable parser: {e}. Tool calls will not be parsed."
);
return Ok((vec![], Some(original_text)));
}
};
// Process each token to create Harmony Format messages
for token in tokens {
if parser.process(token).is_err() {
// Skip the token if it causes an error. Some special tokens are not supported by the parser.
continue;
}
}
// Get the Harmony Format messages
let messages = parser.messages();
let mut normal_text = String::new();
let mut res = Vec::with_capacity(messages.len());
let mut call_idx = 0usize; // Index of the tool call
// Iteratate through messages and extract tool calls if there
// For tool call, role should be Assistant, channel should be commentary and recipient should start with functions.
// Message {
// author: Author {
// role: Assistant,
// name: None
// },
// recipient: Some("functions.get_current_weather"),
// content: [
// Text(
// TextContent {
// text: "{\"location\":\"San Francisco\"}"
// }
// )
// ],
// channel: Some("commentary"),
// content_type: Some("<|constrain|>json")
for message in messages.iter() {
if message.author.role == Role::Assistant
&& message.channel.as_deref() == Some("commentary")
&& message
.recipient
.as_deref()
.unwrap_or_default()
.starts_with("functions.")
{
let Some(fname) = message
.recipient
.as_ref()
.and_then(|r| r.split('.').nth(1))
.filter(|s| !s.is_empty())
.map(|s| s.to_string())
else {
continue;
};
let args = match message.content.first() {
Some(Text(text)) => match serde_json::from_str::<Value>(text.text.trim()) {
Ok(value) => value,
Err(_) => {
Value::Null // Set args to null if it's not valid JSON
}
},
_ => {
Value::Null // Set args to null if it's not a text content
}
};
// Add tool call to result if args is valid JSON
if !args.is_null() {
call_idx += 1;
res.push(ToolCallResponse {
id: format!("call-{}", call_idx),
tp: ToolCallType::Function,
function: CalledFunction {
name: fname.to_string(),
// Safety: `Value::Object` is always valid JSON, so serialization cannot fail
arguments: serde_json::to_string(&args).unwrap(),
},
});
}
}
if message.author.role == Role::Assistant && message.channel.as_deref() == Some("analysis")
{
normal_text.push_str(match &message.content[0] {
Text(t) => &t.text,
_ => "",
});
}
}
Ok((res, Some(normal_text.to_string())))
}
/// Parse tool calls from a complete Harmony Format text chunk using direct token parsing. /// Parse tool calls from a complete Harmony Format text chunk using direct token parsing.
/// ///
/// This function is optimized for parsing complete text chunks where the entire content /// This function is optimized for parsing complete text chunks where the entire content
...@@ -169,7 +31,7 @@ pub async fn parse_tool_calls_harmony( ...@@ -169,7 +31,7 @@ pub async fn parse_tool_calls_harmony(
/// parse all tokens into Harmony Format messages, then extracts tool calls from messages /// parse all tokens into Harmony Format messages, then extracts tool calls from messages
/// with the "commentary" channel and "functions.*" recipients. /// with the "commentary" channel and "functions.*" recipients.
/// ///
/// Unlike `parse_tool_calls_harmony`, this function doesn't perform start token detection /// This function doesn't perform start token detection
/// or token-by-token streaming, making it more efficient for complete chunks. /// or token-by-token streaming, making it more efficient for complete chunks.
/// ///
/// # Arguments /// # Arguments
...@@ -346,29 +208,6 @@ mod tests { ...@@ -346,29 +208,6 @@ mod tests {
(call.function.name, args) (call.function.name, args)
} }
#[tokio::test]
async fn test_parse_tool_calls_harmony_basic() {
let text = r#"
<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|>
<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json
<|message|>{"location":"San Francisco"}<|call|>
"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default()
};
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).await.unwrap();
assert_eq!(
normal_content,
Some("Need to use function get_current_weather.".to_string())
);
assert_eq!(tool_calls.len(), 1);
let (name, args) = extract_name_and_args(tool_calls[0].clone());
assert_eq!(name, "get_current_weather");
assert_eq!(args["location"], "San Francisco");
}
#[tokio::test] #[tokio::test]
async fn test_parse_tool_calls_harmony_complete_basic() { async fn test_parse_tool_calls_harmony_complete_basic() {
let text = r#"<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"format":"celsius","location":"San Francisco"}"#; let text = r#"<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"format":"celsius","location":"San Francisco"}"#;
...@@ -385,33 +224,22 @@ mod tests { ...@@ -385,33 +224,22 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_parse_tools_harmony_without_start_token() { async fn test_parse_tools_harmony_without_start_token() {
let text = r#" let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|message|>{"location":"San Francisco"}<|call|>"#;
<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|> let (tool_calls, normal_content) =
<|message|>{"location":"San Francisco"}<|call|> parse_tool_calls_harmony_complete(text, &Default::default())
"#; .await
let config = JsonParserConfig { .unwrap();
tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default()
};
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).await.unwrap();
assert_eq!(normal_content, Some(text.trim().to_string())); assert_eq!(normal_content, Some(text.trim().to_string()));
assert_eq!(tool_calls.len(), 0); assert_eq!(tool_calls.len(), 0);
} }
#[tokio::test] #[tokio::test]
async fn test_parse_tool_calls_harmony_with_multi_args() { async fn test_parse_tool_calls_harmony_with_multi_args() {
let text = r#" let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco", "unit":"fahrenheit"}<|call|>"#;
<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|> let (tool_calls, normal_content) =
<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json parse_tool_calls_harmony_complete(text, &Default::default())
<|message|>{"location":"San Francisco", "unit":"fahrenheit"}<|call|> .await
"#; .unwrap();
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default()
};
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).await.unwrap();
assert_eq!( assert_eq!(
normal_content, normal_content,
Some("Need to use function get_current_weather.".to_string()) Some("Need to use function get_current_weather.".to_string())
...@@ -425,17 +253,11 @@ mod tests { ...@@ -425,17 +253,11 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_parse_tool_calls_harmony_with_normal_text() { async fn test_parse_tool_calls_harmony_with_normal_text() {
let text = r#" let text = r#"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco"}<|call|>"#;
<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|> let (tool_calls, normal_content) =
<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json parse_tool_calls_harmony_complete(text, &Default::default())
<|message|>{"location":"San Francisco"}<|call|> .await
"#; .unwrap();
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default()
};
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).await.unwrap();
assert_eq!( assert_eq!(
normal_content, normal_content,
Some("Need to use function get_current_weather.".to_string()) Some("Need to use function get_current_weather.".to_string())
...@@ -449,12 +271,10 @@ mod tests { ...@@ -449,12 +271,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_parse_tool_calls_harmony_without_call_token() { async fn test_parse_tool_calls_harmony_without_call_token() {
let text = r#"<|channel|>analysis<|message|>We need to call get_weather function. The user asks "What's the weather like in San Francisco in Celsius?" So location: "San Francisco, CA" unit: "celsius". Let's call function.<|end|><|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{"location":"San Francisco, CA","unit":"celsius"}"#; let text = r#"<|channel|>analysis<|message|>We need to call get_weather function. The user asks "What's the weather like in San Francisco in Celsius?" So location: "San Francisco, CA" unit: "celsius". Let's call function.<|end|><|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{"location":"San Francisco, CA","unit":"celsius"}"#;
let config = JsonParserConfig { let (tool_calls, normal_content) =
tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()], parse_tool_calls_harmony_complete(text, &Default::default())
tool_call_end_tokens: vec!["<|call|>".to_string()], .await
..Default::default() .unwrap();
};
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).await.unwrap();
assert_eq!(normal_content, Some("We need to call get_weather function. The user asks \"What's the weather like in San Francisco in Celsius?\" So location: \"San Francisco, CA\" unit: \"celsius\". Let's call function.".to_string())); assert_eq!(normal_content, Some("We need to call get_weather function. The user asks \"What's the weather like in San Francisco in Celsius?\" So location: \"San Francisco, CA\" unit: \"celsius\". Let's call function.".to_string()));
assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls.len(), 1);
let (name, args) = extract_name_and_args(tool_calls[0].clone()); let (name, args) = extract_name_and_args(tool_calls[0].clone());
......
...@@ -5,9 +5,7 @@ pub mod harmony_parser; ...@@ -5,9 +5,7 @@ pub mod harmony_parser;
pub use super::config::JsonParserConfig; pub use super::config::JsonParserConfig;
pub use super::{config, response}; pub use super::{config, response};
pub use harmony_parser::{ pub use harmony_parser::{detect_tool_call_start_harmony, parse_tool_calls_harmony_complete};
detect_tool_call_start_harmony, parse_tool_calls_harmony, parse_tool_calls_harmony_complete,
};
pub fn find_tool_call_end_position_harmony(chunk: &str, config: &JsonParserConfig) -> usize { pub fn find_tool_call_end_position_harmony(chunk: &str, config: &JsonParserConfig) -> usize {
let end_token = config let end_token = config
......
...@@ -13,7 +13,7 @@ pub mod tools; ...@@ -13,7 +13,7 @@ pub mod tools;
// Re-export main types and functions for convenience // Re-export main types and functions for convenience
pub use config::{JsonParserConfig, ToolCallConfig, ToolCallParserType}; pub use config::{JsonParserConfig, ToolCallConfig, ToolCallParserType};
pub use harmony::{parse_tool_calls_harmony, parse_tool_calls_harmony_complete}; pub use harmony::parse_tool_calls_harmony_complete;
pub use json::try_tool_call_parse_json; pub use json::try_tool_call_parse_json;
pub use parsers::{ pub use parsers::{
detect_and_parse_tool_call, detect_tool_call_start, find_tool_call_end_position, detect_and_parse_tool_call, detect_tool_call_start, find_tool_call_end_position,
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::config::{ToolCallConfig, ToolCallParserType}; use super::config::{ToolCallConfig, ToolCallParserType};
use super::harmony::parse_tool_calls_harmony;
use super::harmony::{ use super::harmony::{
detect_tool_call_start_harmony, find_tool_call_end_position_harmony, detect_tool_call_start_harmony, find_tool_call_end_position_harmony,
parse_tool_calls_harmony_complete, parse_tool_calls_harmony_complete,
...@@ -54,13 +53,6 @@ pub async fn try_tool_call_parse( ...@@ -54,13 +53,6 @@ pub async fn try_tool_call_parse(
ToolCallParserType::Harmony => { ToolCallParserType::Harmony => {
let (results, normal_content) = let (results, normal_content) =
parse_tool_calls_harmony_complete(message, &config.json).await?; parse_tool_calls_harmony_complete(message, &config.json).await?;
if results.is_empty() {
// Fallback: attempt streaming parser when direct parse yields no calls
// This increases resilience to multi-call inputs and minor format drift
let (fallback_results, fallback_normal) =
parse_tool_calls_harmony(message, &config.json).await?;
return Ok((fallback_results, fallback_normal));
}
Ok((results, normal_content)) Ok((results, normal_content))
} }
ToolCallParserType::Pythonic => { ToolCallParserType::Pythonic => {
...@@ -1776,7 +1768,7 @@ fahrenheit ...@@ -1776,7 +1768,7 @@ fahrenheit
#[tokio::test] #[tokio::test]
async fn test_parallel_harmony_format_multiple_tools() { async fn test_parallel_harmony_format_multiple_tools() {
// Test with harmony parser for multiple tool calls // Test with harmony parser for multiple tool calls
let input = r#"<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}<|call|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"city": "Orlando", "state": "FL", "unit": "fahrenheit"}<|call|>"#; let input = r#"<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}<|call|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"city": "Orlando", "state": "FL", "unit": "fahrenheit"}<|call|>"#;
let (result, _content) = detect_and_parse_tool_call(input, Some("harmony")) let (result, _content) = detect_and_parse_tool_call(input, Some("harmony"))
.await .await
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment