Unverified Commit c12fe501 authored by nachiketb-nvidia's avatar nachiketb-nvidia Committed by GitHub
Browse files

chore: remove flatten for chat response types, add reasoning_content (#2543)

Changing the chat completions response objects from structs to types of dynamo_async_openai

Implement aggregator traits for them chat completion structs

add reasoning_content under message and delta message in lib/async-openai
parent a0ddcbce
...@@ -50,6 +50,7 @@ jobs: ...@@ -50,6 +50,7 @@ jobs:
# Set GITHUB_TOKEN to avoid github rate limits on URL checks # Set GITHUB_TOKEN to avoid github rate limits on URL checks
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: | run: |
cd docs
set -euo pipefail set -euo pipefail
# Run lychee against all files in repo # Run lychee against all files in repo
lychee \ lychee \
......
...@@ -449,6 +449,9 @@ pub struct ChatCompletionResponseMessage { ...@@ -449,6 +449,9 @@ pub struct ChatCompletionResponseMessage {
/// If the audio output modality is requested, this object contains data about the audio response from the model. [Learn more](https://platform.openai.com/docs/guides/audio). /// If the audio output modality is requested, this object contains data about the audio response from the model. [Learn more](https://platform.openai.com/docs/guides/audio).
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub audio: Option<ChatCompletionResponseMessageAudio>, pub audio: Option<ChatCompletionResponseMessageAudio>,
/// NVIDIA-specific extensions for the chat completion response.
pub reasoning_content: Option<String>,
} }
#[derive(Clone, Serialize, Default, Debug, Deserialize, Builder, PartialEq)] #[derive(Clone, Serialize, Default, Debug, Deserialize, Builder, PartialEq)]
...@@ -1021,6 +1024,9 @@ pub struct ChatCompletionStreamResponseDelta { ...@@ -1021,6 +1024,9 @@ pub struct ChatCompletionStreamResponseDelta {
pub role: Option<Role>, pub role: Option<Role>,
/// The refusal message generated by the model. /// The refusal message generated by the model.
pub refusal: Option<String>, pub refusal: Option<String>,
/// NVIDIA-specific extensions for the chat completion response.
pub reasoning_content: Option<String>,
} }
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
......
...@@ -396,7 +396,7 @@ impl ...@@ -396,7 +396,7 @@ impl
//tracing::trace!("from_assistant: {from_assistant}"); //tracing::trace!("from_assistant: {from_assistant}");
#[allow(deprecated)] #[allow(deprecated)]
let inner = dynamo_async_openai::types::CreateChatCompletionStreamResponse{ let delta = NvCreateChatCompletionStreamResponse {
id: c.id, id: c.id,
choices: vec![dynamo_async_openai::types::ChatChoiceStream{ choices: vec![dynamo_async_openai::types::ChatChoiceStream{
index: 0, index: 0,
...@@ -407,6 +407,7 @@ impl ...@@ -407,6 +407,7 @@ impl
tool_calls: None, tool_calls: None,
refusal: None, refusal: None,
function_call: None, function_call: None,
reasoning_content: None,
}, },
logprobs: None, logprobs: None,
finish_reason, finish_reason,
...@@ -418,7 +419,6 @@ impl ...@@ -418,7 +419,6 @@ impl
system_fingerprint: Some(c.system_fingerprint), system_fingerprint: Some(c.system_fingerprint),
service_tier: None, service_tier: None,
}; };
let delta = NvCreateChatCompletionStreamResponse{inner};
let ann = Annotated{ let ann = Annotated{
id: None, id: None,
data: Some(delta), data: Some(delta),
......
...@@ -204,18 +204,12 @@ impl ...@@ -204,18 +204,12 @@ impl
for c in prompt.chars() { for c in prompt.chars() {
// we are returning characters not tokens, so there will be some postprocessing overhead // we are returning characters not tokens, so there will be some postprocessing overhead
tokio::time::sleep(*TOKEN_ECHO_DELAY).await; tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
let inner = deltas.create_choice(0, Some(c.to_string()), None, None); let response = deltas.create_choice(0, Some(c.to_string()), None, None);
let response = NvCreateChatCompletionStreamResponse {
inner,
};
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
id += 1; id += 1;
} }
let inner = deltas.create_choice(0, None, Some(dynamo_async_openai::types::FinishReason::Stop), None); let response = deltas.create_choice(0, None, Some(dynamo_async_openai::types::FinishReason::Stop), None);
let response = NvCreateChatCompletionStreamResponse {
inner,
};
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
}; };
......
...@@ -233,7 +233,7 @@ async fn evaluate( ...@@ -233,7 +233,7 @@ async fn evaluate(
match (item.data.as_ref(), item.event.as_deref()) { match (item.data.as_ref(), item.event.as_deref()) {
(Some(data), _) => { (Some(data), _) => {
// Normal case // Normal case
let choice = data.inner.choices.first(); let choice = data.choices.first();
let chat_comp = choice.as_ref().unwrap(); let chat_comp = choice.as_ref().unwrap();
if let Some(c) = &chat_comp.delta.content { if let Some(c) = &chat_comp.delta.content {
output += c; output += c;
......
...@@ -143,7 +143,7 @@ async fn main_loop( ...@@ -143,7 +143,7 @@ async fn main_loop(
match (item.data.as_ref(), item.event.as_deref()) { match (item.data.as_ref(), item.event.as_deref()) {
(Some(data), _) => { (Some(data), _) => {
// Normal case // Normal case
let entry = data.inner.choices.first(); let entry = data.choices.first();
let chat_comp = entry.as_ref().unwrap(); let chat_comp = entry.as_ref().unwrap();
if let Some(c) = &chat_comp.delta.content { if let Some(c) = &chat_comp.delta.content {
let _ = stdout.write(c.as_bytes()); let _ = stdout.write(c.as_bytes());
......
...@@ -31,6 +31,7 @@ use super::{ ...@@ -31,6 +31,7 @@ use super::{
service_v2, RouteDoc, service_v2, RouteDoc,
}; };
use crate::preprocessor::LLMMetricAnnotation; use crate::preprocessor::LLMMetricAnnotation;
use crate::protocols::openai::chat_completions::aggregator::ChatCompletionAggregator;
use crate::protocols::openai::{ use crate::protocols::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
......
...@@ -128,7 +128,7 @@ impl LogprobExtractor for NvCreateChatCompletionStreamResponse { ...@@ -128,7 +128,7 @@ impl LogprobExtractor for NvCreateChatCompletionStreamResponse {
fn extract_logprobs_by_choice(&self) -> HashMap<u32, Vec<TokenLogProbs>> { fn extract_logprobs_by_choice(&self) -> HashMap<u32, Vec<TokenLogProbs>> {
let mut result = HashMap::new(); let mut result = HashMap::new();
for choice in &self.inner.choices { for choice in &self.choices {
let choice_index = choice.index; let choice_index = choice.index;
let choice_logprobs = choice let choice_logprobs = choice
...@@ -574,8 +574,7 @@ mod tests { ...@@ -574,8 +574,7 @@ mod tests {
use approx::assert_abs_diff_eq; use approx::assert_abs_diff_eq;
use dynamo_async_openai::types::{ use dynamo_async_openai::types::{
ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionStreamResponseDelta, ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionStreamResponseDelta,
ChatCompletionTokenLogprob, CreateChatCompletionStreamResponse, FinishReason, Role, ChatCompletionTokenLogprob, FinishReason, Role, TopLogprobs,
TopLogprobs,
}; };
use futures::StreamExt; use futures::StreamExt;
use std::sync::Arc; use std::sync::Arc;
...@@ -949,7 +948,7 @@ mod tests { ...@@ -949,7 +948,7 @@ mod tests {
token_logprobs: Vec<ChatCompletionTokenLogprob>, token_logprobs: Vec<ChatCompletionTokenLogprob>,
) -> NvCreateChatCompletionStreamResponse { ) -> NvCreateChatCompletionStreamResponse {
#[expect(deprecated)] #[expect(deprecated)]
let inner = CreateChatCompletionStreamResponse { NvCreateChatCompletionStreamResponse {
id: "test_id".to_string(), id: "test_id".to_string(),
choices: vec![ChatChoiceStream { choices: vec![ChatChoiceStream {
index: 0, index: 0,
...@@ -959,6 +958,7 @@ mod tests { ...@@ -959,6 +958,7 @@ mod tests {
tool_calls: None, tool_calls: None,
role: Some(Role::Assistant), role: Some(Role::Assistant),
refusal: None, refusal: None,
reasoning_content: None,
}, },
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
logprobs: Some(ChatChoiceLogprobs { logprobs: Some(ChatChoiceLogprobs {
...@@ -972,9 +972,7 @@ mod tests { ...@@ -972,9 +972,7 @@ mod tests {
system_fingerprint: None, system_fingerprint: None,
object: "chat.completion.chunk".to_string(), object: "chat.completion.chunk".to_string(),
usage: None, usage: None,
}; }
NvCreateChatCompletionStreamResponse { inner }
} }
fn create_mock_response_with_multiple_choices( fn create_mock_response_with_multiple_choices(
...@@ -992,6 +990,7 @@ mod tests { ...@@ -992,6 +990,7 @@ mod tests {
tool_calls: None, tool_calls: None,
role: Some(Role::Assistant), role: Some(Role::Assistant),
refusal: None, refusal: None,
reasoning_content: None,
}, },
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
logprobs: Some(ChatChoiceLogprobs { logprobs: Some(ChatChoiceLogprobs {
...@@ -1001,7 +1000,7 @@ mod tests { ...@@ -1001,7 +1000,7 @@ mod tests {
}) })
.collect(); .collect();
let inner = CreateChatCompletionStreamResponse { NvCreateChatCompletionStreamResponse {
id: "test_id".to_string(), id: "test_id".to_string(),
choices, choices,
created: 1234567890, created: 1234567890,
...@@ -1010,9 +1009,7 @@ mod tests { ...@@ -1010,9 +1009,7 @@ mod tests {
system_fingerprint: None, system_fingerprint: None,
object: "chat.completion.chunk".to_string(), object: "chat.completion.chunk".to_string(),
usage: None, usage: None,
}; }
NvCreateChatCompletionStreamResponse { inner }
} }
#[test] #[test]
...@@ -1331,7 +1328,7 @@ mod tests { ...@@ -1331,7 +1328,7 @@ mod tests {
fn test_logprob_extractor_with_missing_data() { fn test_logprob_extractor_with_missing_data() {
// Test with choice that has no logprobs // Test with choice that has no logprobs
#[expect(deprecated)] #[expect(deprecated)]
let inner = CreateChatCompletionStreamResponse { let response = NvCreateChatCompletionStreamResponse {
id: "test_id".to_string(), id: "test_id".to_string(),
choices: vec![ChatChoiceStream { choices: vec![ChatChoiceStream {
index: 0, index: 0,
...@@ -1341,6 +1338,7 @@ mod tests { ...@@ -1341,6 +1338,7 @@ mod tests {
tool_calls: None, tool_calls: None,
role: Some(Role::Assistant), role: Some(Role::Assistant),
refusal: None, refusal: None,
reasoning_content: None,
}, },
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
logprobs: None, // No logprobs logprobs: None, // No logprobs
...@@ -1353,7 +1351,6 @@ mod tests { ...@@ -1353,7 +1351,6 @@ mod tests {
usage: None, usage: None,
}; };
let response = NvCreateChatCompletionStreamResponse { inner };
let logprobs = response.extract_logprobs_by_choice(); let logprobs = response.extract_logprobs_by_choice();
assert_eq!(logprobs.len(), 1); assert_eq!(logprobs.len(), 1);
assert!(logprobs.values().any(|v| v.is_empty())); assert!(logprobs.values().any(|v| v.is_empty()));
...@@ -1556,9 +1553,8 @@ mod tests { ...@@ -1556,9 +1553,8 @@ mod tests {
fn create_mock_response() -> NvCreateChatCompletionStreamResponse { fn create_mock_response() -> NvCreateChatCompletionStreamResponse {
// Create a mock response for testing // Create a mock response for testing
// In practice, this would have real logprobs data // In practice, this would have real logprobs data
use dynamo_async_openai::types::CreateChatCompletionStreamResponse;
let inner = CreateChatCompletionStreamResponse { NvCreateChatCompletionStreamResponse {
id: "test_id".to_string(), id: "test_id".to_string(),
choices: vec![], choices: vec![],
created: 1234567890, created: 1234567890,
...@@ -1567,9 +1563,7 @@ mod tests { ...@@ -1567,9 +1563,7 @@ mod tests {
system_fingerprint: None, system_fingerprint: None,
object: "chat.completion.chunk".to_string(), object: "chat.completion.chunk".to_string(),
usage: None, usage: None,
}; }
NvCreateChatCompletionStreamResponse { inner }
} }
// Mock context for testing // Mock context for testing
......
...@@ -27,7 +27,7 @@ use super::{ ...@@ -27,7 +27,7 @@ use super::{
OpenAIStopConditionsProvider, OpenAIStopConditionsProvider,
}; };
mod aggregator; pub mod aggregator;
mod delta; mod delta;
pub use aggregator::DeltaAggregator; pub use aggregator::DeltaAggregator;
...@@ -59,11 +59,7 @@ pub struct NvCreateChatCompletionRequest { ...@@ -59,11 +59,7 @@ pub struct NvCreateChatCompletionRequest {
/// # Fields /// # Fields
/// - `inner`: The base OpenAI unary chat completion response, embedded /// - `inner`: The base OpenAI unary chat completion response, embedded
/// using `serde(flatten)`. /// using `serde(flatten)`.
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] pub type NvCreateChatCompletionResponse = dynamo_async_openai::types::CreateChatCompletionResponse;
pub struct NvCreateChatCompletionResponse {
#[serde(flatten)]
pub inner: dynamo_async_openai::types::CreateChatCompletionResponse,
}
/// A response structure for streamed chat completions, embedding OpenAI's /// A response structure for streamed chat completions, embedding OpenAI's
/// `CreateChatCompletionStreamResponse`. /// `CreateChatCompletionStreamResponse`.
...@@ -71,11 +67,8 @@ pub struct NvCreateChatCompletionResponse { ...@@ -71,11 +67,8 @@ pub struct NvCreateChatCompletionResponse {
/// # Fields /// # Fields
/// - `inner`: The base OpenAI streaming chat completion response, embedded /// - `inner`: The base OpenAI streaming chat completion response, embedded
/// using `serde(flatten)`. /// using `serde(flatten)`.
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] pub type NvCreateChatCompletionStreamResponse =
pub struct NvCreateChatCompletionStreamResponse { dynamo_async_openai::types::CreateChatCompletionStreamResponse;
#[serde(flatten)]
pub inner: dynamo_async_openai::types::CreateChatCompletionStreamResponse,
}
/// Implements `NvExtProvider` for `NvCreateChatCompletionRequest`, /// Implements `NvExtProvider` for `NvCreateChatCompletionRequest`,
/// providing access to NVIDIA-specific extensions. /// providing access to NVIDIA-specific extensions.
......
...@@ -110,21 +110,21 @@ impl DeltaAggregator { ...@@ -110,21 +110,21 @@ impl DeltaAggregator {
if aggregator.error.is_none() && delta.data.is_some() { if aggregator.error.is_none() && delta.data.is_some() {
// Extract the data payload from the delta. // Extract the data payload from the delta.
let delta = delta.data.unwrap(); let delta = delta.data.unwrap();
aggregator.id = delta.inner.id; aggregator.id = delta.id;
aggregator.model = delta.inner.model; aggregator.model = delta.model;
aggregator.created = delta.inner.created; aggregator.created = delta.created;
aggregator.service_tier = delta.inner.service_tier; aggregator.service_tier = delta.service_tier;
// Aggregate usage statistics if available. // Aggregate usage statistics if available.
if let Some(usage) = delta.inner.usage { if let Some(usage) = delta.usage {
aggregator.usage = Some(usage); aggregator.usage = Some(usage);
} }
if let Some(system_fingerprint) = delta.inner.system_fingerprint { if let Some(system_fingerprint) = delta.system_fingerprint {
aggregator.system_fingerprint = Some(system_fingerprint); aggregator.system_fingerprint = Some(system_fingerprint);
} }
// Aggregate choices incrementally. // Aggregate choices incrementally.
for choice in delta.inner.choices { for choice in delta.choices {
let state_choice = let state_choice =
aggregator aggregator
.choices .choices
...@@ -198,7 +198,7 @@ impl DeltaAggregator { ...@@ -198,7 +198,7 @@ impl DeltaAggregator {
choices.sort_by(|a, b| a.index.cmp(&b.index)); choices.sort_by(|a, b| a.index.cmp(&b.index));
// Construct the final response object. // Construct the final response object.
let inner = dynamo_async_openai::types::CreateChatCompletionResponse { let response = NvCreateChatCompletionResponse {
id: aggregator.id, id: aggregator.id,
created: aggregator.created, created: aggregator.created,
usage: aggregator.usage, usage: aggregator.usage,
...@@ -209,8 +209,6 @@ impl DeltaAggregator { ...@@ -209,8 +209,6 @@ impl DeltaAggregator {
service_tier: aggregator.service_tier, service_tier: aggregator.service_tier,
}; };
let response = NvCreateChatCompletionResponse { inner };
Ok(response) Ok(response)
} }
} }
...@@ -234,6 +232,7 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice { ...@@ -234,6 +232,7 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice {
refusal: None, refusal: None,
function_call: None, function_call: None,
audio: None, audio: None,
reasoning_content: None,
}, },
index: delta.index, index: delta.index,
finish_reason: delta.finish_reason, finish_reason: delta.finish_reason,
...@@ -242,35 +241,48 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice { ...@@ -242,35 +241,48 @@ impl From<DeltaChoice> for dynamo_async_openai::types::ChatChoice {
} }
} }
impl NvCreateChatCompletionResponse { /// Trait for aggregating chat completion responses from streams.
/// Converts an SSE stream into a [`NvCreateChatCompletionResponse`]. /// Setting this macro because our async functions are not used outside of the library
#[allow(async_fn_in_trait)]
pub trait ChatCompletionAggregator {
/// Aggregates an annotated stream of chat completion responses into a final response.
/// ///
/// # Arguments /// # Arguments
/// * `stream` - A stream of SSE messages containing chat completion responses. /// * `stream` - A stream of annotated chat completion responses.
/// ///
/// # Returns /// # Returns
/// * `Ok(NvCreateChatCompletionResponse)` if aggregation succeeds. /// * `Ok(NvCreateChatCompletionResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs. /// * `Err(String)` if an error occurs.
pub async fn from_sse_stream( async fn from_annotated_stream(
stream: DataStream<Result<Message, SseCodecError>>, stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
) -> Result<NvCreateChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String>;
let stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(stream);
NvCreateChatCompletionResponse::from_annotated_stream(stream).await
}
/// Aggregates an annotated stream of chat completion responses into a final response. /// Converts an SSE stream into a [`NvCreateChatCompletionResponse`].
/// ///
/// # Arguments /// # Arguments
/// * `stream` - A stream of annotated chat completion responses. /// * `stream` - A stream of SSE messages containing chat completion responses.
/// ///
/// # Returns /// # Returns
/// * `Ok(NvCreateChatCompletionResponse)` if aggregation succeeds. /// * `Ok(NvCreateChatCompletionResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs. /// * `Err(String)` if an error occurs.
pub async fn from_annotated_stream( async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>,
) -> Result<NvCreateChatCompletionResponse, String>;
}
impl ChatCompletionAggregator for dynamo_async_openai::types::CreateChatCompletionResponse {
async fn from_annotated_stream(
stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>, stream: impl Stream<Item = Annotated<NvCreateChatCompletionStreamResponse>>,
) -> Result<NvCreateChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String> {
DeltaAggregator::apply(stream).await DeltaAggregator::apply(stream).await
} }
async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>,
) -> Result<NvCreateChatCompletionResponse, String> {
let stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(stream);
NvCreateChatCompletionResponse::from_annotated_stream(stream).await
}
} }
#[cfg(test)] #[cfg(test)]
...@@ -293,6 +305,7 @@ mod tests { ...@@ -293,6 +305,7 @@ mod tests {
tool_calls: None, tool_calls: None,
role, role,
refusal: None, refusal: None,
reasoning_content: None,
}; };
let choice = dynamo_async_openai::types::ChatChoiceStream { let choice = dynamo_async_openai::types::ChatChoiceStream {
index, index,
...@@ -301,7 +314,7 @@ mod tests { ...@@ -301,7 +314,7 @@ mod tests {
logprobs: None, logprobs: None,
}; };
let inner = dynamo_async_openai::types::CreateChatCompletionStreamResponse { let data = NvCreateChatCompletionStreamResponse {
id: "test_id".to_string(), id: "test_id".to_string(),
model: "meta/llama-3.1-8b-instruct".to_string(), model: "meta/llama-3.1-8b-instruct".to_string(),
created: 1234567890, created: 1234567890,
...@@ -312,8 +325,6 @@ mod tests { ...@@ -312,8 +325,6 @@ mod tests {
object: "chat.completion".to_string(), object: "chat.completion".to_string(),
}; };
let data = NvCreateChatCompletionStreamResponse { inner };
Annotated { Annotated {
data: Some(data), data: Some(data),
id: Some("test_id".to_string()), id: Some("test_id".to_string()),
...@@ -336,13 +347,13 @@ mod tests { ...@@ -336,13 +347,13 @@ mod tests {
let response = result.unwrap(); let response = result.unwrap();
// Verify that the response is empty and has default values // Verify that the response is empty and has default values
assert_eq!(response.inner.id, ""); assert_eq!(response.id, "");
assert_eq!(response.inner.model, ""); assert_eq!(response.model, "");
assert_eq!(response.inner.created, 0); assert_eq!(response.created, 0);
assert!(response.inner.usage.is_none()); assert!(response.usage.is_none());
assert!(response.inner.system_fingerprint.is_none()); assert!(response.system_fingerprint.is_none());
assert_eq!(response.inner.choices.len(), 0); assert_eq!(response.choices.len(), 0);
assert!(response.inner.service_tier.is_none()); assert!(response.service_tier.is_none());
} }
#[tokio::test] #[tokio::test]
...@@ -366,18 +377,18 @@ mod tests { ...@@ -366,18 +377,18 @@ mod tests {
let response = result.unwrap(); let response = result.unwrap();
// Verify the response fields // Verify the response fields
assert_eq!(response.inner.id, "test_id"); assert_eq!(response.id, "test_id");
assert_eq!(response.inner.model, "meta/llama-3.1-8b-instruct"); assert_eq!(response.model, "meta/llama-3.1-8b-instruct");
assert_eq!(response.inner.created, 1234567890); assert_eq!(response.created, 1234567890);
assert!(response.inner.usage.is_none()); assert!(response.usage.is_none());
assert!(response.inner.system_fingerprint.is_none()); assert!(response.system_fingerprint.is_none());
assert_eq!(response.inner.choices.len(), 1); assert_eq!(response.choices.len(), 1);
let choice = &response.inner.choices[0]; let choice = &response.choices[0];
assert_eq!(choice.index, 0); assert_eq!(choice.index, 0);
assert_eq!(choice.message.content.as_ref().unwrap(), "Hello,"); assert_eq!(choice.message.content.as_ref().unwrap(), "Hello,");
assert!(choice.finish_reason.is_none()); assert!(choice.finish_reason.is_none());
assert_eq!(choice.message.role, dynamo_async_openai::types::Role::User); assert_eq!(choice.message.role, dynamo_async_openai::types::Role::User);
assert!(response.inner.service_tier.is_none()); assert!(response.service_tier.is_none());
} }
#[tokio::test] #[tokio::test]
...@@ -410,8 +421,8 @@ mod tests { ...@@ -410,8 +421,8 @@ mod tests {
let response = result.unwrap(); let response = result.unwrap();
// Verify the response fields // Verify the response fields
assert_eq!(response.inner.choices.len(), 1); assert_eq!(response.choices.len(), 1);
let choice = &response.inner.choices[0]; let choice = &response.choices[0];
assert_eq!(choice.index, 0); assert_eq!(choice.index, 0);
assert_eq!(choice.message.content.as_ref().unwrap(), "Hello, world!"); assert_eq!(choice.message.content.as_ref().unwrap(), "Hello, world!");
assert_eq!( assert_eq!(
...@@ -426,7 +437,7 @@ mod tests { ...@@ -426,7 +437,7 @@ mod tests {
async fn test_multiple_choices() { async fn test_multiple_choices() {
// Create a delta with multiple choices // Create a delta with multiple choices
// ALLOW: function_call is deprecated // ALLOW: function_call is deprecated
let delta = dynamo_async_openai::types::CreateChatCompletionStreamResponse { let data = NvCreateChatCompletionStreamResponse {
id: "test_id".to_string(), id: "test_id".to_string(),
model: "test_model".to_string(), model: "test_model".to_string(),
created: 1234567890, created: 1234567890,
...@@ -442,6 +453,7 @@ mod tests { ...@@ -442,6 +453,7 @@ mod tests {
function_call: None, function_call: None,
tool_calls: None, tool_calls: None,
refusal: None, refusal: None,
reasoning_content: None,
}, },
finish_reason: Some(dynamo_async_openai::types::FinishReason::Stop), finish_reason: Some(dynamo_async_openai::types::FinishReason::Stop),
logprobs: None, logprobs: None,
...@@ -454,6 +466,7 @@ mod tests { ...@@ -454,6 +466,7 @@ mod tests {
function_call: None, function_call: None,
tool_calls: None, tool_calls: None,
refusal: None, refusal: None,
reasoning_content: None,
}, },
finish_reason: Some(dynamo_async_openai::types::FinishReason::Stop), finish_reason: Some(dynamo_async_openai::types::FinishReason::Stop),
logprobs: None, logprobs: None,
...@@ -462,8 +475,6 @@ mod tests { ...@@ -462,8 +475,6 @@ mod tests {
object: "chat.completion".to_string(), object: "chat.completion".to_string(),
}; };
let data = NvCreateChatCompletionStreamResponse { inner: delta };
// Wrap it in Annotated and create a stream // Wrap it in Annotated and create a stream
let annotated_delta = Annotated { let annotated_delta = Annotated {
data: Some(data), data: Some(data),
...@@ -481,9 +492,9 @@ mod tests { ...@@ -481,9 +492,9 @@ mod tests {
let mut response = result.unwrap(); let mut response = result.unwrap();
// Verify the response fields // Verify the response fields
assert_eq!(response.inner.choices.len(), 2); assert_eq!(response.choices.len(), 2);
response.inner.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered response.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered
let choice0 = &response.inner.choices[0]; let choice0 = &response.choices[0];
assert_eq!(choice0.index, 0); assert_eq!(choice0.index, 0);
assert_eq!(choice0.message.content.as_ref().unwrap(), "Choice 0"); assert_eq!(choice0.message.content.as_ref().unwrap(), "Choice 0");
assert_eq!( assert_eq!(
...@@ -495,7 +506,7 @@ mod tests { ...@@ -495,7 +506,7 @@ mod tests {
dynamo_async_openai::types::Role::Assistant dynamo_async_openai::types::Role::Assistant
); );
let choice1 = &response.inner.choices[1]; let choice1 = &response.choices[1];
assert_eq!(choice1.index, 1); assert_eq!(choice1.index, 1);
assert_eq!(choice1.message.content.as_ref().unwrap(), "Choice 1"); assert_eq!(choice1.message.content.as_ref().unwrap(), "Choice 1");
assert_eq!( assert_eq!(
...@@ -520,9 +531,7 @@ mod tests { ...@@ -520,9 +531,7 @@ mod tests {
Some(dynamo_async_openai::types::Role::Assistant), Some(dynamo_async_openai::types::Role::Assistant),
Some(dynamo_async_openai::types::FinishReason::ToolCalls), Some(dynamo_async_openai::types::FinishReason::ToolCalls),
); );
let delta = annotated_delta.data.unwrap().inner; let data = annotated_delta.data.unwrap();
let data = NvCreateChatCompletionStreamResponse { inner: delta };
// Wrap it in Annotated and create a stream // Wrap it in Annotated and create a stream
let annotated_delta = Annotated { let annotated_delta = Annotated {
...@@ -541,8 +550,8 @@ mod tests { ...@@ -541,8 +550,8 @@ mod tests {
let response = result.unwrap(); let response = result.unwrap();
// There should be one choice // There should be one choice
assert_eq!(response.inner.choices.len(), 1); assert_eq!(response.choices.len(), 1);
let choice = &response.inner.choices[0]; let choice = &response.choices[0];
// The tool_calls field should be present and parsed // The tool_calls field should be present and parsed
assert!(choice.message.tool_calls.is_some()); assert!(choice.message.tool_calls.is_some());
......
...@@ -209,6 +209,7 @@ impl DeltaGenerator { ...@@ -209,6 +209,7 @@ impl DeltaGenerator {
None None
}, },
refusal: None, refusal: None,
reasoning_content: None,
}; };
let choice = dynamo_async_openai::types::ChatChoiceStream { let choice = dynamo_async_openai::types::ChatChoiceStream {
...@@ -304,9 +305,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -304,9 +305,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
let index = 0; let index = 0;
let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs); let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
Ok(NvCreateChatCompletionStreamResponse { Ok(stream_response)
inner: stream_response,
})
} }
fn get_isl(&self) -> Option<u32> { fn get_isl(&self) -> Option<u32> {
......
...@@ -199,7 +199,7 @@ impl TryFrom<NvCreateChatCompletionResponse> for NvResponse { ...@@ -199,7 +199,7 @@ impl TryFrom<NvCreateChatCompletionResponse> for NvResponse {
type Error = anyhow::Error; type Error = anyhow::Error;
fn try_from(nv_resp: NvCreateChatCompletionResponse) -> Result<Self, Self::Error> { fn try_from(nv_resp: NvCreateChatCompletionResponse) -> Result<Self, Self::Error> {
let chat_resp = nv_resp.inner; let chat_resp = nv_resp;
let content_text = chat_resp let content_text = chat_resp
.choices .choices
.into_iter() .into_iter()
...@@ -341,7 +341,6 @@ mod tests { ...@@ -341,7 +341,6 @@ mod tests {
fn test_into_nvresponse_from_chat_response() { fn test_into_nvresponse_from_chat_response() {
let now = 1_726_000_000; let now = 1_726_000_000;
let chat_resp = NvCreateChatCompletionResponse { let chat_resp = NvCreateChatCompletionResponse {
inner: dynamo_async_openai::types::CreateChatCompletionResponse {
id: "chatcmpl-xyz".into(), id: "chatcmpl-xyz".into(),
choices: vec![dynamo_async_openai::types::ChatChoice { choices: vec![dynamo_async_openai::types::ChatChoice {
index: 0, index: 0,
...@@ -352,6 +351,7 @@ mod tests { ...@@ -352,6 +351,7 @@ mod tests {
role: dynamo_async_openai::types::Role::Assistant, role: dynamo_async_openai::types::Role::Assistant,
function_call: None, function_call: None,
audio: None, audio: None,
reasoning_content: None,
}, },
finish_reason: None, finish_reason: None,
logprobs: None, logprobs: None,
...@@ -362,7 +362,6 @@ mod tests { ...@@ -362,7 +362,6 @@ mod tests {
system_fingerprint: None, system_fingerprint: None,
object: "chat.completion".to_string(), object: "chat.completion".to_string(),
usage: None, usage: None,
},
}; };
let wrapped: NvResponse = chat_resp.try_into().unwrap(); let wrapped: NvResponse = chat_resp.try_into().unwrap();
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
use dynamo_llm::protocols::{ use dynamo_llm::protocols::{
codec::{create_message_stream, Message, SseCodecError}, codec::{create_message_stream, Message, SseCodecError},
openai::{ openai::{
chat_completions::NvCreateChatCompletionResponse, completions::NvCreateCompletionResponse, chat_completions::{aggregator::ChatCompletionAggregator, NvCreateChatCompletionResponse},
completions::NvCreateCompletionResponse,
}, },
ContentProvider, DataStream, ContentProvider, DataStream,
}; };
...@@ -43,7 +44,6 @@ async fn test_openai_chat_stream() { ...@@ -43,7 +44,6 @@ async fn test_openai_chat_stream() {
// todo: provide a cleaner way to extract the content from choices // todo: provide a cleaner way to extract the content from choices
assert_eq!( assert_eq!(
result result
.inner
.choices .choices
.first() .first()
.unwrap() .unwrap()
...@@ -65,7 +65,6 @@ async fn test_openai_chat_edge_case_multi_line_data() { ...@@ -65,7 +65,6 @@ async fn test_openai_chat_edge_case_multi_line_data() {
assert_eq!( assert_eq!(
result result
.inner
.choices .choices
.first() .first()
.unwrap() .unwrap()
...@@ -86,7 +85,6 @@ async fn test_openai_chat_edge_case_comments_per_response() { ...@@ -86,7 +85,6 @@ async fn test_openai_chat_edge_case_comments_per_response() {
assert_eq!( assert_eq!(
result result
.inner
.choices .choices
.first() .first()
.unwrap() .unwrap()
......
...@@ -100,11 +100,7 @@ impl ...@@ -100,11 +100,7 @@ impl
let stream = stream! { let stream = stream! {
tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await; tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await;
for i in 0..10 { for i in 0..10 {
let inner = generator.create_choice(i,Some(format!("choice {i}")), None, None); let output = generator.create_choice(i,Some(format!("choice {i}")), None, None);
let output = NvCreateChatCompletionStreamResponse {
inner,
};
yield Annotated::from_data(output); yield Annotated::from_data(output);
} }
......
...@@ -12,8 +12,7 @@ use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStrea ...@@ -12,8 +12,7 @@ use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStrea
use dynamo_async_openai::types::{ use dynamo_async_openai::types::{
ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionStreamResponseDelta, ChatChoiceLogprobs, ChatChoiceStream, ChatCompletionStreamResponseDelta,
ChatCompletionTokenLogprob, CreateChatCompletionStreamResponse, FinishReason, Role, ChatCompletionTokenLogprob, FinishReason, Role, TopLogprobs,
TopLogprobs,
}; };
// Type aliases to simplify complex test data structures // Type aliases to simplify complex test data structures
...@@ -387,6 +386,7 @@ fn create_response_with_linear_probs( ...@@ -387,6 +386,7 @@ fn create_response_with_linear_probs(
tool_calls: None, tool_calls: None,
role: Some(Role::Assistant), role: Some(Role::Assistant),
refusal: None, refusal: None,
reasoning_content: None,
}, },
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
logprobs: Some(ChatChoiceLogprobs { logprobs: Some(ChatChoiceLogprobs {
...@@ -395,7 +395,7 @@ fn create_response_with_linear_probs( ...@@ -395,7 +395,7 @@ fn create_response_with_linear_probs(
}), }),
}; };
let inner = CreateChatCompletionStreamResponse { NvCreateChatCompletionStreamResponse {
id: "test_id".to_string(), id: "test_id".to_string(),
choices: vec![choice], choices: vec![choice],
created: 1234567890, created: 1234567890,
...@@ -404,9 +404,7 @@ fn create_response_with_linear_probs( ...@@ -404,9 +404,7 @@ fn create_response_with_linear_probs(
system_fingerprint: None, system_fingerprint: None,
object: "chat.completion.chunk".to_string(), object: "chat.completion.chunk".to_string(),
usage: None, usage: None,
}; }
NvCreateChatCompletionStreamResponse { inner }
} }
fn create_multi_choice_response( fn create_multi_choice_response(
...@@ -466,6 +464,7 @@ fn create_multi_choice_response( ...@@ -466,6 +464,7 @@ fn create_multi_choice_response(
tool_calls: None, tool_calls: None,
role: Some(Role::Assistant), role: Some(Role::Assistant),
refusal: None, refusal: None,
reasoning_content: None,
}, },
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
logprobs: Some(ChatChoiceLogprobs { logprobs: Some(ChatChoiceLogprobs {
...@@ -476,7 +475,7 @@ fn create_multi_choice_response( ...@@ -476,7 +475,7 @@ fn create_multi_choice_response(
}) })
.collect(); .collect();
let inner = CreateChatCompletionStreamResponse { NvCreateChatCompletionStreamResponse {
id: "test_id".to_string(), id: "test_id".to_string(),
choices, choices,
created: 1234567890, created: 1234567890,
...@@ -485,7 +484,5 @@ fn create_multi_choice_response( ...@@ -485,7 +484,5 @@ fn create_multi_choice_response(
system_fingerprint: None, system_fingerprint: None,
object: "chat.completion.chunk".to_string(), object: "chat.completion.chunk".to_string(),
usage: None, usage: None,
}; }
NvCreateChatCompletionStreamResponse { inner }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment