Unverified Commit 8601ccdb authored by Vladislav Nosivskoy's avatar Vladislav Nosivskoy Committed by GitHub
Browse files

fix: parallel tool call indices in streaming (#4723)


Signed-off-by: default avatarVladislav Nosivskoy <vladnosiv@gmail.com>
parent 67273aba
......@@ -75,6 +75,8 @@ struct ChoiceJailState {
partial_match_buffer: String,
/// Stream finish reason
stream_finish_reason: Option<FinishReason>,
/// Number of tool calls already emitted for this choice
emitted_tool_calls_count: usize,
}
fn create_choice_stream(
......@@ -110,6 +112,7 @@ impl ChoiceJailState {
accumulated_content: String::new(),
partial_match_buffer: String::new(),
stream_finish_reason: None,
emitted_tool_calls_count: 0,
}
}
......@@ -178,10 +181,18 @@ impl ChoiceJailState {
// Create the tool call choice
let tool_choice = jail_stream
.create_tool_call_choice(choice.index, jailed_part, choice)
.create_tool_call_choice(
choice.index,
jailed_part,
choice,
self.emitted_tool_calls_count,
)
.await;
if tool_choice.delta.tool_calls.is_some() {
if let Some(ref tool_calls) = tool_choice.delta.tool_calls {
self.emitted_tool_calls_count += tool_calls.len();
}
emissions.push(ChoiceEmission::ToolCall(tool_choice));
} else {
emissions.push(ChoiceEmission::Content(tool_choice));
......@@ -297,11 +308,19 @@ impl ChoiceJailState {
// Create the unjailed choice
let unjailed_choice = jail_stream
.create_tool_call_choice(choice.index, jailed_part, choice)
.create_tool_call_choice(
choice.index,
jailed_part,
choice,
self.emitted_tool_calls_count,
)
.await;
// Determine emission type based on whether tool calls were parsed
if unjailed_choice.delta.tool_calls.is_some() {
if let Some(ref tool_calls) = unjailed_choice.delta.tool_calls {
self.emitted_tool_calls_count += tool_calls.len();
}
emissions.push(ChoiceEmission::ToolCall(unjailed_choice));
} else {
emissions.push(ChoiceEmission::Content(unjailed_choice));
......@@ -349,9 +368,18 @@ impl ChoiceJailState {
);
let final_choice = jail_stream
.create_tool_call_choice(self.index, &self.accumulated_content, &dummy_choice)
.create_tool_call_choice(
self.index,
&self.accumulated_content,
&dummy_choice,
self.emitted_tool_calls_count,
)
.await;
if let Some(ref tool_calls) = final_choice.delta.tool_calls {
self.emitted_tool_calls_count += tool_calls.len();
}
// End jailing
self.end_jail();
......@@ -714,6 +742,7 @@ impl JailedStream {
choice_index: u32,
accumulated_content: &str,
base_choice: &ChatChoiceStream,
tool_call_offset: usize,
) -> ChatChoiceStream {
if let Ok((tool_calls, normal_text)) =
try_tool_call_parse_aggregate(accumulated_content, self.tool_call_parser.as_deref())
......@@ -725,7 +754,7 @@ impl JailedStream {
.into_iter()
.enumerate()
.map(|(idx, tool_call)| ChatCompletionMessageToolCallChunk {
index: idx as u32,
index: (tool_call_offset + idx) as u32,
id: Some(tool_call.id),
r#type: Some(tool_call.r#type),
function: Some(FunctionCallStream {
......
......@@ -2392,6 +2392,13 @@ mod parallel_jail_tests {
for (i, (expected_name, expected_args)) in expected_tool_calls.iter().enumerate() {
let tool_call = &all_tool_calls[i];
assert!(tool_call.id.is_some(), "Tool call {} should have an ID", i);
assert_eq!(
tool_call.index, i as u32,
"Tool call {} should have index {}, got {}",
i, i, tool_call.index
);
assert_eq!(
tool_call.r#type,
Some(dynamo_async_openai::types::ChatCompletionToolType::Function),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment