"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "b19de4ed77c11a550e037582b92f91ca19eebde1"
Unverified Commit 8601ccdb authored by Vladislav Nosivskoy's avatar Vladislav Nosivskoy Committed by GitHub
Browse files

fix: parallel tool call indices in streaming (#4723)


Signed-off-by: default avatarVladislav Nosivskoy <vladnosiv@gmail.com>
parent 67273aba
...@@ -75,6 +75,8 @@ struct ChoiceJailState { ...@@ -75,6 +75,8 @@ struct ChoiceJailState {
partial_match_buffer: String, partial_match_buffer: String,
/// Stream finish reason /// Stream finish reason
stream_finish_reason: Option<FinishReason>, stream_finish_reason: Option<FinishReason>,
/// Number of tool calls already emitted for this choice
emitted_tool_calls_count: usize,
} }
fn create_choice_stream( fn create_choice_stream(
...@@ -110,6 +112,7 @@ impl ChoiceJailState { ...@@ -110,6 +112,7 @@ impl ChoiceJailState {
accumulated_content: String::new(), accumulated_content: String::new(),
partial_match_buffer: String::new(), partial_match_buffer: String::new(),
stream_finish_reason: None, stream_finish_reason: None,
emitted_tool_calls_count: 0,
} }
} }
...@@ -178,10 +181,18 @@ impl ChoiceJailState { ...@@ -178,10 +181,18 @@ impl ChoiceJailState {
// Create the tool call choice // Create the tool call choice
let tool_choice = jail_stream let tool_choice = jail_stream
.create_tool_call_choice(choice.index, jailed_part, choice) .create_tool_call_choice(
choice.index,
jailed_part,
choice,
self.emitted_tool_calls_count,
)
.await; .await;
if tool_choice.delta.tool_calls.is_some() { if tool_choice.delta.tool_calls.is_some() {
if let Some(ref tool_calls) = tool_choice.delta.tool_calls {
self.emitted_tool_calls_count += tool_calls.len();
}
emissions.push(ChoiceEmission::ToolCall(tool_choice)); emissions.push(ChoiceEmission::ToolCall(tool_choice));
} else { } else {
emissions.push(ChoiceEmission::Content(tool_choice)); emissions.push(ChoiceEmission::Content(tool_choice));
...@@ -297,11 +308,19 @@ impl ChoiceJailState { ...@@ -297,11 +308,19 @@ impl ChoiceJailState {
// Create the unjailed choice // Create the unjailed choice
let unjailed_choice = jail_stream let unjailed_choice = jail_stream
.create_tool_call_choice(choice.index, jailed_part, choice) .create_tool_call_choice(
choice.index,
jailed_part,
choice,
self.emitted_tool_calls_count,
)
.await; .await;
// Determine emission type based on whether tool calls were parsed // Determine emission type based on whether tool calls were parsed
if unjailed_choice.delta.tool_calls.is_some() { if unjailed_choice.delta.tool_calls.is_some() {
if let Some(ref tool_calls) = unjailed_choice.delta.tool_calls {
self.emitted_tool_calls_count += tool_calls.len();
}
emissions.push(ChoiceEmission::ToolCall(unjailed_choice)); emissions.push(ChoiceEmission::ToolCall(unjailed_choice));
} else { } else {
emissions.push(ChoiceEmission::Content(unjailed_choice)); emissions.push(ChoiceEmission::Content(unjailed_choice));
...@@ -349,9 +368,18 @@ impl ChoiceJailState { ...@@ -349,9 +368,18 @@ impl ChoiceJailState {
); );
let final_choice = jail_stream let final_choice = jail_stream
.create_tool_call_choice(self.index, &self.accumulated_content, &dummy_choice) .create_tool_call_choice(
self.index,
&self.accumulated_content,
&dummy_choice,
self.emitted_tool_calls_count,
)
.await; .await;
if let Some(ref tool_calls) = final_choice.delta.tool_calls {
self.emitted_tool_calls_count += tool_calls.len();
}
// End jailing // End jailing
self.end_jail(); self.end_jail();
...@@ -714,6 +742,7 @@ impl JailedStream { ...@@ -714,6 +742,7 @@ impl JailedStream {
choice_index: u32, choice_index: u32,
accumulated_content: &str, accumulated_content: &str,
base_choice: &ChatChoiceStream, base_choice: &ChatChoiceStream,
tool_call_offset: usize,
) -> ChatChoiceStream { ) -> ChatChoiceStream {
if let Ok((tool_calls, normal_text)) = if let Ok((tool_calls, normal_text)) =
try_tool_call_parse_aggregate(accumulated_content, self.tool_call_parser.as_deref()) try_tool_call_parse_aggregate(accumulated_content, self.tool_call_parser.as_deref())
...@@ -725,7 +754,7 @@ impl JailedStream { ...@@ -725,7 +754,7 @@ impl JailedStream {
.into_iter() .into_iter()
.enumerate() .enumerate()
.map(|(idx, tool_call)| ChatCompletionMessageToolCallChunk { .map(|(idx, tool_call)| ChatCompletionMessageToolCallChunk {
index: idx as u32, index: (tool_call_offset + idx) as u32,
id: Some(tool_call.id), id: Some(tool_call.id),
r#type: Some(tool_call.r#type), r#type: Some(tool_call.r#type),
function: Some(FunctionCallStream { function: Some(FunctionCallStream {
......
...@@ -2392,6 +2392,13 @@ mod parallel_jail_tests { ...@@ -2392,6 +2392,13 @@ mod parallel_jail_tests {
for (i, (expected_name, expected_args)) in expected_tool_calls.iter().enumerate() { for (i, (expected_name, expected_args)) in expected_tool_calls.iter().enumerate() {
let tool_call = &all_tool_calls[i]; let tool_call = &all_tool_calls[i];
assert!(tool_call.id.is_some(), "Tool call {} should have an ID", i); assert!(tool_call.id.is_some(), "Tool call {} should have an ID", i);
assert_eq!(
tool_call.index, i as u32,
"Tool call {} should have index {}, got {}",
i, i, tool_call.index
);
assert_eq!( assert_eq!(
tool_call.r#type, tool_call.r#type,
Some(dynamo_async_openai::types::ChatCompletionToolType::Function), Some(dynamo_async_openai::types::ChatCompletionToolType::Function),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment