Unverified Commit 0c9ae4dd authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: refactoring to use async_openai::types::CompletionUsage (#1397)

parent cd18cf2e
...@@ -414,7 +414,7 @@ impl ...@@ -414,7 +414,7 @@ impl
let (common_request, annotations) = self.preprocess_request(&request)?; let (common_request, annotations) = self.preprocess_request(&request)?;
// update isl // update isl
response_generator.update_isl(common_request.token_ids.len() as i32); response_generator.update_isl(common_request.token_ids.len() as u32);
// repack the common completion request // repack the common completion request
let common_request = context.map(|_| common_request); let common_request = context.map(|_| common_request);
......
...@@ -24,13 +24,13 @@ ...@@ -24,13 +24,13 @@
//! need some additional information to propagate intermediate results for improved observability. //! need some additional information to propagate intermediate results for improved observability.
//! The metadata is transferred via the other arms of the `StreamingResponse` enum. //! The metadata is transferred via the other arms of the `StreamingResponse` enum.
//! //!
use std::collections::HashMap;
use std::time::SystemTime;
use anyhow::Result; use anyhow::Result;
use derive_builder::Builder; use derive_builder::Builder;
use serde::ser::SerializeStruct; use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::collections::HashMap;
use std::time::SystemTime;
use super::TokenIdType; use super::TokenIdType;
......
...@@ -67,47 +67,6 @@ pub const MAX_PRESENCE_PENALTY: f32 = 2.0; ...@@ -67,47 +67,6 @@ pub const MAX_PRESENCE_PENALTY: f32 = 2.0;
/// Allowed range of values for OpenAI's `presence_penalty` sampling option /// Allowed range of values for OpenAI's `presence_penalty` sampling option
pub const PRESENCE_PENALTY_RANGE: (f32, f32) = (MIN_PRESENCE_PENALTY, MAX_PRESENCE_PENALTY); pub const PRESENCE_PENALTY_RANGE: (f32, f32) = (MIN_PRESENCE_PENALTY, MAX_PRESENCE_PENALTY);
/// Usage statistics for the completion request
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct CompletionUsage {
/// Number of tokens in the generated completion.
pub completion_tokens: i32,
/// Number of tokens in the prompt.
pub prompt_tokens: i32,
/// Total number of tokens used in the request (prompt + completion).
pub total_tokens: i32,
/// Breakdown of tokens used in a completion, optional.
#[serde(skip_serializing_if = "Option::is_none")]
pub completion_tokens_details: Option<CompletionTokensDetails>,
/// Breakdown of tokens used in the prompt, optional.
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<PromptTokensDetails>,
}
// Struct for details on completion tokens
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct CompletionTokensDetails {
/// Audio input tokens generated by the model.
pub audio_tokens: Option<i32>,
/// Tokens generated by the model for reasoning.
pub reasoning_tokens: Option<i32>,
}
// Struct for details on prompt tokens
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PromptTokensDetails {
/// Audio input tokens present in the prompt.
pub audio_tokens: Option<i32>,
/// Cached tokens present in the prompt.
pub cached_tokens: Option<i32>,
}
/// Represents a streaming response from the OpenAI API /// Represents a streaming response from the OpenAI API
/// The object is generalized on R, which is the type of the response. /// The object is generalized on R, which is the type of the response.
/// For SSE streaming responses, the expected `data: ` field is always a JSON /// For SSE streaming responses, the expected `data: ` field is always a JSON
...@@ -247,7 +206,7 @@ pub struct GenericCompletionResponse<C> ...@@ -247,7 +206,7 @@ pub struct GenericCompletionResponse<C>
/// or is `chat.completion.chunk` if the type of `Choice` is `ChatCompletionChoiceDelta`. /// or is `chat.completion.chunk` if the type of `Choice` is `ChatCompletionChoiceDelta`.
pub object: String, pub object: String,
pub usage: Option<CompletionUsage>, pub usage: Option<async_openai::types::CompletionUsage>,
/// This fingerprint represents the backend configuration that the model runs with. /// This fingerprint represents the backend configuration that the model runs with.
/// ///
......
...@@ -16,22 +16,21 @@ ...@@ -16,22 +16,21 @@
use std::collections::HashMap; use std::collections::HashMap;
use derive_builder::Builder; use derive_builder::Builder;
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use validator::Validate; use validator::Validate;
mod aggregator;
mod delta;
pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
use super::{ use super::{
common::{self, SamplingOptionsProvider, StopConditionsProvider}, common::{self, SamplingOptionsProvider, StopConditionsProvider},
nvext::{NvExt, NvExtProvider}, nvext::{NvExt, NvExtProvider},
CompletionUsage, ContentProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider, ContentProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
}; };
use dynamo_runtime::protocols::annotated::AnnotationsProvider; mod aggregator;
mod delta;
pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateCompletionRequest { pub struct NvCreateCompletionRequest {
...@@ -64,7 +63,7 @@ pub struct CompletionResponse { ...@@ -64,7 +63,7 @@ pub struct CompletionResponse {
pub object: String, pub object: String,
/// Usage statistics for the completion request. /// Usage statistics for the completion request.
pub usage: Option<CompletionUsage>, pub usage: Option<async_openai::types::CompletionUsage>,
/// This fingerprint represents the backend configuration that the model runs with. /// This fingerprint represents the backend configuration that the model runs with.
/// Can be used in conjunction with the seed request parameter to understand when backend /// Can be used in conjunction with the seed request parameter to understand when backend
...@@ -240,7 +239,7 @@ impl ResponseFactory { ...@@ -240,7 +239,7 @@ impl ResponseFactory {
pub fn make_response( pub fn make_response(
&self, &self,
choice: CompletionChoice, choice: CompletionChoice,
usage: Option<CompletionUsage>, usage: Option<async_openai::types::CompletionUsage>,
) -> CompletionResponse { ) -> CompletionResponse {
CompletionResponse { CompletionResponse {
id: self.id.clone(), id: self.id.clone(),
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -18,7 +18,7 @@ use std::{collections::HashMap, str::FromStr}; ...@@ -18,7 +18,7 @@ use std::{collections::HashMap, str::FromStr};
use anyhow::Result; use anyhow::Result;
use futures::StreamExt; use futures::StreamExt;
use super::{CompletionChoice, CompletionResponse, CompletionUsage, LogprobResult}; use super::{CompletionChoice, CompletionResponse, LogprobResult};
use crate::protocols::{ use crate::protocols::{
codec::{Message, SseCodecError}, codec::{Message, SseCodecError},
common::FinishReason, common::FinishReason,
...@@ -30,7 +30,7 @@ pub struct DeltaAggregator { ...@@ -30,7 +30,7 @@ pub struct DeltaAggregator {
id: String, id: String,
model: String, model: String,
created: u64, created: u64,
usage: Option<CompletionUsage>, usage: Option<async_openai::types::CompletionUsage>,
system_fingerprint: Option<String>, system_fingerprint: Option<String>,
choices: HashMap<u64, DeltaChoice>, choices: HashMap<u64, DeltaChoice>,
error: Option<String>, error: Option<String>,
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
use super::{CompletionChoice, CompletionResponse, NvCreateCompletionRequest}; use super::{CompletionChoice, CompletionResponse, NvCreateCompletionRequest};
use crate::protocols::common; use crate::protocols::common;
use crate::protocols::openai::CompletionUsage;
impl NvCreateCompletionRequest { impl NvCreateCompletionRequest {
// put this method on the request // put this method on the request
...@@ -43,8 +42,7 @@ pub struct DeltaGenerator { ...@@ -43,8 +42,7 @@ pub struct DeltaGenerator {
created: u64, created: u64,
model: String, model: String,
system_fingerprint: Option<String>, system_fingerprint: Option<String>,
usage: CompletionUsage, usage: async_openai::types::CompletionUsage,
options: DeltaGeneratorOptions, options: DeltaGeneratorOptions,
} }
...@@ -55,18 +53,28 @@ impl DeltaGenerator { ...@@ -55,18 +53,28 @@ impl DeltaGenerator {
.unwrap() .unwrap()
.as_secs(); .as_secs();
// Previously, our home-rolled CompletionUsage impl'd Default
// PR !387 - https://github.com/64bit/async-openai/pull/387
let usage = async_openai::types::CompletionUsage {
completion_tokens: 0,
prompt_tokens: 0,
total_tokens: 0,
completion_tokens_details: None,
prompt_tokens_details: None,
};
Self { Self {
id: format!("cmpl-{}", uuid::Uuid::new_v4()), id: format!("cmpl-{}", uuid::Uuid::new_v4()),
object: "text_completion".to_string(), object: "text_completion".to_string(),
created: now, created: now,
model, model,
system_fingerprint: None, system_fingerprint: None,
usage: CompletionUsage::default(), usage,
options, options,
} }
} }
pub fn update_isl(&mut self, isl: i32) { pub fn update_isl(&mut self, isl: u32) {
self.usage.prompt_tokens = isl; self.usage.prompt_tokens = isl;
} }
...@@ -106,7 +114,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGe ...@@ -106,7 +114,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGe
) -> anyhow::Result<CompletionResponse> { ) -> anyhow::Result<CompletionResponse> {
// aggregate usage // aggregate usage
if self.options.enable_usage { if self.options.enable_usage {
self.usage.completion_tokens += delta.token_ids.len() as i32; self.usage.completion_tokens += delta.token_ids.len() as u32;
} }
// todo logprobs // todo logprobs
...@@ -127,8 +135,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGe ...@@ -127,8 +135,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGe
Ok(self.create_choice(index, delta.text, finish_reason)) Ok(self.create_choice(index, delta.text, finish_reason))
} }
// TODO: This is a hack. Change `prompt_tokens` to u32
fn get_isl(&self) -> Option<u32> { fn get_isl(&self) -> Option<u32> {
Some(self.usage.prompt_tokens as u32) Some(self.usage.prompt_tokens)
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment