Unverified Commit 49b7a0d9 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

feat: record + analyze logprobs (#1957)

parent 6d2be143
...@@ -147,6 +147,15 @@ version = "1.0.0" ...@@ -147,6 +147,15 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fca387cdc0a1f9c7a7c26556d584aa2d07fc529843082e4861003cde4ab914ed" checksum = "fca387cdc0a1f9c7a7c26556d584aa2d07fc529843082e4861003cde4ab914ed"
[[package]]
name = "approx"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "arbitrary" name = "arbitrary"
version = "1.4.1" version = "1.4.1"
...@@ -1773,6 +1782,7 @@ dependencies = [ ...@@ -1773,6 +1782,7 @@ dependencies = [
"akin", "akin",
"aligned-vec", "aligned-vec",
"anyhow", "anyhow",
"approx",
"assert_matches", "assert_matches",
"async-nats", "async-nats",
"async-openai", "async-openai",
......
...@@ -129,6 +129,7 @@ zeromq = "0.4.1" ...@@ -129,6 +129,7 @@ zeromq = "0.4.1"
rmp-serde = "1.3" rmp-serde = "1.3"
[dev-dependencies] [dev-dependencies]
approx = "0.5"
assert_matches = "1.5" assert_matches = "1.5"
criterion = { version = "0.3", features = ["html_reports"] } criterion = { version = "0.3", features = ["html_reports"] }
hf-hub = { workspace = true } hf-hub = { workspace = true }
......
...@@ -27,7 +27,7 @@ use crate::protocols::openai::chat_completions::{ ...@@ -27,7 +27,7 @@ use crate::protocols::openai::chat_completions::{
}; };
use crate::protocols::Annotated; use crate::protocols::Annotated;
use dynamo_runtime::engine::{ use dynamo_runtime::engine::{
AsyncEngineContext, AsyncEngineContextProvider, AsyncEngineStream, Data, AsyncEngineContext, AsyncEngineContextProvider, AsyncEngineStream, Data, DataStream,
}; };
/// Configuration for HTTP clients /// Configuration for HTTP clients
...@@ -226,43 +226,29 @@ impl BaseHttpClient { ...@@ -226,43 +226,29 @@ impl BaseHttpClient {
} }
/// Type alias for NV chat response stream /// Type alias for NV chat response stream
pub type NvChatResponseStream = Pin< pub type NvChatResponseStream =
Box< DataStream<Result<Annotated<NvCreateChatCompletionStreamResponse>, OpenAIError>>;
dyn Stream<Item = Result<Annotated<NvCreateChatCompletionStreamResponse>, OpenAIError>>
+ Send
+ Sync,
>,
>;
/// Type alias for generic BYOT response stream /// Type alias for generic BYOT response stream
pub type ByotResponseStream = Pin<Box<dyn Stream<Item = Result<Value, OpenAIError>> + Send + Sync>>; pub type ByotResponseStream = DataStream<Result<Value, OpenAIError>>;
/// Type alias for pure OpenAI chat response stream /// Type alias for pure OpenAI chat response stream
pub type OpenAIChatResponseStream = Pin< pub type OpenAIChatResponseStream =
Box< DataStream<Result<async_openai::types::CreateChatCompletionStreamResponse, OpenAIError>>;
dyn Stream<
Item = Result<async_openai::types::CreateChatCompletionStreamResponse, OpenAIError>,
> + Send
+ Sync,
>,
>;
/// A wrapped HTTP response stream that combines a stream with its context /// A wrapped HTTP response stream that combines a stream with its context
/// This provides a unified interface for HTTP client responses /// This provides a unified interface for HTTP client responses
#[derive(Dissolve)] #[derive(Dissolve)]
pub struct HttpResponseStream<T> { pub struct HttpResponseStream<T> {
/// The underlying stream of responses /// The underlying stream of responses
pub stream: Pin<Box<dyn Stream<Item = T> + Send>>, pub stream: DataStream<T>,
/// The context for this request /// The context for this request
pub context: Arc<dyn AsyncEngineContext>, pub context: Arc<dyn AsyncEngineContext>,
} }
impl<T> HttpResponseStream<T> { impl<T> HttpResponseStream<T> {
/// Create a new HttpResponseStream /// Create a new HttpResponseStream
pub fn new( pub fn new(stream: DataStream<T>, context: Arc<dyn AsyncEngineContext>) -> Self {
stream: Pin<Box<dyn Stream<Item = T> + Send>>,
context: Arc<dyn AsyncEngineContext>,
) -> Self {
Self { stream, context } Self { stream, context }
} }
} }
...@@ -299,7 +285,7 @@ impl<T: Data> HttpResponseStream<T> { ...@@ -299,7 +285,7 @@ impl<T: Data> HttpResponseStream<T> {
/// A wrapper that implements AsyncEngineStream for streams that are Send + Sync /// A wrapper that implements AsyncEngineStream for streams that are Send + Sync
struct AsyncEngineStreamWrapper<T> { struct AsyncEngineStreamWrapper<T> {
stream: Pin<Box<dyn Stream<Item = T> + Send>>, stream: DataStream<T>,
context: Arc<dyn AsyncEngineContext>, context: Arc<dyn AsyncEngineContext>,
} }
...@@ -317,10 +303,6 @@ impl<T: Data> AsyncEngineContextProvider for AsyncEngineStreamWrapper<T> { ...@@ -317,10 +303,6 @@ impl<T: Data> AsyncEngineContextProvider for AsyncEngineStreamWrapper<T> {
} }
} }
// This is unsafe because we're claiming the stream is Sync when it might not be
// But this is needed for the AsyncEngineStream trait
unsafe impl<T> Sync for AsyncEngineStreamWrapper<T> {}
impl<T: Data> AsyncEngineStream<T> for AsyncEngineStreamWrapper<T> {} impl<T: Data> AsyncEngineStream<T> for AsyncEngineStreamWrapper<T> {}
impl<T> std::fmt::Debug for AsyncEngineStreamWrapper<T> { impl<T> std::fmt::Debug for AsyncEngineStreamWrapper<T> {
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
//! This module provides mechanisms to record streaming responses with minimal overhead //! This module provides mechanisms to record streaming responses with minimal overhead
//! during collection, then analyze the recorded data for performance insights. //! during collection, then analyze the recorded data for performance insights.
pub mod logprobs;
use futures::Stream; use futures::Stream;
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
...@@ -339,7 +341,7 @@ pub fn record_response_stream<R: Data + Clone>( ...@@ -339,7 +341,7 @@ pub fn record_response_stream<R: Data + Clone>(
} }
#[cfg(test)] #[cfg(test)]
mod tests { pub mod tests {
use super::*; use super::*;
use dynamo_runtime::engine::ResponseStream; use dynamo_runtime::engine::ResponseStream;
use futures::stream; use futures::stream;
......
This diff is collapsed.
...@@ -19,9 +19,7 @@ ...@@ -19,9 +19,7 @@
//! both publicly via the HTTP API and internally between Dynamo components. //! both publicly via the HTTP API and internally between Dynamo components.
//! //!
use std::pin::Pin; use futures::StreamExt;
use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub mod codec; pub mod codec;
...@@ -30,7 +28,7 @@ pub mod openai; ...@@ -30,7 +28,7 @@ pub mod openai;
/// The token ID type /// The token ID type
pub type TokenIdType = u32; pub type TokenIdType = u32;
pub type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>; pub use dynamo_runtime::engine::DataStream;
// TODO: This is an awkward dependency that we need to address // TODO: This is an awkward dependency that we need to address
// Originally, all the Annotated/SSE Codec bits where in the LLM protocol module; however, [Annotated] // Originally, all the Annotated/SSE Codec bits where in the LLM protocol module; however, [Annotated]
......
...@@ -13,9 +13,8 @@ ...@@ -13,9 +13,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{collections::HashMap, pin::Pin}; use futures::StreamExt;
use std::collections::HashMap;
use futures::{Stream, StreamExt};
use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse}; use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
use crate::protocols::{ use crate::protocols::{
...@@ -24,7 +23,7 @@ use crate::protocols::{ ...@@ -24,7 +23,7 @@ use crate::protocols::{
}; };
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`. /// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>; use dynamo_runtime::engine::DataStream;
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single /// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
/// [`NvCreateChatCompletionResponse`]. This struct accumulates incremental responses /// [`NvCreateChatCompletionResponse`]. This struct accumulates incremental responses
......
...@@ -13,18 +13,14 @@ ...@@ -13,18 +13,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::pin::Pin;
use futures::{Stream, StreamExt};
use super::NvCreateEmbeddingResponse; use super::NvCreateEmbeddingResponse;
use crate::protocols::{ use crate::protocols::{
codec::{Message, SseCodecError}, codec::{Message, SseCodecError},
convert_sse_stream, Annotated, convert_sse_stream, Annotated,
}; };
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`. use dynamo_runtime::engine::DataStream;
type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>; use futures::StreamExt;
/// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single /// Aggregates a stream of [`NvCreateEmbeddingResponse`]s into a single
/// [`NvCreateEmbeddingResponse`]. For embeddings, this is typically simpler /// [`NvCreateEmbeddingResponse`]. For embeddings, this is typically simpler
......
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":"Okay","tool_calls":[]},"logprobs":{"content":[{"token":"Okay","logprob":-0.5292773246765137,"bytes":[79,107,97,121],"top_logprobs":[{"token":"Okay","logprob":-0.5292773246765137,"bytes":[79,107,97,121]},{"token":"Alright","logprob":-0.9042773246765137,"bytes":[65,108,114,105,103,104,116]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":",","tool_calls":[]},"logprobs":{"content":[{"token":",","logprob":-0.000017165990357170813,"bytes":[44],"top_logprobs":[{"token":",","logprob":-0.000017165990357170813,"bytes":[44]},{"token":"Ġso","logprob":-11.812517166137695,"bytes":[196,160,115,111]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" so","tool_calls":[]},"logprobs":{"content":[{"token":"Ġso","logprob":-0.10039777308702469,"bytes":[196,160,115,111],"top_logprobs":[{"token":"Ġso","logprob":-0.10039777308702469,"bytes":[196,160,115,111]},{"token":"Ġthe","logprob":-2.600397825241089,"bytes":[196,160,116,104,101]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" I","tool_calls":[]},"logprobs":{"content":[{"token":"ĠI","logprob":-0.07118851691484451,"bytes":[196,160,73],"top_logprobs":[{"token":"ĠI","logprob":-0.07118851691484451,"bytes":[196,160,73]},{"token":"Ġthe","logprob":-2.696188449859619,"bytes":[196,160,116,104,101]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":"'m","tool_calls":[]},"logprobs":{"content":[{"token":"'m","logprob":-0.5393549799919128,"bytes":[39,109],"top_logprobs":[{"token":"'m","logprob":-0.5393549799919128,"bytes":[39,109]},{"token":"'ve","logprob":-2.2893550395965576,"bytes":[39,118,101]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" trying","tool_calls":[]},"logprobs":{"content":[{"token":"Ġtrying","logprob":-0.2027934193611145,"bytes":[196,160,116,114,121,105,110,103],"top_logprobs":[{"token":"Ġtrying","logprob":-0.2027934193611145,"bytes":[196,160,116,114,121,105,110,103]},{"token":"Ġlooking","logprob":-1.8277933597564697,"bytes":[196,160,108,111,111,107,105,110,103]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" to","tool_calls":[]},"logprobs":{"content":[{"token":"Ġto","logprob":-1.5497195136049413e-6,"bytes":[196,160,116,111],"top_logprobs":[{"token":"Ġto","logprob":-1.5497195136049413e-6,"bytes":[196,160,116,111]},{"token":"Ġout","logprob":-14.187501907348633,"bytes":[196,160,111,117,116]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" figure","tool_calls":[]},"logprobs":{"content":[{"token":"Ġfigure","logprob":-0.42643895745277405,"bytes":[196,160,102,105,103,117,114,101],"top_logprobs":[{"token":"Ġfigure","logprob":-0.42643895745277405,"bytes":[196,160,102,105,103,117,114,101]},{"token":"Ġunderstand","logprob":-1.1764389276504517,"bytes":[196,160,117,110,100,101,114,115,116,97,110,100]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" out","tool_calls":[]},"logprobs":{"content":[{"token":"Ġout","logprob":-0.00021181246847845614,"bytes":[196,160,111,117,116],"top_logprobs":[{"token":"Ġout","logprob":-0.00021181246847845614,"bytes":[196,160,111,117,116]},{"token":"Ġthis","logprob":-8.500211715698242,"bytes":[196,160,116,104,105,115]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" how","tool_calls":[]},"logprobs":{"content":[{"token":"Ġhow","logprob":-0.5830438137054443,"bytes":[196,160,104,111,119],"top_logprobs":[{"token":"Ġhow","logprob":-0.5830438137054443,"bytes":[196,160,104,111,119]},{"token":"Ġwhat","logprob":-1.0830438137054443,"bytes":[196,160,119,104,97,116]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" to","tool_calls":[]},"logprobs":{"content":[{"token":"Ġto","logprob":-0.0042633600533008575,"bytes":[196,160,116,111],"top_logprobs":[{"token":"Ġto","logprob":-0.0042633600533008575,"bytes":[196,160,116,111]},{"token":"Ġthis","logprob":-6.004263401031494,"bytes":[196,160,116,104,105,115]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" respond","tool_calls":[]},"logprobs":{"content":[{"token":"Ġrespond","logprob":-0.5788105726242065,"bytes":[196,160,114,101,115,112,111,110,100],"top_logprobs":[{"token":"Ġrespond","logprob":-0.5788105726242065,"bytes":[196,160,114,101,115,112,111,110,100]},{"token":"Ġapproach","logprob":-1.2038105726242065,"bytes":[196,160,97,112,112,114,111,97,99,104]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" to","tool_calls":[]},"logprobs":{"content":[{"token":"Ġto","logprob":-0.0014138950500637293,"bytes":[196,160,116,111],"top_logprobs":[{"token":"Ġto","logprob":-0.0014138950500637293,"bytes":[196,160,116,111]},{"token":"Ġhelp","logprob":-7.751413822174072,"bytes":[196,160,104,101,108,112]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" this","tool_calls":[]},"logprobs":{"content":[{"token":"Ġthis","logprob":-0.16383041441440582,"bytes":[196,160,116,104,105,115],"top_logprobs":[{"token":"Ġthis","logprob":-0.16383041441440582,"bytes":[196,160,116,104,105,115]},{"token":"Ġthe","logprob":-1.9138303995132446,"bytes":[196,160,116,104,101]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" user","tool_calls":[]},"logprobs":{"content":[{"token":"Ġuser","logprob":-0.342995822429657,"bytes":[196,160,117,115,101,114],"top_logprobs":[{"token":"Ġuser","logprob":-0.342995822429657,"bytes":[196,160,117,115,101,114]},{"token":"Ġmessage","logprob":-2.4054958820343018,"bytes":[196,160,109,101,115,115,97,103,101]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":"'s","tool_calls":[]},"logprobs":{"content":[{"token":"'s","logprob":-0.6218149065971375,"bytes":[39,115],"top_logprobs":[{"token":"'s","logprob":-0.6218149065971375,"bytes":[39,115]},{"token":".","logprob":-1.2468149662017822,"bytes":[46]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" message","tool_calls":[]},"logprobs":{"content":[{"token":"Ġmessage","logprob":-0.49226677417755127,"bytes":[196,160,109,101,115,115,97,103,101],"top_logprobs":[{"token":"Ġmessage","logprob":-0.49226677417755127,"bytes":[196,160,109,101,115,115,97,103,101]},{"token":"Ġquery","logprob":-1.1172667741775513,"bytes":[196,160,113,117,101,114,121]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":".","tool_calls":[]},"logprobs":{"content":[{"token":".","logprob":-0.004951002076268196,"bytes":[46],"top_logprobs":[{"token":".","logprob":-0.004951002076268196,"bytes":[46]},{"token":"Ġthat","logprob":-5.879951000213623,"bytes":[196,160,116,104,97,116]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" They","tool_calls":[]},"logprobs":{"content":[{"token":"ĠThey","logprob":-0.038852665573358536,"bytes":[196,160,84,104,101,121],"top_logprobs":[{"token":"ĠThey","logprob":-0.038852665573358536,"bytes":[196,160,84,104,101,121]},{"token":"ĠLet","logprob":-3.7888526916503906,"bytes":[196,160,76,101,116]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" provided","tool_calls":[]},"logprobs":{"content":[{"token":"Ġprovided","logprob":-0.4865674376487732,"bytes":[196,160,112,114,111,118,105,100,101,100],"top_logprobs":[{"token":"Ġprovided","logprob":-0.4865674376487732,"bytes":[196,160,112,114,111,118,105,100,101,100]},{"token":"'ve","logprob":-1.736567497253418,"bytes":[39,118,101]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" a","tool_calls":[]},"logprobs":{"content":[{"token":"Ġa","logprob":-0.08489075303077698,"bytes":[196,160,97],"top_logprobs":[{"token":"Ġa","logprob":-0.08489075303077698,"bytes":[196,160,97]},{"token":"Ġsome","logprob":-2.584890842437744,"bytes":[196,160,115,111,109,101]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" block","tool_calls":[]},"logprobs":{"content":[{"token":"Ġblock","logprob":-0.9078922271728516,"bytes":[196,160,98,108,111,99,107],"top_logprobs":[{"token":"Ġblock","logprob":-0.9078922271728516,"bytes":[196,160,98,108,111,99,107]},{"token":"Ġchunk","logprob":-0.9078922271728516,"bytes":[196,160,99,104,117,110,107]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" of","tool_calls":[]},"logprobs":{"content":[{"token":"Ġof","logprob":-4.172316494077677e-6,"bytes":[196,160,111,102],"top_logprobs":[{"token":"Ġof","logprob":-4.172316494077677e-6,"bytes":[196,160,111,102]},{"token":"Ġthat","logprob":-13.062503814697266,"bytes":[196,160,116,104,97,116]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" text","tool_calls":[]},"logprobs":{"content":[{"token":"Ġtext","logprob":-0.1239960640668869,"bytes":[196,160,116,101,120,116],"top_logprobs":[{"token":"Ġtext","logprob":-0.1239960640668869,"bytes":[196,160,116,101,120,116]},{"token":"ĠLorem","logprob":-2.7489960193634033,"bytes":[196,160,76,111,114,101,109]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" that","tool_calls":[]},"logprobs":{"content":[{"token":"Ġthat","logprob":-0.021982578560709953,"bytes":[196,160,116,104,97,116],"top_logprobs":[{"token":"Ġthat","logprob":-0.021982578560709953,"bytes":[196,160,116,104,97,116]},{"token":"Ġwhich","logprob":-4.646982669830322,"bytes":[196,160,119,104,105,99,104]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" looks","tool_calls":[]},"logprobs":{"content":[{"token":"Ġlooks","logprob":-0.6330966353416443,"bytes":[196,160,108,111,111,107,115],"top_logprobs":[{"token":"Ġlooks","logprob":-0.6330966353416443,"bytes":[196,160,108,111,111,107,115]},{"token":"'s","logprob":-1.133096694946289,"bytes":[39,115]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" like","tool_calls":[]},"logprobs":{"content":[{"token":"Ġlike","logprob":-0.001482222112827003,"bytes":[196,160,108,105,107,101],"top_logprobs":[{"token":"Ġlike","logprob":-0.001482222112827003,"bytes":[196,160,108,105,107,101]},{"token":"Ġa","logprob":-7.001482009887695,"bytes":[196,160,97]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" Lorem","tool_calls":[]},"logprobs":{"content":[{"token":"ĠLorem","logprob":-0.2382608950138092,"bytes":[196,160,76,111,114,101,109],"top_logprobs":[{"token":"ĠLorem","logprob":-0.2382608950138092,"bytes":[196,160,76,111,114,101,109]},{"token":"Ġplaceholder","logprob":-2.6132609844207764,"bytes":[196,160,112,108,97,99,101,104,111,108,100,101,114]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" Ipsum","tool_calls":[]},"logprobs":{"content":[{"token":"ĠIpsum","logprob":-0.22565951943397522,"bytes":[196,160,73,112,115,117,109],"top_logprobs":[{"token":"ĠIpsum","logprob":-0.22565951943397522,"bytes":[196,160,73,112,115,117,109]},{"token":"Ġipsum","logprob":-1.6006594896316528,"bytes":[196,160,105,112,115,117,109]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":",","tool_calls":[]},"logprobs":{"content":[{"token":",","logprob":-0.02414931170642376,"bytes":[44],"top_logprobs":[{"token":",","logprob":-0.02414931170642376,"bytes":[44]},{"token":"Ġand","logprob":-4.399149417877197,"bytes":[196,160,97,110,100]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" which","tool_calls":[]},"logprobs":{"content":[{"token":"Ġwhich","logprob":-0.02117946185171604,"bytes":[196,160,119,104,105,99,104],"top_logprobs":[{"token":"Ġwhich","logprob":-0.02117946185171604,"bytes":[196,160,119,104,105,99,104]},{"token":"Ġand","logprob":-4.271179676055908,"bytes":[196,160,97,110,100]}]}]}}]}
data: {"id":"chatcmpl-33523fb1b0d24e93b89686d88c3284d1","object":"chat.completion.chunk","created":1752563089,"model":"deepseek-ai/DeepSeek-R1-Distill-Llama-8B","choices":[{"index":0,"delta":{"content":" is","tool_calls":[]},"logprobs":{"content":[{"token":"Ġis","logprob":-0.43066686391830444,"bytes":[196,160,105,115],"top_logprobs":[{"token":"Ġis","logprob":-0.43066686391830444,"bytes":[196,160,105,115]},{"token":"ĠI","logprob":-1.0556669235229492,"bytes":[196,160,73]}]}]},"finish_reason":"length"}]}
data: [DONE]
captured from 0.9.0.2.dev22+gbc825748a.d20250715.precompiled
script to generate deepseek-r1-distill-llama-8b/chat-completions.stream.logprobs.1
```
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [{"role": "user", "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat."}],
"max_tokens": 32,
"temperature": 0.0,
"top_p": 0.001,"stream":true,"logprobs":1,"top_logprobs":2
}'
```
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for logprob analysis functionality
use std::sync::Arc;
use std::time::Instant;
use dynamo_llm::perf::logprobs::analyze_logprob_sensitivity;
use dynamo_llm::perf::{RecordedStream, TimestampedResponse};
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
use async_openai::types::{
ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionStreamResponseDelta,
ChatCompletionTokenLogprob, CreateChatCompletionStreamResponse, FinishReason, Role,
TopLogprobs,
};
// Type aliases to simplify complex test data structures
type TokenAlternative = (&'static str, f32);
type TokenData = (&'static str, f32, Vec<TokenAlternative>);
type TokenDataVec = Vec<TokenData>;
// Type aliases for multi-choice test data (using String instead of &str)
type StringTokenAlternative = (String, f32);
type StringTokenData = (String, f32, Vec<StringTokenAlternative>);
type ChoiceTokenData = Vec<StringTokenData>;
type MultiChoiceData = Vec<ChoiceTokenData>;
/// Test full workflow with realistic streaming data
#[test]
fn test_realistic_streaming_analysis() {
let stream = create_realistic_stream();
let analysis = analyze_logprob_sensitivity(stream);
// Verify basic structure
assert_eq!(analysis.total_responses, 3);
assert_eq!(analysis.choice_analyses.len(), 1);
assert_eq!(
analysis.choice_analyses.get(&0).unwrap().positions_analyzed,
3
);
// Check that positions are sorted by closeness
let positions = &analysis.choice_analyses.get(&0).unwrap().position_closeness;
for i in 1..positions.len() {
assert!(positions[i - 1].probability_difference <= positions[i].probability_difference);
}
// Test API methods
let close_positions = analysis.get_close_positions_for_choice(0, 0.2);
assert!(!close_positions.is_empty());
let percentage = analysis.close_position_percentage_for_choice(0, 0.2);
assert!((0.0..=100.0).contains(&percentage));
}
/// Test multiple choices analysis
#[test]
fn test_multiple_choices_independent_analysis() {
let stream = create_multi_choice_stream();
let analysis = analyze_logprob_sensitivity(stream);
// Should have 2 choices
assert_eq!(analysis.choice_analyses.len(), 2);
// Each choice should be analyzed independently
let choice0_count = analysis.choice_analyses.get(&0).unwrap().positions_analyzed;
let choice1_count = analysis.choice_analyses.get(&1).unwrap().positions_analyzed;
assert_eq!(choice0_count, 2);
assert_eq!(choice1_count, 2);
// Test that choices have different closeness patterns
let choice0_close = analysis.get_close_positions_for_choice(0, 0.3);
let choice1_close = analysis.get_close_positions_for_choice(1, 0.3);
// Based on our test data, choice 1 should have closer logprobs
assert!(choice1_close.len() >= choice0_close.len());
}
/// Test detection of multiple close tokens
#[test]
fn test_multiple_close_tokens_detection() {
let stream = create_stream_with_multiple_close_tokens();
let analysis = analyze_logprob_sensitivity(stream);
// Should detect positions with 3+ close tokens
let multiple_close = analysis.detect_multiple_close_tokens(0, 0.05);
assert!(!multiple_close.is_empty());
let first_multiple = &multiple_close[0];
assert!(first_multiple.close_count >= 3);
assert!(first_multiple.max_difference <= 0.05);
// Verify the close tokens are actually close in probability space
for i in 1..first_multiple.close_tokens.len() {
let prob_top = first_multiple.close_tokens[0].logprob.exp();
let prob_current = first_multiple.close_tokens[i].logprob.exp();
let diff = prob_top - prob_current;
assert!(diff <= 0.05);
}
}
/// Test edge cases and error handling
#[test]
fn test_edge_cases() {
// Empty stream
let empty_stream = create_empty_stream();
let analysis = analyze_logprob_sensitivity(empty_stream);
assert_eq!(analysis.total_responses, 0);
assert!(analysis.choice_analyses.is_empty());
// Single token positions (no alternatives)
let single_token_stream = create_single_token_stream();
let analysis = analyze_logprob_sensitivity(single_token_stream);
// Should have no close positions since there's only one token per position
let close_positions = analysis.get_close_positions_for_choice(0, 1.0);
assert!(close_positions.is_empty());
}
/// Test threshold sensitivity
#[test]
fn test_threshold_sensitivity() {
let stream = create_graduated_closeness_stream();
let analysis = analyze_logprob_sensitivity(stream);
// Test different thresholds
let strict_close = analysis.get_close_positions_for_choice(0, 0.01);
let permissive_close = analysis.get_close_positions_for_choice(0, 0.1);
let very_permissive_close = analysis.get_close_positions_for_choice(0, 0.5);
// Should have increasing numbers of close positions
assert!(strict_close.len() <= permissive_close.len());
assert!(permissive_close.len() <= very_permissive_close.len());
// Percentages should increase with threshold
let strict_pct = analysis.close_position_percentage_for_choice(0, 0.01);
let permissive_pct = analysis.close_position_percentage_for_choice(0, 0.1);
assert!(strict_pct <= permissive_pct);
}
/// Test performance with larger datasets
#[test]
fn test_large_dataset_performance() {
let stream = create_large_stream(100, 5); // 100 positions, 5 choices
let start_time = Instant::now();
let analysis = analyze_logprob_sensitivity(stream);
let elapsed = start_time.elapsed();
// Should complete quickly
assert!(elapsed.as_millis() < 100);
// Verify correctness
assert_eq!(analysis.total_responses, 100);
assert_eq!(analysis.choice_analyses.len(), 5);
for i in 0..5 {
let choice_analysis = analysis.choice_analyses.get(&(i as u32)).unwrap();
assert_eq!(choice_analysis.choice_index, i as u32);
assert_eq!(choice_analysis.positions_analyzed, 100);
}
}
// Helper functions for creating test data
fn create_realistic_stream() -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let responses = vec![
TimestampedResponse::new(
create_response_with_linear_probs(
"Hello",
vec![("Hello", 0.6, vec![("Hi", 0.3), ("Hey", 0.1)])], // Moderate differences
),
0,
),
TimestampedResponse::new(
create_response_with_linear_probs(
" world",
vec![(" world", 0.55, vec![(" there", 0.4), (" everyone", 0.05)])], // Close competition
),
1,
),
TimestampedResponse::new(
create_response_with_linear_probs(
"!",
vec![("!", 0.8, vec![(".", 0.15), ("?", 0.05)])],
), // Clear winner
2,
),
];
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
fn create_multi_choice_stream() -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let responses = vec![
TimestampedResponse::new(
create_multi_choice_response(vec![
// Choice 0: moderate closeness (65% vs 35%)
vec![("token1".to_string(), 0.65, vec![("alt1".to_string(), 0.35)])],
// Choice 1: very close logprobs (51% vs 49%)
vec![("token2".to_string(), 0.51, vec![("alt2".to_string(), 0.49)])],
]),
0,
),
TimestampedResponse::new(
create_multi_choice_response(vec![
// Choice 0: not close (80% vs 20%)
vec![("token3".to_string(), 0.8, vec![("alt3".to_string(), 0.2)])],
// Choice 1: close (53% vs 47%)
vec![("token4".to_string(), 0.53, vec![("alt4".to_string(), 0.47)])],
]),
1,
),
];
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
// fn create_stream_from_recorded_sse_stream(
// file: &str,
// ) -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
// let data = std::fs::read_to_string(file).unwrap();
// let sse_stream = create_message_stream(&data);
// let response_stream =
// convert_sse_stream::<NvCreateChatCompletionStreamResponse>(Box::pin(sse_stream));
// let context = Arc::new(MockContext::new());
// let response_stream = record_stream_with_context(response_stream, context, RecordingMode::Sink);
// let filtered_stream = response_stream.filter_map(|annotated| async move { annotated.data });
// let (recorded_stream, recording_rx) =
// record_stream_with_context(Box::pin(filtered_stream), ctx, RecordingMode::Sink);
// }
fn create_stream_with_multiple_close_tokens(
) -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let responses = vec![TimestampedResponse::new(
create_response_with_linear_probs(
"test",
vec![(
"test",
0.27,
vec![
("best", 0.26), // diff = 0.01
("rest", 0.25), // diff = 0.01 from best, 0.02 from test
("nest", 0.22), // diff = 0.03 from rest, 0.05 from test (sum = 1.0)
],
)],
),
0,
)];
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
fn create_empty_stream() -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let stream = RecordedStream::new(vec![], start_time, Instant::now());
Arc::new(stream)
}
fn create_single_token_stream() -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let responses = vec![TimestampedResponse::new(
create_response_with_linear_probs(
"only",
vec![
("only", 1.0, vec![]), // 100% probability, no alternatives
],
),
0,
)];
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
fn create_graduated_closeness_stream() -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>>
{
let start_time = Instant::now();
let responses = vec![TimestampedResponse::new(
create_response_with_linear_probs(
"test",
vec![
("very_close", 0.501, vec![("alt1", 0.499)]), // diff = 0.002 (very close)
("close", 0.55, vec![("alt2", 0.45)]), // diff = 0.1 (close)
("medium", 0.7, vec![("alt3", 0.3)]), // diff = 0.4 (medium)
("far", 0.9, vec![("alt4", 0.1)]), // diff = 0.8 (far)
],
),
0,
)];
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
fn create_large_stream(
positions: usize,
choices: usize,
) -> Arc<RecordedStream<NvCreateChatCompletionStreamResponse>> {
let start_time = Instant::now();
let mut responses = Vec::new();
for i in 0..positions {
let mut choice_data = Vec::new();
for j in 0..choices {
let token = format!("token_{}_{}", i, j);
let alt = format!("alt_{}_{}", i, j);
// Create varied but realistic probability distributions
let prob = 0.5 + (i as f32 * 0.001) + (j as f32 * 0.01); // Range: ~0.5-0.6
let alt_prob = 1.0 - prob - 0.05; // Ensure sum < 1, remaining ~5-15% for other tokens
let alt_prob = alt_prob.max(0.1); // Ensure alt_prob is reasonable
choice_data.push(vec![(token, prob, vec![(alt, alt_prob)])]);
}
responses.push(TimestampedResponse::new(
create_multi_choice_response(choice_data),
i,
));
}
let stream = RecordedStream::new(responses, start_time, Instant::now());
Arc::new(stream)
}
/// Helper function to create response with linear probabilities [0, 1]
/// This ensures realistic probability distributions that sum to ≤ 1
fn create_response_with_linear_probs(
_content: &str,
token_data: TokenDataVec,
) -> NvCreateChatCompletionStreamResponse {
let token_logprobs = token_data
.into_iter()
.map(|(token, prob, alternatives)| {
// Validate probabilities
assert!(
(0.0..=1.0).contains(&prob),
"Probability must be in [0, 1]: {}",
prob
);
let total_prob = prob + alternatives.iter().map(|(_, p)| p).sum::<f32>();
assert!(
total_prob <= 1.001,
"Total probability mass exceeds 1: {}",
total_prob
);
let top_logprobs = alternatives
.into_iter()
.map(|(alt_token, alt_prob)| {
assert!(
(0.0..=1.0).contains(&alt_prob),
"Probability must be in [0, 1]: {}",
alt_prob
);
TopLogprobs {
token: alt_token.to_string(),
logprob: alt_prob.ln(),
bytes: None,
}
})
.collect();
ChatCompletionTokenLogprob {
token: token.to_string(),
logprob: prob.ln(),
bytes: None,
top_logprobs,
}
})
.collect();
let choice = ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: Some(_content.to_string()),
#[expect(deprecated)]
function_call: None,
tool_calls: None,
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: Some(FinishReason::Stop),
logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs),
refusal: None,
}),
};
let inner = CreateChatCompletionStreamResponse {
id: "test_id".to_string(),
choices: vec![choice],
created: 1234567890,
model: "test-model".to_string(),
service_tier: None,
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
};
NvCreateChatCompletionStreamResponse { inner }
}
fn create_multi_choice_response(
choices_data: MultiChoiceData,
) -> NvCreateChatCompletionStreamResponse {
let choices = choices_data
.into_iter()
.enumerate()
.map(|(choice_idx, token_data)| {
let token_logprobs = token_data
.into_iter()
.map(|(token, prob, alternatives)| {
// Validate probabilities
assert!(
(0.0..=1.0).contains(&prob),
"Probability must be in [0, 1]: {}",
prob
);
let total_prob = prob + alternatives.iter().map(|(_, p)| p).sum::<f32>();
assert!(
total_prob <= 1.001,
"Total probability mass exceeds 1: {}",
total_prob
);
let top_logprobs = alternatives
.into_iter()
.map(|(alt_token, alt_prob)| {
assert!(
(0.0..=1.0).contains(&alt_prob),
"Probability must be in [0, 1]: {}",
alt_prob
);
TopLogprobs {
token: alt_token,
logprob: alt_prob.ln(),
bytes: None,
}
})
.collect();
ChatCompletionTokenLogprob {
token,
logprob: prob.ln(),
bytes: None,
top_logprobs,
}
})
.collect();
ChatChoiceStream {
index: choice_idx as u32,
delta: ChatCompletionStreamResponseDelta {
content: Some("test".to_string()),
#[expect(deprecated)]
function_call: None,
tool_calls: None,
role: Some(Role::Assistant),
refusal: None,
},
finish_reason: Some(FinishReason::Stop),
logprobs: Some(ChatChoiceLogprobs {
content: Some(token_logprobs),
refusal: None,
}),
}
})
.collect();
let inner = CreateChatCompletionStreamResponse {
id: "test_id".to_string(),
choices,
created: 1234567890,
model: "test-model".to_string(),
service_tier: None,
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
};
NvCreateChatCompletionStreamResponse { inner }
}
...@@ -92,8 +92,8 @@ impl<T: Send + Sync + 'static> Data for T {} ...@@ -92,8 +92,8 @@ impl<T: Send + Sync + 'static> Data for T {}
/// [`DataStream`] is a type alias for a stream of [`Data`] items. This can be adapted to a [`ResponseStream`] /// [`DataStream`] is a type alias for a stream of [`Data`] items. This can be adapted to a [`ResponseStream`]
/// by associating it with a [`AsyncEngineContext`]. /// by associating it with a [`AsyncEngineContext`].
pub type DataUnary<T> = Pin<Box<dyn Future<Output = T> + Send + Sync>>; pub type DataUnary<T> = Pin<Box<dyn Future<Output = T> + Send>>;
pub type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>; pub type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send>>;
pub type Engine<Req, Resp, E> = Arc<dyn AsyncEngine<Req, Resp, E>>; pub type Engine<Req, Resp, E> = Arc<dyn AsyncEngine<Req, Resp, E>>;
pub type EngineUnary<Resp> = Pin<Box<dyn AsyncEngineUnary<Resp>>>; pub type EngineUnary<Resp> = Pin<Box<dyn AsyncEngineUnary<Resp>>>;
...@@ -174,7 +174,7 @@ pub trait AsyncEngineContextProvider: Send + Debug { ...@@ -174,7 +174,7 @@ pub trait AsyncEngineContextProvider: Send + Debug {
/// This trait combines `Future` semantics with context provider capabilities, /// This trait combines `Future` semantics with context provider capabilities,
/// representing a single async operation that produces one result. /// representing a single async operation that produces one result.
pub trait AsyncEngineUnary<Resp: Data>: pub trait AsyncEngineUnary<Resp: Data>:
Future<Output = Resp> + AsyncEngineContextProvider + Send + Sync Future<Output = Resp> + AsyncEngineContextProvider + Send
{ {
} }
...@@ -183,7 +183,7 @@ pub trait AsyncEngineUnary<Resp: Data>: ...@@ -183,7 +183,7 @@ pub trait AsyncEngineUnary<Resp: Data>:
/// This trait combines `Stream` semantics with context provider capabilities, /// This trait combines `Stream` semantics with context provider capabilities,
/// representing a continuous async operation that produces multiple results over time. /// representing a continuous async operation that produces multiple results over time.
pub trait AsyncEngineStream<Resp: Data>: pub trait AsyncEngineStream<Resp: Data>:
Stream<Item = Resp> + AsyncEngineContextProvider + Send + Sync Stream<Item = Resp> + AsyncEngineContextProvider + Send
{ {
} }
...@@ -204,7 +204,7 @@ pub trait AsyncEngineStream<Resp: Data>: ...@@ -204,7 +204,7 @@ pub trait AsyncEngineStream<Resp: Data>:
/// Implementations should ensure proper error handling and resource management. /// Implementations should ensure proper error handling and resource management.
/// The `generate` method should be cancellable via the response's context provider. /// The `generate` method should be cancellable via the response's context provider.
#[async_trait] #[async_trait]
pub trait AsyncEngine<Req: Data, Resp: Data + AsyncEngineContextProvider, E: Data>: pub trait AsyncEngine<Req: Send + Sync + 'static, Resp: AsyncEngineContextProvider, E: Data>:
Send + Sync Send + Sync
{ {
/// Generate a stream of completion responses. /// Generate a stream of completion responses.
......
...@@ -69,7 +69,7 @@ pub type ServerStreamingEngine<T, U> = ServiceEngine<SingleIn<T>, ManyOut<U>>; ...@@ -69,7 +69,7 @@ pub type ServerStreamingEngine<T, U> = ServiceEngine<SingleIn<T>, ManyOut<U>>;
/// are considered independent of each other; however, they could be constrained to be related. /// are considered independent of each other; however, they could be constrained to be related.
pub type BidirectionalStreamingEngine<T, U> = ServiceEngine<ManyIn<T>, ManyOut<U>>; pub type BidirectionalStreamingEngine<T, U> = ServiceEngine<ManyIn<T>, ManyOut<U>>;
pub trait AsyncTransportEngine<T: PipelineIO, U: PipelineIO>: pub trait AsyncTransportEngine<T: Data + PipelineIO, U: Data + PipelineIO>:
AsyncEngine<T, U, Error> + Send + Sync + 'static AsyncEngine<T, U, Error> + Send + Sync + 'static
{ {
} }
...@@ -97,7 +97,7 @@ mod sealed { ...@@ -97,7 +97,7 @@ mod sealed {
} }
} }
pub trait PipelineIO: Data + sealed::Connectable + AsyncEngineContextProvider { pub trait PipelineIO: sealed::Connectable + AsyncEngineContextProvider + 'static {
fn id(&self) -> String; fn id(&self) -> String;
} }
......
...@@ -280,7 +280,7 @@ pub struct Ingress<Req: PipelineIO, Resp: PipelineIO> { ...@@ -280,7 +280,7 @@ pub struct Ingress<Req: PipelineIO, Resp: PipelineIO> {
segment: OnceLock<Arc<SegmentSource<Req, Resp>>>, segment: OnceLock<Arc<SegmentSource<Req, Resp>>>,
} }
impl<Req: PipelineIO, Resp: PipelineIO> Ingress<Req, Resp> { impl<Req: PipelineIO + Sync, Resp: PipelineIO> Ingress<Req, Resp> {
pub fn new() -> Arc<Self> { pub fn new() -> Arc<Self> {
Arc::new(Self { Arc::new(Self {
segment: OnceLock::new(), segment: OnceLock::new(),
......
...@@ -221,8 +221,8 @@ where ...@@ -221,8 +221,8 @@ where
impl<UpIn, UpOut, DownIn, DownOut> AsyncEngine<UpIn, UpOut, Error> impl<UpIn, UpOut, DownIn, DownOut> AsyncEngine<UpIn, UpOut, Error>
for PipelineOperator<UpIn, UpOut, DownIn, DownOut> for PipelineOperator<UpIn, UpOut, DownIn, DownOut>
where where
UpIn: PipelineIO, UpIn: PipelineIO + Sync,
DownIn: PipelineIO, DownIn: PipelineIO + Sync,
DownOut: PipelineIO, DownOut: PipelineIO,
UpOut: PipelineIO, UpOut: PipelineIO,
{ {
...@@ -235,8 +235,8 @@ where ...@@ -235,8 +235,8 @@ where
impl<UpIn, UpOut, DownIn, DownOut> Sink<UpIn> impl<UpIn, UpOut, DownIn, DownOut> Sink<UpIn>
for PipelineOperatorForwardEdge<UpIn, UpOut, DownIn, DownOut> for PipelineOperatorForwardEdge<UpIn, UpOut, DownIn, DownOut>
where where
UpIn: PipelineIO, UpIn: PipelineIO + Sync,
DownIn: PipelineIO, DownIn: PipelineIO + Sync,
DownOut: PipelineIO, DownOut: PipelineIO,
UpOut: PipelineIO, UpOut: PipelineIO,
{ {
......
...@@ -26,7 +26,7 @@ impl<Req: PipelineIO, Resp: PipelineIO> ServiceBackend<Req, Resp> { ...@@ -26,7 +26,7 @@ impl<Req: PipelineIO, Resp: PipelineIO> ServiceBackend<Req, Resp> {
} }
#[async_trait] #[async_trait]
impl<Req: PipelineIO, Resp: PipelineIO> Sink<Req> for ServiceBackend<Req, Resp> { impl<Req: PipelineIO + Sync, Resp: PipelineIO> Sink<Req> for ServiceBackend<Req, Resp> {
async fn on_data(&self, data: Req, _: Token) -> Result<(), Error> { async fn on_data(&self, data: Req, _: Token) -> Result<(), Error> {
let stream = self.engine.generate(data).await?; let stream = self.engine.generate(data).await?;
self.on_next(stream, Token).await self.on_next(stream, Token).await
......
...@@ -38,7 +38,7 @@ impl<Req: PipelineIO, Resp: PipelineIO> Default for SegmentSink<Req, Resp> { ...@@ -38,7 +38,7 @@ impl<Req: PipelineIO, Resp: PipelineIO> Default for SegmentSink<Req, Resp> {
} }
#[async_trait] #[async_trait]
impl<Req: PipelineIO, Resp: PipelineIO> Sink<Req> for SegmentSink<Req, Resp> { impl<Req: PipelineIO + Sync, Resp: PipelineIO> Sink<Req> for SegmentSink<Req, Resp> {
async fn on_data(&self, data: Req, _: Token) -> Result<(), Error> { async fn on_data(&self, data: Req, _: Token) -> Result<(), Error> {
let stream = self let stream = self
.engine .engine
......
...@@ -68,7 +68,7 @@ impl<In: PipelineIO, Out: PipelineIO + AsyncEngineContextProvider> Sink<Out> for ...@@ -68,7 +68,7 @@ impl<In: PipelineIO, Out: PipelineIO + AsyncEngineContextProvider> Sink<Out> for
} }
#[async_trait] #[async_trait]
impl<In: PipelineIO, Out: PipelineIO> AsyncEngine<In, Out, Error> for Frontend<In, Out> { impl<In: PipelineIO + Sync, Out: PipelineIO> AsyncEngine<In, Out, Error> for Frontend<In, Out> {
async fn generate(&self, request: In) -> Result<Out, Error> { async fn generate(&self, request: In) -> Result<Out, Error> {
let (tx, rx) = oneshot::channel::<Out>(); let (tx, rx) = oneshot::channel::<Out>();
{ {
......
...@@ -48,7 +48,9 @@ macro_rules! impl_frontend { ...@@ -48,7 +48,9 @@ macro_rules! impl_frontend {
} }
#[async_trait] #[async_trait]
impl<In: PipelineIO, Out: PipelineIO> AsyncEngine<In, Out, Error> for $type<In, Out> { impl<In: PipelineIO + Sync, Out: PipelineIO> AsyncEngine<In, Out, Error>
for $type<In, Out>
{
async fn generate(&self, request: In) -> Result<Out, Error> { async fn generate(&self, request: In) -> Result<Out, Error> {
self.inner.generate(request).await self.inner.generate(request).await
} }
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#![allow(dead_code)]
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, OnceLock}; use std::sync::{Arc, OnceLock};
...@@ -159,12 +161,14 @@ impl<T: Data, U: Data> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> ...@@ -159,12 +161,14 @@ impl<T: Data, U: Data> AsyncEngine<SingleIn<T>, ManyOut<U>, Error>
for MockNetworkEgress<SingleIn<T>, ManyOut<U>> for MockNetworkEgress<SingleIn<T>, ManyOut<U>>
where where
T: Data + Serialize, T: Data + Serialize,
U: for<'de> Deserialize<'de> + Data, U: for<'de> Deserialize<'de> + Data + Send + Sync + 'static,
Self: Send + Sync,
{ {
async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> { async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
let ctrl_tx = self.ctrl_tx.clone();
let id = request.id().to_string(); let id = request.id().to_string();
// serialze the request // serialize the request
let request = request.try_map(|req| serde_json::to_vec(&req))?; let request = request.try_map(|req| serde_json::to_vec(&req))?;
// transfer the request context to a stream context // transfer the request context to a stream context
...@@ -172,14 +176,11 @@ where ...@@ -172,14 +176,11 @@ where
let context = Arc::new(StreamContext::from(context)); let context = Arc::new(StreamContext::from(context));
// subscribe to the response stream // subscribe to the response stream
// but in this case, we are doing a mock, so we are going to be more explicit // in this mock, we use a channel for the data plane
// since we are transferring data over a channel instead of the networ, creating the channel
// is the same as subscribing to the response stream
let (data_tx, data_rx) = mpsc::channel::<DataPlaneMessage>(16); let (data_tx, data_rx) = mpsc::channel::<DataPlaneMessage>(16);
let mut byte_stream = tokio_stream::wrappers::ReceiverStream::new(data_rx); let mut byte_stream = tokio_stream::wrappers::ReceiverStream::new(data_rx);
// prepare the stateful objects that will be used to monitor the response stream // prepare the stateful objects that will be used to monitor the response stream
// finish_rx is a oneshot channel that will be used to signal the natural termination of the stream
let (finished_tx, finished_rx) = tokio::sync::oneshot::channel::<()>(); let (finished_tx, finished_rx) = tokio::sync::oneshot::channel::<()>();
let stream_monitor = ResponseMonitor { let stream_monitor = ResponseMonitor {
ctx: context.clone(), ctx: context.clone(),
...@@ -187,9 +188,6 @@ where ...@@ -187,9 +188,6 @@ where
}; };
// create the control plane request // create the control plane request
// when this is issued, control is handed off to the control plane and the downstream segment
// sometimes we might include the local server address and port for the response find its way home
// todo(design) this will be part of the generalization error for multiple transport types
let request = ControlPlaneRequest { let request = ControlPlaneRequest {
id, id,
request: data, request: data,
...@@ -197,108 +195,83 @@ where ...@@ -197,108 +195,83 @@ where
}; };
// send the request to the control plane // send the request to the control plane
self.ctrl_tx ctrl_tx
.send(MockNetworkControlEvents::ControlPlaneRequest(request)) .send(MockNetworkControlEvents::ControlPlaneRequest(request))
.await .await
.map_err(|e| PipelineError::ControlPlaneRequestError(e.to_string()))?; .map_err(|e| PipelineError::ControlPlaneRequestError(e.to_string()))?;
// the first message from the remote publisher on the data plane needs to be a handshake message // the first message from the remote publisher on the data plane needs to be a handshake message
// the handshake will indicate to what stream the data belongs to and if the remote segment was
// able to process the request.
//
// note: in the case of the mock transport, the handshaking of the request id is not strictly
// because the channel is specific to the request. this is similar to other transports like nats
// where we will subscribe to a response stream on a subject unique to the stream.
match byte_stream.next().await { match byte_stream.next().await {
Some(DataPlaneMessage { headers, body }) => { Some(DataPlaneMessage { headers, body }) => {
if !body.is_empty() { if !body.is_empty() {
Err(PipelineError::ControlPlaneRequestError( return Err(PipelineError::ControlPlaneRequestError(
"Expected an empty body for the handshake message".to_string(), "Expected an empty body for the handshake message".to_string(),
))?; )
.into());
} }
match headers { match headers {
Some(header) => { Some(header) => match header {
match header {
MockNetworkDataPlaneHeaders::Handshake(handshake) => { MockNetworkDataPlaneHeaders::Handshake(handshake) => {
match handshake.status { match handshake.status {
Status::Ok => {} Status::Ok => {}
Status::Error(e) => { Status::Error(e) => {
// todo(metrics): increment metric counter for failed handshakes return Err(PipelineError::ControlPlaneRequestError(format!(
Err(PipelineError::ControlPlaneRequestError(format!(
"remote segment was unable to process request: {}", "remote segment was unable to process request: {}",
e e
)))?; ))
.into());
} }
} }
} }
_ => { _ => {
Err(PipelineError::ControlPlaneRequestError(format!( return Err(PipelineError::ControlPlaneRequestError(format!(
"Expected a handshake message; got: {:?}", "Expected a handshake message; got: {:?}",
header header
)))?; ))
} .into());
}
} }
},
_ => { _ => {
Err(PipelineError::ControlPlaneRequestError( return Err(PipelineError::ControlPlaneRequestError(
"Failed to receive properly formatted handshake on data plane" "Failed to receive properly formatted handshake on data plane"
.to_string(), .to_string(),
))?; )
.into());
} }
} }
} }
None => { None => {
// todo(metrics): increment metric counter for failed requests return Err(PipelineError::ControlPlaneRequestError(
Err(PipelineError::ControlPlaneRequestError(
"Failed data plane connection closed before receiving handshake".to_string(), "Failed data plane connection closed before receiving handshake".to_string(),
))?; )
.into());
} }
} }
let decoded = byte_stream let decoded = byte_stream
// .inspect(|_item| {
// // todo(metrics) increment the metrics counter by the number of bytes
// })
.scan(Some(stream_monitor), move |_stream_monitor, item| { .scan(Some(stream_monitor), move |_stream_monitor, item| {
// we could check the kill state of the context and terminate the stream here
// if our transport needs a heartbeat, trigger a heartbeat here the monitor
if let Some(headers) = &item.headers { if let Some(headers) = &item.headers {
match headers { match headers {
MockNetworkDataPlaneHeaders::HeartBeat => { MockNetworkDataPlaneHeaders::HeartBeat => {
// todo(metrics): increment metric counter for heartbeats // Heartbeat received, do nothing special
// send a heartbeat to the control plane
// this is a good place to send a heartbeat to the control plane
// to keep the connection alive
} }
MockNetworkDataPlaneHeaders::Sentinel => { MockNetworkDataPlaneHeaders::Sentinel => {
// todo(metrics): increment metric counter for sentinels // End of stream
// the stream has ended
// send a sentinel to the control plane
// this is a good place to send a sentinel to the control plane
// to indicate the end of the stream
return futures::future::ready(None); return futures::future::ready(None);
} }
_ => {} _ => {}
} }
} }
futures::future::ready(Some(item)) futures::future::ready(Some(item))
}) })
// decode the response
.map(move |item| { .map(move |item| {
serde_json::from_slice::<U>(&item.body).expect("failed to deserialize response") serde_json::from_slice::<U>(&item.body).expect("failed to deserialize response")
}); });
// cancellation can be tricky and is transport / protocol specific
// in this case, our channel for this is both ordered and 1:1, thus we can
// use that fact to first send the request, then forward any cancellation requests
// this ensures the downstream node should register the context/request id before any
// cancellation requests are sent
// create the cancellation monitor object // create the cancellation monitor object
let cancellation_monitor = CancellationMonitor { let cancellation_monitor = CancellationMonitor {
ctx: context.clone(), ctx: context.clone(),
ctrl_tx: self.ctrl_tx.clone(), ctrl_tx,
finish_tx: finished_tx, finish_tx: finished_tx,
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment