"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "a1b38af25eec422a454d541d7d1cef5dfb235cbf"
Unverified Commit 8a2d6529 authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: removing unsized integer conversions (#1668)

parent d4c2f0a3
...@@ -74,12 +74,14 @@ impl DeltaGenerator { ...@@ -74,12 +74,14 @@ impl DeltaGenerator {
/// # Returns /// # Returns
/// * A new instance of [`DeltaGenerator`]. /// * A new instance of [`DeltaGenerator`].
pub fn new(model: String, options: DeltaGeneratorOptions) -> Self { pub fn new(model: String, options: DeltaGeneratorOptions) -> Self {
// SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
// but this will not be an issue until 2106.
let now = std::time::SystemTime::now() let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap() .unwrap()
.as_secs() as u32; .as_secs();
// SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
// but this will not be an issue until 2106.
let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");
let usage = async_openai::types::CompletionUsage { let usage = async_openai::types::CompletionUsage {
prompt_tokens: 0, prompt_tokens: 0,
...@@ -191,7 +193,15 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -191,7 +193,15 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
) -> anyhow::Result<NvCreateChatCompletionStreamResponse> { ) -> anyhow::Result<NvCreateChatCompletionStreamResponse> {
// Aggregate token usage if enabled. // Aggregate token usage if enabled.
if self.options.enable_usage { if self.options.enable_usage {
self.usage.completion_tokens += delta.token_ids.len() as u32; // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
// but this will not be an issue until context lengths exceed 4_294_967_295.
let token_length: u32 = delta
.token_ids
.len()
.try_into()
.expect("token_ids length exceeds u32::MAX");
self.usage.completion_tokens += token_length;
} }
// TODO: Implement log probabilities aggregation. // TODO: Implement log probabilities aggregation.
......
...@@ -161,8 +161,8 @@ pub struct ResponseFactory { ...@@ -161,8 +161,8 @@ pub struct ResponseFactory {
#[builder(default = "\"text_completion\".to_string()")] #[builder(default = "\"text_completion\".to_string()")]
pub object: String, pub object: String,
#[builder(default = "chrono::Utc::now().timestamp() as u64")] #[builder(default = "chrono::Utc::now().timestamp() as u32")]
pub created: u64, pub created: u32,
} }
impl ResponseFactory { impl ResponseFactory {
...@@ -178,7 +178,7 @@ impl ResponseFactory { ...@@ -178,7 +178,7 @@ impl ResponseFactory {
let inner = async_openai::types::CreateCompletionResponse { let inner = async_openai::types::CreateCompletionResponse {
id: self.id.clone(), id: self.id.clone(),
object: self.object.clone(), object: self.object.clone(),
created: self.created as u32, created: self.created,
model: self.model.clone(), model: self.model.clone(),
choices: vec![choice], choices: vec![choice],
system_fingerprint: self.system_fingerprint.clone(), system_fingerprint: self.system_fingerprint.clone(),
...@@ -250,9 +250,14 @@ impl TryFrom<common::StreamingCompletionResponse> for async_openai::types::Choic ...@@ -250,9 +250,14 @@ impl TryFrom<common::StreamingCompletionResponse> for async_openai::types::Choic
.text .text
.ok_or(anyhow::anyhow!("No text in response"))?; .ok_or(anyhow::anyhow!("No text in response"))?;
// Safety: we're downcasting from u64 to u32 here but u32::MAX is 4_294_967_295 // SAFETY: we're downcasting from u64 to u32 here but u32::MAX is 4_294_967_295
// so we're fairly safe knowing we won't generate that many Choices // so we're fairly safe knowing we won't generate that many Choices
let index = response.delta.index.unwrap_or(0) as u32; let index: u32 = response
.delta
.index
.unwrap_or(0)
.try_into()
.expect("index exceeds u32::MAX");
// TODO handle aggregating logprobs // TODO handle aggregating logprobs
let logprobs = None; let logprobs = None;
......
...@@ -29,15 +29,15 @@ use crate::protocols::{ ...@@ -29,15 +29,15 @@ use crate::protocols::{
pub struct DeltaAggregator { pub struct DeltaAggregator {
id: String, id: String,
model: String, model: String,
created: u64, created: u32,
usage: Option<async_openai::types::CompletionUsage>, usage: Option<async_openai::types::CompletionUsage>,
system_fingerprint: Option<String>, system_fingerprint: Option<String>,
choices: HashMap<u64, DeltaChoice>, choices: HashMap<u32, DeltaChoice>,
error: Option<String>, error: Option<String>,
} }
struct DeltaChoice { struct DeltaChoice {
index: u64, index: u32,
text: String, text: String,
finish_reason: Option<FinishReason>, finish_reason: Option<FinishReason>,
logprobs: Option<async_openai::types::Logprobs>, logprobs: Option<async_openai::types::Logprobs>,
...@@ -85,7 +85,7 @@ impl DeltaAggregator { ...@@ -85,7 +85,7 @@ impl DeltaAggregator {
let delta = delta.data.unwrap(); let delta = delta.data.unwrap();
aggregator.id = delta.inner.id; aggregator.id = delta.inner.id;
aggregator.model = delta.inner.model; aggregator.model = delta.inner.model;
aggregator.created = delta.inner.created as u64; aggregator.created = delta.inner.created;
if let Some(usage) = delta.inner.usage { if let Some(usage) = delta.inner.usage {
aggregator.usage = Some(usage); aggregator.usage = Some(usage);
} }
...@@ -98,9 +98,9 @@ impl DeltaAggregator { ...@@ -98,9 +98,9 @@ impl DeltaAggregator {
let state_choice = let state_choice =
aggregator aggregator
.choices .choices
.entry(choice.index as u64) .entry(choice.index)
.or_insert(DeltaChoice { .or_insert(DeltaChoice {
index: choice.index as u64, index: choice.index,
text: "".to_string(), text: "".to_string(),
finish_reason: None, finish_reason: None,
logprobs: choice.logprobs, logprobs: choice.logprobs,
...@@ -147,7 +147,7 @@ impl DeltaAggregator { ...@@ -147,7 +147,7 @@ impl DeltaAggregator {
let inner = async_openai::types::CreateCompletionResponse { let inner = async_openai::types::CreateCompletionResponse {
id: aggregator.id, id: aggregator.id,
created: aggregator.created as u32, created: aggregator.created,
usage: aggregator.usage, usage: aggregator.usage,
model: aggregator.model, model: aggregator.model,
object: "text_completion".to_string(), object: "text_completion".to_string(),
...@@ -166,7 +166,7 @@ impl From<DeltaChoice> for async_openai::types::Choice { ...@@ -166,7 +166,7 @@ impl From<DeltaChoice> for async_openai::types::Choice {
let finish_reason = delta.finish_reason.map(Into::into); let finish_reason = delta.finish_reason.map(Into::into);
async_openai::types::Choice { async_openai::types::Choice {
index: delta.index as u32, index: delta.index,
text: delta.text, text: delta.text,
finish_reason, finish_reason,
logprobs: delta.logprobs, logprobs: delta.logprobs,
...@@ -199,7 +199,7 @@ mod tests { ...@@ -199,7 +199,7 @@ mod tests {
use crate::protocols::openai::completions::NvCreateCompletionResponse; use crate::protocols::openai::completions::NvCreateCompletionResponse;
fn create_test_delta( fn create_test_delta(
index: u64, index: u32,
text: &str, text: &str,
finish_reason: Option<String>, finish_reason: Option<String>,
) -> Annotated<NvCreateCompletionResponse> { ) -> Annotated<NvCreateCompletionResponse> {
...@@ -217,7 +217,7 @@ mod tests { ...@@ -217,7 +217,7 @@ mod tests {
usage: None, usage: None,
system_fingerprint: None, system_fingerprint: None,
choices: vec![async_openai::types::Choice { choices: vec![async_openai::types::Choice {
index: index as u32, index,
text: text.to_string(), text: text.to_string(),
finish_reason, finish_reason,
logprobs: None, logprobs: None,
......
...@@ -39,7 +39,7 @@ pub struct DeltaGeneratorOptions { ...@@ -39,7 +39,7 @@ pub struct DeltaGeneratorOptions {
pub struct DeltaGenerator { pub struct DeltaGenerator {
id: String, id: String,
object: String, object: String,
created: u64, created: u32,
model: String, model: String,
system_fingerprint: Option<String>, system_fingerprint: Option<String>,
usage: async_openai::types::CompletionUsage, usage: async_openai::types::CompletionUsage,
...@@ -53,6 +53,10 @@ impl DeltaGenerator { ...@@ -53,6 +53,10 @@ impl DeltaGenerator {
.unwrap() .unwrap()
.as_secs(); .as_secs();
// SAFETY: Casting from `u64` to `u32` could lead to precision loss after `u32::MAX`,
// but this will not be an issue until 2106.
let now: u32 = now.try_into().expect("timestamp exceeds u32::MAX");
// Previously, our home-rolled CompletionUsage impl'd Default // Previously, our home-rolled CompletionUsage impl'd Default
// PR !387 - https://github.com/64bit/async-openai/pull/387 // PR !387 - https://github.com/64bit/async-openai/pull/387
let usage = async_openai::types::CompletionUsage { let usage = async_openai::types::CompletionUsage {
...@@ -80,7 +84,7 @@ impl DeltaGenerator { ...@@ -80,7 +84,7 @@ impl DeltaGenerator {
pub fn create_choice( pub fn create_choice(
&self, &self,
index: u64, index: u32,
text: Option<String>, text: Option<String>,
finish_reason: Option<async_openai::types::CompletionFinishReason>, finish_reason: Option<async_openai::types::CompletionFinishReason>,
) -> NvCreateCompletionResponse { ) -> NvCreateCompletionResponse {
...@@ -94,12 +98,12 @@ impl DeltaGenerator { ...@@ -94,12 +98,12 @@ impl DeltaGenerator {
let inner = async_openai::types::CreateCompletionResponse { let inner = async_openai::types::CreateCompletionResponse {
id: self.id.clone(), id: self.id.clone(),
object: self.object.clone(), object: self.object.clone(),
created: self.created as u32, created: self.created,
model: self.model.clone(), model: self.model.clone(),
system_fingerprint: self.system_fingerprint.clone(), system_fingerprint: self.system_fingerprint.clone(),
choices: vec![async_openai::types::Choice { choices: vec![async_openai::types::Choice {
text: text.unwrap_or_default(), text: text.unwrap_or_default(),
index: index as u32, index,
finish_reason, finish_reason,
logprobs: None, logprobs: None,
}], }],
...@@ -121,7 +125,15 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for ...@@ -121,7 +125,15 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
) -> anyhow::Result<NvCreateCompletionResponse> { ) -> anyhow::Result<NvCreateCompletionResponse> {
// aggregate usage // aggregate usage
if self.options.enable_usage { if self.options.enable_usage {
self.usage.completion_tokens += delta.token_ids.len() as u32; // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`,
// but this will not be an issue until context lengths exceed 4_294_967_295.
let token_length: u32 = delta
.token_ids
.len()
.try_into()
.expect("token_ids length exceeds u32::MAX");
self.usage.completion_tokens += token_length;
} }
// TODO logprobs // TODO logprobs
...@@ -129,7 +141,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for ...@@ -129,7 +141,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
let finish_reason = delta.finish_reason.map(Into::into); let finish_reason = delta.finish_reason.map(Into::into);
// create choice // create choice
let index = delta.index.unwrap_or(0).into(); let index = delta.index.unwrap_or(0);
let response = self.create_choice(index, delta.text.clone(), finish_reason); let response = self.create_choice(index, delta.text.clone(), finish_reason);
Ok(response) Ok(response)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment