Unverified Commit 8a2d6529 authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: removing unsized integer conversions (#1668)

parent d4c2f0a3
......@@ -74,12 +74,14 @@ impl DeltaGenerator {
/// # Returns
/// * A new instance of [`DeltaGenerator`].
pub fn new(model: String, options: DeltaGeneratorOptions) -> Self {
// SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
// but this will not be an issue until 2106.
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as u32;
.as_secs();
// SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
// but this will not be an issue until 2106.
let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");
let usage = async_openai::types::CompletionUsage {
prompt_tokens: 0,
......@@ -191,7 +193,15 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
) -> anyhow::Result<NvCreateChatCompletionStreamResponse> {
// Aggregate token usage if enabled.
if self.options.enable_usage {
self.usage.completion_tokens += delta.token_ids.len() as u32;
// SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
// but this will not be an issue until context lengths exceed 4_294_967_295.
let token_length: u32 = delta
.token_ids
.len()
.try_into()
.expect("token_ids length exceeds u32::MAX");
self.usage.completion_tokens += token_length;
}
// TODO: Implement log probabilities aggregation.
......
......@@ -161,8 +161,8 @@ pub struct ResponseFactory {
#[builder(default = "\"text_completion\".to_string()")]
pub object: String,
#[builder(default = "chrono::Utc::now().timestamp() as u64")]
pub created: u64,
#[builder(default = "chrono::Utc::now().timestamp() as u32")]
pub created: u32,
}
impl ResponseFactory {
......@@ -178,7 +178,7 @@ impl ResponseFactory {
let inner = async_openai::types::CreateCompletionResponse {
id: self.id.clone(),
object: self.object.clone(),
created: self.created as u32,
created: self.created,
model: self.model.clone(),
choices: vec![choice],
system_fingerprint: self.system_fingerprint.clone(),
......@@ -250,9 +250,14 @@ impl TryFrom<common::StreamingCompletionResponse> for async_openai::types::Choic
.text
.ok_or(anyhow::anyhow!("No text in response"))?;
// Safety: we're downcasting from u64 to u32 here but u32::MAX is 4_294_967_295
// SAFETY: we're downcasting from u64 to u32 here but u32::MAX is 4_294_967_295
// so we're fairly safe knowing we won't generate that many Choices
let index = response.delta.index.unwrap_or(0) as u32;
let index: u32 = response
.delta
.index
.unwrap_or(0)
.try_into()
.expect("index exceeds u32::MAX");
// TODO handle aggregating logprobs
let logprobs = None;
......
......@@ -29,15 +29,15 @@ use crate::protocols::{
pub struct DeltaAggregator {
id: String,
model: String,
created: u64,
created: u32,
usage: Option<async_openai::types::CompletionUsage>,
system_fingerprint: Option<String>,
choices: HashMap<u64, DeltaChoice>,
choices: HashMap<u32, DeltaChoice>,
error: Option<String>,
}
struct DeltaChoice {
index: u64,
index: u32,
text: String,
finish_reason: Option<FinishReason>,
logprobs: Option<async_openai::types::Logprobs>,
......@@ -85,7 +85,7 @@ impl DeltaAggregator {
let delta = delta.data.unwrap();
aggregator.id = delta.inner.id;
aggregator.model = delta.inner.model;
aggregator.created = delta.inner.created as u64;
aggregator.created = delta.inner.created;
if let Some(usage) = delta.inner.usage {
aggregator.usage = Some(usage);
}
......@@ -98,9 +98,9 @@ impl DeltaAggregator {
let state_choice =
aggregator
.choices
.entry(choice.index as u64)
.entry(choice.index)
.or_insert(DeltaChoice {
index: choice.index as u64,
index: choice.index,
text: "".to_string(),
finish_reason: None,
logprobs: choice.logprobs,
......@@ -147,7 +147,7 @@ impl DeltaAggregator {
let inner = async_openai::types::CreateCompletionResponse {
id: aggregator.id,
created: aggregator.created as u32,
created: aggregator.created,
usage: aggregator.usage,
model: aggregator.model,
object: "text_completion".to_string(),
......@@ -166,7 +166,7 @@ impl From<DeltaChoice> for async_openai::types::Choice {
let finish_reason = delta.finish_reason.map(Into::into);
async_openai::types::Choice {
index: delta.index as u32,
index: delta.index,
text: delta.text,
finish_reason,
logprobs: delta.logprobs,
......@@ -199,7 +199,7 @@ mod tests {
use crate::protocols::openai::completions::NvCreateCompletionResponse;
fn create_test_delta(
index: u64,
index: u32,
text: &str,
finish_reason: Option<String>,
) -> Annotated<NvCreateCompletionResponse> {
......@@ -217,7 +217,7 @@ mod tests {
usage: None,
system_fingerprint: None,
choices: vec![async_openai::types::Choice {
index: index as u32,
index,
text: text.to_string(),
finish_reason,
logprobs: None,
......
......@@ -39,7 +39,7 @@ pub struct DeltaGeneratorOptions {
pub struct DeltaGenerator {
id: String,
object: String,
created: u64,
created: u32,
model: String,
system_fingerprint: Option<String>,
usage: async_openai::types::CompletionUsage,
......@@ -53,6 +53,10 @@ impl DeltaGenerator {
.unwrap()
.as_secs();
// SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
// but this will not be an issue until 2106.
let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");
// Previously, our home-rolled CompletionUsage impl'd Default
// PR !387 - https://github.com/64bit/async-openai/pull/387
let usage = async_openai::types::CompletionUsage {
......@@ -80,7 +84,7 @@ impl DeltaGenerator {
pub fn create_choice(
&self,
index: u64,
index: u32,
text: Option<String>,
finish_reason: Option<async_openai::types::CompletionFinishReason>,
) -> NvCreateCompletionResponse {
......@@ -94,12 +98,12 @@ impl DeltaGenerator {
let inner = async_openai::types::CreateCompletionResponse {
id: self.id.clone(),
object: self.object.clone(),
created: self.created as u32,
created: self.created,
model: self.model.clone(),
system_fingerprint: self.system_fingerprint.clone(),
choices: vec![async_openai::types::Choice {
text: text.unwrap_or_default(),
index: index as u32,
index,
finish_reason,
logprobs: None,
}],
......@@ -121,7 +125,15 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
) -> anyhow::Result<NvCreateCompletionResponse> {
// aggregate usage
if self.options.enable_usage {
self.usage.completion_tokens += delta.token_ids.len() as u32;
// SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
// but this will not be an issue until context lengths exceed 4_294_967_295.
let token_length: u32 = delta
.token_ids
.len()
.try_into()
.expect("token_ids length exceeds u32::MAX");
self.usage.completion_tokens += token_length;
}
// TODO logprobs
......@@ -129,7 +141,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
let finish_reason = delta.finish_reason.map(Into::into);
// create choice
let index = delta.index.unwrap_or(0).into();
let index = delta.index.unwrap_or(0);
let response = self.create_choice(index, delta.text.clone(), finish_reason);
Ok(response)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment