Unverified Commit cd814377 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

fix: harmony parser streaming fix (#3074)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
Co-authored-by: default avatarGraham King <grahamk@nvidia.com>
parent 78a3feda
......@@ -2080,6 +2080,7 @@ dependencies = [
"rustpython-parser",
"serde",
"serde_json",
"tokio",
"tracing",
"uuid 1.18.0",
]
......
......@@ -653,17 +653,17 @@ impl OpenAIPreprocessor {
}
/// Apply tool calling jail to the stream using the preprocessor's tool call parser
pub fn apply_tool_calling_jail_with_parser(
pub async fn apply_tool_calling_jail_with_parser(
&self,
stream: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
) -> ManyOut<Annotated<NvCreateChatCompletionStreamResponse>> {
apply_tool_calling_jail_internal(stream, self.tool_call_parser.clone())
apply_tool_calling_jail_internal(stream, self.tool_call_parser.clone()).await
}
}
/// Apply tool calling jail to the stream - stops/jails the stream under certain conditions
/// When jailed, the stream will be unjailed when the input stream ends
pub fn apply_tool_calling_jail_internal(
pub async fn apply_tool_calling_jail_internal(
stream: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
tool_call_parser: Option<String>,
) -> ManyOut<Annotated<NvCreateChatCompletionStreamResponse>> {
......@@ -677,6 +677,7 @@ pub fn apply_tool_calling_jail_internal(
last_response_metadata: None,
finished: false,
};
// Transform the stream using unfold to maintain state
// Input: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>
// Returns None if the stream is finished
......@@ -814,7 +815,9 @@ pub fn apply_tool_calling_jail_internal(
if let Ok((tool_calls, normal_text)) = try_tool_call_parse_aggregate(
accumulated_text,
state.tool_call_parser.as_deref(),
) {
)
.await
{
// Found tool calls, create a final response with them
tracing::debug!(
"Parsed {} tool calls from accumulated content",
......@@ -952,7 +955,7 @@ impl
// transform the postprocessor stream
let stream = Self::transform_postprocessor_stream(response_stream, response_generator);
let stream = self.apply_tool_calling_jail_with_parser(stream);
let stream = self.apply_tool_calling_jail_with_parser(stream).await;
let context = stream.context();
// prepend the annotations to the response stream
let stream = annotations_stream.chain(stream);
......
......@@ -169,7 +169,7 @@ async fn test_apply_tool_calling_jail_internal_with_tool_call_detection() {
// Apply the jail with nemotron_deci parser - should trigger jailing on first chunk
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("nemotron_deci".to_string()));
apply_tool_calling_jail_internal(response_stream, Some("nemotron_deci".to_string())).await;
// Collect all results
let results: Vec<_> = jailed_stream.collect().await;
......@@ -225,7 +225,7 @@ async fn test_apply_tool_calling_jail_internal_no_tool_calls() {
// Apply the jail with nemotron_deci parser - regular text should NOT be jailed
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("nemotron_deci".to_string()));
apply_tool_calling_jail_internal(response_stream, Some("nemotron_deci".to_string())).await;
// Collect all results
let results: Vec<_> = jailed_stream.collect().await;
......@@ -276,7 +276,7 @@ async fn test_apply_tool_calling_jail_internal_with_empty_stream() {
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream = apply_tool_calling_jail_internal(response_stream, None);
let jailed_stream = apply_tool_calling_jail_internal(response_stream, None).await;
let results: Vec<_> = jailed_stream.collect().await;
assert!(results.is_empty(), "Empty stream should produce no results");
......@@ -300,7 +300,7 @@ async fn test_apply_tool_calling_jail_internal_with_different_parsers() {
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("hermes".to_string()));
apply_tool_calling_jail_internal(response_stream, Some("hermes".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await;
assert!(!results.is_empty(), "Should have results for hermes parser");
......@@ -360,7 +360,7 @@ async fn test_apply_tool_calling_jail_internal_hermes_parser() {
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("hermes".to_string()));
apply_tool_calling_jail_internal(response_stream, Some("hermes".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await;
assert!(!results.is_empty(), "Should have results for hermes parser");
......@@ -458,7 +458,7 @@ async fn test_apply_tool_calling_jail_internal_mistral_parser_with_no_tool_call_
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string()));
apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await;
......@@ -532,7 +532,7 @@ async fn test_apply_tool_calling_jail_internal_mistral_parser_with_false_positiv
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string()));
apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await;
assert!(
......@@ -583,7 +583,7 @@ async fn test_apply_tool_calling_jail_internal_mistral_parser_with_false_positiv
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string()));
apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await;
assert!(
......@@ -635,3 +635,77 @@ async fn test_apply_tool_calling_jail_internal_mistral_parser_with_false_positiv
assert_eq!(arguments["location"], "San Francisco");
assert_eq!(arguments["unit"], "fahrenheit");
}
#[tokio::test]
async fn test_tool_calling_jail_internal_with_harmony_parser() {
let mock_context = Arc::new(MockAsyncEngineContext::new(
"test-request-id-harmony".to_string(),
));
// Harmony Format:
// <|channel|>analysis<|message|>Need to use function get_current_weather.<|end|>
// <|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json
// <|message|>{"location":"San Francisco"}<|call|>
let chunks = vec![
create_mock_response_chunk(
"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|>"
.to_string(),
0,
),
create_mock_response_chunk("<|start|>".to_string(), 0),
create_mock_response_chunk("assistant".to_string(), 0),
create_mock_response_chunk("<|channel|>".to_string(), 0),
create_mock_response_chunk(
"commentary to=functions.get_current_weather <|constrain|>json".to_string(),
0,
),
create_mock_response_chunk(
"<|message|>{\"location\":\"San Francisco\"}<|call|>".to_string(),
0,
),
create_final_response_chunk(0),
];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("harmony".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await;
assert!(
!results.is_empty(),
"Should have results for harmony parser"
);
assert_eq!(results.len(), 2);
assert_eq!(
results[1].data.as_ref().unwrap().choices[0].delta.content,
Some("Need to use function get_current_weather.".to_string())
);
assert!(
results[1].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.is_some()
);
let tools = results[1].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tools.len(), 1);
let name = tools[0].function.as_ref().unwrap().name.as_ref().unwrap();
let arguments = serde_json::from_str::<serde_json::Value>(
tools[0]
.function
.as_ref()
.unwrap()
.arguments
.as_ref()
.unwrap(),
)
.unwrap();
assert_eq!(name, "get_current_weather");
assert_eq!(arguments["location"], "San Francisco");
}
......@@ -29,6 +29,7 @@ anyhow = { workspace = true }
dynamo-async-openai = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
uuid = { workspace = true }
......
......@@ -7,19 +7,27 @@ use openai_harmony::StreamableParser;
use openai_harmony::chat::{Content::Text, Role};
use openai_harmony::{HarmonyEncoding, HarmonyEncodingName, load_harmony_encoding};
use serde_json::Value;
use std::sync::OnceLock;
static GLOBAL_HARMONY_GPTOSS_ENCODING: OnceLock<Result<HarmonyEncoding, anyhow::Error>> =
OnceLock::new();
static GLOBAL_HARMONY_GPTOSS_ENCODING: tokio::sync::OnceCell<
Result<HarmonyEncoding, anyhow::Error>,
> = tokio::sync::OnceCell::const_new();
pub fn get_harmony_encoding() -> &'static Result<HarmonyEncoding, anyhow::Error> {
pub async fn get_harmony_encoding() -> &'static Result<HarmonyEncoding, anyhow::Error> {
GLOBAL_HARMONY_GPTOSS_ENCODING
.get_or_init(|| load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss))
.get_or_init(|| async {
tokio::task::spawn_blocking(|| {
load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss)
})
.await
.map_err(anyhow::Error::msg)
.flatten()
})
.await
}
/// Parse tool calls from Harmony Format text
/// <|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco"}<|call|>
pub fn parse_tool_calls_harmony(
pub async fn parse_tool_calls_harmony(
text: &str,
config: &JsonParserConfig,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
......@@ -29,7 +37,7 @@ pub fn parse_tool_calls_harmony(
// Check if tool call start tokens are present, if not return everything as normal text
// Start Token: "<|start|>assistant<|channel|>commentary" should be present in the text if tool calls are present
// End Token: "<|call|>"
if !detect_tool_call_start_harmony(text, config) {
if !detect_tool_call_start_harmony(text, config, true) {
return Ok((vec![], Some(trimmed)));
}
......@@ -43,7 +51,7 @@ pub fn parse_tool_calls_harmony(
trimmed.push_str(end_token);
}
let enc = match get_harmony_encoding().as_ref() {
let enc = match get_harmony_encoding().await.as_ref() {
Ok(e) => e,
Err(e) => {
tracing::debug!("Failed to load harmony encoding: {e}. Tool calls will not be parsed.");
......@@ -154,15 +162,28 @@ pub fn parse_tool_calls_harmony(
Ok((res, Some(normal_text.to_string())))
}
pub fn detect_tool_call_start_harmony(chunk: &str, config: &JsonParserConfig) -> bool {
pub fn detect_tool_call_start_harmony(
chunk: &str,
config: &JsonParserConfig,
strict: bool,
) -> bool {
let trimmed = chunk.trim();
if trimmed.is_empty() {
return false;
}
if strict {
config
.tool_call_start_tokens
.iter()
.any(|token| trimmed.contains(token))
} else {
config
.tool_call_start_tokens
.iter()
.any(|token| trimmed.contains(token))
|| trimmed.contains("<|channel|>")
}
}
#[cfg(test)]
......@@ -174,8 +195,8 @@ mod tests {
(call.function.name, args)
}
#[test]
fn test_parse_tool_calls_harmony_basic() {
#[tokio::test]
async fn test_parse_tool_calls_harmony_basic() {
let text = r#"
<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|>
<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json
......@@ -186,7 +207,7 @@ mod tests {
tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default()
};
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).unwrap();
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).await.unwrap();
assert_eq!(
normal_content,
Some("Need to use function get_current_weather.".to_string())
......@@ -197,8 +218,8 @@ mod tests {
assert_eq!(args["location"], "San Francisco");
}
#[test]
fn test_parse_tools_harmony_without_start_token() {
#[tokio::test]
async fn test_parse_tools_harmony_without_start_token() {
let text = r#"
<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|>
<|message|>{"location":"San Francisco"}<|call|>
......@@ -208,13 +229,13 @@ mod tests {
tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default()
};
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).unwrap();
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).await.unwrap();
assert_eq!(normal_content, Some(text.trim().to_string()));
assert_eq!(tool_calls.len(), 0);
}
#[test]
fn test_parse_tool_calls_harmony_with_multi_args() {
#[tokio::test]
async fn test_parse_tool_calls_harmony_with_multi_args() {
let text = r#"
<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|>
<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json
......@@ -225,7 +246,7 @@ mod tests {
tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default()
};
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).unwrap();
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).await.unwrap();
assert_eq!(
normal_content,
Some("Need to use function get_current_weather.".to_string())
......@@ -237,8 +258,8 @@ mod tests {
assert_eq!(args["unit"], "fahrenheit");
}
#[test]
fn test_parse_tool_calls_harmony_with_normal_text() {
#[tokio::test]
async fn test_parse_tool_calls_harmony_with_normal_text() {
let text = r#"
<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|>
<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json
......@@ -249,7 +270,7 @@ mod tests {
tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default()
};
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).unwrap();
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).await.unwrap();
assert_eq!(
normal_content,
Some("Need to use function get_current_weather.".to_string())
......@@ -260,15 +281,15 @@ mod tests {
assert_eq!(args["location"], "San Francisco");
}
#[test]
fn test_parse_tool_calls_harmony_without_call_token() {
#[tokio::test]
async fn test_parse_tool_calls_harmony_without_call_token() {
let text = r#"<|channel|>analysis<|message|>We need to call get_weather function. The user asks "What's the weather like in San Francisco in Celsius?" So location: "San Francisco, CA" unit: "celsius". Let's call function.<|end|><|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{"location":"San Francisco, CA","unit":"celsius"}"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default()
};
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).unwrap();
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).await.unwrap();
assert_eq!(normal_content, Some("We need to call get_weather function. The user asks \"What's the weather like in San Francisco in Celsius?\" So location: \"San Francisco, CA\" unit: \"celsius\". Let's call function.".to_string()));
assert_eq!(tool_calls.len(), 1);
let (name, args) = extract_name_and_args(tool_calls[0].clone());
......@@ -290,19 +311,21 @@ mod detect_parser_tests {
tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default()
};
let result = detect_tool_call_start_harmony(text, &config);
let result = detect_tool_call_start_harmony(text, &config, false);
assert!(result);
}
#[test]
fn test_detect_tool_call_start_harmony_chunk_without_tool_call_start_token() {
// This is a warkaround for now. Right now everything is treated as tool call start token.
// We need to improve this in the future.
let text = r#"<|channel|>commentary to=functions.get_current_weather"#;
let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default()
};
let result = detect_tool_call_start_harmony(text, &config);
assert!(!result);
let result = detect_tool_call_start_harmony(text, &config, false);
assert!(result);
}
}
This diff is collapsed.
......@@ -7,7 +7,7 @@ pub use super::parsers::detect_and_parse_tool_call;
/// Try parsing a string as a structured tool call, for aggregation usage.
///
/// If successful, returns a `ChatCompletionMessageToolCall`.
pub fn try_tool_call_parse_aggregate(
pub async fn try_tool_call_parse_aggregate(
message: &str,
parser_str: Option<&str>,
) -> anyhow::Result<(
......@@ -19,7 +19,7 @@ pub fn try_tool_call_parse_aggregate(
} else {
tracing::info!("Using tool parser: {:?}", parser_str);
}
let (parsed, content) = detect_and_parse_tool_call(message, parser_str)?;
let (parsed, content) = detect_and_parse_tool_call(message, parser_str).await?;
if parsed.is_empty() {
return Ok((vec![], content));
}
......@@ -44,14 +44,14 @@ pub fn try_tool_call_parse_aggregate(
/// Try parsing a string as a structured tool call, for streaming (delta) usage.
///
/// If successful, returns a `ChatCompletionMessageToolCallChunk`.
pub fn try_tool_call_parse_stream(
pub async fn try_tool_call_parse_stream(
message: &str,
parser_str: Option<&str>,
) -> anyhow::Result<(
Vec<dynamo_async_openai::types::ChatCompletionMessageToolCallChunk>,
Option<String>,
)> {
let (parsed, content) = detect_and_parse_tool_call(message, parser_str)?;
let (parsed, content) = detect_and_parse_tool_call(message, parser_str).await?;
if parsed.is_empty() {
return Ok((vec![], content));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment