"vscode:/vscode.git/clone" did not exist on "a01cd9c19d4b3a015579caa3293d6bb2091aaab4"
Unverified Commit 0c9ae4dd authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: refactoring to use async_openai::types::CompletionUsage (#1397)

parent cd18cf2e
......@@ -414,7 +414,7 @@ impl
let (common_request, annotations) = self.preprocess_request(&request)?;
// update isl
response_generator.update_isl(common_request.token_ids.len() as i32);
response_generator.update_isl(common_request.token_ids.len() as u32);
// repack the common completion request
let common_request = context.map(|_| common_request);
......
......@@ -24,13 +24,13 @@
//! need some additional information to propagate intermediate results for improved observability.
//! The metadata is transferred via the other arms of the `StreamingResponse` enum.
//!
use std::collections::HashMap;
use std::time::SystemTime;
use anyhow::Result;
use derive_builder::Builder;
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::collections::HashMap;
use std::time::SystemTime;
use super::TokenIdType;
......
......@@ -67,47 +67,6 @@ pub const MAX_PRESENCE_PENALTY: f32 = 2.0;
/// Allowed range of values for OpenAI's `presence_penalty` sampling option
pub const PRESENCE_PENALTY_RANGE: (f32, f32) = (MIN_PRESENCE_PENALTY, MAX_PRESENCE_PENALTY);
/// Usage statistics for the completion request
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct CompletionUsage {
/// Number of tokens in the generated completion.
pub completion_tokens: i32,
/// Number of tokens in the prompt.
pub prompt_tokens: i32,
/// Total number of tokens used in the request (prompt + completion).
pub total_tokens: i32,
/// Breakdown of tokens used in a completion, optional.
#[serde(skip_serializing_if = "Option::is_none")]
pub completion_tokens_details: Option<CompletionTokensDetails>,
/// Breakdown of tokens used in the prompt, optional.
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<PromptTokensDetails>,
}
// Struct for details on completion tokens
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct CompletionTokensDetails {
/// Audio input tokens generated by the model.
pub audio_tokens: Option<i32>,
/// Tokens generated by the model for reasoning.
pub reasoning_tokens: Option<i32>,
}
// Struct for details on prompt tokens
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PromptTokensDetails {
/// Audio input tokens present in the prompt.
pub audio_tokens: Option<i32>,
/// Cached tokens present in the prompt.
pub cached_tokens: Option<i32>,
}
/// Represents a streaming response from the OpenAI API
/// The object is generalized on R, which is the type of the response.
/// For SSE streaming responses, the expected `data: ` field is always a JSON
......@@ -247,7 +206,7 @@ pub struct GenericCompletionResponse<C>
/// or is `chat.completion.chunk` if the type of `Choice` is `ChatCompletionChoiceDelta`.
pub object: String,
pub usage: Option<CompletionUsage>,
pub usage: Option<async_openai::types::CompletionUsage>,
/// This fingerprint represents the backend configuration that the model runs with.
///
......
......@@ -16,22 +16,21 @@
use std::collections::HashMap;
use derive_builder::Builder;
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use serde::{Deserialize, Serialize};
use validator::Validate;
mod aggregator;
mod delta;
pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
use super::{
common::{self, SamplingOptionsProvider, StopConditionsProvider},
nvext::{NvExt, NvExtProvider},
CompletionUsage, ContentProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
ContentProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
};
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
mod aggregator;
mod delta;
pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateCompletionRequest {
......@@ -64,7 +63,7 @@ pub struct CompletionResponse {
pub object: String,
/// Usage statistics for the completion request.
pub usage: Option<CompletionUsage>,
pub usage: Option<async_openai::types::CompletionUsage>,
/// This fingerprint represents the backend configuration that the model runs with.
/// Can be used in conjunction with the seed request parameter to understand when backend
......@@ -240,7 +239,7 @@ impl ResponseFactory {
pub fn make_response(
&self,
choice: CompletionChoice,
usage: Option<CompletionUsage>,
usage: Option<async_openai::types::CompletionUsage>,
) -> CompletionResponse {
CompletionResponse {
id: self.id.clone(),
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -18,7 +18,7 @@ use std::{collections::HashMap, str::FromStr};
use anyhow::Result;
use futures::StreamExt;
use super::{CompletionChoice, CompletionResponse, CompletionUsage, LogprobResult};
use super::{CompletionChoice, CompletionResponse, LogprobResult};
use crate::protocols::{
codec::{Message, SseCodecError},
common::FinishReason,
......@@ -30,7 +30,7 @@ pub struct DeltaAggregator {
id: String,
model: String,
created: u64,
usage: Option<CompletionUsage>,
usage: Option<async_openai::types::CompletionUsage>,
system_fingerprint: Option<String>,
choices: HashMap<u64, DeltaChoice>,
error: Option<String>,
......
......@@ -15,7 +15,6 @@
use super::{CompletionChoice, CompletionResponse, NvCreateCompletionRequest};
use crate::protocols::common;
use crate::protocols::openai::CompletionUsage;
impl NvCreateCompletionRequest {
// put this method on the request
......@@ -43,8 +42,7 @@ pub struct DeltaGenerator {
created: u64,
model: String,
system_fingerprint: Option<String>,
usage: CompletionUsage,
usage: async_openai::types::CompletionUsage,
options: DeltaGeneratorOptions,
}
......@@ -55,18 +53,28 @@ impl DeltaGenerator {
.unwrap()
.as_secs();
// Previously, our home-rolled CompletionUsage impl'd Default
// PR !387 - https://github.com/64bit/async-openai/pull/387
let usage = async_openai::types::CompletionUsage {
completion_tokens: 0,
prompt_tokens: 0,
total_tokens: 0,
completion_tokens_details: None,
prompt_tokens_details: None,
};
Self {
id: format!("cmpl-{}", uuid::Uuid::new_v4()),
object: "text_completion".to_string(),
created: now,
model,
system_fingerprint: None,
usage: CompletionUsage::default(),
usage,
options,
}
}
pub fn update_isl(&mut self, isl: i32) {
pub fn update_isl(&mut self, isl: u32) {
self.usage.prompt_tokens = isl;
}
......@@ -106,7 +114,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGe
) -> anyhow::Result<CompletionResponse> {
// aggregate usage
if self.options.enable_usage {
self.usage.completion_tokens += delta.token_ids.len() as i32;
self.usage.completion_tokens += delta.token_ids.len() as u32;
}
// todo logprobs
......@@ -127,8 +135,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGe
Ok(self.create_choice(index, delta.text, finish_reason))
}
// TODO: This is a hack. Change `prompt_tokens` to u32
fn get_isl(&self) -> Option<u32> {
Some(self.usage.prompt_tokens as u32)
Some(self.usage.prompt_tokens)
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment