"vscode:/vscode.git/clone" did not exist on "32a748e4030859dbb2a4dd9eaac389e2c84966b3"
Commit 151a2a1d authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: removes wrapper for ChatCompletionContent and adds documentation (#296)

parent 85cc7b67
......@@ -27,6 +27,13 @@ mod delta;
pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
/// A request structure for creating a chat completion, extending OpenAI's
/// `CreateChatCompletionRequest` with [`NvExt`] extensions.
///
/// # Fields
/// - `inner`: The base OpenAI chat completion request, embedded using `serde(flatten)`.
/// - `nvext`: The optional NVIDIA extension field. See [`NvExt`] for
/// more details.
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateChatCompletionRequest {
#[serde(flatten)]
......@@ -34,41 +41,61 @@ pub struct NvCreateChatCompletionRequest {
pub nvext: Option<NvExt>,
}
/// A response structure for unary chat completion responses, embedding OpenAI's
/// `CreateChatCompletionResponse`.
///
/// # Fields
/// - `inner`: The base OpenAI unary chat completion response, embedded
/// using `serde(flatten)`.
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateChatCompletionResponse {
#[serde(flatten)]
pub inner: async_openai::types::CreateChatCompletionResponse,
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct ChatCompletionContent {
#[serde(flatten)]
pub inner: async_openai::types::ChatCompletionStreamResponseDelta,
}
/// A response structure for streamed chat completions, embedding OpenAI's
/// `CreateChatCompletionStreamResponse`.
///
/// # Fields
/// - `inner`: The base OpenAI streaming chat completion response, embedded
/// using `serde(flatten)`.
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateChatCompletionStreamResponse {
#[serde(flatten)]
pub inner: async_openai::types::CreateChatCompletionStreamResponse,
}
/// Implements `NvExtProvider` for `NvCreateChatCompletionRequest`,
/// providing access to NVIDIA-specific extensions.
impl NvExtProvider for NvCreateChatCompletionRequest {
/// Returns a reference to the optional `NvExt` extension, if available.
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
/// Returns `None`, as raw prompt extraction is not implemented.
fn raw_prompt(&self) -> Option<String> {
None
}
}
/// Implements `AnnotationsProvider` for `NvCreateChatCompletionRequest`,
/// enabling retrieval and management of request annotations.
impl AnnotationsProvider for NvCreateChatCompletionRequest {
/// Retrieves the list of annotations from `NvExt`, if present.
fn annotations(&self) -> Option<Vec<String>> {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.clone())
}
/// Checks whether a specific annotation exists in the request.
///
/// # Arguments
/// * `annotation` - A string slice representing the annotation to check.
///
/// # Returns
/// `true` if the annotation exists, `false` otherwise.
fn has_annotation(&self, annotation: &str) -> bool {
self.nvext
.as_ref()
......@@ -78,47 +105,72 @@ impl AnnotationsProvider for NvCreateChatCompletionRequest {
}
}
/// Implements `OpenAISamplingOptionsProvider` for `NvCreateChatCompletionRequest`,
/// exposing OpenAI's sampling parameters for chat completion.
impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest {
/// Retrieves the temperature parameter for sampling, if set.
fn get_temperature(&self) -> Option<f32> {
self.inner.temperature
}
/// Retrieves the top-p (nucleus sampling) parameter, if set.
fn get_top_p(&self) -> Option<f32> {
self.inner.top_p
}
/// Retrieves the frequency penalty parameter, if set.
fn get_frequency_penalty(&self) -> Option<f32> {
self.inner.frequency_penalty
}
/// Retrieves the presence penalty parameter, if set.
fn get_presence_penalty(&self) -> Option<f32> {
self.inner.presence_penalty
}
/// Returns a reference to the optional `NvExt` extension, if available.
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
}
/// Implements `OpenAIStopConditionsProvider` for `NvCreateChatCompletionRequest`,
/// providing access to stop conditions that control chat completion behavior.
#[allow(deprecated)]
impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
/// Retrieves the maximum number of tokens allowed in the response.
///
/// # Note
/// This field is deprecated in favor of `max_completion_tokens`.
fn get_max_tokens(&self) -> Option<u32> {
// ALLOW: max_tokens is deprecated in favor of max_completion_tokens
self.inner.max_tokens
}
/// Retrieves the minimum number of tokens required in the response.
///
/// # Note
/// This method is currently a placeholder and always returns `None`
/// since `min_tokens` is not an OpenAI-supported parameter.
fn get_min_tokens(&self) -> Option<u32> {
// TODO THIS IS WRONG min_tokens does not exist
None
}
/// Retrieves the stop conditions that terminate the chat completion response.
///
/// Converts OpenAI's `Stop` enum to a `Vec<String>`, normalizing the representation.
///
/// # Returns
/// * `Some(Vec<String>)` if stop conditions are set.
/// * `None` if no stop conditions are defined.
fn get_stop(&self) -> Option<Vec<String>> {
// TODO THIS IS WRONG should instead do
// Vec<String> -> async_openai::types::Stop
// self.inner.stop.clone()
None
self.inner.stop.as_ref().map(|stop| match stop {
async_openai::types::Stop::String(s) => vec![s.clone()],
async_openai::types::Stop::StringArray(arr) => arr.clone(),
})
}
/// Returns a reference to the optional `NvExt` extension, if available.
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
......
......@@ -22,37 +22,54 @@ use crate::protocols::{
use futures::{Stream, StreamExt};
use std::{collections::HashMap, pin::Pin};
/// A type alias for a pinned, dynamically-dispatched stream that is `Send` and `Sync`.
type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single [`NvCreateChatCompletionResponse`].
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
/// [`NvCreateChatCompletionResponse`]. This struct accumulates incremental responses
/// from a streaming OpenAI API call into a complete final response.
pub struct DeltaAggregator {
/// Unique identifier for the chat completion.
id: String,
/// Model name used for the chat completion.
model: String,
/// Timestamp (Unix epoch) indicating when the response was created.
created: u32,
/// Optional usage statistics for the completion request.
usage: Option<async_openai::types::CompletionUsage>,
/// Optional system fingerprint for version tracking.
system_fingerprint: Option<String>,
/// Map of incremental response choices, keyed by index.
choices: HashMap<u32, DeltaChoice>,
/// Optional error message if an error occurs during aggregation.
error: Option<String>,
/// Optional service tier information for the response.
service_tier: Option<async_openai::types::ServiceTierResponse>,
}
// Holds the accumulated state of a choice
/// Represents the accumulated state of a single chat choice during streaming aggregation.
struct DeltaChoice {
/// The index of the choice in the completion.
index: u32,
/// The accumulated text content for the choice.
text: String,
/// The role associated with this message (e.g., `system`, `user`, `assistant`).
role: Option<async_openai::types::Role>,
/// The reason the completion was finished (if applicable).
finish_reason: Option<async_openai::types::FinishReason>,
/// Optional log probabilities for the chat choice.
logprobs: Option<async_openai::types::ChatChoiceLogprobs>,
}
impl Default for DeltaAggregator {
/// Provides a default implementation for `DeltaAggregator` by calling [`DeltaAggregator::new`].
fn default() -> Self {
Self::new()
}
}
impl DeltaAggregator {
/// Creates a new [`DeltaAggregator`].
/// Creates a new, empty [`DeltaAggregator`] instance.
pub fn new() -> Self {
Self {
id: "".to_string(),
......@@ -66,14 +83,21 @@ impl DeltaAggregator {
}
}
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single [`NvCreateChatCompletionResponse`].
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single
/// [`NvCreateChatCompletionResponse`].
///
/// # Arguments
/// * `stream` - A stream of annotated chat completion responses.
///
/// # Returns
/// * `Ok(NvCreateChatCompletionResponse)` if aggregation is successful.
/// * `Err(String)` if an error occurs during processing.
pub async fn apply(
stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>,
) -> Result<NvCreateChatCompletionResponse, String> {
let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
// these are cheap to move so we do it every time since we are consuming the delta
// Attempt to unwrap the delta, capturing any errors.
let delta = match delta.ok() {
Ok(delta) => delta,
Err(error) => {
......@@ -83,15 +107,14 @@ impl DeltaAggregator {
};
if aggregator.error.is_none() && delta.data.is_some() {
// note: we could extract annotations here and add them to the aggregator
// to be return as part of the NIM Response Extension
// TODO(#14) - Aggregate Annotation
// Extract the data payload from the delta.
let delta = delta.data.unwrap();
aggregator.id = delta.inner.id;
aggregator.model = delta.inner.model;
aggregator.created = delta.inner.created;
aggregator.service_tier = delta.inner.service_tier;
// Aggregate usage statistics if available.
if let Some(usage) = delta.inner.usage {
aggregator.usage = Some(usage);
}
......@@ -99,7 +122,7 @@ impl DeltaAggregator {
aggregator.system_fingerprint = Some(system_fingerprint);
}
// handle the choices
// Aggregate choices incrementally.
for choice in delta.inner.choices {
let state_choice =
aggregator
......@@ -113,10 +136,12 @@ impl DeltaAggregator {
logprobs: choice.logprobs,
});
// Append content if available.
if let Some(content) = &choice.delta.content {
state_choice.text.push_str(content);
}
// Update finish reason if provided.
if let Some(finish_reason) = choice.finish_reason {
state_choice.finish_reason = Some(finish_reason);
}
......@@ -126,14 +151,14 @@ impl DeltaAggregator {
})
.await;
// If we have an error, return it
// Return early if an error was encountered.
let aggregator = if let Some(error) = aggregator.error {
return Err(error);
} else {
aggregator
};
// extra the aggregated deltas and sort by index
// Extract aggregated choices and sort them by index.
let mut choices: Vec<_> = aggregator
.choices
.into_values()
......@@ -142,6 +167,7 @@ impl DeltaAggregator {
choices.sort_by(|a, b| a.index.cmp(&b.index));
// Construct the final response object.
let inner = async_openai::types::CreateChatCompletionResponse {
id: aggregator.id,
created: aggregator.created,
......@@ -159,11 +185,13 @@ impl DeltaAggregator {
}
}
// todo - handle tool calls
#[allow(deprecated)]
impl From<DeltaChoice> for async_openai::types::ChatChoice {
/// Converts a [`DeltaChoice`] into an [`async_openai::types::ChatChoice`].
///
/// # Note
/// The `function_call` field is deprecated.
fn from(delta: DeltaChoice) -> Self {
// ALLOW: function_call is deprecated
async_openai::types::ChatChoice {
message: async_openai::types::ChatCompletionResponseMessage {
role: delta.role.expect("delta should have a Role"),
......@@ -181,6 +209,14 @@ impl From<DeltaChoice> for async_openai::types::ChatChoice {
}
impl NvCreateChatCompletionResponse {
/// Converts an SSE stream into a [`NvCreateChatCompletionResponse`].
///
/// # Arguments
/// * `stream` - A stream of SSE messages containing chat completion responses.
///
/// # Returns
/// * `Ok(NvCreateChatCompletionResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs.
pub async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>,
) -> Result<NvCreateChatCompletionResponse, String> {
......@@ -188,6 +224,14 @@ impl NvCreateChatCompletionResponse {
NvCreateChatCompletionResponse::from_annotated_stream(stream).await
}
/// Aggregates an annotated stream of chat completion responses into a final response.
///
/// # Arguments
/// * `stream` - A stream of annotated chat completion responses.
///
/// # Returns
/// * `Ok(NvCreateChatCompletionResponse)` if aggregation succeeds.
/// * `Err(String)` if an error occurs.
pub async fn from_annotated_stream(
stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>,
) -> Result<NvCreateChatCompletionResponse, String> {
......
......@@ -16,9 +16,12 @@
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
use crate::protocols::common;
/// Provides a method for generating a [`DeltaGenerator`] from a chat completion request.
impl NvCreateChatCompletionRequest {
// put this method on the request
// inspect the request to extract options
/// Creates a [`DeltaGenerator`] instance based on the chat completion request.
///
/// # Returns
/// * [`DeltaGenerator`] configured with model name and response options.
pub fn response_generator(&self) -> DeltaGenerator {
let options = DeltaGeneratorOptions {
enable_usage: true,
......@@ -29,34 +32,50 @@ impl NvCreateChatCompletionRequest {
}
}
/// Configuration options for the [`DeltaGenerator`], controlling response behavior.
#[derive(Debug, Clone, Default)]
pub struct DeltaGeneratorOptions {
/// Determines whether token usage statistics should be included in the response.
pub enable_usage: bool,
/// Determines whether log probabilities should be included in the response.
pub enable_logprobs: bool,
}
/// Generates incremental chat completion responses in a streaming fashion.
#[derive(Debug, Clone)]
pub struct DeltaGenerator {
/// Unique identifier for the chat completion session.
id: String,
/// Object type, representing a streamed chat completion response.
object: String,
/// Timestamp (Unix epoch) when the response was created.
created: u32,
/// Model name used for generating responses.
model: String,
/// Optional system fingerprint for version tracking.
system_fingerprint: Option<String>,
/// Optional service tier information for the response.
service_tier: Option<async_openai::types::ServiceTierResponse>,
/// Tracks token usage for the completion request.
usage: async_openai::types::CompletionUsage,
// counter on how many messages we have issued
/// Counter tracking the number of messages issued.
msg_counter: u64,
/// Configuration options for response generation.
options: DeltaGeneratorOptions,
}
impl DeltaGenerator {
/// Creates a new [`DeltaGenerator`] instance with the specified model and options.
///
/// # Arguments
/// * `model` - The model name used for response generation.
/// * `options` - Configuration options for enabling usage and log probabilities.
///
/// # Returns
/// * A new instance of [`DeltaGenerator`].
pub fn new(model: String, options: DeltaGeneratorOptions) -> Self {
// SAFETY: This is a fun one to write. We are casting from u64 to u32
// which typically is unsafe due to loss of precision after it
// exceeds u32::MAX. Fortunately, this won't be an issue until
// 2106. So whoever is still maintaining this then, enjoy!
// SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
// but this will not be an issue until 2106.
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
......@@ -83,10 +102,24 @@ impl DeltaGenerator {
}
}
/// Updates the prompt token usage count.
///
/// # Arguments
/// * `isl` - The number of prompt tokens used.
pub fn update_isl(&mut self, isl: u32) {
self.usage.prompt_tokens = isl;
}
/// Creates a choice within a chat completion response.
///
/// # Arguments
/// * `index` - The index of the choice in the completion response.
/// * `text` - The text content for the response.
/// * `finish_reason` - The reason why the response finished (e.g., stop, length, etc.).
/// * `logprobs` - Optional log probabilities of the generated tokens.
///
/// # Returns
/// * An [`async_openai::types::CreateChatCompletionStreamResponse`] instance representing the choice.
#[allow(deprecated)]
pub fn create_choice(
&self,
......@@ -96,7 +129,6 @@ impl DeltaGenerator {
logprobs: Option<async_openai::types::ChatChoiceLogprobs>,
) -> async_openai::types::CreateChatCompletionStreamResponse {
// TODO: Update for tool calling
// ALLOW: function_call is deprecated
let delta = async_openai::types::ChatCompletionStreamResponseDelta {
role: if self.msg_counter == 0 {
Some(async_openai::types::Role::Assistant)
......@@ -135,21 +167,32 @@ impl DeltaGenerator {
}
}
/// Implements the [`DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing
/// it to transform backend responses into OpenAI-style streaming responses.
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamResponse>
for DeltaGenerator
{
/// Converts a backend response into a structured OpenAI-style streaming response.
///
/// # Arguments
/// * `delta` - The backend response containing generated text and metadata.
///
/// # Returns
/// * `Ok(NvCreateChatCompletionStreamResponse)` if conversion succeeds.
/// * `Err(anyhow::Error)` if an error occurs.
fn choice_from_postprocessor(
&mut self,
delta: crate::protocols::common::llm_backend::BackendOutput,
) -> anyhow::Result<NvCreateChatCompletionStreamResponse> {
// aggregate usage
// Aggregate token usage if enabled.
if self.options.enable_usage {
self.usage.completion_tokens += delta.token_ids.len() as u32;
}
// todo logprobs
// TODO: Implement log probabilities aggregation.
let logprobs = None;
// Map backend finish reasons to OpenAI's finish reasons.
let finish_reason = match delta.finish_reason {
Some(common::FinishReason::EoS) => Some(async_openai::types::FinishReason::Stop),
Some(common::FinishReason::Stop) => Some(async_openai::types::FinishReason::Stop),
......@@ -161,7 +204,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
None => None,
};
// create choice
// Create the streaming response.
let index = 0;
let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment