Unverified Commit 889d6529 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: parse normal text along with tool calls (#2709)

parent 766d3f2c
...@@ -166,7 +166,7 @@ impl DeltaAggregator { ...@@ -166,7 +166,7 @@ impl DeltaAggregator {
// After aggregation, inspect each choice's text for tool call syntax // After aggregation, inspect each choice's text for tool call syntax
for choice in aggregator.choices.values_mut() { for choice in aggregator.choices.values_mut() {
if choice.tool_calls.is_none() if choice.tool_calls.is_none()
&& let Ok(tool_calls) = try_tool_call_parse_aggregate( && let Ok((tool_calls, normal_text)) = try_tool_call_parse_aggregate(
&choice.text, &choice.text,
parsing_options.tool_call_parser.as_deref(), parsing_options.tool_call_parser.as_deref(),
) )
...@@ -184,6 +184,10 @@ impl DeltaAggregator { ...@@ -184,6 +184,10 @@ impl DeltaAggregator {
} }
choice.tool_calls = Some(tool_calls); choice.tool_calls = Some(tool_calls);
choice.text.clear(); choice.text.clear();
// If normal text is not empty, update the choice text
if let Some(normal_text) = normal_text.filter(|text| !text.is_empty()) {
choice.text = normal_text;
}
choice.finish_reason = Some(dynamo_async_openai::types::FinishReason::ToolCalls); choice.finish_reason = Some(dynamo_async_openai::types::FinishReason::ToolCalls);
} }
} }
...@@ -223,7 +227,7 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice { ...@@ -223,7 +227,7 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice {
dynamo_async_openai::types::ChatChoice { dynamo_async_openai::types::ChatChoice {
message: dynamo_async_openai::types::ChatCompletionResponseMessage { message: dynamo_async_openai::types::ChatCompletionResponseMessage {
role: delta.role.expect("delta should have a Role"), role: delta.role.expect("delta should have a Role"),
content: if delta.tool_calls.is_some() { content: if delta.text.is_empty() {
None None
} else { } else {
Some(delta.text) Some(delta.text)
...@@ -582,4 +586,68 @@ mod tests { ...@@ -582,4 +586,68 @@ mod tests {
dynamo_async_openai::types::Role::Assistant dynamo_async_openai::types::Role::Assistant
); );
} }
#[tokio::test]
async fn test_tool_calling_output_with_normal_text() {
// Simulate a delta with a tool call in the content
let tool_call_json = r#"Hey, I'm a normal text! {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}"#;
// Use create_test_delta to generate the annotated delta, then extract the inner delta for the test
let annotated_delta = create_test_delta(
0,
tool_call_json,
Some(dynamo_async_openai::types::Role::Assistant),
Some(dynamo_async_openai::types::FinishReason::ToolCalls),
);
let data = annotated_delta.data.unwrap();
// Wrap it in Annotated and create a stream
let annotated_delta = Annotated {
data: Some(data),
id: Some("test_id".to_string()),
event: None,
comment: None,
};
let stream = Box::pin(stream::iter(vec![annotated_delta]));
// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result
assert!(result.is_ok());
let response = result.unwrap();
// There should be one choice
assert_eq!(response.choices.len(), 1);
let choice = &response.choices[0];
// The tool_calls field should be present and parsed
assert!(choice.message.tool_calls.is_some());
let tool_calls = choice.message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
let tool_call = &tool_calls[0];
assert_eq!(tool_call.function.name, "get_weather");
// The arguments should be a JSON string containing the expected keys
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments).unwrap();
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
// The content should be the normal text
assert!(choice.message.content.is_some());
assert_eq!(
choice.message.content.as_ref().unwrap(),
"Hey, I'm a normal text!"
);
// The finish_reason should be ToolCalls
assert_eq!(
choice.finish_reason,
Some(dynamo_async_openai::types::FinishReason::ToolCalls)
);
assert_eq!(
choice.message.role,
dynamo_async_openai::types::Role::Assistant
);
}
} }
...@@ -60,10 +60,10 @@ fn extract_tool_call_content(input: &str, start_token: &str, end_token: &str) -> ...@@ -60,10 +60,10 @@ fn extract_tool_call_content(input: &str, start_token: &str, end_token: &str) ->
// Special case for <|python_tag|> . Regex pattern does not work well with it as it has no end token // Special case for <|python_tag|> . Regex pattern does not work well with it as it has no end token
// Handles single tool and multiple tool call cases for single start_token like <|python_tag|> // Handles single tool and multiple tool call cases for single start_token like <|python_tag|>
fn handle_single_token_tool_calls(input: &str, start_token: &str) -> String { fn handle_single_token_tool_calls(input: &str, start_token: &str) -> Option<String> {
// Return the input if it doesn't contain the start token // Return the input if it doesn't contain the start token
if !input.contains(start_token) { if !input.contains(start_token) {
return input.to_string(); return None;
} }
// Split on the start token and keep only JSON-looking segments // Split on the start token and keep only JSON-looking segments
...@@ -89,13 +89,23 @@ fn handle_single_token_tool_calls(input: &str, start_token: &str) -> String { ...@@ -89,13 +89,23 @@ fn handle_single_token_tool_calls(input: &str, start_token: &str) -> String {
// Remove everything up to and including the first occurrence of the start token // Remove everything up to and including the first occurrence of the start token
if let Some(idx) = input.find(start_token) { if let Some(idx) = input.find(start_token) {
let rest = &input[idx + start_token.len()..]; let rest = &input[idx + start_token.len()..];
return rest.trim_start().to_string(); return Some(rest.trim_start().to_string());
} else { } else {
// Shouldn't happen because we checked contains() above, but be defensive // Shouldn't happen because we checked contains() above, but be defensive
return input.to_string(); return None;
} }
} }
format!("[{}]", items.join(",")) Some(format!("[{}]", items.join(",")))
}
fn try_parse_normal_text(input: &str, start_token: &str) -> String {
// If input contains start token, just take the part before it
if let Some(idx) = input.find(start_token) {
return input[..idx].trim().to_string();
}
// No start token found, return empty string
String::new()
} }
/// Attempts to parse a tool call from a raw LLM message string into a unified [`ToolCallResponse`] format. /// Attempts to parse a tool call from a raw LLM message string into a unified [`ToolCallResponse`] format.
...@@ -142,40 +152,81 @@ fn handle_single_token_tool_calls(input: &str, start_token: &str) -> String { ...@@ -142,40 +152,81 @@ fn handle_single_token_tool_calls(input: &str, start_token: &str) -> String {
pub fn try_tool_call_parse_json( pub fn try_tool_call_parse_json(
message: &str, message: &str,
config: &JsonParserConfig, config: &JsonParserConfig,
) -> anyhow::Result<Vec<ToolCallResponse>> { ) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
// Log the config we are using // Log the config we are using
tracing::debug!("Using JSON parser config: {:?}", config); tracing::debug!("Using JSON parser config: {:?}", config);
let trimmed = message.trim(); let trimmed = message.trim();
// Use config to get tool call start and end token vectors, then use the first element for now // Early exit if no content
if trimmed.is_empty() {
return Ok((vec![], Some(String::new())));
}
let tool_call_start_tokens = &config.tool_call_start_tokens; let tool_call_start_tokens = &config.tool_call_start_tokens;
let tool_call_end_tokens = &config.tool_call_end_tokens; let tool_call_end_tokens = &config.tool_call_end_tokens;
assert!( // Early exit if no tokens configured
tool_call_start_tokens.len() == tool_call_end_tokens.len(), if tool_call_start_tokens.is_empty() {
"Tool call start and end tokens must have the same length" return Ok((vec![], Some(trimmed.to_string())));
); }
// Iterate over all start and end tokens and try to extract the content between them // Iterate over all start and end tokens and try to extract the content between them
// Assumption : One message will not contain different tags for tool calls. Iteration over tags is to support different tags by default for multiple models // Assumption : One message will not contain different tags for tool calls. Iteration over tags is to support different tags by default for multiple models
let mut json = trimmed.to_string(); let mut json = trimmed.to_string();
let mut normal_text = trimmed.to_string();
// First, check if ANY start token exists in the input
let has_start_token = tool_call_start_tokens
.iter()
.any(|token| !token.is_empty() && normal_text.contains(token));
if !has_start_token {
// No start tokens found, try to extract JSON directly. Everything that starts with { or [ is considered a potential JSON.
if let Some(idx) = normal_text.find(['{', '[']) {
let extracted_normal = normal_text[..idx].trim().to_string();
let extracted_json = normal_text[idx..].trim().to_string();
if !extracted_json.is_empty() {
normal_text = extracted_normal;
json = extracted_json;
}
}
} else {
// Start tokens exist, use regex-based parsing
for (start_token, end_token) in tool_call_start_tokens for (start_token, end_token) in tool_call_start_tokens
.iter() .iter()
.zip(tool_call_end_tokens.iter()) .zip(tool_call_end_tokens.iter())
{ {
// Special case for <|python_tag|> . Regex pattern does not work well with it as it has no end token let new_normal_text = try_parse_normal_text(&normal_text, start_token);
json = if !start_token.is_empty() && end_token.is_empty() {
handle_single_token_tool_calls(&json, start_token)
} else if let Some(content) = extract_tool_call_content(&json, start_token, end_token) {
content
} else {
json
};
}
// Process based on token types
match (start_token.is_empty(), end_token.is_empty()) {
(false, true) => {
// Single token case
let result = handle_single_token_tool_calls(&json, start_token);
if let Some(content) = result {
json = content;
// For single token case, use the normal text we extracted earlier
normal_text = new_normal_text;
break; // Found content, exit early
}
}
(false, false) => {
// Start and end token case
let result = extract_tool_call_content(&json, start_token, end_token);
if let Some(content) = result {
json = content;
normal_text = new_normal_text;
break; // Found content, exit early
}
}
_ => {
continue;
}
}
}
}
// Convert json (String) to &str // Convert json (String) to &str
let json = json.as_str(); let json = json.as_str();
// Anonymous function to attempt deserialization into a known representation // Anonymous function to attempt deserialization into a known representation
let parse = |name: String, args: HashMap<String, Value>| -> anyhow::Result<_> { let parse = |name: String, args: HashMap<String, Value>| -> anyhow::Result<_> {
Ok(ToolCallResponse { Ok(ToolCallResponse {
...@@ -198,7 +249,10 @@ pub fn try_tool_call_parse_json( ...@@ -198,7 +249,10 @@ pub fn try_tool_call_parse_json(
// } // }
// } // }
if let Ok(single) = serde_json::from_str::<CalledFunctionParameters>(json) { if let Ok(single) = serde_json::from_str::<CalledFunctionParameters>(json) {
return Ok(vec![parse(single.name, single.parameters)?]); return Ok((
vec![parse(single.name, single.parameters)?],
Some(normal_text),
));
//parse(single.name, single.parameters).map(Some); //parse(single.name, single.parameters).map(Some);
// CalledFunctionArguments: Single { name, arguments } // CalledFunctionArguments: Single { name, arguments }
...@@ -211,7 +265,10 @@ pub fn try_tool_call_parse_json( ...@@ -211,7 +265,10 @@ pub fn try_tool_call_parse_json(
// } // }
// } // }
} else if let Ok(single) = serde_json::from_str::<CalledFunctionArguments>(json) { } else if let Ok(single) = serde_json::from_str::<CalledFunctionArguments>(json) {
return Ok(vec![parse(single.name, single.arguments)?]); return Ok((
vec![parse(single.name, single.arguments)?],
Some(normal_text),
));
// Vec<CalledFunctionParameters>: List of { name, parameters } // Vec<CalledFunctionParameters>: List of { name, parameters }
// Example: // Example:
...@@ -225,7 +282,7 @@ pub fn try_tool_call_parse_json( ...@@ -225,7 +282,7 @@ pub fn try_tool_call_parse_json(
for item in list { for item in list {
results.push(parse(item.name, item.parameters)?); results.push(parse(item.name, item.parameters)?);
} }
return Ok(results); return Ok((results, Some(normal_text)));
// Vec<CalledFunctionArguments>: List of { name, arguments } // Vec<CalledFunctionArguments>: List of { name, arguments }
// Example: // Example:
...@@ -244,8 +301,8 @@ pub fn try_tool_call_parse_json( ...@@ -244,8 +301,8 @@ pub fn try_tool_call_parse_json(
for item in list { for item in list {
results.push(parse(item.name, item.arguments)?); results.push(parse(item.name, item.arguments)?);
} }
return Ok(results); return Ok((results, Some(normal_text)));
} }
Ok(vec![]) Ok((vec![], Some(trimmed.to_string())))
} }
This diff is collapsed.
...@@ -13,17 +13,21 @@ pub use super::parsers::{ToolCallConfig, detect_and_parse_tool_call}; ...@@ -13,17 +13,21 @@ pub use super::parsers::{ToolCallConfig, detect_and_parse_tool_call};
pub fn try_tool_call_parse_aggregate( pub fn try_tool_call_parse_aggregate(
message: &str, message: &str,
parser_str: Option<&str>, parser_str: Option<&str>,
) -> anyhow::Result<Vec<dynamo_async_openai::types::ChatCompletionMessageToolCall>> { ) -> anyhow::Result<(
Vec<dynamo_async_openai::types::ChatCompletionMessageToolCall>,
Option<String>,
)> {
if parser_str.is_none() { if parser_str.is_none() {
tracing::info!("No tool parser provided. Trying parsing with default parser."); tracing::info!("No tool parser provided. Trying parsing with default parser.");
} else { } else {
tracing::info!("Using tool parser: {:?}", parser_str); tracing::info!("Using tool parser: {:?}", parser_str);
} }
let parsed = detect_and_parse_tool_call(message, parser_str)?; let (parsed, content) = detect_and_parse_tool_call(message, parser_str)?;
if parsed.is_empty() { if parsed.is_empty() {
return Ok(vec![]); return Ok((vec![], content));
} }
Ok(parsed Ok((
parsed
.into_iter() .into_iter()
.map( .map(
|parsed| dynamo_async_openai::types::ChatCompletionMessageToolCall { |parsed| dynamo_async_openai::types::ChatCompletionMessageToolCall {
...@@ -35,7 +39,9 @@ pub fn try_tool_call_parse_aggregate( ...@@ -35,7 +39,9 @@ pub fn try_tool_call_parse_aggregate(
}, },
}, },
) )
.collect()) .collect(),
content,
))
} }
/// Try parsing a string as a structured tool call, for streaming (delta) usage. /// Try parsing a string as a structured tool call, for streaming (delta) usage.
...@@ -44,12 +50,16 @@ pub fn try_tool_call_parse_aggregate( ...@@ -44,12 +50,16 @@ pub fn try_tool_call_parse_aggregate(
pub fn try_tool_call_parse_stream( pub fn try_tool_call_parse_stream(
message: &str, message: &str,
parser_str: Option<&str>, parser_str: Option<&str>,
) -> anyhow::Result<Vec<dynamo_async_openai::types::ChatCompletionMessageToolCallChunk>> { ) -> anyhow::Result<(
let parsed = detect_and_parse_tool_call(message, parser_str)?; Vec<dynamo_async_openai::types::ChatCompletionMessageToolCallChunk>,
Option<String>,
)> {
let (parsed, content) = detect_and_parse_tool_call(message, parser_str)?;
if parsed.is_empty() { if parsed.is_empty() {
return Ok(vec![]); return Ok((vec![], content));
} }
Ok(parsed Ok((
parsed
.into_iter() .into_iter()
.enumerate() .enumerate()
.map( .map(
...@@ -64,5 +74,7 @@ pub fn try_tool_call_parse_stream( ...@@ -64,5 +74,7 @@ pub fn try_tool_call_parse_stream(
// Add other fields as needed if required by the struct definition // Add other fields as needed if required by the struct definition
}, },
) )
.collect()) .collect(),
content,
))
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment