Unverified Commit 04f7579b authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

fix: no more multiple finish reasons in stream (#4154)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent d3b5e9f2
...@@ -764,7 +764,7 @@ impl OpenAIPreprocessor { ...@@ -764,7 +764,7 @@ impl OpenAIPreprocessor {
let jail = JailedStream::builder() let jail = JailedStream::builder()
.tool_call_parser(tool_call_parser) .tool_call_parser(tool_call_parser)
.build(); .build();
jail.apply(stream) jail.apply_with_finish_reason(stream)
} }
// Motivation: Each transformation on the stream should be a separate step to allow for more flexibility // Motivation: Each transformation on the stream should be a separate step to allow for more flexibility
......
...@@ -13,6 +13,7 @@ use dynamo_parsers::tool_calling::{ ...@@ -13,6 +13,7 @@ use dynamo_parsers::tool_calling::{
}; };
use dynamo_runtime::protocols::annotated::Annotated; use dynamo_runtime::protocols::annotated::Annotated;
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use std::collections::HashMap;
use crate::utils::{MarkerMatcher, MatchResult}; use crate::utils::{MarkerMatcher, MatchResult};
...@@ -72,6 +73,8 @@ struct ChoiceJailState { ...@@ -72,6 +73,8 @@ struct ChoiceJailState {
accumulated_content: String, accumulated_content: String,
/// Buffer for partial marker matches across chunks /// Buffer for partial marker matches across chunks
partial_match_buffer: String, partial_match_buffer: String,
/// Stream finish reason
stream_finish_reason: Option<FinishReason>,
} }
fn create_choice_stream( fn create_choice_stream(
...@@ -106,6 +109,7 @@ impl ChoiceJailState { ...@@ -106,6 +109,7 @@ impl ChoiceJailState {
is_jailed: false, is_jailed: false,
accumulated_content: String::new(), accumulated_content: String::new(),
partial_match_buffer: String::new(), partial_match_buffer: String::new(),
stream_finish_reason: None,
} }
} }
...@@ -130,7 +134,6 @@ impl ChoiceJailState { ...@@ -130,7 +134,6 @@ impl ChoiceJailState {
jail_stream: &JailedStream, jail_stream: &JailedStream,
) -> Vec<ChoiceEmission> { ) -> Vec<ChoiceEmission> {
let mut emissions = Vec::new(); let mut emissions = Vec::new();
if !self.is_jailed { if !self.is_jailed {
// Use the marker matcher to detect complete/partial markers // Use the marker matcher to detect complete/partial markers
let match_result = jail_stream let match_result = jail_stream
...@@ -152,7 +155,7 @@ impl ChoiceJailState { ...@@ -152,7 +155,7 @@ impl ChoiceJailState {
choice.delta.role, choice.delta.role,
&prefix, &prefix,
None, None,
None, choice.finish_reason,
choice.logprobs.clone(), choice.logprobs.clone(),
); );
emissions.push(ChoiceEmission::PassThrough(prefix_choice)); emissions.push(ChoiceEmission::PassThrough(prefix_choice));
...@@ -192,7 +195,7 @@ impl ChoiceJailState { ...@@ -192,7 +195,7 @@ impl ChoiceJailState {
choice.delta.role, choice.delta.role,
trailing_part, trailing_part,
None, None,
None, choice.finish_reason,
choice.logprobs.clone(), choice.logprobs.clone(),
); );
emissions.push(ChoiceEmission::Trailing(trailing_choice)); emissions.push(ChoiceEmission::Trailing(trailing_choice));
...@@ -224,7 +227,7 @@ impl ChoiceJailState { ...@@ -224,7 +227,7 @@ impl ChoiceJailState {
choice.delta.role, choice.delta.role,
&prefix, &prefix,
None, None,
None, choice.finish_reason,
choice.logprobs.clone(), choice.logprobs.clone(),
); );
emissions.push(ChoiceEmission::PassThrough(prefix_choice)); emissions.push(ChoiceEmission::PassThrough(prefix_choice));
...@@ -267,7 +270,7 @@ impl ChoiceJailState { ...@@ -267,7 +270,7 @@ impl ChoiceJailState {
choice.delta.role, choice.delta.role,
&content, &content,
None, None,
None, choice.finish_reason,
choice.logprobs.clone(), choice.logprobs.clone(),
); );
emissions.push(ChoiceEmission::PassThrough(pass_through_choice)); emissions.push(ChoiceEmission::PassThrough(pass_through_choice));
...@@ -312,7 +315,7 @@ impl ChoiceJailState { ...@@ -312,7 +315,7 @@ impl ChoiceJailState {
choice.delta.role, choice.delta.role,
trailing_part, trailing_part,
None, None,
None, choice.finish_reason,
choice.logprobs.clone(), choice.logprobs.clone(),
); );
emissions.push(ChoiceEmission::Trailing(trailing_choice)); emissions.push(ChoiceEmission::Trailing(trailing_choice));
...@@ -323,7 +326,6 @@ impl ChoiceJailState { ...@@ -323,7 +326,6 @@ impl ChoiceJailState {
} }
// If not unjailing, don't emit anything (still accumulating) // If not unjailing, don't emit anything (still accumulating)
} }
emissions emissions
} }
...@@ -342,7 +344,7 @@ impl ChoiceJailState { ...@@ -342,7 +344,7 @@ impl ChoiceJailState {
Some(Role::Assistant), Some(Role::Assistant),
&self.accumulated_content, &self.accumulated_content,
None, None,
None, self.stream_finish_reason, // For the accumulated content, assign the original stream finish reason, otherwise it will get lost
None, None,
); );
...@@ -428,6 +430,19 @@ impl JailedStream { ...@@ -428,6 +430,19 @@ impl JailedStream {
JailedStreamBuilder::new() JailedStreamBuilder::new()
} }
/// Apply jail stream transformation with finish_reason fix
/// This is a convenience method that applies both apply() and fix_finish_reason()
pub fn apply_with_finish_reason<S>(
self,
stream: S,
) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
{
let jailed_stream = self.apply(stream);
JailedStream::fix_finish_reason(jailed_stream)
}
/// Apply the jail transformation to a stream of chat completion responses /// Apply the jail transformation to a stream of chat completion responses
/// Consumes self and returns the transformed stream /// Consumes self and returns the transformed stream
pub fn apply<S>( pub fn apply<S>(
...@@ -449,6 +464,7 @@ impl JailedStream { ...@@ -449,6 +464,7 @@ impl JailedStream {
// Pin the stream for iteration (stack pinning is more efficient) // Pin the stream for iteration (stack pinning is more efficient)
tokio::pin!(stream); tokio::pin!(stream);
// Process each item in the stream // Process each item in the stream
while let Some(response) = stream.next().await { while let Some(response) = stream.next().await {
if let Some(chat_response) = response.data.as_ref() { if let Some(chat_response) = response.data.as_ref() {
...@@ -467,6 +483,9 @@ impl JailedStream { ...@@ -467,6 +483,9 @@ impl JailedStream {
last_annotated_comment = response.comment.clone(); last_annotated_comment = response.comment.clone();
} }
// Track actual stream finish reason in the choice state
choice_state.stream_finish_reason = choice.finish_reason;
// Process this choice and get emissions // Process this choice and get emissions
let emissions = choice_state.process_content(choice, content, &self).await; let emissions = choice_state.process_content(choice, content, &self).await;
all_emissions.extend(emissions); all_emissions.extend(emissions);
...@@ -707,16 +726,16 @@ impl JailedStream { ...@@ -707,16 +726,16 @@ impl JailedStream {
}), }),
}) })
.collect(); .collect();
// Create choice with tool calls // Create choice with tool calls
return create_choice_stream( let choice = create_choice_stream(
choice_index, choice_index,
Some(Role::Assistant), Some(Role::Assistant),
normal_text.as_deref().unwrap_or(""), normal_text.as_deref().unwrap_or(""),
Some(tool_call_chunks), Some(tool_call_chunks),
Some(FinishReason::ToolCalls), None,
None, None,
); );
return choice;
} }
// No tool calls found or parsing failed, return content choice // No tool calls found or parsing failed, return content choice
...@@ -725,7 +744,7 @@ impl JailedStream { ...@@ -725,7 +744,7 @@ impl JailedStream {
Some(Role::Assistant), Some(Role::Assistant),
accumulated_content, accumulated_content,
None, None,
None, base_choice.finish_reason,
base_choice.logprobs.clone(), base_choice.logprobs.clone(),
) )
} }
...@@ -745,6 +764,44 @@ impl JailedStream { ...@@ -745,6 +764,44 @@ impl JailedStream {
} }
false false
} }
/// Post-processor that sets finish_reason to ToolCalls when tool calls were emitted
/// This should be called after apply() to fix the finish_reason for tool call chunks
pub fn fix_finish_reason<S>(
input_stream: S,
) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
{
stream! {
tokio::pin!(input_stream);
let mut has_tool_calls_per_choice: HashMap<u32, bool> = HashMap::new();
while let Some(mut response) = input_stream.next().await {
// Track if any choice emitted tool calls
if let Some(ref data) = response.data {
for choice in &data.choices {
if choice.delta.tool_calls.is_some() {
has_tool_calls_per_choice.insert(choice.index, true);
}
}
}
// If this chunk has finish_reason and the choice had tool calls, override to ToolCalls
if let Some(ref mut data) = response.data {
for choice in &mut data.choices {
if choice.finish_reason.is_some() && choice.finish_reason == Some(FinishReason::Stop)
&& has_tool_calls_per_choice.get(&choice.index).copied().unwrap_or(false)
{
choice.finish_reason = Some(FinishReason::ToolCalls);
}
}
}
yield response;
}
}
}
} }
/// Builder for configuring a JailedStream /// Builder for configuring a JailedStream
......
{
"request_id": "8f33c28b-cb52-4272-9ac5-0cb9f80386d3",
"expected_output": {
"normal_content": " the requested format.\n</think>\n\n<tool_call>\n\n{\"name\":\"get"
},
"input_stream": [
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" the","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" requested","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" format","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":".\n","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"</think>","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"\n\n","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"<tool_call>","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"\n","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"{\"","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"name","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"\":","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" \"","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"get","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}, "finish_reason":"length"}]}}
]
}
{
"request_id": "8f33c28b-cb52-4272-9ac5-0cb9f80386d3",
"expected_output": {
"normal_content": "<think>\nOkay, the user is asking for the weather in San Francisco in"
},
"input_stream": [
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"<think>","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"\n","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":"Okay","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":",","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" the","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" user","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" is","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" asking","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" for","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" the","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" weather","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null}}]}},
{"data":{"id":"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3","choices":[{"index":0,"delta":{"content":" in","function_call":null,"tool_calls":null,"role":"assistant","refusal":null,"reasoning_content":null},"finish_reason":"length"}]}}
]
}
...@@ -179,6 +179,49 @@ mod tests { ...@@ -179,6 +179,49 @@ mod tests {
} }
} }
/// Helper function to create a multi-choice finish_reason chunk
pub fn create_multi_choice_finish_chunk(
choice_indices: Vec<u32>,
) -> Annotated<NvCreateChatCompletionStreamResponse> {
let choices: Vec<ChatChoiceStream> = choice_indices
.into_iter()
.map(|index| {
#[allow(deprecated)]
ChatChoiceStream {
index,
delta: ChatCompletionStreamResponseDelta {
role: None,
content: None,
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: Some(FinishReason::Stop),
logprobs: None,
}
})
.collect();
let response = NvCreateChatCompletionStreamResponse {
id: "test-id".to_string(),
choices,
created: 1234567890,
model: "test-model".to_string(),
system_fingerprint: Some("test-fingerprint".to_string()),
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
};
Annotated {
data: Some(response),
id: None,
event: None,
comment: None,
}
}
/// Helper to assert content in a result /// Helper to assert content in a result
pub fn assert_content( pub fn assert_content(
result: &Annotated<NvCreateChatCompletionStreamResponse>, result: &Annotated<NvCreateChatCompletionStreamResponse>,
...@@ -336,8 +379,7 @@ mod tests { ...@@ -336,8 +379,7 @@ mod tests {
.jail_end_sequence("</jail>") .jail_end_sequence("</jail>")
.build(); .build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// We should only get 3 chunks now: // We should only get 3 chunks now:
// 1. "Hello " (before jail) // 1. "Hello " (before jail)
...@@ -393,8 +435,7 @@ mod tests { ...@@ -393,8 +435,7 @@ mod tests {
.tool_call_parser("nemotron_deci") .tool_call_parser("nemotron_deci")
.build(); .build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// Should have jailed the content and parsed tool calls at the end // Should have jailed the content and parsed tool calls at the end
assert!(!results.is_empty()); assert!(!results.is_empty());
...@@ -431,8 +472,7 @@ mod tests { ...@@ -431,8 +472,7 @@ mod tests {
.tool_call_parser("nemotron_deci") .tool_call_parser("nemotron_deci")
.build(); .build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// We should get 2 chunks: // We should get 2 chunks:
// 1. "Normal text " (before jail) // 1. "Normal text " (before jail)
...@@ -475,8 +515,7 @@ mod tests { ...@@ -475,8 +515,7 @@ mod tests {
.tool_call_parser("nemotron_deci") .tool_call_parser("nemotron_deci")
.build(); .build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// Should have exactly 2 chunks: tool call + trailing content // Should have exactly 2 chunks: tool call + trailing content
assert_eq!( assert_eq!(
...@@ -518,8 +557,7 @@ mod tests { ...@@ -518,8 +557,7 @@ mod tests {
.jail_start_sequence("<NOTPRESENT>") .jail_start_sequence("<NOTPRESENT>")
.build(); .build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// === Verify chunk count === // === Verify chunk count ===
assert_eq!( assert_eq!(
...@@ -572,8 +610,7 @@ mod tests { ...@@ -572,8 +610,7 @@ mod tests {
// Create JailedStream with Hermes parser // Create JailedStream with Hermes parser
let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// Should have exactly 3 chunks: content + tool call + content // Should have exactly 3 chunks: content + tool call + content
assert_eq!( assert_eq!(
...@@ -618,8 +655,7 @@ mod tests { ...@@ -618,8 +655,7 @@ mod tests {
// Create JailedStream with Mistral parser // Create JailedStream with Mistral parser
let jail = JailedStream::builder().tool_call_parser("mistral").build(); let jail = JailedStream::builder().tool_call_parser("mistral").build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// Should have exactly 3 chunks: content + tool call + content // Should have exactly 3 chunks: content + tool call + content
assert_eq!( assert_eq!(
...@@ -660,8 +696,7 @@ mod tests { ...@@ -660,8 +696,7 @@ mod tests {
// Create JailedStream with Mistral parser // Create JailedStream with Mistral parser
let jail = JailedStream::builder().tool_call_parser("mistral").build(); let jail = JailedStream::builder().tool_call_parser("mistral").build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// Should have exactly 3 chunks: content + tool call + content // Should have exactly 3 chunks: content + tool call + content
assert_eq!( assert_eq!(
...@@ -709,8 +744,7 @@ mod tests { ...@@ -709,8 +744,7 @@ mod tests {
// Create JailedStream with Phi4 parser // Create JailedStream with Phi4 parser
let jail = JailedStream::builder().tool_call_parser("phi4").build(); let jail = JailedStream::builder().tool_call_parser("phi4").build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// Should have exactly 3 chunks: content + tool call + content // Should have exactly 3 chunks: content + tool call + content
assert_eq!( assert_eq!(
...@@ -756,8 +790,7 @@ mod tests { ...@@ -756,8 +790,7 @@ mod tests {
.tool_call_parser("llama3_json") .tool_call_parser("llama3_json")
.build(); .build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// Should have exactly 3 chunks: content + tool call + content // Should have exactly 3 chunks: content + tool call + content
assert_eq!( assert_eq!(
...@@ -797,8 +830,7 @@ mod tests { ...@@ -797,8 +830,7 @@ mod tests {
// Create JailedStream with mistral parser (which specifically looks for [{ or [TOOL_CALLS] patterns) // Create JailedStream with mistral parser (which specifically looks for [{ or [TOOL_CALLS] patterns)
let jail = JailedStream::builder().tool_call_parser("mistral").build(); let jail = JailedStream::builder().tool_call_parser("mistral").build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// The "{" pattern triggers jailing, so some chunks get combined // The "{" pattern triggers jailing, so some chunks get combined
assert_eq!(results.len(), 2); assert_eq!(results.len(), 2);
...@@ -839,8 +871,7 @@ mod tests { ...@@ -839,8 +871,7 @@ mod tests {
.tool_call_parser("nemotron_deci") .tool_call_parser("nemotron_deci")
.build(); .build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// Jailing combines the tool call content into fewer chunks // Jailing combines the tool call content into fewer chunks
assert_eq!( assert_eq!(
...@@ -884,8 +915,7 @@ mod tests { ...@@ -884,8 +915,7 @@ mod tests {
.tool_call_parser("nemotron_deci") .tool_call_parser("nemotron_deci")
.build(); .build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// Should handle partial tool call gracefully - releases accumulated content on stream end // Should handle partial tool call gracefully - releases accumulated content on stream end
assert_eq!( assert_eq!(
...@@ -924,8 +954,7 @@ mod tests { ...@@ -924,8 +954,7 @@ mod tests {
.jail_end_sequence("</jail>") .jail_end_sequence("</jail>")
.build(); .build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// === Verify chunk count === // === Verify chunk count ===
assert_eq!( assert_eq!(
...@@ -979,8 +1008,7 @@ mod tests { ...@@ -979,8 +1008,7 @@ mod tests {
.tool_call_parser("nemotron_deci") .tool_call_parser("nemotron_deci")
.build(); .build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// === Verify chunk count === // === Verify chunk count ===
assert_eq!( assert_eq!(
...@@ -1087,8 +1115,7 @@ mod tests { ...@@ -1087,8 +1115,7 @@ mod tests {
.tool_call_parser("nemotron_deci") .tool_call_parser("nemotron_deci")
.build(); .build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// Should consolidate extreme fragmentation into 3 clean chunks // Should consolidate extreme fragmentation into 3 clean chunks
// Input: "I'll process your request. " + 54-char tool call + " Processing complete!" // Input: "I'll process your request. " + 54-char tool call + " Processing complete!"
...@@ -1142,6 +1169,7 @@ mod tests { ...@@ -1142,6 +1169,7 @@ mod tests {
create_mock_response_chunk("\"arguments\": {\"query\": \"test\"}}".to_string(), 0), create_mock_response_chunk("\"arguments\": {\"query\": \"test\"}}".to_string(), 0),
create_mock_response_chunk("</tool_call>".to_string(), 0), create_mock_response_chunk("</tool_call>".to_string(), 0),
create_mock_response_chunk(" Processing complete.".to_string(), 0), create_mock_response_chunk(" Processing complete.".to_string(), 0),
test_utils::create_final_response_chunk(0), // Backend finish_reason chunk
]; ];
let input_stream = stream::iter(chunks); let input_stream = stream::iter(chunks);
...@@ -1149,8 +1177,7 @@ mod tests { ...@@ -1149,8 +1177,7 @@ mod tests {
// Create JailedStream with Hermes parser // Create JailedStream with Hermes parser
let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// Should get 3 chunks: before jail, tool call response, after jail // Should get 3 chunks: before jail, tool call response, after jail
assert!( assert!(
...@@ -1159,14 +1186,14 @@ mod tests { ...@@ -1159,14 +1186,14 @@ mod tests {
results.len() results.len()
); );
// Find the synthesized tool call response chunk // Find the tool call chunk (the one with tool_calls, not the finish_reason chunk)
let tool_call_chunk = results let tool_call_chunk = results
.iter() .iter()
.find(|r| { .find(|r| {
r.data r.data
.as_ref() .as_ref()
.and_then(|d| d.choices.first()) .and_then(|d| d.choices.first())
.map(|c| c.finish_reason == Some(FinishReason::ToolCalls)) .map(|c| c.delta.tool_calls.is_some())
.unwrap_or(false) .unwrap_or(false)
}) })
.expect("Should have a tool call response chunk"); .expect("Should have a tool call response chunk");
...@@ -1232,8 +1259,7 @@ mod tests { ...@@ -1232,8 +1259,7 @@ mod tests {
// Create JailedStream with Hermes parser // Create JailedStream with Hermes parser
let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// Should get 2 chunks: first chunk passes through, stream end releases accumulated // Should get 2 chunks: first chunk passes through, stream end releases accumulated
assert_eq!(results.len(), 2, "Should have exactly 2 chunks"); assert_eq!(results.len(), 2, "Should have exactly 2 chunks");
...@@ -1291,23 +1317,23 @@ mod tests { ...@@ -1291,23 +1317,23 @@ mod tests {
), ),
create_mock_response_chunk("{\"name\": \"test\", \"arguments\": {}}".to_string(), 0), create_mock_response_chunk("{\"name\": \"test\", \"arguments\": {}}".to_string(), 0),
create_mock_response_chunk("</tool_call>".to_string(), 0), create_mock_response_chunk("</tool_call>".to_string(), 0),
test_utils::create_final_response_chunk(0), // Backend finish_reason chunk
]; ];
let input_stream = stream::iter(chunks); let input_stream = stream::iter(chunks);
let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// Find the tool call response // Find the tool call chunk (the one with tool_calls, not the finish_reason chunk)
let tool_call_chunk = results let tool_call_chunk = results
.iter() .iter()
.find(|r| { .find(|r| {
r.data r.data
.as_ref() .as_ref()
.and_then(|d| d.choices.first()) .and_then(|d| d.choices.first())
.map(|c| c.finish_reason == Some(FinishReason::ToolCalls)) .map(|c| c.delta.tool_calls.is_some())
.unwrap_or(false) .unwrap_or(false)
}) })
.expect("Should have a tool call response chunk"); .expect("Should have a tool call response chunk");
...@@ -1352,8 +1378,7 @@ mod tests { ...@@ -1352,8 +1378,7 @@ mod tests {
let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// === Verify chunk count === // === Verify chunk count ===
assert_eq!( assert_eq!(
...@@ -1395,8 +1420,7 @@ mod tests { ...@@ -1395,8 +1420,7 @@ mod tests {
let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// Should have exactly 3 chunks: content + tool call + trailing // Should have exactly 3 chunks: content + tool call + trailing
assert_eq!( assert_eq!(
...@@ -1453,14 +1477,15 @@ mod tests { ...@@ -1453,14 +1477,15 @@ mod tests {
("Done with B. ".to_string(), 1), // Choice 1 continues ("Done with B. ".to_string(), 1), // Choice 1 continues
("</tool_call>".to_string(), 2), // Choice 2 unjails ("</tool_call>".to_string(), 2), // Choice 2 unjails
]), ]),
// Chunk 6: Backend finish_reason chunks for all choices
test_utils::create_multi_choice_finish_chunk(vec![0, 1, 2]),
]; ];
let input_stream = stream::iter(chunks); let input_stream = stream::iter(chunks);
let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// EXPECTED BEHAVIOR (will fail with current implementation): // EXPECTED BEHAVIOR (will fail with current implementation):
// - Choice 1 should stream continuously (never jailed) // - Choice 1 should stream continuously (never jailed)
...@@ -1529,14 +1554,14 @@ mod tests { ...@@ -1529,14 +1554,14 @@ mod tests {
2, 2,
), ),
]), ]),
test_utils::create_multi_choice_finish_chunk(vec![0, 1, 2]),
]; ];
let input_stream = stream::iter(chunks); let input_stream = stream::iter(chunks);
let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// Find all tool call responses // Find all tool call responses
let mut tool_call_responses: Vec<_> = results let mut tool_call_responses: Vec<_> = results
...@@ -1559,25 +1584,30 @@ mod tests { ...@@ -1559,25 +1584,30 @@ mod tests {
// Run this test multiple times to verify determinism // Run this test multiple times to verify determinism
for run in 0..5 { for run in 0..5 {
let chunks = vec![create_multi_choice_chunk(vec![ let chunks = vec![
( create_multi_choice_chunk(vec![
"<tool_call>{\"name\": \"tool_0\", \"arguments\": {}}</tool_call>".to_string(), (
0, "<tool_call>{\"name\": \"tool_0\", \"arguments\": {}}</tool_call>"
), .to_string(),
( 0,
"<tool_call>{\"name\": \"tool_1\", \"arguments\": {}}</tool_call>".to_string(), ),
1, (
), "<tool_call>{\"name\": \"tool_1\", \"arguments\": {}}</tool_call>"
( .to_string(),
"<tool_call>{\"name\": \"tool_2\", \"arguments\": {}}</tool_call>".to_string(), 1,
2, ),
), (
])]; "<tool_call>{\"name\": \"tool_2\", \"arguments\": {}}</tool_call>"
.to_string(),
2,
),
]),
test_utils::create_multi_choice_finish_chunk(vec![0, 1, 2]),
];
let input_stream = stream::iter(chunks); let input_stream = stream::iter(chunks);
let jail = JailedStream::builder().tool_call_parser("hermes").build(); let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream); let run_results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let run_results: Vec<_> = jailed_stream.collect().await;
let run_responses: Vec<_> = run_results let run_responses: Vec<_> = run_results
.iter() .iter()
...@@ -1616,8 +1646,7 @@ mod tests { ...@@ -1616,8 +1646,7 @@ mod tests {
let jail = JailedStream::builder().build(); let jail = JailedStream::builder().build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// TODO: Once usage aggregation is implemented, verify: // TODO: Once usage aggregation is implemented, verify:
// - Usage chunk has choices: [] (empty array) // - Usage chunk has choices: [] (empty array)
...@@ -1652,8 +1681,7 @@ mod tests { ...@@ -1652,8 +1681,7 @@ mod tests {
.tool_call_parser("nemotron_deci") .tool_call_parser("nemotron_deci")
.build(); .build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// === Verify chunk count === // === Verify chunk count ===
assert_eq!( assert_eq!(
...@@ -1708,8 +1736,7 @@ mod tests { ...@@ -1708,8 +1736,7 @@ mod tests {
.jail_end_sequence("</TOOLCALL>") .jail_end_sequence("</TOOLCALL>")
.build(); .build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// === Verify chunk count === // === Verify chunk count ===
assert_eq!( assert_eq!(
...@@ -1763,8 +1790,7 @@ mod tests { ...@@ -1763,8 +1790,7 @@ mod tests {
let input_stream = stream::iter(chunks); let input_stream = stream::iter(chunks);
let jail = JailedStream::builder().tool_call_parser("harmony").build(); let jail = JailedStream::builder().tool_call_parser("harmony").build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// Should have at least one output containing both analysis text and parsed tool call // Should have at least one output containing both analysis text and parsed tool call
assert!(!results.is_empty()); assert!(!results.is_empty());
...@@ -1804,7 +1830,7 @@ mod tests { ...@@ -1804,7 +1830,7 @@ mod tests {
let jail = JailedStream::builder() let jail = JailedStream::builder()
.tool_call_parser("deepseek_v3_1") .tool_call_parser("deepseek_v3_1")
.build(); .build();
let jailed_stream = jail.apply(input_stream); let jailed_stream = jail.apply_with_finish_reason(input_stream);
let results: Vec<_> = jailed_stream.collect().await; let results: Vec<_> = jailed_stream.collect().await;
// Should have at least one output containing both analysis text and parsed tool call // Should have at least one output containing both analysis text and parsed tool call
...@@ -1878,7 +1904,7 @@ mod tests { ...@@ -1878,7 +1904,7 @@ mod tests {
let jail = JailedStream::builder() let jail = JailedStream::builder()
.tool_call_parser("deepseek_v3_1") .tool_call_parser("deepseek_v3_1")
.build(); .build();
let jailed_stream = jail.apply(input_stream); let jailed_stream = jail.apply_with_finish_reason(input_stream);
let results: Vec<_> = jailed_stream.collect().await; let results: Vec<_> = jailed_stream.collect().await;
// Should have at least one output containing both analysis text and parsed tool call // Should have at least one output containing both analysis text and parsed tool call
...@@ -1920,8 +1946,7 @@ mod tests { ...@@ -1920,8 +1946,7 @@ mod tests {
let input_stream = stream::iter(chunks); let input_stream = stream::iter(chunks);
let jail = JailedStream::builder().tool_call_parser("mistral").build(); let jail = JailedStream::builder().tool_call_parser("mistral").build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
assert!(results.len() >= 2); assert!(results.len() >= 2);
assert_content(&results[0], "Hey How"); assert_content(&results[0], "Hey How");
...@@ -1956,8 +1981,7 @@ mod tests { ...@@ -1956,8 +1981,7 @@ mod tests {
let input_stream = stream::iter(chunks); let input_stream = stream::iter(chunks);
let jail = JailedStream::builder().tool_call_parser("mistral").build(); let jail = JailedStream::builder().tool_call_parser("mistral").build();
let jailed_stream = jail.apply(input_stream); let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
let results: Vec<_> = jailed_stream.collect().await;
// Should preserve earlier content and also produce a tool call // Should preserve earlier content and also produce a tool call
assert!(results.len() >= 2); assert!(results.len() >= 2);
...@@ -2130,7 +2154,7 @@ mod parallel_jail_tests { ...@@ -2130,7 +2154,7 @@ mod parallel_jail_tests {
]; ];
let input_stream = stream::iter(input_chunks); let input_stream = stream::iter(input_chunks);
let results: Vec<_> = jail.apply(input_stream).collect().await; let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
// Should have tool call results // Should have tool call results
assert!(!results.is_empty(), "Should have results"); assert!(!results.is_empty(), "Should have results");
...@@ -2203,7 +2227,7 @@ mod parallel_jail_tests { ...@@ -2203,7 +2227,7 @@ mod parallel_jail_tests {
]; ];
let input_stream = stream::iter(input_chunks); let input_stream = stream::iter(input_chunks);
let results: Vec<_> = jail.apply(input_stream).collect().await; let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
assert!(!results.is_empty(), "Should have results"); assert!(!results.is_empty(), "Should have results");
...@@ -2240,7 +2264,7 @@ mod parallel_jail_tests { ...@@ -2240,7 +2264,7 @@ mod parallel_jail_tests {
]; ];
let input_stream = stream::iter(input_chunks); let input_stream = stream::iter(input_chunks);
let results: Vec<_> = jail.apply(input_stream).collect().await; let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
assert!(!results.is_empty(), "Should have results"); assert!(!results.is_empty(), "Should have results");
...@@ -2310,7 +2334,7 @@ mod parallel_jail_tests { ...@@ -2310,7 +2334,7 @@ mod parallel_jail_tests {
]; ];
let input_stream = stream::iter(input_chunks); let input_stream = stream::iter(input_chunks);
let results: Vec<_> = jail.apply(input_stream).collect().await; let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
assert!(!results.is_empty(), "Should have results"); assert!(!results.is_empty(), "Should have results");
...@@ -2548,7 +2572,7 @@ mod parallel_jail_tests { ...@@ -2548,7 +2572,7 @@ mod parallel_jail_tests {
]; ];
let input_stream = stream::iter(input_chunks); let input_stream = stream::iter(input_chunks);
let results: Vec<_> = jail.apply(input_stream).collect().await; let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
assert!(!results.is_empty(), "Should have results"); assert!(!results.is_empty(), "Should have results");
...@@ -2593,7 +2617,7 @@ mod parallel_jail_tests { ...@@ -2593,7 +2617,7 @@ mod parallel_jail_tests {
]; ];
let input_stream = stream::iter(input_chunks); let input_stream = stream::iter(input_chunks);
let results: Vec<_> = jail.apply(input_stream).collect().await; let results: Vec<_> = jail.apply_with_finish_reason(input_stream).collect().await;
// Should still handle the incomplete stream gracefully // Should still handle the incomplete stream gracefully
assert!( assert!(
......
...@@ -26,7 +26,7 @@ across backends. ...@@ -26,7 +26,7 @@ across backends.
*/ */
use dynamo_async_openai::types::ChatChoiceStream; use dynamo_async_openai::types::{ChatChoiceStream, FinishReason};
use dynamo_llm::preprocessor::OpenAIPreprocessor; use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse; use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use dynamo_runtime::protocols::annotated::Annotated; use dynamo_runtime::protocols::annotated::Annotated;
...@@ -251,6 +251,71 @@ fn aggregate_content_from_chunks( ...@@ -251,6 +251,71 @@ fn aggregate_content_from_chunks(
} }
} }
/// Helper function to validate finish_reason in the stream
/// Returns true if:
/// 1. There is exactly one finish_reason in the entire stream
/// 2. The finish_reason is in the last chunk
/// 3. The finish_reason matches the expected value
fn validate_finish_reason(
chunks: &[Annotated<NvCreateChatCompletionStreamResponse>],
expected_finish_reason: FinishReason,
) -> bool {
let mut finish_reason_count = 0;
let mut last_chunk_index = None;
let mut finish_reason_value = None;
// Count finish_reason occurrences and track position
for (idx, chunk) in chunks.iter().enumerate() {
if let Some(ref response_data) = chunk.data {
for choice in &response_data.choices {
if let Some(reason) = choice.finish_reason {
finish_reason_count += 1;
last_chunk_index = Some(idx);
finish_reason_value = Some(reason);
}
}
}
}
// Validate:
// 1. Exactly one finish_reason in the stream
if finish_reason_count != 1 {
eprintln!(
"Expected exactly 1 finish_reason, but found {}",
finish_reason_count
);
return false;
}
// 2. finish_reason is in the last chunk
if let Some(idx) = last_chunk_index {
if idx != chunks.len() - 1 {
eprintln!(
"Expected finish_reason in last chunk (index {}), but found at index {}",
chunks.len() - 1,
idx
);
return false;
}
} else {
eprintln!("No finish_reason found in stream");
return false;
}
// 3. finish_reason matches expected value
if let Some(reason) = finish_reason_value
&& reason != expected_finish_reason
{
eprintln!(
"Expected finish_reason {:?}, but found {:?}",
expected_finish_reason, reason
);
return false;
}
true
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
...@@ -304,6 +369,12 @@ mod tests { ...@@ -304,6 +369,12 @@ mod tests {
aggregated.has_tool_calls, expected_has_tool_calls, aggregated.has_tool_calls, expected_has_tool_calls,
"Tool calls presence should match expected value" "Tool calls presence should match expected value"
); );
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is Stop
assert!(
validate_finish_reason(&output_chunks, FinishReason::Stop),
"finish_reason validation failed for non-tool call case"
);
} }
#[tokio::test] #[tokio::test]
...@@ -360,6 +431,12 @@ mod tests { ...@@ -360,6 +431,12 @@ mod tests {
// Verify tool calls // Verify tool calls
assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls);
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is ToolCalls
assert!(
validate_finish_reason(&output_chunks, FinishReason::ToolCalls),
"finish_reason validation failed for tool call case"
);
} }
#[tokio::test] #[tokio::test]
...@@ -403,6 +480,12 @@ mod tests { ...@@ -403,6 +480,12 @@ mod tests {
aggregated.has_tool_calls, expected_has_tool_calls, aggregated.has_tool_calls, expected_has_tool_calls,
"Tool calls presence should match expected value" "Tool calls presence should match expected value"
); );
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is Stop
assert!(
validate_finish_reason(&output_chunks, FinishReason::Stop),
"finish_reason validation failed for non-tool call case"
);
} }
#[tokio::test] #[tokio::test]
...@@ -455,6 +538,12 @@ mod tests { ...@@ -455,6 +538,12 @@ mod tests {
// Verify tool calls // Verify tool calls
assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls);
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is ToolCalls
assert!(
validate_finish_reason(&output_chunks, FinishReason::ToolCalls),
"finish_reason validation failed for tool call case"
);
} }
#[tokio::test] #[tokio::test]
...@@ -511,6 +600,12 @@ mod tests { ...@@ -511,6 +600,12 @@ mod tests {
); );
assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls);
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is Stop
assert!(
validate_finish_reason(&output_chunks, FinishReason::Stop),
"finish_reason validation failed for non-tool call case"
);
} }
#[tokio::test] #[tokio::test]
...@@ -567,6 +662,12 @@ mod tests { ...@@ -567,6 +662,12 @@ mod tests {
); );
assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls);
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is ToolCalls
assert!(
validate_finish_reason(&output_chunks, FinishReason::ToolCalls),
"finish_reason validation failed for tool call case"
);
} }
#[tokio::test] #[tokio::test]
...@@ -620,6 +721,12 @@ mod tests { ...@@ -620,6 +721,12 @@ mod tests {
); );
assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls);
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is Stop
assert!(
validate_finish_reason(&output_chunks, FinishReason::Stop),
"finish_reason validation failed for non-tool call case"
);
} }
#[tokio::test] #[tokio::test]
...@@ -674,6 +781,12 @@ mod tests { ...@@ -674,6 +781,12 @@ mod tests {
"Tool calls presence should match expected value" "Tool calls presence should match expected value"
); );
assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls);
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is ToolCalls
assert!(
validate_finish_reason(&output_chunks, FinishReason::ToolCalls),
"finish_reason validation failed for tool call case"
);
} }
#[tokio::test] #[tokio::test]
...@@ -726,5 +839,46 @@ mod tests { ...@@ -726,5 +839,46 @@ mod tests {
// Verify tool calls // Verify tool calls
assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls); assert_tool_calls(&aggregated.tool_calls, &test_data.expected_tool_calls);
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is ToolCalls
assert!(
validate_finish_reason(&output_chunks, FinishReason::ToolCalls),
"finish_reason validation failed for tool call case"
);
}
#[tokio::test]
async fn test_qwen_finish_reason_length_vllm() {
let file_paths = vec![
format!(
"{}/vllm/qwen3-0.6B/chat_completion_stream_finish_length.json",
DATA_ROOT_PATH
),
format!(
"{}/vllm/qwen3-0.6B/chat_completion_incomplete_tool.json",
DATA_ROOT_PATH
),
];
for file_path in file_paths {
let test_data = load_test_data(&file_path);
// Create a stream from the mock chunks
let input_stream = stream::iter(test_data.stream_chunks);
// Parse the response stream with tool parsing enabled
let output_chunks =
parse_response_stream(input_stream, true, false, Some("hermes".to_string()), None)
.await;
// Verify we got output chunks
assert!(!output_chunks.is_empty(), "Should have output chunks");
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is Length
assert!(
validate_finish_reason(&output_chunks, FinishReason::Length),
"finish_reason validation failed for length finish case"
);
}
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment