"tests/vscode:/vscode.git/clone" did not exist on "836e8ef6eeafcd1e24b25c990da6331f48a95fd2"
Unverified Commit c63cceaa authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: JailedStream (#3034)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 4c56c8ae
...@@ -92,7 +92,6 @@ generated-values.yaml ...@@ -92,7 +92,6 @@ generated-values.yaml
**/.devcontainer/.env **/.devcontainer/.env
TensorRT-LLM TensorRT-LLM
# Ruler Generated Files # Ruler Generated Files
/.cursor/instructions.md /.cursor/instructions.md
/.cursor/instructions.md.bak /.cursor/instructions.md.bak
......
...@@ -1981,6 +1981,7 @@ name = "dynamo-llm" ...@@ -1981,6 +1981,7 @@ name = "dynamo-llm"
version = "0.5.0" version = "0.5.0"
dependencies = [ dependencies = [
"ahash", "ahash",
"aho-corasick",
"akin", "akin",
"aligned-vec", "aligned-vec",
"anyhow", "anyhow",
......
# JailedStream Implementation
## Overview
The `JailedStream` is a standalone implementation for handling "jail" detection in token streams. It provides a clean, builder-based API for accumulating tokens when certain sequences are detected, then releasing them as a single chunk when the jail ends.
## Key Features
- **Builder Pattern**: Clean configuration API using the builder pattern
- **Configurable Sequences**: Support for multiple start/end jail sequences
- **Tool Call Parsing**: Integrated tool call detection and parsing
- **Stream Macro**: Uses `async-stream::stream!` for clean async implementation
- **Standalone**: Completely independent of existing code
- **Annotations**: Preserves annotations for observability
## Implementation
### Location
- Main implementation: `lib/llm/src/protocols/openai/chat_completions/jail.rs`
- Examples: `lib/llm/src/protocols/openai/chat_completions/jail_example.rs`
### Usage
```rust
use crate::protocols::openai::chat_completions::jail::JailedStream;
use dynamo_runtime::engine::{AsyncEngineContextProvider, ResponseStream};
// Get your ResponseStream with context
let response_stream: Pin<Box<ResponseStream<_>>> = get_stream_from_engine();
// Extract context BEFORE passing to apply
let context = response_stream.context();
// Apply jail transformation (ResponseStream implements Stream)
let jail = JailedStream::builder()
.tool_call_parser("nemotron_deci")
.build();
let jailed_stream = jail.apply(response_stream);
// Re-wrap with context when needed for engine consumption
let final_stream = ResponseStream::new(Box::pin(jailed_stream), context);
```
### Advanced Configuration
```rust
// With custom jail sequences
let jail = JailedStream::builder()
.jail_start_sequence("<TOOLCALL>")
.jail_end_sequence("</TOOLCALL>")
.tool_call_parser("nemotron_deci")
.build();
// With multiple sequences
let jail = JailedStream::builder()
.jail_start_sequences(vec!["<TOOLCALL>", "<FUNCTION>"])
.jail_end_sequences(vec!["</TOOLCALL>", "</FUNCTION>"])
.tool_call_parser("harmony")
.build();
```
## How It Works
1. **Detection**: When a jail start sequence (or tool call start) is detected, the stream enters "jail" mode
2. **Accumulation**: While jailed, tokens are accumulated in memory instead of being yielded
3. **Annotations**: Empty chunks with annotations are sent downstream for observability
4. **Release**: When a jail end sequence is detected OR the stream ends:
- Accumulated content is parsed for tool calls
- A single chunk with the parsed content is yielded
5. **Pass-through**: Non-jailed content passes through unchanged
## Testing
The implementation includes comprehensive tests:
- `test_jailed_stream_with_start_end_sequences`: Tests explicit jail sequences
- `test_jailed_stream_with_tool_calls`: Tests tool call detection and parsing
- `test_jailed_stream_no_jailing`: Tests normal pass-through behavior
Run tests with:
```bash
cargo test -p dynamo-llm jail --lib
```
## Benefits
1. **Standalone**: No modifications to existing code required
2. **Clean API**: Builder pattern makes configuration intuitive
3. **Flexible**: Supports multiple jail detection strategies
4. **Maintainable**: Uses `stream!` macro for cleaner async code
5. **Testable**: Comprehensive test suite with shared utilities
6. **Efficient**: No unnecessary boxing or context handling in the library
7. **Composable**: Can chain multiple stream transformers before re-adding context
## Performance Optimizations
- **No Boxing in Library**: Returns `impl Stream` instead of `Pin<Box<ResponseStream>>`
- **Stack Pinning**: Uses `tokio::pin!()` instead of `Box::pin()` for better performance
- **No Context Overhead**: JailedStream doesn't manage AsyncEngineContext
- **Lazy Evaluation**: Only processes what's needed
- **Efficient State Management**: Minimal cloning, only when entering jail state
## Integration Options
To replace the existing `apply_tool_calling_jail_internal` function:
```rust
// In preprocessor.rs
pub fn apply_tool_calling_jail_with_parser(
&self,
stream: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
) -> ManyOut<Annotated<NvCreateChatCompletionStreamResponse>> {
let jail = JailedStream::builder()
.tool_call_parser(self.tool_call_parser.clone())
.build();
jail.apply(stream)
}
```
## Future Enhancements
- Add support for regex patterns for jail sequences
- Add metrics/telemetry for jail detection
- Support for partial sequence matching across chunk boundaries
- Configurable accumulation limits
- Support for nested jails
\ No newline at end of file
...@@ -1382,6 +1382,7 @@ name = "dynamo-llm" ...@@ -1382,6 +1382,7 @@ name = "dynamo-llm"
version = "0.5.0" version = "0.5.0"
dependencies = [ dependencies = [
"ahash", "ahash",
"aho-corasick",
"akin", "akin",
"anyhow", "anyhow",
"async-nats", "async-nats",
......
...@@ -52,6 +52,7 @@ required-features = ["block-manager", "testing-cuda"] ...@@ -52,6 +52,7 @@ required-features = ["block-manager", "testing-cuda"]
dynamo-runtime = { workspace = true } dynamo-runtime = { workspace = true }
# workspace # workspace
aho-corasick = "1.1"
anyhow = { workspace = true } anyhow = { workspace = true }
dynamo-async-openai = { workspace = true } dynamo-async-openai = { workspace = true }
dynamo-parsers = { workspace = true} dynamo-parsers = { workspace = true}
......
...@@ -37,6 +37,7 @@ pub mod request_template; ...@@ -37,6 +37,7 @@ pub mod request_template;
pub mod tokenizers; pub mod tokenizers;
pub mod tokens; pub mod tokens;
pub mod types; pub mod types;
pub mod utils;
#[cfg(feature = "block-manager")] #[cfg(feature = "block-manager")]
pub mod block_manager; pub mod block_manager;
......
...@@ -15,18 +15,14 @@ pub mod prompt; ...@@ -15,18 +15,14 @@ pub mod prompt;
pub mod tools; pub mod tools;
use anyhow::Result; use anyhow::Result;
use dynamo_async_openai::types::ChatCompletionToolChoiceOption; use dynamo_async_openai::types::{ChatCompletionToolChoiceOption, EncodingFormat};
use dynamo_async_openai::types::EncodingFormat; use futures::Stream;
use futures::stream::{self, StreamExt}; use futures::stream::{self, StreamExt};
use prompt::OAIPromptFormatter; use prompt::OAIPromptFormatter;
use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, pin::Pin, sync::Arc};
use tracing; use tracing;
use dynamo_parsers::tool_calling::{
parsers::detect_tool_call_start, try_tool_call_parse_aggregate,
};
use crate::model_card::{ModelDeploymentCard, ModelInfo}; use crate::model_card::{ModelDeploymentCard, ModelInfo};
use crate::preprocessor::prompt::OAIChatLikeRequest; use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::protocols::common::preprocessor::PreprocessedRequestBuilder; use crate::protocols::common::preprocessor::PreprocessedRequestBuilder;
...@@ -42,7 +38,9 @@ use crate::protocols::{ ...@@ -42,7 +38,9 @@ use crate::protocols::{
common::{OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider}, common::{OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
openai::{ openai::{
DeltaGeneratorExt, DeltaGeneratorExt,
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, jail::JailedStream,
},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
nvext::NvExtProvider, nvext::NvExtProvider,
...@@ -60,7 +58,6 @@ use crate::protocols::common::llm_backend::EmbeddingsEngineOutput; ...@@ -60,7 +58,6 @@ use crate::protocols::common::llm_backend::EmbeddingsEngineOutput;
pub const ANNOTATION_FORMATTED_PROMPT: &str = "formatted_prompt"; pub const ANNOTATION_FORMATTED_PROMPT: &str = "formatted_prompt";
pub const ANNOTATION_TOKEN_IDS: &str = "token_ids"; pub const ANNOTATION_TOKEN_IDS: &str = "token_ids";
pub const ANNOTATION_LLM_METRICS: &str = "llm_metrics"; pub const ANNOTATION_LLM_METRICS: &str = "llm_metrics";
pub const ANNOTATION_POSSIBLE_TOOL_CALL: &str = "possible_tool_call";
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LLMMetricAnnotation { pub struct LLMMetricAnnotation {
pub input_tokens: usize, pub input_tokens: usize,
...@@ -68,30 +65,6 @@ pub struct LLMMetricAnnotation { ...@@ -68,30 +65,6 @@ pub struct LLMMetricAnnotation {
pub chunk_tokens: usize, pub chunk_tokens: usize,
} }
#[derive(Debug)]
pub struct JailState {
stream: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
is_jailed: bool,
tool_call_parser: Option<String>,
accumulated_content: HashMap<u32, String>, // choice index -> accumulated content
last_response_metadata: Option<NvCreateChatCompletionStreamResponse>, // for response structure
finished: bool, // Add this flag to track if stream is finished
}
pub fn maybe_enable_tool_call(
parser_str: Option<&str>,
request: &NvCreateChatCompletionRequest,
) -> bool {
// Enable tool call if the below two conditions are satisfied
// 1. parser_str is not None
// 2. tool_choice is not None
parser_str.is_some()
&& !matches!(
request.inner.tool_choice,
Some(ChatCompletionToolChoiceOption::None)
)
}
impl LLMMetricAnnotation { impl LLMMetricAnnotation {
/// Convert this metrics struct to an Annotated event /// Convert this metrics struct to an Annotated event
pub fn to_annotation<T>(&self) -> Result<Annotated<T>, serde_json::Error> { pub fn to_annotation<T>(&self) -> Result<Annotated<T>, serde_json::Error> {
...@@ -120,41 +93,6 @@ impl LLMMetricAnnotation { ...@@ -120,41 +93,6 @@ impl LLMMetricAnnotation {
} }
} }
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct PossibleToolCallAnnotation {
pub possible_tokens: usize,
pub possible_content: String,
pub parser_used: Option<String>,
}
impl PossibleToolCallAnnotation {
/// Convert this possible tool call annotation to an Annotated event
pub fn to_annotation<T>(&self) -> Result<Annotated<T>, serde_json::Error> {
Annotated::from_annotation(ANNOTATION_POSSIBLE_TOOL_CALL, self)
}
/// Extract possible tool call info from an Annotated event, if present
pub fn from_annotation<T>(
annotation: &Annotated<T>,
) -> Result<Option<PossibleToolCallAnnotation>, Box<dyn std::error::Error>> {
if annotation.event.is_none() {
return Ok(None);
}
if annotation.event.as_ref().unwrap() != ANNOTATION_POSSIBLE_TOOL_CALL {
return Ok(None);
}
let comments = annotation
.comment
.as_ref()
.ok_or("missing comments block")?;
if comments.len() != 1 {
return Err("malformed comments block - expected exactly 1 comment".into());
}
let possible_info: PossibleToolCallAnnotation = serde_json::from_str(&comments[0])?;
Ok(Some(possible_info))
}
}
pub struct OpenAIPreprocessor { pub struct OpenAIPreprocessor {
mdcsum: String, mdcsum: String,
formatter: Arc<dyn OAIPromptFormatter>, formatter: Arc<dyn OAIPromptFormatter>,
...@@ -482,36 +420,43 @@ impl OpenAIPreprocessor { ...@@ -482,36 +420,43 @@ impl OpenAIPreprocessor {
Ok((builder.build()?, annotations)) Ok((builder.build()?, annotations))
} }
pub fn transform_postprocessor_stream<Resp: Send + Sync + 'static + std::fmt::Debug>( pub fn transform_postprocessor_stream<S, Resp>(
stream: ManyOut<Annotated<BackendOutput>>, stream: S,
generator: Box<dyn DeltaGeneratorExt<Resp>>, generator: Box<dyn DeltaGeneratorExt<Resp>>,
) -> ManyOut<Annotated<Resp>> { context: Arc<dyn AsyncEngineContext>,
let context = stream.context(); ) -> impl Stream<Item = Annotated<Resp>> + Send
where
struct State<Resp: Send + Sync + 'static + std::fmt::Debug> { S: Stream<Item = Annotated<BackendOutput>> + Send + 'static,
response_stream: ManyOut<Annotated<BackendOutput>>, Resp: Send + Sync + 'static + std::fmt::Debug,
{
struct State<Resp>
where
Resp: Send + Sync + 'static + std::fmt::Debug,
{
response_stream: Pin<Box<dyn Stream<Item = Annotated<BackendOutput>> + Send>>,
response_generator: Box<dyn DeltaGeneratorExt<Resp>>, response_generator: Box<dyn DeltaGeneratorExt<Resp>>,
context: Arc<dyn AsyncEngineContext>, context: Arc<dyn AsyncEngineContext>,
cancelled: bool, cancelled: bool,
cumulative_output_tokens: usize, cumulative_output_tokens: usize,
finish_reason_sent: bool, finish_reason_sent: bool,
usage_chunk_sent: bool, usage_chunk_sent: bool,
finished: bool, // Add this flag to track if stream is finished finished: bool,
} }
let state = State { let state = State {
response_stream: stream, response_stream: Box::pin(stream),
response_generator: generator, response_generator: generator,
context: context.clone(), context: context.clone(),
cancelled: false, cancelled: false,
cumulative_output_tokens: 0, cumulative_output_tokens: 0,
finish_reason_sent: false, finish_reason_sent: false,
usage_chunk_sent: false, usage_chunk_sent: false,
finished: false, // Initialize as not finished finished: false,
}; };
// transform the common response stream into a chat response stream // transform the common response stream into a chat response stream
let stream = stream::unfold(state, |mut inner| {
stream::unfold(state, |mut inner| {
async move { async move {
// If already finished, return None immediately // If already finished, return None immediately
if inner.finished { if inner.finished {
...@@ -628,19 +573,18 @@ impl OpenAIPreprocessor { ...@@ -628,19 +573,18 @@ impl OpenAIPreprocessor {
} }
} }
} }
}); })
ResponseStream::new(Box::pin(stream), context)
} }
/// Transform engine embedding output stream to OpenAI embedding response stream /// Transform engine embedding output stream to OpenAI embedding response stream
pub fn transform_embedding_postprocessor_stream( pub fn transform_embedding_postprocessor_stream<S>(
stream: ManyOut<Annotated<EmbeddingsEngineOutput>>, stream: S,
original_request: NvCreateEmbeddingRequest, original_request: NvCreateEmbeddingRequest,
) -> ManyOut<Annotated<NvCreateEmbeddingResponse>> { ) -> impl Stream<Item = Annotated<NvCreateEmbeddingResponse>> + Send
let context = stream.context(); where
S: Stream<Item = Annotated<EmbeddingsEngineOutput>> + Send + 'static,
let transformed_stream = stream.map(move |output| { {
stream.map(move |output| {
output.map_data(|engine_output| { output.map_data(|engine_output| {
// Convert engine output to OpenAI response format // Convert engine output to OpenAI response format
let embeddings: Vec<dynamo_async_openai::types::Embedding> = engine_output let embeddings: Vec<dynamo_async_openai::types::Embedding> = engine_output
...@@ -668,262 +612,62 @@ impl OpenAIPreprocessor { ...@@ -668,262 +612,62 @@ impl OpenAIPreprocessor {
Ok(response) Ok(response)
}) })
}); })
ResponseStream::new(Box::pin(transformed_stream), context)
}
/// Apply tool calling jail to the stream using the preprocessor's tool call parser
pub async fn apply_tool_calling_jail_with_parser(
&self,
stream: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
) -> ManyOut<Annotated<NvCreateChatCompletionStreamResponse>> {
apply_tool_calling_jail_internal(stream, self.tool_call_parser.clone()).await
}
}
/// Apply tool calling jail to the stream - stops/jails the stream under certain conditions
/// When jailed, the stream will be unjailed when the input stream ends
pub async fn apply_tool_calling_jail_internal(
stream: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
tool_call_parser: Option<String>,
) -> ManyOut<Annotated<NvCreateChatCompletionStreamResponse>> {
let context = stream.context();
let jail_state = JailState {
stream,
is_jailed: false,
tool_call_parser,
accumulated_content: HashMap::new(),
last_response_metadata: None,
finished: false,
};
// Transform the stream using unfold to maintain state
// Input: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>
// Returns None if the stream is finished
// Returns Some((Annotated<NvCreateChatCompletionStreamResponse>, JailState)) if the stream is not finished
// End output: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>
let jailed_stream = stream::unfold(jail_state, |mut state| async move {
// If already finished, return None immediately
if state.finished {
return None;
}
if let Some(response) = state.stream.next().await {
// Check if we should jail the stream
if !state.is_jailed {
// Handle the case where response.data is Option<T>
if let Some(ref chat_response) = response.data {
// Store metadata for potential tool call parsing later
state.last_response_metadata = Some(chat_response.clone());
// Extract text content from the response
if let Some(choice) = chat_response.choices.first()
&& let Some(ref content) = choice.delta.content
{
// Check for tool call start
match detect_tool_call_start(content, state.tool_call_parser.as_deref()) {
Ok(should_jail) => {
if should_jail {
tracing::debug!("Tool call detected, jailing stream");
state.is_jailed = true;
// Start accumulating content for this choice
state
.accumulated_content
.insert(choice.index, content.clone());
// Create possible tool call annotation with token information
let possible_annotation = PossibleToolCallAnnotation {
possible_tokens: 1, // This chunk contains tokens being processed
possible_content: content.clone(),
parser_used: state.tool_call_parser.clone(),
};
// Create annotated response instead of empty response
let mut annotated_response = response.clone();
if let Ok(possible_annotated) =
possible_annotation
.to_annotation::<NvCreateChatCompletionStreamResponse>()
{
// Set annotation event and comment
annotated_response.event = possible_annotated.event;
annotated_response.comment = possible_annotated.comment;
}
// Modify the response to have empty content but keep metadata
annotated_response =
annotated_response.map_data(|mut chat_response| {
// Clear the content but keep choice structure for ITL measurement
for choice in &mut chat_response.choices {
choice.delta.content = Some(String::new()); // Empty content
} }
Ok(chat_response)
});
return Some((annotated_response, state)); /// Determine if we should apply the tool calling jail based on configuration
} /// Returns Ok(true) if jail should be applied, Ok(false) if not, or Err if invalid config
} pub fn should_apply_tool_jail(
Err(e) => { tool_call_parser: Option<&String>,
tracing::warn!("Error detecting tool call start: {}", e); tool_choice: Option<&ChatCompletionToolChoiceOption>,
} has_tools: bool,
} ) -> std::result::Result<bool, Error> {
match (tool_call_parser, tool_choice, has_tools) {
// No parser but tools requested - error cases
(None, Some(ChatCompletionToolChoiceOption::Required), true) => {
tracing::warn!(
"Tool choice 'required' specified but no tool parser configured; proceeding without jailing"
);
Ok(false)
} }
(None, Some(ChatCompletionToolChoiceOption::Auto), true) => {
tracing::warn!(
"Tool choice 'auto' specified but no tool parser configured; proceeding without jailing"
);
Ok(false)
} }
} else if state.is_jailed { (None, Some(ChatCompletionToolChoiceOption::Named(_)), _) => {
// If already jailed, continue to jail but with annotations and accumulate content tracing::warn!(
if let Some(ref chat_response) = response.data { "Named tool choice specified but no tool parser configured; proceeding without jailing"
// Extract content for annotation and accumulation );
for choice in &chat_response.choices { Ok(false)
if let Some(ref content) = choice.delta.content
&& !content.is_empty()
{
// Accumulate content for this choice
state
.accumulated_content
.entry(choice.index)
.or_default()
.push_str(content);
// Create possible tool call annotation
let possible_annotation = PossibleToolCallAnnotation {
possible_tokens: 1,
possible_content: content.clone(),
parser_used: state.tool_call_parser.clone(),
};
// Create annotated response
let mut annotated_response = response.clone();
if let Ok(possible_annotated) = possible_annotation
.to_annotation::<NvCreateChatCompletionStreamResponse>(
) {
annotated_response.event = possible_annotated.event;
annotated_response.comment = possible_annotated.comment;
} }
// Clear content but keep structure // Parser exists and tools might be called
annotated_response = (Some(_), Some(ChatCompletionToolChoiceOption::None), _) => {
annotated_response.map_data(|mut chat_response| { Ok(false) // Explicitly disabled
for choice in &mut chat_response.choices {
choice.delta.content = Some(String::new());
} }
Ok(chat_response) (Some(_), Some(_), true) => Ok(true), // Any other tool_choice with tools
}); (Some(_), None, true) => Ok(true), // Default behavior when tools present
return Some((annotated_response, state)); // No tools or no parser
} _ => Ok(false),
}
} }
} }
// If not jailed or jailing condition not met, return the response as-is /// Apply tool calling jail to the stream if needed
Some((response, state)) pub fn apply_tool_calling_jail<S>(
} else { tool_call_parser: String,
// Stream ended - if we were jailed, we should unjail now and parse tool calls stream: S,
if state.is_jailed { ) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
tracing::debug!("Stream ended, unjailing and parsing accumulated content"); where
state.is_jailed = false; S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
// Parse accumulated content for tool calls
if !state.accumulated_content.is_empty()
&& let Some(base_response) = state.last_response_metadata.take()
{ {
// Try to parse tool calls from accumulated content for each choice let jail = JailedStream::builder()
let mut final_response = base_response.clone(); .tool_call_parser(tool_call_parser)
.build();
for (choice_index, accumulated_text) in &state.accumulated_content { jail.apply(stream)
if let Ok((tool_calls, normal_text)) = try_tool_call_parse_aggregate(
accumulated_text,
state.tool_call_parser.as_deref(),
)
.await
{
// Found tool calls, create a final response with them
tracing::debug!(
"Parsed {} tool calls from accumulated content",
tool_calls.len()
);
for tool_call in &tool_calls {
tracing::debug!(
tool_call_id = %tool_call.id,
function_name = %tool_call.function.name,
arguments = %tool_call.function.arguments,
"Parsed structured tool call from accumulated content in jail"
);
}
// Convert ChatCompletionMessageToolCall to ChatCompletionMessageToolCallChunk for streaming
let tool_call_chunks: Vec<
dynamo_async_openai::types::ChatCompletionMessageToolCallChunk,
> = tool_calls
.into_iter()
.enumerate()
.map(|(idx, tool_call)| {
dynamo_async_openai::types::ChatCompletionMessageToolCallChunk {
index: idx as u32,
id: Some(tool_call.id),
r#type: Some(tool_call.r#type),
function: Some(
dynamo_async_openai::types::FunctionCallStream {
name: Some(tool_call.function.name),
arguments: Some(tool_call.function.arguments),
},
),
}
})
.collect();
// Create a choice with tool calls
#[allow(deprecated)]
let final_choice = dynamo_async_openai::types::ChatChoiceStream {
index: *choice_index,
delta:
dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
role: Some(dynamo_async_openai::types::Role::Assistant),
content: normal_text.filter(|t| !t.is_empty()),
tool_calls: Some(tool_call_chunks.clone()),
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: Some(
dynamo_async_openai::types::FinishReason::ToolCalls,
),
logprobs: None,
};
// Update the response choices
final_response.choices = vec![final_choice];
// Create final annotated response
let final_annotated = Annotated {
data: Some(final_response),
id: None,
event: None,
comment: None,
};
state.finished = true; // Mark as finished before returning
return Some((final_annotated, state));
} }
}
}
}
state.finished = true; // Mark as finished
None
}
});
// Jailed Stream contains empty content chunks with annotation event "possible_tool_call" whenever the stream is jailed
// This is a bad UX for the user, as they have to see a lot of empty content chunks
// Filter out the empty content chunks with annotation event "possible_tool_call"
let filtered_stream = jailed_stream.filter(|annotated| {
let keep = annotated.event.as_deref() != Some(ANNOTATION_POSSIBLE_TOOL_CALL);
async move { keep }
});
ResponseStream::new(Box::pin(filtered_stream), context)
} }
// for pals, we do not want to add the generation prompt to the formatted prompt // for pals, we do not want to add the generation prompt to the formatted prompt
...@@ -952,14 +696,14 @@ impl ...@@ -952,14 +696,14 @@ impl
// create a response generator // create a response generator
let response_generator = request.response_generator(context.id().to_string()); let response_generator = request.response_generator(context.id().to_string());
// convert the chat completion request to a common completion request
let (common_request, annotations) = self.preprocess_request(&request)?;
let mut response_generator = Box::new(response_generator); let mut response_generator = Box::new(response_generator);
// set the runtime configuration // set the runtime configuration
response_generator.set_reasoning_parser(self.runtime_config.clone()); response_generator.set_reasoning_parser(self.runtime_config.clone());
let enable_tool_calling =
maybe_enable_tool_call(self.tool_call_parser.as_deref(), &request);
// convert the chat completion request to a common completion request
let (common_request, annotations) = self.preprocess_request(&request)?;
// update isl // update isl
response_generator.update_isl(common_request.token_ids.len() as u32); response_generator.update_isl(common_request.token_ids.len() as u32);
...@@ -977,21 +721,43 @@ impl ...@@ -977,21 +721,43 @@ impl
// forward the common completion request to the next operator // forward the common completion request to the next operator
let response_stream = next.generate(common_request).await?; let response_stream = next.generate(common_request).await?;
// transform the postprocessor stream // Extract context once
let stream = Self::transform_postprocessor_stream(response_stream, response_generator); let context = response_stream.context();
// transform the postprocessor stream (no boxing yet)
let stream = Self::transform_postprocessor_stream(
response_stream,
response_generator,
context.clone(),
);
// Check if tools are present and if we should apply jail
let has_tools =
request.inner.tools.is_some() && !request.inner.tools.as_ref().unwrap().is_empty();
// Context was already extracted above from response_stream
// Determine if we should apply jail (do this before moving request)
let should_jail = Self::should_apply_tool_jail(
self.tool_call_parser.as_ref(),
request.inner.tool_choice.as_ref(),
has_tools,
)?;
// Apply tool calling jail to the stream if tool call parser is present // Apply jail conditionally
let stream = if enable_tool_calling { let stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_jail {
self.apply_tool_calling_jail_with_parser(stream).await if let Some(parser) = self.tool_call_parser.clone() {
Box::pin(Self::apply_tool_calling_jail(parser, stream))
} else { } else {
stream Box::pin(stream) // Should not happen due to should_jail check
}
} else {
Box::pin(stream)
}; };
let context = stream.context();
// prepend the annotations to the response stream // prepend the annotations to the response stream
let stream = annotations_stream.chain(stream); let stream = annotations_stream.chain(stream);
// return the response stream // return the response stream - single boxing at the end
Ok(ResponseStream::new(Box::pin(stream), context)) Ok(ResponseStream::new(Box::pin(stream), context))
} }
} }
...@@ -1039,9 +805,15 @@ impl ...@@ -1039,9 +805,15 @@ impl
// forward the common completion request to the next operator // forward the common completion request to the next operator
let response_stream = next.generate(common_request).await?; let response_stream = next.generate(common_request).await?;
// Extract context once
let context = response_stream.context();
// transform the postprocessor stream // transform the postprocessor stream
let stream = Self::transform_postprocessor_stream(response_stream, response_generator); let stream = Self::transform_postprocessor_stream(
let context = stream.context(); response_stream,
response_generator,
context.clone(),
);
// prepend the annotations to the response stream // prepend the annotations to the response stream
let stream = annotations_stream.chain(stream); let stream = annotations_stream.chain(stream);
...@@ -1082,9 +854,11 @@ impl ...@@ -1082,9 +854,11 @@ impl
let preprocessed_request = context.map(|_| preprocessed_request); let preprocessed_request = context.map(|_| preprocessed_request);
let response_stream = next.generate(preprocessed_request).await?; let response_stream = next.generate(preprocessed_request).await?;
// Extract context once
let context = response_stream.context();
// Transform response stream back to OpenAI format // Transform response stream back to OpenAI format
let stream = Self::transform_embedding_postprocessor_stream(response_stream, request); let stream = Self::transform_embedding_postprocessor_stream(response_stream, request);
let context = stream.context();
// Prepend annotations // Prepend annotations
let annotations_stream = stream::iter( let annotations_stream = stream::iter(
...@@ -1098,3 +872,5 @@ impl ...@@ -1098,3 +872,5 @@ impl
Ok(ResponseStream::new(Box::pin(combined_stream), context)) Ok(ResponseStream::new(Box::pin(combined_stream), context))
} }
} }
// Note: tests for jailing and parser detection live in `lib/llm/tests/test_jail.rs`
...@@ -19,6 +19,7 @@ use super::{ ...@@ -19,6 +19,7 @@ use super::{
pub mod aggregator; pub mod aggregator;
mod delta; mod delta;
pub mod jail;
pub use aggregator::DeltaAggregator; pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator; pub use delta::DeltaGenerator;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use async_stream::stream;
use dynamo_async_openai::types::{
ChatChoiceStream, ChatCompletionMessageToolCallChunk, ChatCompletionStreamResponseDelta,
FinishReason, FunctionCallStream, Role,
};
use dynamo_parsers::tool_calling::parsers::get_tool_parser_map;
use dynamo_parsers::tool_calling::{detect_tool_call_start, try_tool_call_parse_aggregate};
use dynamo_runtime::protocols::annotated::Annotated;
use futures::{Stream, StreamExt};
use crate::utils::{MarkerMatcher, MatchResult};
use super::NvCreateChatCompletionStreamResponse;
/// Represents what a choice wants to emit after processing content
#[derive(Debug, Clone)]
pub enum ChoiceEmission {
/// Pass through content unchanged (choice is not jailed)
PassThrough(ChatChoiceStream),
/// Emit parsed tool calls (choice finished jailing with tool calls)
ToolCall(ChatChoiceStream),
/// Emit accumulated content (choice finished jailing without tool calls)
Content(ChatChoiceStream),
/// Emit trailing content after tool call end (choice has trailing after unjail)
Trailing(ChatChoiceStream),
}
impl ChoiceEmission {
/// Extract the ChatChoiceStream from any emission type
pub fn into_choice(self) -> ChatChoiceStream {
match self {
ChoiceEmission::PassThrough(choice) => choice,
ChoiceEmission::ToolCall(choice) => choice,
ChoiceEmission::Content(choice) => choice,
ChoiceEmission::Trailing(choice) => choice,
}
}
/// Get the choice index
pub fn index(&self) -> u32 {
match self {
ChoiceEmission::PassThrough(choice) => choice.index,
ChoiceEmission::ToolCall(choice) => choice.index,
ChoiceEmission::Content(choice) => choice.index,
ChoiceEmission::Trailing(choice) => choice.index,
}
}
}
/// Configuration for jail detection and parsing
#[derive(Debug, Clone)]
pub struct JailConfig<'a> {
pub jail_start_sequences: &'a [String],
pub jail_end_sequences: &'a [String],
pub tool_call_parser: Option<&'a str>,
}
/// State tracking for an individual choice during jail processing
#[derive(Debug, Clone)]
struct ChoiceJailState {
/// The choice index (0, 1, 2, ...)
index: u32,
/// Whether this choice is currently jailed
is_jailed: bool,
/// Accumulated content for this choice while jailed
accumulated_content: String,
/// Buffer for partial marker matches across chunks
partial_match_buffer: String,
}
impl ChoiceJailState {
/// Create a new jail state for a choice
fn new(index: u32) -> Self {
Self {
index,
is_jailed: false,
accumulated_content: String::new(),
partial_match_buffer: String::new(),
}
}
/// Add content to this choice's accumulation
fn accumulate(&mut self, content: &str) {
if self.is_jailed {
self.accumulated_content.push_str(content);
}
}
/// End jailing and return the accumulated content
fn end_jail(&mut self) -> String {
self.is_jailed = false;
std::mem::take(&mut self.accumulated_content)
}
/// Process incoming content and return what should be emitted (if anything)
async fn process_content(
&mut self,
choice: &ChatChoiceStream,
content: &str,
jail_stream: &JailedStream,
) -> Vec<ChoiceEmission> {
let mut emissions = Vec::new();
if !self.is_jailed {
// Use the marker matcher to detect complete/partial markers
match jail_stream
.marker_matcher
.process_chunk(content, &self.partial_match_buffer)
{
MatchResult::Complete {
prefix,
marker,
suffix,
..
} => {
// Emit prefix if any
if !prefix.is_empty() {
#[allow(deprecated)]
let prefix_choice = ChatChoiceStream {
index: choice.index,
delta: ChatCompletionStreamResponseDelta {
role: choice.delta.role,
content: Some(prefix),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: choice.logprobs.clone(),
};
emissions.push(ChoiceEmission::PassThrough(prefix_choice));
}
// Build the potential full content
let full_content = format!("{}{}", marker, suffix);
// Check if this already contains the end marker
let (should_end, split_pos) = jail_stream.should_end_jail(&full_content).await;
if should_end {
// Complete tool call found in this chunk
tracing::debug!(
"Choice {} complete tool call detected in single chunk",
choice.index
);
let (jailed_part, trailing_part) = full_content.split_at(split_pos);
// Create the tool call choice
let tool_choice = jail_stream
.create_tool_call_choice(choice.index, jailed_part, choice)
.await;
if tool_choice.delta.tool_calls.is_some() {
emissions.push(ChoiceEmission::ToolCall(tool_choice));
} else {
emissions.push(ChoiceEmission::Content(tool_choice));
}
// Handle trailing content if any
if !trailing_part.is_empty() {
#[allow(deprecated)]
let trailing_choice = ChatChoiceStream {
index: choice.index,
delta: ChatCompletionStreamResponseDelta {
role: choice.delta.role,
content: Some(trailing_part.to_string()),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: choice.logprobs.clone(),
};
emissions.push(ChoiceEmission::Trailing(trailing_choice));
}
} else {
// Start jailing with the marker and suffix
tracing::debug!(
"Choice {} start marker '{}' detected, starting jail",
choice.index,
marker
);
self.is_jailed = true;
self.accumulated_content = full_content;
}
self.partial_match_buffer.clear();
}
MatchResult::Partial {
prefix,
partial,
possible_patterns,
} => {
// Emit the safe prefix
if !prefix.is_empty() {
#[allow(deprecated)]
let prefix_choice = ChatChoiceStream {
index: choice.index,
delta: ChatCompletionStreamResponseDelta {
role: choice.delta.role,
content: Some(prefix),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: choice.logprobs.clone(),
};
emissions.push(ChoiceEmission::PassThrough(prefix_choice));
}
// Hold the partial for next chunk
self.partial_match_buffer = partial;
tracing::trace!(
"Choice {} holding partial '{}' for patterns: {:?}",
choice.index,
self.partial_match_buffer,
possible_patterns
);
}
MatchResult::None { content } => {
// Check if this content (combined with partial buffer) should start jailing
let combined_content = if self.partial_match_buffer.is_empty() {
content.clone()
} else {
format!("{}{}", self.partial_match_buffer, content)
};
if jail_stream.should_start_jail(&combined_content) {
// Start jailing with the combined content
tracing::debug!(
"Choice {} tool call start detected via parser, starting jail",
choice.index
);
self.is_jailed = true;
self.accumulated_content = combined_content;
self.partial_match_buffer.clear();
} else {
// No markers - emit everything
if !content.is_empty() {
#[allow(deprecated)]
let pass_through_choice = ChatChoiceStream {
index: choice.index,
delta: ChatCompletionStreamResponseDelta {
role: choice.delta.role,
content: Some(content),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: choice.logprobs.clone(),
};
emissions.push(ChoiceEmission::PassThrough(pass_through_choice));
}
self.partial_match_buffer.clear();
}
}
}
} else {
// Already jailed - accumulate and check for unjail
self.accumulate(content);
let (should_end, split_pos) =
jail_stream.should_end_jail(&self.accumulated_content).await;
if should_end {
tracing::debug!(
"Choice {} jail exit detected, releasing accumulated content",
choice.index
);
// Split the content
let (jailed_part, trailing_part) = self.accumulated_content.split_at(split_pos);
// Create the unjailed choice
let unjailed_choice = jail_stream
.create_tool_call_choice(choice.index, jailed_part, choice)
.await;
// Determine emission type based on whether tool calls were parsed
if unjailed_choice.delta.tool_calls.is_some() {
emissions.push(ChoiceEmission::ToolCall(unjailed_choice));
} else {
emissions.push(ChoiceEmission::Content(unjailed_choice));
}
// Handle trailing content if any
if !trailing_part.is_empty() {
#[allow(deprecated)]
let trailing_choice = ChatChoiceStream {
index: choice.index,
delta: ChatCompletionStreamResponseDelta {
role: choice.delta.role,
content: Some(trailing_part.to_string()),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: choice.logprobs.clone(),
};
emissions.push(ChoiceEmission::Trailing(trailing_choice));
}
// End jailing
self.end_jail();
}
// If not unjailing, don't emit anything (still accumulating)
}
emissions
}
/// Finalize any remaining content when stream ends
async fn finalize(&mut self, jail_stream: &JailedStream) -> Option<ChoiceEmission> {
if self.is_jailed && !self.accumulated_content.is_empty() {
tracing::debug!(
"Choice {} stream ended while jailed, releasing accumulated content",
self.index
);
// Create a dummy choice for the method call
#[allow(deprecated)]
let dummy_choice = ChatChoiceStream {
index: self.index,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: None,
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: None,
};
let final_choice = jail_stream
.create_tool_call_choice(self.index, &self.accumulated_content, &dummy_choice)
.await;
// End jailing
self.end_jail();
// Determine emission type
if final_choice.delta.tool_calls.is_some() {
Some(ChoiceEmission::ToolCall(final_choice))
} else {
Some(ChoiceEmission::Content(final_choice))
}
} else {
None
}
}
}
/// Collection of choice jail states with deterministic ordering
#[derive(Debug, Clone)]
struct ChoiceJailStateCollection {
/// Vec of states, always kept sorted by choice index for deterministic iteration
states: Vec<ChoiceJailState>,
}
impl ChoiceJailStateCollection {
/// Create a new empty collection
fn new() -> Self {
Self { states: Vec::new() }
}
/// Get or create state for a choice index
fn get_or_create_state(&mut self, index: u32) -> &mut ChoiceJailState {
// Find the position where this index should be
match self.states.binary_search_by_key(&index, |s| s.index) {
Ok(pos) => {
// Found existing state
&mut self.states[pos]
}
Err(insert_pos) => {
// Need to create new state
let new_state = ChoiceJailState::new(index);
self.states.insert(insert_pos, new_state);
&mut self.states[insert_pos]
}
}
}
}
/// Emission mode for handling multiple choices
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EmissionMode {
/// Pack multiple choices in the same chunk (default, matches original behavior)
Packed,
/// Emit one choice per chunk for OpenAI compatibility
SingleChoicePerChunk,
}
impl Default for EmissionMode {
fn default() -> Self {
Self::Packed
}
}
/// A stream transformer that can "jail" tokens based on configurable start/end sequences
/// When jailed, tokens are accumulated rather than yielded immediately
/// When the jail ends (via end sequence or stream completion), accumulated content is processed and released
pub struct JailedStream {
jail_start_sequences: Vec<String>,
jail_end_sequences: Vec<String>,
tool_call_parser: Option<String>,
emission_mode: EmissionMode,
marker_matcher: MarkerMatcher,
}
impl JailedStream {
/// Create a new builder for configuring a JailedStream
pub fn builder() -> JailedStreamBuilder {
JailedStreamBuilder::new()
}
/// Apply the jail transformation to a stream of chat completion responses
/// Consumes self and returns the transformed stream
pub fn apply<S>(
self,
stream: S,
) -> impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send
where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
{
// Use the stream! macro for cleaner async stream processing
stream! {
// State variables - clean architecture with choice state collection
let mut choice_states = ChoiceJailStateCollection::new();
// Track Annotated metadata for preservation
let mut last_annotated_id: Option<String> = None;
let mut last_annotated_event: Option<String> = None;
let mut last_annotated_comment: Option<Vec<String>> = None;
// Pin the stream for iteration (stack pinning is more efficient)
tokio::pin!(stream);
// Process each item in the stream
while let Some(response) = stream.next().await {
if let Some(chat_response) = response.data.as_ref() {
let mut all_emissions = Vec::new();
// Process each choice independently using the new architecture
for choice in &chat_response.choices {
if let Some(ref content) = choice.delta.content {
let choice_state = choice_states.get_or_create_state(choice.index);
// Store metadata when any choice becomes jailed (first time only)
if !choice_state.is_jailed && self.should_start_jail(content)
&& last_annotated_id.is_none() {
last_annotated_id = response.id.clone();
last_annotated_event = response.event.clone();
last_annotated_comment = response.comment.clone();
}
// Process this choice and get emissions
let emissions = choice_state.process_content(choice, content, &self).await;
all_emissions.extend(emissions);
} else {
// Handle choices without content (e.g., final chunks with finish_reason)
// These should always pass through
let pass_through_choice = ChatChoiceStream {
index: choice.index,
delta: choice.delta.clone(),
finish_reason: choice.finish_reason,
logprobs: choice.logprobs.clone(),
};
all_emissions.push(ChoiceEmission::PassThrough(pass_through_choice));
}
}
// Emit all results based on emission mode
if !all_emissions.is_empty() {
// Group emissions by type for proper ordering and separation
let mut tool_content_emissions = Vec::new();
let mut trailing_emissions = Vec::new();
let mut passthrough_emissions = Vec::new();
for emission in all_emissions {
match emission {
ChoiceEmission::PassThrough(_) => passthrough_emissions.push(emission),
ChoiceEmission::ToolCall(_) | ChoiceEmission::Content(_) => {
tool_content_emissions.push(emission);
}
ChoiceEmission::Trailing(_) => {
trailing_emissions.push(emission);
}
}
}
// Emit tool calls and content with preserved metadata
if !tool_content_emissions.is_empty() {
let preserved_metadata = (
last_annotated_id.clone(),
last_annotated_event.clone(),
last_annotated_comment.clone(),
);
let responses = self.emit_choice_emissions(tool_content_emissions, chat_response, preserved_metadata);
for emitted_response in responses {
yield emitted_response;
}
}
// Emit trailing content separately (always as individual chunks)
if !trailing_emissions.is_empty() {
let preserved_metadata = (
last_annotated_id.clone(),
last_annotated_event.clone(),
last_annotated_comment.clone(),
);
let responses = self.emit_choice_emissions(trailing_emissions, chat_response, preserved_metadata);
for emitted_response in responses {
yield emitted_response;
}
}
// Emit pass-through content with current metadata
if !passthrough_emissions.is_empty() {
let current_metadata = (response.id.clone(), response.event.clone(), response.comment.clone());
let responses = self.emit_choice_emissions(passthrough_emissions, chat_response, current_metadata);
for emitted_response in responses {
yield emitted_response;
}
}
}
} else {
// No response data, pass through as-is
yield response;
}
}
// Stream ended - finalize any remaining jailed choices
let mut final_emissions = Vec::new();
for state in choice_states.states.iter_mut() {
if let Some(emission) = state.finalize(&self).await {
final_emissions.push(emission);
}
}
if !final_emissions.is_empty() {
tracing::debug!("Stream ended while jailed, releasing accumulated content");
// Create a dummy response for finalization
let dummy_response = NvCreateChatCompletionStreamResponse {
id: "stream-end".to_string(),
object: "chat.completion.chunk".to_string(),
created: 0,
model: "unknown".to_string(),
choices: Vec::new(),
usage: None,
service_tier: None,
system_fingerprint: None,
};
let final_metadata = (last_annotated_id, last_annotated_event, last_annotated_comment);
let responses = self.emit_choice_emissions(final_emissions, &dummy_response, final_metadata);
for emitted_response in responses {
yield emitted_response;
}
}
}
}
/// Emit choice emissions based on the configured emission mode
fn emit_choice_emissions(
&self,
emissions: Vec<ChoiceEmission>,
base_response: &NvCreateChatCompletionStreamResponse,
annotated_metadata: (Option<String>, Option<String>, Option<Vec<String>>),
) -> Vec<Annotated<NvCreateChatCompletionStreamResponse>> {
if emissions.is_empty() {
return Vec::new();
}
let (id, event, comment) = annotated_metadata;
match self.emission_mode {
EmissionMode::Packed => {
// Pack all choices into a single response
let mut response = base_response.clone();
response.choices = emissions.into_iter().map(|e| e.into_choice()).collect();
vec![Annotated {
data: Some(response),
id,
event,
comment,
}]
}
EmissionMode::SingleChoicePerChunk => {
// Emit each choice in a separate response
emissions
.into_iter()
.map(|emission| {
let mut response = base_response.clone();
response.choices = vec![emission.into_choice()];
Annotated {
data: Some(response),
id: id.clone(),
event: event.clone(),
comment: comment.clone(),
}
})
.collect()
}
}
}
/// Check if content matches any jail start patterns
fn should_start_jail(&self, content: &str) -> bool {
// Path 1: Check configured start sequences
let sequence_match = !self.jail_start_sequences.is_empty()
&& self
.jail_start_sequences
.iter()
.any(|seq| content.contains(seq));
// Path 2: Check for tool call start pattern
let tool_call_match = self.tool_call_parser.is_some()
&& detect_tool_call_start(content, self.tool_call_parser.as_deref()).unwrap_or(false);
sequence_match || tool_call_match
}
/// Check if accumulated content should end jail
async fn should_end_jail(&self, accumulated_content: &str) -> (bool, usize) {
// Path 1: End sequence detected
let end_marker_info = if !self.jail_end_sequences.is_empty() {
self.jail_end_sequences.iter().find_map(|seq| {
accumulated_content
.find(seq)
.map(|pos| (pos + seq.len(), seq.clone()))
})
} else {
None
};
// Path 2: Complete tool call(s) can be parsed (early exit)
let early_exit = self.should_exit_jail_early(accumulated_content).await;
if let Some((end_pos, _)) = end_marker_info {
(true, end_pos)
} else if early_exit {
// For early exit, find where the complete tool call ends
if let Some(parser) = &self.tool_call_parser {
if let Ok((_, _)) =
try_tool_call_parse_aggregate(accumulated_content, Some(parser)).await
{
let split_pos = self.find_tool_call_end_position(accumulated_content, parser);
(true, split_pos)
} else {
(false, accumulated_content.len())
}
} else {
(false, accumulated_content.len())
}
} else {
(false, accumulated_content.len())
}
}
/// Parse tool calls from accumulated content and create choice
async fn create_tool_call_choice(
&self,
choice_index: u32,
accumulated_content: &str,
base_choice: &ChatChoiceStream,
) -> ChatChoiceStream {
if let Ok((tool_calls, normal_text)) =
try_tool_call_parse_aggregate(accumulated_content, self.tool_call_parser.as_deref())
.await
&& !tool_calls.is_empty()
{
// Convert to streaming format
let tool_call_chunks: Vec<ChatCompletionMessageToolCallChunk> = tool_calls
.into_iter()
.enumerate()
.map(|(idx, tool_call)| ChatCompletionMessageToolCallChunk {
index: idx as u32,
id: Some(tool_call.id),
r#type: Some(tool_call.r#type),
function: Some(FunctionCallStream {
name: Some(tool_call.function.name),
arguments: Some(tool_call.function.arguments),
}),
})
.collect();
// Create choice with tool calls
#[allow(deprecated)]
return ChatChoiceStream {
index: choice_index,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: normal_text.filter(|t| !t.is_empty()),
tool_calls: Some(tool_call_chunks),
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: Some(FinishReason::ToolCalls),
logprobs: None,
};
}
// No tool calls found or parsing failed, return content choice
#[allow(deprecated)]
ChatChoiceStream {
index: choice_index,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: Some(accumulated_content.to_string()),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: base_choice.logprobs.clone(),
}
}
/// Check if accumulated content contains complete tool calls that can be parsed
/// Returns true if we should exit the jail early
async fn should_exit_jail_early(&self, accumulated: &str) -> bool {
if let Some(ref parser) = self.tool_call_parser {
// Try to parse - if successful and we have complete tool calls, exit early
if let Ok((tool_calls, _)) =
try_tool_call_parse_aggregate(accumulated, Some(parser)).await
{
return !tool_calls.is_empty();
}
}
false
}
/// Find the exact position where the tool call ends for splitting content
/// This handles the early exit case where we have trailing content after the tool call
fn find_tool_call_end_position(&self, content: &str, parser: &str) -> usize {
match parser {
"hermes" => {
// For Hermes, look for </tool_call> marker
if let Some(pos) = content.find("</tool_call>") {
pos + "</tool_call>".len()
} else {
content.len()
}
}
"nemotron_deci" => {
// For Nemotron, look for </TOOLCALL> marker
if let Some(pos) = content.find("</TOOLCALL>") {
pos + "</TOOLCALL>".len()
} else {
content.len()
}
}
"mistral" => {
// For Mistral, look for [/TOOL_CALLS] marker or end of JSON array
if let Some(pos) = content.find("[/TOOL_CALLS]") {
pos + "[/TOOL_CALLS]".len()
} else if let Some(pos) = content.rfind(']') {
// Find the last ] which should be the end of the tool calls array
pos + 1
} else {
content.len()
}
}
"phi4" => {
// For Phi4, look for <|tool_call|> end marker
if let Some(pos) = content.rfind("<|tool_call|>") {
// Look for the next occurrence after this position
if let Some(end_pos) = content[pos..].find(">") {
pos + end_pos + 1
} else {
content.len()
}
} else {
content.len()
}
}
"llama3_json" => {
// For Llama3 JSON, there's no explicit end marker
// The end is determined by complete JSON parsing
// Return full content length to avoid early splitting
content.len()
}
_ => {
// Unknown parser, default to full content
content.len()
}
}
}
}
/// Builder for configuring a JailedStream
pub struct JailedStreamBuilder {
jail_start_sequences: Vec<String>,
jail_end_sequences: Vec<String>,
tool_call_parser: Option<String>,
emission_mode: EmissionMode,
}
impl JailedStreamBuilder {
/// Create a new builder with default settings
pub fn new() -> Self {
Self {
jail_start_sequences: Vec::new(),
jail_end_sequences: Vec::new(),
tool_call_parser: None,
emission_mode: EmissionMode::default(),
}
}
/// Add a sequence that triggers jailing when detected
pub fn jail_start_sequence(mut self, sequence: impl Into<String>) -> Self {
self.jail_start_sequences.push(sequence.into());
self
}
/// Add multiple sequences that trigger jailing when detected
pub fn jail_start_sequences(
mut self,
sequences: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.jail_start_sequences
.extend(sequences.into_iter().map(Into::into));
self
}
/// Add a sequence that ends jailing when detected
pub fn jail_end_sequence(mut self, sequence: impl Into<String>) -> Self {
self.jail_end_sequences.push(sequence.into());
self
}
/// Add multiple sequences that end jailing when detected
pub fn jail_end_sequences(
mut self,
sequences: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.jail_end_sequences
.extend(sequences.into_iter().map(Into::into));
self
}
/// Set the tool call parser to use for detection and parsing
pub fn tool_call_parser(mut self, parser: impl Into<String>) -> Self {
self.tool_call_parser = Some(parser.into());
self
}
/// Set the emission mode for handling multiple choices
pub fn emission_mode(mut self, mode: EmissionMode) -> Self {
self.emission_mode = mode;
self
}
/// Enable single choice per chunk emission for OpenAI compatibility
pub fn single_choice_per_chunk(mut self) -> Self {
self.emission_mode = EmissionMode::SingleChoicePerChunk;
self
}
/// Enable packed emission mode (multiple choices per chunk)
pub fn packed_emission(mut self) -> Self {
self.emission_mode = EmissionMode::Packed;
self
}
/// Build the configured JailedStream
pub fn build(mut self) -> JailedStream {
// Auto-populate jail sequences from parser config if not manually configured
if let Some(ref parser_name) = self.tool_call_parser {
let parser_map = get_tool_parser_map();
if let Some(config) = parser_map.get(parser_name.as_str()) {
// Auto-populate start sequences if none configured
if self.jail_start_sequences.is_empty() {
self.jail_start_sequences = config.json.tool_call_start_tokens.clone();
}
// Auto-populate end sequences if none configured
if self.jail_end_sequences.is_empty() {
self.jail_end_sequences = config
.json
.tool_call_end_tokens
.iter()
.filter(|&s| !s.is_empty())
.cloned()
.collect();
}
}
}
// Collect all possible marker patterns for the MarkerMatcher
let mut all_patterns = Vec::new();
// Add configured start sequences (now auto-populated if needed)
all_patterns.extend(self.jail_start_sequences.clone());
// Add patterns from tool call parser if configured (for redundancy)
if let Some(ref parser_name) = self.tool_call_parser {
let parser_map = get_tool_parser_map();
if let Some(config) = parser_map.get(parser_name.as_str()) {
// Add start tokens from the parser config
all_patterns.extend(config.json.tool_call_start_tokens.clone());
}
}
// Add common tool call markers to ensure we detect all formats
// Only include these when a specific parser is NOT configured,
// to avoid unexpected false positives for explicit formats
if self.tool_call_parser.is_none() {
let common_markers = vec![
"<TOOLCALL>".to_string(), // nemotron_deci format
"<tool_call>".to_string(), // hermes format
"[TOOL_CALLS]".to_string(), // mistral format
"<|python_tag|>".to_string(), // llama3_json format
"functools[".to_string(), // phi4 format
// Add JSON start patterns for Mistral-style tool calls
"[{".to_string(),
"{".to_string(),
// Note: Harmony parser uses JSON patterns, covered by "{" above
];
for marker in common_markers {
if !all_patterns.contains(&marker) {
all_patterns.push(marker);
}
}
}
// Create the marker matcher (fallback to empty patterns if none configured)
let marker_matcher = if all_patterns.is_empty() {
// If no patterns, create a dummy matcher that never matches
MarkerMatcher::new(vec!["__NEVER_MATCH__".to_string()])
.expect("Failed to create dummy MarkerMatcher")
} else {
MarkerMatcher::new(all_patterns)
.expect("Failed to create MarkerMatcher with configured patterns")
};
JailedStream {
jail_start_sequences: self.jail_start_sequences,
jail_end_sequences: self.jail_end_sequences,
tool_call_parser: self.tool_call_parser,
emission_mode: self.emission_mode,
marker_matcher,
}
}
}
impl Default for JailedStreamBuilder {
fn default() -> Self {
Self::new()
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod prefix_matcher;
pub use prefix_matcher::{MarkerMatcher, MatchResult};
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Efficient multi-pattern marker detection with partial suffix matching
//!
//! This module provides utilities for detecting complete and partial marker patterns
//! in streaming text, with support for detecting markers split across chunk boundaries.
use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind};
use std::collections::HashMap;
/// Result of processing a chunk with potential marker detection
#[derive(Debug, Clone, PartialEq)]
pub enum MatchResult {
/// Complete marker found
Complete {
/// Content before the marker (safe to emit)
prefix: String,
/// The complete marker matched
marker: String,
/// Start position of the marker in the input
marker_start: usize,
/// Remaining content after the marker
suffix: String,
},
/// Partial marker at end of chunk
Partial {
/// Content before the partial (safe to emit)
prefix: String,
/// The partial match to hold
partial: String,
/// Which patterns this could match
possible_patterns: Vec<String>,
},
/// No markers detected
None {
/// All content is safe to emit
content: String,
},
}
/// Efficient multi-pattern matcher with partial suffix detection
pub struct MarkerMatcher {
/// All patterns we're looking for
patterns: Vec<String>,
/// Aho-Corasick matcher for complete patterns
complete_matcher: AhoCorasick,
/// Trie for partial matching
prefix_trie: PrefixTrie,
/// Maximum pattern length (for buffer limits)
max_pattern_len: usize,
}
impl MarkerMatcher {
/// Create a new matcher with the given patterns
pub fn new(patterns: Vec<String>) -> Result<Self, String> {
if patterns.is_empty() {
return Err("Cannot create MarkerMatcher with empty patterns".to_string());
}
let complete_matcher = AhoCorasickBuilder::new()
.match_kind(MatchKind::LeftmostFirst)
.build(&patterns)
.map_err(|e| format!("Failed to build Aho-Corasick matcher: {}", e))?;
let max_pattern_len = patterns.iter().map(|p| p.len()).max().unwrap_or(0);
let prefix_trie = PrefixTrie::new(&patterns);
Ok(Self {
patterns,
complete_matcher,
prefix_trie,
max_pattern_len,
})
}
/// Get the maximum pattern length
pub fn max_pattern_len(&self) -> usize {
self.max_pattern_len
}
/// Safe UTF-8 slicing that ensures we only slice at character boundaries
fn safe_slice(text: &str, start_byte: usize, end_byte: usize) -> String {
// Clamp indices to valid boundaries
let start = text
.char_indices()
.find(|(i, _)| *i >= start_byte)
.map(|(i, _)| i)
.unwrap_or(text.len());
let end = text
.char_indices()
.find(|(i, _)| *i >= end_byte)
.map(|(i, _)| i)
.unwrap_or(text.len());
text[start..end].to_string()
}
/// Process a chunk with an optional partial buffer from previous chunk
pub fn process_chunk(&self, chunk: &str, partial_buffer: &str) -> MatchResult {
// Combine buffer with new chunk
let combined = if partial_buffer.is_empty() {
chunk.to_string()
} else {
format!("{}{}", partial_buffer, chunk)
};
// First check for complete markers
if let Some(mat) = self.complete_matcher.find(&combined) {
let marker = &self.patterns[mat.pattern().as_usize()];
return MatchResult::Complete {
prefix: Self::safe_slice(&combined, 0, mat.start()),
marker: marker.clone(),
marker_start: mat.start(),
suffix: Self::safe_slice(&combined, mat.end(), combined.len()),
};
}
// No complete match - check for partial at ANY suffix position
// This is the key: check "n<T" → finds "<T" as partial
if let Some((partial_start, partial, patterns)) = self.find_partial_suffix(&combined) {
return MatchResult::Partial {
prefix: Self::safe_slice(&combined, 0, partial_start),
partial: partial.to_string(),
possible_patterns: patterns,
};
}
// No matches at all
MatchResult::None { content: combined }
}
/// Find the longest partial match in any suffix of the input
///
/// This scans from left to right to find the EARLIEST partial match,
/// ensuring we emit as much content as possible while holding only the minimal partial.
fn find_partial_suffix<'a>(&self, text: &'a str) -> Option<(usize, &'a str, Vec<String>)> {
// Start from the beginning to find the EARLIEST partial match
// This ensures we emit as much as possible
// Use char_indices to get valid UTF-8 boundaries
for (i, _) in text.char_indices() {
let suffix = &text[i..];
if let Some(patterns) = self.prefix_trie.find_prefix_match(suffix) {
// This suffix is a prefix of one or more patterns
return Some((i, suffix, patterns));
}
}
None
}
}
/// Trie structure for efficient prefix matching
struct PrefixTrie {
root: TrieNode,
}
#[derive(Debug)]
struct TrieNode {
children: HashMap<char, TrieNode>,
/// Patterns that have this exact prefix
matching_patterns: Vec<String>,
/// Is this node a complete pattern?
is_complete: bool,
}
impl PrefixTrie {
fn new(patterns: &[String]) -> Self {
let mut root = TrieNode {
children: HashMap::new(),
matching_patterns: Vec::new(),
is_complete: false,
};
// Build trie
for pattern in patterns {
let mut current = &mut root;
let chars: Vec<char> = pattern.chars().collect();
for (i, &ch) in chars.iter().enumerate() {
current = current.children.entry(ch).or_insert(TrieNode {
children: HashMap::new(),
matching_patterns: Vec::new(),
is_complete: false,
});
// Add this pattern to all prefix nodes
if !current.matching_patterns.contains(pattern) {
current.matching_patterns.push(pattern.clone());
}
// Mark complete if we're at the end
if i == chars.len() - 1 {
current.is_complete = true;
}
}
}
PrefixTrie { root }
}
/// Check if text is a prefix of any pattern (but not a complete pattern)
fn find_prefix_match(&self, text: &str) -> Option<Vec<String>> {
let mut current = &self.root;
for ch in text.chars() {
if let Some(node) = current.children.get(&ch) {
current = node;
} else {
// Not a prefix of any pattern
return None;
}
}
// If we matched the entire text and it's a prefix of something (but not complete)
if !current.matching_patterns.is_empty() && !current.is_complete {
Some(current.matching_patterns.clone())
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_complete_match() {
let patterns = vec!["<TOOLCALL>".to_string(), "<tool_call>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
let result = matcher.process_chunk("<TOOLCALL>data", "");
if let MatchResult::Complete {
prefix,
marker,
suffix,
..
} = result
{
assert_eq!(prefix, "");
assert_eq!(marker, "<TOOLCALL>");
assert_eq!(suffix, "data");
} else {
panic!("Expected complete match");
}
}
#[test]
fn test_partial_match_suffix() {
let patterns = vec!["<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
// Test the key case: "n<T" should detect "<T" as partial
let result = matcher.process_chunk("n<T", "");
if let MatchResult::Partial {
prefix,
partial,
possible_patterns,
} = result
{
assert_eq!(prefix, "n");
assert_eq!(partial, "<T");
assert_eq!(possible_patterns, vec!["<TOOLCALL>"]);
} else {
panic!("Expected partial match, got: {:?}", result);
}
}
#[test]
fn test_no_false_positive() {
let patterns = vec!["<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
// Test case: "n < 5" should not trigger partial match
let result = matcher.process_chunk("n < 5", "");
if let MatchResult::None { content } = result {
assert_eq!(content, "n < 5");
} else {
panic!("Expected no match, got: {:?}", result);
}
}
#[test]
fn test_partial_buffer_combination() {
let patterns = vec!["<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
// First chunk: partial "<"
let result1 = matcher.process_chunk("<", "");
let partial = if let MatchResult::Partial { partial, .. } = result1 {
partial
} else {
panic!("Expected partial match");
};
// Second chunk: "TOOLCALL>" completes the pattern
let result2 = matcher.process_chunk("TOOLCALL>", &partial);
if let MatchResult::Complete { marker, .. } = result2 {
assert_eq!(marker, "<TOOLCALL>");
} else {
panic!("Expected complete match, got: {:?}", result2);
}
}
#[test]
fn test_prefix_with_content() {
let patterns = vec!["<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
let result = matcher.process_chunk("text before <TOOLCALL> after", "");
if let MatchResult::Complete {
prefix,
marker,
suffix,
..
} = result
{
assert_eq!(prefix, "text before ");
assert_eq!(marker, "<TOOLCALL>");
assert_eq!(suffix, " after");
} else {
panic!("Expected complete match");
}
}
#[test]
fn test_empty_patterns() {
let result = MarkerMatcher::new(vec![]);
assert!(result.is_err());
}
#[test]
fn test_multiple_patterns() {
let patterns = vec![
"<TOOLCALL>".to_string(),
"[TOOL_CALLS]".to_string(),
"<tool_call>".to_string(),
];
let matcher = MarkerMatcher::new(patterns).unwrap();
// Test different patterns
let result1 = matcher.process_chunk("[TOOL_CALLS]", "");
if let MatchResult::Complete { marker, .. } = result1 {
assert_eq!(marker, "[TOOL_CALLS]");
} else {
panic!("Expected complete match for [TOOL_CALLS]");
}
// Test partial for different pattern
let result2 = matcher.process_chunk("text<to", "");
if let MatchResult::Partial {
partial,
possible_patterns,
..
} = result2
{
assert_eq!(partial, "<to");
assert!(possible_patterns.contains(&"<tool_call>".to_string()));
} else {
panic!("Expected partial match for <tool_call>");
}
}
#[test]
fn test_multiple_partial_matches_edge_case() {
// Test scenario: Multiple patterns where one looks like a prefix but isn't valid
// Patterns: ["FooBar", "<TOOLCALL>"]
// Input: "This is FooBaz which is a no, but <TOO"
// Key insight: "FooBa" from "FooBaz" is NOT a valid partial because the 'z'
// doesn't match the expected 'r' in "FooBar"
// Expected: Hold "<TOO" as partial, emit "This is FooBaz which is a no, but "
let patterns = vec!["FooBar".to_string(), "<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
let result = matcher.process_chunk("This is FooBaz which is a no, but <TOO", "");
if let MatchResult::Partial {
prefix,
partial,
possible_patterns,
} = result
{
// The algorithm correctly skips "FooBaz" (not a valid prefix) and finds "<TOO"
assert_eq!(partial, "<TOO");
assert_eq!(prefix, "This is FooBaz which is a no, but ");
assert!(possible_patterns.contains(&"<TOOLCALL>".to_string()));
} else {
panic!("Expected partial match for '<TOO>', got: {:?}", result);
}
}
#[test]
fn test_earliest_valid_partial_match() {
// Test that the algorithm finds the earliest VALID partial match
// Patterns: ["FooBar", "<TOOLCALL>"]
// Input: "Some text FooBa and then <TO"
// Analysis: "FooBa and then <TO" is not a valid prefix of "FooBar" because
// after "FooBa" we have " " (space) but "FooBar" expects "r"
// Expected: Skip invalid "FooBa..." and find valid "<TO" partial
let patterns = vec!["FooBar".to_string(), "<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
let result = matcher.process_chunk("Some text FooBa and then <TO", "");
if let MatchResult::Partial {
prefix,
partial,
possible_patterns,
} = result
{
// Should find "<TO" as the valid partial match
assert_eq!(partial, "<TO");
assert_eq!(prefix, "Some text FooBa and then ");
assert!(possible_patterns.contains(&"<TOOLCALL>".to_string()));
} else {
panic!("Expected partial match for '<TO>', got: {:?}", result);
}
}
#[test]
fn test_partial_at_exact_end() {
// Test case where a valid partial is exactly at the end
// Patterns: ["FooBar", "<TOOLCALL>"]
// Input: "Some text ending with FooBa"
// Expected: Hold "FooBa" as partial (valid prefix of "FooBar")
let patterns = vec!["FooBar".to_string(), "<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
let result = matcher.process_chunk("Some text ending with FooBa", "");
if let MatchResult::Partial {
prefix,
partial,
possible_patterns,
} = result
{
// Should find "FooBa" as a valid partial match at the end
assert_eq!(partial, "FooBa");
assert_eq!(prefix, "Some text ending with ");
assert!(possible_patterns.contains(&"FooBar".to_string()));
} else {
panic!("Expected partial match for 'FooBa', got: {:?}", result);
}
}
#[test]
fn test_unicode_complete_match() {
// Test complete pattern matching with unicode content
// Use patterns with ASCII markers but unicode content
let patterns = vec!["<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
// Test with emoji and multi-byte characters
let result = matcher.process_chunk("Hello 👋 world <TOOLCALL>data 🚀", "");
if let MatchResult::Complete {
prefix,
marker,
suffix,
..
} = result
{
assert_eq!(prefix, "Hello 👋 world ");
assert_eq!(marker, "<TOOLCALL>");
assert_eq!(suffix, "data 🚀");
} else {
panic!("Expected complete match, got: {:?}", result);
}
}
#[test]
fn test_unicode_partial_match() {
// Test partial matching where the partial might occur after unicode content
let patterns = vec!["<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
// Test partial after multi-byte characters
let result = matcher.process_chunk("Text with 中文字符 and <TO", "");
if let MatchResult::Partial {
prefix,
partial,
possible_patterns,
} = result
{
assert_eq!(prefix, "Text with 中文字符 and ");
assert_eq!(partial, "<TO");
assert!(possible_patterns.contains(&"<TOOLCALL>".to_string()));
} else {
panic!("Expected partial match, got: {:?}", result);
}
}
#[test]
fn test_unicode_no_false_positive() {
// Test that unicode content doesn't create false positives
let patterns = vec!["<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
// Test with unicode that might look similar to ASCII patterns
let result = matcher.process_chunk("Unicode test <TOOLCALL> full-width", "");
if let MatchResult::None { content } = result {
assert_eq!(content, "Unicode test <TOOLCALL> full-width");
} else {
panic!(
"Expected no match for full-width characters, got: {:?}",
result
);
}
}
#[test]
fn test_unicode_pattern_itself() {
// Test patterns that contain unicode characters
let patterns = vec!["🔧工具".to_string(), "📞call".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
// Test complete match with unicode pattern
let result1 = matcher.process_chunk("Start 🔧工具 end", "");
if let MatchResult::Complete {
prefix,
marker,
suffix,
..
} = result1
{
assert_eq!(prefix, "Start ");
assert_eq!(marker, "🔧工具");
assert_eq!(suffix, " end");
} else {
panic!(
"Expected complete match for unicode pattern, got: {:?}",
result1
);
}
// Test partial match with unicode pattern
let result2 = matcher.process_chunk("Text 🔧工", "");
if let MatchResult::Partial {
prefix,
partial,
possible_patterns,
} = result2
{
assert_eq!(prefix, "Text ");
assert_eq!(partial, "🔧工");
assert!(possible_patterns.contains(&"🔧工具".to_string()));
} else {
panic!(
"Expected partial match for unicode pattern, got: {:?}",
result2
);
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_async_openai::types::{
ChatChoiceStream, ChatCompletionStreamResponseDelta, FinishReason, Role,
};
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream;
use dynamo_runtime::protocols::annotated::Annotated;
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
use futures::stream;
// Test utilities module - shared test infrastructure
pub(crate) mod test_utils {
use super::*;
/// Helper function to create a mock chat response chunk
pub fn create_mock_response_chunk(
content: String,
index: u32,
) -> Annotated<NvCreateChatCompletionStreamResponse> {
#[allow(deprecated)]
let choice = ChatChoiceStream {
index,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: Some(content),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: None,
};
let response = NvCreateChatCompletionStreamResponse {
id: "test-id".to_string(),
choices: vec![choice],
created: 1234567890,
model: "test-model".to_string(),
system_fingerprint: Some("test-fingerprint".to_string()),
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
};
Annotated {
data: Some(response),
id: None,
event: None,
comment: None,
}
}
/// Helper function to create a final response chunk with finish reason
pub fn create_final_response_chunk(
index: u32,
) -> Annotated<NvCreateChatCompletionStreamResponse> {
#[allow(deprecated)]
let choice = ChatChoiceStream {
index,
delta: ChatCompletionStreamResponseDelta {
role: None,
content: None,
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: Some(FinishReason::Stop),
logprobs: None,
};
let response = NvCreateChatCompletionStreamResponse {
id: "test-id".to_string(),
choices: vec![choice],
created: 1234567890,
model: "test-model".to_string(),
system_fingerprint: Some("test-fingerprint".to_string()),
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
};
Annotated {
data: Some(response),
id: None,
event: None,
comment: None,
}
}
/// Helper function to create a mock chat response chunk with metadata
pub fn create_annotated_chunk(
content: String,
index: u32,
id: Option<String>,
event: Option<String>,
comment: Option<Vec<String>>,
) -> Annotated<NvCreateChatCompletionStreamResponse> {
#[allow(deprecated)]
let choice = ChatChoiceStream {
index,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: Some(content),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: None,
};
let response = NvCreateChatCompletionStreamResponse {
id: "test-id".to_string(),
choices: vec![choice],
created: 1234567890,
model: "test-model".to_string(),
system_fingerprint: Some("test-fingerprint".to_string()),
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
};
Annotated {
data: Some(response),
id,
event,
comment,
}
}
/// Helper function to create a multi-choice chunk
pub fn create_multi_choice_chunk(
choices_content: Vec<(String, u32)>, // (content, index)
) -> Annotated<NvCreateChatCompletionStreamResponse> {
let choices: Vec<ChatChoiceStream> = choices_content
.into_iter()
.map(|(content, index)| {
#[allow(deprecated)]
ChatChoiceStream {
index,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: Some(content),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: None,
}
})
.collect();
let response = NvCreateChatCompletionStreamResponse {
id: "test-id".to_string(),
choices,
created: 1234567890,
model: "test-model".to_string(),
system_fingerprint: Some("test-fingerprint".to_string()),
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
};
Annotated {
data: Some(response),
id: None,
event: None,
comment: None,
}
}
/// Helper to assert content in a result
pub fn assert_content(
result: &Annotated<NvCreateChatCompletionStreamResponse>,
expected: &str,
) {
let content = result
.data
.as_ref()
.and_then(|d| d.choices.first())
.and_then(|c| c.delta.content.as_ref())
.expect("Expected content in result");
assert_eq!(
content, expected,
"Content mismatch: expected '{}', got '{}'",
expected, content
);
}
/// Helper to assert a tool call in a result
pub fn assert_tool_call(
result: &Annotated<NvCreateChatCompletionStreamResponse>,
name: &str,
args: serde_json::Value,
) {
let tool_calls = result
.data
.as_ref()
.and_then(|d| d.choices.first())
.and_then(|c| c.delta.tool_calls.as_ref())
.expect("Expected tool calls in result");
assert!(!tool_calls.is_empty(), "Expected at least one tool call");
let tool_call = &tool_calls[0];
let function = tool_call
.function
.as_ref()
.expect("Expected function in tool call");
assert_eq!(
function.name.as_deref(),
Some(name),
"Tool call name mismatch: expected '{}', got '{:?}'",
name,
function.name
);
if let Some(arguments_str) = &function.arguments {
let parsed_args: serde_json::Value = serde_json::from_str(arguments_str)
.expect("Tool call arguments should be valid JSON");
assert_eq!(
parsed_args, args,
"Tool call arguments mismatch: expected {}, got {}",
args, parsed_args
);
} else if !args.is_null() {
panic!("Expected tool call arguments {} but got None", args);
}
}
/// Helper to assert no content or tool calls (for accumulated chunks)
#[allow(dead_code)]
pub fn assert_empty_emission(result: &Annotated<NvCreateChatCompletionStreamResponse>) {
if let Some(data) = &result.data
&& let Some(choice) = data.choices.first()
{
assert!(
choice.delta.content.is_none()
|| choice.delta.content.as_ref().unwrap().is_empty(),
"Expected no content but got: {:?}",
choice.delta.content
);
assert!(
choice.delta.tool_calls.is_none()
|| choice.delta.tool_calls.as_ref().unwrap().is_empty(),
"Expected no tool calls but got: {:?}",
choice.delta.tool_calls
);
}
}
/// Helper to reconstruct all content from results
pub fn reconstruct_content(
results: &[Annotated<NvCreateChatCompletionStreamResponse>],
) -> String {
results
.iter()
.filter_map(|r| {
r.data
.as_ref()
.and_then(|d| d.choices.first())
.and_then(|c| c.delta.content.as_ref())
})
.cloned()
.collect::<Vec<_>>()
.join("")
}
/// Helper to extract content from a single result (for negative assertions)
pub fn extract_content(result: &Annotated<NvCreateChatCompletionStreamResponse>) -> String {
result
.data
.as_ref()
.and_then(|d| d.choices.first())
.and_then(|c| c.delta.content.as_ref())
.cloned()
.unwrap_or_default()
}
/// Helper to check if result contains a tool call
pub fn has_tool_call(result: &Annotated<NvCreateChatCompletionStreamResponse>) -> bool {
result
.data
.as_ref()
.and_then(|d| d.choices.first())
.and_then(|c| c.delta.tool_calls.as_ref())
.map(|tc| !tc.is_empty())
.unwrap_or(false)
}
/// Helper to check if result contains content
#[allow(dead_code)]
pub fn has_content(result: &Annotated<NvCreateChatCompletionStreamResponse>) -> bool {
result
.data
.as_ref()
.and_then(|d| d.choices.first())
.and_then(|c| c.delta.content.as_ref())
.map(|content| !content.is_empty())
.unwrap_or(false)
}
}
use serde_json::json;
use test_utils::*;
#[tokio::test]
async fn test_jailed_stream_with_start_end_sequences() {
// Create chunks with jail start/end markers
let chunks = vec![
create_mock_response_chunk("Hello ".to_string(), 0),
create_mock_response_chunk("<jail>".to_string(), 0),
create_mock_response_chunk("This is jailed ".to_string(), 0),
create_mock_response_chunk("content".to_string(), 0),
create_mock_response_chunk("</jail>".to_string(), 0),
create_mock_response_chunk(" World".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Create JailedStream with start/end sequences
let jail = JailedStream::builder()
.jail_start_sequence("<jail>")
.jail_end_sequence("</jail>")
.build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// We should only get 3 chunks now:
// 1. "Hello " (before jail)
// 2. Accumulated jailed content when jail ends
// 3. " World" (after jail)
assert_eq!(results.len(), 3);
// First chunk should pass through
assert_eq!(
results[0].data.as_ref().unwrap().choices[0]
.delta
.content
.as_deref(),
Some("Hello ")
);
// When jail ends, accumulated content should be released
let unjailed_content = &results[1].data.as_ref().unwrap().choices[0].delta.content;
assert!(unjailed_content.is_some());
assert!(
unjailed_content
.as_ref()
.unwrap()
.contains("<jail>This is jailed content</jail>")
);
// Last chunk should pass through normally
assert_eq!(
results[2].data.as_ref().unwrap().choices[0]
.delta
.content
.as_deref(),
Some(" World")
);
}
#[tokio::test]
async fn test_jailed_stream_with_tool_calls() {
// Create chunks representing a tool call
let chunks = vec![
create_mock_response_chunk("<TOOLCALL>".to_string(), 0),
create_mock_response_chunk(
"[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"SF\"}}]".to_string(),
0,
),
create_mock_response_chunk("</TOOLCALL>".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Create JailedStream with tool call parser
let jail = JailedStream::builder()
.tool_call_parser("nemotron_deci")
.build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// Should have jailed the content and parsed tool calls at the end
assert!(!results.is_empty());
// Check if tool calls were parsed
if let Some(last_result) = results.last()
&& let Some(ref response_data) = last_result.data
&& let Some(ref tool_calls) = response_data.choices[0].delta.tool_calls
{
assert!(!tool_calls.as_slice().is_empty());
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name.as_deref(),
Some("get_weather")
);
}
}
#[tokio::test]
async fn test_jailed_stream_dual_entry_paths() {
// Test that BOTH sequence AND tool call detection can trigger jail
let chunks = vec![
create_mock_response_chunk("Normal text ".to_string(), 0),
create_mock_response_chunk("<jail><TOOLCALL>".to_string(), 0), // Both triggers
create_mock_response_chunk("Jailed content".to_string(), 0),
create_mock_response_chunk("</jail>".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Configure with both sequences AND tool call parser
let jail = JailedStream::builder()
.jail_start_sequence("<jail>")
.jail_end_sequence("</jail>")
.tool_call_parser("nemotron_deci")
.build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// We should get 2 chunks:
// 1. "Normal text " (before jail)
// 2. Accumulated jailed content when jail ends via </jail>
assert_eq!(results.len(), 2);
// First chunk should pass through
assert_eq!(
results[0].data.as_ref().unwrap().choices[0]
.delta
.content
.as_deref(),
Some("Normal text ")
);
// Second chunk should contain the accumulated jailed content
let jailed = results[1].data.as_ref().unwrap().choices[0]
.delta
.content
.as_ref()
.expect("Expected accumulated jailed content");
assert!(jailed.contains("<jail><TOOLCALL>Jailed content</jail>"));
}
#[tokio::test]
async fn test_jailed_stream_early_exit() {
// Tests detection of complete tool call with unjail in same chunk as the end marker
// Input: "<TOOLCALL>" + "[{\"name\": \"test\", " + "\"arguments\": {}}]" + "</TOOLCALL>More text"
// Expected output: 2 chunks [ToolCall(), Content()]
let chunks = vec![
create_mock_response_chunk("<TOOLCALL>".to_string(), 0),
create_mock_response_chunk("[{\"name\": \"test\", ".to_string(), 0),
create_mock_response_chunk("\"arguments\": {}}]".to_string(), 0),
create_mock_response_chunk("</TOOLCALL>More text".to_string(), 0),
];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder()
.tool_call_parser("nemotron_deci")
.build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// Should have exactly 2 chunks: tool call + trailing content
assert_eq!(
results.len(),
2,
"Should have tool call and trailing content"
);
// Verify exact output structure: [ToolCall(), Content()]
test_utils::assert_tool_call(&results[0], "test", serde_json::json!({}));
test_utils::assert_content(&results[1], "More text");
// Verify content reconstruction excludes tool calls
let reconstructed = test_utils::reconstruct_content(&results);
assert_eq!(reconstructed, "More text");
}
#[tokio::test]
async fn test_jailed_stream_no_jailing() {
// Input chunks:
// [0] "Hello "
// [1] "World"
// [2] [final chunk]
//
// Expected output (pass-through):
// [0] Content("Hello ")
// [1] Content("World")
// [2] [final chunk]
let chunks = vec![
create_mock_response_chunk("Hello ".to_string(), 0),
create_mock_response_chunk("World".to_string(), 0),
create_final_response_chunk(0),
];
let input_stream = stream::iter(chunks);
// Create JailedStream with sequences that won't match
let jail = JailedStream::builder()
.jail_start_sequence("<NOTPRESENT>")
.build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// === Verify chunk count ===
assert_eq!(
results.len(),
3,
"Should pass through all 3 chunks unchanged"
);
// === Verify individual chunks ===
assert_content(&results[0], "Hello ");
assert_content(&results[1], "World");
// results[2] is the final chunk - no content to verify
// === Verify negative assertions ===
for (i, result) in results.iter().take(2).enumerate() {
assert!(
!has_tool_call(result),
"Chunk {} should not contain tool calls when no patterns match",
i
);
}
// === Verify content reconstruction ===
assert_eq!(
reconstruct_content(&results),
"Hello World",
"Content should pass through unchanged when no jailing occurs"
);
}
#[tokio::test]
async fn test_jailed_stream_hermes_parser() {
// Tests Hermes format tool call parsing with <tool_call> markers
// Input: "I'll help you with that. " + "<tool_call>{\"name\": \"search_web\", \"arguments\": {\"query\": \"weather today\"}}</tool_call>" + " Let me search for that."
// Expected output: 3 chunks [Content(), ToolCall(), Content()]
let chunks = vec![
create_mock_response_chunk("I'll help you with that. ".to_string(), 0),
create_mock_response_chunk("<tool_call>".to_string(), 0),
create_mock_response_chunk("{\"name\": \"search_web\", ".to_string(), 0),
create_mock_response_chunk(
"\"arguments\": {\"query\": \"weather today\"}}".to_string(),
0,
),
create_mock_response_chunk("</tool_call>".to_string(), 0),
create_mock_response_chunk(" Let me search for that.".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Create JailedStream with Hermes parser
let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// Should have exactly 3 chunks: content + tool call + content
assert_eq!(
results.len(),
3,
"Should have content, tool call, and trailing content"
);
// Verify exact output structure: [Content(), ToolCall(), Content()]
test_utils::assert_content(&results[0], "I'll help you with that. ");
test_utils::assert_tool_call(
&results[1],
"search_web",
serde_json::json!({"query": "weather today"}),
);
test_utils::assert_content(&results[2], " Let me search for that.");
// Verify content reconstruction excludes tool calls
let reconstructed = test_utils::reconstruct_content(&results);
assert_eq!(
reconstructed,
"I'll help you with that. Let me search for that."
);
}
#[tokio::test]
async fn test_jailed_stream_mistral_parser() {
// Tests Mistral format tool call parsing with [{ pattern
// Input: "Sure, I can help. " + "[{\"name\": \"calculate\", \"arguments\": {\"expression\": \"2+2\"}}]" + " The calculation is done."
// Expected output: 3 chunks [Content(), ToolCall(), Content()]
let chunks = vec![
create_mock_response_chunk("Sure, I can help. ".to_string(), 0),
create_mock_response_chunk("[{".to_string(), 0),
create_mock_response_chunk("\"name\": \"calculate\", ".to_string(), 0),
create_mock_response_chunk("\"arguments\": {\"expression\": \"2+2\"}".to_string(), 0),
create_mock_response_chunk("}]".to_string(), 0),
create_mock_response_chunk(" The calculation is done.".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Create JailedStream with Mistral parser
let jail = JailedStream::builder().tool_call_parser("mistral").build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// Should have exactly 3 chunks: content + tool call + content
assert_eq!(
results.len(),
3,
"Should have content, tool call, and trailing content"
);
// Verify exact output structure: [Content(), ToolCall(), Content()]
test_utils::assert_content(&results[0], "Sure, I can help. ");
test_utils::assert_tool_call(
&results[1],
"calculate",
serde_json::json!({"expression": "2+2"}),
);
test_utils::assert_content(&results[2], " The calculation is done.");
// Verify content reconstruction excludes tool calls
let reconstructed = test_utils::reconstruct_content(&results);
assert_eq!(reconstructed, "Sure, I can help. The calculation is done.");
}
#[tokio::test]
async fn test_jailed_stream_mistral_parser_with_tool_calls_marker() {
// Tests Mistral format tool call parsing with explicit [TOOL_CALLS] marker
// Input: "Let me check that for you. " + "[TOOL_CALLS][{\"name\": \"get_time\", \"arguments\": {\"timezone\": \"UTC\"}}]" + " Here's the time."
// Expected output: 3 chunks [Content(), ToolCall(), Content()]
let chunks = vec![
create_mock_response_chunk("Let me check that for you. ".to_string(), 0),
create_mock_response_chunk("[TOOL_CALLS]".to_string(), 0),
create_mock_response_chunk("[{\"name\": \"get_time\", ".to_string(), 0),
create_mock_response_chunk("\"arguments\": {\"timezone\": \"UTC\"}}]".to_string(), 0),
create_mock_response_chunk(" Here's the time.".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Create JailedStream with Mistral parser
let jail = JailedStream::builder().tool_call_parser("mistral").build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// Should have exactly 3 chunks: content + tool call + content
assert_eq!(
results.len(),
3,
"Should have content, tool call, and trailing content"
);
// Verify exact output structure: [Content(), ToolCall(), Content()]
test_utils::assert_content(&results[0], "Let me check that for you. ");
test_utils::assert_tool_call(
&results[1],
"get_time",
serde_json::json!({"timezone": "UTC"}),
);
test_utils::assert_content(&results[2], " Here's the time.");
// Verify content reconstruction excludes tool calls
let reconstructed = test_utils::reconstruct_content(&results);
assert_eq!(
reconstructed,
"Let me check that for you. Here's the time."
);
}
#[tokio::test]
async fn test_jailed_stream_phi4_parser() {
// Tests Phi4 format tool call parsing with functools[ pattern
// Input: "I'll analyze this data. " + "functools[{\"name\": \"analyze_data\", \"arguments\": {\"dataset\": \"sales_data\"}}]" + " Analysis complete."
// Expected output: 3 chunks [Content(), ToolCall(), Content()]
let chunks = vec![
create_mock_response_chunk("I'll analyze this data. ".to_string(), 0),
create_mock_response_chunk("functools[".to_string(), 0),
create_mock_response_chunk("{\"name\": \"analyze_data\", ".to_string(), 0),
create_mock_response_chunk(
"\"arguments\": {\"dataset\": \"sales_data\"}}".to_string(),
0,
),
create_mock_response_chunk("]".to_string(), 0),
create_mock_response_chunk(" Analysis complete.".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Create JailedStream with Phi4 parser
let jail = JailedStream::builder().tool_call_parser("phi4").build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// Should have exactly 3 chunks: content + tool call + content
assert_eq!(
results.len(),
3,
"Should have content, tool call, and trailing content"
);
// Verify exact output structure: [Content(), ToolCall(), Content()]
test_utils::assert_content(&results[0], "I'll analyze this data. ");
test_utils::assert_tool_call(
&results[1],
"analyze_data",
serde_json::json!({"dataset": "sales_data"}),
);
test_utils::assert_content(&results[2], " Analysis complete.");
// Verify content reconstruction excludes tool calls
let reconstructed = test_utils::reconstruct_content(&results);
assert_eq!(reconstructed, "I'll analyze this data. Analysis complete.");
}
#[tokio::test]
async fn test_jailed_stream_llama3_json_parser() {
// Tests Llama3 JSON format tool call parsing with <|python_tag|> pattern
// Input: "Let me run some code. " + "<|python_tag|>{\"name\": \"execute_code\", \"arguments\": {\"code\": \"print('Hello')\"}}" + " Done executing."
// Expected output: 3 chunks [Content(), ToolCall(), Content()]
let chunks = vec![
create_mock_response_chunk("Let me run some code. ".to_string(), 0),
create_mock_response_chunk("<|python_tag|>".to_string(), 0),
create_mock_response_chunk("{\"name\": \"execute_code\", ".to_string(), 0),
create_mock_response_chunk(
"\"arguments\": {\"code\": \"print('Hello')\"}}".to_string(),
0,
),
create_mock_response_chunk(" Done executing.".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Create JailedStream with llama3_json parser
let jail = JailedStream::builder()
.tool_call_parser("llama3_json")
.build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
println!("results: {:?}", results);
// Should have exactly 3 chunks: content + tool call + content
assert_eq!(
results.len(),
3,
"Should have content, tool call, and trailing content"
);
// Verify exact output structure: [Content(), ToolCall(), Content()]
test_utils::assert_content(&results[0], "Let me run some code. ");
test_utils::assert_tool_call(
&results[1],
"execute_code",
serde_json::json!({"code": "print('Hello')"}),
);
test_utils::assert_content(&results[2], " Done executing.");
// Verify content reconstruction excludes tool calls
let reconstructed = test_utils::reconstruct_content(&results);
assert_eq!(reconstructed, "Let me run some code. Done executing.");
}
#[tokio::test]
async fn test_jailed_stream_false_positive_json() {
// Tests that JSON-like content doesn't trigger false positive tool call detection
// Input: "I can explain JSON format. " + "Here's an example: { \"key\": \"value\" }" + " is a simple JSON object. " + "Hope that helps!"
// Expected output: 4 chunks [Content(), Content(), Content(), Content()] - no jailing
let chunks = vec![
create_mock_response_chunk("I can explain JSON format. ".to_string(), 0),
create_mock_response_chunk("Here's an example: { \"key\": \"value\" }".to_string(), 0),
create_mock_response_chunk(" is a simple JSON object. ".to_string(), 0),
create_mock_response_chunk("Hope that helps!".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Create JailedStream with mistral parser (which specifically looks for [{ or [TOOL_CALLS] patterns)
let jail = JailedStream::builder().tool_call_parser("mistral").build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
println!("results: {:?}", results);
// The "{" pattern triggers jailing, so some chunks get combined
assert_eq!(results.len(), 2);
// Verify exact output structure: content chunks
test_utils::assert_content(&results[0], "I can explain JSON format. ");
test_utils::assert_content(
&results[1],
"Here's an example: { \"key\": \"value\" } is a simple JSON object. Hope that helps!",
);
// Verify no tool calls were detected and all content preserved
let reconstructed = test_utils::reconstruct_content(&results);
assert_eq!(
reconstructed,
"I can explain JSON format. Here's an example: { \"key\": \"value\" } is a simple JSON object. Hope that helps!"
);
}
#[tokio::test]
async fn test_jailed_stream_malformed_tool_call() {
// Tests graceful handling of malformed JSON within tool call markers
// Input: "Let me call a function. " + "<TOOLCALL>[{\"name\": \"broken_func\", \"arguments\": {\"param\": incomplete</TOOLCALL>" + " Function call attempt finished."
// Expected output: 3 chunks [Content(), Content(malformed), Content()] - parser fails gracefully
let chunks = vec![
create_mock_response_chunk("Let me call a function. ".to_string(), 0),
create_mock_response_chunk("<TOOLCALL>".to_string(), 0),
create_mock_response_chunk("[{\"name\": \"broken_func\", ".to_string(), 0),
create_mock_response_chunk("\"arguments\": {\"param\": incomplete".to_string(), 0), // Malformed JSON
create_mock_response_chunk("</TOOLCALL>".to_string(), 0),
create_mock_response_chunk(" Function call attempt finished.".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Create JailedStream with nemotron_deci parser
let jail = JailedStream::builder()
.tool_call_parser("nemotron_deci")
.build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// Jailing combines the tool call content into fewer chunks
assert_eq!(
results.len(),
3,
"Should handle malformed JSON gracefully and jail appropriately"
);
// Verify exact output structure: [Content(), Content(complete jailed content)]
test_utils::assert_content(&results[0], "Let me call a function. ");
test_utils::assert_content(
&results[1],
"<TOOLCALL>[{\"name\": \"broken_func\", \"arguments\": {\"param\": incomplete</TOOLCALL>",
);
// Verify malformed content is preserved as text (including markers when parsing fails)
let reconstructed = test_utils::reconstruct_content(&results);
assert_eq!(
reconstructed,
"Let me call a function. <TOOLCALL>[{\"name\": \"broken_func\", \"arguments\": {\"param\": incomplete</TOOLCALL> Function call attempt finished."
);
}
#[tokio::test]
async fn test_jailed_stream_partial_tool_call() {
// Tests handling of incomplete tool call when stream ends abruptly
// Input: "Starting function call. " + "<TOOLCALL>[{\"name\": \"incomplete_func\", \"arguments\": {" (no end marker)
// Expected output: 2 chunks [Content(), Content(partial)] - partial accumulated content released on stream end
let chunks = vec![
create_mock_response_chunk("Starting function call. ".to_string(), 0),
create_mock_response_chunk("<TOOLCALL>".to_string(), 0),
create_mock_response_chunk("[{\"name\": \"incomplete_func\", ".to_string(), 0),
create_mock_response_chunk("\"arguments\": {".to_string(), 0),
// Stream ends abruptly without closing the tool call
];
let input_stream = stream::iter(chunks);
// Create JailedStream with nemotron_deci parser
let jail = JailedStream::builder()
.tool_call_parser("nemotron_deci")
.build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// Should handle partial tool call gracefully - releases accumulated content on stream end
assert_eq!(
results.len(),
2,
"Should handle partial tool call and release content"
);
// Verify exact output structure: [Content(), Content(accumulated partial)]
test_utils::assert_content(&results[0], "Starting function call. ");
test_utils::assert_content(
&results[1],
"<TOOLCALL>[{\"name\": \"incomplete_func\", \"arguments\": {",
);
// Verify partial content is preserved as text since no valid tool call could be parsed
let reconstructed = test_utils::reconstruct_content(&results);
assert_eq!(
reconstructed,
"Starting function call. <TOOLCALL>[{\"name\": \"incomplete_func\", \"arguments\": {"
);
}
#[tokio::test]
async fn test_jailed_stream_empty_stream() {
// Input chunks: []
//
// Expected output: []
let chunks: Vec<Annotated<NvCreateChatCompletionStreamResponse>> = vec![];
let input_stream = stream::iter(chunks);
// Create JailedStream
let jail = JailedStream::builder()
.tool_call_parser("nemotron_deci")
.jail_start_sequence("<jail>")
.jail_end_sequence("</jail>")
.build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// === Verify chunk count ===
assert_eq!(
results.len(),
0,
"Empty stream should produce exactly 0 results"
);
// === Verify content reconstruction ===
assert_eq!(
reconstruct_content(&results),
"",
"Empty stream should reconstruct to empty string"
);
}
#[tokio::test]
async fn test_jailed_stream_multiple_tool_calls() {
// Input chunks: 9 chunks for 2 tool calls with content between
//
// Expected output:
// [0] Content("I'll help with multiple tasks. ")
// [1] ToolCall("get_weather", {"city": "NYC"})
// [2] Content(" Now let me get the time. ")
// [3] ToolCall("get_time", {"timezone": "EST"})
// [4] Content(" Both tasks completed!")
let chunks = vec![
create_mock_response_chunk("I'll help with multiple tasks. ".to_string(), 0),
// First tool call
create_mock_response_chunk("<TOOLCALL>".to_string(), 0),
create_mock_response_chunk(
"[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"NYC\"}}]".to_string(),
0,
),
create_mock_response_chunk("</TOOLCALL>".to_string(), 0),
create_mock_response_chunk(" Now let me get the time. ".to_string(), 0),
// Second tool call
create_mock_response_chunk("<TOOLCALL>".to_string(), 0),
create_mock_response_chunk(
"[{\"name\": \"get_time\", \"arguments\": {\"timezone\": \"EST\"}}]".to_string(),
0,
),
create_mock_response_chunk("</TOOLCALL>".to_string(), 0),
create_mock_response_chunk(" Both tasks completed!".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Create JailedStream
let jail = JailedStream::builder()
.tool_call_parser("nemotron_deci")
.build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// === Verify chunk count ===
assert_eq!(
results.len(),
5,
"Should emit exactly 5 chunks as documented above"
);
// === Verify individual chunks ===
assert_content(&results[0], "I'll help with multiple tasks. ");
assert_tool_call(&results[1], "get_weather", json!({"city": "NYC"}));
assert_content(&results[2], " Now let me get the time. ");
assert_tool_call(&results[3], "get_time", json!({"timezone": "EST"}));
assert_content(&results[4], " Both tasks completed!");
// === Verify content reconstruction ===
let expected_content =
"I'll help with multiple tasks. Now let me get the time. Both tasks completed!";
assert_eq!(
reconstruct_content(&results),
expected_content,
"Content reconstruction should exclude tool calls and preserve text flow"
);
}
#[tokio::test]
async fn test_jailed_stream_tool_call_across_many_chunks() {
// Tests extreme fragmentation: tool call split across 65 individual character chunks
// Input: "I'll process your request. " + "<TOOLCALL>[{"name": "process_data", "arguments": {}}]</TOOLCALL>" + " Processing complete!"
// Expected output: 3 chunks [Content(), ToolCall(), Content()]
let chunks = vec![
create_mock_response_chunk("I'll process your request. ".to_string(), 0),
create_mock_response_chunk("<".to_string(), 0),
create_mock_response_chunk("T".to_string(), 0),
create_mock_response_chunk("O".to_string(), 0),
create_mock_response_chunk("O".to_string(), 0),
create_mock_response_chunk("L".to_string(), 0),
create_mock_response_chunk("C".to_string(), 0),
create_mock_response_chunk("A".to_string(), 0),
create_mock_response_chunk("L".to_string(), 0),
create_mock_response_chunk("L".to_string(), 0),
create_mock_response_chunk(">".to_string(), 0),
create_mock_response_chunk("[".to_string(), 0),
create_mock_response_chunk("{".to_string(), 0),
create_mock_response_chunk("\"".to_string(), 0),
create_mock_response_chunk("n".to_string(), 0),
create_mock_response_chunk("a".to_string(), 0),
create_mock_response_chunk("m".to_string(), 0),
create_mock_response_chunk("e".to_string(), 0),
create_mock_response_chunk("\"".to_string(), 0),
create_mock_response_chunk(":".to_string(), 0),
create_mock_response_chunk(" ".to_string(), 0),
create_mock_response_chunk("\"".to_string(), 0),
create_mock_response_chunk("p".to_string(), 0),
create_mock_response_chunk("r".to_string(), 0),
create_mock_response_chunk("o".to_string(), 0),
create_mock_response_chunk("c".to_string(), 0),
create_mock_response_chunk("e".to_string(), 0),
create_mock_response_chunk("s".to_string(), 0),
create_mock_response_chunk("s".to_string(), 0),
create_mock_response_chunk("_".to_string(), 0),
create_mock_response_chunk("d".to_string(), 0),
create_mock_response_chunk("a".to_string(), 0),
create_mock_response_chunk("t".to_string(), 0),
create_mock_response_chunk("a".to_string(), 0),
create_mock_response_chunk("\"".to_string(), 0),
create_mock_response_chunk(",".to_string(), 0),
create_mock_response_chunk(" ".to_string(), 0),
create_mock_response_chunk("\"".to_string(), 0),
create_mock_response_chunk("a".to_string(), 0),
create_mock_response_chunk("r".to_string(), 0),
create_mock_response_chunk("g".to_string(), 0),
create_mock_response_chunk("u".to_string(), 0),
create_mock_response_chunk("m".to_string(), 0),
create_mock_response_chunk("e".to_string(), 0),
create_mock_response_chunk("n".to_string(), 0),
create_mock_response_chunk("t".to_string(), 0),
create_mock_response_chunk("s".to_string(), 0),
create_mock_response_chunk("\"".to_string(), 0),
create_mock_response_chunk(":".to_string(), 0),
create_mock_response_chunk(" ".to_string(), 0),
create_mock_response_chunk("{".to_string(), 0),
create_mock_response_chunk("}".to_string(), 0),
create_mock_response_chunk("}".to_string(), 0),
create_mock_response_chunk("]".to_string(), 0),
create_mock_response_chunk("<".to_string(), 0),
create_mock_response_chunk("/".to_string(), 0),
create_mock_response_chunk("T".to_string(), 0),
create_mock_response_chunk("O".to_string(), 0),
create_mock_response_chunk("O".to_string(), 0),
create_mock_response_chunk("L".to_string(), 0),
create_mock_response_chunk("C".to_string(), 0),
create_mock_response_chunk("A".to_string(), 0),
create_mock_response_chunk("L".to_string(), 0),
create_mock_response_chunk("L".to_string(), 0),
create_mock_response_chunk(">".to_string(), 0),
create_mock_response_chunk(" Processing complete!".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Create JailedStream
let jail = JailedStream::builder()
.tool_call_parser("nemotron_deci")
.build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// Should consolidate extreme fragmentation into 3 clean chunks
// Input: "I'll process your request. " + 54-char tool call + " Processing complete!"
// Expected output: [Content(), ToolCall(), Content()]
assert_eq!(
results.len(),
3,
"Should consolidate fragments into 3 chunks"
);
// Verify exact output structure
test_utils::assert_content(&results[0], "I'll process your request. ");
test_utils::assert_tool_call(&results[1], "process_data", serde_json::json!({}));
test_utils::assert_content(&results[2], " Processing complete!");
// Verify content reconstruction excludes tool calls
let reconstructed = test_utils::reconstruct_content(&results);
assert_eq!(
reconstructed,
"I'll process your request. Processing complete!"
);
}
#[tokio::test]
async fn test_jailed_stream_preserves_metadata() {
// Test metadata preservation through jail processing
let test_id = Some("correlation-id-123".to_string());
let test_event = Some("request-processing".to_string());
let test_comment = Some(vec![
"upstream-correlation".to_string(),
"debug-info".to_string(),
]);
// Create chunks with specific metadata for the jail trigger
let chunks = vec![
create_annotated_chunk(
"I'll help you with that. ".to_string(),
0,
None, // No metadata on first chunk
None,
None,
),
create_annotated_chunk(
"<tool_call>".to_string(),
0,
test_id.clone(), // Metadata on jail trigger chunk
test_event.clone(),
test_comment.clone(),
),
create_mock_response_chunk("{\"name\": \"search_web\", ".to_string(), 0),
create_mock_response_chunk("\"arguments\": {\"query\": \"test\"}}".to_string(), 0),
create_mock_response_chunk("</tool_call>".to_string(), 0),
create_mock_response_chunk(" Processing complete.".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Create JailedStream with Hermes parser
let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// Should get 3 chunks: before jail, tool call response, after jail
assert!(
results.len() >= 3,
"Should have at least 3 chunks, got {}",
results.len()
);
// Find the synthesized tool call response chunk
let tool_call_chunk = results
.iter()
.find(|r| {
r.data
.as_ref()
.and_then(|d| d.choices.first())
.map(|c| c.finish_reason == Some(FinishReason::ToolCalls))
.unwrap_or(false)
})
.expect("Should have a tool call response chunk");
// Verify metadata is preserved
assert_eq!(
tool_call_chunk.id, test_id,
"ID should be preserved from jail trigger chunk"
);
assert_eq!(
tool_call_chunk.event, test_event,
"Event should be preserved from jail trigger chunk"
);
assert_eq!(
tool_call_chunk.comment, test_comment,
"Comment should be preserved from jail trigger chunk"
);
// Verify tool call was parsed correctly
let tool_calls = &tool_call_chunk.data.as_ref().unwrap().choices[0]
.delta
.tool_calls;
assert!(tool_calls.is_some(), "Should have tool calls");
let tool_calls = tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1, "Should have exactly one tool call");
assert_eq!(
tool_calls[0]
.function
.as_ref()
.unwrap()
.name
.as_ref()
.unwrap(),
"search_web"
);
}
#[tokio::test]
async fn test_jailed_stream_preserves_metadata_on_stream_end() {
// Test metadata preservation when stream ends while jailed
let test_id = Some("end-correlation-456".to_string());
let test_event = Some("stream-termination".to_string());
let test_comment = Some(vec!["incomplete-processing".to_string()]);
// Create chunks that end while jailed (no explicit end marker)
let chunks = vec![
create_mock_response_chunk("Starting function call: ".to_string(), 0),
create_annotated_chunk(
"<tool_call>".to_string(), // This chunk triggers jail and has metadata
0,
test_id.clone(),
test_event.clone(),
test_comment.clone(),
),
create_mock_response_chunk(
"{\"name\": \"incomplete_call\"".to_string(), // No closing brace
0,
),
];
let input_stream = stream::iter(chunks);
// Create JailedStream with Hermes parser
let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// Should get 2 chunks: first chunk passes through, stream end releases accumulated
assert_eq!(results.len(), 2, "Should have exactly 2 chunks");
// The second chunk is the accumulated content released when stream ended
let accumulated_chunk = &results[1];
// Verify metadata is preserved from the jail trigger
assert_eq!(
accumulated_chunk.id, test_id,
"ID should be preserved when stream ends while jailed"
);
assert_eq!(
accumulated_chunk.event, test_event,
"Event should be preserved when stream ends while jailed"
);
assert_eq!(
accumulated_chunk.comment, test_comment,
"Comment should be preserved when stream ends while jailed"
);
// Verify accumulated content is returned
let content = &accumulated_chunk.data.as_ref().unwrap().choices[0]
.delta
.content;
assert!(content.is_some(), "Should have accumulated content");
let content = content.as_ref().unwrap();
assert!(
content.contains("<tool_call>"),
"Should contain jail start marker in accumulated content"
);
assert!(
content.contains("incomplete_call"),
"Should contain accumulated incomplete content"
);
}
#[tokio::test]
async fn test_jailed_stream_metadata_edge_cases() {
// Test edge cases: empty metadata, partial metadata, etc.
let chunks = vec![
create_annotated_chunk(
"Text with ".to_string(),
0,
Some("".to_string()), // Empty string ID
None, // No event
Some(vec![]), // Empty comment vector
),
create_annotated_chunk(
"<tool_call>".to_string(),
0,
None, // No ID
Some("partial-metadata".to_string()), // Only event
None, // No comment
),
create_mock_response_chunk("{\"name\": \"test\", \"arguments\": {}}".to_string(), 0),
create_mock_response_chunk("</tool_call>".to_string(), 0),
];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// Find the tool call response
let tool_call_chunk = results
.iter()
.find(|r| {
r.data
.as_ref()
.and_then(|d| d.choices.first())
.map(|c| c.finish_reason == Some(FinishReason::ToolCalls))
.unwrap_or(false)
})
.expect("Should have a tool call response chunk");
// Verify partial metadata is preserved correctly
assert_eq!(tool_call_chunk.id, None, "Should preserve None ID");
assert_eq!(
tool_call_chunk.event,
Some("partial-metadata".to_string()),
"Should preserve event"
);
assert_eq!(
tool_call_chunk.comment, None,
"Should preserve None comment"
);
}
#[tokio::test]
async fn test_jailed_stream_trailing_content_same_chunk() {
// Input chunks:
// [0] "I'll help you. "
// [1] "<tool_call>"
// [2] "{\"name\": \"search\", \"arguments\": {}}"
// [3] "</tool_call>trailing text that should not be lost"
//
// Expected output:
// [0] Content("I'll help you. ")
// [1] ToolCall("search", {})
// [2] Content("trailing text that should not be lost")
let chunks = vec![
create_mock_response_chunk("I'll help you. ".to_string(), 0),
create_mock_response_chunk("<tool_call>".to_string(), 0),
create_mock_response_chunk("{\"name\": \"search\", \"arguments\": {}}".to_string(), 0),
// This chunk contains both the end marker AND trailing content
create_mock_response_chunk(
"</tool_call>trailing text that should not be lost".to_string(),
0,
),
];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// === Verify chunk count ===
assert_eq!(
results.len(),
3,
"Should emit exactly 3 chunks as documented above"
);
// === Verify individual chunks ===
assert_content(&results[0], "I'll help you. ");
assert_tool_call(&results[1], "search", json!({}));
assert_content(&results[2], "trailing text that should not be lost");
// === Verify content reconstruction ===
let expected_content = "I'll help you. trailing text that should not be lost";
assert_eq!(
reconstruct_content(&results),
expected_content,
"Content reconstruction should preserve initial and trailing text"
);
}
#[tokio::test]
async fn test_jailed_stream_early_exit_with_trailing() {
// Tests early exit when complete tool call is detected in chunk that also contains trailing content
// Input: "Starting task: " + "<tool_call>{\"name\": \"complete_task\", \"arguments\": {}}" + "</tool_call> Task completed successfully."
// Expected output: 3 chunks [Content(), ToolCall(), Content()]
let chunks = vec![
create_mock_response_chunk("Starting task: ".to_string(), 0),
create_mock_response_chunk(
"<tool_call>{\"name\": \"complete_task\", \"arguments\": {}}".to_string(),
0,
),
// Early exit should happen here, but we also have trailing content
create_mock_response_chunk("</tool_call> Task completed successfully.".to_string(), 0),
];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// Should have exactly 3 chunks: content + tool call + trailing
assert_eq!(
results.len(),
3,
"Should have content, tool call, and trailing content"
);
// Verify exact output structure: [Content(), ToolCall(), Content()]
test_utils::assert_content(&results[0], "Starting task: ");
test_utils::assert_tool_call(&results[1], "complete_task", serde_json::json!({}));
test_utils::assert_content(&results[2], " Task completed successfully.");
// Verify content reconstruction excludes tool calls but preserves trailing
let reconstructed = test_utils::reconstruct_content(&results);
assert_eq!(
reconstructed,
"Starting task: Task completed successfully."
);
}
#[tokio::test]
async fn test_multiple_choices_independent_jailing() {
// Test that different choices can jail and unjail independently
// This test will FAIL with the current HashMap-based implementation
let chunks = vec![
// Chunk 1: All choices start normally
create_multi_choice_chunk(vec![
("Starting task A. ".to_string(), 0),
("Starting task B. ".to_string(), 1),
("Starting task C. ".to_string(), 2),
]),
// Chunk 2: Choice 0 starts tool call (gets jailed), others continue
create_multi_choice_chunk(vec![
("<tool_call>".to_string(), 0), // Choice 0 jailed
("Continuing B. ".to_string(), 1), // Choice 1 continues
("Continuing C. ".to_string(), 2), // Choice 2 continues
]),
// Chunk 3: Choice 0 still jailed, Choice 2 starts tool call
create_multi_choice_chunk(vec![
("{\"name\": \"tool_a\"".to_string(), 0), // Choice 0 still jailed
("More B content. ".to_string(), 1), // Choice 1 continues
("<tool_call>".to_string(), 2), // Choice 2 now jailed
]),
// Chunk 4: Choice 0 finishes tool call, Choice 2 continues tool call
create_multi_choice_chunk(vec![
(", \"arguments\": {}}</tool_call>".to_string(), 0), // Choice 0 unjails
("Final B. ".to_string(), 1), // Choice 1 continues
("{\"name\": \"tool_c\", \"arguments\": {}}".to_string(), 2), // Choice 2 still jailed
]),
// Chunk 5: Choice 2 finishes tool call
create_multi_choice_chunk(vec![
("After tool A. ".to_string(), 0), // Choice 0 continues after unjail
("Done with B. ".to_string(), 1), // Choice 1 continues
("</tool_call>".to_string(), 2), // Choice 2 unjails
]),
];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// EXPECTED BEHAVIOR (will fail with current implementation):
// - Choice 1 should stream continuously (never jailed)
// - Choice 0 should jail from chunk 2 until chunk 4
// - Choice 2 should jail from chunk 3 until chunk 5
// - Each choice should emit independently
// Verify choice 1 was never interrupted (should have ~5 chunks of content)
let choice_1_chunks: Vec<_> = results
.iter()
.filter_map(|r| r.data.as_ref())
.flat_map(|d| &d.choices)
.filter(|c| c.index == 1 && c.delta.content.is_some())
.collect();
assert!(
choice_1_chunks.len() >= 4,
"Choice 1 should have multiple continuous chunks, got {}",
choice_1_chunks.len()
);
// Verify choice 0 has a tool call response
let choice_0_tool_calls: Vec<_> = results
.iter()
.filter_map(|r| r.data.as_ref())
.flat_map(|d| &d.choices)
.filter(|c| c.index == 0 && c.finish_reason == Some(FinishReason::ToolCalls))
.collect();
assert!(
!choice_0_tool_calls.is_empty(),
"Choice 0 should have tool call response"
);
// Verify choice 2 has a tool call response
let choice_2_tool_calls: Vec<_> = results
.iter()
.filter_map(|r| r.data.as_ref())
.flat_map(|d| &d.choices)
.filter(|c| c.index == 2 && c.finish_reason == Some(FinishReason::ToolCalls))
.collect();
assert!(
!choice_2_tool_calls.is_empty(),
"Choice 2 should have tool call response"
);
}
#[tokio::test]
async fn test_deterministic_choice_ordering() {
// Test that choices are processed in deterministic order (0, 1, 2...)
// This test will FAIL with the current HashMap implementation
let chunks = vec![
// All choices have tool calls that complete at the same time
create_multi_choice_chunk(vec![
(
"<tool_call>{\"name\": \"tool_0\", \"arguments\": {}}</tool_call>".to_string(),
0,
),
(
"<tool_call>{\"name\": \"tool_1\", \"arguments\": {}}</tool_call>".to_string(),
1,
),
(
"<tool_call>{\"name\": \"tool_2\", \"arguments\": {}}</tool_call>".to_string(),
2,
),
]),
];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// Find all tool call responses
let mut tool_call_responses: Vec<_> = results
.iter()
.filter_map(|r| r.data.as_ref())
.flat_map(|d| &d.choices)
.filter(|c| c.finish_reason == Some(FinishReason::ToolCalls))
.collect();
// Sort by the order they appear in the results
// With HashMap, this order will be non-deterministic
// With Vec, this should always be [0, 1, 2]
tool_call_responses.sort_by_key(|c| c.index);
assert_eq!(
tool_call_responses.len(),
3,
"Should have 3 tool call responses"
);
// Run this test multiple times to verify determinism
for run in 0..5 {
let chunks = vec![create_multi_choice_chunk(vec![
(
"<tool_call>{\"name\": \"tool_0\", \"arguments\": {}}</tool_call>".to_string(),
0,
),
(
"<tool_call>{\"name\": \"tool_1\", \"arguments\": {}}</tool_call>".to_string(),
1,
),
(
"<tool_call>{\"name\": \"tool_2\", \"arguments\": {}}</tool_call>".to_string(),
2,
),
])];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder().tool_call_parser("hermes").build();
let jailed_stream = jail.apply(input_stream);
let run_results: Vec<_> = jailed_stream.collect().await;
let run_responses: Vec<_> = run_results
.iter()
.filter_map(|r| r.data.as_ref())
.flat_map(|d| &d.choices)
.filter(|c| c.finish_reason == Some(FinishReason::ToolCalls))
.collect();
// The order should be consistent across runs
// This will fail with HashMap due to non-deterministic iteration
let indices: Vec<u32> = run_responses.iter().map(|c| c.index).collect();
assert_eq!(
indices,
vec![0, 1, 2],
"Choice processing order should be deterministic on run {}",
run
);
}
}
#[tokio::test]
async fn test_multiple_choices_usage_aggregation() {
// Test that usage is correctly aggregated across multiple choices
// This test demonstrates how usage should work with n>1
// For now, this test just documents expected behavior
// It will need to be expanded once usage aggregation is implemented
let chunks = vec![create_multi_choice_chunk(vec![
("Response A with many tokens".to_string(), 0), // ~5 tokens
("Response B".to_string(), 1), // ~2 tokens
("Response C has even more tokens than A".to_string(), 2), // ~8 tokens
])];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder().build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// TODO: Once usage aggregation is implemented, verify:
// - Usage chunk has choices: [] (empty array)
// - completion_tokens = sum of all choices (~15 total)
// - prompt_tokens counted once
// - total_tokens = prompt_tokens + completion_tokens
// For now, just verify we got some results
assert!(!results.is_empty(), "Should have some results");
}
#[tokio::test]
async fn test_partial_matching_false_positive_prevention() {
// Input chunks:
// [0] "n "
// [1] "<"
// [2] " 5"
//
// Expected output:
// [0] Content("n ")
// [1] Content("< 5") // "<" held as partial, then combined with " 5" when pattern doesn't match
let chunks = vec![
create_mock_response_chunk("n ".to_string(), 0),
create_mock_response_chunk("<".to_string(), 0),
create_mock_response_chunk(" 5".to_string(), 0),
];
let input_stream = stream::iter(chunks);
// Use nemotron parser which has <TOOLCALL> as a pattern
let jail = JailedStream::builder()
.tool_call_parser("nemotron_deci")
.build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// === Verify chunk count ===
assert_eq!(
results.len(),
2,
"Should emit exactly 2 chunks: 'n ' and '< 5'"
);
// === Verify individual chunks ===
assert_content(&results[0], "n ");
assert_content(&results[1], "< 5");
// === Verify negative assertions ===
// Verify NO tool calls were detected
for (i, result) in results.iter().enumerate() {
assert!(
!has_tool_call(result),
"Chunk {} should not contain tool calls in mathematical expression",
i
);
}
// === Verify content reconstruction ===
assert_eq!(
reconstruct_content(&results),
"n < 5",
"Content reconstruction should preserve the complete mathematical expression"
);
}
#[tokio::test]
async fn test_partial_matching_suffix_detection() {
// Input chunks:
// [0] "text<TO"
// [1] "OLCALL>[{\"name\": \"test\", \"arguments\": {}}]</TOOLCALL>"
//
// Expected output:
// [0] Content("text") // "<TO" held as partial
// [1] ToolCall("test", {})
let chunks = vec![
create_mock_response_chunk("text<TO".to_string(), 0),
create_mock_response_chunk(
"OLCALL>[{\"name\": \"test\", \"arguments\": {}}]</TOOLCALL>".to_string(),
0,
),
];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder()
.tool_call_parser("nemotron_deci")
.jail_end_sequence("</TOOLCALL>")
.build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// === Verify chunk count ===
assert_eq!(
results.len(),
2,
"Should emit exactly 2 chunks: [0] 'text' content, [1] tool call"
);
// === Verify individual chunks ===
assert_content(&results[0], "text");
assert_tool_call(&results[1], "test", json!({}));
// === Verify negative assertions ===
// Verify '<' was not emitted in first chunk (held as partial)
let first_content = extract_content(&results[0]);
assert!(
!first_content.contains('<'),
"First chunk should not contain '<' as it's part of partial match '<TO'"
);
// === Verify content reconstruction ===
assert_eq!(
reconstruct_content(&results),
"text",
"Content reconstruction should only include 'text' (tool call parsed separately)"
);
}
#[tokio::test]
async fn test_jailed_stream_harmony_parser() {
// Harmony format with analysis text and a tool call encoded in special tags
let chunks = vec![
create_mock_response_chunk(
"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|>"
.to_string(),
0,
),
create_mock_response_chunk("<|start|>".to_string(), 0),
create_mock_response_chunk("assistant".to_string(), 0),
create_mock_response_chunk("<|channel|>".to_string(), 0),
create_mock_response_chunk(
"commentary to=functions.get_current_weather <|constrain|>json".to_string(),
0,
),
create_mock_response_chunk(
"<|message|>{\"location\":\"San Francisco\"}<|call|>".to_string(),
0,
),
];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder().tool_call_parser("harmony").build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// Should have at least one output containing both analysis text and parsed tool call
assert!(!results.is_empty());
println!("results: {:?}", results);
// Verify the analysis text appears as content in one of the outputs
let has_analysis_text = results.iter().any(|r| {
r.data
.as_ref()
.and_then(|d| d.choices.first())
.and_then(|c| c.delta.content.as_ref())
.map(|content| content.contains("Need to use function get_current_weather."))
.unwrap_or(false)
});
assert!(has_analysis_text, "Should contain extracted analysis text");
// Verify a tool call was parsed with expected name and args
let tool_call_idx = results
.iter()
.position(test_utils::has_tool_call)
.expect("Should have a tool call result");
test_utils::assert_tool_call(
&results[tool_call_idx],
"get_current_weather",
json!({"location": "San Francisco"}),
);
}
#[tokio::test]
async fn test_jailed_stream_mistral_false_positive_curly() {
// Curly brace in normal text should not trigger tool call detection for mistral
let chunks = vec![
create_mock_response_chunk("Hey How".to_string(), 0),
create_mock_response_chunk("are { you? ".to_string(), 0),
create_final_response_chunk(0),
];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder().tool_call_parser("mistral").build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
println!("results: {:?}", results);
assert!(results.len() >= 2);
assert_content(&results[0], "Hey How");
assert!(
results.iter().any(|r| extract_content(r) == "are { you? "),
"Should preserve the literal text with curly brace"
);
for (i, r) in results.iter().enumerate() {
assert!(
!has_tool_call(r),
"Result {} should not contain tool calls for false-positive text",
i
);
}
}
#[tokio::test]
#[ignore]
// TODO: This needs to be fixed in parser library. P1 priority.
async fn test_jailed_stream_mistral_false_positive_then_tool_calls_marker() {
// Normal text with curly brace followed by explicit [TOOL_CALLS] marker should parse tool call
let chunks = vec![
create_mock_response_chunk("Hey How".to_string(), 0),
create_mock_response_chunk("are { you? ".to_string(), 0),
create_mock_response_chunk("[TOOL_CALLS]".to_string(), 0),
create_mock_response_chunk(
"[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"San Francisco\", \"unit\": \"fahrenheit\"}}]"
.to_string(),
0,
),
];
let input_stream = stream::iter(chunks);
let jail = JailedStream::builder().tool_call_parser("mistral").build();
let jailed_stream = jail.apply(input_stream);
let results: Vec<_> = jailed_stream.collect().await;
// Should preserve earlier content and also produce a tool call
assert!(results.len() >= 2);
assert!(
results.iter().any(|r| extract_content(r) == "Hey How"),
"Should include initial content"
);
assert!(
results.iter().any(|r| extract_content(r) == "{ you? "),
"Should include content preceding the marker"
);
let tool_call_idx = results
.iter()
.position(test_utils::has_tool_call)
.expect("Should have a tool call result");
test_utils::assert_tool_call(
&results[tool_call_idx],
"get_weather",
json!({"location": "San Francisco", "unit": "fahrenheit"}),
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use async_trait::async_trait;
use dynamo_async_openai::types::ChatCompletionToolChoiceOption;
use dynamo_async_openai::types::CreateChatCompletionRequest;
use dynamo_async_openai::types::{
ChatChoiceStream, ChatCompletionStreamResponseDelta, FinishReason as OAIFinishReason, Role,
};
use dynamo_llm::preprocessor::{
ANNOTATION_POSSIBLE_TOOL_CALL, PossibleToolCallAnnotation, apply_tool_calling_jail_internal,
maybe_enable_tool_call,
};
use dynamo_llm::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
};
use dynamo_parsers::tool_calling::parsers::detect_tool_call_start;
use dynamo_runtime::pipeline::ResponseStream;
use dynamo_runtime::protocols::annotated::Annotated;
use futures::stream::{self, StreamExt};
use std::sync::Arc;
#[allow(deprecated)]
// Helper function to create a mock chat response chunk
fn create_mock_response_chunk(
content: String,
index: u32,
) -> Annotated<NvCreateChatCompletionStreamResponse> {
let choice = ChatChoiceStream {
index,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: Some(content),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: None,
};
let response = NvCreateChatCompletionStreamResponse {
id: "test-id".to_string(),
choices: vec![choice],
created: 1234567890,
model: "test-model".to_string(),
system_fingerprint: Some("test-fingerprint".to_string()),
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
};
Annotated {
data: Some(response),
id: None,
event: None,
comment: None,
}
}
#[allow(deprecated)]
// Helper function to create a final response chunk with finish reason
fn create_final_response_chunk(index: u32) -> Annotated<NvCreateChatCompletionStreamResponse> {
let choice = ChatChoiceStream {
index,
delta: ChatCompletionStreamResponseDelta {
role: None,
content: None,
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: Some(OAIFinishReason::Stop),
logprobs: None,
};
let response = NvCreateChatCompletionStreamResponse {
id: "test-id".to_string(),
choices: vec![choice],
created: 1234567890,
model: "test-model".to_string(),
system_fingerprint: Some("test-fingerprint".to_string()),
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
};
Annotated {
data: Some(response),
id: None,
event: None,
comment: None,
}
}
// Mock async engine context for testing
#[derive(Debug)]
struct MockAsyncEngineContext {
id: String,
stopped: std::sync::atomic::AtomicBool,
}
impl MockAsyncEngineContext {
fn new(id: String) -> Self {
Self {
id,
stopped: std::sync::atomic::AtomicBool::new(false),
}
}
}
#[async_trait]
impl dynamo_runtime::pipeline::AsyncEngineContext for MockAsyncEngineContext {
fn id(&self) -> &str {
&self.id
}
fn stop(&self) {
self.stopped
.store(true, std::sync::atomic::Ordering::Relaxed);
}
fn stop_generating(&self) {
self.stopped
.store(true, std::sync::atomic::Ordering::Relaxed);
}
fn kill(&self) {
self.stopped
.store(true, std::sync::atomic::Ordering::Relaxed);
}
fn is_stopped(&self) -> bool {
self.stopped.load(std::sync::atomic::Ordering::Relaxed)
}
fn is_killed(&self) -> bool {
self.stopped.load(std::sync::atomic::Ordering::Relaxed)
}
async fn stopped(&self) {
// No-op for testing
}
async fn killed(&self) {
// No-op for testing
}
fn link_child(&self, _: Arc<dyn dynamo_runtime::pipeline::AsyncEngineContext>) {
// No-op for testing
}
}
#[tokio::test]
async fn test_apply_tool_calling_jail_internal_with_tool_call_detection() {
// Create a stream with tool call content that SHOULD trigger jailing
let mock_context = Arc::new(MockAsyncEngineContext::new("test-request-id".to_string()));
// Create chunks that represent a tool call being generated
let chunks = vec![
create_mock_response_chunk("<TOOLCALL>".to_string(), 0),
create_mock_response_chunk("[{\"name\": \"get_weather\", ".to_string(), 0),
create_mock_response_chunk(
"\"arguments\": {\"location\": \"San Francisco\"}}]".to_string(),
0,
),
create_mock_response_chunk("</TOOLCALL>".to_string(), 0),
];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
// Apply the jail with nemotron_deci parser - should trigger jailing on first chunk
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("nemotron_deci".to_string())).await;
// Collect all results
let results: Vec<_> = jailed_stream.collect().await;
// Verify that jailing was triggered
assert!(!results.is_empty(), "Should have some results");
// Results should be of length 1
// First Stream: [{"name": "get_weather", "arguments":"{"location": "San Francisco"}}]"
assert_eq!(results.len(), 1);
assert!(
results[0].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.is_some()
);
let tools = results[0].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tools.len(), 1);
let name = tools[0].function.as_ref().unwrap().name.as_ref().unwrap();
let arguments = serde_json::from_str::<serde_json::Value>(
tools[0]
.function
.as_ref()
.unwrap()
.arguments
.as_ref()
.unwrap(),
)
.unwrap();
assert_eq!(name, "get_weather");
assert_eq!(arguments["location"], "San Francisco");
}
#[tokio::test]
async fn test_apply_tool_calling_jail_internal_no_tool_calls() {
// Create a stream with regular content that should NOT trigger jailing
let mock_context = Arc::new(MockAsyncEngineContext::new("test-request-id-2".to_string()));
let chunks = vec![
create_mock_response_chunk("Hello, ".to_string(), 0),
create_mock_response_chunk("how can I ".to_string(), 0),
create_mock_response_chunk("help you today?".to_string(), 0),
create_final_response_chunk(0),
];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
// Apply the jail with nemotron_deci parser - regular text should NOT be jailed
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("nemotron_deci".to_string())).await;
// Collect all results
let results: Vec<_> = jailed_stream.collect().await;
// Should have results and they should NOT be jailed (content should be preserved)
assert!(!results.is_empty(), "Should have results");
assert_eq!(results.len(), 4, "Should have all 4 chunks");
// Verify that content is NOT jailed - first few chunks should have their original content
for (i, result) in results.iter().take(3).enumerate() {
if let Some(ref response_data) = result.data {
let expected_content = match i {
0 => "Hello, ",
1 => "how can I ",
2 => "help you today?",
_ => unreachable!(),
};
assert_eq!(
response_data.choices[0].delta.content.as_deref(),
Some(expected_content),
"Chunk {} should have original content, not be jailed",
i
);
// Should NOT have annotation events for regular content
assert!(
result.event.is_none(),
"Regular content should not have annotation events"
);
}
}
// Last chunk should be the final response with finish reason
if let Some(last_result) = results.last()
&& let Some(ref response_data) = last_result.data
{
assert_eq!(
response_data.choices[0].finish_reason,
Some(OAIFinishReason::Stop)
);
}
}
#[tokio::test]
async fn test_apply_tool_calling_jail_internal_with_empty_stream() {
let mock_context = Arc::new(MockAsyncEngineContext::new("test-request-id-3".to_string()));
let chunks: Vec<Annotated<NvCreateChatCompletionStreamResponse>> = vec![];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream = apply_tool_calling_jail_internal(response_stream, None).await;
let results: Vec<_> = jailed_stream.collect().await;
assert!(results.is_empty(), "Empty stream should produce no results");
}
#[tokio::test]
async fn test_apply_tool_calling_jail_internal_with_different_parsers() {
let mock_context = Arc::new(MockAsyncEngineContext::new("test-request-id-4".to_string()));
// Test with hermes parser format
let chunks = vec![
create_mock_response_chunk("<tool_call>".to_string(), 0),
create_mock_response_chunk(
"{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Tokyo\"}}".to_string(),
0,
),
create_mock_response_chunk("</tool_call>".to_string(), 0),
];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("hermes".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await;
assert!(!results.is_empty(), "Should have results for hermes parser");
}
#[tokio::test]
async fn test_detect_tool_call_start_different_parsers() {
// Test nemotron_deci parser
assert!(detect_tool_call_start("<TOOLCALL>", Some("nemotron_deci")).unwrap());
assert!(!detect_tool_call_start("Hello world", Some("nemotron_deci")).unwrap());
assert!(!detect_tool_call_start("<tool_call>", Some("nemotron_deci")).unwrap()); // Wrong format
// Test hermes parser - now also detects JSON patterns
assert!(detect_tool_call_start("<tool_call>", Some("hermes")).unwrap());
assert!(detect_tool_call_start("{\"name\": \"test\"}", Some("hermes")).unwrap()); // JSON detection
assert!(!detect_tool_call_start("Hello world", Some("hermes")).unwrap());
assert!(!detect_tool_call_start("<TOOLCALL>", Some("hermes")).unwrap()); // Wrong format
// Test phi4 parser
assert!(detect_tool_call_start("functools[", Some("phi4")).unwrap());
assert!(detect_tool_call_start("{\"name\": \"test\"}", Some("phi4")).unwrap()); // JSON detection
assert!(!detect_tool_call_start("Hello world", Some("phi4")).unwrap());
// Test mistral parser
assert!(detect_tool_call_start("[{", Some("mistral")).unwrap());
assert!(detect_tool_call_start("[TOOL_CALLS]", Some("mistral")).unwrap());
assert!(!detect_tool_call_start("Hello world", Some("mistral")).unwrap());
// Test llama3_json parser
assert!(detect_tool_call_start("<|python_tag|>", Some("llama3_json")).unwrap());
assert!(detect_tool_call_start("{\"name\": \"test\"}", Some("llama3_json")).unwrap()); // JSON detection
// Test default parser (should behave like nemotron_deci)
assert!(detect_tool_call_start("<TOOLCALL>", None).unwrap());
assert!(detect_tool_call_start("{\"name\": \"test\"}", None).unwrap()); // JSON detection
assert!(!detect_tool_call_start("Hello world", None).unwrap());
}
#[tokio::test]
async fn test_apply_tool_calling_jail_internal_hermes_parser() {
// Test with hermes parser format
let mock_context = Arc::new(MockAsyncEngineContext::new(
"test-request-id-hermes".to_string(),
));
let chunks = vec![
create_mock_response_chunk("I'll help you with that. ".to_string(), 0),
create_mock_response_chunk("<tool_call>".to_string(), 0), // This should trigger jailing
create_mock_response_chunk(
"{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Tokyo\"}}".to_string(),
0,
),
create_mock_response_chunk("</tool_call>".to_string(), 0),
];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("hermes".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await;
assert!(!results.is_empty(), "Should have results for hermes parser");
// Results should be of length 2
// First Stream : I'll help you with that.
// Second Stream : [{"name": "get_weather", "arguments":"{"location": "Tokyo"}}]" (jailed)
assert_eq!(results.len(), 2);
assert_eq!(
results[0].data.as_ref().unwrap().choices[0].delta.content,
Some("I'll help you with that. ".to_string())
);
assert!(
results[1].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.is_some()
);
let tools = results[1].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tools.len(), 1);
let name = tools[0].function.as_ref().unwrap().name.as_ref().unwrap();
let arguments = serde_json::from_str::<serde_json::Value>(
tools[0]
.function
.as_ref()
.unwrap()
.arguments
.as_ref()
.unwrap(),
)
.unwrap();
assert_eq!(name, "get_weather");
assert_eq!(arguments["location"], "Tokyo");
}
#[tokio::test]
async fn test_possible_tool_call_annotation_serialization() {
let annotation = PossibleToolCallAnnotation {
possible_tokens: 5,
possible_content: "test content".to_string(),
parser_used: Some("nemotron_deci".to_string()),
};
let annotated_result = annotation.to_annotation::<NvCreateChatCompletionStreamResponse>();
assert!(
annotated_result.is_ok(),
"Should be able to create annotation"
);
let annotated = annotated_result.unwrap();
assert_eq!(
annotated.event,
Some(ANNOTATION_POSSIBLE_TOOL_CALL.to_string())
);
assert!(annotated.comment.is_some(), "Should have comment");
// Test deserialization
let parsed_annotation = PossibleToolCallAnnotation::from_annotation(&annotated);
assert!(
parsed_annotation.is_ok(),
"Should be able to parse annotation"
);
let parsed = parsed_annotation.unwrap();
assert!(parsed.is_some(), "Should have parsed annotation");
let parsed = parsed.unwrap();
assert_eq!(parsed.possible_tokens, 5);
assert_eq!(parsed.possible_content, "test content");
assert_eq!(parsed.parser_used, Some("nemotron_deci".to_string()));
}
#[tokio::test]
async fn test_apply_tool_calling_jail_internal_mistral_parser_with_no_tool_call_start_token() {
let mock_context = Arc::new(MockAsyncEngineContext::new(
"test-request-id-mistral".to_string(),
));
let chunks = vec![
create_mock_response_chunk("Hey How".to_string(), 0),
create_mock_response_chunk("are you? ".to_string(), 0),
create_mock_response_chunk(r#"[{"name": "get_weather", "arguments":"#.to_string(), 0),
create_mock_response_chunk(
r#"{"location": "San Francisco", "unit": "fahrenheit"}}]"#.to_string(),
0,
),
create_final_response_chunk(0),
];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await;
assert!(
!results.is_empty(),
"Should have results for mistral parser"
);
// Results should be of length 4
// First Stream : Hey How
// Second Stream : are you?
// Third Stream : None (final response chunk)
// Fourth Stream : [{"name": "get_weather", "arguments":"{"location": "San Francisco", "unit": "fahrenheit"}}]" (jailed)
assert_eq!(results.len(), 4);
// First two normal text
assert_eq!(
results[0].data.as_ref().unwrap().choices[0].delta.content,
Some("Hey How".to_string())
);
assert_eq!(
results[1].data.as_ref().unwrap().choices[0].delta.content,
Some("are you? ".to_string())
);
assert_eq!(
results[2].data.as_ref().unwrap().choices[0].delta.content,
None
);
// Final tool call
assert!(
results[3].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.is_some()
);
let tools = results[3].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tools.len(), 1);
let name = tools[0].function.as_ref().unwrap().name.as_ref().unwrap();
let arguments = serde_json::from_str::<serde_json::Value>(
tools[0]
.function
.as_ref()
.unwrap()
.arguments
.as_ref()
.unwrap(),
)
.unwrap();
assert_eq!(name, "get_weather");
assert_eq!(arguments["location"], "San Francisco");
assert_eq!(arguments["unit"], "fahrenheit");
}
#[tokio::test]
async fn test_apply_tool_calling_jail_internal_mistral_parser_with_false_positive_tool_start() {
let mock_context = Arc::new(MockAsyncEngineContext::new(
"test-request-id-mistral".to_string(),
));
let chunks = vec![
create_mock_response_chunk("Hey How".to_string(), 0),
create_mock_response_chunk("are { you? ".to_string(), 0),
create_final_response_chunk(0),
];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await;
assert!(
!results.is_empty(),
"Should have results for mistral parser"
);
// Results should be of length 3
// First Stream : Hey How
// Second Stream : None (final response chunk)
// Third Stream : are { you? (normal text field from tool-call-parse-aggregate)
assert_eq!(results.len(), 3);
assert_eq!(
results[0].data.as_ref().unwrap().choices[0].delta.content,
Some("Hey How".to_string())
);
assert_eq!(
results[1].data.as_ref().unwrap().choices[0].delta.content,
None
);
assert_eq!(
results[2].data.as_ref().unwrap().choices[0].delta.content,
Some("are { you?".to_string())
);
}
#[tokio::test]
async fn test_apply_tool_calling_jail_internal_mistral_parser_with_false_positive_tool_start_and_tool_call_token()
{
let mock_context = Arc::new(MockAsyncEngineContext::new(
"test-request-id-mistral".to_string(),
));
let chunks = vec![
create_mock_response_chunk("Hey How".to_string(), 0),
create_mock_response_chunk("are { you? ".to_string(), 0),
create_mock_response_chunk(
r#"[TOOL_CALLS][{"name": "get_weather", "arguments":"#.to_string(),
0,
),
create_mock_response_chunk(
r#"{"location": "San Francisco", "unit": "fahrenheit"}}]"#.to_string(),
0,
),
create_final_response_chunk(0),
];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("mistral".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await;
assert!(
!results.is_empty(),
"Should have results for mistral parser"
);
// Results should be of length 3
// First Stream : Hey How
// Second Stream : None (final response chunk)
// Third Stream : Content: are { you? , Tool Calls: [{"name": "get_weather", "arguments":"{"location": "San Francisco", "unit": "fahrenheit"}}]"
assert_eq!(results.len(), 3);
assert_eq!(
results[0].data.as_ref().unwrap().choices[0].delta.content,
Some("Hey How".to_string())
);
assert_eq!(
results[1].data.as_ref().unwrap().choices[0].delta.content,
None
);
assert_eq!(
results[2].data.as_ref().unwrap().choices[0].delta.content,
Some("are { you?".to_string())
);
assert!(
results[2].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.is_some()
);
let tools = results[2].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tools.len(), 1);
let name = tools[0].function.as_ref().unwrap().name.as_ref().unwrap();
let arguments = serde_json::from_str::<serde_json::Value>(
tools[0]
.function
.as_ref()
.unwrap()
.arguments
.as_ref()
.unwrap(),
)
.unwrap();
assert_eq!(name, "get_weather");
assert_eq!(arguments["location"], "San Francisco");
assert_eq!(arguments["unit"], "fahrenheit");
}
#[tokio::test]
async fn test_tool_calling_jail_internal_with_harmony_parser() {
let mock_context = Arc::new(MockAsyncEngineContext::new(
"test-request-id-harmony".to_string(),
));
// Harmony Format:
// <|channel|>analysis<|message|>Need to use function get_current_weather.<|end|>
// <|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json
// <|message|>{"location":"San Francisco"}<|call|>
let chunks = vec![
create_mock_response_chunk(
"<|channel|>analysis<|message|>Need to use function get_current_weather.<|end|>"
.to_string(),
0,
),
create_mock_response_chunk("<|start|>".to_string(), 0),
create_mock_response_chunk("assistant".to_string(), 0),
create_mock_response_chunk("<|channel|>".to_string(), 0),
create_mock_response_chunk(
"commentary to=functions.get_current_weather <|constrain|>json".to_string(),
0,
),
create_mock_response_chunk(
"<|message|>{\"location\":\"San Francisco\"}<|call|>".to_string(),
0,
),
create_final_response_chunk(0),
];
let input_stream = stream::iter(chunks);
let response_stream = ResponseStream::new(Box::pin(input_stream), mock_context.clone());
let jailed_stream =
apply_tool_calling_jail_internal(response_stream, Some("harmony".to_string())).await;
let results: Vec<_> = jailed_stream.collect().await;
assert!(
!results.is_empty(),
"Should have results for harmony parser"
);
assert_eq!(results.len(), 2);
assert_eq!(
results[1].data.as_ref().unwrap().choices[0].delta.content,
Some("Need to use function get_current_weather.".to_string())
);
assert!(
results[1].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.is_some()
);
let tools = results[1].data.as_ref().unwrap().choices[0]
.delta
.tool_calls
.as_ref()
.unwrap();
assert_eq!(tools.len(), 1);
let name = tools[0].function.as_ref().unwrap().name.as_ref().unwrap();
let arguments = serde_json::from_str::<serde_json::Value>(
tools[0]
.function
.as_ref()
.unwrap()
.arguments
.as_ref()
.unwrap(),
)
.unwrap();
assert_eq!(name, "get_current_weather");
assert_eq!(arguments["location"], "San Francisco");
}
#[test]
fn test_enable_tool_call() {
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
tool_choice: Some(ChatCompletionToolChoiceOption::Auto),
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
assert!(maybe_enable_tool_call(Some("nemotron_deci"), &request));
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
tool_choice: Some(ChatCompletionToolChoiceOption::None),
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
assert!(!maybe_enable_tool_call(Some("nemotron_deci"), &request));
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
tool_choice: Some(ChatCompletionToolChoiceOption::Required),
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
assert!(maybe_enable_tool_call(Some("nemotron_deci"), &request));
let request = NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
tool_choice: Some(ChatCompletionToolChoiceOption::Auto),
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
};
assert!(!maybe_enable_tool_call(None, &request));
}
...@@ -159,11 +159,14 @@ async fn test_streaming_without_usage() { ...@@ -159,11 +159,14 @@ async fn test_streaming_without_usage() {
// Create mock backend stream // Create mock backend stream
let ctx = Arc::new(MockContext::new()); let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx); let backend_stream = create_mock_backend_stream(ctx.clone());
// Transform the stream // Transform the stream
let transformed_stream = let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
OpenAIPreprocessor::transform_postprocessor_stream(backend_stream, response_generator); backend_stream,
response_generator,
ctx.clone(),
);
// Collect all chunks // Collect all chunks
let chunks: Vec<_> = transformed_stream.collect().await; let chunks: Vec<_> = transformed_stream.collect().await;
...@@ -197,11 +200,14 @@ async fn test_streaming_with_usage_compliance() { ...@@ -197,11 +200,14 @@ async fn test_streaming_with_usage_compliance() {
// Create mock backend stream // Create mock backend stream
let ctx = Arc::new(MockContext::new()); let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx); let backend_stream = create_mock_backend_stream(ctx.clone());
// Transform the stream // Transform the stream
let transformed_stream = let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
OpenAIPreprocessor::transform_postprocessor_stream(backend_stream, response_generator); backend_stream,
response_generator,
ctx.clone(),
);
// Collect all chunks // Collect all chunks
let chunks: Vec<_> = transformed_stream.collect().await; let chunks: Vec<_> = transformed_stream.collect().await;
...@@ -267,11 +273,14 @@ async fn test_streaming_with_usage_false() { ...@@ -267,11 +273,14 @@ async fn test_streaming_with_usage_false() {
// Create mock backend stream // Create mock backend stream
let ctx = Arc::new(MockContext::new()); let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx); let backend_stream = create_mock_backend_stream(ctx.clone());
// Transform the stream // Transform the stream
let transformed_stream = let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
OpenAIPreprocessor::transform_postprocessor_stream(backend_stream, response_generator); backend_stream,
response_generator,
ctx.clone(),
);
// Collect all chunks // Collect all chunks
let chunks: Vec<_> = transformed_stream.collect().await; let chunks: Vec<_> = transformed_stream.collect().await;
......
...@@ -162,7 +162,6 @@ pub async fn parse_tool_calls_harmony( ...@@ -162,7 +162,6 @@ pub async fn parse_tool_calls_harmony(
} }
Ok((res, Some(normal_text.to_string()))) Ok((res, Some(normal_text.to_string())))
} }
/// Parse tool calls from a complete Harmony Format text chunk using direct token parsing. /// Parse tool calls from a complete Harmony Format text chunk using direct token parsing.
/// ///
/// This function is optimized for parsing complete text chunks where the entire content /// This function is optimized for parsing complete text chunks where the entire content
......
...@@ -13,7 +13,7 @@ pub mod tools; ...@@ -13,7 +13,7 @@ pub mod tools;
pub use config::{JsonParserConfig, ToolCallConfig, ToolCallParserType}; pub use config::{JsonParserConfig, ToolCallConfig, ToolCallParserType};
pub use harmony::{parse_tool_calls_harmony, parse_tool_calls_harmony_complete}; pub use harmony::{parse_tool_calls_harmony, parse_tool_calls_harmony_complete};
pub use json::try_tool_call_parse_json; pub use json::try_tool_call_parse_json;
pub use parsers::{detect_and_parse_tool_call, try_tool_call_parse}; pub use parsers::{detect_and_parse_tool_call, detect_tool_call_start, try_tool_call_parse};
pub use pythonic::try_tool_call_parse_pythonic; pub use pythonic::try_tool_call_parse_pythonic;
pub use response::{CalledFunction, ToolCallResponse, ToolCallType}; pub use response::{CalledFunction, ToolCallResponse, ToolCallType};
pub use tools::{try_tool_call_parse_aggregate, try_tool_call_parse_stream}; pub use tools::{try_tool_call_parse_aggregate, try_tool_call_parse_stream};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment