Unverified Commit aa80ac41 authored by Ryan Olson's avatar Ryan Olson Committed by GitHub
Browse files

fix: implement OpenAI-compliant usage stats for streaming responses (#3022)


Signed-off-by: default avatarRyan Olson <rolson@nvidia.com>
parent 23033136
...@@ -91,3 +91,10 @@ generated-values.yaml ...@@ -91,3 +91,10 @@ generated-values.yaml
.build/ .build/
**/.devcontainer/.env **/.devcontainer/.env
TensorRT-LLM TensorRT-LLM
# Ruler Generated Files
/.cursor/instructions.md
/.cursor/instructions.md.bak
/CLAUDE.md
/CLAUDE.md.bak
...@@ -473,6 +473,8 @@ impl OpenAIPreprocessor { ...@@ -473,6 +473,8 @@ impl OpenAIPreprocessor {
context: Arc<dyn AsyncEngineContext>, context: Arc<dyn AsyncEngineContext>,
cancelled: bool, cancelled: bool,
cumulative_output_tokens: usize, cumulative_output_tokens: usize,
finish_reason_sent: bool,
usage_chunk_sent: bool,
finished: bool, // Add this flag to track if stream is finished finished: bool, // Add this flag to track if stream is finished
} }
...@@ -482,6 +484,8 @@ impl OpenAIPreprocessor { ...@@ -482,6 +484,8 @@ impl OpenAIPreprocessor {
context: context.clone(), context: context.clone(),
cancelled: false, cancelled: false,
cumulative_output_tokens: 0, cumulative_output_tokens: 0,
finish_reason_sent: false,
usage_chunk_sent: false,
finished: false, // Initialize as not finished finished: false, // Initialize as not finished
}; };
...@@ -509,6 +513,13 @@ impl OpenAIPreprocessor { ...@@ -509,6 +513,13 @@ impl OpenAIPreprocessor {
response response
); );
// Check if this response has a finish_reason
let has_finish_reason = response
.data
.as_ref()
.map(|d| d.finish_reason.is_some())
.unwrap_or(false);
let (chunk_tokens, isl) = if let Some(ref backend_output) = response.data { let (chunk_tokens, isl) = if let Some(ref backend_output) = response.data {
let chunk_tokens = backend_output.token_ids.len(); let chunk_tokens = backend_output.token_ids.len();
inner.cumulative_output_tokens += chunk_tokens; inner.cumulative_output_tokens += chunk_tokens;
...@@ -553,6 +564,11 @@ impl OpenAIPreprocessor { ...@@ -553,6 +564,11 @@ impl OpenAIPreprocessor {
} }
} }
// Mark if we've seen a finish_reason
if has_finish_reason {
inner.finish_reason_sent = true;
}
tracing::trace!( tracing::trace!(
request_id = inner.context.id(), request_id = inner.context.id(),
"OpenAI NvCreateChatCompletionStreamResponse: {:?}", "OpenAI NvCreateChatCompletionStreamResponse: {:?}",
...@@ -561,10 +577,34 @@ impl OpenAIPreprocessor { ...@@ -561,10 +577,34 @@ impl OpenAIPreprocessor {
Some((response, inner)) Some((response, inner))
} else { } else {
// stream closed with out graceful closure // Stream has ended - check if we need to send a usage chunk
// we did not detect an is_finished/completed message if inner.response_generator.is_usage_enabled()
inner.finished = true; // Mark as finished && inner.finish_reason_sent
None && !inner.usage_chunk_sent
&& !inner.finished
{
inner.usage_chunk_sent = true;
// Create the final usage chunk
let usage_chunk = inner.response_generator.create_usage_chunk();
let annotated_usage = Annotated::<Resp> {
id: None,
data: Some(usage_chunk),
event: Some(ANNOTATION_LLM_METRICS.to_string()),
comment: None,
};
tracing::trace!(
request_id = inner.context.id(),
"Sending final usage chunk for OpenAI compliance"
);
Some((annotated_usage, inner))
} else {
// stream closed
inner.finished = true; // Mark as finished
None
}
} }
} }
}); });
......
...@@ -206,6 +206,12 @@ pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>: ...@@ -206,6 +206,12 @@ pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>:
/// Gets the current prompt token count (Input Sequence Length). /// Gets the current prompt token count (Input Sequence Length).
fn get_isl(&self) -> Option<u32>; fn get_isl(&self) -> Option<u32>;
/// Creates a final usage-only chunk for OpenAI compliance.
fn create_usage_chunk(&self) -> ResponseType;
/// Check if usage tracking is enabled.
fn is_usage_enabled(&self) -> bool;
} }
#[derive(Clone, Debug, Serialize, Deserialize, Default)] #[derive(Clone, Debug, Serialize, Deserialize, Default)]
......
...@@ -20,7 +20,12 @@ impl NvCreateChatCompletionRequest { ...@@ -20,7 +20,12 @@ impl NvCreateChatCompletionRequest {
/// * [`DeltaGenerator`] configured with model name and response options. /// * [`DeltaGenerator`] configured with model name and response options.
pub fn response_generator(&self, request_id: String) -> DeltaGenerator { pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
let options = DeltaGeneratorOptions { let options = DeltaGeneratorOptions {
enable_usage: true, enable_usage: self
.inner
.stream_options
.as_ref()
.map(|opts| opts.include_usage)
.unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(false) enable_logprobs: self.inner.logprobs.unwrap_or(false)
|| self.inner.top_logprobs.unwrap_or(0) > 0, || self.inner.top_logprobs.unwrap_or(0) > 0,
runtime_config: ModelRuntimeConfig::default(), runtime_config: ModelRuntimeConfig::default(),
...@@ -252,10 +257,29 @@ impl DeltaGenerator { ...@@ -252,10 +257,29 @@ impl DeltaGenerator {
let choices = vec![choice]; let choices = vec![choice];
let mut usage = self.usage.clone(); // According to OpenAI spec: when stream_options.include_usage is true,
if self.options.enable_usage { // all intermediate chunks should have usage: null
usage.total_tokens = usage.prompt_tokens + usage.completion_tokens; // The final usage chunk will be sent separately with empty choices
dynamo_async_openai::types::CreateChatCompletionStreamResponse {
id: self.id.clone(),
object: self.object.clone(),
created: self.created,
model: self.model.clone(),
system_fingerprint: self.system_fingerprint.clone(),
choices,
usage: None, // Always None for chunks with content/choices
service_tier: self.service_tier.clone(),
} }
}
/// Creates a final usage-only chunk for OpenAI compliance.
/// This should be sent after the last content chunk when stream_options.include_usage is true.
///
/// # Returns
/// * A [`CreateChatCompletionStreamResponse`] with empty choices and usage stats.
pub fn create_usage_chunk(&self) -> NvCreateChatCompletionStreamResponse {
let mut usage = self.usage.clone();
usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
dynamo_async_openai::types::CreateChatCompletionStreamResponse { dynamo_async_openai::types::CreateChatCompletionStreamResponse {
id: self.id.clone(), id: self.id.clone(),
...@@ -263,15 +287,16 @@ impl DeltaGenerator { ...@@ -263,15 +287,16 @@ impl DeltaGenerator {
created: self.created, created: self.created,
model: self.model.clone(), model: self.model.clone(),
system_fingerprint: self.system_fingerprint.clone(), system_fingerprint: self.system_fingerprint.clone(),
choices, choices: vec![], // Empty choices for usage-only chunk
usage: if self.options.enable_usage { usage: Some(usage),
Some(usage)
} else {
None
},
service_tier: self.service_tier.clone(), service_tier: self.service_tier.clone(),
} }
} }
/// Check if usage tracking is enabled
pub fn is_usage_enabled(&self) -> bool {
self.options.enable_usage
}
} }
/// Implements the [`crate::protocols::openai::DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing /// Implements the [`crate::protocols::openai::DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing
...@@ -358,4 +383,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -358,4 +383,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
fn get_isl(&self) -> Option<u32> { fn get_isl(&self) -> Option<u32> {
Some(self.usage.prompt_tokens) Some(self.usage.prompt_tokens)
} }
fn create_usage_chunk(&self) -> NvCreateChatCompletionStreamResponse {
DeltaGenerator::create_usage_chunk(self)
}
fn is_usage_enabled(&self) -> bool {
DeltaGenerator::is_usage_enabled(self)
}
} }
...@@ -9,7 +9,12 @@ impl NvCreateCompletionRequest { ...@@ -9,7 +9,12 @@ impl NvCreateCompletionRequest {
// inspect the request to extract options // inspect the request to extract options
pub fn response_generator(&self, request_id: String) -> DeltaGenerator { pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
let options = DeltaGeneratorOptions { let options = DeltaGeneratorOptions {
enable_usage: true, enable_usage: self
.inner
.stream_options
.as_ref()
.map(|opts| opts.include_usage)
.unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0, enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
}; };
...@@ -143,11 +148,9 @@ impl DeltaGenerator { ...@@ -143,11 +148,9 @@ impl DeltaGenerator {
) -> NvCreateCompletionResponse { ) -> NvCreateCompletionResponse {
// todo - update for tool calling // todo - update for tool calling
let mut usage = self.usage.clone(); // According to OpenAI spec: when stream_options.include_usage is true,
if self.options.enable_usage { // all intermediate chunks should have usage: null
usage.total_tokens = usage.prompt_tokens + usage.completion_tokens; // The final usage chunk will be sent separately with empty choices
}
let inner = dynamo_async_openai::types::CreateCompletionResponse { let inner = dynamo_async_openai::types::CreateCompletionResponse {
id: self.id.clone(), id: self.id.clone(),
object: self.object.clone(), object: self.object.clone(),
...@@ -160,15 +163,38 @@ impl DeltaGenerator { ...@@ -160,15 +163,38 @@ impl DeltaGenerator {
finish_reason, finish_reason,
logprobs, logprobs,
}], }],
usage: if self.options.enable_usage { usage: None, // Always None for chunks with content/choices
Some(usage)
} else {
None
},
}; };
NvCreateCompletionResponse { inner } NvCreateCompletionResponse { inner }
} }
/// Creates a final usage-only chunk for OpenAI compliance.
/// This should be sent after the last content chunk when stream_options.include_usage is true.
///
/// # Returns
/// * A [`NvCreateCompletionResponse`] with empty choices and usage stats.
pub fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
let mut usage = self.usage.clone();
usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
let inner = dynamo_async_openai::types::CreateCompletionResponse {
id: self.id.clone(),
object: self.object.clone(),
created: self.created,
model: self.model.clone(),
system_fingerprint: self.system_fingerprint.clone(),
choices: vec![], // Empty choices for usage-only chunk
usage: Some(usage),
};
NvCreateCompletionResponse { inner }
}
/// Check if usage tracking is enabled
pub fn is_usage_enabled(&self) -> bool {
self.options.enable_usage
}
} }
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for DeltaGenerator { impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for DeltaGenerator {
...@@ -207,4 +233,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for ...@@ -207,4 +233,12 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
fn get_isl(&self) -> Option<u32> { fn get_isl(&self) -> Option<u32> {
Some(self.usage.prompt_tokens) Some(self.usage.prompt_tokens)
} }
fn create_usage_chunk(&self) -> NvCreateCompletionResponse {
DeltaGenerator::create_usage_chunk(self)
}
fn is_usage_enabled(&self) -> bool {
DeltaGenerator::is_usage_enabled(self)
}
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use async_trait::async_trait;
use dynamo_async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, ChatCompletionStreamOptions,
CreateChatCompletionRequest,
};
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::protocols::common::llm_backend::{BackendOutput, FinishReason};
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
use dynamo_runtime::engine::{AsyncEngineContext, AsyncEngineStream};
use dynamo_runtime::protocols::annotated::Annotated;
use futures::StreamExt;
use futures::stream;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
// Mock context for testing
#[derive(Debug)]
struct MockContext {
id: String,
stopped: AtomicBool,
killed: AtomicBool,
}
impl MockContext {
fn new() -> Self {
Self {
id: "test-request-123".to_string(),
stopped: AtomicBool::new(false),
killed: AtomicBool::new(false),
}
}
}
#[async_trait]
impl AsyncEngineContext for MockContext {
fn id(&self) -> &str {
&self.id
}
fn stop_generating(&self) {
self.stopped.store(true, Ordering::SeqCst);
}
fn is_stopped(&self) -> bool {
self.stopped.load(Ordering::SeqCst)
}
fn is_killed(&self) -> bool {
self.killed.load(Ordering::SeqCst)
}
async fn stopped(&self) {
// No-op for testing
}
async fn killed(&self) {
// No-op for testing
}
fn stop(&self) {
self.stopped.store(true, Ordering::SeqCst);
}
fn kill(&self) {
self.killed.store(true, Ordering::SeqCst);
}
fn link_child(&self, _: Arc<dyn AsyncEngineContext>) {
// No-op for testing
}
}
/// Creates a mock stream of BackendOutput messages simulating a typical LLM response
fn create_mock_backend_stream(
ctx: Arc<dyn AsyncEngineContext>,
) -> Pin<Box<dyn AsyncEngineStream<Annotated<BackendOutput>>>> {
let outputs = vec![
// First chunk with "Hello"
BackendOutput {
token_ids: vec![15339],
tokens: vec![Some("Hello".to_string())],
text: Some("Hello".to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: None,
index: Some(0),
},
// Second chunk with " world"
BackendOutput {
token_ids: vec![1917],
tokens: vec![Some(" world".to_string())],
text: Some(" world".to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: None,
index: Some(0),
},
// Third chunk with "!" and finish_reason
BackendOutput {
token_ids: vec![0],
tokens: vec![Some("!".to_string())],
text: Some("!".to_string()),
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(FinishReason::Stop),
index: Some(0),
},
];
let stream = stream::iter(outputs.into_iter().map(Annotated::from_data));
use dynamo_runtime::engine::ResponseStream;
ResponseStream::new(Box::pin(stream), ctx)
}
/// Helper to create a chat completion request with optional stream_options
fn create_chat_request(include_usage: Option<bool>) -> NvCreateChatCompletionRequest {
let messages = vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
name: None,
},
)];
let stream_options = include_usage.map(|include| ChatCompletionStreamOptions {
include_usage: include,
});
let inner = CreateChatCompletionRequest {
model: "test-model".to_string(),
messages,
stream: Some(true),
stream_options,
..Default::default()
};
NvCreateChatCompletionRequest {
inner,
common: Default::default(),
nvext: None,
}
}
#[tokio::test]
async fn test_streaming_without_usage() {
// Create request without stream_options (usage should not be included)
let request = create_chat_request(None);
let request_id = "test-123".to_string();
let response_generator = Box::new(request.response_generator(request_id));
// Create mock backend stream
let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx);
// Transform the stream
let transformed_stream =
OpenAIPreprocessor::transform_postprocessor_stream(backend_stream, response_generator);
// Collect all chunks
let chunks: Vec<_> = transformed_stream.collect().await;
// Verify we got exactly 3 chunks (no extra usage chunk)
assert_eq!(chunks.len(), 3, "Should have exactly 3 content chunks");
// Verify all chunks have usage: None
for (i, chunk) in chunks.iter().enumerate() {
if let Some(response) = &chunk.data {
assert!(
response.usage.is_none(),
"Chunk {} should have usage: None when stream_options not set",
i
);
assert!(
!response.choices.is_empty(),
"Chunk {} should have choices",
i
);
}
}
}
#[tokio::test]
async fn test_streaming_with_usage_compliance() {
// Create request with stream_options.include_usage = true
let request = create_chat_request(Some(true));
let request_id = "test-456".to_string();
let response_generator = Box::new(request.response_generator(request_id));
// Create mock backend stream
let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx);
// Transform the stream
let transformed_stream =
OpenAIPreprocessor::transform_postprocessor_stream(backend_stream, response_generator);
// Collect all chunks
let chunks: Vec<_> = transformed_stream.collect().await;
// Verify we got 4 chunks (3 content + 1 usage)
assert_eq!(
chunks.len(),
4,
"Should have 3 content chunks + 1 usage chunk"
);
// Verify first 3 chunks have usage: None and non-empty choices
for (i, chunk) in chunks.iter().take(3).enumerate() {
if let Some(response) = &chunk.data {
assert!(
response.usage.is_none(),
"Content chunk {} should have usage: None",
i
);
assert!(
!response.choices.is_empty(),
"Content chunk {} should have choices",
i
);
}
}
// Verify the final chunk is the usage-only chunk
if let Some(final_response) = &chunks[3].data {
assert!(
final_response.choices.is_empty(),
"Final usage chunk should have empty choices array"
);
assert!(
final_response.usage.is_some(),
"Final usage chunk should have usage statistics"
);
let usage = final_response.usage.as_ref().unwrap();
assert_eq!(
usage.completion_tokens, 3,
"Should have 3 completion tokens"
);
assert_eq!(
usage.prompt_tokens, 0,
"Should have 0 prompt tokens (not set in test)"
);
assert_eq!(
usage.total_tokens, 3,
"Total tokens should be prompt + completion"
);
} else {
panic!("Final chunk should be a valid response");
}
}
#[tokio::test]
async fn test_streaming_with_usage_false() {
// Create request with stream_options.include_usage = false (explicitly disabled)
let request = create_chat_request(Some(false));
let request_id = "test-789".to_string();
let response_generator = Box::new(request.response_generator(request_id));
// Create mock backend stream
let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx);
// Transform the stream
let transformed_stream =
OpenAIPreprocessor::transform_postprocessor_stream(backend_stream, response_generator);
// Collect all chunks
let chunks: Vec<_> = transformed_stream.collect().await;
// Verify we got exactly 3 chunks (no extra usage chunk when explicitly false)
assert_eq!(
chunks.len(),
3,
"Should have exactly 3 content chunks when include_usage is false"
);
// Verify all chunks have usage: None
for (i, chunk) in chunks.iter().enumerate() {
if let Some(response) = &chunk.data {
assert!(
response.usage.is_none(),
"Chunk {} should have usage: None when include_usage is false",
i
);
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment