Commit 151a2a1d authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: removes wrapper for ChatCompletionContent and adds documentation (#296)

parent 85cc7b67
...@@ -27,6 +27,13 @@ mod delta; ...@@ -27,6 +27,13 @@ mod delta;
pub use aggregator::DeltaAggregator; pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator; pub use delta::DeltaGenerator;
/// A request structure for creating a chat completion, extending OpenAI's
/// `CreateChatCompletionRequest` with [`NvExt`] extensions.
///
/// # Fields
/// - `inner`: The base OpenAI chat completion request, embedded using `serde(flatten)`.
/// - `nvext`: The optional NVIDIA extension field. See [`NvExt`] for
/// more details.
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateChatCompletionRequest { pub struct NvCreateChatCompletionRequest {
#[serde(flatten)] #[serde(flatten)]
...@@ -34,41 +41,61 @@ pub struct NvCreateChatCompletionRequest { ...@@ -34,41 +41,61 @@ pub struct NvCreateChatCompletionRequest {
pub nvext: Option<NvExt>, pub nvext: Option<NvExt>,
} }
/// A response structure for unary chat completion responses, embedding OpenAI's
/// `CreateChatCompletionResponse`.
///
/// # Fields
/// - `inner`: The base OpenAI unary chat completion response, embedded
/// using `serde(flatten)`.
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateChatCompletionResponse { pub struct NvCreateChatCompletionResponse {
#[serde(flatten)] #[serde(flatten)]
pub inner: async_openai::types::CreateChatCompletionResponse, pub inner: async_openai::types::CreateChatCompletionResponse,
} }
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] /// A response structure for streamed chat completions, embedding OpenAI's
pub struct ChatCompletionContent { /// `CreateChatCompletionStreamResponse`.
#[serde(flatten)] ///
pub inner: async_openai::types::ChatCompletionStreamResponseDelta, /// # Fields
} /// - `inner`: The base OpenAI streaming chat completion response, embedded
/// using `serde(flatten)`.
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateChatCompletionStreamResponse { pub struct NvCreateChatCompletionStreamResponse {
#[serde(flatten)] #[serde(flatten)]
pub inner: async_openai::types::CreateChatCompletionStreamResponse, pub inner: async_openai::types::CreateChatCompletionStreamResponse,
} }
/// Implements `NvExtProvider` for `NvCreateChatCompletionRequest`,
/// providing access to NVIDIA-specific extensions.
impl NvExtProvider for NvCreateChatCompletionRequest { impl NvExtProvider for NvCreateChatCompletionRequest {
/// Returns a reference to the optional `NvExt` extension, if available.
fn nvext(&self) -> Option<&NvExt> { fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref() self.nvext.as_ref()
} }
/// Returns `None`, as raw prompt extraction is not implemented.
fn raw_prompt(&self) -> Option<String> { fn raw_prompt(&self) -> Option<String> {
None None
} }
} }
/// Implements `AnnotationsProvider` for `NvCreateChatCompletionRequest`,
/// enabling retrieval and management of request annotations.
impl AnnotationsProvider for NvCreateChatCompletionRequest { impl AnnotationsProvider for NvCreateChatCompletionRequest {
/// Retrieves the list of annotations from `NvExt`, if present.
fn annotations(&self) -> Option<Vec<String>> { fn annotations(&self) -> Option<Vec<String>> {
self.nvext self.nvext
.as_ref() .as_ref()
.and_then(|nvext| nvext.annotations.clone()) .and_then(|nvext| nvext.annotations.clone())
} }
/// Checks whether a specific annotation exists in the request.
///
/// # Arguments
/// * `annotation` - A string slice representing the annotation to check.
///
/// # Returns
/// `true` if the annotation exists, `false` otherwise.
fn has_annotation(&self, annotation: &str) -> bool { fn has_annotation(&self, annotation: &str) -> bool {
self.nvext self.nvext
.as_ref() .as_ref()
...@@ -78,47 +105,72 @@ impl AnnotationsProvider for NvCreateChatCompletionRequest { ...@@ -78,47 +105,72 @@ impl AnnotationsProvider for NvCreateChatCompletionRequest {
} }
} }
/// Implements `OpenAISamplingOptionsProvider` for `NvCreateChatCompletionRequest`,
/// exposing OpenAI's sampling parameters for chat completion.
impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest { impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest {
/// Retrieves the temperature parameter for sampling, if set.
fn get_temperature(&self) -> Option<f32> { fn get_temperature(&self) -> Option<f32> {
self.inner.temperature self.inner.temperature
} }
/// Retrieves the top-p (nucleus sampling) parameter, if set.
fn get_top_p(&self) -> Option<f32> { fn get_top_p(&self) -> Option<f32> {
self.inner.top_p self.inner.top_p
} }
/// Retrieves the frequency penalty parameter, if set.
fn get_frequency_penalty(&self) -> Option<f32> { fn get_frequency_penalty(&self) -> Option<f32> {
self.inner.frequency_penalty self.inner.frequency_penalty
} }
/// Retrieves the presence penalty parameter, if set.
fn get_presence_penalty(&self) -> Option<f32> { fn get_presence_penalty(&self) -> Option<f32> {
self.inner.presence_penalty self.inner.presence_penalty
} }
/// Returns a reference to the optional `NvExt` extension, if available.
fn nvext(&self) -> Option<&NvExt> { fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref() self.nvext.as_ref()
} }
} }
/// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`,
/// providing access to stop conditions that control chat completion behavior.
#[allow(deprecated)] #[allow(deprecated)]
impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest { impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
/// Retrieves the maximum number of tokens allowed in the response.
///
/// # Note
/// This field is deprecated in favor of `max_completion_tokens`.
fn get_max_tokens(&self) -> Option<u32> { fn get_max_tokens(&self) -> Option<u32> {
// ALLOW: max_tokens is deprecated in favor of max_completion_tokens // ALLOW: max_tokens is deprecated in favor of max_completion_tokens
self.inner.max_tokens self.inner.max_tokens
} }
/// Retrieves the minimum number of tokens required in the response.
///
/// # Note
/// This method is currently a placeholder and always returns `None`
/// since `min_tokens` is not an OpenAI-supported parameter.
fn get_min_tokens(&self) -> Option<u32> { fn get_min_tokens(&self) -> Option<u32> {
// TODO THIS IS WRONG min_tokens does not exist
None None
} }
/// Retrieves the stop conditions that terminate the chat completion response.
///
/// Converts OpenAI's `Stop` enum to a `Vec<String>`, normalizing the representation.
///
/// # Returns
/// * `Some(Vec<String>)` if stop conditions are set.
/// * `None` if no stop conditions are defined.
fn get_stop(&self) -> Option<Vec<String>> { fn get_stop(&self) -> Option<Vec<String>> {
// TODO THIS IS WRONG should instead do self.inner.stop.as_ref().map(|stop| match stop {
// Vec<String> -> async_openai::types::Stop async_openai::types::Stop::String(s) => vec![s.clone()],
// self.inner.stop.clone() async_openai::types::Stop::StringArray(arr) => arr.clone(),
None })
} }
/// Returns a reference to the optional `NvExt` extension, if available.
fn nvext(&self) -> Option<&NvExt> { fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref() self.nvext.as_ref()
} }
......
...@@ -22,37 +22,54 @@ use crate::protocols::{ ...@@ -22,37 +22,54 @@ use crate::protocols::{
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use std::{collections::HashMap, pin::Pin}; use std::{collections::HashMap, pin::Pin};
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>; type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single [`NvCreateChatCompletionResponse`]. /// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
/// [`NvCreateChatCompletionResponse`]. This struct accumulates incremental responses
/// from a streaming OpenAI API call into a complete final response.
pub struct DeltaAggregator { pub struct DeltaAggregator {
/// Unique identifier for the chat completion.
id: String, id: String,
/// Model name used for the chat completion.
model: String, model: String,
/// Timestamp (Unix epoch) indicating when the response was created.
created: u32, created: u32,
/// Optional usage statistics for the completion request.
usage: Option<async_openai::types::CompletionUsage>, usage: Option<async_openai::types::CompletionUsage>,
/// Optional system fingerprint for version tracking.
system_fingerprint: Option<String>, system_fingerprint: Option<String>,
/// Map of incremental response choices, keyed by index.
choices: HashMap<u32, DeltaChoice>, choices: HashMap<u32, DeltaChoice>,
/// Optional error message if an error occurs during aggregation.
error: Option<String>, error: Option<String>,
/// Optional service tier information for the response.
service_tier: Option<async_openai::types::ServiceTierResponse>, service_tier: Option<async_openai::types::ServiceTierResponse>,
} }
// Holds the accumulated state of a choice /// Represents the accumulated state of a single chat choice during streaming aggregation.
struct DeltaChoice { struct DeltaChoice {
/// The index of the choice in the completion.
index: u32, index: u32,
/// The accumulated text content for the choice.
text: String, text: String,
/// The role associated with this message (e.g., `system`, `user`, `assistant`).
role: Option<async_openai::types::Role>, role: Option<async_openai::types::Role>,
/// The reason the completion was finished (if applicable).
finish_reason: Option<async_openai::types::FinishReason>, finish_reason: Option<async_openai::types::FinishReason>,
/// Optional log probabilities for the chat choice.
logprobs: Option<async_openai::types::ChatChoiceLogprobs>, logprobs: Option<async_openai::types::ChatChoiceLogprobs>,
} }
impl Default for DeltaAggregator { impl Default for DeltaAggregator {
/// Provides a default implementation for `DeltaAggregator` by calling [`DeltaAggregator::new`].
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
} }
} }
impl DeltaAggregator { impl DeltaAggregator {
/// Creates a new [`DeltaAggregator`]. /// Creates a new, empty [`DeltaAggregator`] instance.
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
id: "".to_string(), id: "".to_string(),
...@@ -66,14 +83,21 @@ impl DeltaAggregator { ...@@ -66,14 +83,21 @@ impl DeltaAggregator {
} }
} }
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single [`NvCreateChatCompletionResponse`]. /// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
/// [`NvCreateChatCompletionResponse`].
///
/// # Arguments
/// * `stream` - A stream of annotated chat completion responses.
///
/// # Returns
/// * `Ok(NvCreateChatCompletionResponse)` if aggregation is successful.
/// * `Err(String)` if an error occurs during processing.
pub async fn apply( pub async fn apply(
stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>, stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>,
) -> Result<NvCreateChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String> {
let aggregator = stream let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move { .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
// these are cheap to move so we do it every time since we are consuming the delta // Attempt to unwrap the delta, capturing any errors.
let delta = match delta.ok() { let delta = match delta.ok() {
Ok(delta) => delta, Ok(delta) => delta,
Err(error) => { Err(error) => {
...@@ -83,15 +107,14 @@ impl DeltaAggregator { ...@@ -83,15 +107,14 @@ impl DeltaAggregator {
}; };
if aggregator.error.is_none() && delta.data.is_some() { if aggregator.error.is_none() && delta.data.is_some() {
// note: we could extract annotations here and add them to the aggregator // Extract the data payload from the delta.
// to be return as part of the NIM Response Extension
// TODO(#14) - Aggregate Annotation
let delta = delta.data.unwrap(); let delta = delta.data.unwrap();
aggregator.id = delta.inner.id; aggregator.id = delta.inner.id;
aggregator.model = delta.inner.model; aggregator.model = delta.inner.model;
aggregator.created = delta.inner.created; aggregator.created = delta.inner.created;
aggregator.service_tier = delta.inner.service_tier; aggregator.service_tier = delta.inner.service_tier;
// Aggregate usage statistics if available.
if let Some(usage) = delta.inner.usage { if let Some(usage) = delta.inner.usage {
aggregator.usage = Some(usage); aggregator.usage = Some(usage);
} }
...@@ -99,7 +122,7 @@ impl DeltaAggregator { ...@@ -99,7 +122,7 @@ impl DeltaAggregator {
aggregator.system_fingerprint = Some(system_fingerprint); aggregator.system_fingerprint = Some(system_fingerprint);
} }
// handle the choices // Aggregate choices incrementally.
for choice in delta.inner.choices { for choice in delta.inner.choices {
let state_choice = let state_choice =
aggregator aggregator
...@@ -113,10 +136,12 @@ impl DeltaAggregator { ...@@ -113,10 +136,12 @@ impl DeltaAggregator {
logprobs: choice.logprobs, logprobs: choice.logprobs,
}); });
// Append content if available.
if let Some(content) = &choice.delta.content { if let Some(content) = &choice.delta.content {
state_choice.text.push_str(content); state_choice.text.push_str(content);
} }
// Update finish reason if provided.
if let Some(finish_reason) = choice.finish_reason { if let Some(finish_reason) = choice.finish_reason {
state_choice.finish_reason = Some(finish_reason); state_choice.finish_reason = Some(finish_reason);
} }
...@@ -126,14 +151,14 @@ impl DeltaAggregator { ...@@ -126,14 +151,14 @@ impl DeltaAggregator {
}) })
.await; .await;
// If we have an error, return it // Return early if an error was encountered.
let aggregator = if let Some(error) = aggregator.error { let aggregator = if let Some(error) = aggregator.error {
return Err(error); return Err(error);
} else { } else {
aggregator aggregator
}; };
// extra the aggregated deltas and sort by index // Extract aggregated choices and sort them by index.
let mut choices: Vec<_> = aggregator let mut choices: Vec<_> = aggregator
.choices .choices
.into_values() .into_values()
...@@ -142,6 +167,7 @@ impl DeltaAggregator { ...@@ -142,6 +167,7 @@ impl DeltaAggregator {
choices.sort_by(|a, b| a.index.cmp(&b.index)); choices.sort_by(|a, b| a.index.cmp(&b.index));
// Construct the final response object.
let inner = async_openai::types::CreateChatCompletionResponse { let inner = async_openai::types::CreateChatCompletionResponse {
id: aggregator.id, id: aggregator.id,
created: aggregator.created, created: aggregator.created,
...@@ -159,11 +185,13 @@ impl DeltaAggregator { ...@@ -159,11 +185,13 @@ impl DeltaAggregator {
} }
} }
// todo - handle tool calls
#[allow(deprecated)] #[allow(deprecated)]
impl From<DeltaChoice> for async_openai::types::ChatChoice { impl From<DeltaChoice> for async_openai::types::ChatChoice {
/// Converts a [`DeltaChoice`] into an [`async_openai::types::ChatChoice`].
///
/// # Note
/// The `function_call` field is deprecated.
fn from(delta: DeltaChoice) -> Self { fn from(delta: DeltaChoice) -> Self {
// ALLOW: function_call is deprecated
async_openai::types::ChatChoice { async_openai::types::ChatChoice {
message: async_openai::types::ChatCompletionResponseMessage { message: async_openai::types::ChatCompletionResponseMessage {
role: delta.role.expect("delta should have a Role"), role: delta.role.expect("delta should have a Role"),
...@@ -181,6 +209,14 @@ impl From<DeltaChoice> for async_openai::types::ChatChoice { ...@@ -181,6 +209,14 @@ impl From<DeltaChoice> for async_openai::types::ChatChoice {
} }
impl NvCreateChatCompletionResponse { impl NvCreateChatCompletionResponse {
/// Converts an SSE stream into a [`NvCreateChatCompletionResponse`].
///
/// # Arguments
/// * `stream` - A stream of SSE messages containing chat completion responses.
///
/// # Returns
/// * `Ok(NvCreateChatCompletionResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs.
pub async fn from_sse_stream( pub async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>, stream: DataStream<Result<Message, SseCodecError>>,
) -> Result<NvCreateChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String> {
...@@ -188,6 +224,14 @@ impl NvCreateChatCompletionResponse { ...@@ -188,6 +224,14 @@ impl NvCreateChatCompletionResponse {
NvCreateChatCompletionResponse::from_annotated_stream(stream).await NvCreateChatCompletionResponse::from_annotated_stream(stream).await
} }
/// Aggregates an annotated stream of chat completion responses into a final response.
///
/// # Arguments
/// * `stream` - A stream of annotated chat completion responses.
///
/// # Returns
/// * `Ok(NvCreateChatCompletionResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs.
pub async fn from_annotated_stream( pub async fn from_annotated_stream(
stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>, stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>,
) -> Result<NvCreateChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String> {
......
...@@ -16,9 +16,12 @@ ...@@ -16,9 +16,12 @@
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}; use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
use crate::protocols::common; use crate::protocols::common;
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
impl NvCreateChatCompletionRequest { impl NvCreateChatCompletionRequest {
// put this method on the request /// Creates a [`DeltaGenerator`] instance based on the chat completion request.
// inspect the request to extract options ///
/// # Returns
/// * [`DeltaGenerator`] configured with model name and response options.
pub fn response_generator(&self) -> DeltaGenerator { pub fn response_generator(&self) -> DeltaGenerator {
let options = DeltaGeneratorOptions { let options = DeltaGeneratorOptions {
enable_usage: true, enable_usage: true,
...@@ -29,34 +32,50 @@ impl NvCreateChatCompletionRequest { ...@@ -29,34 +32,50 @@ impl NvCreateChatCompletionRequest {
} }
} }
/// Configuration options for the [`DeltaGenerator`], controlling response behavior.
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions { pub struct DeltaGeneratorOptions {
/// Determines whether token usage statistics should be included in the response.
pub enable_usage: bool, pub enable_usage: bool,
/// Determines whether log probabilities should be included in the response.
pub enable_logprobs: bool, pub enable_logprobs: bool,
} }
/// Generates incremental chat completion responses in a streaming fashion.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct DeltaGenerator { pub struct DeltaGenerator {
/// Unique identifier for the chat completion session.
id: String, id: String,
/// Object type, representing a streamed chat completion response.
object: String, object: String,
/// Timestamp (Unix epoch) when the response was created.
created: u32, created: u32,
/// Model name used for generating responses.
model: String, model: String,
/// Optional system fingerprint for version tracking.
system_fingerprint: Option<String>, system_fingerprint: Option<String>,
/// Optional service tier information for the response.
service_tier: Option<async_openai::types::ServiceTierResponse>, service_tier: Option<async_openai::types::ServiceTierResponse>,
/// Tracks token usage for the completion request.
usage: async_openai::types::CompletionUsage, usage: async_openai::types::CompletionUsage,
/// Counter tracking the number of messages issued.
// counter on how many messages we have issued
msg_counter: u64, msg_counter: u64,
/// Configuration options for response generation.
options: DeltaGeneratorOptions, options: DeltaGeneratorOptions,
} }
impl DeltaGenerator { impl DeltaGenerator {
/// Creates a new [`DeltaGenerator`] instance with the specified model and options.
///
/// # Arguments
/// * `model` - The model name used for response generation.
/// * `options` - Configuration options for enabling usage and log probabilities.
///
/// # Returns
/// * A new instance of [`DeltaGenerator`].
pub fn new(model: String, options: DeltaGeneratorOptions) -> Self { pub fn new(model: String, options: DeltaGeneratorOptions) -> Self {
// SAFETY: This is a fun one to write. We are casting from u64 to u32 // SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
// which typically is unsafe due to loss of precision after it // but this will not be an issue until 2106.
// exceeds u32::MAX. Fortunately, this won't be an issue until
// 2106. So whoever is still maintaining this then, enjoy!
let now = std::time::SystemTime::now() let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap() .unwrap()
...@@ -83,10 +102,24 @@ impl DeltaGenerator { ...@@ -83,10 +102,24 @@ impl DeltaGenerator {
} }
} }
/// Updates the prompt token usage count.
///
/// # Arguments
/// * `isl` - The number of prompt tokens used.
pub fn update_isl(&mut self, isl: u32) { pub fn update_isl(&mut self, isl: u32) {
self.usage.prompt_tokens = isl; self.usage.prompt_tokens = isl;
} }
/// Creates a choice within a chat completion response.
///
/// # Arguments
/// * `index` - The index of the choice in the completion response.
/// * `text` - The text content for the response.
/// * `finish_reason` - The reason why the response finished (e.g., stop, length, etc.).
/// * `logprobs` - Optional log probabilities of the generated tokens.
///
/// # Returns
/// * An [`async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice.
#[allow(deprecated)] #[allow(deprecated)]
pub fn create_choice( pub fn create_choice(
&self, &self,
...@@ -96,7 +129,6 @@ impl DeltaGenerator { ...@@ -96,7 +129,6 @@ impl DeltaGenerator {
logprobs: Option<async_openai::types::ChatChoiceLogprobs>, logprobs: Option<async_openai::types::ChatChoiceLogprobs>,
) -> async_openai::types::CreateChatCompletionStreamResponse { ) -> async_openai::types::CreateChatCompletionStreamResponse {
// TODO: Update for tool calling // TODO: Update for tool calling
// ALLOW: function_call is deprecated
let delta = async_openai::types::ChatCompletionStreamResponseDelta { let delta = async_openai::types::ChatCompletionStreamResponseDelta {
role: if self.msg_counter == 0 { role: if self.msg_counter == 0 {
Some(async_openai::types::Role::Assistant) Some(async_openai::types::Role::Assistant)
...@@ -135,21 +167,32 @@ impl DeltaGenerator { ...@@ -135,21 +167,32 @@ impl DeltaGenerator {
} }
} }
/// Implements the [`DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing
/// it to transform backend responses into OpenAI-style streaming responses.
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamResponse> impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamResponse>
for DeltaGenerator for DeltaGenerator
{ {
/// Converts a backend response into a structured OpenAI-style streaming response.
///
/// # Arguments
/// * `delta` - The backend response containing generated text and metadata.
///
/// # Returns
/// * `Ok(NvCreateChatCompletionStreamResponse)` if conversion succeeds.
/// * `Err(anyhow::Error)` if an error occurs.
fn choice_from_postprocessor( fn choice_from_postprocessor(
&mut self, &mut self,
delta: crate::protocols::common::llm_backend::BackendOutput, delta: crate::protocols::common::llm_backend::BackendOutput,
) -> anyhow::Result<NvCreateChatCompletionStreamResponse> { ) -> anyhow::Result<NvCreateChatCompletionStreamResponse> {
// aggregate usage // Aggregate token usage if enabled.
if self.options.enable_usage { if self.options.enable_usage {
self.usage.completion_tokens += delta.token_ids.len() as u32; self.usage.completion_tokens += delta.token_ids.len() as u32;
} }
// todo logprobs // TODO: Implement log probabilities aggregation.
let logprobs = None; let logprobs = None;
// Map backend finish reasons to OpenAI's finish reasons.
let finish_reason = match delta.finish_reason { let finish_reason = match delta.finish_reason {
Some(common::FinishReason::EoS) => Some(async_openai::types::FinishReason::Stop), Some(common::FinishReason::EoS) => Some(async_openai::types::FinishReason::Stop),
Some(common::FinishReason::Stop) => Some(async_openai::types::FinishReason::Stop), Some(common::FinishReason::Stop) => Some(async_openai::types::FinishReason::Stop),
...@@ -161,7 +204,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -161,7 +204,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
None => None, None => None,
}; };
// create choice // Create the streaming response.
let index = 0; let index = 0;
let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs); let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment