Unverified Commit 7b7b6a6d authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: refactored using Choice and CompletionFinishReason (#1635)

parent c95031ed
......@@ -238,7 +238,7 @@ impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<Completi
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
id += 1;
}
let response = deltas.create_choice(0, None, Some("stop".to_string()));
let response = deltas.create_choice(0, None, Some(async_openai::types::CompletionFinishReason::Stop));
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
};
......
......@@ -64,6 +64,9 @@ pub enum FinishReason {
#[serde(rename = "cancelled")]
Cancelled,
#[serde(rename = "content_filter")]
ContentFilter,
}
impl std::fmt::Display for FinishReason {
......@@ -74,6 +77,7 @@ impl std::fmt::Display for FinishReason {
FinishReason::Stop => write!(f, "stop"),
FinishReason::Error(msg) => write!(f, "error: {}", msg),
FinishReason::Cancelled => write!(f, "cancelled"),
FinishReason::ContentFilter => write!(f, "content_filter"),
}
}
}
......@@ -93,6 +97,33 @@ impl std::str::FromStr for FinishReason {
}
}
impl From<FinishReason> for async_openai::types::CompletionFinishReason {
fn from(reason: FinishReason) -> Self {
match reason {
FinishReason::EoS | FinishReason::Stop | FinishReason::Cancelled => {
async_openai::types::CompletionFinishReason::Stop
}
FinishReason::ContentFilter => {
async_openai::types::CompletionFinishReason::ContentFilter
}
FinishReason::Length => async_openai::types::CompletionFinishReason::Length,
FinishReason::Error(_) => async_openai::types::CompletionFinishReason::Stop,
}
}
}
impl From<async_openai::types::CompletionFinishReason> for FinishReason {
fn from(reason: async_openai::types::CompletionFinishReason) -> Self {
match reason {
async_openai::types::CompletionFinishReason::Stop => FinishReason::Stop,
async_openai::types::CompletionFinishReason::Length => FinishReason::Length,
async_openai::types::CompletionFinishReason::ContentFilter => {
FinishReason::ContentFilter
}
}
}
}
/// LLM Inference Engines can accept a variety of input types. Not all Engines will support all
/// input types. For example, the trtllm::AsyncEngine only supports `PromptType::Tokens` as an
/// input type. The higher-level `Backend` class is a general wrapper around Engines that will
......
......@@ -203,6 +203,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
Some(common::FinishReason::Stop) => Some(async_openai::types::FinishReason::Stop),
Some(common::FinishReason::Length) => Some(async_openai::types::FinishReason::Length),
Some(common::FinishReason::Cancelled) => Some(async_openai::types::FinishReason::Stop),
Some(common::FinishReason::ContentFilter) => {
Some(async_openai::types::FinishReason::ContentFilter)
}
Some(common::FinishReason::Error(err_msg)) => {
return Err(anyhow::anyhow!(err_msg));
}
......
......@@ -49,7 +49,7 @@ pub struct CompletionResponse {
pub id: String,
/// The list of completion choices the model generated for the input prompt.
pub choices: Vec<CompletionChoice>,
pub choices: Vec<async_openai::types::Choice>,
/// The Unix timestamp (in seconds) of when the completion was created.
pub created: u64,
......@@ -76,35 +76,12 @@ pub struct CompletionResponse {
// pub nvext: Option<NimResponseExt>,
}
/// Legacy OpenAI CompletionResponse Choice component
#[derive(Clone, Debug, Deserialize, Serialize, Builder)]
pub struct CompletionChoice {
#[builder(setter(into))]
pub text: String,
#[builder(default = "0")]
pub index: u64,
#[builder(default, setter(into, strip_option))]
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub logprobs: Option<async_openai::types::Logprobs>,
}
impl ContentProvider for CompletionChoice {
impl ContentProvider for async_openai::types::Choice {
fn content(&self) -> String {
self.text.clone()
}
}
impl CompletionChoice {
pub fn builder() -> CompletionChoiceBuilder {
CompletionChoiceBuilder::default()
}
}
pub fn prompt_to_string(prompt: &async_openai::types::Prompt) -> String {
match prompt {
async_openai::types::Prompt::String(s) => s.clone(),
......@@ -226,7 +203,7 @@ impl ResponseFactory {
pub fn make_response(
&self,
choice: CompletionChoice,
choice: async_openai::types::Choice,
usage: Option<async_openai::types::CompletionUsage>,
) -> CompletionResponse {
CompletionResponse {
......@@ -294,27 +271,30 @@ impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
}
}
impl TryFrom<common::StreamingCompletionResponse> for CompletionChoice {
impl TryFrom<common::StreamingCompletionResponse> for async_openai::types::Choice {
type Error = anyhow::Error;
fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> {
let choice = CompletionChoice {
text: response
.delta
.text
.ok_or(anyhow::anyhow!("No text in response"))?,
index: response.delta.index.unwrap_or(0) as u64,
logprobs: None,
finish_reason: match &response.delta.finish_reason {
Some(common::FinishReason::EoS) => Some("stop".to_string()),
Some(common::FinishReason::Stop) => Some("stop".to_string()),
Some(common::FinishReason::Length) => Some("length".to_string()),
Some(common::FinishReason::Error(err_msg)) => {
return Err(anyhow::anyhow!("finish_reason::error = {}", err_msg));
}
Some(common::FinishReason::Cancelled) => Some("cancelled".to_string()),
None => None,
},
let text = response
.delta
.text
.ok_or(anyhow::anyhow!("No text in response"))?;
// Safety: we're downcasting from u64 to u32 here but u32::MAX is 4_294_967_295
// so we're fairly safe knowing we won't generate that many Choices
let index = response.delta.index.unwrap_or(0) as u32;
// TODO handle aggregating logprobs
let logprobs = None;
let finish_reason: Option<async_openai::types::CompletionFinishReason> =
response.delta.finish_reason.map(Into::into);
let choice = async_openai::types::Choice {
text,
index,
logprobs,
finish_reason,
};
Ok(choice)
......
......@@ -13,12 +13,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashMap, str::FromStr};
use std::collections::HashMap;
use anyhow::Result;
use futures::StreamExt;
use super::{CompletionChoice, CompletionResponse};
use super::CompletionResponse;
use crate::protocols::{
codec::{Message, SseCodecError},
common::FinishReason,
......@@ -98,9 +98,9 @@ impl DeltaAggregator {
let state_choice =
aggregator
.choices
.entry(choice.index)
.entry(choice.index as u64)
.or_insert(DeltaChoice {
index: choice.index,
index: choice.index as u64,
text: "".to_string(),
finish_reason: None,
logprobs: choice.logprobs,
......@@ -108,12 +108,21 @@ impl DeltaAggregator {
state_choice.text.push_str(&choice.text);
// todo - handle logprobs
if let Some(finish_reason) = choice.finish_reason {
let reason = FinishReason::from_str(&finish_reason).ok();
state_choice.finish_reason = reason;
}
// TODO - handle logprobs
// Handle CompletionFinishReason -> FinishReason conversation
state_choice.finish_reason = match choice.finish_reason {
Some(async_openai::types::CompletionFinishReason::Stop) => {
Some(FinishReason::Stop)
}
Some(async_openai::types::CompletionFinishReason::Length) => {
Some(FinishReason::Length)
}
Some(async_openai::types::CompletionFinishReason::ContentFilter) => {
Some(FinishReason::ContentFilter)
}
None => None,
};
}
}
aggregator
......@@ -131,7 +140,7 @@ impl DeltaAggregator {
let mut choices: Vec<_> = aggregator
.choices
.into_values()
.map(CompletionChoice::from)
.map(async_openai::types::Choice::from)
.collect();
choices.sort_by(|a, b| a.index.cmp(&b.index));
......@@ -148,12 +157,12 @@ impl DeltaAggregator {
}
}
impl From<DeltaChoice> for CompletionChoice {
impl From<DeltaChoice> for async_openai::types::Choice {
fn from(delta: DeltaChoice) -> Self {
let finish_reason = delta.finish_reason.map(|reason| reason.to_string());
let finish_reason = delta.finish_reason.map(Into::into);
CompletionChoice {
index: delta.index,
async_openai::types::Choice {
index: delta.index as u32,
text: delta.text,
finish_reason,
logprobs: delta.logprobs,
......@@ -178,16 +187,25 @@ impl CompletionResponse {
#[cfg(test)]
mod tests {
use crate::protocols::openai::completions::{CompletionChoice, CompletionResponse};
use std::str::FromStr;
use super::*;
use futures::stream;
use super::*;
use crate::protocols::openai::completions::CompletionResponse;
fn create_test_delta(
index: u64,
text: &str,
finish_reason: Option<String>,
) -> Annotated<CompletionResponse> {
// This will silently discard invalid_finish reason values and fall back
// to None - totally fine since this is test code
let finish_reason = finish_reason
.as_deref()
.and_then(|s| FinishReason::from_str(s).ok())
.map(Into::into);
Annotated {
data: Some(CompletionResponse {
id: "test_id".to_string(),
......@@ -195,8 +213,8 @@ mod tests {
created: 1234567890,
usage: None,
system_fingerprint: None,
choices: vec![CompletionChoice {
index,
choices: vec![async_openai::types::Choice {
index: index as u32,
text: text.to_string(),
finish_reason,
logprobs: None,
......@@ -255,7 +273,10 @@ mod tests {
let choice = &response.choices[0];
assert_eq!(choice.index, 0);
assert_eq!(choice.text, "Hello,".to_string());
assert_eq!(choice.finish_reason, Some("length".to_string()));
assert_eq!(
choice.finish_reason,
Some(async_openai::types::CompletionFinishReason::Length)
);
assert!(choice.logprobs.is_none());
}
......@@ -283,7 +304,10 @@ mod tests {
let choice = &response.choices[0];
assert_eq!(choice.index, 0);
assert_eq!(choice.text, "Hello, world!".to_string());
assert_eq!(choice.finish_reason, Some("stop".to_string()));
assert_eq!(
choice.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop)
);
}
#[tokio::test]
......@@ -297,16 +321,16 @@ mod tests {
usage: None,
system_fingerprint: None,
choices: vec![
CompletionChoice {
async_openai::types::Choice {
index: 0,
text: "Choice 0".to_string(),
finish_reason: Some("stop".to_string()),
finish_reason: Some(async_openai::types::CompletionFinishReason::Stop),
logprobs: None,
},
CompletionChoice {
async_openai::types::Choice {
index: 1,
text: "Choice 1".to_string(),
finish_reason: Some("stop".to_string()),
finish_reason: Some(async_openai::types::CompletionFinishReason::Stop),
logprobs: None,
},
],
......@@ -333,11 +357,17 @@ mod tests {
let choice0 = &response.choices[0];
assert_eq!(choice0.index, 0);
assert_eq!(choice0.text, "Choice 0".to_string());
assert_eq!(choice0.finish_reason, Some("stop".to_string()));
assert_eq!(
choice0.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop)
);
let choice1 = &response.choices[1];
assert_eq!(choice1.index, 1);
assert_eq!(choice1.text, "Choice 1".to_string());
assert_eq!(choice1.finish_reason, Some("stop".to_string()));
assert_eq!(
choice1.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop)
);
}
}
......@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use super::{CompletionChoice, CompletionResponse, NvCreateCompletionRequest};
use super::{CompletionResponse, NvCreateCompletionRequest};
use crate::protocols::common;
impl NvCreateCompletionRequest {
......@@ -82,7 +82,7 @@ impl DeltaGenerator {
&self,
index: u64,
text: Option<String>,
finish_reason: Option<String>,
finish_reason: Option<async_openai::types::CompletionFinishReason>,
) -> CompletionResponse {
// todo - update for tool calling
......@@ -97,9 +97,9 @@ impl DeltaGenerator {
created: self.created,
model: self.model.clone(),
system_fingerprint: self.system_fingerprint.clone(),
choices: vec![CompletionChoice {
choices: vec![async_openai::types::Choice {
text: text.unwrap_or_default(),
index,
index: index as u32,
finish_reason,
logprobs: None,
}],
......@@ -122,18 +122,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGe
self.usage.completion_tokens += delta.token_ids.len() as u32;
}
// todo logprobs
let finish_reason = match delta.finish_reason {
Some(common::FinishReason::EoS) => Some("stop".to_string()),
Some(common::FinishReason::Stop) => Some("stop".to_string()),
Some(common::FinishReason::Length) => Some("length".to_string()),
Some(common::FinishReason::Cancelled) => Some("cancelled".to_string()),
Some(common::FinishReason::Error(err_msg)) => {
return Err(anyhow::anyhow!(err_msg));
}
None => None,
};
// TODO logprobs
let finish_reason = delta.finish_reason.map(Into::into);
// create choice
let index = delta.index.unwrap_or(0).into();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment