Unverified Commit 56d20f53 authored by ryan-lempka's avatar ryan-lempka Committed by GitHub
Browse files

feat: add audit logging for chat completions (#3062)


Signed-off-by: default avatarRyan Lempka <rlempka@nvidia.com>
parent 5b457b70
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::handle::AuditRecord;
use std::sync::{Arc, OnceLock};
use tokio::sync::broadcast;
static BUS: OnceLock<broadcast::Sender<Arc<AuditRecord>>> = OnceLock::new();
pub fn init(capacity: usize) {
let (tx, _rx) = broadcast::channel::<Arc<AuditRecord>>(capacity);
let _ = BUS.set(tx);
}
pub fn subscribe() -> broadcast::Receiver<Arc<AuditRecord>> {
BUS.get().expect("audit bus not initialized").subscribe()
}
pub fn publish(rec: AuditRecord) {
if let Some(tx) = BUS.get() {
let _ = tx.send(Arc::new(rec));
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::OnceLock;
#[derive(Clone, Copy)]
pub struct AuditPolicy {
pub enabled: bool,
}
static POLICY: OnceLock<AuditPolicy> = OnceLock::new();
pub fn init_from_env() -> AuditPolicy {
let enabled = std::env::var("DYN_AUDIT_ENABLED")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false);
AuditPolicy { enabled }
}
pub fn policy() -> AuditPolicy {
*POLICY.get_or_init(init_from_env)
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use serde::Serialize;
use std::sync::Arc;
use super::{bus, config};
use crate::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionResponse,
};
#[derive(Serialize, Clone)]
pub struct AuditRecord {
pub schema_version: u32,
pub request_id: String,
pub requested_streaming: bool,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub request: Option<Arc<NvCreateChatCompletionRequest>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<Arc<NvCreateChatCompletionResponse>>,
}
pub struct AuditHandle {
requested_streaming: bool,
request_id: String,
model: String,
req_full: Option<Arc<NvCreateChatCompletionRequest>>,
resp_full: Option<Arc<NvCreateChatCompletionResponse>>,
}
impl AuditHandle {
pub fn streaming(&self) -> bool {
self.requested_streaming
}
pub fn set_request(&mut self, req: Arc<NvCreateChatCompletionRequest>) {
self.req_full = Some(req);
}
pub fn set_response(&mut self, resp: Arc<NvCreateChatCompletionResponse>) {
self.resp_full = Some(resp);
}
/// Emit exactly once (publishes to the bus; sinks do I/O).
pub fn emit(self) {
let rec = AuditRecord {
schema_version: 1,
request_id: self.request_id,
requested_streaming: self.requested_streaming,
model: self.model,
request: self.req_full,
response: self.resp_full,
};
bus::publish(rec);
}
}
pub fn create_handle(req: &NvCreateChatCompletionRequest, request_id: &str) -> Option<AuditHandle> {
if !config::policy().enabled || !req.inner.store.unwrap_or(false) {
return None;
}
let requested_streaming = req.inner.stream.unwrap_or(false);
let model = req.inner.model.clone();
Some(AuditHandle {
requested_streaming,
request_id: request_id.to_string(),
model,
req_full: None,
resp_full: None,
})
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
pub mod bus;
pub mod config;
pub mod handle;
pub mod sink;
pub mod stream;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use tokio::sync::broadcast;
use super::{bus, handle::AuditRecord};
pub trait AuditSink: Send + Sync {
fn name(&self) -> &'static str;
fn emit(&self, rec: &AuditRecord);
}
pub struct StderrSink;
impl AuditSink for StderrSink {
fn name(&self) -> &'static str {
"stderr"
}
fn emit(&self, rec: &AuditRecord) {
match serde_json::to_string(rec) {
Ok(js) => {
tracing::info!(target="dynamo_llm::audit", log_type="audit", record=%js, "audit")
}
Err(e) => tracing::warn!("audit: serialize failed: {e}"),
}
}
}
fn parse_sinks_from_env() -> Vec<Arc<dyn AuditSink>> {
let cfg = std::env::var("DYN_AUDIT_SINKS").unwrap_or_else(|_| "stderr".into());
let mut out: Vec<Arc<dyn AuditSink>> = Vec::new();
for name in cfg.split(',').map(|s| s.trim().to_lowercase()) {
match name.as_str() {
"stderr" | "" => out.push(Arc::new(StderrSink)),
// "nats" => out.push(Arc::new(NatsSink::from_env())),
// "pg" => out.push(Arc::new(PostgresSink::from_env())),
other => tracing::warn!(%other, "audit: unknown sink ignored"),
}
}
out
}
/// spawn one worker per sink; each subscribes to the bus (off hot path)
pub fn spawn_workers_from_env() {
let sinks = parse_sinks_from_env();
for sink in sinks {
let name = sink.name();
let mut rx: broadcast::Receiver<Arc<AuditRecord>> = bus::subscribe();
tokio::spawn(async move {
loop {
match rx.recv().await {
Ok(rec) => sink.emit(&rec),
Err(broadcast::error::RecvError::Lagged(n)) => tracing::warn!(
sink = name,
dropped = n,
"audit bus lagged; dropped records"
),
Err(broadcast::error::RecvError::Closed) => break,
}
}
});
}
tracing::info!("Audit sinks ready.");
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use futures::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::oneshot;
use crate::protocols::openai::ParsingOptions;
use crate::protocols::openai::chat_completions::{
DeltaAggregator, NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse,
};
use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_async_openai::types::{ChatChoiceStream, ChatCompletionStreamResponseDelta};
use futures::StreamExt;
type AuditStream =
Pin<Box<dyn Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send>>;
type AuditFuture =
Pin<Box<dyn std::future::Future<Output = NvCreateChatCompletionResponse> + Send>>;
/// Forwards transformed chunks unchanged; collects them for aggregation.
pub struct PassThroughWithAgg<S> {
inner: S,
chunks: Vec<Annotated<NvCreateChatCompletionStreamResponse>>,
done_tx: Option<oneshot::Sender<NvCreateChatCompletionResponse>>,
}
impl<S> PassThroughWithAgg<S> {
fn new(inner: S, tx: oneshot::Sender<NvCreateChatCompletionResponse>) -> Self {
Self {
inner,
chunks: Vec::new(),
done_tx: Some(tx),
}
}
}
impl<S> Stream for PassThroughWithAgg<S>
where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Unpin,
{
type Item = Annotated<NvCreateChatCompletionStreamResponse>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(chunk)) => {
// Store chunk for aggregation
self.chunks.push(chunk.clone());
// Forward the chunk unchanged downstream
Poll::Ready(Some(chunk))
}
Poll::Ready(None) => {
if let Some(tx) = self.done_tx.take() {
// Aggregate all collected chunks
let chunks = std::mem::take(&mut self.chunks);
let chunks_stream = futures::stream::iter(chunks);
let parsing_options = ParsingOptions::default();
tokio::spawn(async move {
match DeltaAggregator::apply(chunks_stream, parsing_options).await {
Ok(final_resp) => {
let _ = tx.send(final_resp);
}
Err(e) => {
tracing::warn!("audit: aggregation failed: {e}");
}
}
});
}
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
/// Return (pass-through stream, future -> final aggregated response for audit).
pub fn scan_aggregate_with_future<S>(stream: S) -> (AuditStream, AuditFuture)
where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Unpin + Send + 'static,
{
let (tx, rx) = oneshot::channel::<NvCreateChatCompletionResponse>();
let passthrough = PassThroughWithAgg::new(stream, tx);
(
Box::pin(passthrough),
Box::pin(async move {
rx.await.unwrap_or_else(|_| {
tracing::warn!("audit: aggregation future canceled/failed");
// Return minimal response if aggregation failed
NvCreateChatCompletionResponse {
id: String::new(),
created: 0,
usage: None,
model: String::new(),
object: "chat.completion".to_string(),
system_fingerprint: None,
choices: vec![],
service_tier: None,
}
})
}),
)
}
/// Collect all chunks, aggregate them, then emit a single final chunk (for non-streaming)
pub fn fold_aggregate_with_future<S>(stream: S) -> (AuditStream, AuditFuture)
where
S: Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send + 'static,
{
let (tx, rx) = oneshot::channel::<NvCreateChatCompletionResponse>();
let single_chunk_stream = async move {
let chunks: Vec<_> = stream.collect().await;
let chunks_stream = futures::stream::iter(chunks);
let parsing_options = ParsingOptions::default();
match DeltaAggregator::apply(chunks_stream, parsing_options).await {
Ok(final_resp) => {
let _ = tx.send(final_resp.clone());
final_response_to_one_chunk_stream(final_resp)
}
Err(e) => {
tracing::warn!("fold aggregation failed: {e}");
let fallback = NvCreateChatCompletionResponse {
id: String::new(),
created: 0,
usage: None,
model: String::new(),
object: "chat.completion".to_string(),
system_fingerprint: None,
choices: vec![],
service_tier: None,
};
let _ = tx.send(fallback.clone());
final_response_to_one_chunk_stream(fallback)
}
}
};
let future = Box::pin(async move {
rx.await.unwrap_or_else(|_| {
tracing::warn!("fold aggregation future canceled");
NvCreateChatCompletionResponse {
id: String::new(),
created: 0,
usage: None,
model: String::new(),
object: "chat.completion".to_string(),
system_fingerprint: None,
choices: vec![],
service_tier: None,
}
})
});
(
Box::pin(futures::stream::once(single_chunk_stream).flatten()),
future,
)
}
/// Convert a final (non-streaming) response into a single "final chunk" stream.
/// Put the entire final text/tool-calls into `delta` so downstream aggregate is a no-op.
pub fn final_response_to_one_chunk_stream(
resp: NvCreateChatCompletionResponse,
) -> std::pin::Pin<
Box<dyn futures::Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>> + Send>,
> {
let mut choices: Vec<ChatChoiceStream> = Vec::with_capacity(resp.choices.len());
for (idx, ch) in resp.choices.iter().enumerate() {
// Convert FunctionCall to FunctionCallStream if present
#[allow(deprecated)]
let function_call = ch.message.function_call.as_ref().map(|fc| {
dynamo_async_openai::types::FunctionCallStream {
name: Some(fc.name.clone()),
arguments: Some(fc.arguments.clone()),
}
});
// Convert tool calls
let tool_calls = ch.message.tool_calls.as_ref().map(|calls| {
calls
.iter()
.enumerate()
.map(
|(i, call)| dynamo_async_openai::types::ChatCompletionMessageToolCallChunk {
index: i as u32,
id: Some(call.id.clone()),
r#type: Some(call.r#type.clone()),
function: Some(dynamo_async_openai::types::FunctionCallStream {
name: Some(call.function.name.clone()),
arguments: Some(call.function.arguments.clone()),
}),
},
)
.collect()
});
#[allow(deprecated)]
let delta = ChatCompletionStreamResponseDelta {
role: Some(ch.message.role),
content: ch.message.content.clone(),
tool_calls,
function_call,
refusal: ch.message.refusal.clone(),
reasoning_content: ch.message.reasoning_content.clone(),
};
let choice = ChatChoiceStream {
index: idx as u32,
delta,
finish_reason: ch.finish_reason,
logprobs: ch.logprobs.clone(),
};
choices.push(choice);
}
let chunk = NvCreateChatCompletionStreamResponse {
id: resp.id.clone(),
object: "chat.completion.chunk".to_string(),
created: resp.created,
model: resp.model.clone(),
system_fingerprint: resp.system_fingerprint.clone(),
service_tier: resp.service_tier.clone(),
choices,
usage: resp.usage.clone(),
};
let annotated = Annotated {
data: Some(chunk),
id: None,
event: None,
comment: None,
};
Box::pin(futures::stream::once(async move { annotated }))
}
#[cfg(test)]
mod tests {
use super::*;
use dynamo_async_openai::types::{
ChatChoiceStream, ChatCompletionStreamResponseDelta, FinishReason, Role,
};
use futures::StreamExt;
use futures::stream;
/// Helper function to create a mock chat response chunk
fn create_mock_chunk(
content: String,
index: u32,
) -> Annotated<NvCreateChatCompletionStreamResponse> {
#[allow(deprecated)]
let choice = ChatChoiceStream {
index,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: Some(content),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: None,
};
let response = NvCreateChatCompletionStreamResponse {
id: "test-id".to_string(),
choices: vec![choice],
created: 1234567890,
model: "test-model".to_string(),
system_fingerprint: Some("test-fingerprint".to_string()),
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
};
Annotated {
data: Some(response),
id: None,
event: None,
comment: None,
}
}
/// Helper function to create a final response chunk with finish reason
fn create_final_chunk(index: u32) -> Annotated<NvCreateChatCompletionStreamResponse> {
#[allow(deprecated)]
let choice = ChatChoiceStream {
index,
delta: ChatCompletionStreamResponseDelta {
role: None,
content: None,
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: Some(FinishReason::Stop),
logprobs: None,
};
let response = NvCreateChatCompletionStreamResponse {
id: "test-id".to_string(),
choices: vec![choice],
created: 1234567890,
model: "test-model".to_string(),
system_fingerprint: Some("test-fingerprint".to_string()),
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
};
Annotated {
data: Some(response),
id: None,
event: None,
comment: None,
}
}
/// Helper to extract content from a chunk
fn extract_content(chunk: &Annotated<NvCreateChatCompletionStreamResponse>) -> String {
chunk
.data
.as_ref()
.and_then(|d| d.choices.first())
.and_then(|c| c.delta.content.as_ref())
.cloned()
.unwrap_or_default()
}
/// Helper to reconstruct all content from results
fn reconstruct_content(results: &[Annotated<NvCreateChatCompletionStreamResponse>]) -> String {
results
.iter()
.map(extract_content)
.collect::<Vec<_>>()
.join("")
}
#[tokio::test]
async fn test_passthrough_forwards_chunks_unchanged() {
// Input chunks should pass through exactly as-is
let chunks = vec![
create_mock_chunk("Hello ".to_string(), 0),
create_mock_chunk("World".to_string(), 0),
create_final_chunk(0),
];
let input_stream = stream::iter(chunks.clone());
let (passthrough, _future) = scan_aggregate_with_future(input_stream);
let results: Vec<_> = passthrough.collect().await;
// Verify chunk count
assert_eq!(results.len(), 3, "Should pass through all chunks unchanged");
// Verify content is identical
assert_eq!(extract_content(&results[0]), "Hello ");
assert_eq!(extract_content(&results[1]), "World");
assert_eq!(extract_content(&results[2]), ""); // Final chunk has no content
// Verify complete content reconstruction
assert_eq!(reconstruct_content(&results), "Hello World");
}
#[tokio::test]
async fn test_empty_stream_handling() {
// Empty stream should not panic and should provide fallback response
let chunks: Vec<Annotated<NvCreateChatCompletionStreamResponse>> = vec![];
let input_stream = stream::iter(chunks);
let (passthrough, future) = scan_aggregate_with_future(input_stream);
let results: Vec<_> = passthrough.collect().await;
let final_resp = future.await;
// Verify empty passthrough
assert_eq!(results.len(), 0, "Empty stream should produce no chunks");
// Verify fallback response (aggregation will fail on empty stream)
assert_eq!(final_resp.object, "chat.completion");
// Should get fallback response, not panic
}
#[tokio::test]
async fn test_single_chunk_stream() {
// Single chunk should pass through and aggregate correctly
let chunks = vec![create_mock_chunk("Single chunk".to_string(), 0)];
let input_stream = stream::iter(chunks);
let (passthrough, future) = scan_aggregate_with_future(input_stream);
let results: Vec<_> = passthrough.collect().await;
let final_resp = future.await;
// Verify passthrough
assert_eq!(results.len(), 1);
assert_eq!(extract_content(&results[0]), "Single chunk");
// Verify aggregation
assert_eq!(final_resp.object, "chat.completion");
}
#[tokio::test]
async fn test_chunks_with_metadata_preserved() {
// Test that metadata (id, event, comment) is preserved through passthrough
let chunk_with_metadata = Annotated {
data: Some(NvCreateChatCompletionStreamResponse {
id: "test-id".to_string(),
choices: vec![{
#[allow(deprecated)]
ChatChoiceStream {
index: 0,
delta: ChatCompletionStreamResponseDelta {
role: Some(Role::Assistant),
content: Some("Content".to_string()),
tool_calls: None,
function_call: None,
refusal: None,
reasoning_content: None,
},
finish_reason: None,
logprobs: None,
}
}],
created: 1234567890,
model: "test-model".to_string(),
system_fingerprint: None,
object: "chat.completion.chunk".to_string(),
usage: None,
service_tier: None,
}),
id: Some("correlation-123".to_string()),
event: Some("test-event".to_string()),
comment: Some(vec!["test-comment".to_string()]),
};
let input_stream = stream::iter(vec![chunk_with_metadata.clone()]);
let (passthrough, _future) = scan_aggregate_with_future(input_stream);
let results: Vec<_> = passthrough.collect().await;
// Verify metadata is preserved
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, Some("correlation-123".to_string()));
assert_eq!(results[0].event, Some("test-event".to_string()));
assert_eq!(results[0].comment, Some(vec!["test-comment".to_string()]));
}
#[tokio::test]
async fn test_concurrent_futures() {
// Test that multiple concurrent audit streams don't interfere
let chunks1 = vec![create_mock_chunk("Stream 1".to_string(), 0)];
let chunks2 = vec![create_mock_chunk("Stream 2".to_string(), 0)];
let (_, future1) = scan_aggregate_with_future(stream::iter(chunks1));
let (_, future2) = scan_aggregate_with_future(stream::iter(chunks2));
// Run both futures concurrently
let (resp1, resp2) = tokio::join!(future1, future2);
// Both should complete successfully
assert_eq!(resp1.object, "chat.completion");
assert_eq!(resp2.object, "chat.completion");
}
}
...@@ -115,6 +115,18 @@ pub async fn run_input( ...@@ -115,6 +115,18 @@ pub async fn run_input(
Either::Left(rt) => rt.clone(), Either::Left(rt) => rt.clone(),
Either::Right(drt) => drt.runtime().clone(), Either::Right(drt) => drt.runtime().clone(),
}; };
// Initialize audit bus + sink workers (off hot path; fan-out supported)
if crate::audit::config::policy().enabled {
let cap: usize = std::env::var("DYN_AUDIT_CAPACITY")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(1024);
crate::audit::bus::init(cap);
crate::audit::sink::spawn_workers_from_env();
tracing::info!("Audit initialized: bus cap={}", cap);
}
match in_opt { match in_opt {
Input::Http => { Input::Http => {
http::run(runtime, engine_config).await?; http::run(runtime, engine_config).await?;
......
...@@ -334,12 +334,6 @@ async fn completions( ...@@ -334,12 +334,6 @@ async fn completions(
// Create http_queue_guard early - tracks time waiting to be processed // Create http_queue_guard early - tracks time waiting to be processed
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model); let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
// update the request to always stream
let request = request.map(|mut req| {
req.inner.stream = Some(true);
req
});
// todo - error handling should be more robust // todo - error handling should be more robust
let engine = state let engine = state
.manager() .manager()
...@@ -589,12 +583,6 @@ async fn chat_completions( ...@@ -589,12 +583,6 @@ async fn chat_completions(
// todo - decide on default // todo - decide on default
let streaming = request.inner.stream.unwrap_or(false); let streaming = request.inner.stream.unwrap_or(false);
// update the request to always stream
let request = request.map(|mut req| {
req.inner.stream = Some(true);
req
});
// todo - make the protocols be optional for model name // todo - make the protocols be optional for model name
// todo - when optional, if none, apply a default // todo - when optional, if none, apply a default
// todo - determine the proper error code for when a request model is not present // todo - determine the proper error code for when a request model is not present
......
...@@ -22,6 +22,7 @@ pub mod grpc; ...@@ -22,6 +22,7 @@ pub mod grpc;
pub mod http; pub mod http;
pub mod hub; pub mod hub;
// pub mod key_value_store; // pub mod key_value_store;
pub mod audit;
pub mod kv_router; pub mod kv_router;
pub mod local_model; pub mod local_model;
pub mod migration; pub mod migration;
......
...@@ -692,7 +692,20 @@ impl ...@@ -692,7 +692,20 @@ impl
>, >,
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
// unpack the request // unpack the request
let (request, context) = request.into_parts(); let (mut request, context) = request.into_parts();
// Preserve original inbound streaming flag before any internal overrides
let request_id = context.id().to_string();
// Build audit handle (None if DYN_AUDIT_ENABLED=0)
let mut audit_handle = crate::audit::handle::create_handle(&request, &request_id);
if let Some(ref mut h) = audit_handle {
h.set_request(std::sync::Arc::new(request.clone()));
}
// Set stream=true for internal processing (after audit capture)
request.inner.stream = Some(true);
// create a response generator // create a response generator
let response_generator = request.response_generator(context.id().to_string()); let response_generator = request.response_generator(context.id().to_string());
...@@ -735,8 +748,6 @@ impl ...@@ -735,8 +748,6 @@ impl
let has_tools = let has_tools =
request.inner.tools.is_some() && !request.inner.tools.as_ref().unwrap().is_empty(); request.inner.tools.is_some() && !request.inner.tools.as_ref().unwrap().is_empty();
// Context was already extracted above from response_stream
// Determine if we should apply jail (do this before moving request) // Determine if we should apply jail (do this before moving request)
let should_jail = Self::should_apply_tool_jail( let should_jail = Self::should_apply_tool_jail(
self.tool_call_parser.as_ref(), self.tool_call_parser.as_ref(),
...@@ -745,7 +756,7 @@ impl ...@@ -745,7 +756,7 @@ impl
)?; )?;
// Apply jail conditionally // Apply jail conditionally
let stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_jail { let transformed_stream: Pin<Box<dyn Stream<Item = _> + Send>> = if should_jail {
if let Some(parser) = self.tool_call_parser.clone() { if let Some(parser) = self.tool_call_parser.clone() {
Box::pin(Self::apply_tool_calling_jail(parser, stream)) Box::pin(Self::apply_tool_calling_jail(parser, stream))
} else { } else {
...@@ -754,8 +765,31 @@ impl ...@@ -754,8 +765,31 @@ impl
} else { } else {
Box::pin(stream) Box::pin(stream)
}; };
// Step 4: Apply audit aggregation strategy
let final_stream = if let Some(mut audit) = audit_handle {
let (stream, agg_fut) = if audit.streaming() {
// Streaming: apply scan (pass-through + parallel aggregation)
crate::audit::stream::scan_aggregate_with_future(transformed_stream)
} else {
// Non-streaming: apply fold (collect all, then emit single chunk)
crate::audit::stream::fold_aggregate_with_future(transformed_stream)
};
// Spawn audit task
tokio::spawn(async move {
let final_resp = agg_fut.await;
audit.set_response(Arc::new(final_resp));
audit.emit();
});
Box::pin(stream)
} else {
transformed_stream
};
// prepend the annotations to the response stream // prepend the annotations to the response stream
let stream = annotations_stream.chain(stream); let stream = annotations_stream.chain(final_stream);
// return the response stream - single boxing at the end // return the response stream - single boxing at the end
Ok(ResponseStream::new(Box::pin(stream), context)) Ok(ResponseStream::new(Box::pin(stream), context))
...@@ -779,7 +813,9 @@ impl ...@@ -779,7 +813,9 @@ impl
>, >,
) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
// unpack the request // unpack the request
let (request, context) = request.into_parts(); let (mut request, context) = request.into_parts();
request.inner.stream = Some(true);
// create a response generator // create a response generator
let response_generator = request.response_generator(context.id().to_string()); let response_generator = request.response_generator(context.id().to_string());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment