Unverified Commit 6bfb41de authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

feat: add continuous_usage_stats option for per-chunk usage (#5139)


Signed-off-by: default avatarGuan Luo <gluo@nvidia.com>
parent 3e341fd6
...@@ -966,6 +966,9 @@ pub struct CreateChatCompletionRequest { ...@@ -966,6 +966,9 @@ pub struct CreateChatCompletionRequest {
pub struct ChatCompletionStreamOptions { pub struct ChatCompletionStreamOptions {
/// If set, an additional chunk will be streamed before the `data: [DONE]` message. The `usage` field on this chunk shows the token usage statistics for the entire request, and the `choices` field will always be an empty array. All other chunks will also include a `usage` field, but with a null value. /// If set, an additional chunk will be streamed before the `data: [DONE]` message. The `usage` field on this chunk shows the token usage statistics for the entire request, and the `choices` field will always be an empty array. All other chunks will also include a `usage` field, but with a null value.
pub include_usage: bool, pub include_usage: bool,
/// NVIDIA-specific and industrial common extensions for per chunk usage reporting.
#[serde(default)]
pub continuous_usage_stats: bool,
} }
#[derive(ToSchema, Debug, Serialize, Deserialize, Clone, Copy, PartialEq)] #[derive(ToSchema, Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
......
...@@ -226,6 +226,9 @@ pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>: ...@@ -226,6 +226,9 @@ pub trait DeltaGeneratorExt<ResponseType: Send + 'static + std::fmt::Debug>:
/// Check if usage tracking is enabled. /// Check if usage tracking is enabled.
fn is_usage_enabled(&self) -> bool; fn is_usage_enabled(&self) -> bool;
/// Check if continuous usage tracking is enabled.
fn is_continuous_usage_enabled(&self) -> bool;
/// Get the current usage statistics with properly calculated total_tokens. /// Get the current usage statistics with properly calculated total_tokens.
fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage; fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage;
} }
......
...@@ -30,6 +30,7 @@ impl NvCreateChatCompletionRequest { ...@@ -30,6 +30,7 @@ impl NvCreateChatCompletionRequest {
self.inner.stream_options = self.inner.stream_options =
Some(dynamo_async_openai::types::ChatCompletionStreamOptions { Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
include_usage: true, include_usage: true,
continuous_usage_stats: false,
}); });
} else if let Some(ref mut opts) = self.inner.stream_options { } else if let Some(ref mut opts) = self.inner.stream_options {
// If stream_options exists, ensure include_usage is true for non-streaming // If stream_options exists, ensure include_usage is true for non-streaming
...@@ -68,6 +69,12 @@ impl NvCreateChatCompletionRequest { ...@@ -68,6 +69,12 @@ impl NvCreateChatCompletionRequest {
.as_ref() .as_ref()
.map(|opts| opts.include_usage) .map(|opts| opts.include_usage)
.unwrap_or(false), .unwrap_or(false),
continuous_usage_stats: self
.inner
.stream_options
.as_ref()
.map(|opts| opts.continuous_usage_stats)
.unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(false) enable_logprobs: self.inner.logprobs.unwrap_or(false)
|| self.inner.top_logprobs.unwrap_or(0) > 0, || self.inner.top_logprobs.unwrap_or(0) > 0,
enable_tracking, enable_tracking,
...@@ -83,6 +90,8 @@ impl NvCreateChatCompletionRequest { ...@@ -83,6 +90,8 @@ impl NvCreateChatCompletionRequest {
pub struct DeltaGeneratorOptions { pub struct DeltaGeneratorOptions {
/// Determines whether token usage statistics should be included in the response. /// Determines whether token usage statistics should be included in the response.
pub enable_usage: bool, pub enable_usage: bool,
/// Determines whether continuous usage statistics should be included in the response.
pub continuous_usage_stats: bool,
/// Determines whether log probabilities should be included in the response. /// Determines whether log probabilities should be included in the response.
pub enable_logprobs: bool, pub enable_logprobs: bool,
/// Determines whether request tracking (timing, KV hit rate) should be enabled. /// Determines whether request tracking (timing, KV hit rate) should be enabled.
...@@ -296,7 +305,11 @@ impl DeltaGenerator { ...@@ -296,7 +305,11 @@ impl DeltaGenerator {
model: self.model.clone(), model: self.model.clone(),
system_fingerprint: self.system_fingerprint.clone(), system_fingerprint: self.system_fingerprint.clone(),
choices, choices,
usage: None, // Always None for chunks with content/choices usage: if self.options.enable_usage && self.options.continuous_usage_stats {
Some(self.get_usage())
} else {
None
},
service_tier: self.service_tier.clone(), service_tier: self.service_tier.clone(),
nvext: None, // Will be populated by router layer if needed nvext: None, // Will be populated by router layer if needed
} }
...@@ -328,6 +341,11 @@ impl DeltaGenerator { ...@@ -328,6 +341,11 @@ impl DeltaGenerator {
self.options.enable_usage self.options.enable_usage
} }
/// Check if continuous usage tracking is enabled
pub fn is_continuous_usage_enabled(&self) -> bool {
self.options.continuous_usage_stats
}
pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage { pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
let mut usage = self.usage.clone(); let mut usage = self.usage.clone();
usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens); usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
...@@ -476,6 +494,10 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -476,6 +494,10 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
DeltaGenerator::is_usage_enabled(self) DeltaGenerator::is_usage_enabled(self)
} }
fn is_continuous_usage_enabled(&self) -> bool {
DeltaGenerator::is_continuous_usage_enabled(self)
}
fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage { fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
DeltaGenerator::get_usage(self) DeltaGenerator::get_usage(self)
} }
...@@ -529,6 +551,10 @@ mod tests { ...@@ -529,6 +551,10 @@ mod tests {
request.inner.stream_options.unwrap().include_usage, request.inner.stream_options.unwrap().include_usage,
"Non-streaming request should have include_usage=true for OpenAI compliance" "Non-streaming request should have include_usage=true for OpenAI compliance"
); );
assert!(
!request.inner.stream_options.unwrap().continuous_usage_stats,
"Non-streaming request should have continuous_usage_stats=false for OpenAI compliance"
);
} }
#[test] #[test]
......
...@@ -30,6 +30,7 @@ impl NvCreateCompletionRequest { ...@@ -30,6 +30,7 @@ impl NvCreateCompletionRequest {
self.inner.stream_options = self.inner.stream_options =
Some(dynamo_async_openai::types::ChatCompletionStreamOptions { Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
include_usage: true, include_usage: true,
continuous_usage_stats: false,
}); });
} else if let Some(ref mut opts) = self.inner.stream_options { } else if let Some(ref mut opts) = self.inner.stream_options {
// If stream_options exists, ensure include_usage is true for non-streaming // If stream_options exists, ensure include_usage is true for non-streaming
...@@ -63,6 +64,12 @@ impl NvCreateCompletionRequest { ...@@ -63,6 +64,12 @@ impl NvCreateCompletionRequest {
.as_ref() .as_ref()
.map(|opts| opts.include_usage) .map(|opts| opts.include_usage)
.unwrap_or(false), .unwrap_or(false),
continuous_usage_stats: self
.inner
.stream_options
.as_ref()
.map(|opts| opts.continuous_usage_stats)
.unwrap_or(false),
enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0, enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0,
enable_tracking, enable_tracking,
}; };
...@@ -74,6 +81,7 @@ impl NvCreateCompletionRequest { ...@@ -74,6 +81,7 @@ impl NvCreateCompletionRequest {
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions { pub struct DeltaGeneratorOptions {
pub enable_usage: bool, pub enable_usage: bool,
pub continuous_usage_stats: bool,
pub enable_logprobs: bool, pub enable_logprobs: bool,
pub enable_tracking: bool, pub enable_tracking: bool,
} }
...@@ -226,7 +234,11 @@ impl DeltaGenerator { ...@@ -226,7 +234,11 @@ impl DeltaGenerator {
finish_reason, finish_reason,
logprobs, logprobs,
}], }],
usage: None, // Always None for chunks with content/choices usage: if self.options.enable_usage && self.options.continuous_usage_stats {
Some(self.get_usage())
} else {
None
},
nvext: None, // Will be populated by router layer if needed nvext: None, // Will be populated by router layer if needed
}; };
...@@ -260,6 +272,11 @@ impl DeltaGenerator { ...@@ -260,6 +272,11 @@ impl DeltaGenerator {
self.options.enable_usage self.options.enable_usage
} }
/// Check if continuous usage tracking is enabled
pub fn is_continuous_usage_enabled(&self) -> bool {
self.options.continuous_usage_stats
}
pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage { pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
let mut usage = self.usage.clone(); let mut usage = self.usage.clone();
usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens); usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens);
...@@ -371,6 +388,10 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for ...@@ -371,6 +388,10 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
DeltaGenerator::is_usage_enabled(self) DeltaGenerator::is_usage_enabled(self)
} }
fn is_continuous_usage_enabled(&self) -> bool {
DeltaGenerator::is_continuous_usage_enabled(self)
}
fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage { fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage {
DeltaGenerator::get_usage(self) DeltaGenerator::get_usage(self)
} }
......
...@@ -161,7 +161,10 @@ fn create_backend_stream_with_cached_tokens( ...@@ -161,7 +161,10 @@ fn create_backend_stream_with_cached_tokens(
} }
/// Helper to create a chat completion request with optional stream_options /// Helper to create a chat completion request with optional stream_options
fn create_chat_request(include_usage: Option<bool>) -> NvCreateChatCompletionRequest { fn create_chat_request(
include_usage: Option<bool>,
continuous_usage: Option<bool>,
) -> NvCreateChatCompletionRequest {
let messages = vec![ChatCompletionRequestMessage::User( let messages = vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage { ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()), content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
...@@ -171,6 +174,7 @@ fn create_chat_request(include_usage: Option<bool>) -> NvCreateChatCompletionReq ...@@ -171,6 +174,7 @@ fn create_chat_request(include_usage: Option<bool>) -> NvCreateChatCompletionReq
let stream_options = include_usage.map(|include| ChatCompletionStreamOptions { let stream_options = include_usage.map(|include| ChatCompletionStreamOptions {
include_usage: include, include_usage: include,
continuous_usage_stats: continuous_usage.unwrap_or(false),
}); });
let inner = CreateChatCompletionRequest { let inner = CreateChatCompletionRequest {
...@@ -194,7 +198,7 @@ fn create_chat_request(include_usage: Option<bool>) -> NvCreateChatCompletionReq ...@@ -194,7 +198,7 @@ fn create_chat_request(include_usage: Option<bool>) -> NvCreateChatCompletionReq
#[tokio::test] #[tokio::test]
async fn test_streaming_without_usage() { async fn test_streaming_without_usage() {
// Create request without stream_options (usage should not be included) // Create request without stream_options (usage should not be included)
let request = create_chat_request(None); let request = create_chat_request(None, None);
let request_id = "test-123".to_string(); let request_id = "test-123".to_string();
let response_generator = Box::new(request.response_generator(request_id)); let response_generator = Box::new(request.response_generator(request_id));
...@@ -253,7 +257,7 @@ async fn test_streaming_without_usage() { ...@@ -253,7 +257,7 @@ async fn test_streaming_without_usage() {
#[tokio::test] #[tokio::test]
async fn test_streaming_with_usage_compliance() { async fn test_streaming_with_usage_compliance() {
// Create request with stream_options.include_usage = true // Create request with stream_options.include_usage = true
let request = create_chat_request(Some(true)); let request = create_chat_request(Some(true), None);
let request_id = "test-456".to_string(); let request_id = "test-456".to_string();
let response_generator = Box::new(request.response_generator(request_id)); let response_generator = Box::new(request.response_generator(request_id));
...@@ -323,10 +327,101 @@ async fn test_streaming_with_usage_compliance() { ...@@ -323,10 +327,101 @@ async fn test_streaming_with_usage_compliance() {
} }
} }
#[tokio::test]
async fn test_streaming_with_continuous_usage() {
// Create request with stream_options.include_usage = true, stream_options.continuous_usage_stats = true
let request = create_chat_request(Some(true), Some(true));
let request_id = "test-456".to_string();
let response_generator = Box::new(request.response_generator(request_id));
// Create mock backend stream
let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx.clone());
// Transform the stream
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
// Collect all chunks
let chunks: Vec<_> = transformed_stream.collect().await;
// Verify we got 4 chunks (3 content + 1 usage)
assert_eq!(
chunks.len(),
4,
"Should have 3 content chunks + 1 usage chunk"
);
// Verify first 3 chunks have usage: None and non-empty choices
for (i, chunk) in chunks.iter().take(3).enumerate() {
if let Some(response) = &chunk.data {
assert!(
response.usage.is_some(),
"Content chunk {} should have usage: Some",
i
);
assert!(
!response.choices.is_empty(),
"Content chunk {} should have choices",
i
);
// Verify usage counts are properly accumulated for each chunk
let usage = response.usage.as_ref().unwrap();
assert_eq!(
usage.completion_tokens,
i as u32 + 1,
"Should have {} completion tokens",
i + 1
);
assert_eq!(
usage.prompt_tokens, 0,
"Should have 0 prompt tokens (not set in test)"
);
assert_eq!(
usage.total_tokens,
i as u32 + 1,
"Total tokens should be prompt + completion"
);
}
}
// Verify the final chunk is the usage-only chunk
if let Some(final_response) = &chunks[3].data {
assert!(
final_response.choices.is_empty(),
"Final usage chunk should have empty choices array"
);
assert!(
final_response.usage.is_some(),
"Final usage chunk should have usage statistics"
);
let usage = final_response.usage.as_ref().unwrap();
assert_eq!(
usage.completion_tokens, 3,
"Should have 3 completion tokens"
);
assert_eq!(
usage.prompt_tokens, 0,
"Should have 0 prompt tokens (not set in test)"
);
assert_eq!(
usage.total_tokens, 3,
"Total tokens should be prompt + completion"
);
} else {
panic!("Final chunk should be a valid response");
}
}
#[tokio::test] #[tokio::test]
async fn test_streaming_with_usage_false() { async fn test_streaming_with_usage_false() {
// Create request with stream_options.include_usage = false (explicitly disabled) // Create request with stream_options.include_usage = false (explicitly disabled)
let request = create_chat_request(Some(false)); let request = create_chat_request(Some(false), None);
let request_id = "test-789".to_string(); let request_id = "test-789".to_string();
let response_generator = Box::new(request.response_generator(request_id)); let response_generator = Box::new(request.response_generator(request_id));
...@@ -388,6 +483,7 @@ fn create_cmpl_request(include_usage: Option<bool>, stream: bool) -> NvCreateCom ...@@ -388,6 +483,7 @@ fn create_cmpl_request(include_usage: Option<bool>, stream: bool) -> NvCreateCom
if let Some(include) = include_usage { if let Some(include) = include_usage {
builder.stream_options(dynamo_async_openai::types::ChatCompletionStreamOptions { builder.stream_options(dynamo_async_openai::types::ChatCompletionStreamOptions {
include_usage: include, include_usage: include,
continuous_usage_stats: false,
}); });
} }
builder.build().unwrap() builder.build().unwrap()
...@@ -610,7 +706,7 @@ async fn test_cmpl_streaming_with_cached_tokens_propagation() { ...@@ -610,7 +706,7 @@ async fn test_cmpl_streaming_with_cached_tokens_propagation() {
#[tokio::test] #[tokio::test]
async fn test_chat_streaming_with_cached_tokens_propagation() { async fn test_chat_streaming_with_cached_tokens_propagation() {
// Chat Completions: include_usage=true, backend provides cached_tokens -> must propagate // Chat Completions: include_usage=true, backend provides cached_tokens -> must propagate
let request = create_chat_request(Some(true)); let request = create_chat_request(Some(true), Some(true));
let request_id = "chat-usage-cached-1".to_string(); let request_id = "chat-usage-cached-1".to_string();
let mut response_generator = Box::new(request.response_generator(request_id)); let mut response_generator = Box::new(request.response_generator(request_id));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment