Unverified Commit 6bfb41de authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: add continuous_usage_stats option for per-chunk usage (#5139)


Signed-off-by: default avatarGuan Luo <gluo@nvidia.com>
parent 3e341fd6
......@@ -966,6 +966,9 @@ pub struct CreateChatCompletionRequest {
pub struct ChatCompletionStreamOptions {
/// If set, an additional chunk will be streamed before the `data: [DONE]` message. The `usage` field on this chunk shows the token usage statistics for the entire request, and the `choices` field will always be an empty array. All other chunks will also include a `usage` field, but with a null value.
pub include_usage: bool,
/// NVIDIA-specific and industrial common extensions for per chunk usage reporting.
#[serde(default)]
pub continuous_usage_stats: bool,
}
#[derive(ToSchema, Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
......
......@@ -226,6 +226,9 @@ pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>:
/// Check if usage tracking is enabled.
fn is_usage_enabled(&self) -> bool;
/// Check if continuous usage tracking is enabled.
fn is_continuous_usage_enabled(&self) -> bool;
/// Get the current usage statistics with properly calculated total_tokens.
fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage;
}
......
......@@ -30,6 +30,7 @@ impl NvCreateChatCompletionRequest {
self.inner.stream_options =
Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
include_usage: true,
continuous_usage_stats: false,
});
} else if let Some(ref mut opts) = self.inner.stream_options {
// If stream_options exists, ensure include_usage is true for non-streaming
......@@ -68,6 +69,12 @@ impl NvCreateChatCompletionRequest {
.as_ref()
.map(|opts| opts.include_usage)
.unwrap_or(false),
continuous_usage_stats: self
.inner
.stream_options
.as_ref()
.map(|opts| opts.continuous_usage_stats)
.unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(false)
|| self.inner.top_logprobs.unwrap_or(0) > 0,
enable_tracking,
......@@ -83,6 +90,8 @@ impl NvCreateChatCompletionRequest {
pub struct DeltaGeneratorOptions {
/// Determines whether token usage statistics should be included in the response.
pub enable_usage: bool,
/// Determines whether continuous usage statistics should be included in the response.
pub continuous_usage_stats: bool,
/// Determines whether log probabilities should be included in the response.
pub enable_logprobs: bool,
/// Determines whether request tracking (timing, KV hit rate) should be enabled.
......@@ -296,7 +305,11 @@ impl DeltaGenerator {
model: self.model.clone(),
system_fingerprint: self.system_fingerprint.clone(),
choices,
usage: None, // Always None for chunks with content/choices
usage: if self.options.enable_usage && self.options.continuous_usage_stats {
Some(self.get_usage())
} else {
None
},
service_tier: self.service_tier.clone(),
nvext: None, // Will be populated by router layer if needed
}
......@@ -328,6 +341,11 @@ impl DeltaGenerator {
self.options.enable_usage
}
/// Check if continuous usage tracking is enabled
pub fn is_continuous_usage_enabled(&self) -> bool {
self.options.continuous_usage_stats
}
pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
let mut usage = self.usage.clone();
usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
......@@ -476,6 +494,10 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
DeltaGenerator::is_usage_enabled(self)
}
fn is_continuous_usage_enabled(&self) -> bool {
DeltaGenerator::is_continuous_usage_enabled(self)
}
fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
DeltaGenerator::get_usage(self)
}
......@@ -529,6 +551,10 @@ mod tests {
request.inner.stream_options.unwrap().include_usage,
"Non-streaming request should have include_usage=true for OpenAI compliance"
);
assert!(
!request.inner.stream_options.unwrap().continuous_usage_stats,
"Non-streaming request should have continuous_usage_stats=false for OpenAI compliance"
);
}
#[test]
......
......@@ -30,6 +30,7 @@ impl NvCreateCompletionRequest {
self.inner.stream_options =
Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
include_usage: true,
continuous_usage_stats: false,
});
} else if let Some(ref mut opts) = self.inner.stream_options {
// If stream_options exists, ensure include_usage is true for non-streaming
......@@ -63,6 +64,12 @@ impl NvCreateCompletionRequest {
.as_ref()
.map(|opts| opts.include_usage)
.unwrap_or(false),
continuous_usage_stats: self
.inner
.stream_options
.as_ref()
.map(|opts| opts.continuous_usage_stats)
.unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
enable_tracking,
};
......@@ -74,6 +81,7 @@ impl NvCreateCompletionRequest {
#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
pub enable_usage: bool,
pub continuous_usage_stats: bool,
pub enable_logprobs: bool,
pub enable_tracking: bool,
}
......@@ -226,7 +234,11 @@ impl DeltaGenerator {
finish_reason,
logprobs,
}],
usage: None, // Always None for chunks with content/choices
usage: if self.options.enable_usage && self.options.continuous_usage_stats {
Some(self.get_usage())
} else {
None
},
nvext: None, // Will be populated by router layer if needed
};
......@@ -260,6 +272,11 @@ impl DeltaGenerator {
self.options.enable_usage
}
/// Check if continuous usage tracking is enabled
pub fn is_continuous_usage_enabled(&self) -> bool {
self.options.continuous_usage_stats
}
pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
let mut usage = self.usage.clone();
usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
......@@ -371,6 +388,10 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
DeltaGenerator::is_usage_enabled(self)
}
fn is_continuous_usage_enabled(&self) -> bool {
DeltaGenerator::is_continuous_usage_enabled(self)
}
fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
DeltaGenerator::get_usage(self)
}
......
......@@ -161,7 +161,10 @@ fn create_backend_stream_with_cached_tokens(
}
/// Helper to create a chat completion request with optional stream_options
fn create_chat_request(include_usage: Option<bool>) -> NvCreateChatCompletionRequest {
fn create_chat_request(
include_usage: Option<bool>,
continuous_usage: Option<bool>,
) -> NvCreateChatCompletionRequest {
let messages = vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
......@@ -171,6 +174,7 @@ fn create_chat_request(include_usage: Option<bool>) -> NvCreateChatCompletionReq
let stream_options = include_usage.map(|include| ChatCompletionStreamOptions {
include_usage: include,
continuous_usage_stats: continuous_usage.unwrap_or(false),
});
let inner = CreateChatCompletionRequest {
......@@ -194,7 +198,7 @@ fn create_chat_request(include_usage: Option<bool>) -> NvCreateChatCompletionReq
#[tokio::test]
async fn test_streaming_without_usage() {
// Create request without stream_options (usage should not be included)
let request = create_chat_request(None);
let request = create_chat_request(None, None);
let request_id = "test-123".to_string();
let response_generator = Box::new(request.response_generator(request_id));
......@@ -253,7 +257,7 @@ async fn test_streaming_without_usage() {
#[tokio::test]
async fn test_streaming_with_usage_compliance() {
// Create request with stream_options.include_usage = true
let request = create_chat_request(Some(true));
let request = create_chat_request(Some(true), None);
let request_id = "test-456".to_string();
let response_generator = Box::new(request.response_generator(request_id));
......@@ -323,10 +327,101 @@ async fn test_streaming_with_usage_compliance() {
}
}
#[tokio::test]
async fn test_streaming_with_continuous_usage() {
// Create request with stream_options.include_usage = true, stream_options.continuous_usage_stats = true
let request = create_chat_request(Some(true), Some(true));
let request_id = "test-456".to_string();
let response_generator = Box::new(request.response_generator(request_id));
// Create mock backend stream
let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx.clone());
// Transform the stream
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
// Collect all chunks
let chunks: Vec<_> = transformed_stream.collect().await;
// Verify we got 4 chunks (3 content + 1 usage)
assert_eq!(
chunks.len(),
4,
"Should have 3 content chunks + 1 usage chunk"
);
// Verify first 3 chunks have usage: None and non-empty choices
for (i, chunk) in chunks.iter().take(3).enumerate() {
if let Some(response) = &chunk.data {
assert!(
response.usage.is_some(),
"Content chunk {} should have usage: Some",
i
);
assert!(
!response.choices.is_empty(),
"Content chunk {} should have choices",
i
);
// Verify usage counts are properly accumulated for each chunk
let usage = response.usage.as_ref().unwrap();
assert_eq!(
usage.completion_tokens,
i as u32 + 1,
"Should have {} completion tokens",
i + 1
);
assert_eq!(
usage.prompt_tokens, 0,
"Should have 0 prompt tokens (not set in test)"
);
assert_eq!(
usage.total_tokens,
i as u32 + 1,
"Total tokens should be prompt + completion"
);
}
}
// Verify the final chunk is the usage-only chunk
if let Some(final_response) = &chunks[3].data {
assert!(
final_response.choices.is_empty(),
"Final usage chunk should have empty choices array"
);
assert!(
final_response.usage.is_some(),
"Final usage chunk should have usage statistics"
);
let usage = final_response.usage.as_ref().unwrap();
assert_eq!(
usage.completion_tokens, 3,
"Should have 3 completion tokens"
);
assert_eq!(
usage.prompt_tokens, 0,
"Should have 0 prompt tokens (not set in test)"
);
assert_eq!(
usage.total_tokens, 3,
"Total tokens should be prompt + completion"
);
} else {
panic!("Final chunk should be a valid response");
}
}
#[tokio::test]
async fn test_streaming_with_usage_false() {
// Create request with stream_options.include_usage = false (explicitly disabled)
let request = create_chat_request(Some(false));
let request = create_chat_request(Some(false), None);
let request_id = "test-789".to_string();
let response_generator = Box::new(request.response_generator(request_id));
......@@ -388,6 +483,7 @@ fn create_cmpl_request(include_usage: Option<bool>, stream: bool) -> NvCreateCom
if let Some(include) = include_usage {
builder.stream_options(dynamo_async_openai::types::ChatCompletionStreamOptions {
include_usage: include,
continuous_usage_stats: false,
});
}
builder.build().unwrap()
......@@ -610,7 +706,7 @@ async fn test_cmpl_streaming_with_cached_tokens_propagation() {
#[tokio::test]
async fn test_chat_streaming_with_cached_tokens_propagation() {
// Chat Completions: include_usage=true, backend provides cached_tokens -> must propagate
let request = create_chat_request(Some(true));
let request = create_chat_request(Some(true), Some(true));
let request_id = "chat-usage-cached-1".to_string();
let mut response_generator = Box::new(request.response_generator(request_id));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment