"docs/kubernetes/api-reference.md" did not exist on "cf433e6825d83f41905da47d69ca5ee30d4eb1ba"
Unverified Commit cd814377 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

fix: harmony parser streaming fix (#3074)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
Co-authored-by: default avatarGraham King <grahamk@nvidia.com>
parent 78a3feda
...@@ -2080,6 +2080,7 @@ dependencies = [ ...@@ -2080,6 +2080,7 @@ dependencies = [
"rustpython-parser", "rustpython-parser",
"serde", "serde",
"serde_json", "serde_json",
"tokio",
"tracing", "tracing",
"uuid 1.18.0", "uuid 1.18.0",
] ]
......
...@@ -653,17 +653,17 @@ impl OpenAIPreprocessor { ...@@ -653,17 +653,17 @@ impl OpenAIPreprocessor {
} }
/// Apply tool calling jail to the stream using the preprocessor's tool call parser /// Apply tool calling jail to the stream using the preprocessor's tool call parser
pub fn apply_tool_calling_jail_with_parser( pub async fn apply_tool_calling_jail_with_parser(
&self, &self,
stream: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, stream: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
) -> ManyOut<Annotated<NvCreateChatCompletionStreamResponse>> { ) -> ManyOut<Annotated<NvCreateChatCompletionStreamResponse>> {
apply_tool_calling_jail_internal(stream, self.tool_call_parser.clone()) apply_tool_calling_jail_internal(stream, self.tool_call_parser.clone()).await
} }
} }
/// Apply tool calling jail to the stream - stops/jails the stream under certain conditions /// Apply tool calling jail to the stream - stops/jails the stream under certain conditions
/// When jailed, the stream will be unjailed when the input stream ends /// When jailed, the stream will be unjailed when the input stream ends
pub fn apply_tool_calling_jail_internal( pub async fn apply_tool_calling_jail_internal(
stream: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, stream: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
tool_call_parser: Option<String>, tool_call_parser: Option<String>,
) -> ManyOut<Annotated<NvCreateChatCompletionStreamResponse>> { ) -> ManyOut<Annotated<NvCreateChatCompletionStreamResponse>> {
...@@ -677,6 +677,7 @@ pub fn apply_tool_calling_jail_internal( ...@@ -677,6 +677,7 @@ pub fn apply_tool_calling_jail_internal(
last_response_metadata: None, last_response_metadata: None,
finished: false, finished: false,
}; };
// Transform the stream using unfold to maintain state // Transform the stream using unfold to maintain state
// Input: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>> // Input: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>
// Returns None if the stream is finished // Returns None if the stream is finished
...@@ -814,7 +815,9 @@ pub fn apply_tool_calling_jail_internal( ...@@ -814,7 +815,9 @@ pub fn apply_tool_calling_jail_internal(
if let Ok((tool_calls, normal_text)) = try_tool_call_parse_aggregate( if let Ok((tool_calls, normal_text)) = try_tool_call_parse_aggregate(
accumulated_text, accumulated_text,
state.tool_call_parser.as_deref(), state.tool_call_parser.as_deref(),
) { )
.await
{
// Found tool calls, create a final response with them // Found tool calls, create a final response with them
tracing::debug!( tracing::debug!(
"Parsed {} tool calls from accumulated content", "Parsed {} tool calls from accumulated content",
...@@ -952,7 +955,7 @@ impl ...@@ -952,7 +955,7 @@ impl
// transform the postprocessor stream // transform the postprocessor stream
let stream = Self::transform_postprocessor_stream(response_stream, response_generator); let stream = Self::transform_postprocessor_stream(response_stream, response_generator);
let stream = self.apply_tool_calling_jail_with_parser(stream); let stream = self.apply_tool_calling_jail_with_parser(stream).await;
let context = stream.context(); let context = stream.context();
// prepend the annotations to the response stream // prepend the annotations to the response stream
let stream = annotations_stream.chain(stream); let stream = annotations_stream.chain(stream);
......
...@@ -169,7 +169,7 @@ async fn test_apply_tool_calling_jail_internal_with_tool_call_detection() { ...@@ -169,7 +169,7 @@ async fn test_apply_tool_calling_jail_internal_with_tool_call_detection() {
// Apply the jail with nemotron_deci parser - should trigger jailing on first chunk // Apply the jail with nemotron_deci parser - should trigger jailing on first chunk
let jailed_stream = let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("nemotron_deci".to_string())); apply_tool_calling_jail_internal(response_stream, Some("nemotron_deci".to_string())).await;
// Collect all results // Collect all results
let results: Vec<_> = jailed_stream.collect().await; let results: Vec<_> = jailed_stream.collect().await;
...@@ -225,7 +225,7 @@ async fn test_apply_tool_calling_jail_internal_no_tool_calls() { ...@@ -225,7 +225,7 @@ async fn test_apply_tool_calling_jail_internal_no_tool_calls() {
// Apply the jail with nemotron_deci parser - regular text should NOT be jailed // Apply the jail with nemotron_deci parser - regular text should NOT be jailed
let jailed_stream = let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("nemotron_deci".to_string())); apply_tool_calling_jail_internal(response_stream, Some("nemotron_deci".to_string())).await;
// Collect all results // Collect all results
let results: Vec<_> = jailed_stream.collect().await; let results: Vec<_> = jailed_stream.collect().await;
...@@ -276,7 +276,7 @@ async fn test_apply_tool_calling_jail_internal_with_empty_stream() { ...@@ -276,7 +276,7 @@ async fn test_apply_tool_calling_jail_internal_with_empty_stream() {
let input_stream = stream::iter(chunks); let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone()); let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream = apply_tool_calling_jail_internal(response_stream, None); let jailed_stream = apply_tool_calling_jail_internal(response_stream, None).await;
let results: Vec<_> = jailed_stream.collect().await; let results: Vec<_> = jailed_stream.collect().await;
assert!(results.is_empty(), "Empty stream should produce no results"); assert!(results.is_empty(), "Empty stream should produce no results");
...@@ -300,7 +300,7 @@ async fn test_apply_tool_calling_jail_internal_with_different_parsers() { ...@@ -300,7 +300,7 @@ async fn test_apply_tool_calling_jail_internal_with_different_parsers() {
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone()); let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream = let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("hermes".to_string())); apply_tool_calling_jail_internal(response_stream, Some("hermes".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await; let results: Vec<_> = jailed_stream.collect().await;
assert!(!results.is_empty(), "Should have results for hermes parser"); assert!(!results.is_empty(), "Should have results for hermes parser");
...@@ -360,7 +360,7 @@ async fn test_apply_tool_calling_jail_internal_hermes_parser() { ...@@ -360,7 +360,7 @@ async fn test_apply_tool_calling_jail_internal_hermes_parser() {
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone()); let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream = let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("hermes".to_string())); apply_tool_calling_jail_internal(response_stream, Some("hermes".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await; let results: Vec<_> = jailed_stream.collect().await;
assert!(!results.is_empty(), "Should have results for hermes parser"); assert!(!results.is_empty(), "Should have results for hermes parser");
...@@ -458,7 +458,7 @@ async fn test_apply_tool_calling_jail_internal_mistral_parser_with_no_tool_call_ ...@@ -458,7 +458,7 @@ async fn test_apply_tool_calling_jail_internal_mistral_parser_with_no_tool_call_
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone()); let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream = let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string())); apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await; let results: Vec<_> = jailed_stream.collect().await;
...@@ -532,7 +532,7 @@ async fn test_apply_tool_calling_jail_internal_mistral_parser_with_false_positiv ...@@ -532,7 +532,7 @@ async fn test_apply_tool_calling_jail_internal_mistral_parser_with_false_positiv
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone()); let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream = let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string())); apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await; let results: Vec<_> = jailed_stream.collect().await;
assert!( assert!(
...@@ -583,7 +583,7 @@ async fn test_apply_tool_calling_jail_internal_mistral_parser_with_false_positiv ...@@ -583,7 +583,7 @@ async fn test_apply_tool_calling_jail_internal_mistral_parser_with_false_positiv
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone()); let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream = let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string())); apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await; let results: Vec<_> = jailed_stream.collect().await;
assert!( assert!(
...@@ -635,3 +635,77 @@ async fn test_apply_tool_calling_jail_internal_mistral_parser_with_false_positiv ...@@ -635,3 +635,77 @@ async fn test_apply_tool_calling_jail_internal_mistral_parser_with_false_positiv
assert_eq!(arguments["location"], "San Francisco"); assert_eq!(arguments["location"], "San Francisco");
assert_eq!(arguments["unit"], "fahrenheit"); assert_eq!(arguments["unit"], "fahrenheit");
} }
#[tokio::test]
async fn test_tool_calling_jail_internal_with_harmony_parser() {
let mock_context = Arc::new(MockAsyncEngineContext::new(
"test-request-id-harmony".to_string(),
));
// Harmony Format:
// <|channel|>analysis<|message|>Need to use function get_current_weather.<|end|>
// <|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json
// <|message|>{"location":"San Francisco"}<|call|>
let chunks = vec![
create_mock_response_chunk(
"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|>"
.to_string(),
0,
),
create_mock_response_chunk("<|start|>".to_string(), 0),
create_mock_response_chunk("assistant".to_string(), 0),
create_mock_response_chunk("<|channel|>".to_string(), 0),
create_mock_response_chunk(
"commentary to=functions.get_current_weather <|constrain|>json".to_string(),
0,
),
create_mock_response_chunk(
"<|message|>{\"location\":\"San Francisco\"}<|call|>".to_string(),
0,
),
create_final_response_chunk(0),
];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("harmony".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await;
assert!(
!results.is_empty(),
"Should have results for harmony parser"
);
assert_eq!(results.len(), 2);
assert_eq!(
results[1].data.as_ref().unwrap().choices[0].delta.content,
Some("Need to use function get_current_weather.".to_string())
);
assert!(
results[1].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.is_some()
);
let tools = results[1].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tools.len(), 1);
let name = tools[0].function.as_ref().unwrap().name.as_ref().unwrap();
let arguments = serde_json::from_str::<serde_json::Value>(
tools[0]
.function
.as_ref()
.unwrap()
.arguments
.as_ref()
.unwrap(),
)
.unwrap();
assert_eq!(name, "get_current_weather");
assert_eq!(arguments["location"], "San Francisco");
}
...@@ -29,6 +29,7 @@ anyhow = { workspace = true } ...@@ -29,6 +29,7 @@ anyhow = { workspace = true }
dynamo-async-openai = { workspace = true } dynamo-async-openai = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
uuid = { workspace = true } uuid = { workspace = true }
......
...@@ -7,19 +7,27 @@ use openai_harmony::StreamableParser; ...@@ -7,19 +7,27 @@ use openai_harmony::StreamableParser;
use openai_harmony::chat::{Content::Text, Role}; use openai_harmony::chat::{Content::Text, Role};
use openai_harmony::{HarmonyEncoding, HarmonyEncodingName, load_harmony_encoding}; use openai_harmony::{HarmonyEncoding, HarmonyEncodingName, load_harmony_encoding};
use serde_json::Value; use serde_json::Value;
use std::sync::OnceLock;
static GLOBAL_HARMONY_GPTOSS_ENCODING: OnceLock<Result<HarmonyEncoding, anyhow::Error>> = static GLOBAL_HARMONY_GPTOSS_ENCODING: tokio::sync::OnceCell<
OnceLock::new(); Result<HarmonyEncoding, anyhow::Error>,
> = tokio::sync::OnceCell::const_new();
pub fn get_harmony_encoding() -> &'static Result<HarmonyEncoding, anyhow::Error> { pub async fn get_harmony_encoding() -> &'static Result<HarmonyEncoding, anyhow::Error> {
GLOBAL_HARMONY_GPTOSS_ENCODING GLOBAL_HARMONY_GPTOSS_ENCODING
.get_or_init(|| load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss)) .get_or_init(|| async {
tokio::task::spawn_blocking(|| {
load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss)
})
.await
.map_err(anyhow::Error::msg)
.flatten()
})
.await
} }
/// Parse tool calls from Harmony Format text /// Parse tool calls from Harmony Format text
/// <|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco"}<|call|> /// <|channel|>analysis<|message|>Need to use function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{"location":"San Francisco"}<|call|>
pub fn parse_tool_calls_harmony( pub async fn parse_tool_calls_harmony(
text: &str, text: &str,
config: &JsonParserConfig, config: &JsonParserConfig,
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> { ) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
...@@ -29,7 +37,7 @@ pub fn parse_tool_calls_harmony( ...@@ -29,7 +37,7 @@ pub fn parse_tool_calls_harmony(
// Check if tool call start tokens are present, if not return everything as normal text // Check if tool call start tokens are present, if not return everything as normal text
// Start Token: "<|start|>assistant<|channel|>commentary" should be present in the text if tool calls are present // Start Token: "<|start|>assistant<|channel|>commentary" should be present in the text if tool calls are present
// End Token: "<|call|>" // End Token: "<|call|>"
if !detect_tool_call_start_harmony(text, config) { if !detect_tool_call_start_harmony(text, config, true) {
return Ok((vec![], Some(trimmed))); return Ok((vec![], Some(trimmed)));
} }
...@@ -43,7 +51,7 @@ pub fn parse_tool_calls_harmony( ...@@ -43,7 +51,7 @@ pub fn parse_tool_calls_harmony(
trimmed.push_str(end_token); trimmed.push_str(end_token);
} }
let enc = match get_harmony_encoding().as_ref() { let enc = match get_harmony_encoding().await.as_ref() {
Ok(e) => e, Ok(e) => e,
Err(e) => { Err(e) => {
tracing::debug!("Failed to load harmony encoding: {e}. Tool calls will not be parsed."); tracing::debug!("Failed to load harmony encoding: {e}. Tool calls will not be parsed.");
...@@ -154,15 +162,28 @@ pub fn parse_tool_calls_harmony( ...@@ -154,15 +162,28 @@ pub fn parse_tool_calls_harmony(
Ok((res, Some(normal_text.to_string()))) Ok((res, Some(normal_text.to_string())))
} }
pub fn detect_tool_call_start_harmony(chunk: &str, config: &JsonParserConfig) -> bool { pub fn detect_tool_call_start_harmony(
chunk: &str,
config: &JsonParserConfig,
strict: bool,
) -> bool {
let trimmed = chunk.trim(); let trimmed = chunk.trim();
if trimmed.is_empty() { if trimmed.is_empty() {
return false; return false;
} }
if strict {
config config
.tool_call_start_tokens .tool_call_start_tokens
.iter() .iter()
.any(|token| trimmed.contains(token)) .any(|token| trimmed.contains(token))
} else {
config
.tool_call_start_tokens
.iter()
.any(|token| trimmed.contains(token))
|| trimmed.contains("<|channel|>")
}
} }
#[cfg(test)] #[cfg(test)]
...@@ -174,8 +195,8 @@ mod tests { ...@@ -174,8 +195,8 @@ mod tests {
(call.function.name, args) (call.function.name, args)
} }
#[test] #[tokio::test]
fn test_parse_tool_calls_harmony_basic() { async fn test_parse_tool_calls_harmony_basic() {
let text = r#" let text = r#"
<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|> <|channel|>analysis<|message|>Need to use function get_current_weather.<|end|>
<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json <|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json
...@@ -186,7 +207,7 @@ mod tests { ...@@ -186,7 +207,7 @@ mod tests {
tool_call_end_tokens: vec!["<|call|>".to_string()], tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default() ..Default::default()
}; };
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).unwrap(); let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).await.unwrap();
assert_eq!( assert_eq!(
normal_content, normal_content,
Some("Need to use function get_current_weather.".to_string()) Some("Need to use function get_current_weather.".to_string())
...@@ -197,8 +218,8 @@ mod tests { ...@@ -197,8 +218,8 @@ mod tests {
assert_eq!(args["location"], "San Francisco"); assert_eq!(args["location"], "San Francisco");
} }
#[test] #[tokio::test]
fn test_parse_tools_harmony_without_start_token() { async fn test_parse_tools_harmony_without_start_token() {
let text = r#" let text = r#"
<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|> <|channel|>analysis<|message|>Need to use function get_current_weather.<|end|>
<|message|>{"location":"San Francisco"}<|call|> <|message|>{"location":"San Francisco"}<|call|>
...@@ -208,13 +229,13 @@ mod tests { ...@@ -208,13 +229,13 @@ mod tests {
tool_call_end_tokens: vec!["<|call|>".to_string()], tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default() ..Default::default()
}; };
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).unwrap(); let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).await.unwrap();
assert_eq!(normal_content, Some(text.trim().to_string())); assert_eq!(normal_content, Some(text.trim().to_string()));
assert_eq!(tool_calls.len(), 0); assert_eq!(tool_calls.len(), 0);
} }
#[test] #[tokio::test]
fn test_parse_tool_calls_harmony_with_multi_args() { async fn test_parse_tool_calls_harmony_with_multi_args() {
let text = r#" let text = r#"
<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|> <|channel|>analysis<|message|>Need to use function get_current_weather.<|end|>
<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json <|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json
...@@ -225,7 +246,7 @@ mod tests { ...@@ -225,7 +246,7 @@ mod tests {
tool_call_end_tokens: vec!["<|call|>".to_string()], tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default() ..Default::default()
}; };
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).unwrap(); let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).await.unwrap();
assert_eq!( assert_eq!(
normal_content, normal_content,
Some("Need to use function get_current_weather.".to_string()) Some("Need to use function get_current_weather.".to_string())
...@@ -237,8 +258,8 @@ mod tests { ...@@ -237,8 +258,8 @@ mod tests {
assert_eq!(args["unit"], "fahrenheit"); assert_eq!(args["unit"], "fahrenheit");
} }
#[test] #[tokio::test]
fn test_parse_tool_calls_harmony_with_normal_text() { async fn test_parse_tool_calls_harmony_with_normal_text() {
let text = r#" let text = r#"
<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|> <|channel|>analysis<|message|>Need to use function get_current_weather.<|end|>
<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json <|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json
...@@ -249,7 +270,7 @@ mod tests { ...@@ -249,7 +270,7 @@ mod tests {
tool_call_end_tokens: vec!["<|call|>".to_string()], tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default() ..Default::default()
}; };
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).unwrap(); let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).await.unwrap();
assert_eq!( assert_eq!(
normal_content, normal_content,
Some("Need to use function get_current_weather.".to_string()) Some("Need to use function get_current_weather.".to_string())
...@@ -260,15 +281,15 @@ mod tests { ...@@ -260,15 +281,15 @@ mod tests {
assert_eq!(args["location"], "San Francisco"); assert_eq!(args["location"], "San Francisco");
} }
#[test] #[tokio::test]
fn test_parse_tool_calls_harmony_without_call_token() { async fn test_parse_tool_calls_harmony_without_call_token() {
let text = r#"<|channel|>analysis<|message|>We need to call get_weather function. The user asks "What's the weather like in San Francisco in Celsius?" So location: "San Francisco, CA" unit: "celsius". Let's call function.<|end|><|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{"location":"San Francisco, CA","unit":"celsius"}"#; let text = r#"<|channel|>analysis<|message|>We need to call get_weather function. The user asks "What's the weather like in San Francisco in Celsius?" So location: "San Francisco, CA" unit: "celsius". Let's call function.<|end|><|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{"location":"San Francisco, CA","unit":"celsius"}"#;
let config = JsonParserConfig { let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()], tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
tool_call_end_tokens: vec!["<|call|>".to_string()], tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default() ..Default::default()
}; };
let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).unwrap(); let (tool_calls, normal_content) = parse_tool_calls_harmony(text, &config).await.unwrap();
assert_eq!(normal_content, Some("We need to call get_weather function. The user asks \"What's the weather like in San Francisco in Celsius?\" So location: \"San Francisco, CA\" unit: \"celsius\". Let's call function.".to_string())); assert_eq!(normal_content, Some("We need to call get_weather function. The user asks \"What's the weather like in San Francisco in Celsius?\" So location: \"San Francisco, CA\" unit: \"celsius\". Let's call function.".to_string()));
assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls.len(), 1);
let (name, args) = extract_name_and_args(tool_calls[0].clone()); let (name, args) = extract_name_and_args(tool_calls[0].clone());
...@@ -290,19 +311,21 @@ mod detect_parser_tests { ...@@ -290,19 +311,21 @@ mod detect_parser_tests {
tool_call_end_tokens: vec!["<|call|>".to_string()], tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default() ..Default::default()
}; };
let result = detect_tool_call_start_harmony(text, &config); let result = detect_tool_call_start_harmony(text, &config, false);
assert!(result); assert!(result);
} }
#[test] #[test]
fn test_detect_tool_call_start_harmony_chunk_without_tool_call_start_token() { fn test_detect_tool_call_start_harmony_chunk_without_tool_call_start_token() {
// This is a warkaround for now. Right now everything is treated as tool call start token.
// We need to improve this in the future.
let text = r#"<|channel|>commentary to=functions.get_current_weather"#; let text = r#"<|channel|>commentary to=functions.get_current_weather"#;
let config = JsonParserConfig { let config = JsonParserConfig {
tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()], tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
tool_call_end_tokens: vec!["<|call|>".to_string()], tool_call_end_tokens: vec!["<|call|>".to_string()],
..Default::default() ..Default::default()
}; };
let result = detect_tool_call_start_harmony(text, &config); let result = detect_tool_call_start_harmony(text, &config, false);
assert!(!result); assert!(result);
} }
} }
This diff is collapsed.
...@@ -7,7 +7,7 @@ pub use super::parsers::detect_and_parse_tool_call; ...@@ -7,7 +7,7 @@ pub use super::parsers::detect_and_parse_tool_call;
/// Try parsing a string as a structured tool call, for aggregation usage. /// Try parsing a string as a structured tool call, for aggregation usage.
/// ///
/// If successful, returns a `ChatCompletionMessageToolCall`. /// If successful, returns a `ChatCompletionMessageToolCall`.
pub fn try_tool_call_parse_aggregate( pub async fn try_tool_call_parse_aggregate(
message: &str, message: &str,
parser_str: Option<&str>, parser_str: Option<&str>,
) -> anyhow::Result<( ) -> anyhow::Result<(
...@@ -19,7 +19,7 @@ pub fn try_tool_call_parse_aggregate( ...@@ -19,7 +19,7 @@ pub fn try_tool_call_parse_aggregate(
} else { } else {
tracing::info!("Using tool parser: {:?}", parser_str); tracing::info!("Using tool parser: {:?}", parser_str);
} }
let (parsed, content) = detect_and_parse_tool_call(message, parser_str)?; let (parsed, content) = detect_and_parse_tool_call(message, parser_str).await?;
if parsed.is_empty() { if parsed.is_empty() {
return Ok((vec![], content)); return Ok((vec![], content));
} }
...@@ -44,14 +44,14 @@ pub fn try_tool_call_parse_aggregate( ...@@ -44,14 +44,14 @@ pub fn try_tool_call_parse_aggregate(
/// Try parsing a string as a structured tool call, for streaming (delta) usage. /// Try parsing a string as a structured tool call, for streaming (delta) usage.
/// ///
/// If successful, returns a `ChatCompletionMessageToolCallChunk`. /// If successful, returns a `ChatCompletionMessageToolCallChunk`.
pub fn try_tool_call_parse_stream( pub async fn try_tool_call_parse_stream(
message: &str, message: &str,
parser_str: Option<&str>, parser_str: Option<&str>,
) -> anyhow::Result<( ) -> anyhow::Result<(
Vec<dynamo_async_openai::types::ChatCompletionMessageToolCallChunk>, Vec<dynamo_async_openai::types::ChatCompletionMessageToolCallChunk>,
Option<String>, Option<String>,
)> { )> {
let (parsed, content) = detect_and_parse_tool_call(message, parser_str)?; let (parsed, content) = detect_and_parse_tool_call(message, parser_str).await?;
if parsed.is_empty() { if parsed.is_empty() {
return Ok((vec![], content)); return Ok((vec![], content));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment