Unverified Commit 3fe5653b authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: add usage field to non-streaming responses by default (#3922)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 502df72b
...@@ -818,6 +818,7 @@ impl ...@@ -818,6 +818,7 @@ impl
// Preserve original inbound streaming flag before any internal overrides // Preserve original inbound streaming flag before any internal overrides
let request_id = context.id().to_string(); let request_id = context.id().to_string();
let original_stream_flag = request.inner.stream.unwrap_or(false);
// Build audit handle (None if DYN_AUDIT_ENABLED=0) // Build audit handle (None if DYN_AUDIT_ENABLED=0)
let mut audit_handle = crate::audit::handle::create_handle(&request, &request_id); let mut audit_handle = crate::audit::handle::create_handle(&request, &request_id);
...@@ -826,6 +827,11 @@ impl ...@@ -826,6 +827,11 @@ impl
h.set_request(std::sync::Arc::new(request.clone())); h.set_request(std::sync::Arc::new(request.clone()));
} }
// For non-streaming requests (stream=false), enable usage by default
// This ensures compliance with OpenAI API spec where non-streaming responses
// always include usage statistics
request.enable_usage_for_nonstreaming(original_stream_flag);
// Set stream=true for internal processing (after audit capture) // Set stream=true for internal processing (after audit capture)
request.inner.stream = Some(true); request.inner.stream = Some(true);
...@@ -952,6 +958,14 @@ impl ...@@ -952,6 +958,14 @@ impl
// unpack the request // unpack the request
let (mut request, context) = request.into_parts(); let (mut request, context) = request.into_parts();
// Preserve original streaming flag
let original_stream_flag = request.inner.stream.unwrap_or(false);
// For non-streaming requests (stream=false), enable usage by default
// This ensures compliance with OpenAI API spec where non-streaming responses
// always include usage statistics
request.enable_usage_for_nonstreaming(original_stream_flag);
request.inner.stream = Some(true); request.inner.stream = Some(true);
// create a response generator // create a response generator
......
...@@ -10,6 +10,29 @@ use crate::{ ...@@ -10,6 +10,29 @@ use crate::{
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request. /// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
impl NvCreateChatCompletionRequest { impl NvCreateChatCompletionRequest {
/// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
///
/// According to OpenAI API spec, non-streaming chat completion responses (stream=false)
/// must always include usage statistics. This method ensures `stream_options.include_usage`
/// is set to `true` for non-streaming requests.
///
/// # Arguments
/// * `original_stream_flag` - The original value of the `stream` field before any internal processing
pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
if !original_stream_flag {
// For non-streaming requests (stream=false), enable usage by default
if self.inner.stream_options.is_none() {
self.inner.stream_options =
Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
include_usage: true,
});
} else if let Some(ref mut opts) = self.inner.stream_options {
// If stream_options exists, ensure include_usage is true for non-streaming
opts.include_usage = true;
}
}
}
/// Creates a [`DeltaGenerator`] instance based on the chat completion request. /// Creates a [`DeltaGenerator`] instance based on the chat completion request.
/// ///
/// # Arguments /// # Arguments
...@@ -342,3 +365,66 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -342,3 +365,66 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
DeltaGenerator::is_usage_enabled(self) DeltaGenerator::is_usage_enabled(self)
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use dynamo_async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
};
fn create_test_request() -> NvCreateChatCompletionRequest {
let messages = vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("test".to_string()),
name: None,
},
)];
NvCreateChatCompletionRequest {
inner: CreateChatCompletionRequest {
model: "test-model".to_string(),
messages,
stream: Some(false),
stream_options: None,
..Default::default()
},
common: Default::default(),
nvext: None,
chat_template_args: None,
}
}
#[test]
fn test_enable_usage_for_nonstreaming_enables_usage() {
// Test that non-streaming requests get usage enabled
let mut request = create_test_request();
assert!(request.inner.stream_options.is_none());
request.enable_usage_for_nonstreaming(false); // false = non-streaming
assert!(
request.inner.stream_options.is_some(),
"Non-streaming request should have stream_options created"
);
assert!(
request.inner.stream_options.unwrap().include_usage,
"Non-streaming request should have include_usage=true for OpenAI compliance"
);
}
#[test]
fn test_enable_usage_for_nonstreaming_ignores_streaming() {
// Test that streaming requests are not modified
let mut request = create_test_request();
assert!(request.inner.stream_options.is_none());
request.enable_usage_for_nonstreaming(true); // true = streaming
assert!(
request.inner.stream_options.is_none(),
"Streaming request should not have stream_options modified"
);
}
}
...@@ -5,6 +5,31 @@ use super::{NvCreateCompletionRequest, NvCreateCompletionResponse}; ...@@ -5,6 +5,31 @@ use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
use crate::{protocols::common, types::TokenIdType}; use crate::{protocols::common, types::TokenIdType};
impl NvCreateCompletionRequest { impl NvCreateCompletionRequest {
/// Enables usage tracking for non-streaming requests to comply with OpenAI API specification.
///
/// According to OpenAI API spec, non-streaming completion responses (stream=false)
/// must always include usage statistics. This method ensures `stream_options.include_usage`
/// is set to `true` for non-streaming requests.
///
/// Reference: https://platform.openai.com/docs/api-reference/completions/create
///
/// # Arguments
/// * `original_stream_flag` - The original value of the `stream` field before any internal processing
pub fn enable_usage_for_nonstreaming(&mut self, original_stream_flag: bool) {
if !original_stream_flag {
// For non-streaming requests (stream=false), enable usage by default
if self.inner.stream_options.is_none() {
self.inner.stream_options =
Some(dynamo_async_openai::types::ChatCompletionStreamOptions {
include_usage: true,
});
} else if let Some(ref mut opts) = self.inner.stream_options {
// If stream_options exists, ensure include_usage is true for non-streaming
opts.include_usage = true;
}
}
}
// put this method on the request // put this method on the request
// inspect the request to extract options // inspect the request to extract options
pub fn response_generator(&self, request_id: String) -> DeltaGenerator { pub fn response_generator(&self, request_id: String) -> DeltaGenerator {
......
...@@ -9,7 +9,10 @@ use dynamo_async_openai::types::{ ...@@ -9,7 +9,10 @@ use dynamo_async_openai::types::{
}; };
use dynamo_llm::preprocessor::OpenAIPreprocessor; use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::protocols::common::llm_backend::{BackendOutput, FinishReason}; use dynamo_llm::protocols::common::llm_backend::{BackendOutput, FinishReason};
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest; use dynamo_llm::protocols::openai::ParsingOptions;
use dynamo_llm::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, aggregator::ChatCompletionAggregator,
};
use dynamo_runtime::engine::{AsyncEngineContext, AsyncEngineStream}; use dynamo_runtime::engine::{AsyncEngineContext, AsyncEngineStream};
use dynamo_runtime::protocols::annotated::Annotated; use dynamo_runtime::protocols::annotated::Annotated;
use futures::StreamExt; use futures::StreamExt;
...@@ -303,3 +306,99 @@ async fn test_streaming_with_usage_false() { ...@@ -303,3 +306,99 @@ async fn test_streaming_with_usage_false() {
} }
} }
} }
/// Helper to create a non-streaming chat completion request
fn create_nonstreaming_chat_request() -> NvCreateChatCompletionRequest {
let messages = vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text("Hello".to_string()),
name: None,
},
)];
let inner = CreateChatCompletionRequest {
model: "test-model".to_string(),
messages,
stream: Some(false),
stream_options: None,
..Default::default()
};
NvCreateChatCompletionRequest {
inner,
common: Default::default(),
nvext: None,
chat_template_args: None,
}
}
#[tokio::test]
async fn test_nonstreaming_has_usage_field() {
let mut request = create_nonstreaming_chat_request();
assert_eq!(
request.inner.stream,
Some(false),
"Request should be non-streaming"
);
assert!(
request.inner.stream_options.is_none(),
"stream_options should not be set initially"
);
// Simulate what the preprocessor does for non-streaming requests
let original_stream_flag = request.inner.stream.unwrap_or(false);
// Enable usage for non-streaming requests
request.enable_usage_for_nonstreaming(original_stream_flag);
let request_id = "test-nonstream-123".to_string();
let response_generator = Box::new(request.response_generator(request_id));
// Create mock backend stream
let ctx = Arc::new(MockContext::new());
let backend_stream = create_mock_backend_stream(ctx.clone());
// Transform the stream (this generates streaming chunks)
let transformed_stream = OpenAIPreprocessor::transform_postprocessor_stream(
backend_stream,
response_generator,
ctx.clone(),
);
// Aggregate the streaming chunks into a single non-streaming response
// This simulates what the HTTP service does for non-streaming requests
let result = dynamo_async_openai::types::CreateChatCompletionResponse::from_annotated_stream(
transformed_stream,
ParsingOptions::default(),
)
.await;
assert!(result.is_ok(), "Aggregation should succeed");
let response = result.unwrap();
assert!(
response.usage.is_some(),
"Non-streaming chat completion response MUST have a usage field populated. \
This is required for OpenAI API compliance."
);
let usage = response.usage.unwrap();
// Verify usage contains valid token counts
// In our mock, we generated 3 tokens (from the 3 backend outputs)
assert_eq!(
usage.completion_tokens, 3,
"Completion tokens should match the number of tokens generated"
);
assert!(
usage.total_tokens > 0,
"Total tokens should be greater than 0"
);
assert_eq!(
usage.total_tokens,
usage.prompt_tokens + usage.completion_tokens,
"Total tokens should equal prompt_tokens + completion_tokens"
);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment