Unverified Commit 889d6529 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: parse normal text along with tool calls (#2709)

parent 766d3f2c
......@@ -166,7 +166,7 @@ impl DeltaAggregator {
// After aggregation, inspect each choice's text for tool call syntax
for choice in aggregator.choices.values_mut() {
if choice.tool_calls.is_none()
&& let Ok(tool_calls) = try_tool_call_parse_aggregate(
&& let Ok((tool_calls, normal_text)) = try_tool_call_parse_aggregate(
&choice.text,
parsing_options.tool_call_parser.as_deref(),
)
......@@ -184,6 +184,10 @@ impl DeltaAggregator {
}
choice.tool_calls = Some(tool_calls);
choice.text.clear();
// If normal text is not empty, update the choice text
if let Some(normal_text) = normal_text.filter(|text| !text.is_empty()) {
choice.text = normal_text;
}
choice.finish_reason = Some(dynamo_async_openai::types::FinishReason::ToolCalls);
}
}
......@@ -223,7 +227,7 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice {
dynamo_async_openai::types::ChatChoice {
message: dynamo_async_openai::types::ChatCompletionResponseMessage {
role: delta.role.expect("delta should have a Role"),
content: if delta.tool_calls.is_some() {
content: if delta.text.is_empty() {
None
} else {
Some(delta.text)
......@@ -582,4 +586,68 @@ mod tests {
dynamo_async_openai::types::Role::Assistant
);
}
#[tokio::test]
async fn test_tool_calling_output_with_normal_text() {
// Simulate a delta with a tool call in the content
let tool_call_json = r#"Hey, I'm a normal text! {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}"#;
// Use create_test_delta to generate the annotated delta, then extract the inner delta for the test
let annotated_delta = create_test_delta(
0,
tool_call_json,
Some(dynamo_async_openai::types::Role::Assistant),
Some(dynamo_async_openai::types::FinishReason::ToolCalls),
);
let data = annotated_delta.data.unwrap();
// Wrap it in Annotated and create a stream
let annotated_delta = Annotated {
data: Some(data),
id: Some("test_id".to_string()),
event: None,
comment: None,
};
let stream = Box::pin(stream::iter(vec![annotated_delta]));
// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream, ParsingOptions::default()).await;
// Check the result
assert!(result.is_ok());
let response = result.unwrap();
// There should be one choice
assert_eq!(response.choices.len(), 1);
let choice = &response.choices[0];
// The tool_calls field should be present and parsed
assert!(choice.message.tool_calls.is_some());
let tool_calls = choice.message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
let tool_call = &tool_calls[0];
assert_eq!(tool_call.function.name, "get_weather");
// The arguments should be a JSON string containing the expected keys
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments).unwrap();
assert_eq!(args["location"], "San Francisco, CA");
assert_eq!(args["unit"], "fahrenheit");
// The content should be the normal text
assert!(choice.message.content.is_some());
assert_eq!(
choice.message.content.as_ref().unwrap(),
"Hey, I'm a normal text!"
);
// The finish_reason should be ToolCalls
assert_eq!(
choice.finish_reason,
Some(dynamo_async_openai::types::FinishReason::ToolCalls)
);
assert_eq!(
choice.message.role,
dynamo_async_openai::types::Role::Assistant
);
}
}
......@@ -60,10 +60,10 @@ fn extract_tool_call_content(input: &str, start_token: &str, end_token: &str) ->
// Special case for <|python_tag|> . Regex pattern does not work well with it as it has no end token
// Handles single tool and multiple tool call cases for single start_token like <|python_tag|>
fn handle_single_token_tool_calls(input: &str, start_token: &str) -> String {
fn handle_single_token_tool_calls(input: &str, start_token: &str) -> Option<String> {
// Return the input if it doesn't contain the start token
if !input.contains(start_token) {
return input.to_string();
return None;
}
// Split on the start token and keep only JSON-looking segments
......@@ -89,13 +89,23 @@ fn handle_single_token_tool_calls(input: &str, start_token: &str) -> String {
// Remove everything up to and including the first occurrence of the start token
if let Some(idx) = input.find(start_token) {
let rest = &input[idx + start_token.len()..];
return rest.trim_start().to_string();
return Some(rest.trim_start().to_string());
} else {
// Shouldn't happen because we checked contains() above, but be defensive
return input.to_string();
return None;
}
}
format!("[{}]", items.join(","))
Some(format!("[{}]", items.join(",")))
}
fn try_parse_normal_text(input: &str, start_token: &str) -> String {
// If input contains start token, just take the part before it
if let Some(idx) = input.find(start_token) {
return input[..idx].trim().to_string();
}
// No start token found, return empty string
String::new()
}
/// Attempts to parse a tool call from a raw LLM message string into a unified [`ToolCallResponse`] format.
......@@ -142,40 +152,81 @@ fn handle_single_token_tool_calls(input: &str, start_token: &str) -> String {
pub fn try_tool_call_parse_json(
message: &str,
config: &JsonParserConfig,
) -> anyhow::Result<Vec<ToolCallResponse>> {
) -> anyhow::Result<(Vec<ToolCallResponse>, Option<String>)> {
// Log the config we are using
tracing::debug!("Using JSON parser config: {:?}", config);
let trimmed = message.trim();
// Use config to get tool call start and end token vectors, then use the first element for now
// Early exit if no content
if trimmed.is_empty() {
return Ok((vec![], Some(String::new())));
}
let tool_call_start_tokens = &config.tool_call_start_tokens;
let tool_call_end_tokens = &config.tool_call_end_tokens;
assert!(
tool_call_start_tokens.len() == tool_call_end_tokens.len(),
"Tool call start and end tokens must have the same length"
);
// Early exit if no tokens configured
if tool_call_start_tokens.is_empty() {
return Ok((vec![], Some(trimmed.to_string())));
}
// Iterate over all start and end tokens and try to extract the content between them
// Assumption : One message will not contain different tags for tool calls. Iteration over tags is to support different tags by default for multiple models
let mut json = trimmed.to_string();
let mut normal_text = trimmed.to_string();
// First, check if ANY start token exists in the input
let has_start_token = tool_call_start_tokens
.iter()
.any(|token| !token.is_empty() && normal_text.contains(token));
if !has_start_token {
// No start tokens found, try to extract JSON directly. Everything that starts with { or [ is considered a potential JSON.
if let Some(idx) = normal_text.find(['{', '[']) {
let extracted_normal = normal_text[..idx].trim().to_string();
let extracted_json = normal_text[idx..].trim().to_string();
if !extracted_json.is_empty() {
normal_text = extracted_normal;
json = extracted_json;
}
}
} else {
// Start tokens exist, use regex-based parsing
for (start_token, end_token) in tool_call_start_tokens
.iter()
.zip(tool_call_end_tokens.iter())
{
// Special case for <|python_tag|> . Regex pattern does not work well with it as it has no end token
json = if !start_token.is_empty() && end_token.is_empty() {
handle_single_token_tool_calls(&json, start_token)
} else if let Some(content) = extract_tool_call_content(&json, start_token, end_token) {
content
} else {
json
};
}
let new_normal_text = try_parse_normal_text(&normal_text, start_token);
// Process based on token types
match (start_token.is_empty(), end_token.is_empty()) {
(false, true) => {
// Single token case
let result = handle_single_token_tool_calls(&json, start_token);
if let Some(content) = result {
json = content;
// For single token case, use the normal text we extracted earlier
normal_text = new_normal_text;
break; // Found content, exit early
}
}
(false, false) => {
// Start and end token case
let result = extract_tool_call_content(&json, start_token, end_token);
if let Some(content) = result {
json = content;
normal_text = new_normal_text;
break; // Found content, exit early
}
}
_ => {
continue;
}
}
}
}
// Convert json (String) to &str
let json = json.as_str();
// Anonymous function to attempt deserialization into a known representation
let parse = |name: String, args: HashMap<String, Value>| -> anyhow::Result<_> {
Ok(ToolCallResponse {
......@@ -198,7 +249,10 @@ pub fn try_tool_call_parse_json(
// }
// }
if let Ok(single) = serde_json::from_str::<CalledFunctionParameters>(json) {
return Ok(vec![parse(single.name, single.parameters)?]);
return Ok((
vec![parse(single.name, single.parameters)?],
Some(normal_text),
));
//parse(single.name, single.parameters).map(Some);
// CalledFunctionArguments: Single { name, arguments }
......@@ -211,7 +265,10 @@ pub fn try_tool_call_parse_json(
// }
// }
} else if let Ok(single) = serde_json::from_str::<CalledFunctionArguments>(json) {
return Ok(vec![parse(single.name, single.arguments)?]);
return Ok((
vec![parse(single.name, single.arguments)?],
Some(normal_text),
));
// Vec<CalledFunctionParameters>: List of { name, parameters }
// Example:
......@@ -225,7 +282,7 @@ pub fn try_tool_call_parse_json(
for item in list {
results.push(parse(item.name, item.parameters)?);
}
return Ok(results);
return Ok((results, Some(normal_text)));
// Vec<CalledFunctionArguments>: List of { name, arguments }
// Example:
......@@ -244,8 +301,8 @@ pub fn try_tool_call_parse_json(
for item in list {
results.push(parse(item.name, item.arguments)?);
}
return Ok(results);
return Ok((results, Some(normal_text)));
}
Ok(vec![])
Ok((vec![], Some(trimmed.to_string())))
}
This diff is collapsed.
......@@ -13,17 +13,21 @@ pub use super::parsers::{ToolCallConfig, detect_and_parse_tool_call};
pub fn try_tool_call_parse_aggregate(
message: &str,
parser_str: Option<&str>,
) -> anyhow::Result<Vec<dynamo_async_openai::types::ChatCompletionMessageToolCall>> {
) -> anyhow::Result<(
Vec<dynamo_async_openai::types::ChatCompletionMessageToolCall>,
Option<String>,
)> {
if parser_str.is_none() {
tracing::info!("No tool parser provided. Trying parsing with default parser.");
} else {
tracing::info!("Using tool parser: {:?}", parser_str);
}
let parsed = detect_and_parse_tool_call(message, parser_str)?;
let (parsed, content) = detect_and_parse_tool_call(message, parser_str)?;
if parsed.is_empty() {
return Ok(vec![]);
return Ok((vec![], content));
}
Ok(parsed
Ok((
parsed
.into_iter()
.map(
|parsed| dynamo_async_openai::types::ChatCompletionMessageToolCall {
......@@ -35,7 +39,9 @@ pub fn try_tool_call_parse_aggregate(
},
},
)
.collect())
.collect(),
content,
))
}
/// Try parsing a string as a structured tool call, for streaming (delta) usage.
......@@ -44,12 +50,16 @@ pub fn try_tool_call_parse_aggregate(
pub fn try_tool_call_parse_stream(
message: &str,
parser_str: Option<&str>,
) -> anyhow::Result<Vec<dynamo_async_openai::types::ChatCompletionMessageToolCallChunk>> {
let parsed = detect_and_parse_tool_call(message, parser_str)?;
) -> anyhow::Result<(
Vec<dynamo_async_openai::types::ChatCompletionMessageToolCallChunk>,
Option<String>,
)> {
let (parsed, content) = detect_and_parse_tool_call(message, parser_str)?;
if parsed.is_empty() {
return Ok(vec![]);
return Ok((vec![], content));
}
Ok(parsed
Ok((
parsed
.into_iter()
.enumerate()
.map(
......@@ -64,5 +74,7 @@ pub fn try_tool_call_parse_stream(
// Add other fields as needed if required by the struct definition
},
)
.collect())
.collect(),
content,
))
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment