Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
04f7579b
Unverified
Commit
04f7579b
authored
Nov 07, 2025
by
Ayush Agarwal
Committed by
GitHub
Nov 08, 2025
Browse files
fix: no more multiple finish reasons in stream (#4154)
Signed-off-by:
ayushag
<
ayushag@nvidia.com
>
parent
d3b5e9f2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
376 additions
and
100 deletions
+376
-100
lib/llm/src/preprocessor.rs
lib/llm/src/preprocessor.rs
+1
-1
lib/llm/src/protocols/openai/chat_completions/jail.rs
lib/llm/src/protocols/openai/chat_completions/jail.rs
+69
-12
lib/llm/tests/data/vllm/qwen3-0.6B/chat_completion_incomplete_tool.json
...data/vllm/qwen3-0.6B/chat_completion_incomplete_tool.json
+21
-0
lib/llm/tests/data/vllm/qwen3-0.6B/chat_completion_stream_finish_length.json
...vllm/qwen3-0.6B/chat_completion_stream_finish_length.json
+20
-0
lib/llm/tests/test_jail.rs
lib/llm/tests/test_jail.rs
+110
-86
lib/llm/tests/test_streaming_tool_parsers.rs
lib/llm/tests/test_streaming_tool_parsers.rs
+155
-1
No files found.
lib/llm/src/preprocessor.rs
View file @
04f7579b
...
@@ -764,7 +764,7 @@ impl OpenAIPreprocessor {
...
@@ -764,7 +764,7 @@ impl OpenAIPreprocessor {
let
jail
=
JailedStream
::
builder
()
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
tool_call_parser
)
.tool_call_parser
(
tool_call_parser
)
.build
();
.build
();
jail
.apply
(
stream
)
jail
.apply
_with_finish_reason
(
stream
)
}
}
// Motivation: Each transformation on the stream should be a separate step to allow for more flexibility
// Motivation: Each transformation on the stream should be a separate step to allow for more flexibility
...
...
lib/llm/src/protocols/openai/chat_completions/jail.rs
View file @
04f7579b
...
@@ -13,6 +13,7 @@ use dynamo_parsers::tool_calling::{
...
@@ -13,6 +13,7 @@ use dynamo_parsers::tool_calling::{
};
};
use
dynamo_runtime
::
protocols
::
annotated
::
Annotated
;
use
dynamo_runtime
::
protocols
::
annotated
::
Annotated
;
use
futures
::{
Stream
,
StreamExt
};
use
futures
::{
Stream
,
StreamExt
};
use
std
::
collections
::
HashMap
;
use
crate
::
utils
::{
MarkerMatcher
,
MatchResult
};
use
crate
::
utils
::{
MarkerMatcher
,
MatchResult
};
...
@@ -72,6 +73,8 @@ struct ChoiceJailState {
...
@@ -72,6 +73,8 @@ struct ChoiceJailState {
accumulated_content
:
String
,
accumulated_content
:
String
,
/// Buffer for partial marker matches across chunks
/// Buffer for partial marker matches across chunks
partial_match_buffer
:
String
,
partial_match_buffer
:
String
,
/// Stream finish reason
stream_finish_reason
:
Option
<
FinishReason
>
,
}
}
fn
create_choice_stream
(
fn
create_choice_stream
(
...
@@ -106,6 +109,7 @@ impl ChoiceJailState {
...
@@ -106,6 +109,7 @@ impl ChoiceJailState {
is_jailed
:
false
,
is_jailed
:
false
,
accumulated_content
:
String
::
new
(),
accumulated_content
:
String
::
new
(),
partial_match_buffer
:
String
::
new
(),
partial_match_buffer
:
String
::
new
(),
stream_finish_reason
:
None
,
}
}
}
}
...
@@ -130,7 +134,6 @@ impl ChoiceJailState {
...
@@ -130,7 +134,6 @@ impl ChoiceJailState {
jail_stream
:
&
JailedStream
,
jail_stream
:
&
JailedStream
,
)
->
Vec
<
ChoiceEmission
>
{
)
->
Vec
<
ChoiceEmission
>
{
let
mut
emissions
=
Vec
::
new
();
let
mut
emissions
=
Vec
::
new
();
if
!
self
.is_jailed
{
if
!
self
.is_jailed
{
// Use the marker matcher to detect complete/partial markers
// Use the marker matcher to detect complete/partial markers
let
match_result
=
jail_stream
let
match_result
=
jail_stream
...
@@ -152,7 +155,7 @@ impl ChoiceJailState {
...
@@ -152,7 +155,7 @@ impl ChoiceJailState {
choice
.delta.role
,
choice
.delta.role
,
&
prefix
,
&
prefix
,
None
,
None
,
None
,
choice
.finish_reason
,
choice
.logprobs
.clone
(),
choice
.logprobs
.clone
(),
);
);
emissions
.push
(
ChoiceEmission
::
PassThrough
(
prefix_choice
));
emissions
.push
(
ChoiceEmission
::
PassThrough
(
prefix_choice
));
...
@@ -192,7 +195,7 @@ impl ChoiceJailState {
...
@@ -192,7 +195,7 @@ impl ChoiceJailState {
choice
.delta.role
,
choice
.delta.role
,
trailing_part
,
trailing_part
,
None
,
None
,
None
,
choice
.finish_reason
,
choice
.logprobs
.clone
(),
choice
.logprobs
.clone
(),
);
);
emissions
.push
(
ChoiceEmission
::
Trailing
(
trailing_choice
));
emissions
.push
(
ChoiceEmission
::
Trailing
(
trailing_choice
));
...
@@ -224,7 +227,7 @@ impl ChoiceJailState {
...
@@ -224,7 +227,7 @@ impl ChoiceJailState {
choice
.delta.role
,
choice
.delta.role
,
&
prefix
,
&
prefix
,
None
,
None
,
None
,
choice
.finish_reason
,
choice
.logprobs
.clone
(),
choice
.logprobs
.clone
(),
);
);
emissions
.push
(
ChoiceEmission
::
PassThrough
(
prefix_choice
));
emissions
.push
(
ChoiceEmission
::
PassThrough
(
prefix_choice
));
...
@@ -267,7 +270,7 @@ impl ChoiceJailState {
...
@@ -267,7 +270,7 @@ impl ChoiceJailState {
choice
.delta.role
,
choice
.delta.role
,
&
content
,
&
content
,
None
,
None
,
None
,
choice
.finish_reason
,
choice
.logprobs
.clone
(),
choice
.logprobs
.clone
(),
);
);
emissions
.push
(
ChoiceEmission
::
PassThrough
(
pass_through_choice
));
emissions
.push
(
ChoiceEmission
::
PassThrough
(
pass_through_choice
));
...
@@ -312,7 +315,7 @@ impl ChoiceJailState {
...
@@ -312,7 +315,7 @@ impl ChoiceJailState {
choice
.delta.role
,
choice
.delta.role
,
trailing_part
,
trailing_part
,
None
,
None
,
None
,
choice
.finish_reason
,
choice
.logprobs
.clone
(),
choice
.logprobs
.clone
(),
);
);
emissions
.push
(
ChoiceEmission
::
Trailing
(
trailing_choice
));
emissions
.push
(
ChoiceEmission
::
Trailing
(
trailing_choice
));
...
@@ -323,7 +326,6 @@ impl ChoiceJailState {
...
@@ -323,7 +326,6 @@ impl ChoiceJailState {
}
}
// If not unjailing, don't emit anything (still accumulating)
// If not unjailing, don't emit anything (still accumulating)
}
}
emissions
emissions
}
}
...
@@ -342,7 +344,7 @@ impl ChoiceJailState {
...
@@ -342,7 +344,7 @@ impl ChoiceJailState {
Some
(
Role
::
Assistant
),
Some
(
Role
::
Assistant
),
&
self
.accumulated_content
,
&
self
.accumulated_content
,
None
,
None
,
None
,
self
.stream_finish_reason
,
// For the accumulated content, assign the original stream finish reason, otherwise it will get lost
None
,
None
,
);
);
...
@@ -428,6 +430,19 @@ impl JailedStream {
...
@@ -428,6 +430,19 @@ impl JailedStream {
JailedStreamBuilder
::
new
()
JailedStreamBuilder
::
new
()
}
}
/// Apply jail stream transformation with finish_reason fix
/// This is a convenience method that applies both apply() and fix_finish_reason()
pub
fn
apply_with_finish_reason
<
S
>
(
self
,
stream
:
S
,
)
->
impl
Stream
<
Item
=
Annotated
<
NvCreateChatCompletionStreamResponse
>>
+
Send
where
S
:
Stream
<
Item
=
Annotated
<
NvCreateChatCompletionStreamResponse
>>
+
Send
+
'static
,
{
let
jailed_stream
=
self
.apply
(
stream
);
JailedStream
::
fix_finish_reason
(
jailed_stream
)
}
/// Apply the jail transformation to a stream of chat completion responses
/// Apply the jail transformation to a stream of chat completion responses
/// Consumes self and returns the transformed stream
/// Consumes self and returns the transformed stream
pub
fn
apply
<
S
>
(
pub
fn
apply
<
S
>
(
...
@@ -449,6 +464,7 @@ impl JailedStream {
...
@@ -449,6 +464,7 @@ impl JailedStream {
// Pin the stream for iteration (stack pinning is more efficient)
// Pin the stream for iteration (stack pinning is more efficient)
tokio
::
pin!
(
stream
);
tokio
::
pin!
(
stream
);
// Process each item in the stream
// Process each item in the stream
while
let
Some
(
response
)
=
stream
.next
()
.await
{
while
let
Some
(
response
)
=
stream
.next
()
.await
{
if
let
Some
(
chat_response
)
=
response
.data
.as_ref
()
{
if
let
Some
(
chat_response
)
=
response
.data
.as_ref
()
{
...
@@ -467,6 +483,9 @@ impl JailedStream {
...
@@ -467,6 +483,9 @@ impl JailedStream {
last_annotated_comment
=
response
.comment
.clone
();
last_annotated_comment
=
response
.comment
.clone
();
}
}
// Track actual stream finish reason in the choice state
choice_state
.stream_finish_reason
=
choice
.finish_reason
;
// Process this choice and get emissions
// Process this choice and get emissions
let
emissions
=
choice_state
.process_content
(
choice
,
content
,
&
self
)
.await
;
let
emissions
=
choice_state
.process_content
(
choice
,
content
,
&
self
)
.await
;
all_emissions
.extend
(
emissions
);
all_emissions
.extend
(
emissions
);
...
@@ -707,16 +726,16 @@ impl JailedStream {
...
@@ -707,16 +726,16 @@ impl JailedStream {
}),
}),
})
})
.collect
();
.collect
();
// Create choice with tool calls
// Create choice with tool calls
r
et
urn
create_choice_stream
(
l
et
choice
=
create_choice_stream
(
choice_index
,
choice_index
,
Some
(
Role
::
Assistant
),
Some
(
Role
::
Assistant
),
normal_text
.as_deref
()
.unwrap_or
(
""
),
normal_text
.as_deref
()
.unwrap_or
(
""
),
Some
(
tool_call_chunks
),
Some
(
tool_call_chunks
),
Some
(
FinishReason
::
ToolCalls
)
,
None
,
None
,
None
,
);
);
return
choice
;
}
}
// No tool calls found or parsing failed, return content choice
// No tool calls found or parsing failed, return content choice
...
@@ -725,7 +744,7 @@ impl JailedStream {
...
@@ -725,7 +744,7 @@ impl JailedStream {
Some
(
Role
::
Assistant
),
Some
(
Role
::
Assistant
),
accumulated_content
,
accumulated_content
,
None
,
None
,
None
,
base_choice
.finish_reason
,
base_choice
.logprobs
.clone
(),
base_choice
.logprobs
.clone
(),
)
)
}
}
...
@@ -745,6 +764,44 @@ impl JailedStream {
...
@@ -745,6 +764,44 @@ impl JailedStream {
}
}
false
false
}
}
/// Post-processor that sets finish_reason to ToolCalls when tool calls were emitted
/// This should be called after apply() to fix the finish_reason for tool call chunks
pub
fn
fix_finish_reason
<
S
>
(
input_stream
:
S
,
)
->
impl
Stream
<
Item
=
Annotated
<
NvCreateChatCompletionStreamResponse
>>
+
Send
where
S
:
Stream
<
Item
=
Annotated
<
NvCreateChatCompletionStreamResponse
>>
+
Send
+
'static
,
{
stream!
{
tokio
::
pin!
(
input_stream
);
let
mut
has_tool_calls_per_choice
:
HashMap
<
u32
,
bool
>
=
HashMap
::
new
();
while
let
Some
(
mut
response
)
=
input_stream
.next
()
.await
{
// Track if any choice emitted tool calls
if
let
Some
(
ref
data
)
=
response
.data
{
for
choice
in
&
data
.choices
{
if
choice
.delta.tool_calls
.is_some
()
{
has_tool_calls_per_choice
.insert
(
choice
.index
,
true
);
}
}
}
// If this chunk has finish_reason and the choice had tool calls, override to ToolCalls
if
let
Some
(
ref
mut
data
)
=
response
.data
{
for
choice
in
&
mut
data
.choices
{
if
choice
.finish_reason
.is_some
()
&&
choice
.finish_reason
==
Some
(
FinishReason
::
Stop
)
&&
has_tool_calls_per_choice
.get
(
&
choice
.index
)
.copied
()
.unwrap_or
(
false
)
{
choice
.finish_reason
=
Some
(
FinishReason
::
ToolCalls
);
}
}
}
yield
response
;
}
}
}
}
}
/// Builder for configuring a JailedStream
/// Builder for configuring a JailedStream
...
...
lib/llm/tests/data/vllm/qwen3-0.6B/chat_completion_incomplete_tool.json
0 → 100644
View file @
04f7579b
{
"request_id"
:
"8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"expected_output"
:
{
"normal_content"
:
" the requested format.
\n
</think>
\n\n
<tool_call>
\n\n
{
\"
name
\"
:
\"
get"
},
"input_stream"
:
[
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
" the"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
" requested"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
" format"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
".
\n
"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
"</think>"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
"
\n\n
"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
"<tool_call>"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
"
\n
"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
"{
\"
"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
"name"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
"
\"
:"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
"
\"
"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
"get"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
},
"finish_reason"
:
"length"
}]}}
]
}
lib/llm/tests/data/vllm/qwen3-0.6B/chat_completion_stream_finish_length.json
0 → 100644
View file @
04f7579b
{
"request_id"
:
"8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"expected_output"
:
{
"normal_content"
:
"<think>
\n
Okay, the user is asking for the weather in San Francisco in"
},
"input_stream"
:
[
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
"<think>"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
"
\n
"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
"Okay"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
","
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
" the"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
" user"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
" is"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
" asking"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
" for"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
" the"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
" weather"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
}}]}},
{
"data"
:{
"id"
:
"chatcmpl-8f33c28b-cb52-4272-9ac5-0cb9f80386d3"
,
"choices"
:[{
"index"
:
0
,
"delta"
:{
"content"
:
" in"
,
"function_call"
:
null
,
"tool_calls"
:
null
,
"role"
:
"assistant"
,
"refusal"
:
null
,
"reasoning_content"
:
null
},
"finish_reason"
:
"length"
}]}}
]
}
lib/llm/tests/test_jail.rs
View file @
04f7579b
...
@@ -179,6 +179,49 @@ mod tests {
...
@@ -179,6 +179,49 @@ mod tests {
}
}
}
}
/// Helper function to create a multi-choice finish_reason chunk
pub
fn
create_multi_choice_finish_chunk
(
choice_indices
:
Vec
<
u32
>
,
)
->
Annotated
<
NvCreateChatCompletionStreamResponse
>
{
let
choices
:
Vec
<
ChatChoiceStream
>
=
choice_indices
.into_iter
()
.map
(|
index
|
{
#[allow(deprecated)]
ChatChoiceStream
{
index
,
delta
:
ChatCompletionStreamResponseDelta
{
role
:
None
,
content
:
None
,
tool_calls
:
None
,
function_call
:
None
,
refusal
:
None
,
reasoning_content
:
None
,
},
finish_reason
:
Some
(
FinishReason
::
Stop
),
logprobs
:
None
,
}
})
.collect
();
let
response
=
NvCreateChatCompletionStreamResponse
{
id
:
"test-id"
.to_string
(),
choices
,
created
:
1234567890
,
model
:
"test-model"
.to_string
(),
system_fingerprint
:
Some
(
"test-fingerprint"
.to_string
()),
object
:
"chat.completion.chunk"
.to_string
(),
usage
:
None
,
service_tier
:
None
,
};
Annotated
{
data
:
Some
(
response
),
id
:
None
,
event
:
None
,
comment
:
None
,
}
}
/// Helper to assert content in a result
/// Helper to assert content in a result
pub
fn
assert_content
(
pub
fn
assert_content
(
result
:
&
Annotated
<
NvCreateChatCompletionStreamResponse
>
,
result
:
&
Annotated
<
NvCreateChatCompletionStreamResponse
>
,
...
@@ -336,8 +379,7 @@ mod tests {
...
@@ -336,8 +379,7 @@ mod tests {
.jail_end_sequence
(
"</jail>"
)
.jail_end_sequence
(
"</jail>"
)
.build
();
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// We should only get 3 chunks now:
// We should only get 3 chunks now:
// 1. "Hello " (before jail)
// 1. "Hello " (before jail)
...
@@ -393,8 +435,7 @@ mod tests {
...
@@ -393,8 +435,7 @@ mod tests {
.tool_call_parser
(
"nemotron_deci"
)
.tool_call_parser
(
"nemotron_deci"
)
.build
();
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Should have jailed the content and parsed tool calls at the end
// Should have jailed the content and parsed tool calls at the end
assert
!
(
!
results
.is_empty
());
assert
!
(
!
results
.is_empty
());
...
@@ -431,8 +472,7 @@ mod tests {
...
@@ -431,8 +472,7 @@ mod tests {
.tool_call_parser
(
"nemotron_deci"
)
.tool_call_parser
(
"nemotron_deci"
)
.build
();
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// We should get 2 chunks:
// We should get 2 chunks:
// 1. "Normal text " (before jail)
// 1. "Normal text " (before jail)
...
@@ -475,8 +515,7 @@ mod tests {
...
@@ -475,8 +515,7 @@ mod tests {
.tool_call_parser
(
"nemotron_deci"
)
.tool_call_parser
(
"nemotron_deci"
)
.build
();
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Should have exactly 2 chunks: tool call + trailing content
// Should have exactly 2 chunks: tool call + trailing content
assert_eq!
(
assert_eq!
(
...
@@ -518,8 +557,7 @@ mod tests {
...
@@ -518,8 +557,7 @@ mod tests {
.jail_start_sequence
(
"<NOTPRESENT>"
)
.jail_start_sequence
(
"<NOTPRESENT>"
)
.build
();
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// === Verify chunk count ===
// === Verify chunk count ===
assert_eq!
(
assert_eq!
(
...
@@ -572,8 +610,7 @@ mod tests {
...
@@ -572,8 +610,7 @@ mod tests {
// Create JailedStream with Hermes parser
// Create JailedStream with Hermes parser
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Should have exactly 3 chunks: content + tool call + content
// Should have exactly 3 chunks: content + tool call + content
assert_eq!
(
assert_eq!
(
...
@@ -618,8 +655,7 @@ mod tests {
...
@@ -618,8 +655,7 @@ mod tests {
// Create JailedStream with Mistral parser
// Create JailedStream with Mistral parser
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"mistral"
)
.build
();
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"mistral"
)
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Should have exactly 3 chunks: content + tool call + content
// Should have exactly 3 chunks: content + tool call + content
assert_eq!
(
assert_eq!
(
...
@@ -660,8 +696,7 @@ mod tests {
...
@@ -660,8 +696,7 @@ mod tests {
// Create JailedStream with Mistral parser
// Create JailedStream with Mistral parser
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"mistral"
)
.build
();
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"mistral"
)
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Should have exactly 3 chunks: content + tool call + content
// Should have exactly 3 chunks: content + tool call + content
assert_eq!
(
assert_eq!
(
...
@@ -709,8 +744,7 @@ mod tests {
...
@@ -709,8 +744,7 @@ mod tests {
// Create JailedStream with Phi4 parser
// Create JailedStream with Phi4 parser
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"phi4"
)
.build
();
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"phi4"
)
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Should have exactly 3 chunks: content + tool call + content
// Should have exactly 3 chunks: content + tool call + content
assert_eq!
(
assert_eq!
(
...
@@ -756,8 +790,7 @@ mod tests {
...
@@ -756,8 +790,7 @@ mod tests {
.tool_call_parser
(
"llama3_json"
)
.tool_call_parser
(
"llama3_json"
)
.build
();
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Should have exactly 3 chunks: content + tool call + content
// Should have exactly 3 chunks: content + tool call + content
assert_eq!
(
assert_eq!
(
...
@@ -797,8 +830,7 @@ mod tests {
...
@@ -797,8 +830,7 @@ mod tests {
// Create JailedStream with mistral parser (which specifically looks for [{ or [TOOL_CALLS] patterns)
// Create JailedStream with mistral parser (which specifically looks for [{ or [TOOL_CALLS] patterns)
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"mistral"
)
.build
();
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"mistral"
)
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// The "{" pattern triggers jailing, so some chunks get combined
// The "{" pattern triggers jailing, so some chunks get combined
assert_eq!
(
results
.len
(),
2
);
assert_eq!
(
results
.len
(),
2
);
...
@@ -839,8 +871,7 @@ mod tests {
...
@@ -839,8 +871,7 @@ mod tests {
.tool_call_parser
(
"nemotron_deci"
)
.tool_call_parser
(
"nemotron_deci"
)
.build
();
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Jailing combines the tool call content into fewer chunks
// Jailing combines the tool call content into fewer chunks
assert_eq!
(
assert_eq!
(
...
@@ -884,8 +915,7 @@ mod tests {
...
@@ -884,8 +915,7 @@ mod tests {
.tool_call_parser
(
"nemotron_deci"
)
.tool_call_parser
(
"nemotron_deci"
)
.build
();
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Should handle partial tool call gracefully - releases accumulated content on stream end
// Should handle partial tool call gracefully - releases accumulated content on stream end
assert_eq!
(
assert_eq!
(
...
@@ -924,8 +954,7 @@ mod tests {
...
@@ -924,8 +954,7 @@ mod tests {
.jail_end_sequence
(
"</jail>"
)
.jail_end_sequence
(
"</jail>"
)
.build
();
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// === Verify chunk count ===
// === Verify chunk count ===
assert_eq!
(
assert_eq!
(
...
@@ -979,8 +1008,7 @@ mod tests {
...
@@ -979,8 +1008,7 @@ mod tests {
.tool_call_parser
(
"nemotron_deci"
)
.tool_call_parser
(
"nemotron_deci"
)
.build
();
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// === Verify chunk count ===
// === Verify chunk count ===
assert_eq!
(
assert_eq!
(
...
@@ -1087,8 +1115,7 @@ mod tests {
...
@@ -1087,8 +1115,7 @@ mod tests {
.tool_call_parser
(
"nemotron_deci"
)
.tool_call_parser
(
"nemotron_deci"
)
.build
();
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Should consolidate extreme fragmentation into 3 clean chunks
// Should consolidate extreme fragmentation into 3 clean chunks
// Input: "I'll process your request. " + 54-char tool call + " Processing complete!"
// Input: "I'll process your request. " + 54-char tool call + " Processing complete!"
...
@@ -1142,6 +1169,7 @@ mod tests {
...
@@ -1142,6 +1169,7 @@ mod tests {
create_mock_response_chunk
(
"
\"
arguments
\"
: {
\"
query
\"
:
\"
test
\"
}}"
.to_string
(),
0
),
create_mock_response_chunk
(
"
\"
arguments
\"
: {
\"
query
\"
:
\"
test
\"
}}"
.to_string
(),
0
),
create_mock_response_chunk
(
"</tool_call>"
.to_string
(),
0
),
create_mock_response_chunk
(
"</tool_call>"
.to_string
(),
0
),
create_mock_response_chunk
(
" Processing complete."
.to_string
(),
0
),
create_mock_response_chunk
(
" Processing complete."
.to_string
(),
0
),
test_utils
::
create_final_response_chunk
(
0
),
// Backend finish_reason chunk
];
];
let
input_stream
=
stream
::
iter
(
chunks
);
let
input_stream
=
stream
::
iter
(
chunks
);
...
@@ -1149,8 +1177,7 @@ mod tests {
...
@@ -1149,8 +1177,7 @@ mod tests {
// Create JailedStream with Hermes parser
// Create JailedStream with Hermes parser
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Should get 3 chunks: before jail, tool call response, after jail
// Should get 3 chunks: before jail, tool call response, after jail
assert
!
(
assert
!
(
...
@@ -1159,14 +1186,14 @@ mod tests {
...
@@ -1159,14 +1186,14 @@ mod tests {
results
.len
()
results
.len
()
);
);
// Find the
synthesized
tool
call
response
chunk
// Find the
tool call chunk (the one with
tool
_
call
s, not the finish_reason
chunk
)
let
tool_call_chunk
=
results
let
tool_call_chunk
=
results
.iter
()
.iter
()
.find
(|
r
|
{
.find
(|
r
|
{
r
.data
r
.data
.as_ref
()
.as_ref
()
.and_then
(|
d
|
d
.choices
.first
())
.and_then
(|
d
|
d
.choices
.first
())
.map
(|
c
|
c
.
finish_reason
==
Some
(
FinishReason
::
T
ool
C
alls
))
.map
(|
c
|
c
.
delta.t
ool
_c
alls
.is_some
(
))
.unwrap_or
(
false
)
.unwrap_or
(
false
)
})
})
.expect
(
"Should have a tool call response chunk"
);
.expect
(
"Should have a tool call response chunk"
);
...
@@ -1232,8 +1259,7 @@ mod tests {
...
@@ -1232,8 +1259,7 @@ mod tests {
// Create JailedStream with Hermes parser
// Create JailedStream with Hermes parser
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Should get 2 chunks: first chunk passes through, stream end releases accumulated
// Should get 2 chunks: first chunk passes through, stream end releases accumulated
assert_eq!
(
results
.len
(),
2
,
"Should have exactly 2 chunks"
);
assert_eq!
(
results
.len
(),
2
,
"Should have exactly 2 chunks"
);
...
@@ -1291,23 +1317,23 @@ mod tests {
...
@@ -1291,23 +1317,23 @@ mod tests {
),
),
create_mock_response_chunk
(
"{
\"
name
\"
:
\"
test
\"
,
\"
arguments
\"
: {}}"
.to_string
(),
0
),
create_mock_response_chunk
(
"{
\"
name
\"
:
\"
test
\"
,
\"
arguments
\"
: {}}"
.to_string
(),
0
),
create_mock_response_chunk
(
"</tool_call>"
.to_string
(),
0
),
create_mock_response_chunk
(
"</tool_call>"
.to_string
(),
0
),
test_utils
::
create_final_response_chunk
(
0
),
// Backend finish_reason chunk
];
];
let
input_stream
=
stream
::
iter
(
chunks
);
let
input_stream
=
stream
::
iter
(
chunks
);
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Find the tool call
response
// Find the tool call
chunk (the one with tool_calls, not the finish_reason chunk)
let
tool_call_chunk
=
results
let
tool_call_chunk
=
results
.iter
()
.iter
()
.find
(|
r
|
{
.find
(|
r
|
{
r
.data
r
.data
.as_ref
()
.as_ref
()
.and_then
(|
d
|
d
.choices
.first
())
.and_then
(|
d
|
d
.choices
.first
())
.map
(|
c
|
c
.
finish_reason
==
Some
(
FinishReason
::
T
ool
C
alls
))
.map
(|
c
|
c
.
delta.t
ool
_c
alls
.is_some
(
))
.unwrap_or
(
false
)
.unwrap_or
(
false
)
})
})
.expect
(
"Should have a tool call response chunk"
);
.expect
(
"Should have a tool call response chunk"
);
...
@@ -1352,8 +1378,7 @@ mod tests {
...
@@ -1352,8 +1378,7 @@ mod tests {
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// === Verify chunk count ===
// === Verify chunk count ===
assert_eq!
(
assert_eq!
(
...
@@ -1395,8 +1420,7 @@ mod tests {
...
@@ -1395,8 +1420,7 @@ mod tests {
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Should have exactly 3 chunks: content + tool call + trailing
// Should have exactly 3 chunks: content + tool call + trailing
assert_eq!
(
assert_eq!
(
...
@@ -1453,14 +1477,15 @@ mod tests {
...
@@ -1453,14 +1477,15 @@ mod tests {
(
"Done with B. "
.to_string
(),
1
),
// Choice 1 continues
(
"Done with B. "
.to_string
(),
1
),
// Choice 1 continues
(
"</tool_call>"
.to_string
(),
2
),
// Choice 2 unjails
(
"</tool_call>"
.to_string
(),
2
),
// Choice 2 unjails
]),
]),
// Chunk 6: Backend finish_reason chunks for all choices
test_utils
::
create_multi_choice_finish_chunk
(
vec!
[
0
,
1
,
2
]),
];
];
let
input_stream
=
stream
::
iter
(
chunks
);
let
input_stream
=
stream
::
iter
(
chunks
);
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// EXPECTED BEHAVIOR (will fail with current implementation):
// EXPECTED BEHAVIOR (will fail with current implementation):
// - Choice 1 should stream continuously (never jailed)
// - Choice 1 should stream continuously (never jailed)
...
@@ -1529,14 +1554,14 @@ mod tests {
...
@@ -1529,14 +1554,14 @@ mod tests {
2
,
2
,
),
),
]),
]),
test_utils
::
create_multi_choice_finish_chunk
(
vec!
[
0
,
1
,
2
]),
];
];
let
input_stream
=
stream
::
iter
(
chunks
);
let
input_stream
=
stream
::
iter
(
chunks
);
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Find all tool call responses
// Find all tool call responses
let
mut
tool_call_responses
:
Vec
<
_
>
=
results
let
mut
tool_call_responses
:
Vec
<
_
>
=
results
...
@@ -1559,25 +1584,30 @@ mod tests {
...
@@ -1559,25 +1584,30 @@ mod tests {
// Run this test multiple times to verify determinism
// Run this test multiple times to verify determinism
for
run
in
0
..
5
{
for
run
in
0
..
5
{
let
chunks
=
vec!
[
create_multi_choice_chunk
(
vec!
[
let
chunks
=
vec!
[
(
create_multi_choice_chunk
(
vec!
[
"<tool_call>{
\"
name
\"
:
\"
tool_0
\"
,
\"
arguments
\"
: {}}</tool_call>"
.to_string
(),
(
0
,
"<tool_call>{
\"
name
\"
:
\"
tool_0
\"
,
\"
arguments
\"
: {}}</tool_call>"
),
.to_string
(),
(
0
,
"<tool_call>{
\"
name
\"
:
\"
tool_1
\"
,
\"
arguments
\"
: {}}</tool_call>"
.to_string
(),
),
1
,
(
),
"<tool_call>{
\"
name
\"
:
\"
tool_1
\"
,
\"
arguments
\"
: {}}</tool_call>"
(
.to_string
(),
"<tool_call>{
\"
name
\"
:
\"
tool_2
\"
,
\"
arguments
\"
: {}}</tool_call>"
.to_string
(),
1
,
2
,
),
),
(
])];
"<tool_call>{
\"
name
\"
:
\"
tool_2
\"
,
\"
arguments
\"
: {}}</tool_call>"
.to_string
(),
2
,
),
]),
test_utils
::
create_multi_choice_finish_chunk
(
vec!
[
0
,
1
,
2
]),
];
let
input_stream
=
stream
::
iter
(
chunks
);
let
input_stream
=
stream
::
iter
(
chunks
);
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"hermes"
)
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
run_results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
run_results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
let
run_responses
:
Vec
<
_
>
=
run_results
let
run_responses
:
Vec
<
_
>
=
run_results
.iter
()
.iter
()
...
@@ -1616,8 +1646,7 @@ mod tests {
...
@@ -1616,8 +1646,7 @@ mod tests {
let
jail
=
JailedStream
::
builder
()
.build
();
let
jail
=
JailedStream
::
builder
()
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// TODO: Once usage aggregation is implemented, verify:
// TODO: Once usage aggregation is implemented, verify:
// - Usage chunk has choices: [] (empty array)
// - Usage chunk has choices: [] (empty array)
...
@@ -1652,8 +1681,7 @@ mod tests {
...
@@ -1652,8 +1681,7 @@ mod tests {
.tool_call_parser
(
"nemotron_deci"
)
.tool_call_parser
(
"nemotron_deci"
)
.build
();
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// === Verify chunk count ===
// === Verify chunk count ===
assert_eq!
(
assert_eq!
(
...
@@ -1708,8 +1736,7 @@ mod tests {
...
@@ -1708,8 +1736,7 @@ mod tests {
.jail_end_sequence
(
"</TOOLCALL>"
)
.jail_end_sequence
(
"</TOOLCALL>"
)
.build
();
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// === Verify chunk count ===
// === Verify chunk count ===
assert_eq!
(
assert_eq!
(
...
@@ -1763,8 +1790,7 @@ mod tests {
...
@@ -1763,8 +1790,7 @@ mod tests {
let
input_stream
=
stream
::
iter
(
chunks
);
let
input_stream
=
stream
::
iter
(
chunks
);
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"harmony"
)
.build
();
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"harmony"
)
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Should have at least one output containing both analysis text and parsed tool call
// Should have at least one output containing both analysis text and parsed tool call
assert
!
(
!
results
.is_empty
());
assert
!
(
!
results
.is_empty
());
...
@@ -1804,7 +1830,7 @@ mod tests {
...
@@ -1804,7 +1830,7 @@ mod tests {
let
jail
=
JailedStream
::
builder
()
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"deepseek_v3_1"
)
.tool_call_parser
(
"deepseek_v3_1"
)
.build
();
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
jailed_stream
=
jail
.apply
_with_finish_reason
(
input_stream
);
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Should have at least one output containing both analysis text and parsed tool call
// Should have at least one output containing both analysis text and parsed tool call
...
@@ -1878,7 +1904,7 @@ mod tests {
...
@@ -1878,7 +1904,7 @@ mod tests {
let
jail
=
JailedStream
::
builder
()
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"deepseek_v3_1"
)
.tool_call_parser
(
"deepseek_v3_1"
)
.build
();
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
jailed_stream
=
jail
.apply
_with_finish_reason
(
input_stream
);
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Should have at least one output containing both analysis text and parsed tool call
// Should have at least one output containing both analysis text and parsed tool call
...
@@ -1920,8 +1946,7 @@ mod tests {
...
@@ -1920,8 +1946,7 @@ mod tests {
let
input_stream
=
stream
::
iter
(
chunks
);
let
input_stream
=
stream
::
iter
(
chunks
);
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"mistral"
)
.build
();
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"mistral"
)
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
assert
!
(
results
.len
()
>=
2
);
assert
!
(
results
.len
()
>=
2
);
assert_content
(
&
results
[
0
],
"Hey How"
);
assert_content
(
&
results
[
0
],
"Hey How"
);
...
@@ -1956,8 +1981,7 @@ mod tests {
...
@@ -1956,8 +1981,7 @@ mod tests {
let
input_stream
=
stream
::
iter
(
chunks
);
let
input_stream
=
stream
::
iter
(
chunks
);
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"mistral"
)
.build
();
let
jail
=
JailedStream
::
builder
()
.tool_call_parser
(
"mistral"
)
.build
();
let
jailed_stream
=
jail
.apply
(
input_stream
);
let
results
:
Vec
<
_
>
=
jail
.apply_with_finish_reason
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jailed_stream
.collect
()
.await
;
// Should preserve earlier content and also produce a tool call
// Should preserve earlier content and also produce a tool call
assert
!
(
results
.len
()
>=
2
);
assert
!
(
results
.len
()
>=
2
);
...
@@ -2130,7 +2154,7 @@ mod parallel_jail_tests {
...
@@ -2130,7 +2154,7 @@ mod parallel_jail_tests {
];
];
let
input_stream
=
stream
::
iter
(
input_chunks
);
let
input_stream
=
stream
::
iter
(
input_chunks
);
let
results
:
Vec
<
_
>
=
jail
.apply
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jail
.apply
_with_finish_reason
(
input_stream
)
.collect
()
.await
;
// Should have tool call results
// Should have tool call results
assert
!
(
!
results
.is_empty
(),
"Should have results"
);
assert
!
(
!
results
.is_empty
(),
"Should have results"
);
...
@@ -2203,7 +2227,7 @@ mod parallel_jail_tests {
...
@@ -2203,7 +2227,7 @@ mod parallel_jail_tests {
];
];
let
input_stream
=
stream
::
iter
(
input_chunks
);
let
input_stream
=
stream
::
iter
(
input_chunks
);
let
results
:
Vec
<
_
>
=
jail
.apply
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jail
.apply
_with_finish_reason
(
input_stream
)
.collect
()
.await
;
assert
!
(
!
results
.is_empty
(),
"Should have results"
);
assert
!
(
!
results
.is_empty
(),
"Should have results"
);
...
@@ -2240,7 +2264,7 @@ mod parallel_jail_tests {
...
@@ -2240,7 +2264,7 @@ mod parallel_jail_tests {
];
];
let
input_stream
=
stream
::
iter
(
input_chunks
);
let
input_stream
=
stream
::
iter
(
input_chunks
);
let
results
:
Vec
<
_
>
=
jail
.apply
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jail
.apply
_with_finish_reason
(
input_stream
)
.collect
()
.await
;
assert
!
(
!
results
.is_empty
(),
"Should have results"
);
assert
!
(
!
results
.is_empty
(),
"Should have results"
);
...
@@ -2310,7 +2334,7 @@ mod parallel_jail_tests {
...
@@ -2310,7 +2334,7 @@ mod parallel_jail_tests {
];
];
let
input_stream
=
stream
::
iter
(
input_chunks
);
let
input_stream
=
stream
::
iter
(
input_chunks
);
let
results
:
Vec
<
_
>
=
jail
.apply
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jail
.apply
_with_finish_reason
(
input_stream
)
.collect
()
.await
;
assert
!
(
!
results
.is_empty
(),
"Should have results"
);
assert
!
(
!
results
.is_empty
(),
"Should have results"
);
...
@@ -2548,7 +2572,7 @@ mod parallel_jail_tests {
...
@@ -2548,7 +2572,7 @@ mod parallel_jail_tests {
];
];
let
input_stream
=
stream
::
iter
(
input_chunks
);
let
input_stream
=
stream
::
iter
(
input_chunks
);
let
results
:
Vec
<
_
>
=
jail
.apply
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jail
.apply
_with_finish_reason
(
input_stream
)
.collect
()
.await
;
assert
!
(
!
results
.is_empty
(),
"Should have results"
);
assert
!
(
!
results
.is_empty
(),
"Should have results"
);
...
@@ -2593,7 +2617,7 @@ mod parallel_jail_tests {
...
@@ -2593,7 +2617,7 @@ mod parallel_jail_tests {
];
];
let
input_stream
=
stream
::
iter
(
input_chunks
);
let
input_stream
=
stream
::
iter
(
input_chunks
);
let
results
:
Vec
<
_
>
=
jail
.apply
(
input_stream
)
.collect
()
.await
;
let
results
:
Vec
<
_
>
=
jail
.apply
_with_finish_reason
(
input_stream
)
.collect
()
.await
;
// Should still handle the incomplete stream gracefully
// Should still handle the incomplete stream gracefully
assert
!
(
assert
!
(
...
...
lib/llm/tests/test_streaming_tool_parsers.rs
View file @
04f7579b
...
@@ -26,7 +26,7 @@ across backends.
...
@@ -26,7 +26,7 @@ across backends.
*/
*/
use
dynamo_async_openai
::
types
::
ChatChoiceStream
;
use
dynamo_async_openai
::
types
::
{
ChatChoiceStream
,
FinishReason
}
;
use
dynamo_llm
::
preprocessor
::
OpenAIPreprocessor
;
use
dynamo_llm
::
preprocessor
::
OpenAIPreprocessor
;
use
dynamo_llm
::
protocols
::
openai
::
chat_completions
::
NvCreateChatCompletionStreamResponse
;
use
dynamo_llm
::
protocols
::
openai
::
chat_completions
::
NvCreateChatCompletionStreamResponse
;
use
dynamo_runtime
::
protocols
::
annotated
::
Annotated
;
use
dynamo_runtime
::
protocols
::
annotated
::
Annotated
;
...
@@ -251,6 +251,71 @@ fn aggregate_content_from_chunks(
...
@@ -251,6 +251,71 @@ fn aggregate_content_from_chunks(
}
}
}
}
/// Helper function to validate finish_reason in the stream
/// Returns true if:
/// 1. There is exactly one finish_reason in the entire stream
/// 2. The finish_reason is in the last chunk
/// 3. The finish_reason matches the expected value
fn
validate_finish_reason
(
chunks
:
&
[
Annotated
<
NvCreateChatCompletionStreamResponse
>
],
expected_finish_reason
:
FinishReason
,
)
->
bool
{
let
mut
finish_reason_count
=
0
;
let
mut
last_chunk_index
=
None
;
let
mut
finish_reason_value
=
None
;
// Count finish_reason occurrences and track position
for
(
idx
,
chunk
)
in
chunks
.iter
()
.enumerate
()
{
if
let
Some
(
ref
response_data
)
=
chunk
.data
{
for
choice
in
&
response_data
.choices
{
if
let
Some
(
reason
)
=
choice
.finish_reason
{
finish_reason_count
+=
1
;
last_chunk_index
=
Some
(
idx
);
finish_reason_value
=
Some
(
reason
);
}
}
}
}
// Validate:
// 1. Exactly one finish_reason in the stream
if
finish_reason_count
!=
1
{
eprintln!
(
"Expected exactly 1 finish_reason, but found {}"
,
finish_reason_count
);
return
false
;
}
// 2. finish_reason is in the last chunk
if
let
Some
(
idx
)
=
last_chunk_index
{
if
idx
!=
chunks
.len
()
-
1
{
eprintln!
(
"Expected finish_reason in last chunk (index {}), but found at index {}"
,
chunks
.len
()
-
1
,
idx
);
return
false
;
}
}
else
{
eprintln!
(
"No finish_reason found in stream"
);
return
false
;
}
// 3. finish_reason matches expected value
if
let
Some
(
reason
)
=
finish_reason_value
&&
reason
!=
expected_finish_reason
{
eprintln!
(
"Expected finish_reason {:?}, but found {:?}"
,
expected_finish_reason
,
reason
);
return
false
;
}
true
}
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
use
super
::
*
;
use
super
::
*
;
...
@@ -304,6 +369,12 @@ mod tests {
...
@@ -304,6 +369,12 @@ mod tests {
aggregated
.has_tool_calls
,
expected_has_tool_calls
,
aggregated
.has_tool_calls
,
expected_has_tool_calls
,
"Tool calls presence should match expected value"
"Tool calls presence should match expected value"
);
);
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is Stop
assert
!
(
validate_finish_reason
(
&
output_chunks
,
FinishReason
::
Stop
),
"finish_reason validation failed for non-tool call case"
);
}
}
#[tokio::test]
#[tokio::test]
...
@@ -360,6 +431,12 @@ mod tests {
...
@@ -360,6 +431,12 @@ mod tests {
// Verify tool calls
// Verify tool calls
assert_tool_calls
(
&
aggregated
.tool_calls
,
&
test_data
.expected_tool_calls
);
assert_tool_calls
(
&
aggregated
.tool_calls
,
&
test_data
.expected_tool_calls
);
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is ToolCalls
assert
!
(
validate_finish_reason
(
&
output_chunks
,
FinishReason
::
ToolCalls
),
"finish_reason validation failed for tool call case"
);
}
}
#[tokio::test]
#[tokio::test]
...
@@ -403,6 +480,12 @@ mod tests {
...
@@ -403,6 +480,12 @@ mod tests {
aggregated
.has_tool_calls
,
expected_has_tool_calls
,
aggregated
.has_tool_calls
,
expected_has_tool_calls
,
"Tool calls presence should match expected value"
"Tool calls presence should match expected value"
);
);
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is Stop
assert
!
(
validate_finish_reason
(
&
output_chunks
,
FinishReason
::
Stop
),
"finish_reason validation failed for non-tool call case"
);
}
}
#[tokio::test]
#[tokio::test]
...
@@ -455,6 +538,12 @@ mod tests {
...
@@ -455,6 +538,12 @@ mod tests {
// Verify tool calls
// Verify tool calls
assert_tool_calls
(
&
aggregated
.tool_calls
,
&
test_data
.expected_tool_calls
);
assert_tool_calls
(
&
aggregated
.tool_calls
,
&
test_data
.expected_tool_calls
);
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is ToolCalls
assert
!
(
validate_finish_reason
(
&
output_chunks
,
FinishReason
::
ToolCalls
),
"finish_reason validation failed for tool call case"
);
}
}
#[tokio::test]
#[tokio::test]
...
@@ -511,6 +600,12 @@ mod tests {
...
@@ -511,6 +600,12 @@ mod tests {
);
);
assert_tool_calls
(
&
aggregated
.tool_calls
,
&
test_data
.expected_tool_calls
);
assert_tool_calls
(
&
aggregated
.tool_calls
,
&
test_data
.expected_tool_calls
);
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is Stop
assert
!
(
validate_finish_reason
(
&
output_chunks
,
FinishReason
::
Stop
),
"finish_reason validation failed for non-tool call case"
);
}
}
#[tokio::test]
#[tokio::test]
...
@@ -567,6 +662,12 @@ mod tests {
...
@@ -567,6 +662,12 @@ mod tests {
);
);
assert_tool_calls
(
&
aggregated
.tool_calls
,
&
test_data
.expected_tool_calls
);
assert_tool_calls
(
&
aggregated
.tool_calls
,
&
test_data
.expected_tool_calls
);
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is ToolCalls
assert
!
(
validate_finish_reason
(
&
output_chunks
,
FinishReason
::
ToolCalls
),
"finish_reason validation failed for tool call case"
);
}
}
#[tokio::test]
#[tokio::test]
...
@@ -620,6 +721,12 @@ mod tests {
...
@@ -620,6 +721,12 @@ mod tests {
);
);
assert_tool_calls
(
&
aggregated
.tool_calls
,
&
test_data
.expected_tool_calls
);
assert_tool_calls
(
&
aggregated
.tool_calls
,
&
test_data
.expected_tool_calls
);
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is Stop
assert
!
(
validate_finish_reason
(
&
output_chunks
,
FinishReason
::
Stop
),
"finish_reason validation failed for non-tool call case"
);
}
}
#[tokio::test]
#[tokio::test]
...
@@ -674,6 +781,12 @@ mod tests {
...
@@ -674,6 +781,12 @@ mod tests {
"Tool calls presence should match expected value"
"Tool calls presence should match expected value"
);
);
assert_tool_calls
(
&
aggregated
.tool_calls
,
&
test_data
.expected_tool_calls
);
assert_tool_calls
(
&
aggregated
.tool_calls
,
&
test_data
.expected_tool_calls
);
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is ToolCalls
assert
!
(
validate_finish_reason
(
&
output_chunks
,
FinishReason
::
ToolCalls
),
"finish_reason validation failed for tool call case"
);
}
}
#[tokio::test]
#[tokio::test]
...
@@ -726,5 +839,46 @@ mod tests {
...
@@ -726,5 +839,46 @@ mod tests {
// Verify tool calls
// Verify tool calls
assert_tool_calls
(
&
aggregated
.tool_calls
,
&
test_data
.expected_tool_calls
);
assert_tool_calls
(
&
aggregated
.tool_calls
,
&
test_data
.expected_tool_calls
);
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is ToolCalls
assert
!
(
validate_finish_reason
(
&
output_chunks
,
FinishReason
::
ToolCalls
),
"finish_reason validation failed for tool call case"
);
}
#[tokio::test]
async
fn
test_qwen_finish_reason_length_vllm
()
{
let
file_paths
=
vec!
[
format!
(
"{}/vllm/qwen3-0.6B/chat_completion_stream_finish_length.json"
,
DATA_ROOT_PATH
),
format!
(
"{}/vllm/qwen3-0.6B/chat_completion_incomplete_tool.json"
,
DATA_ROOT_PATH
),
];
for
file_path
in
file_paths
{
let
test_data
=
load_test_data
(
&
file_path
);
// Create a stream from the mock chunks
let
input_stream
=
stream
::
iter
(
test_data
.stream_chunks
);
// Parse the response stream with tool parsing enabled
let
output_chunks
=
parse_response_stream
(
input_stream
,
true
,
false
,
Some
(
"hermes"
.to_string
()),
None
)
.await
;
// Verify we got output chunks
assert
!
(
!
output_chunks
.is_empty
(),
"Should have output chunks"
);
// Verify finish_reason is valid: exactly one occurrence, in last chunk, and is Length
assert
!
(
validate_finish_reason
(
&
output_chunks
,
FinishReason
::
Length
),
"finish_reason validation failed for length finish case"
);
}
}
}
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment