Unverified Commit 49b7a0d9 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: record + analyze logprobs (#1957)

parent 6d2be143
...@@ -147,6 +147,15 @@ version = "1.0.0" ...@@ -147,6 +147,15 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fca387cdc0a1f9c7a7c26556d584aa2d07fc529843082e4861003cde4ab914ed" checksum = "fca387cdc0a1f9c7a7c26556d584aa2d07fc529843082e4861003cde4ab914ed"
[[package]]
name = "approx"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "arbitrary" name = "arbitrary"
version = "1.4.1" version = "1.4.1"
...@@ -1773,6 +1782,7 @@ dependencies = [ ...@@ -1773,6 +1782,7 @@ dependencies = [
"akin", "akin",
"aligned-vec", "aligned-vec",
"anyhow", "anyhow",
"approx",
"assert_matches", "assert_matches",
"async-nats", "async-nats",
"async-openai", "async-openai",
......
...@@ -129,6 +129,7 @@ zeromq = "0.4.1" ...@@ -129,6 +129,7 @@ zeromq = "0.4.1"
rmp-serde = "1.3" rmp-serde = "1.3"
[dev-dependencies] [dev-dependencies]
approx = "0.5"
assert_matches = "1.5" assert_matches = "1.5"
criterion = { version = "0.3", features = ["html_reports"] } criterion = { version = "0.3", features = ["html_reports"] }
hf-hub = { workspace = true } hf-hub = { workspace = true }
......
...@@ -27,7 +27,7 @@ use crate::protocols::openai::chat_completions::{ ...@@ -27,7 +27,7 @@ use crate::protocols::openai::chat_completions::{
}; };
use crate::protocols::Annotated; use crate::protocols::Annotated;
use dynamo_runtime::engine::{ use dynamo_runtime::engine::{
AsyncEngineContext, AsyncEngineContextProvider, AsyncEngineStream, Data, AsyncEngineContext, AsyncEngineContextProvider, AsyncEngineStream, Data, DataStream,
}; };
/// Configuration for HTTP clients /// Configuration for HTTP clients
...@@ -226,43 +226,29 @@ impl BaseHttpClient { ...@@ -226,43 +226,29 @@ impl BaseHttpClient {
} }
/// Type alias for NV chat response stream /// Type alias for NV chat response stream
pub type NvChatResponseStream = Pin< pub type NvChatResponseStream =
Box< DataStream<Result<Annotated<NvCreateChatCompletionStreamResponse>, OpenAIError>>;
dyn Stream<Item = Result<Annotated<NvCreateChatCompletionStreamResponse>, OpenAIError>>
+ Send
+ Sync,
>,
>;
/// Type alias for generic BYOT response stream /// Type alias for generic BYOT response stream
pub type ByotResponseStream = Pin<Box<dyn Stream<Item = Result<Value, OpenAIError>> + Send + Sync>>; pub type ByotResponseStream = DataStream<Result<Value, OpenAIError>>;
/// Type alias for pure OpenAI chat response stream /// Type alias for pure OpenAI chat response stream
pub type OpenAIChatResponseStream = Pin< pub type OpenAIChatResponseStream =
Box< DataStream<Result<async_openai::types::CreateChatCompletionStreamResponse, OpenAIError>>;
dyn Stream<
Item = Result<async_openai::types::CreateChatCompletionStreamResponse, OpenAIError>,
> + Send
+ Sync,
>,
>;
/// A wrapped HTTP response stream that combines a stream with its context /// A wrapped HTTP response stream that combines a stream with its context
/// This provides a unified interface for HTTP client responses /// This provides a unified interface for HTTP client responses
#[derive(Dissolve)] #[derive(Dissolve)]
pub struct HttpResponseStream<T> { pub struct HttpResponseStream<T> {
/// The underlying stream of responses /// The underlying stream of responses
pub stream: Pin<Box<dyn Stream<Item = T> + Send>>, pub stream: DataStream<T>,
/// The context for this request /// The context for this request
pub context: Arc<dyn AsyncEngineContext>, pub context: Arc<dyn AsyncEngineContext>,
} }
impl<T> HttpResponseStream<T> { impl<T> HttpResponseStream<T> {
/// Create a new HttpResponseStream /// Create a new HttpResponseStream
pub fn new( pub fn new(stream: DataStream<T>, context: Arc<dyn AsyncEngineContext>) -> Self {
stream: Pin<Box<dyn Stream<Item = T> + Send>>,
context: Arc<dyn AsyncEngineContext>,
) -> Self {
Self { stream, context } Self { stream, context }
} }
} }
...@@ -299,7 +285,7 @@ impl<T: Data> HttpResponseStream<T> { ...@@ -299,7 +285,7 @@ impl<T: Data> HttpResponseStream<T> {
/// A wrapper that implements AsyncEngineStream for streams that are Send + Sync /// A wrapper that implements AsyncEngineStream for streams that are Send + Sync
struct AsyncEngineStreamWrapper<T> { struct AsyncEngineStreamWrapper<T> {
stream: Pin<Box<dyn Stream<Item = T> + Send>>, stream: DataStream<T>,
context: Arc<dyn AsyncEngineContext>, context: Arc<dyn AsyncEngineContext>,
} }
...@@ -317,10 +303,6 @@ impl<T: Data> AsyncEngineContextProvider for AsyncEngineStreamWrapper<T> { ...@@ -317,10 +303,6 @@ impl<T: Data> AsyncEngineContextProvider for AsyncEngineStreamWrapper<T> {
} }
} }
// This is unsafe because we're claiming the stream is Sync when it might not be
// But this is needed for the AsyncEngineStream trait
unsafe impl<T> Sync for AsyncEngineStreamWrapper<T> {}
impl<T: Data> AsyncEngineStream<T> for AsyncEngineStreamWrapper<T> {} impl<T: Data> AsyncEngineStream<T> for AsyncEngineStreamWrapper<T> {}
impl<T> std::fmt::Debug for AsyncEngineStreamWrapper<T> { impl<T> std::fmt::Debug for AsyncEngineStreamWrapper<T> {
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
//! This module provides mechanisms to record streaming responses with minimal overhead //! This module provides mechanisms to record streaming responses with minimal overhead
//! during collection, then analyze the recorded data for performance insights. //! during collection, then analyze the recorded data for performance insights.
pub mod logprobs;
use futures::Stream; use futures::Stream;
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
...@@ -339,7 +341,7 @@ pub fn record_response_stream<R: Data + Clone>( ...@@ -339,7 +341,7 @@ pub fn record_response_stream<R: Data + Clone>(
} }
#[cfg(test)] #[cfg(test)]
mod tests { pub mod tests {
use super::*; use super::*;
use dynamo_runtime::engine::ResponseStream; use dynamo_runtime::engine::ResponseStream;
use futures::stream; use futures::stream;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Module for recording logprobs from a streaming response.
//!
//! Logprobs are a bit easier than token counting and timing because they are
//! fully self-contained in the response chunk.
//!
//! In fact, if logprobs are given, they are a good way to count tokens; however,
//! the emission of logprobs is also more costly and generally not available unless
//! explicitly requested.
//!
//! The primary reason to record logprobs is to analyze the possible outputs of
//! a model as a function of sequence position.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use crate::perf::RecordedStream;
use crate::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
/// The type of logprobs observed in the response.
pub enum LogprobType {
/// If normalized, then all the reported "top_logprobs" sum to 0.
Normalized,
/// If unnormalized, then the reported "top_logprobs" are not normalized,
/// so the sum of the "top_logprobs" will not sum to 0.
Unnormalized,
}
/// Represents a token with its logprob information
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TokenLogprob {
/// The token as a string
pub token: String,
/// The log probability of this token
pub logprob: f32,
/// Optional byte representation of the token
pub bytes: Option<Vec<u8>>,
}
/// Represents logprob information for a single position with selected and alternative tokens
#[derive(Debug, Clone)]
pub struct TokenLogProbs {
selected: TokenLogprob,
alternatives: Vec<TokenLogprob>,
all_sorted: Vec<TokenLogprob>,
}
impl TokenLogProbs {
/// Create a new TokenLogProbs from a selected token and alternatives
pub fn new(selected: TokenLogprob, mut alternatives: Vec<TokenLogprob>) -> Self {
// Sort alternatives by logprob (highest first)
alternatives.sort_by(|a, b| b.logprob.partial_cmp(&a.logprob).unwrap());
// Create all_sorted by merging selected with alternatives (ensuring uniqueness)
let mut all_sorted = Vec::new();
let mut added_selected = false;
// Check if selected token appears in alternatives
let selected_in_alternatives = alternatives.iter().any(|alt| {
alt.token == selected.token && (alt.logprob - selected.logprob).abs() < 1e-6
});
// If selected is not in alternatives, we need to insert it in the right position
if !selected_in_alternatives {
// Find the correct position to insert selected token
let mut insert_position = alternatives.len();
for (i, alt) in alternatives.iter().enumerate() {
if selected.logprob > alt.logprob {
insert_position = i;
break;
}
}
// Build all_sorted by merging at the correct position
for (i, alt) in alternatives.iter().enumerate() {
if i == insert_position && !added_selected {
all_sorted.push(selected.clone());
added_selected = true;
}
all_sorted.push(alt.clone());
}
// If we haven't added selected yet, it goes at the end
if !added_selected {
all_sorted.push(selected.clone());
}
} else {
// Selected is already in alternatives, just use alternatives
all_sorted = alternatives.clone();
}
Self {
selected,
alternatives,
all_sorted,
}
}
/// Get the selected token
pub fn selected_token(&self) -> &TokenLogprob {
&self.selected
}
/// Get alternative tokens sorted by most likely first
pub fn alternative_tokens(&self) -> &[TokenLogprob] {
&self.alternatives
}
/// Get all tokens (selected merged with alternatives, unique) sorted by most likely first
pub fn all_tokens(&self) -> &[TokenLogprob] {
&self.all_sorted
}
}
/// Trait for extracting logprob information from various response types
pub trait LogprobExtractor {
/// Extract logprobs organized by choice index
/// Returns: HashMap<choice_index, Vec<TokenLogProbs>>
fn extract_logprobs_by_choice(&self) -> HashMap<u32, Vec<TokenLogProbs>>;
}
/// Implementation for NvCreateChatCompletionStreamResponse (our main streaming response type)
impl LogprobExtractor for NvCreateChatCompletionStreamResponse {
fn extract_logprobs_by_choice(&self) -> HashMap<u32, Vec<TokenLogProbs>> {
let mut result = HashMap::new();
for choice in &self.inner.choices {
let choice_index = choice.index;
let choice_logprobs = choice
.logprobs
.as_ref()
.and_then(|logprobs| logprobs.content.as_ref())
.map(|content| {
content
.iter()
.map(|token_logprob| {
let selected_token = TokenLogprob {
token: token_logprob.token.clone(),
logprob: token_logprob.logprob,
bytes: token_logprob.bytes.clone(),
};
// Convert top alternatives to our format
let alternatives: Vec<TokenLogprob> = token_logprob
.top_logprobs
.iter()
.map(|top_logprob| TokenLogprob {
token: top_logprob.token.clone(),
logprob: top_logprob.logprob,
bytes: top_logprob.bytes.clone(),
})
.collect();
TokenLogProbs::new(selected_token, alternatives)
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
result.insert(choice_index, choice_logprobs);
}
result
}
}
/// Validate and flatten choice logprobs HashMap to Vec
/// Ensures all expected choice indices [0, max_choice) are present
pub fn validate_and_flatten_choices(
choice_logprobs: HashMap<u32, Vec<TokenLogProbs>>,
) -> Result<Vec<Vec<TokenLogProbs>>, String> {
if choice_logprobs.is_empty() {
return Ok(Vec::new());
}
let max_choice = *choice_logprobs.keys().max().unwrap();
let expected_count = (max_choice + 1) as usize;
if choice_logprobs.len() != expected_count {
return Err(format!(
"Missing choice indices: expected {} choices [0, {}), but found {} choices: {:?}",
expected_count,
max_choice + 1,
choice_logprobs.len(),
choice_logprobs.keys().collect::<Vec<_>>()
));
}
// Validate all indices from 0 to max_choice are present
for i in 0..=max_choice {
if !choice_logprobs.contains_key(&i) {
return Err(format!(
"Missing choice index {}: expected [0, {}), found {:?}",
i,
max_choice + 1,
choice_logprobs.keys().collect::<Vec<_>>()
));
}
}
// Flatten to Vec ordered by keys
let mut result = Vec::with_capacity(expected_count);
for i in 0..=max_choice {
result.push(choice_logprobs[&i].clone());
}
Ok(result)
}
/// Analysis focused on detecting close logprobs indicating model uncertainty
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SensitivityAnalysis {
/// Total number of responses analyzed
pub total_responses: usize,
/// Analysis results per choice index
pub choice_analyses: HashMap<u32, ChoiceAnalysis>,
}
/// Analysis for a single choice
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChoiceAnalysis {
/// Choice index
pub choice_index: u32,
/// All positions with their closeness values, sorted by closeness
pub position_closeness: Vec<PositionCloseness>,
/// Number of positions analyzed for this choice
pub positions_analyzed: usize,
}
/// Closeness information for a position
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PositionCloseness {
/// Position in the stream (response index)
pub stream_position: usize,
/// Position within the token sequence
pub token_position: usize,
/// Logprob difference between top 2 candidates (deprecated - use probability_difference)
pub logprob_difference: f32,
/// Probability difference between top 2 candidates (in linear space 0-1)
pub probability_difference: f32,
/// Probability mass not accounted for by all_tokens (1 - sum of all_tokens probabilities)
pub probability_remaining: f32,
/// All candidates at this position, sorted by logprob (highest first)
pub candidates: Vec<TokenLogprob>,
}
/// A position where top candidates have close probabilities
#[derive(Debug, Clone)]
pub struct ClosePosition {
/// Position in the stream (response index)
pub stream_position: usize,
/// Position within the token sequence
pub token_position: usize,
/// Logprob difference between top 2 candidates (deprecated - use probability_difference)
pub logprob_difference: f32,
/// Probability difference between top 2 candidates (in linear space 0-1)
pub probability_difference: f32,
/// Probability mass not accounted for by top_candidates (1 - sum of top_candidates probabilities)
pub probability_remaining: f32,
/// Top 2 candidates at this position
pub top_candidates: Vec<TokenLogprob>,
}
/// Analyzes logprobs from a recorded stream focusing on token similarity/closeness
pub fn analyze_logprob_sensitivity(
recorded_stream: Arc<RecordedStream<impl LogprobExtractor>>,
) -> SensitivityAnalysis {
let mut choice_analyses: HashMap<u32, ChoiceAnalysis> = HashMap::new();
// Track cumulative sequence position per choice
let mut choice_sequence_positions: HashMap<u32, usize> = HashMap::new();
for (stream_pos, timestamped_response) in recorded_stream.responses().iter().enumerate() {
let response = &timestamped_response.response;
let logprobs_by_choice = response.extract_logprobs_by_choice();
for (choice_index, choice_logprobs) in logprobs_by_choice {
// Ensure we have a ChoiceAnalysis for this choice
let choice_analysis =
choice_analyses
.entry(choice_index)
.or_insert_with(|| ChoiceAnalysis {
choice_index,
position_closeness: Vec::new(),
positions_analyzed: 0,
});
// Get current sequence position for this choice
let current_seq_pos = choice_sequence_positions.entry(choice_index).or_insert(0);
for token_logprobs in choice_logprobs {
let all_tokens = token_logprobs.all_tokens();
if all_tokens.len() < 2 {
*current_seq_pos += 1;
continue;
}
// all_tokens is already sorted by logprob (highest first)
let sorted_candidates = all_tokens.to_vec();
// Calculate difference between top 2 in both logprob and probability space
let logprob_difference =
sorted_candidates[0].logprob - sorted_candidates[1].logprob;
// Convert to probability space for more intuitive closeness calculation
let prob1 = sorted_candidates[0].logprob.exp();
let prob2 = sorted_candidates[1].logprob.exp();
let probability_difference = prob1 - prob2;
// Calculate probability_remaining
let total_prob_sum: f32 = sorted_candidates.iter().map(|t| t.logprob.exp()).sum();
let probability_remaining = 1.0 - total_prob_sum;
choice_analysis.position_closeness.push(PositionCloseness {
stream_position: stream_pos,
token_position: *current_seq_pos,
logprob_difference,
probability_difference,
probability_remaining,
candidates: sorted_candidates,
});
choice_analysis.positions_analyzed += 1;
*current_seq_pos += 1;
}
}
}
// Sort position closeness by probability difference (smallest first = most uncertain)
for choice_analysis in choice_analyses.values_mut() {
choice_analysis.position_closeness.sort_by(|a, b| {
a.probability_difference
.partial_cmp(&b.probability_difference)
.unwrap()
});
}
SensitivityAnalysis {
total_responses: recorded_stream.responses().len(),
choice_analyses,
}
}
impl SensitivityAnalysis {
/// Get positions below a threshold for a specific choice
/// Threshold is in probability space (0-1), where smaller values indicate closer probabilities
pub fn get_close_positions_for_choice(
&self,
choice_index: u32,
threshold: f32,
) -> Vec<&PositionCloseness> {
self.choice_analyses
.get(&choice_index)
.map(|analysis| {
analysis
.position_closeness
.iter()
.filter(|pos| pos.probability_difference <= threshold)
.collect()
})
.unwrap_or_default()
}
/// Get the closest N positions for a specific choice
pub fn get_closest_positions_for_choice(
&self,
choice_index: u32,
count: usize,
) -> Vec<&PositionCloseness> {
self.choice_analyses
.get(&choice_index)
.map(|analysis| analysis.position_closeness.iter().take(count).collect())
.unwrap_or_default()
}
/// Print a summary of the sensitivity analysis
pub fn print_summary(&self) {
println!("=== Logprob Sensitivity Analysis Summary ===");
println!("Total stream responses analyzed: {}", self.total_responses);
println!("Number of choices: {}", self.choice_analyses.len());
println!();
for (choice_index, choice_analysis) in &self.choice_analyses {
println!(
"Choice {}: {} positions analyzed",
choice_index, choice_analysis.positions_analyzed
);
if !choice_analysis.position_closeness.is_empty() {
println!(" Closest positions (smallest probability differences):");
for (j, pos) in choice_analysis
.position_closeness
.iter()
.take(3)
.enumerate()
{
let top_token = &pos.candidates[0].token;
let second_token = &pos.candidates[1].token;
let prob1 = pos.candidates[0].logprob.exp();
let prob2 = pos.candidates[1].logprob.exp();
println!(
" {}: Stream pos {}, token pos {} - '{}' ({:.1}%) vs '{}' ({:.1}%) (prob diff: {:.4})",
j + 1,
pos.stream_position,
pos.token_position,
top_token,
prob1 * 100.0,
second_token,
prob2 * 100.0,
pos.probability_difference
);
}
}
println!();
}
}
/// Get percentage of positions with close probabilities for a specific choice
/// Threshold is in probability space (0-1)
pub fn close_position_percentage_for_choice(&self, choice_index: u32, threshold: f32) -> f32 {
if let Some(analysis) = self.choice_analyses.get(&choice_index) {
if analysis.positions_analyzed == 0 {
return 0.0;
}
let close_count = analysis
.position_closeness
.iter()
.filter(|pos| pos.probability_difference <= threshold)
.count();
(close_count as f32 / analysis.positions_analyzed as f32) * 100.0
} else {
0.0
}
}
/// Check if multiple tokens are close (within threshold of each other)
pub fn detect_multiple_close_tokens(
&self,
choice_index: u32,
threshold: f32,
) -> Vec<MultipleCloseTokens> {
let mut results = Vec::new();
if let Some(analysis) = self.choice_analyses.get(&choice_index) {
for pos in &analysis.position_closeness {
let close_tokens = self.count_close_tokens_at_position(pos, threshold);
if close_tokens.close_count > 2 {
results.push(close_tokens);
}
}
}
results
}
/// Detect if greedy decoding was likely used by checking if selected tokens are always the most probable
/// Note: This is an approximation since we infer selection from the data structure
pub fn detect_likely_greedy_decoding(&self, choice_index: u32) -> bool {
if let Some(analysis) = self.choice_analyses.get(&choice_index) {
if analysis.positions_analyzed == 0 {
return true; // No evidence against greedy
}
// For greedy detection, we're looking for positions with moderate to large differences
// Very small differences (< 0.01) suggest equal alternatives - could be greedy or random
// Very large differences (> 0.05) suggest clear winners - likely greedy
let likely_greedy_positions = analysis
.position_closeness
.iter()
.filter(|pos| {
if pos.candidates.is_empty() {
return true; // No contradiction
}
// Either very close (tie - could be greedy) or clear difference (likely greedy)
pos.probability_difference < 0.01 || pos.probability_difference > 0.05
})
.count();
// If most positions show greedy-like patterns, consider it greedy
(likely_greedy_positions as f32 / analysis.positions_analyzed as f32) > 0.5
} else {
false
}
}
/// Get percentage of positions with greedy-like selection patterns
pub fn greedy_selection_percentage(&self, choice_index: u32) -> f32 {
if let Some(analysis) = self.choice_analyses.get(&choice_index) {
if analysis.positions_analyzed == 0 {
return 0.0;
}
let greedy_like_positions = analysis
.position_closeness
.iter()
.filter(|pos| {
// Same logic as detect_likely_greedy_decoding for consistency
pos.probability_difference < 0.01 || pos.probability_difference > 0.05
})
.count();
(greedy_like_positions as f32 / analysis.positions_analyzed as f32) * 100.0
} else {
0.0
}
}
/// Count how many tokens are close at a specific position
/// Threshold is in probability space (0-1)
fn count_close_tokens_at_position(
&self,
position: &PositionCloseness,
threshold: f32,
) -> MultipleCloseTokens {
let top_prob = position.candidates[0].logprob.exp();
let mut close_count = 1; // Top token is always included
let mut close_tokens = vec![position.candidates[0].clone()];
for candidate in &position.candidates[1..] {
let candidate_prob = candidate.logprob.exp();
let prob_diff = top_prob - candidate_prob;
if prob_diff <= threshold {
close_count += 1;
close_tokens.push(candidate.clone());
} else {
break; // Since candidates are sorted, no need to check further
}
}
let max_difference = if close_count > 1 {
let last_prob = close_tokens.last().unwrap().logprob.exp();
top_prob - last_prob
} else {
0.0
};
MultipleCloseTokens {
stream_position: position.stream_position,
token_position: position.token_position,
close_count,
close_tokens,
max_difference,
}
}
}
/// Information about multiple close tokens at a position
#[derive(Debug, Clone)]
pub struct MultipleCloseTokens {
pub stream_position: usize,
pub token_position: usize,
pub close_count: usize,
pub close_tokens: Vec<TokenLogprob>,
pub max_difference: f32,
}
#[cfg(test)]
mod tests {
use super::*;
// Type aliases to simplify complex test data structures
type TestTokenAlternative = (&'static str, f32);
type TestTokenData = (&'static str, f32, Vec<TestTokenAlternative>);
type TestTokenDataVec = Vec<TestTokenData>;
use crate::perf::{record_stream_with_context, RecordingMode, TimestampedResponse};
use crate::protocols::codec::create_message_stream;
use crate::protocols::convert_sse_stream;
use approx::assert_abs_diff_eq;
use async_openai::types::{
ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionStreamResponseDelta,
ChatCompletionTokenLogprob, CreateChatCompletionStreamResponse, FinishReason, Role,
TopLogprobs,
};
use futures::StreamExt;
use std::sync::Arc;
use std::time::Instant;
const FLOAT_EPSILON: f32 = 1e-6;
#[test]
fn test_two_tokens_close() {
// Two very close tokens: 45% vs 44% (remaining 11% for other tokens)
// Linear probs: [0.45, 0.44], difference = 0.01
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"hello",
0.45,
vec![("world", 0.44)], // Very close: 45% vs 44%
)]);
let close_positions = analysis.get_close_positions_for_choice(0, 0.1);
assert_eq!(close_positions.len(), 1);
// Probability difference should be 0.01 (45% - 44%)
assert_abs_diff_eq!(
close_positions[0].probability_difference,
0.01,
epsilon = FLOAT_EPSILON
);
// Logprob difference: ln(0.45) - ln(0.44) ≈ -0.798 - (-0.821) ≈ 0.023
assert_abs_diff_eq!(
close_positions[0].logprob_difference,
0.023,
epsilon = 0.001
);
let multiple_close = analysis.detect_multiple_close_tokens(0, 0.05);
assert_eq!(multiple_close.len(), 0); // Only 2 tokens, so no "multiple" detected
}
#[test]
fn test_three_tokens_close() {
// Three close tokens: 35%, 33%, 32% (complete distribution)
// Linear probs: [0.35, 0.33, 0.32], differences = [0.02, 0.01]
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"hello",
0.35,
vec![
("world", 0.33), // Close: 35% vs 33% (diff = 0.02)
("there", 0.32), // Close: 33% vs 32% (diff = 0.01)
],
)]);
let close_positions = analysis.get_close_positions_for_choice(0, 0.025);
assert_eq!(close_positions.len(), 1);
// Top 2 probability difference: 0.35 - 0.33 = 0.02
assert_abs_diff_eq!(
close_positions[0].probability_difference,
0.02,
epsilon = FLOAT_EPSILON
);
let multiple_close = analysis.detect_multiple_close_tokens(0, 0.04);
assert_eq!(multiple_close.len(), 1);
assert_eq!(multiple_close[0].close_count, 3);
// Max difference: 0.35 - 0.32 = 0.03
assert_abs_diff_eq!(
multiple_close[0].max_difference,
0.03,
epsilon = FLOAT_EPSILON
);
}
#[test]
fn test_four_tokens_close() {
// Four close tokens: 27%, 26%, 25%, 22% (complete distribution)
// Linear probs: [0.27, 0.26, 0.25, 0.22], all very close
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"hello",
0.27,
vec![
("world", 0.26), // Close: 27% vs 26% (diff = 0.01)
("there", 0.25), // Close: 26% vs 25% (diff = 0.01)
("friend", 0.22), // Close: 25% vs 22% (diff = 0.03)
],
)]);
let close_positions = analysis.get_close_positions_for_choice(0, 0.02);
assert_eq!(close_positions.len(), 1);
// Top 2 probability difference: 0.27 - 0.26 = 0.01
assert_abs_diff_eq!(
close_positions[0].probability_difference,
0.01,
epsilon = FLOAT_EPSILON
);
let multiple_close = analysis.detect_multiple_close_tokens(0, 0.06);
assert_eq!(multiple_close.len(), 1);
assert_eq!(multiple_close[0].close_count, 4);
// Max difference: 0.27 - 0.22 = 0.05
assert_abs_diff_eq!(
multiple_close[0].max_difference,
0.05,
epsilon = FLOAT_EPSILON
);
}
#[test]
fn test_multiple_choices_analysis() {
let analysis = create_analysis_with_multiple_choices(vec![
// Choice 0: Moderately close tokens (70% vs 25%, remaining 5%)
vec![create_token_logprob_from_linear_probs(
"hello",
0.7,
vec![("world", 0.25)],
)],
// Choice 1: Very close tokens (50.5% vs 49.5%)
vec![create_token_logprob_from_linear_probs(
"hi",
0.505,
vec![("there", 0.495)],
)],
]);
assert_eq!(analysis.choice_analyses.len(), 2);
// Check choice 0: probability difference = 0.7 - 0.25 = 0.45
let choice0_close = analysis.get_close_positions_for_choice(0, 0.5);
assert_eq!(choice0_close.len(), 1);
assert_abs_diff_eq!(
choice0_close[0].probability_difference,
0.45,
epsilon = FLOAT_EPSILON
);
// Check choice 1: probability difference = 0.505 - 0.495 = 0.01
let choice1_close = analysis.get_close_positions_for_choice(1, 0.5);
assert_eq!(choice1_close.len(), 1);
assert_abs_diff_eq!(
choice1_close[0].probability_difference,
0.01,
epsilon = FLOAT_EPSILON
);
// Choice 1 should be much closer than choice 0
assert!(choice1_close[0].probability_difference < choice0_close[0].probability_difference);
}
#[test]
fn test_edge_case_single_token() {
// Position with only one token (100% probability, no alternatives)
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"hello",
1.0,
vec![],
)]);
let close_positions = analysis.get_close_positions_for_choice(0, 1.0);
assert_eq!(close_positions.len(), 0); // No close positions when only 1 token
}
#[test]
fn test_threshold_filtering() {
let analysis = create_analysis_with_logprobs(vec![
// Position 1: Close tokens (55% vs 45%)
create_token_logprob_from_linear_probs("token1", 0.55, vec![("token2", 0.45)]),
// Position 2: Far tokens (80% vs 20%)
create_token_logprob_from_linear_probs("token3", 0.8, vec![("token4", 0.2)]),
]);
// With threshold 0.15, only first position should be close (diff = 0.1)
let close_strict = analysis.get_close_positions_for_choice(0, 0.15);
assert_eq!(close_strict.len(), 1);
assert_abs_diff_eq!(
close_strict[0].probability_difference,
0.1,
epsilon = FLOAT_EPSILON
);
// With threshold 0.7, both positions should be close
let close_permissive = analysis.get_close_positions_for_choice(0, 0.7);
assert_eq!(close_permissive.len(), 2);
// Check they're sorted by closeness (0.1 < 0.6)
assert!(
close_permissive[0].probability_difference < close_permissive[1].probability_difference
);
}
#[test]
fn test_percentage_calculation() {
let analysis = create_analysis_with_logprobs(vec![
// Position 1: Close (60% vs 40%, diff = 0.2)
create_token_logprob_from_linear_probs("token1", 0.6, vec![("token2", 0.4)]),
// Position 2: Far (90% vs 10%, diff = 0.8)
create_token_logprob_from_linear_probs("token3", 0.9, vec![("token4", 0.1)]),
// Position 3: Close (52% vs 48%, diff = 0.04)
create_token_logprob_from_linear_probs("token5", 0.52, vec![("token6", 0.48)]),
]);
let percentage = analysis.close_position_percentage_for_choice(0, 0.25);
assert!((percentage - 66.67).abs() < 0.01); // 2 out of 3 positions are close
}
#[test]
fn test_real_vllm_equal_logprobs() {
// Real example from vLLM where two tokens have identical logprobs
// Both "Ġblock" and "Ġchunk" have logprob -0.9078922271728516
// exp(-0.9078922271728516) ≈ 0.403 = 40.3% each (sum = 80.6%, remaining 19.4%)
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"Ġblock",
0.403,
vec![("Ġchunk", 0.403)], // Identical probability = equally likely
)]);
// These should be detected as extremely close (difference = 0.0)
let close_positions = analysis.get_close_positions_for_choice(0, 0.001);
assert_eq!(close_positions.len(), 1);
assert_abs_diff_eq!(
close_positions[0].probability_difference,
0.0,
epsilon = FLOAT_EPSILON
);
// Verify probabilities are exactly equal at 40.3%
let position = &close_positions[0];
assert_eq!(position.candidates.len(), 2);
// Check that both tokens are present (order doesn't matter for equal logprobs)
let tokens: Vec<&str> = position
.candidates
.iter()
.map(|c| c.token.as_str())
.collect();
assert!(tokens.contains(&"Ġblock"));
assert!(tokens.contains(&"Ġchunk"));
// Both should have identical logprobs (ln(0.403) ≈ -0.907892)
assert_abs_diff_eq!(
position.candidates[0].logprob,
position.candidates[1].logprob,
epsilon = FLOAT_EPSILON
);
// Verify the actual probability values
let prob1 = position.candidates[0].logprob.exp();
let prob2 = position.candidates[1].logprob.exp();
assert_abs_diff_eq!(prob1, 0.403, epsilon = 0.001);
assert_abs_diff_eq!(prob2, 0.403, epsilon = 0.001);
}
// Helper functions for creating test data
fn create_analysis_with_logprobs(
token_logprobs: Vec<ChatCompletionTokenLogprob>,
) -> SensitivityAnalysis {
let start_time = Instant::now();
let response = create_mock_response_with_logprobs(token_logprobs);
let responses = vec![TimestampedResponse::new(response, 0)];
let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
let arc_stream = Arc::new(recorded_stream);
analyze_logprob_sensitivity(arc_stream)
}
fn create_analysis_with_multiple_choices(
choices_logprobs: Vec<Vec<ChatCompletionTokenLogprob>>,
) -> SensitivityAnalysis {
let start_time = Instant::now();
let response = create_mock_response_with_multiple_choices(choices_logprobs);
let responses = vec![TimestampedResponse::new(response, 0)];
let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
let arc_stream = Arc::new(recorded_stream);
analyze_logprob_sensitivity(arc_stream)
}
fn create_analysis_with_mixed_sampling(mixed_data: TestTokenDataVec) -> SensitivityAnalysis {
let start_time = Instant::now();
let token_logprobs: Vec<ChatCompletionTokenLogprob> = mixed_data
.into_iter()
.map(|(selected_token, selected_prob, alternatives)| {
create_token_logprob_from_linear_probs(selected_token, selected_prob, alternatives)
})
.collect();
let response = create_mock_response_with_logprobs(token_logprobs);
let responses = vec![TimestampedResponse::new(response, 0)];
let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
let arc_stream = Arc::new(recorded_stream);
analyze_logprob_sensitivity(arc_stream)
}
fn create_analysis_with_missing_selected_token() -> SensitivityAnalysis {
let start_time = Instant::now();
// Create a scenario where the selected token has a lower probability than alternatives
// This simulates non-greedy sampling: selected token 15%, but alternatives are 40% and 30%
let token_logprobs = vec![ChatCompletionTokenLogprob {
token: "unlikely_selection".to_string(),
logprob: (0.15_f32).ln(), // Selected but not optimal: 15%
bytes: None,
top_logprobs: vec![
TopLogprobs {
token: "best_option".to_string(),
logprob: (0.4_f32).ln(), // Much better option: 40%
bytes: None,
},
TopLogprobs {
token: "second_best".to_string(),
logprob: (0.3_f32).ln(), // Still better than selected: 30%
bytes: None,
},
],
}];
let response = create_mock_response_with_logprobs(token_logprobs);
let responses = vec![TimestampedResponse::new(response, 0)];
let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
let arc_stream = Arc::new(recorded_stream);
analyze_logprob_sensitivity(arc_stream)
}
/// Helper function to create token logprobs from linear probabilities [0, 1]
/// This ensures realistic probability distributions that sum to ≤ 1
fn create_token_logprob_from_linear_probs(
token: &str,
prob: f32,
top_probs: Vec<(&str, f32)>,
) -> ChatCompletionTokenLogprob {
// Validate that probabilities are in [0, 1] range
assert!(
(0.0..=1.0).contains(&prob),
"Probability must be in [0, 1]: {}",
prob
);
// Calculate total probability mass
let total_prob = prob + top_probs.iter().map(|(_, p)| p).sum::<f32>();
assert!(
total_prob <= 1.001,
"Total probability mass exceeds 1: {}",
total_prob
); // Allow small floating point error
for (_, p) in &top_probs {
assert!(
*p >= 0.0 && *p <= 1.0,
"Probability must be in [0, 1]: {}",
p
);
}
ChatCompletionTokenLogprob {
token: token.to_string(),
logprob: prob.ln(),
bytes: None,
top_logprobs: top_probs
.into_iter()
.map(|(t, p)| TopLogprobs {
token: t.to_string(),
logprob: p.ln(),
bytes: None,
})
.collect(),
}
}
fn create_mock_response_with_logprobs(
token_logprobs: Vec<ChatCompletionTokenLogprob>,
) -> NvCreateChatCompletionStreamResponse {
#[expect(deprecated)]
let inner = CreateChatCompletionStreamResponse {
id: "test_id".to_string(),
choices: vec![ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: Some("test".to_string()),
function_call: None,
tool_calls: None,
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: Some(FinishReason::Stop),
logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs),
refusal: None,
}),
}],
created: 1234567890,
model: "test-model".to_string(),
service_tier: None,
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
};
NvCreateChatCompletionStreamResponse { inner }
}
fn create_mock_response_with_multiple_choices(
choices_logprobs: Vec<Vec<ChatCompletionTokenLogprob>>,
) -> NvCreateChatCompletionStreamResponse {
#[expect(deprecated)]
let choices = choices_logprobs
.into_iter()
.enumerate()
.map(|(i, token_logprobs)| ChatChoiceStream {
index: i as u32,
delta: ChatCompletionStreamResponseDelta {
content: Some("test".to_string()),
function_call: None,
tool_calls: None,
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: Some(FinishReason::Stop),
logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs),
refusal: None,
}),
})
.collect();
let inner = CreateChatCompletionStreamResponse {
id: "test_id".to_string(),
choices,
created: 1234567890,
model: "test-model".to_string(),
service_tier: None,
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
};
NvCreateChatCompletionStreamResponse { inner }
}
#[test]
fn test_sensitivity_analysis() {
let start_time = Instant::now();
let responses = vec![TimestampedResponse::new(create_mock_response(), 0)];
let recorded_stream = RecordedStream::new(responses, start_time, Instant::now());
let arc_stream = Arc::new(recorded_stream);
let analysis = analyze_logprob_sensitivity(arc_stream);
// Basic validation that analysis was created
assert_eq!(analysis.total_responses, 1);
assert!(analysis.close_position_percentage_for_choice(0, 0.5) >= 0.0);
}
#[test]
fn test_extract_logprobs_by_choice_empty() {
let response = create_mock_response();
let logprobs = response.extract_logprobs_by_choice();
assert!(logprobs.is_empty() || logprobs.values().any(|v| v.is_empty()));
}
#[test]
fn test_token_logprobs_struct() {
// Test TokenLogProbs with selected token not in alternatives
let selected = TokenLogprob {
token: "selected".to_string(),
logprob: 0.7_f32.ln(), // 70%
bytes: None,
};
let alternatives = vec![
TokenLogprob {
token: "alt1".to_string(),
logprob: 0.2_f32.ln(), // 20%
bytes: None,
},
TokenLogprob {
token: "alt2".to_string(),
logprob: 0.1_f32.ln(), // 10%
bytes: None,
},
];
let token_logprobs = TokenLogProbs::new(selected.clone(), alternatives.clone());
// Test methods
assert_eq!(token_logprobs.selected_token(), &selected);
assert_eq!(token_logprobs.alternative_tokens().len(), 2);
assert_eq!(token_logprobs.all_tokens().len(), 3);
// Test sorting - all_tokens should be sorted by logprob (highest first)
let all_tokens = token_logprobs.all_tokens();
assert_eq!(all_tokens[0].token, "selected"); // 70%
assert_eq!(all_tokens[1].token, "alt1"); // 20%
assert_eq!(all_tokens[2].token, "alt2"); // 10%
// Test that alternatives are sorted
let alt_tokens = token_logprobs.alternative_tokens();
assert_eq!(alt_tokens[0].token, "alt1"); // 20%
assert_eq!(alt_tokens[1].token, "alt2"); // 10%
}
#[test]
fn test_token_logprobs_selected_in_alternatives() {
// Test case where selected token already appears in alternatives
let selected = TokenLogprob {
token: "token".to_string(),
logprob: 0.4_f32.ln(), // 40%
bytes: None,
};
let alternatives = vec![
TokenLogprob {
token: "token".to_string(),
logprob: 0.4_f32.ln(), // Same as selected
bytes: None,
},
TokenLogprob {
token: "other".to_string(),
logprob: 0.3_f32.ln(), // 30%
bytes: None,
},
];
let token_logprobs = TokenLogProbs::new(selected, alternatives.clone());
// all_tokens should not duplicate the selected token
let all_tokens = token_logprobs.all_tokens();
assert_eq!(all_tokens.len(), 2);
assert_eq!(all_tokens[0].token, "token"); // 40%
assert_eq!(all_tokens[1].token, "other"); // 30%
}
#[test]
fn test_validate_and_flatten_choices() {
// Test successful validation
let mut choices = HashMap::new();
choices.insert(0, vec![]);
choices.insert(1, vec![]);
choices.insert(2, vec![]);
let result = validate_and_flatten_choices(choices);
assert!(result.is_ok());
let flattened = result.unwrap();
assert_eq!(flattened.len(), 3);
// Test missing choice index
let mut choices = HashMap::new();
choices.insert(0, vec![]);
choices.insert(2, vec![]); // Missing index 1
let result = validate_and_flatten_choices(choices);
assert!(result.is_err());
let error_msg = result.unwrap_err();
assert!(
error_msg.contains("Missing choice indices")
&& error_msg.contains("expected 3 choices")
);
// Test empty choices
let choices = HashMap::new();
let result = validate_and_flatten_choices(choices);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 0);
}
#[test]
fn test_probability_remaining_calculation() {
// Test with tokens that don't sum to 1.0 (incomplete distribution)
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"token",
0.4, // 40%
vec![
("alt1", 0.3), // 30%
("alt2", 0.1), // 10%
// Missing 20% probability mass
],
)]);
let close_positions = analysis.get_close_positions_for_choice(0, 1.0);
assert_eq!(close_positions.len(), 1);
let position = &close_positions[0];
// Should have probability_remaining ≈ 0.2 (20% missing)
// Total: 40% + 30% + 10% = 80%, so remaining = 20%
assert_abs_diff_eq!(position.probability_remaining, 0.2, epsilon = 0.01);
// Test with tokens that nearly sum to 1.0 (complete distribution)
let analysis_complete =
create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"token",
0.5, // 50%
vec![
("alt1", 0.3), // 30%
("alt2", 0.2), // 20%
// Total: 100%
],
)]);
let complete_positions = analysis_complete.get_close_positions_for_choice(0, 1.0);
assert_eq!(complete_positions.len(), 1);
let complete_position = &complete_positions[0];
// Should have probability_remaining ≈ 0.0 (no missing mass)
assert_abs_diff_eq!(complete_position.probability_remaining, 0.0, epsilon = 0.01);
}
#[test]
fn test_position_closeness_ordering() {
let analysis = create_analysis_with_logprobs(vec![
// Position 1: Far apart (85% vs 15%, diff = 0.7)
create_token_logprob_from_linear_probs("far", 0.85, vec![("alt", 0.15)]),
// Position 2: Close (51% vs 49%, diff = 0.02)
create_token_logprob_from_linear_probs("close", 0.51, vec![("alt", 0.49)]),
// Position 3: Medium (70% vs 30%, diff = 0.4)
create_token_logprob_from_linear_probs("medium", 0.7, vec![("alt", 0.3)]),
]);
let positions = &analysis.choice_analyses.get(&0).unwrap().position_closeness;
assert_eq!(positions.len(), 3);
// Should be sorted by closeness (smallest difference first)
assert!(positions[0].probability_difference <= positions[1].probability_difference);
assert!(positions[1].probability_difference <= positions[2].probability_difference);
// Check actual values
assert_abs_diff_eq!(
positions[0].probability_difference,
0.02,
epsilon = FLOAT_EPSILON
);
assert_abs_diff_eq!(
positions[1].probability_difference,
0.4,
epsilon = FLOAT_EPSILON
);
assert_abs_diff_eq!(
positions[2].probability_difference,
0.7,
epsilon = FLOAT_EPSILON
);
}
#[test]
fn test_multiple_close_tokens_edge_cases() {
// Test with exactly 3 close tokens: 34%, 33%, 32% (close within 0.02)
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"token",
0.34,
vec![
("alt1", 0.33), // diff = 0.01
("alt2", 0.32), // diff = 0.01 from alt1, 0.02 from token
("alt3", 0.01), // diff = 0.31 (not close)
],
)]);
let multiple_close = analysis.detect_multiple_close_tokens(0, 0.025);
assert_eq!(multiple_close.len(), 1);
assert_eq!(multiple_close[0].close_count, 3);
}
#[test]
fn test_choice_analysis_independence() {
let analysis = create_analysis_with_multiple_choices(vec![
// Choice 0: 2 positions, 1 close
vec![
create_token_logprob_from_linear_probs("token1", 0.55, vec![("alt1", 0.45)]), // diff = 0.1
create_token_logprob_from_linear_probs("token2", 0.9, vec![("alt2", 0.1)]), // diff = 0.8
],
// Choice 1: 1 position, very close
vec![
create_token_logprob_from_linear_probs("token3", 0.501, vec![("alt3", 0.499)]), // diff = 0.002
],
]);
assert_eq!(analysis.choice_analyses.len(), 2);
assert_eq!(
analysis.choice_analyses.get(&0).unwrap().positions_analyzed,
2
);
assert_eq!(
analysis.choice_analyses.get(&1).unwrap().positions_analyzed,
1
);
// Check independence - each choice should have different closeness patterns
let choice0_close = analysis.get_close_positions_for_choice(0, 0.5);
let choice1_close = analysis.get_close_positions_for_choice(1, 0.5);
assert_eq!(choice0_close.len(), 1);
assert_eq!(choice1_close.len(), 1);
// Choice 1 should be much closer
assert!(choice1_close[0].probability_difference < choice0_close[0].probability_difference);
}
#[test]
fn test_get_closest_positions_boundary() {
let analysis = create_analysis_with_logprobs(vec![
create_token_logprob_from_linear_probs("token1", 0.6, vec![("alt1", 0.4)]),
create_token_logprob_from_linear_probs("token2", 0.75, vec![("alt2", 0.25)]),
]);
// Request more positions than available
let closest = analysis.get_closest_positions_for_choice(0, 10);
assert_eq!(closest.len(), 2);
// Request exactly the number available
let closest = analysis.get_closest_positions_for_choice(0, 2);
assert_eq!(closest.len(), 2);
// Request fewer
let closest = analysis.get_closest_positions_for_choice(0, 1);
assert_eq!(closest.len(), 1);
}
#[test]
fn test_zero_threshold() {
let analysis = create_analysis_with_logprobs(vec![
create_token_logprob_from_linear_probs("token", 0.5, vec![("alt", 0.5)]), // diff = 0.0
]);
let close_positions = analysis.get_close_positions_for_choice(0, 0.0);
assert_eq!(close_positions.len(), 1);
assert_abs_diff_eq!(
close_positions[0].probability_difference,
0.0,
epsilon = FLOAT_EPSILON
);
}
#[test]
fn test_nonexistent_choice() {
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"token",
0.6,
vec![("alt", 0.4)],
)]);
// Request analysis for non-existent choice
let close_positions = analysis.get_close_positions_for_choice(5, 0.1);
assert!(close_positions.is_empty());
let closest = analysis.get_closest_positions_for_choice(5, 3);
assert!(closest.is_empty());
let percentage = analysis.close_position_percentage_for_choice(5, 0.1);
assert_eq!(percentage, 0.0);
}
#[test]
fn test_logprob_extractor_with_missing_data() {
// Test with choice that has no logprobs
#[expect(deprecated)]
let inner = CreateChatCompletionStreamResponse {
id: "test_id".to_string(),
choices: vec![ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: Some("test".to_string()),
function_call: None,
tool_calls: None,
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: Some(FinishReason::Stop),
logprobs: None, // No logprobs
}],
created: 1234567890,
model: "test-model".to_string(),
service_tier: None,
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
};
let response = NvCreateChatCompletionStreamResponse { inner };
let logprobs = response.extract_logprobs_by_choice();
assert_eq!(logprobs.len(), 1);
assert!(logprobs.values().any(|v| v.is_empty()));
}
#[test]
fn test_print_summary_no_panic() {
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"token",
0.6,
vec![("alt", 0.4)],
)]);
// Should not panic when printing summary
analysis.print_summary();
}
#[test]
fn test_greedy_decoding_detection() {
// Greedy decoding: selected token is always the most probable
// Position 1: Clear winner (80% vs 15% vs 5%)
// Position 2: Another clear winner (70% vs 20% vs 10%)
let analysis = create_analysis_with_logprobs(vec![
create_token_logprob_from_linear_probs(
"best",
0.8,
vec![("second", 0.15), ("third", 0.05)],
),
create_token_logprob_from_linear_probs(
"optimal",
0.7,
vec![("suboptimal", 0.2), ("bad", 0.1)],
),
]);
// Should detect greedy-like behavior (selected tokens have highest probability)
let is_greedy = analysis.detect_likely_greedy_decoding(0);
assert!(is_greedy);
let greedy_percentage = analysis.greedy_selection_percentage(0);
assert!(greedy_percentage > 90.0); // Should be close to 100%
}
#[test]
fn test_non_greedy_decoding_detection() {
// Non-greedy decoding: some positions show sampling behavior
// Position 1: Greedy selection (best token chosen: 60% vs 40%)
// Position 2: Non-greedy-like (close tokens: 35% vs 33% vs 32%)
let analysis = create_analysis_with_mixed_sampling(vec![
("selected_best", 0.6, vec![("alternative", 0.4)]),
(
"close_choice",
0.35,
vec![("very_close", 0.33), ("also_close", 0.32)],
),
]);
let _is_greedy = analysis.detect_likely_greedy_decoding(0);
// This should be detected as greedy since we have some clear differences
let greedy_percentage = analysis.greedy_selection_percentage(0);
assert!((0.0..=100.0).contains(&greedy_percentage)); // Valid percentage range
}
#[test]
fn test_selected_token_not_in_top_logprobs() {
// Edge case: selected token doesn't appear in top_logprobs at all
// Selected: 15%, but alternatives are 40% and 30% (non-greedy sampling)
let analysis = create_analysis_with_missing_selected_token();
// Should still work - the algorithm adapts to different logprob patterns
let greedy_percentage = analysis.greedy_selection_percentage(0);
assert!((0.0..=100.0).contains(&greedy_percentage)); // Valid percentage range
}
#[test]
fn test_equal_logprobs_greedy_detection() {
// Test the original vLLM example - equal logprobs should be detected as close
let analysis = create_analysis_with_logprobs(vec![create_token_logprob_from_linear_probs(
"Ġblock",
0.403,
vec![("Ġchunk", 0.403)], // Identical probability = equally likely
)]);
// Equal probabilities should be detected as extremely close
let close_positions = analysis.get_close_positions_for_choice(0, 0.001);
assert_eq!(close_positions.len(), 1);
// Should be detected as greedy-like since there's no clear better choice
let is_greedy = analysis.detect_likely_greedy_decoding(0);
assert!(is_greedy);
}
#[tokio::test]
async fn test_real_sse_stream_analysis() {
// Read the real SSE data with logprobs
let data = std::fs::read_to_string(
"tests/data/replays/deepseek-r1-distill-llama-8b/chat-completions.stream.1",
)
.expect("Failed to read test data file");
// Create stream from SSE data
let sse_stream = create_message_stream(&data);
// Convert SSE messages to our stream response format using the existing converter
let response_stream =
convert_sse_stream::<NvCreateChatCompletionStreamResponse>(Box::pin(sse_stream));
// Filter out errors and extract successful responses
let filtered_stream = response_stream.filter_map(|annotated| async move { annotated.data });
// Create a mock context for recording
let ctx = Arc::new(MockContext::new());
// Record the stream
let (recorded_stream, recording_rx) =
record_stream_with_context(Box::pin(filtered_stream), ctx, RecordingMode::Sink);
// Consume the stream (it will be recorded)
let _collected: Vec<_> = recorded_stream.collect().await;
// Get the recorded data
let recorded = recording_rx
.await
.expect("Failed to receive recorded stream");
// Verify we have data
assert!(recorded.response_count() > 0, "No responses recorded");
println!("Recorded {} responses", recorded.response_count());
// Perform logprob analysis
let arc_recorded = Arc::new(recorded);
let analysis = analyze_logprob_sensitivity(arc_recorded);
// Print analysis summary
analysis.print_summary();
// Verify the analysis found logprob data
assert!(
!analysis.choice_analyses.is_empty(),
"No choice analyses found"
);
assert!(
analysis
.choice_analyses
.values()
.any(|a| a.positions_analyzed > 0),
"No positions analyzed"
);
// Look for the specific vLLM case with equal logprobs ("Ġblock" vs "Ġchunk")
let close_positions = analysis.get_close_positions_for_choice(0, 0.001);
// Should find at least one very close position (the equal logprob case)
assert!(!close_positions.is_empty(), "No close positions found");
// Check if we found the exact equal case (difference = 0)
let equal_positions = close_positions
.iter()
.filter(|pos| pos.probability_difference < 0.0001)
.count();
if equal_positions > 0 {
println!(
"Found {} positions with nearly equal probabilities",
equal_positions
);
}
// Test other analysis methods
let closest_3 = analysis.get_closest_positions_for_choice(0, 3);
assert!(
closest_3.len() <= 3,
"Should return at most 3 closest positions"
);
let percentage = analysis.close_position_percentage_for_choice(0, 0.1);
assert!(
(0.0..=100.0).contains(&percentage),
"Percentage should be valid"
);
// Test greedy detection
let is_greedy = analysis.detect_likely_greedy_decoding(0);
let greedy_percentage = analysis.greedy_selection_percentage(0);
println!(
"Greedy detection: {} ({}% greedy-like)",
is_greedy, greedy_percentage
);
// Test multiple close tokens detection
let multiple_close = analysis.detect_multiple_close_tokens(0, 0.05);
if !multiple_close.is_empty() {
println!(
"Found {} positions with multiple close tokens",
multiple_close.len()
);
}
}
fn create_mock_response() -> NvCreateChatCompletionStreamResponse {
// Create a mock response for testing
// In practice, this would have real logprobs data
use async_openai::types::CreateChatCompletionStreamResponse;
let inner = CreateChatCompletionStreamResponse {
id: "test_id".to_string(),
choices: vec![],
created: 1234567890,
model: "test-model".to_string(),
service_tier: None,
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
};
NvCreateChatCompletionStreamResponse { inner }
}
// Mock context for testing
#[derive(Debug)]
struct MockContext {
id: String,
}
impl MockContext {
fn new() -> Self {
Self {
id: "test-context".to_string(),
}
}
}
#[async_trait::async_trait]
impl dynamo_runtime::engine::AsyncEngineContext for MockContext {
fn id(&self) -> &str {
&self.id
}
fn stop(&self) {
// No-op for testing
}
fn stop_generating(&self) {
// No-op for testing
}
fn kill(&self) {
// No-op for testing
}
fn is_stopped(&self) -> bool {
false
}
fn is_killed(&self) -> bool {
false
}
async fn stopped(&self) {
// No-op for testing
}
async fn killed(&self) {
// No-op for testing
}
}
}
...@@ -19,9 +19,7 @@ ...@@ -19,9 +19,7 @@
//! both publicly via the HTTP API and internally between Dynamo components. //! both publicly via the HTTP API and internally between Dynamo components.
//! //!
use std::pin::Pin; use futures::StreamExt;
use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub mod codec; pub mod codec;
...@@ -30,7 +28,7 @@ pub mod openai; ...@@ -30,7 +28,7 @@ pub mod openai;
/// The token ID type /// The token ID type
pub type TokenIdType = u32; pub type TokenIdType = u32;
pub type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>; pub use dynamo_runtime::engine::DataStream;
// TODO: This is an awkward dependency that we need to address // TODO: This is an awkward dependency that we need to address
// Originally, all the Annotated/SSE Codec bits where in the LLM protocol module; however, [Annotated] // Originally, all the Annotated/SSE Codec bits where in the LLM protocol module; however, [Annotated]
......
...@@ -13,9 +13,8 @@ ...@@ -13,9 +13,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{collections::HashMap, pin::Pin}; use futures::StreamExt;
use std::collections::HashMap;
use futures::{Stream, StreamExt};
use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse}; use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
use crate::protocols::{ use crate::protocols::{
...@@ -24,7 +23,7 @@ use crate::protocols::{ ...@@ -24,7 +23,7 @@ use crate::protocols::{
}; };
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`. /// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>; use dynamo_runtime::engine::DataStream;
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single /// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
/// [`NvCreateChatCompletionResponse`]. This struct accumulates incremental responses /// [`NvCreateChatCompletionResponse`]. This struct accumulates incremental responses
......
...@@ -13,18 +13,14 @@ ...@@ -13,18 +13,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::pin::Pin;
use futures::{Stream, StreamExt};
use super::NvCreateEmbeddingResponse; use super::NvCreateEmbeddingResponse;
use crate::protocols::{ use crate::protocols::{
codec::{Message, SseCodecError}, codec::{Message, SseCodecError},
convert_sse_stream, Annotated, convert_sse_stream, Annotated,
}; };
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`. use dynamo_runtime::engine::DataStream;
type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>; use futures::StreamExt;
/// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single /// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single
/// [`NvCreateEmbeddingResponse`]. For embeddings, this is typically simpler /// [`NvCreateEmbeddingResponse`]. For embeddings, this is typically simpler
......
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":"Okay","tool_calls":[]},"logprobs":{"content":[{"token":"Okay","logprob":-0.5292773246765137,"bytes":[79,107,97,121],"top_logprobs":[{"token":"Okay","logprob":-0.5292773246765137,"bytes":[79,107,97,121]},{"token":"Alright","logprob":-0.9042773246765137,"bytes":[65,108,114,105,103,104,116]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":",","tool_calls":[]},"logprobs":{"content":[{"token":",","logprob":-0.000017165990357170813,"bytes":[44],"top_logprobs":[{"token":",","logprob":-0.000017165990357170813,"bytes":[44]},{"token":"Ġso","logprob":-11.812517166137695,"bytes":[196,160,115,111]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" so","tool_calls":[]},"logprobs":{"content":[{"token":"Ġso","logprob":-0.10039777308702469,"bytes":[196,160,115,111],"top_logprobs":[{"token":"Ġso","logprob":-0.10039777308702469,"bytes":[196,160,115,111]},{"token":"Ġthe","logprob":-2.600397825241089,"bytes":[196,160,116,104,101]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" I","tool_calls":[]},"logprobs":{"content":[{"token":"ĠI","logprob":-0.07118851691484451,"bytes":[196,160,73],"top_logprobs":[{"token":"ĠI","logprob":-0.07118851691484451,"bytes":[196,160,73]},{"token":"Ġthe","logprob":-2.696188449859619,"bytes":[196,160,116,104,101]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":"'m","tool_calls":[]},"logprobs":{"content":[{"token":"'m","logprob":-0.5393549799919128,"bytes":[39,109],"top_logprobs":[{"token":"'m","logprob":-0.5393549799919128,"bytes":[39,109]},{"token":"'ve","logprob":-2.2893550395965576,"bytes":[39,118,101]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" trying","tool_calls":[]},"logprobs":{"content":[{"token":"Ġtrying","logprob":-0.2027934193611145,"bytes":[196,160,116,114,121,105,110,103],"top_logprobs":[{"token":"Ġtrying","logprob":-0.2027934193611145,"bytes":[196,160,116,114,121,105,110,103]},{"token":"Ġlooking","logprob":-1.8277933597564697,"bytes":[196,160,108,111,111,107,105,110,103]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" to","tool_calls":[]},"logprobs":{"content":[{"token":"Ġto","logprob":-1.5497195136049413e-6,"bytes":[196,160,116,111],"top_logprobs":[{"token":"Ġto","logprob":-1.5497195136049413e-6,"bytes":[196,160,116,111]},{"token":"Ġout","logprob":-14.187501907348633,"bytes":[196,160,111,117,116]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" figure","tool_calls":[]},"logprobs":{"content":[{"token":"Ġfigure","logprob":-0.42643895745277405,"bytes":[196,160,102,105,103,117,114,101],"top_logprobs":[{"token":"Ġfigure","logprob":-0.42643895745277405,"bytes":[196,160,102,105,103,117,114,101]},{"token":"Ġunderstand","logprob":-1.1764389276504517,"bytes":[196,160,117,110,100,101,114,115,116,97,110,100]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" out","tool_calls":[]},"logprobs":{"content":[{"token":"Ġout","logprob":-0.00021181246847845614,"bytes":[196,160,111,117,116],"top_logprobs":[{"token":"Ġout","logprob":-0.00021181246847845614,"bytes":[196,160,111,117,116]},{"token":"Ġthis","logprob":-8.500211715698242,"bytes":[196,160,116,104,105,115]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" how","tool_calls":[]},"logprobs":{"content":[{"token":"Ġhow","logprob":-0.5830438137054443,"bytes":[196,160,104,111,119],"top_logprobs":[{"token":"Ġhow","logprob":-0.5830438137054443,"bytes":[196,160,104,111,119]},{"token":"Ġwhat","logprob":-1.0830438137054443,"bytes":[196,160,119,104,97,116]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" to","tool_calls":[]},"logprobs":{"content":[{"token":"Ġto","logprob":-0.0042633600533008575,"bytes":[196,160,116,111],"top_logprobs":[{"token":"Ġto","logprob":-0.0042633600533008575,"bytes":[196,160,116,111]},{"token":"Ġthis","logprob":-6.004263401031494,"bytes":[196,160,116,104,105,115]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" respond","tool_calls":[]},"logprobs":{"content":[{"token":"Ġrespond","logprob":-0.5788105726242065,"bytes":[196,160,114,101,115,112,111,110,100],"top_logprobs":[{"token":"Ġrespond","logprob":-0.5788105726242065,"bytes":[196,160,114,101,115,112,111,110,100]},{"token":"Ġapproach","logprob":-1.2038105726242065,"bytes":[196,160,97,112,112,114,111,97,99,104]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" to","tool_calls":[]},"logprobs":{"content":[{"token":"Ġto","logprob":-0.0014138950500637293,"bytes":[196,160,116,111],"top_logprobs":[{"token":"Ġto","logprob":-0.0014138950500637293,"bytes":[196,160,116,111]},{"token":"Ġhelp","logprob":-7.751413822174072,"bytes":[196,160,104,101,108,112]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" this","tool_calls":[]},"logprobs":{"content":[{"token":"Ġthis","logprob":-0.16383041441440582,"bytes":[196,160,116,104,105,115],"top_logprobs":[{"token":"Ġthis","logprob":-0.16383041441440582,"bytes":[196,160,116,104,105,115]},{"token":"Ġthe","logprob":-1.9138303995132446,"bytes":[196,160,116,104,101]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" user","tool_calls":[]},"logprobs":{"content":[{"token":"Ġuser","logprob":-0.342995822429657,"bytes":[196,160,117,115,101,114],"top_logprobs":[{"token":"Ġuser","logprob":-0.342995822429657,"bytes":[196,160,117,115,101,114]},{"token":"Ġmessage","logprob":-2.4054958820343018,"bytes":[196,160,109,101,115,115,97,103,101]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":"'s","tool_calls":[]},"logprobs":{"content":[{"token":"'s","logprob":-0.6218149065971375,"bytes":[39,115],"top_logprobs":[{"token":"'s","logprob":-0.6218149065971375,"bytes":[39,115]},{"token":".","logprob":-1.2468149662017822,"bytes":[46]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" message","tool_calls":[]},"logprobs":{"content":[{"token":"Ġmessage","logprob":-0.49226677417755127,"bytes":[196,160,109,101,115,115,97,103,101],"top_logprobs":[{"token":"Ġmessage","logprob":-0.49226677417755127,"bytes":[196,160,109,101,115,115,97,103,101]},{"token":"Ġquery","logprob":-1.1172667741775513,"bytes":[196,160,113,117,101,114,121]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":".","tool_calls":[]},"logprobs":{"content":[{"token":".","logprob":-0.004951002076268196,"bytes":[46],"top_logprobs":[{"token":".","logprob":-0.004951002076268196,"bytes":[46]},{"token":"Ġthat","logprob":-5.879951000213623,"bytes":[196,160,116,104,97,116]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" They","tool_calls":[]},"logprobs":{"content":[{"token":"ĠThey","logprob":-0.038852665573358536,"bytes":[196,160,84,104,101,121],"top_logprobs":[{"token":"ĠThey","logprob":-0.038852665573358536,"bytes":[196,160,84,104,101,121]},{"token":"ĠLet","logprob":-3.7888526916503906,"bytes":[196,160,76,101,116]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" provided","tool_calls":[]},"logprobs":{"content":[{"token":"Ġprovided","logprob":-0.4865674376487732,"bytes":[196,160,112,114,111,118,105,100,101,100],"top_logprobs":[{"token":"Ġprovided","logprob":-0.4865674376487732,"bytes":[196,160,112,114,111,118,105,100,101,100]},{"token":"'ve","logprob":-1.736567497253418,"bytes":[39,118,101]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" a","tool_calls":[]},"logprobs":{"content":[{"token":"Ġa","logprob":-0.08489075303077698,"bytes":[196,160,97],"top_logprobs":[{"token":"Ġa","logprob":-0.08489075303077698,"bytes":[196,160,97]},{"token":"Ġsome","logprob":-2.584890842437744,"bytes":[196,160,115,111,109,101]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" block","tool_calls":[]},"logprobs":{"content":[{"token":"Ġblock","logprob":-0.9078922271728516,"bytes":[196,160,98,108,111,99,107],"top_logprobs":[{"token":"Ġblock","logprob":-0.9078922271728516,"bytes":[196,160,98,108,111,99,107]},{"token":"Ġchunk","logprob":-0.9078922271728516,"bytes":[196,160,99,104,117,110,107]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" of","tool_calls":[]},"logprobs":{"content":[{"token":"Ġof","logprob":-4.172316494077677e-6,"bytes":[196,160,111,102],"top_logprobs":[{"token":"Ġof","logprob":-4.172316494077677e-6,"bytes":[196,160,111,102]},{"token":"Ġthat","logprob":-13.062503814697266,"bytes":[196,160,116,104,97,116]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" text","tool_calls":[]},"logprobs":{"content":[{"token":"Ġtext","logprob":-0.1239960640668869,"bytes":[196,160,116,101,120,116],"top_logprobs":[{"token":"Ġtext","logprob":-0.1239960640668869,"bytes":[196,160,116,101,120,116]},{"token":"ĠLorem","logprob":-2.7489960193634033,"bytes":[196,160,76,111,114,101,109]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" that","tool_calls":[]},"logprobs":{"content":[{"token":"Ġthat","logprob":-0.021982578560709953,"bytes":[196,160,116,104,97,116],"top_logprobs":[{"token":"Ġthat","logprob":-0.021982578560709953,"bytes":[196,160,116,104,97,116]},{"token":"Ġwhich","logprob":-4.646982669830322,"bytes":[196,160,119,104,105,99,104]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" looks","tool_calls":[]},"logprobs":{"content":[{"token":"Ġlooks","logprob":-0.6330966353416443,"bytes":[196,160,108,111,111,107,115],"top_logprobs":[{"token":"Ġlooks","logprob":-0.6330966353416443,"bytes":[196,160,108,111,111,107,115]},{"token":"'s","logprob":-1.133096694946289,"bytes":[39,115]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" like","tool_calls":[]},"logprobs":{"content":[{"token":"Ġlike","logprob":-0.001482222112827003,"bytes":[196,160,108,105,107,101],"top_logprobs":[{"token":"Ġlike","logprob":-0.001482222112827003,"bytes":[196,160,108,105,107,101]},{"token":"Ġa","logprob":-7.001482009887695,"bytes":[196,160,97]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" Lorem","tool_calls":[]},"logprobs":{"content":[{"token":"ĠLorem","logprob":-0.2382608950138092,"bytes":[196,160,76,111,114,101,109],"top_logprobs":[{"token":"ĠLorem","logprob":-0.2382608950138092,"bytes":[196,160,76,111,114,101,109]},{"token":"Ġplaceholder","logprob":-2.6132609844207764,"bytes":[196,160,112,108,97,99,101,104,111,108,100,101,114]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" Ipsum","tool_calls":[]},"logprobs":{"content":[{"token":"ĠIpsum","logprob":-0.22565951943397522,"bytes":[196,160,73,112,115,117,109],"top_logprobs":[{"token":"ĠIpsum","logprob":-0.22565951943397522,"bytes":[196,160,73,112,115,117,109]},{"token":"Ġipsum","logprob":-1.6006594896316528,"bytes":[196,160,105,112,115,117,109]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":",","tool_calls":[]},"logprobs":{"content":[{"token":",","logprob":-0.02414931170642376,"bytes":[44],"top_logprobs":[{"token":",","logprob":-0.02414931170642376,"bytes":[44]},{"token":"Ġand","logprob":-4.399149417877197,"bytes":[196,160,97,110,100]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" which","tool_calls":[]},"logprobs":{"content":[{"token":"Ġwhich","logprob":-0.02117946185171604,"bytes":[196,160,119,104,105,99,104],"top_logprobs":[{"token":"Ġwhich","logprob":-0.02117946185171604,"bytes":[196,160,119,104,105,99,104]},{"token":"Ġand","logprob":-4.271179676055908,"bytes":[196,160,97,110,100]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" is","tool_calls":[]},"logprobs":{"content":[{"token":"Ġis","logprob":-0.43066686391830444,"bytes":[196,160,105,115],"top_logprobs":[{"token":"Ġis","logprob":-0.43066686391830444,"bytes":[196,160,105,115]},{"token":"ĠI","logprob":-1.0556669235229492,"bytes":[196,160,73]}]}]},"finish_reason":"length"}]}
data: [DONE]
captured from 0.9.0.2.dev22+gbc825748a.d20250715.precompiled
script to generate deepseek-r1-distill-llama-8b/chat-completions.stream.logprobs.1
```
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [{"role": "user", "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat."}],
"max_tokens": 32,
"temperature": 0.0,
"top_p": 0.001,"stream":true,"logprobs":1,"top_logprobs":2
}'
```
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for logprob analysis functionality
use std::sync::Arc;
use std::time::Instant;
use dynamo_llm::perf::logprobs::analyze_logprob_sensitivity;
use dynamo_llm::perf::{RecordedStream, TimestampedResponse};
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use async_openai::types::{
ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionStreamResponseDelta,
ChatCompletionTokenLogprob, CreateChatCompletionStreamResponse, FinishReason, Role,
TopLogprobs,
};
// Type aliases to simplify complex test data structures
type TokenAlternative = (&'static str, f32);
type TokenData = (&'static str, f32, Vec<TokenAlternative>);
type TokenDataVec = Vec<TokenData>;
// Type aliases for multi-choice test data (using String instead of &str)
type StringTokenAlternative = (String, f32);
type StringTokenData = (String, f32, Vec<StringTokenAlternative>);
type ChoiceTokenData = Vec<StringTokenData>;
type MultiChoiceData = Vec<ChoiceTokenData>;
/// Test full workflow with realistic streaming data
#[test]
fn test_realistic_streaming_analysis() {
let stream = create_realistic_stream();
let analysis = analyze_logprob_sensitivity(stream);
// Verify basic structure
assert_eq!(analysis.total_responses, 3);
assert_eq!(analysis.choice_analyses.len(), 1);
assert_eq!(
analysis.choice_analyses.get(&0).unwrap().positions_analyzed,
3
);
// Check that positions are sorted by closeness
let positions = &analysis.choice_analyses.get(&0).unwrap().position_closeness;
for i in 1..positions.len() {
assert!(positions[i - 1].probability_difference <= positions[i].probability_difference);
}
// Test API methods
let close_positions = analysis.get_close_positions_for_choice(0, 0.2);
assert!(!close_positions.is_empty());
let percentage = analysis.close_position_percentage_for_choice(0, 0.2);
assert!((0.0..=100.0).contains(&percentage));
}
/// Test multiple choices analysis
#[test]
fn test_multiple_choices_independent_analysis() {
let stream = create_multi_choice_stream();
let analysis = analyze_logprob_sensitivity(stream);
// Should have 2 choices
assert_eq!(analysis.choice_analyses.len(), 2);
// Each choice should be analyzed independently
let choice0_count = analysis.choice_analyses.get(&0).unwrap().positions_analyzed;
let choice1_count = analysis.choice_analyses.get(&1).unwrap().positions_analyzed;
assert_eq!(choice0_count, 2);
assert_eq!(choice1_count, 2);
// Test that choices have different closeness patterns
let choice0_close = analysis.get_close_positions_for_choice(0, 0.3);
let choice1_close = analysis.get_close_positions_for_choice(1, 0.3);
// Based on our test data, choice 1 should have closer logprobs
assert!(choice1_close.len() >= choice0_close.len());
}
/// Test detection of multiple close tokens
#[test]
fn test_multiple_close_tokens_detection() {
let stream = create_stream_with_multiple_close_tokens();
let analysis = analyze_logprob_sensitivity(stream);
// Should detect positions with 3+ close tokens
let multiple_close = analysis.detect_multiple_close_tokens(0, 0.05);
assert!(!multiple_close.is_empty());
let first_multiple = &multiple_close[0];
assert!(first_multiple.close_count >= 3);
assert!(first_multiple.max_difference <= 0.05);
// Verify the close tokens are actually close in probability space
for i in 1..first_multiple.close_tokens.len() {
let prob_top = first_multiple.close_tokens[0].logprob.exp();
let prob_current = first_multiple.close_tokens[i].logprob.exp();
let diff = prob_top - prob_current;
assert!(diff <= 0.05);
}
}
/// Test edge cases and error handling
#[test]
fn test_edge_cases() {
// Empty stream
let empty_stream = create_empty_stream();
let analysis = analyze_logprob_sensitivity(empty_stream);
assert_eq!(analysis.total_responses, 0);
assert!(analysis.choice_analyses.is_empty());
// Single token positions (no alternatives)
let single_token_stream = create_single_token_stream();
let analysis = analyze_logprob_sensitivity(single_token_stream);
// Should have no close positions since there's only one token per position
let close_positions = analysis.get_close_positions_for_choice(0, 1.0);
assert!(close_positions.is_empty());
}
/// Test threshold sensitivity
#[test]
fn test_threshold_sensitivity() {
let stream = create_graduated_closeness_stream();
let analysis = analyze_logprob_sensitivity(stream);
// Test different thresholds
let strict_close = analysis.get_close_positions_for_choice(0, 0.01);
let permissive_close = analysis.get_close_positions_for_choice(0, 0.1);
let very_permissive_close = analysis.get_close_positions_for_choice(0, 0.5);
// Should have increasing numbers of close positions
assert!(strict_close.len() <= permissive_close.len());
assert!(permissive_close.len() <= very_permissive_close.len());
// Percentages should increase with threshold
let strict_pct = analysis.close_position_percentage_for_choice(0, 0.01);
let permissive_pct = analysis.close_position_percentage_for_choice(0, 0.1);
assert!(strict_pct <= permissive_pct);
}
/// Test performance with larger datasets
#[test]
fn test_large_dataset_performance() {
let stream = create_large_stream(100, 5); // 100 positions, 5 choices
let start_time = Instant::now();
let analysis = analyze_logprob_sensitivity(stream);
let elapsed = start_time.elapsed();
// Should complete quickly
assert!(elapsed.as_millis() < 100);
// Verify correctness
assert_eq!(analysis.total_responses, 100);
assert_eq!(analysis.choice_analyses.len(), 5);
for i in 0..5 {
let choice_analysis = analysis.choice_analyses.get(&(i as u32)).unwrap();
assert_eq!(choice_analysis.choice_index, i as u32);
assert_eq!(choice_analysis.positions_analyzed, 100);
}
}
// Helper functions for creating test data
fn create_realistic_stream() -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let responses = vec![
TimestampedResponse::new(
create_response_with_linear_probs(
"Hello",
vec![("Hello", 0.6, vec![("Hi", 0.3), ("Hey", 0.1)])], // Moderate differences
),
0,
),
TimestampedResponse::new(
create_response_with_linear_probs(
" world",
vec![(" world", 0.55, vec![(" there", 0.4), (" everyone", 0.05)])], // Close competition
),
1,
),
TimestampedResponse::new(
create_response_with_linear_probs(
"!",
vec![("!", 0.8, vec![(".", 0.15), ("?", 0.05)])],
), // Clear winner
2,
),
];
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
fn create_multi_choice_stream() -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let responses = vec![
TimestampedResponse::new(
create_multi_choice_response(vec![
// Choice 0: moderate closeness (65% vs 35%)
vec![("token1".to_string(), 0.65, vec![("alt1".to_string(), 0.35)])],
// Choice 1: very close logprobs (51% vs 49%)
vec![("token2".to_string(), 0.51, vec![("alt2".to_string(), 0.49)])],
]),
0,
),
TimestampedResponse::new(
create_multi_choice_response(vec![
// Choice 0: not close (80% vs 20%)
vec![("token3".to_string(), 0.8, vec![("alt3".to_string(), 0.2)])],
// Choice 1: close (53% vs 47%)
vec![("token4".to_string(), 0.53, vec![("alt4".to_string(), 0.47)])],
]),
1,
),
];
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
// fn create_stream_from_recorded_sse_stream(
// file: &str,
// ) -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
// let data = std::fs::read_to_string(file).unwrap();
// let sse_stream = create_message_stream(&data);
// let response_stream =
// convert_sse_stream::<NvCreateChatCompletionStreamResponse>(Box::pin(sse_stream));
// let context = Arc::new(MockContext::new());
// let response_stream = record_stream_with_context(response_stream, context, RecordingMode::Sink);
// let filtered_stream = response_stream.filter_map(|annotated| async move { annotated.data });
// let (recorded_stream, recording_rx) =
// record_stream_with_context(Box::pin(filtered_stream), ctx, RecordingMode::Sink);
// }
fn create_stream_with_multiple_close_tokens(
) -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let responses = vec![TimestampedResponse::new(
create_response_with_linear_probs(
"test",
vec![(
"test",
0.27,
vec![
("best", 0.26), // diff = 0.01
("rest", 0.25), // diff = 0.01 from best, 0.02 from test
("nest", 0.22), // diff = 0.03 from rest, 0.05 from test (sum = 1.0)
],
)],
),
0,
)];
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
fn create_empty_stream() -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let stream = RecordedStream::new(vec![], start_time, Instant::now());
Arc::new(stream)
}
fn create_single_token_stream() -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let responses = vec![TimestampedResponse::new(
create_response_with_linear_probs(
"only",
vec![
("only", 1.0, vec![]), // 100% probability, no alternatives
],
),
0,
)];
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
fn create_graduated_closeness_stream() -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>>
{
let start_time = Instant::now();
let responses = vec![TimestampedResponse::new(
create_response_with_linear_probs(
"test",
vec![
("very_close", 0.501, vec![("alt1", 0.499)]), // diff = 0.002 (very close)
("close", 0.55, vec![("alt2", 0.45)]), // diff = 0.1 (close)
("medium", 0.7, vec![("alt3", 0.3)]), // diff = 0.4 (medium)
("far", 0.9, vec![("alt4", 0.1)]), // diff = 0.8 (far)
],
),
0,
)];
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
fn create_large_stream(
positions: usize,
choices: usize,
) -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let mut responses = Vec::new();
for i in 0..positions {
let mut choice_data = Vec::new();
for j in 0..choices {
let token = format!("token_{}_{}", i, j);
let alt = format!("alt_{}_{}", i, j);
// Create varied but realistic probability distributions
let prob = 0.5 + (i as f32 * 0.001) + (j as f32 * 0.01); // Range: ~0.5-0.6
let alt_prob = 1.0 - prob - 0.05; // Ensure sum < 1, remaining ~5-15% for other tokens
let alt_prob = alt_prob.max(0.1); // Ensure alt_prob is reasonable
choice_data.push(vec![(token, prob, vec![(alt, alt_prob)])]);
}
responses.push(TimestampedResponse::new(
create_multi_choice_response(choice_data),
i,
));
}
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
/// Helper function to create response with linear probabilities [0, 1]
/// This ensures realistic probability distributions that sum to ≤ 1
fn create_response_with_linear_probs(
_content: &str,
token_data: TokenDataVec,
) -> NvCreateChatCompletionStreamResponse {
let token_logprobs = token_data
.into_iter()
.map(|(token, prob, alternatives)| {
// Validate probabilities
assert!(
(0.0..=1.0).contains(&prob),
"Probability must be in [0, 1]: {}",
prob
);
let total_prob = prob + alternatives.iter().map(|(_, p)| p).sum::<f32>();
assert!(
total_prob <= 1.001,
"Total probability mass exceeds 1: {}",
total_prob
);
let top_logprobs = alternatives
.into_iter()
.map(|(alt_token, alt_prob)| {
assert!(
(0.0..=1.0).contains(&alt_prob),
"Probability must be in [0, 1]: {}",
alt_prob
);
TopLogprobs {
token: alt_token.to_string(),
logprob: alt_prob.ln(),
bytes: None,
}
})
.collect();
ChatCompletionTokenLogprob {
token: token.to_string(),
logprob: prob.ln(),
bytes: None,
top_logprobs,
}
})
.collect();
let choice = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: Some(_content.to_string()),
#[expect(deprecated)]
function_call: None,
tool_calls: None,
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: Some(FinishReason::Stop),
logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs),
refusal: None,
}),
};
let inner = CreateChatCompletionStreamResponse {
id: "test_id".to_string(),
choices: vec![choice],
created: 1234567890,
model: "test-model".to_string(),
service_tier: None,
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
};
NvCreateChatCompletionStreamResponse { inner }
}
fn create_multi_choice_response(
choices_data: MultiChoiceData,
) -> NvCreateChatCompletionStreamResponse {
let choices = choices_data
.into_iter()
.enumerate()
.map(|(choice_idx, token_data)| {
let token_logprobs = token_data
.into_iter()
.map(|(token, prob, alternatives)| {
// Validate probabilities
assert!(
(0.0..=1.0).contains(&prob),
"Probability must be in [0, 1]: {}",
prob
);
let total_prob = prob + alternatives.iter().map(|(_, p)| p).sum::<f32>();
assert!(
total_prob <= 1.001,
"Total probability mass exceeds 1: {}",
total_prob
);
let top_logprobs = alternatives
.into_iter()
.map(|(alt_token, alt_prob)| {
assert!(
(0.0..=1.0).contains(&alt_prob),
"Probability must be in [0, 1]: {}",
alt_prob
);
TopLogprobs {
token: alt_token,
logprob: alt_prob.ln(),
bytes: None,
}
})
.collect();
ChatCompletionTokenLogprob {
token,
logprob: prob.ln(),
bytes: None,
top_logprobs,
}
})
.collect();
ChatChoiceStream {
index: choice_idx as u32,
delta: ChatCompletionStreamResponseDelta {
content: Some("test".to_string()),
#[expect(deprecated)]
function_call: None,
tool_calls: None,
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: Some(FinishReason::Stop),
logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs),
refusal: None,
}),
}
})
.collect();
let inner = CreateChatCompletionStreamResponse {
id: "test_id".to_string(),
choices,
created: 1234567890,
model: "test-model".to_string(),
service_tier: None,
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
};
NvCreateChatCompletionStreamResponse { inner }
}
...@@ -92,8 +92,8 @@ impl<T: Send + Sync + 'static> Data for T {} ...@@ -92,8 +92,8 @@ impl<T: Send + Sync + 'static> Data for T {}
/// [`DataStream`] is a type alias for a stream of [`Data`] items. This can be adapted to a [`ResponseStream`] /// [`DataStream`] is a type alias for a stream of [`Data`] items. This can be adapted to a [`ResponseStream`]
/// by associating it with a [`AsyncEngineContext`]. /// by associating it with a [`AsyncEngineContext`].
pub type DataUnary<T> = Pin<Box<dyn Future<Output = T> + Send + Sync>>; pub type DataUnary<T> = Pin<Box<dyn Future<Output = T> + Send>>;
pub type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>; pub type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send>>;
pub type Engine<Req, Resp, E> = Arc<dyn AsyncEngine<Req, Resp, E>>; pub type Engine<Req, Resp, E> = Arc<dyn AsyncEngine<Req, Resp, E>>;
pub type EngineUnary<Resp> = Pin<Box<dyn AsyncEngineUnary<Resp>>>; pub type EngineUnary<Resp> = Pin<Box<dyn AsyncEngineUnary<Resp>>>;
...@@ -174,7 +174,7 @@ pub trait AsyncEngineContextProvider: Send + Debug { ...@@ -174,7 +174,7 @@ pub trait AsyncEngineContextProvider: Send + Debug {
/// This trait combines `Future` semantics with context provider capabilities, /// This trait combines `Future` semantics with context provider capabilities,
/// representing a single async operation that produces one result. /// representing a single async operation that produces one result.
pub trait AsyncEngineUnary<Resp: Data>: pub trait AsyncEngineUnary<Resp: Data>:
Future<Output = Resp> + AsyncEngineContextProvider + Send + Sync Future<Output = Resp> + AsyncEngineContextProvider + Send
{ {
} }
...@@ -183,7 +183,7 @@ pub trait AsyncEngineUnary<Resp: Data>: ...@@ -183,7 +183,7 @@ pub trait AsyncEngineUnary<Resp: Data>:
/// This trait combines `Stream` semantics with context provider capabilities, /// This trait combines `Stream` semantics with context provider capabilities,
/// representing a continuous async operation that produces multiple results over time. /// representing a continuous async operation that produces multiple results over time.
pub trait AsyncEngineStream<Resp: Data>: pub trait AsyncEngineStream<Resp: Data>:
Stream<Item = Resp> + AsyncEngineContextProvider + Send + Sync Stream<Item = Resp> + AsyncEngineContextProvider + Send
{ {
} }
...@@ -204,7 +204,7 @@ pub trait AsyncEngineStream<Resp: Data>: ...@@ -204,7 +204,7 @@ pub trait AsyncEngineStream<Resp: Data>:
/// Implementations should ensure proper error handling and resource management. /// Implementations should ensure proper error handling and resource management.
/// The `generate` method should be cancellable via the response's context provider. /// The `generate` method should be cancellable via the response's context provider.
#[async_trait] #[async_trait]
pub trait AsyncEngine<Req: Data, Resp: Data + AsyncEngineContextProvider, E: Data>: pub trait AsyncEngine<Req: Send + Sync + 'static, Resp: AsyncEngineContextProvider, E: Data>:
Send + Sync Send + Sync
{ {
/// Generate a stream of completion responses. /// Generate a stream of completion responses.
......
...@@ -69,7 +69,7 @@ pub type ServerStreamingEngine<T, U> = ServiceEngine<SingleIn<T>, ManyOut<U>>; ...@@ -69,7 +69,7 @@ pub type ServerStreamingEngine<T, U> = ServiceEngine<SingleIn<T>, ManyOut<U>>;
/// are considered independent of each other; however, they could be constrained to be related. /// are considered independent of each other; however, they could be constrained to be related.
pub type BidirectionalStreamingEngine<T, U> = ServiceEngine<ManyIn<T>, ManyOut<U>>; pub type BidirectionalStreamingEngine<T, U> = ServiceEngine<ManyIn<T>, ManyOut<U>>;
pub trait AsyncTransportEngine<T: PipelineIO, U: PipelineIO>: pub trait AsyncTransportEngine<T: Data + PipelineIO, U: Data + PipelineIO>:
AsyncEngine<T, U, Error> + Send + Sync + 'static AsyncEngine<T, U, Error> + Send + Sync + 'static
{ {
} }
...@@ -97,7 +97,7 @@ mod sealed { ...@@ -97,7 +97,7 @@ mod sealed {
} }
} }
pub trait PipelineIO: Data + sealed::Connectable + AsyncEngineContextProvider { pub trait PipelineIO: sealed::Connectable + AsyncEngineContextProvider + 'static {
fn id(&self) -> String; fn id(&self) -> String;
} }
......
...@@ -280,7 +280,7 @@ pub struct Ingress<Req: PipelineIO, Resp: PipelineIO> { ...@@ -280,7 +280,7 @@ pub struct Ingress<Req: PipelineIO, Resp: PipelineIO> {
segment: OnceLock<Arc<SegmentSource<Req, Resp>>>, segment: OnceLock<Arc<SegmentSource<Req, Resp>>>,
} }
impl<Req: PipelineIO, Resp: PipelineIO> Ingress<Req, Resp> { impl<Req: PipelineIO + Sync, Resp: PipelineIO> Ingress<Req, Resp> {
pub fn new() -> Arc<Self> { pub fn new() -> Arc<Self> {
Arc::new(Self { Arc::new(Self {
segment: OnceLock::new(), segment: OnceLock::new(),
......
...@@ -221,8 +221,8 @@ where ...@@ -221,8 +221,8 @@ where
impl<UpIn, UpOut, DownIn, DownOut> AsyncEngine<UpIn, UpOut, Error> impl<UpIn, UpOut, DownIn, DownOut> AsyncEngine<UpIn, UpOut, Error>
for PipelineOperator<UpIn, UpOut, DownIn, DownOut> for PipelineOperator<UpIn, UpOut, DownIn, DownOut>
where where
UpIn: PipelineIO, UpIn: PipelineIO + Sync,
DownIn: PipelineIO, DownIn: PipelineIO + Sync,
DownOut: PipelineIO, DownOut: PipelineIO,
UpOut: PipelineIO, UpOut: PipelineIO,
{ {
...@@ -235,8 +235,8 @@ where ...@@ -235,8 +235,8 @@ where
impl<UpIn, UpOut, DownIn, DownOut> Sink<UpIn> impl<UpIn, UpOut, DownIn, DownOut> Sink<UpIn>
for PipelineOperatorForwardEdge<UpIn, UpOut, DownIn, DownOut> for PipelineOperatorForwardEdge<UpIn, UpOut, DownIn, DownOut>
where where
UpIn: PipelineIO, UpIn: PipelineIO + Sync,
DownIn: PipelineIO, DownIn: PipelineIO + Sync,
DownOut: PipelineIO, DownOut: PipelineIO,
UpOut: PipelineIO, UpOut: PipelineIO,
{ {
......
...@@ -26,7 +26,7 @@ impl<Req: PipelineIO, Resp: PipelineIO> ServiceBackend<Req, Resp> { ...@@ -26,7 +26,7 @@ impl<Req: PipelineIO, Resp: PipelineIO> ServiceBackend<Req, Resp> {
} }
#[async_trait] #[async_trait]
impl<Req: PipelineIO, Resp: PipelineIO> Sink<Req> for ServiceBackend<Req, Resp> { impl<Req: PipelineIO + Sync, Resp: PipelineIO> Sink<Req> for ServiceBackend<Req, Resp> {
async fn on_data(&self, data: Req, _: Token) -> Result<(), Error> { async fn on_data(&self, data: Req, _: Token) -> Result<(), Error> {
let stream = self.engine.generate(data).await?; let stream = self.engine.generate(data).await?;
self.on_next(stream, Token).await self.on_next(stream, Token).await
......
...@@ -38,7 +38,7 @@ impl<Req: PipelineIO, Resp: PipelineIO> Default for SegmentSink<Req, Resp> { ...@@ -38,7 +38,7 @@ impl<Req: PipelineIO, Resp: PipelineIO> Default for SegmentSink<Req, Resp> {
} }
#[async_trait] #[async_trait]
impl<Req: PipelineIO, Resp: PipelineIO> Sink<Req> for SegmentSink<Req, Resp> { impl<Req: PipelineIO + Sync, Resp: PipelineIO> Sink<Req> for SegmentSink<Req, Resp> {
async fn on_data(&self, data: Req, _: Token) -> Result<(), Error> { async fn on_data(&self, data: Req, _: Token) -> Result<(), Error> {
let stream = self let stream = self
.engine .engine
......
...@@ -68,7 +68,7 @@ impl<In: PipelineIO, Out: PipelineIO + AsyncEngineContextProvider> Sink<Out> for ...@@ -68,7 +68,7 @@ impl<In: PipelineIO, Out: PipelineIO + AsyncEngineContextProvider> Sink<Out> for
} }
#[async_trait] #[async_trait]
impl<In: PipelineIO, Out: PipelineIO> AsyncEngine<In, Out, Error> for Frontend<In, Out> { impl<In: PipelineIO + Sync, Out: PipelineIO> AsyncEngine<In, Out, Error> for Frontend<In, Out> {
async fn generate(&self, request: In) -> Result<Out, Error> { async fn generate(&self, request: In) -> Result<Out, Error> {
let (tx, rx) = oneshot::channel::<Out>(); let (tx, rx) = oneshot::channel::<Out>();
{ {
......
...@@ -48,7 +48,9 @@ macro_rules! impl_frontend { ...@@ -48,7 +48,9 @@ macro_rules! impl_frontend {
} }
#[async_trait] #[async_trait]
impl<In: PipelineIO, Out: PipelineIO> AsyncEngine<In, Out, Error> for $type<In, Out> { impl<In: PipelineIO + Sync, Out: PipelineIO> AsyncEngine<In, Out, Error>
for $type<In, Out>
{
async fn generate(&self, request: In) -> Result<Out, Error> { async fn generate(&self, request: In) -> Result<Out, Error> {
self.inner.generate(request).await self.inner.generate(request).await
} }
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#![allow(dead_code)]
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, OnceLock}; use std::sync::{Arc, OnceLock};
...@@ -159,12 +161,14 @@ impl<T: Data, U: Data> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> ...@@ -159,12 +161,14 @@ impl<T: Data, U: Data> AsyncEngine<SingleIn<T>, ManyOut<U>, Error>
for MockNetworkEgress<SingleIn<T>, ManyOut<U>> for MockNetworkEgress<SingleIn<T>, ManyOut<U>>
where where
T: Data + Serialize, T: Data + Serialize,
U: for<'de> Deserialize<'de> + Data, U: for<'de> Deserialize<'de> + Data + Send + Sync + 'static,
Self: Send + Sync,
{ {
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> { async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
let ctrl_tx = self.ctrl_tx.clone();
let id = request.id().to_string(); let id = request.id().to_string();
// serialze the request // serialize the request
let request = request.try_map(|req| serde_json::to_vec(&req))?; let request = request.try_map(|req| serde_json::to_vec(&req))?;
// transfer the request context to a stream context // transfer the request context to a stream context
...@@ -172,14 +176,11 @@ where ...@@ -172,14 +176,11 @@ where
let context = Arc::new(StreamContext::from(context)); let context = Arc::new(StreamContext::from(context));
// subscribe to the response stream // subscribe to the response stream
// but in this case, we are doing a mock, so we are going to be more explicit // in this mock, we use a channel for the data plane
// since we are transferring data over a channel instead of the networ, creating the channel
// is the same as subscribing to the response stream
let (data_tx, data_rx) = mpsc::channel::<DataPlaneMessage>(16); let (data_tx, data_rx) = mpsc::channel::<DataPlaneMessage>(16);
let mut byte_stream = tokio_stream::wrappers::ReceiverStream::new(data_rx); let mut byte_stream = tokio_stream::wrappers::ReceiverStream::new(data_rx);
// prepare the stateful objects that will be used to monitor the response stream // prepare the stateful objects that will be used to monitor the response stream
// finish_rx is a oneshot channel that will be used to signal the natural termination of the stream
let (finished_tx, finished_rx) = tokio::sync::oneshot::channel::<()>(); let (finished_tx, finished_rx) = tokio::sync::oneshot::channel::<()>();
let stream_monitor = ResponseMonitor { let stream_monitor = ResponseMonitor {
ctx: context.clone(), ctx: context.clone(),
...@@ -187,9 +188,6 @@ where ...@@ -187,9 +188,6 @@ where
}; };
// create the control plane request // create the control plane request
// when this is issued, control is handed off to the control plane and the downstream segment
// sometimes we might include the local server address and port for the response find its way home
// todo(design) this will be part of the generalization error for multiple transport types
let request = ControlPlaneRequest { let request = ControlPlaneRequest {
id, id,
request: data, request: data,
...@@ -197,108 +195,83 @@ where ...@@ -197,108 +195,83 @@ where
}; };
// send the request to the control plane // send the request to the control plane
self.ctrl_tx ctrl_tx
.send(MockNetworkControlEvents::ControlPlaneRequest(request)) .send(MockNetworkControlEvents::ControlPlaneRequest(request))
.await .await
.map_err(|e| PipelineError::ControlPlaneRequestError(e.to_string()))?; .map_err(|e| PipelineError::ControlPlaneRequestError(e.to_string()))?;
// the first message from the remote publisher on the data plane needs to be a handshake message // the first message from the remote publisher on the data plane needs to be a handshake message
// the handshake will indicate to what stream the data belongs to and if the remote segment was
// able to process the request.
//
// note: in the case of the mock transport, the handshaking of the request id is not strictly
// because the channel is specific to the request. this is similar to other transports like nats
// where we will subscribe to a response stream on a subject unique to the stream.
match byte_stream.next().await { match byte_stream.next().await {
Some(DataPlaneMessage { headers, body }) => { Some(DataPlaneMessage { headers, body }) => {
if !body.is_empty() { if !body.is_empty() {
Err(PipelineError::ControlPlaneRequestError( return Err(PipelineError::ControlPlaneRequestError(
"Expected an empty body for the handshake message".to_string(), "Expected an empty body for the handshake message".to_string(),
))?; )
.into());
} }
match headers { match headers {
Some(header) => { Some(header) => match header {
match header { MockNetworkDataPlaneHeaders::Handshake(handshake) => {
MockNetworkDataPlaneHeaders::Handshake(handshake) => { match handshake.status {
match handshake.status { Status::Ok => {}
Status::Ok => {} Status::Error(e) => {
Status::Error(e) => { return Err(PipelineError::ControlPlaneRequestError(format!(
// todo(metrics): increment metric counter for failed handshakes "remote segment was unable to process request: {}",
Err(PipelineError::ControlPlaneRequestError(format!( e
"remote segment was unable to process request: {}", ))
e .into());
)))?;
}
} }
} }
_ => {
Err(PipelineError::ControlPlaneRequestError(format!(
"Expected a handshake message; got: {:?}",
header
)))?;
}
} }
} _ => {
return Err(PipelineError::ControlPlaneRequestError(format!(
"Expected a handshake message; got: {:?}",
header
))
.into());
}
},
_ => { _ => {
Err(PipelineError::ControlPlaneRequestError( return Err(PipelineError::ControlPlaneRequestError(
"Failed to receive properly formatted handshake on data plane" "Failed to receive properly formatted handshake on data plane"
.to_string(), .to_string(),
))?; )
.into());
} }
} }
} }
None => { None => {
// todo(metrics): increment metric counter for failed requests return Err(PipelineError::ControlPlaneRequestError(
Err(PipelineError::ControlPlaneRequestError(
"Failed data plane connection closed before receiving handshake".to_string(), "Failed data plane connection closed before receiving handshake".to_string(),
))?; )
.into());
} }
} }
let decoded = byte_stream let decoded = byte_stream
// .inspect(|_item| {
// // todo(metrics) increment the metrics counter by the number of bytes
// })
.scan(Some(stream_monitor), move |_stream_monitor, item| { .scan(Some(stream_monitor), move |_stream_monitor, item| {
// we could check the kill state of the context and terminate the stream here
// if our transport needs a heartbeat, trigger a heartbeat here the monitor
if let Some(headers) = &item.headers { if let Some(headers) = &item.headers {
match headers { match headers {
MockNetworkDataPlaneHeaders::HeartBeat => { MockNetworkDataPlaneHeaders::HeartBeat => {
// todo(metrics): increment metric counter for heartbeats // Heartbeat received, do nothing special
// send a heartbeat to the control plane
// this is a good place to send a heartbeat to the control plane
// to keep the connection alive
} }
MockNetworkDataPlaneHeaders::Sentinel => { MockNetworkDataPlaneHeaders::Sentinel => {
// todo(metrics): increment metric counter for sentinels // End of stream
// the stream has ended
// send a sentinel to the control plane
// this is a good place to send a sentinel to the control plane
// to indicate the end of the stream
return futures::future::ready(None); return futures::future::ready(None);
} }
_ => {} _ => {}
} }
} }
futures::future::ready(Some(item)) futures::future::ready(Some(item))
}) })
// decode the response
.map(move |item| { .map(move |item| {
serde_json::from_slice::<U>(&item.body).expect("failed to deserialize response") serde_json::from_slice::<U>(&item.body).expect("failed to deserialize response")
}); });
// cancellation can be tricky and is transport / protocol specific
// in this case, our channel for this is both ordered and 1:1, thus we can
// use that fact to first send the request, then forward any cancellation requests
// this ensures the downstream node should register the context/request id before any
// cancellation requests are sent
// create the cancellation monitor object // create the cancellation monitor object
let cancellation_monitor = CancellationMonitor { let cancellation_monitor = CancellationMonitor {
ctx: context.clone(), ctx: context.clone(),
ctrl_tx: self.ctrl_tx.clone(), ctrl_tx,
finish_tx: finished_tx, finish_tx: finished_tx,
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment