Unverified Commit 7b7b6a6d authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: refactored using Choice and CompletionFinishReason (#1635)

parent c95031ed
...@@ -238,7 +238,7 @@ impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<Completi ...@@ -238,7 +238,7 @@ impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<Completi
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
id += 1; id += 1;
} }
let response = deltas.create_choice(0, None, Some("stop".to_string())); let response = deltas.create_choice(0, None, Some(async_openai::types::CompletionFinishReason::Stop));
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
}; };
......
...@@ -64,6 +64,9 @@ pub enum FinishReason { ...@@ -64,6 +64,9 @@ pub enum FinishReason {
#[serde(rename = "cancelled")] #[serde(rename = "cancelled")]
Cancelled, Cancelled,
#[serde(rename = "content_filter")]
ContentFilter,
} }
impl std::fmt::Display for FinishReason { impl std::fmt::Display for FinishReason {
...@@ -74,6 +77,7 @@ impl std::fmt::Display for FinishReason { ...@@ -74,6 +77,7 @@ impl std::fmt::Display for FinishReason {
FinishReason::Stop => write!(f, "stop"), FinishReason::Stop => write!(f, "stop"),
FinishReason::Error(msg) => write!(f, "error: {}", msg), FinishReason::Error(msg) => write!(f, "error: {}", msg),
FinishReason::Cancelled => write!(f, "cancelled"), FinishReason::Cancelled => write!(f, "cancelled"),
FinishReason::ContentFilter => write!(f, "content_filter"),
} }
} }
} }
...@@ -93,6 +97,33 @@ impl std::str::FromStr for FinishReason { ...@@ -93,6 +97,33 @@ impl std::str::FromStr for FinishReason {
} }
} }
impl From<FinishReason> for async_openai::types::CompletionFinishReason {
fn from(reason: FinishReason) -> Self {
match reason {
FinishReason::EoS | FinishReason::Stop | FinishReason::Cancelled => {
async_openai::types::CompletionFinishReason::Stop
}
FinishReason::ContentFilter => {
async_openai::types::CompletionFinishReason::ContentFilter
}
FinishReason::Length => async_openai::types::CompletionFinishReason::Length,
FinishReason::Error(_) => async_openai::types::CompletionFinishReason::Stop,
}
}
}
impl From<async_openai::types::CompletionFinishReason> for FinishReason {
fn from(reason: async_openai::types::CompletionFinishReason) -> Self {
match reason {
async_openai::types::CompletionFinishReason::Stop => FinishReason::Stop,
async_openai::types::CompletionFinishReason::Length => FinishReason::Length,
async_openai::types::CompletionFinishReason::ContentFilter => {
FinishReason::ContentFilter
}
}
}
}
/// LLM Inference Engines can accept a variety of input types. Not all Engines will support all /// LLM Inference Engines can accept a variety of input types. Not all Engines will support all
/// input types. For example, the trtllm::AsyncEngine only supports `PromptType::Tokens` as an /// input types. For example, the trtllm::AsyncEngine only supports `PromptType::Tokens` as an
/// input type. The higher-level `Backend` class is a general wrapper around Engines that will /// input type. The higher-level `Backend` class is a general wrapper around Engines that will
......
...@@ -203,6 +203,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -203,6 +203,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
Some(common::FinishReason::Stop) => Some(async_openai::types::FinishReason::Stop), Some(common::FinishReason::Stop) => Some(async_openai::types::FinishReason::Stop),
Some(common::FinishReason::Length) => Some(async_openai::types::FinishReason::Length), Some(common::FinishReason::Length) => Some(async_openai::types::FinishReason::Length),
Some(common::FinishReason::Cancelled) => Some(async_openai::types::FinishReason::Stop), Some(common::FinishReason::Cancelled) => Some(async_openai::types::FinishReason::Stop),
Some(common::FinishReason::ContentFilter) => {
Some(async_openai::types::FinishReason::ContentFilter)
}
Some(common::FinishReason::Error(err_msg)) => { Some(common::FinishReason::Error(err_msg)) => {
return Err(anyhow::anyhow!(err_msg)); return Err(anyhow::anyhow!(err_msg));
} }
......
...@@ -49,7 +49,7 @@ pub struct CompletionResponse { ...@@ -49,7 +49,7 @@ pub struct CompletionResponse {
pub id: String, pub id: String,
/// The list of completion choices the model generated for the input prompt. /// The list of completion choices the model generated for the input prompt.
pub choices: Vec<CompletionChoice>, pub choices: Vec<async_openai::types::Choice>,
/// The Unix timestamp (in seconds) of when the completion was created. /// The Unix timestamp (in seconds) of when the completion was created.
pub created: u64, pub created: u64,
...@@ -76,35 +76,12 @@ pub struct CompletionResponse { ...@@ -76,35 +76,12 @@ pub struct CompletionResponse {
// pub nvext: Option<NimResponseExt>, // pub nvext: Option<NimResponseExt>,
} }
/// Legacy OpenAI CompletionResponse Choice component impl ContentProvider for async_openai::types::Choice {
#[derive(Clone, Debug, Deserialize, Serialize, Builder)]
pub struct CompletionChoice {
#[builder(setter(into))]
pub text: String,
#[builder(default = "0")]
pub index: u64,
#[builder(default, setter(into, strip_option))]
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub logprobs: Option<async_openai::types::Logprobs>,
}
impl ContentProvider for CompletionChoice {
fn content(&self) -> String { fn content(&self) -> String {
self.text.clone() self.text.clone()
} }
} }
impl CompletionChoice {
pub fn builder() -> CompletionChoiceBuilder {
CompletionChoiceBuilder::default()
}
}
pub fn prompt_to_string(prompt: &async_openai::types::Prompt) -> String { pub fn prompt_to_string(prompt: &async_openai::types::Prompt) -> String {
match prompt { match prompt {
async_openai::types::Prompt::String(s) => s.clone(), async_openai::types::Prompt::String(s) => s.clone(),
...@@ -226,7 +203,7 @@ impl ResponseFactory { ...@@ -226,7 +203,7 @@ impl ResponseFactory {
pub fn make_response( pub fn make_response(
&self, &self,
choice: CompletionChoice, choice: async_openai::types::Choice,
usage: Option<async_openai::types::CompletionUsage>, usage: Option<async_openai::types::CompletionUsage>,
) -> CompletionResponse { ) -> CompletionResponse {
CompletionResponse { CompletionResponse {
...@@ -294,27 +271,30 @@ impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest { ...@@ -294,27 +271,30 @@ impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
} }
} }
impl TryFrom<common::StreamingCompletionResponse> for CompletionChoice { impl TryFrom<common::StreamingCompletionResponse> for async_openai::types::Choice {
type Error = anyhow::Error; type Error = anyhow::Error;
fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> { fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> {
let choice = CompletionChoice { let text = response
text: response .delta
.delta .text
.text .ok_or(anyhow::anyhow!("No text in response"))?;
.ok_or(anyhow::anyhow!("No text in response"))?,
index: response.delta.index.unwrap_or(0) as u64, // Safety: we're downcasting from u64 to u32 here but u32::MAX is 4_294_967_295
logprobs: None, // so we're fairly safe knowing we won't generate that many Choices
finish_reason: match &response.delta.finish_reason { let index = response.delta.index.unwrap_or(0) as u32;
Some(common::FinishReason::EoS) => Some("stop".to_string()),
Some(common::FinishReason::Stop) => Some("stop".to_string()), // TODO handle aggregating logprobs
Some(common::FinishReason::Length) => Some("length".to_string()), let logprobs = None;
Some(common::FinishReason::Error(err_msg)) => {
return Err(anyhow::anyhow!("finish_reason::error = {}", err_msg)); let finish_reason: Option<async_openai::types::CompletionFinishReason> =
} response.delta.finish_reason.map(Into::into);
Some(common::FinishReason::Cancelled) => Some("cancelled".to_string()),
None => None, let choice = async_openai::types::Choice {
}, text,
index,
logprobs,
finish_reason,
}; };
Ok(choice) Ok(choice)
......
...@@ -13,12 +13,12 @@ ...@@ -13,12 +13,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{collections::HashMap, str::FromStr}; use std::collections::HashMap;
use anyhow::Result; use anyhow::Result;
use futures::StreamExt; use futures::StreamExt;
use super::{CompletionChoice, CompletionResponse}; use super::CompletionResponse;
use crate::protocols::{ use crate::protocols::{
codec::{Message, SseCodecError}, codec::{Message, SseCodecError},
common::FinishReason, common::FinishReason,
...@@ -98,9 +98,9 @@ impl DeltaAggregator { ...@@ -98,9 +98,9 @@ impl DeltaAggregator {
let state_choice = let state_choice =
aggregator aggregator
.choices .choices
.entry(choice.index) .entry(choice.index as u64)
.or_insert(DeltaChoice { .or_insert(DeltaChoice {
index: choice.index, index: choice.index as u64,
text: "".to_string(), text: "".to_string(),
finish_reason: None, finish_reason: None,
logprobs: choice.logprobs, logprobs: choice.logprobs,
...@@ -108,12 +108,21 @@ impl DeltaAggregator { ...@@ -108,12 +108,21 @@ impl DeltaAggregator {
state_choice.text.push_str(&choice.text); state_choice.text.push_str(&choice.text);
// todo - handle logprobs // TODO - handle logprobs
if let Some(finish_reason) = choice.finish_reason { // Handle CompletionFinishReason -> FinishReason conversation
let reason = FinishReason::from_str(&finish_reason).ok(); state_choice.finish_reason = match choice.finish_reason {
state_choice.finish_reason = reason; Some(async_openai::types::CompletionFinishReason::Stop) => {
} Some(FinishReason::Stop)
}
Some(async_openai::types::CompletionFinishReason::Length) => {
Some(FinishReason::Length)
}
Some(async_openai::types::CompletionFinishReason::ContentFilter) => {
Some(FinishReason::ContentFilter)
}
None => None,
};
} }
} }
aggregator aggregator
...@@ -131,7 +140,7 @@ impl DeltaAggregator { ...@@ -131,7 +140,7 @@ impl DeltaAggregator {
let mut choices: Vec<_> = aggregator let mut choices: Vec<_> = aggregator
.choices .choices
.into_values() .into_values()
.map(CompletionChoice::from) .map(async_openai::types::Choice::from)
.collect(); .collect();
choices.sort_by(|a, b| a.index.cmp(&b.index)); choices.sort_by(|a, b| a.index.cmp(&b.index));
...@@ -148,12 +157,12 @@ impl DeltaAggregator { ...@@ -148,12 +157,12 @@ impl DeltaAggregator {
} }
} }
impl From<DeltaChoice> for CompletionChoice { impl From<DeltaChoice> for async_openai::types::Choice {
fn from(delta: DeltaChoice) -> Self { fn from(delta: DeltaChoice) -> Self {
let finish_reason = delta.finish_reason.map(|reason| reason.to_string()); let finish_reason = delta.finish_reason.map(Into::into);
CompletionChoice { async_openai::types::Choice {
index: delta.index, index: delta.index as u32,
text: delta.text, text: delta.text,
finish_reason, finish_reason,
logprobs: delta.logprobs, logprobs: delta.logprobs,
...@@ -178,16 +187,25 @@ impl CompletionResponse { ...@@ -178,16 +187,25 @@ impl CompletionResponse {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::protocols::openai::completions::{CompletionChoice, CompletionResponse}; use std::str::FromStr;
use super::*;
use futures::stream; use futures::stream;
use super::*;
use crate::protocols::openai::completions::CompletionResponse;
fn create_test_delta( fn create_test_delta(
index: u64, index: u64,
text: &str, text: &str,
finish_reason: Option<String>, finish_reason: Option<String>,
) -> Annotated<CompletionResponse> { ) -> Annotated<CompletionResponse> {
// This will silently discard invalid_finish reason values and fall back
// to None - totally fine since this is test code
let finish_reason = finish_reason
.as_deref()
.and_then(|s| FinishReason::from_str(s).ok())
.map(Into::into);
Annotated { Annotated {
data: Some(CompletionResponse { data: Some(CompletionResponse {
id: "test_id".to_string(), id: "test_id".to_string(),
...@@ -195,8 +213,8 @@ mod tests { ...@@ -195,8 +213,8 @@ mod tests {
created: 1234567890, created: 1234567890,
usage: None, usage: None,
system_fingerprint: None, system_fingerprint: None,
choices: vec![CompletionChoice { choices: vec![async_openai::types::Choice {
index, index: index as u32,
text: text.to_string(), text: text.to_string(),
finish_reason, finish_reason,
logprobs: None, logprobs: None,
...@@ -255,7 +273,10 @@ mod tests { ...@@ -255,7 +273,10 @@ mod tests {
let choice = &response.choices[0]; let choice = &response.choices[0];
assert_eq!(choice.index, 0); assert_eq!(choice.index, 0);
assert_eq!(choice.text, "Hello,".to_string()); assert_eq!(choice.text, "Hello,".to_string());
assert_eq!(choice.finish_reason, Some("length".to_string())); assert_eq!(
choice.finish_reason,
Some(async_openai::types::CompletionFinishReason::Length)
);
assert!(choice.logprobs.is_none()); assert!(choice.logprobs.is_none());
} }
...@@ -283,7 +304,10 @@ mod tests { ...@@ -283,7 +304,10 @@ mod tests {
let choice = &response.choices[0]; let choice = &response.choices[0];
assert_eq!(choice.index, 0); assert_eq!(choice.index, 0);
assert_eq!(choice.text, "Hello, world!".to_string()); assert_eq!(choice.text, "Hello, world!".to_string());
assert_eq!(choice.finish_reason, Some("stop".to_string())); assert_eq!(
choice.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop)
);
} }
#[tokio::test] #[tokio::test]
...@@ -297,16 +321,16 @@ mod tests { ...@@ -297,16 +321,16 @@ mod tests {
usage: None, usage: None,
system_fingerprint: None, system_fingerprint: None,
choices: vec![ choices: vec![
CompletionChoice { async_openai::types::Choice {
index: 0, index: 0,
text: "Choice 0".to_string(), text: "Choice 0".to_string(),
finish_reason: Some("stop".to_string()), finish_reason: Some(async_openai::types::CompletionFinishReason::Stop),
logprobs: None, logprobs: None,
}, },
CompletionChoice { async_openai::types::Choice {
index: 1, index: 1,
text: "Choice 1".to_string(), text: "Choice 1".to_string(),
finish_reason: Some("stop".to_string()), finish_reason: Some(async_openai::types::CompletionFinishReason::Stop),
logprobs: None, logprobs: None,
}, },
], ],
...@@ -333,11 +357,17 @@ mod tests { ...@@ -333,11 +357,17 @@ mod tests {
let choice0 = &response.choices[0]; let choice0 = &response.choices[0];
assert_eq!(choice0.index, 0); assert_eq!(choice0.index, 0);
assert_eq!(choice0.text, "Choice 0".to_string()); assert_eq!(choice0.text, "Choice 0".to_string());
assert_eq!(choice0.finish_reason, Some("stop".to_string())); assert_eq!(
choice0.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop)
);
let choice1 = &response.choices[1]; let choice1 = &response.choices[1];
assert_eq!(choice1.index, 1); assert_eq!(choice1.index, 1);
assert_eq!(choice1.text, "Choice 1".to_string()); assert_eq!(choice1.text, "Choice 1".to_string());
assert_eq!(choice1.finish_reason, Some("stop".to_string())); assert_eq!(
choice1.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop)
);
} }
} }
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use super::{CompletionChoice, CompletionResponse, NvCreateCompletionRequest}; use super::{CompletionResponse, NvCreateCompletionRequest};
use crate::protocols::common; use crate::protocols::common;
impl NvCreateCompletionRequest { impl NvCreateCompletionRequest {
...@@ -82,7 +82,7 @@ impl DeltaGenerator { ...@@ -82,7 +82,7 @@ impl DeltaGenerator {
&self, &self,
index: u64, index: u64,
text: Option<String>, text: Option<String>,
finish_reason: Option<String>, finish_reason: Option<async_openai::types::CompletionFinishReason>,
) -> CompletionResponse { ) -> CompletionResponse {
// todo - update for tool calling // todo - update for tool calling
...@@ -97,9 +97,9 @@ impl DeltaGenerator { ...@@ -97,9 +97,9 @@ impl DeltaGenerator {
created: self.created, created: self.created,
model: self.model.clone(), model: self.model.clone(),
system_fingerprint: self.system_fingerprint.clone(), system_fingerprint: self.system_fingerprint.clone(),
choices: vec![CompletionChoice { choices: vec![async_openai::types::Choice {
text: text.unwrap_or_default(), text: text.unwrap_or_default(),
index, index: index as u32,
finish_reason, finish_reason,
logprobs: None, logprobs: None,
}], }],
...@@ -122,18 +122,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGe ...@@ -122,18 +122,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGe
self.usage.completion_tokens += delta.token_ids.len() as u32; self.usage.completion_tokens += delta.token_ids.len() as u32;
} }
// todo logprobs // TODO logprobs
let finish_reason = match delta.finish_reason { let finish_reason = delta.finish_reason.map(Into::into);
Some(common::FinishReason::EoS) => Some("stop".to_string()),
Some(common::FinishReason::Stop) => Some("stop".to_string()),
Some(common::FinishReason::Length) => Some("length".to_string()),
Some(common::FinishReason::Cancelled) => Some("cancelled".to_string()),
Some(common::FinishReason::Error(err_msg)) => {
return Err(anyhow::anyhow!(err_msg));
}
None => None,
};
// create choice // create choice
let index = delta.index.unwrap_or(0).into(); let index = delta.index.unwrap_or(0).into();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment