Unverified Commit c63cceaa authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: JailedStream (#3034)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 4c56c8ae
......@@ -92,7 +92,6 @@ generated-values.yaml
**/.devcontainer/.env
TensorRT-LLM
# Ruler Generated Files
/.cursor/instructions.md
/.cursor/instructions.md.bak
......
......@@ -1981,6 +1981,7 @@ name = "dynamo-llm"
version = "0.5.0"
dependencies = [
"ahash",
"aho-corasick",
"akin",
"aligned-vec",
"anyhow",
......
# JailedStream Implementation
## Overview
The `JailedStream` is a standalone implementation for handling "jail" detection in token streams. It provides a clean, builder-based API for accumulating tokens when certain sequences are detected, then releasing them as a single chunk when the jail ends.
## Key Features
- **Builder Pattern**: Clean configuration API using the builder pattern
- **Configurable Sequences**: Support for multiple start/end jail sequences
- **Tool Call Parsing**: Integrated tool call detection and parsing
- **Stream Macro**: Uses `async-stream::stream!` for clean async implementation
- **Standalone**: Completely independent of existing code
- **Annotations**: Preserves annotations for observability
## Implementation
### Location
- Main implementation: `lib/llm/src/protocols/openai/chat_completions/jail.rs`
- Examples: `lib/llm/src/protocols/openai/chat_completions/jail_example.rs`
### Usage
```rust
use crate::protocols::openai::chat_completions::jail::JailedStream;
use dynamo_runtime::engine::{AsyncEngineContextProvider, ResponseStream};
// Get your ResponseStream with context
let response_stream: Pin<Box<ResponseStream<_>>> = get_stream_from_engine();
// Extract context BEFORE passing to apply
let context = response_stream.context();
// Apply jail transformation (ResponseStream implements Stream)
let jail = JailedStream::builder()
.tool_call_parser("nemotron_deci")
.build();
let jailed_stream = jail.apply(response_stream);
// Re-wrap with context when needed for engine consumption
let final_stream = ResponseStream::new(Box::pin(jailed_stream), context);
```
### Advanced Configuration
```rust
// With custom jail sequences
let jail = JailedStream::builder()
.jail_start_sequence("<TOOLCALL>")
.jail_end_sequence("</TOOLCALL>")
.tool_call_parser("nemotron_deci")
.build();
// With multiple sequences
let jail = JailedStream::builder()
.jail_start_sequences(vec!["<TOOLCALL>", "<FUNCTION>"])
.jail_end_sequences(vec!["</TOOLCALL>", "</FUNCTION>"])
.tool_call_parser("harmony")
.build();
```
## How It Works
1. **Detection**: When a jail start sequence (or tool call start) is detected, the stream enters "jail" mode
2. **Accumulation**: While jailed, tokens are accumulated in memory instead of being yielded
3. **Annotations**: Empty chunks with annotations are sent downstream for observability
4. **Release**: When a jail end sequence is detected OR the stream ends:
- Accumulated content is parsed for tool calls
- A single chunk with the parsed content is yielded
5. **Pass-through**: Non-jailed content passes through unchanged
## Testing
The implementation includes comprehensive tests:
- `test_jailed_stream_with_start_end_sequences`: Tests explicit jail sequences
- `test_jailed_stream_with_tool_calls`: Tests tool call detection and parsing
- `test_jailed_stream_no_jailing`: Tests normal pass-through behavior
Run tests with:
```bash
cargo test -p dynamo-llm jail --lib
```
## Benefits
1. **Standalone**: No modifications to existing code required
2. **Clean API**: Builder pattern makes configuration intuitive
3. **Flexible**: Supports multiple jail detection strategies
4. **Maintainable**: Uses `stream!` macro for cleaner async code
5. **Testable**: Comprehensive test suite with shared utilities
6. **Efficient**: No unnecessary boxing or context handling in the library
7. **Composable**: Can chain multiple stream transformers before re-adding context
## Performance Optimizations
- **No Boxing in Library**: Returns `impl Stream` instead of `Pin<Box<ResponseStream>>`
- **Stack Pinning**: Uses `tokio::pin!()` instead of `Box::pin()` for better performance
- **No Context Overhead**: JailedStream doesn't manage AsyncEngineContext
- **Lazy Evaluation**: Only processes what's needed
- **Efficient State Management**: Minimal cloning, only when entering jail state
## Integration Options
To replace the existing `apply_tool_calling_jail_internal` function:
```rust
// In preprocessor.rs
pub fn apply_tool_calling_jail_with_parser(
&self,
stream: ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
) -> ManyOut<Annotated<NvCreateChatCompletionStreamResponse>> {
let jail = JailedStream::builder()
.tool_call_parser(self.tool_call_parser.clone())
.build();
jail.apply(stream)
}
```
## Future Enhancements
- Add support for regex patterns for jail sequences
- Add metrics/telemetry for jail detection
- Support for partial sequence matching across chunk boundaries
- Configurable accumulation limits
- Support for nested jails
\ No newline at end of file
......@@ -1382,6 +1382,7 @@ name = "dynamo-llm"
version = "0.5.0"
dependencies = [
"ahash",
"aho-corasick",
"akin",
"anyhow",
"async-nats",
......
......@@ -52,6 +52,7 @@ required-features = ["block-manager", "testing-cuda"]
dynamo-runtime = { workspace = true }
# workspace
aho-corasick = "1.1"
anyhow = { workspace = true }
dynamo-async-openai = { workspace = true }
dynamo-parsers = { workspace = true}
......
......@@ -37,6 +37,7 @@ pub mod request_template;
pub mod tokenizers;
pub mod tokens;
pub mod types;
pub mod utils;
#[cfg(feature = "block-manager")]
pub mod block_manager;
......
This diff is collapsed.
......@@ -19,6 +19,7 @@ use super::{
pub mod aggregator;
mod delta;
pub mod jail;
pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
......
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod prefix_matcher;
pub use prefix_matcher::{MarkerMatcher, MatchResult};
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Efficient multi-pattern marker detection with partial suffix matching
//!
//! This module provides utilities for detecting complete and partial marker patterns
//! in streaming text, with support for detecting markers split across chunk boundaries.
use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind};
use std::collections::HashMap;
/// Result of processing a chunk with potential marker detection
#[derive(Debug, Clone, PartialEq)]
pub enum MatchResult {
/// Complete marker found
Complete {
/// Content before the marker (safe to emit)
prefix: String,
/// The complete marker matched
marker: String,
/// Start position of the marker in the input
marker_start: usize,
/// Remaining content after the marker
suffix: String,
},
/// Partial marker at end of chunk
Partial {
/// Content before the partial (safe to emit)
prefix: String,
/// The partial match to hold
partial: String,
/// Which patterns this could match
possible_patterns: Vec<String>,
},
/// No markers detected
None {
/// All content is safe to emit
content: String,
},
}
/// Efficient multi-pattern matcher with partial suffix detection
pub struct MarkerMatcher {
/// All patterns we're looking for
patterns: Vec<String>,
/// Aho-Corasick matcher for complete patterns
complete_matcher: AhoCorasick,
/// Trie for partial matching
prefix_trie: PrefixTrie,
/// Maximum pattern length (for buffer limits)
max_pattern_len: usize,
}
impl MarkerMatcher {
/// Create a new matcher with the given patterns
pub fn new(patterns: Vec<String>) -> Result<Self, String> {
if patterns.is_empty() {
return Err("Cannot create MarkerMatcher with empty patterns".to_string());
}
let complete_matcher = AhoCorasickBuilder::new()
.match_kind(MatchKind::LeftmostFirst)
.build(&patterns)
.map_err(|e| format!("Failed to build Aho-Corasick matcher: {}", e))?;
let max_pattern_len = patterns.iter().map(|p| p.len()).max().unwrap_or(0);
let prefix_trie = PrefixTrie::new(&patterns);
Ok(Self {
patterns,
complete_matcher,
prefix_trie,
max_pattern_len,
})
}
/// Get the maximum pattern length
pub fn max_pattern_len(&self) -> usize {
self.max_pattern_len
}
/// Safe UTF-8 slicing that ensures we only slice at character boundaries
fn safe_slice(text: &str, start_byte: usize, end_byte: usize) -> String {
// Clamp indices to valid boundaries
let start = text
.char_indices()
.find(|(i, _)| *i >= start_byte)
.map(|(i, _)| i)
.unwrap_or(text.len());
let end = text
.char_indices()
.find(|(i, _)| *i >= end_byte)
.map(|(i, _)| i)
.unwrap_or(text.len());
text[start..end].to_string()
}
/// Process a chunk with an optional partial buffer from previous chunk
pub fn process_chunk(&self, chunk: &str, partial_buffer: &str) -> MatchResult {
// Combine buffer with new chunk
let combined = if partial_buffer.is_empty() {
chunk.to_string()
} else {
format!("{}{}", partial_buffer, chunk)
};
// First check for complete markers
if let Some(mat) = self.complete_matcher.find(&combined) {
let marker = &self.patterns[mat.pattern().as_usize()];
return MatchResult::Complete {
prefix: Self::safe_slice(&combined, 0, mat.start()),
marker: marker.clone(),
marker_start: mat.start(),
suffix: Self::safe_slice(&combined, mat.end(), combined.len()),
};
}
// No complete match - check for partial at ANY suffix position
// This is the key: check "n<T" → finds "<T" as partial
if let Some((partial_start, partial, patterns)) = self.find_partial_suffix(&combined) {
return MatchResult::Partial {
prefix: Self::safe_slice(&combined, 0, partial_start),
partial: partial.to_string(),
possible_patterns: patterns,
};
}
// No matches at all
MatchResult::None { content: combined }
}
/// Find the longest partial match in any suffix of the input
///
/// This scans from left to right to find the EARLIEST partial match,
/// ensuring we emit as much content as possible while holding only the minimal partial.
fn find_partial_suffix<'a>(&self, text: &'a str) -> Option<(usize, &'a str, Vec<String>)> {
// Start from the beginning to find the EARLIEST partial match
// This ensures we emit as much as possible
// Use char_indices to get valid UTF-8 boundaries
for (i, _) in text.char_indices() {
let suffix = &text[i..];
if let Some(patterns) = self.prefix_trie.find_prefix_match(suffix) {
// This suffix is a prefix of one or more patterns
return Some((i, suffix, patterns));
}
}
None
}
}
/// Trie structure for efficient prefix matching
struct PrefixTrie {
root: TrieNode,
}
#[derive(Debug)]
struct TrieNode {
children: HashMap<char, TrieNode>,
/// Patterns that have this exact prefix
matching_patterns: Vec<String>,
/// Is this node a complete pattern?
is_complete: bool,
}
impl PrefixTrie {
fn new(patterns: &[String]) -> Self {
let mut root = TrieNode {
children: HashMap::new(),
matching_patterns: Vec::new(),
is_complete: false,
};
// Build trie
for pattern in patterns {
let mut current = &mut root;
let chars: Vec<char> = pattern.chars().collect();
for (i, &ch) in chars.iter().enumerate() {
current = current.children.entry(ch).or_insert(TrieNode {
children: HashMap::new(),
matching_patterns: Vec::new(),
is_complete: false,
});
// Add this pattern to all prefix nodes
if !current.matching_patterns.contains(pattern) {
current.matching_patterns.push(pattern.clone());
}
// Mark complete if we're at the end
if i == chars.len() - 1 {
current.is_complete = true;
}
}
}
PrefixTrie { root }
}
/// Check if text is a prefix of any pattern (but not a complete pattern)
fn find_prefix_match(&self, text: &str) -> Option<Vec<String>> {
let mut current = &self.root;
for ch in text.chars() {
if let Some(node) = current.children.get(&ch) {
current = node;
} else {
// Not a prefix of any pattern
return None;
}
}
// If we matched the entire text and it's a prefix of something (but not complete)
if !current.matching_patterns.is_empty() && !current.is_complete {
Some(current.matching_patterns.clone())
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_complete_match() {
let patterns = vec!["<TOOLCALL>".to_string(), "<tool_call>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
let result = matcher.process_chunk("<TOOLCALL>data", "");
if let MatchResult::Complete {
prefix,
marker,
suffix,
..
} = result
{
assert_eq!(prefix, "");
assert_eq!(marker, "<TOOLCALL>");
assert_eq!(suffix, "data");
} else {
panic!("Expected complete match");
}
}
#[test]
fn test_partial_match_suffix() {
let patterns = vec!["<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
// Test the key case: "n<T" should detect "<T" as partial
let result = matcher.process_chunk("n<T", "");
if let MatchResult::Partial {
prefix,
partial,
possible_patterns,
} = result
{
assert_eq!(prefix, "n");
assert_eq!(partial, "<T");
assert_eq!(possible_patterns, vec!["<TOOLCALL>"]);
} else {
panic!("Expected partial match, got: {:?}", result);
}
}
#[test]
fn test_no_false_positive() {
let patterns = vec!["<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
// Test case: "n < 5" should not trigger partial match
let result = matcher.process_chunk("n < 5", "");
if let MatchResult::None { content } = result {
assert_eq!(content, "n < 5");
} else {
panic!("Expected no match, got: {:?}", result);
}
}
#[test]
fn test_partial_buffer_combination() {
let patterns = vec!["<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
// First chunk: partial "<"
let result1 = matcher.process_chunk("<", "");
let partial = if let MatchResult::Partial { partial, .. } = result1 {
partial
} else {
panic!("Expected partial match");
};
// Second chunk: "TOOLCALL>" completes the pattern
let result2 = matcher.process_chunk("TOOLCALL>", &partial);
if let MatchResult::Complete { marker, .. } = result2 {
assert_eq!(marker, "<TOOLCALL>");
} else {
panic!("Expected complete match, got: {:?}", result2);
}
}
#[test]
fn test_prefix_with_content() {
let patterns = vec!["<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
let result = matcher.process_chunk("text before <TOOLCALL> after", "");
if let MatchResult::Complete {
prefix,
marker,
suffix,
..
} = result
{
assert_eq!(prefix, "text before ");
assert_eq!(marker, "<TOOLCALL>");
assert_eq!(suffix, " after");
} else {
panic!("Expected complete match");
}
}
#[test]
fn test_empty_patterns() {
let result = MarkerMatcher::new(vec![]);
assert!(result.is_err());
}
#[test]
fn test_multiple_patterns() {
let patterns = vec![
"<TOOLCALL>".to_string(),
"[TOOL_CALLS]".to_string(),
"<tool_call>".to_string(),
];
let matcher = MarkerMatcher::new(patterns).unwrap();
// Test different patterns
let result1 = matcher.process_chunk("[TOOL_CALLS]", "");
if let MatchResult::Complete { marker, .. } = result1 {
assert_eq!(marker, "[TOOL_CALLS]");
} else {
panic!("Expected complete match for [TOOL_CALLS]");
}
// Test partial for different pattern
let result2 = matcher.process_chunk("text<to", "");
if let MatchResult::Partial {
partial,
possible_patterns,
..
} = result2
{
assert_eq!(partial, "<to");
assert!(possible_patterns.contains(&"<tool_call>".to_string()));
} else {
panic!("Expected partial match for <tool_call>");
}
}
#[test]
fn test_multiple_partial_matches_edge_case() {
// Test scenario: Multiple patterns where one looks like a prefix but isn't valid
// Patterns: ["FooBar", "<TOOLCALL>"]
// Input: "This is FooBaz which is a no, but <TOO"
// Key insight: "FooBa" from "FooBaz" is NOT a valid partial because the 'z'
// doesn't match the expected 'r' in "FooBar"
// Expected: Hold "<TOO" as partial, emit "This is FooBaz which is a no, but "
let patterns = vec!["FooBar".to_string(), "<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
let result = matcher.process_chunk("This is FooBaz which is a no, but <TOO", "");
if let MatchResult::Partial {
prefix,
partial,
possible_patterns,
} = result
{
// The algorithm correctly skips "FooBaz" (not a valid prefix) and finds "<TOO"
assert_eq!(partial, "<TOO");
assert_eq!(prefix, "This is FooBaz which is a no, but ");
assert!(possible_patterns.contains(&"<TOOLCALL>".to_string()));
} else {
panic!("Expected partial match for '<TOO>', got: {:?}", result);
}
}
#[test]
fn test_earliest_valid_partial_match() {
// Test that the algorithm finds the earliest VALID partial match
// Patterns: ["FooBar", "<TOOLCALL>"]
// Input: "Some text FooBa and then <TO"
// Analysis: "FooBa and then <TO" is not a valid prefix of "FooBar" because
// after "FooBa" we have " " (space) but "FooBar" expects "r"
// Expected: Skip invalid "FooBa..." and find valid "<TO" partial
let patterns = vec!["FooBar".to_string(), "<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
let result = matcher.process_chunk("Some text FooBa and then <TO", "");
if let MatchResult::Partial {
prefix,
partial,
possible_patterns,
} = result
{
// Should find "<TO" as the valid partial match
assert_eq!(partial, "<TO");
assert_eq!(prefix, "Some text FooBa and then ");
assert!(possible_patterns.contains(&"<TOOLCALL>".to_string()));
} else {
panic!("Expected partial match for '<TO>', got: {:?}", result);
}
}
#[test]
fn test_partial_at_exact_end() {
// Test case where a valid partial is exactly at the end
// Patterns: ["FooBar", "<TOOLCALL>"]
// Input: "Some text ending with FooBa"
// Expected: Hold "FooBa" as partial (valid prefix of "FooBar")
let patterns = vec!["FooBar".to_string(), "<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
let result = matcher.process_chunk("Some text ending with FooBa", "");
if let MatchResult::Partial {
prefix,
partial,
possible_patterns,
} = result
{
// Should find "FooBa" as a valid partial match at the end
assert_eq!(partial, "FooBa");
assert_eq!(prefix, "Some text ending with ");
assert!(possible_patterns.contains(&"FooBar".to_string()));
} else {
panic!("Expected partial match for 'FooBa', got: {:?}", result);
}
}
#[test]
fn test_unicode_complete_match() {
// Test complete pattern matching with unicode content
// Use patterns with ASCII markers but unicode content
let patterns = vec!["<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
// Test with emoji and multi-byte characters
let result = matcher.process_chunk("Hello 👋 world <TOOLCALL>data 🚀", "");
if let MatchResult::Complete {
prefix,
marker,
suffix,
..
} = result
{
assert_eq!(prefix, "Hello 👋 world ");
assert_eq!(marker, "<TOOLCALL>");
assert_eq!(suffix, "data 🚀");
} else {
panic!("Expected complete match, got: {:?}", result);
}
}
#[test]
fn test_unicode_partial_match() {
// Test partial matching where the partial might occur after unicode content
let patterns = vec!["<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
// Test partial after multi-byte characters
let result = matcher.process_chunk("Text with 中文字符 and <TO", "");
if let MatchResult::Partial {
prefix,
partial,
possible_patterns,
} = result
{
assert_eq!(prefix, "Text with 中文字符 and ");
assert_eq!(partial, "<TO");
assert!(possible_patterns.contains(&"<TOOLCALL>".to_string()));
} else {
panic!("Expected partial match, got: {:?}", result);
}
}
#[test]
fn test_unicode_no_false_positive() {
// Test that unicode content doesn't create false positives
let patterns = vec!["<TOOLCALL>".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
// Test with unicode that might look similar to ASCII patterns
let result = matcher.process_chunk("Unicode test <TOOLCALL> full-width", "");
if let MatchResult::None { content } = result {
assert_eq!(content, "Unicode test <TOOLCALL> full-width");
} else {
panic!(
"Expected no match for full-width characters, got: {:?}",
result
);
}
}
#[test]
fn test_unicode_pattern_itself() {
// Test patterns that contain unicode characters
let patterns = vec!["🔧工具".to_string(), "📞call".to_string()];
let matcher = MarkerMatcher::new(patterns).unwrap();
// Test complete match with unicode pattern
let result1 = matcher.process_chunk("Start 🔧工具 end", "");
if let MatchResult::Complete {
prefix,
marker,
suffix,
..
} = result1
{
assert_eq!(prefix, "Start ");
assert_eq!(marker, "🔧工具");
assert_eq!(suffix, " end");
} else {
panic!(
"Expected complete match for unicode pattern, got: {:?}",
result1
);
}
// Test partial match with unicode pattern
let result2 = matcher.process_chunk("Text 🔧工", "");
if let MatchResult::Partial {
prefix,
partial,
possible_patterns,
} = result2
{
assert_eq!(prefix, "Text ");
assert_eq!(partial, "🔧工");
assert!(possible_patterns.contains(&"🔧工具".to_string()));
} else {
panic!(
"Expected partial match for unicode pattern, got: {:?}",
result2
);
}
}
}
This diff is collapsed.
This diff is collapsed.
......@@ -159,11 +159,14 @@ async fn test_streaming_without_usage() {
// Create mock backend stream
let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx);
let backend_stream = create_mock_backend_stream(ctx.clone());
// Transform the stream
let transformed_stream =
OpenAIPreprocessor::transform_postprocessor_stream(backend_stream, response_generator);
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
// Collect all chunks
let chunks: Vec<_> = transformed_stream.collect().await;
......@@ -197,11 +200,14 @@ async fn test_streaming_with_usage_compliance() {
// Create mock backend stream
let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx);
let backend_stream = create_mock_backend_stream(ctx.clone());
// Transform the stream
let transformed_stream =
OpenAIPreprocessor::transform_postprocessor_stream(backend_stream, response_generator);
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
// Collect all chunks
let chunks: Vec<_> = transformed_stream.collect().await;
......@@ -267,11 +273,14 @@ async fn test_streaming_with_usage_false() {
// Create mock backend stream
let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx);
let backend_stream = create_mock_backend_stream(ctx.clone());
// Transform the stream
let transformed_stream =
OpenAIPreprocessor::transform_postprocessor_stream(backend_stream, response_generator);
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
// Collect all chunks
let chunks: Vec<_> = transformed_stream.collect().await;
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment