Unverified Commit 0edc886f authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: using async_openai::types::Logprobs (#1625)

parent 0b7cdf55
......@@ -13,8 +13,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use derive_builder::Builder;
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
use serde::{Deserialize, Serialize};
......@@ -92,7 +90,7 @@ pub struct CompletionChoice {
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub logprobs: Option<LogprobResult>,
pub logprobs: Option<async_openai::types::Logprobs>,
}
impl ContentProvider for CompletionChoice {
......@@ -107,16 +105,6 @@ impl CompletionChoice {
}
}
// TODO: validate this is the correct format
/// Legacy OpenAI LogprobResult component
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct LogprobResult {
pub tokens: Vec<String>,
pub token_logprobs: Vec<f32>,
pub top_logprobs: Vec<HashMap<String, f32>>,
pub text_offset: Vec<i32>,
}
pub fn prompt_to_string(prompt: &async_openai::types::Prompt) -> String {
match prompt {
async_openai::types::Prompt::String(s) => s.clone(),
......
......@@ -18,7 +18,7 @@ use std::{collections::HashMap, str::FromStr};
use anyhow::Result;
use futures::StreamExt;
use super::{CompletionChoice, CompletionResponse, LogprobResult};
use super::{CompletionChoice, CompletionResponse};
use crate::protocols::{
codec::{Message, SseCodecError},
common::FinishReason,
......@@ -40,7 +40,7 @@ struct DeltaChoice {
index: u64,
text: String,
finish_reason: Option<FinishReason>,
logprobs: Option<LogprobResult>,
logprobs: Option<async_openai::types::Logprobs>,
}
impl Default for DeltaAggregator {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment