Unverified Commit e3f1bd5d authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: refactored using CompletionResponse (#1658)

parent 7b7b6a6d
...@@ -139,7 +139,7 @@ mod tests { ...@@ -139,7 +139,7 @@ mod tests {
use super::*; use super::*;
use dynamo_llm::types::openai::{ use dynamo_llm::types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{CompletionResponse, NvCreateCompletionRequest}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
}; };
const HF_PATH: &str = concat!( const HF_PATH: &str = concat!(
...@@ -174,7 +174,8 @@ mod tests { ...@@ -174,7 +174,8 @@ mod tests {
// Build pipeline for completions // Build pipeline for completions
let pipeline = let pipeline =
build_pipeline::<NvCreateCompletionRequest, CompletionResponse>(&card, engine).await?; build_pipeline::<NvCreateCompletionRequest, NvCreateCompletionResponse>(&card, engine)
.await?;
// Verify pipeline was created // Verify pipeline was created
assert!(Arc::strong_count(&pipeline) >= 1); assert!(Arc::strong_count(&pipeline) >= 1);
......
...@@ -15,7 +15,7 @@ use dynamo_llm::{ ...@@ -15,7 +15,7 @@ use dynamo_llm::{
openai::chat_completions::{ openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
}, },
openai::completions::{CompletionResponse, NvCreateCompletionRequest}, openai::completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
}, },
}; };
use dynamo_runtime::pipeline::RouterMode; use dynamo_runtime::pipeline::RouterMode;
...@@ -78,7 +78,7 @@ pub async fn run( ...@@ -78,7 +78,7 @@ pub async fn run(
let cmpl_pipeline = common::build_pipeline::< let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
CompletionResponse, NvCreateCompletionResponse,
>(model.card(), inner_engine) >(model.card(), inner_engine)
.await?; .await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?; manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
......
...@@ -25,7 +25,7 @@ use dynamo_runtime::protocols::annotated::Annotated; ...@@ -25,7 +25,7 @@ use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_llm::protocols::openai::{ use dynamo_llm::protocols::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{prompt_to_string, CompletionResponse, NvCreateCompletionRequest}, completions::{prompt_to_string, NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
}; };
...@@ -467,13 +467,17 @@ fn to_logit_bias(lb: HashMap<String, serde_json::Value>) -> HashMap<u32, f32> { ...@@ -467,13 +467,17 @@ fn to_logit_bias(lb: HashMap<String, serde_json::Value>) -> HashMap<u32, f32> {
} }
#[async_trait] #[async_trait]
impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error> impl
for MistralRsEngine AsyncEngine<
SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<NvCreateCompletionResponse>>,
Error,
> for MistralRsEngine
{ {
async fn generate( async fn generate(
&self, &self,
request: SingleIn<NvCreateCompletionRequest>, request: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
let (request, context) = request.transfer(()); let (request, context) = request.transfer(());
let ctx = context.context(); let ctx = context.context();
let (tx, mut rx) = channel(10_000); let (tx, mut rx) = channel(10_000);
......
...@@ -25,7 +25,7 @@ use crate::{ ...@@ -25,7 +25,7 @@ use crate::{
protocols::openai::chat_completions::{ protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
}, },
protocols::openai::completions::{CompletionResponse, NvCreateCompletionRequest}, protocols::openai::completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
}; };
...@@ -240,7 +240,7 @@ impl ModelWatcher { ...@@ -240,7 +240,7 @@ impl ModelWatcher {
let frontend = SegmentSource::< let frontend = SegmentSource::<
SingleIn<NvCreateCompletionRequest>, SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<CompletionResponse>>, ManyOut<Annotated<NvCreateCompletionResponse>>,
>::new(); >::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator(); let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator(); let backend = Backend::from_mdc(card.clone()).await?.into_operator();
...@@ -292,7 +292,7 @@ impl ModelWatcher { ...@@ -292,7 +292,7 @@ impl ModelWatcher {
ModelType::Completion => { ModelType::Completion => {
let push_router = PushRouter::< let push_router = PushRouter::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
Annotated<CompletionResponse>, Annotated<NvCreateCompletionResponse>,
>::from_client(client, Default::default()) >::from_client(client, Default::default())
.await?; .await?;
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
......
...@@ -30,7 +30,7 @@ use crate::preprocessor::PreprocessedRequest; ...@@ -30,7 +30,7 @@ use crate::preprocessor::PreprocessedRequest;
use crate::protocols::common::llm_backend::LLMEngineOutput; use crate::protocols::common::llm_backend::LLMEngineOutput;
use crate::protocols::openai::{ use crate::protocols::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{prompt_to_string, CompletionResponse, NvCreateCompletionRequest}, completions::{prompt_to_string, NvCreateCompletionRequest, NvCreateCompletionResponse},
}; };
use crate::types::openai::embeddings::NvCreateEmbeddingRequest; use crate::types::openai::embeddings::NvCreateEmbeddingRequest;
use crate::types::openai::embeddings::NvCreateEmbeddingResponse; use crate::types::openai::embeddings::NvCreateEmbeddingResponse;
...@@ -142,7 +142,7 @@ pub trait StreamingEngine: Send + Sync { ...@@ -142,7 +142,7 @@ pub trait StreamingEngine: Send + Sync {
async fn handle_completion( async fn handle_completion(
&self, &self,
req: SingleIn<NvCreateCompletionRequest>, req: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error>; ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error>;
async fn handle_chat( async fn handle_chat(
&self, &self,
...@@ -219,13 +219,17 @@ impl ...@@ -219,13 +219,17 @@ impl
} }
#[async_trait] #[async_trait]
impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error> impl
for EchoEngineFull AsyncEngine<
SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<NvCreateCompletionResponse>>,
Error,
> for EchoEngineFull
{ {
async fn generate( async fn generate(
&self, &self,
incoming_request: SingleIn<NvCreateCompletionRequest>, incoming_request: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
let (request, context) = incoming_request.transfer(()); let (request, context) = incoming_request.transfer(());
let deltas = request.response_generator(); let deltas = request.response_generator();
let ctx = context.context(); let ctx = context.context();
...@@ -268,7 +272,7 @@ impl<E> StreamingEngine for EngineDispatcher<E> ...@@ -268,7 +272,7 @@ impl<E> StreamingEngine for EngineDispatcher<E>
where where
E: AsyncEngine< E: AsyncEngine<
SingleIn<NvCreateCompletionRequest>, SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<CompletionResponse>>, ManyOut<Annotated<NvCreateCompletionResponse>>,
Error, Error,
> + AsyncEngine< > + AsyncEngine<
SingleIn<NvCreateChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
...@@ -284,7 +288,7 @@ where ...@@ -284,7 +288,7 @@ where
async fn handle_completion( async fn handle_completion(
&self, &self,
req: SingleIn<NvCreateCompletionRequest>, req: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
self.inner.generate(req).await self.inner.generate(req).await
} }
...@@ -347,13 +351,17 @@ impl StreamingEngineAdapter { ...@@ -347,13 +351,17 @@ impl StreamingEngineAdapter {
} }
#[async_trait] #[async_trait]
impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error> impl
for StreamingEngineAdapter AsyncEngine<
SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<NvCreateCompletionResponse>>,
Error,
> for StreamingEngineAdapter
{ {
async fn generate( async fn generate(
&self, &self,
req: SingleIn<NvCreateCompletionRequest>, req: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
self.0.handle_completion(req).await self.0.handle_completion(req).await
} }
} }
......
...@@ -30,7 +30,7 @@ use super::{ ...@@ -30,7 +30,7 @@ use super::{
use crate::preprocessor::LLMMetricAnnotation; use crate::preprocessor::LLMMetricAnnotation;
use crate::protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}; use crate::protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse};
use crate::protocols::openai::{ use crate::protocols::openai::{
chat_completions::NvCreateChatCompletionResponse, completions::CompletionResponse, chat_completions::NvCreateChatCompletionResponse, completions::NvCreateCompletionResponse,
}; };
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
use crate::types::{ use crate::types::{
...@@ -193,7 +193,7 @@ async fn completions( ...@@ -193,7 +193,7 @@ async fn completions(
Ok(sse_stream.into_response()) Ok(sse_stream.into_response())
} else { } else {
// TODO: report ISL/OSL for non-streaming requests // TODO: report ISL/OSL for non-streaming requests
let response = CompletionResponse::from_annotated_stream(stream.into()) let response = NvCreateCompletionResponse::from_annotated_stream(stream.into())
.await .await
.map_err(|e| { .map_err(|e| {
tracing::error!( tracing::error!(
......
...@@ -46,7 +46,7 @@ use crate::protocols::{ ...@@ -46,7 +46,7 @@ use crate::protocols::{
common::{SamplingOptionsProvider, StopConditionsProvider}, common::{SamplingOptionsProvider, StopConditionsProvider},
openai::{ openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{CompletionResponse, NvCreateCompletionRequest}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
nvext::NvExtProvider, nvext::NvExtProvider,
DeltaGeneratorExt, DeltaGeneratorExt,
}, },
...@@ -433,7 +433,7 @@ impl ...@@ -433,7 +433,7 @@ impl
impl impl
Operator< Operator<
SingleIn<NvCreateCompletionRequest>, SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<CompletionResponse>>, ManyOut<Annotated<NvCreateCompletionResponse>>,
SingleIn<PreprocessedRequest>, SingleIn<PreprocessedRequest>,
ManyOut<Annotated<BackendOutput>>, ManyOut<Annotated<BackendOutput>>,
> for OpenAIPreprocessor > for OpenAIPreprocessor
...@@ -448,7 +448,7 @@ impl ...@@ -448,7 +448,7 @@ impl
Error, Error,
>, >,
>, >,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
// unpack the request // unpack the request
let (request, context) = request.into_parts(); let (request, context) = request.into_parts();
...@@ -465,7 +465,7 @@ impl ...@@ -465,7 +465,7 @@ impl
let common_request = context.map(|_| common_request); let common_request = context.map(|_| common_request);
// create a stream of annotations this will be prepend to the response stream // create a stream of annotations this will be prepend to the response stream
let annotations: Vec<Annotated<CompletionResponse>> = annotations let annotations: Vec<Annotated<NvCreateCompletionResponse>> = annotations
.into_iter() .into_iter()
.flat_map(|(k, v)| Annotated::from_annotation(k, &v)) .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
.collect(); .collect();
......
...@@ -39,41 +39,10 @@ pub struct NvCreateCompletionRequest { ...@@ -39,41 +39,10 @@ pub struct NvCreateCompletionRequest {
pub nvext: Option<NvExt>, pub nvext: Option<NvExt>,
} }
/// Legacy OpenAI CompletionResponse #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
/// Represents a completion response from the API. pub struct NvCreateCompletionResponse {
/// Note: both the streamed and non-streamed response objects share the same #[serde(flatten)]
/// shape (unlike the chat endpoint). pub inner: async_openai::types::CreateCompletionResponse,
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct CompletionResponse {
/// A unique identifier for the completion.
pub id: String,
/// The list of completion choices the model generated for the input prompt.
pub choices: Vec<async_openai::types::Choice>,
/// The Unix timestamp (in seconds) of when the completion was created.
pub created: u64,
/// The model used for completion.
pub model: String,
/// The object type, which is always "text_completion"
pub object: String,
/// Usage statistics for the completion request.
pub usage: Option<async_openai::types::CompletionUsage>,
/// This fingerprint represents the backend configuration that the model runs with.
/// Can be used in conjunction with the seed request parameter to understand when backend
/// changes have been made that might impact determinism.
///
/// NIM Compatibility:
/// This field is not supported by the NIM; however it will be added in the future.
/// The optional nature of this field will be relaxed when it is supported.
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
// TODO(ryan)
// pub nvext: Option<NimResponseExt>,
} }
impl ContentProvider for async_openai::types::Choice { impl ContentProvider for async_openai::types::Choice {
...@@ -205,16 +174,17 @@ impl ResponseFactory { ...@@ -205,16 +174,17 @@ impl ResponseFactory {
&self, &self,
choice: async_openai::types::Choice, choice: async_openai::types::Choice,
usage: Option<async_openai::types::CompletionUsage>, usage: Option<async_openai::types::CompletionUsage>,
) -> CompletionResponse { ) -> NvCreateCompletionResponse {
CompletionResponse { let inner = async_openai::types::CreateCompletionResponse {
id: self.id.clone(), id: self.id.clone(),
object: self.object.clone(), object: self.object.clone(),
created: self.created, created: self.created as u32,
model: self.model.clone(), model: self.model.clone(),
choices: vec![choice], choices: vec![choice],
system_fingerprint: self.system_fingerprint.clone(), system_fingerprint: self.system_fingerprint.clone(),
usage, usage,
} };
NvCreateCompletionResponse { inner }
} }
} }
......
...@@ -18,7 +18,7 @@ use std::collections::HashMap; ...@@ -18,7 +18,7 @@ use std::collections::HashMap;
use anyhow::Result; use anyhow::Result;
use futures::StreamExt; use futures::StreamExt;
use super::CompletionResponse; use super::NvCreateCompletionResponse;
use crate::protocols::{ use crate::protocols::{
codec::{Message, SseCodecError}, codec::{Message, SseCodecError},
common::FinishReason, common::FinishReason,
...@@ -64,8 +64,8 @@ impl DeltaAggregator { ...@@ -64,8 +64,8 @@ impl DeltaAggregator {
/// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`]. /// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`].
pub async fn apply( pub async fn apply(
stream: DataStream<Annotated<CompletionResponse>>, stream: DataStream<Annotated<NvCreateCompletionResponse>>,
) -> Result<CompletionResponse> { ) -> Result<NvCreateCompletionResponse> {
let aggregator = stream let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move { .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
let delta = match delta.ok() { let delta = match delta.ok() {
...@@ -83,18 +83,18 @@ impl DeltaAggregator { ...@@ -83,18 +83,18 @@ impl DeltaAggregator {
// these are cheap to move so we do it every time since we are consuming the delta // these are cheap to move so we do it every time since we are consuming the delta
let delta = delta.data.unwrap(); let delta = delta.data.unwrap();
aggregator.id = delta.id; aggregator.id = delta.inner.id;
aggregator.model = delta.model; aggregator.model = delta.inner.model;
aggregator.created = delta.created; aggregator.created = delta.inner.created as u64;
if let Some(usage) = delta.usage { if let Some(usage) = delta.inner.usage {
aggregator.usage = Some(usage); aggregator.usage = Some(usage);
} }
if let Some(system_fingerprint) = delta.system_fingerprint { if let Some(system_fingerprint) = delta.inner.system_fingerprint {
aggregator.system_fingerprint = Some(system_fingerprint); aggregator.system_fingerprint = Some(system_fingerprint);
} }
// handle the choices // handle the choices
for choice in delta.choices { for choice in delta.inner.choices {
let state_choice = let state_choice =
aggregator aggregator
.choices .choices
...@@ -145,15 +145,19 @@ impl DeltaAggregator { ...@@ -145,15 +145,19 @@ impl DeltaAggregator {
choices.sort_by(|a, b| a.index.cmp(&b.index)); choices.sort_by(|a, b| a.index.cmp(&b.index));
Ok(CompletionResponse { let inner = async_openai::types::CreateCompletionResponse {
id: aggregator.id, id: aggregator.id,
created: aggregator.created, created: aggregator.created as u32,
usage: aggregator.usage, usage: aggregator.usage,
model: aggregator.model, model: aggregator.model,
object: "text_completion".to_string(), object: "text_completion".to_string(),
system_fingerprint: aggregator.system_fingerprint, system_fingerprint: aggregator.system_fingerprint,
choices, choices,
}) };
let response = NvCreateCompletionResponse { inner };
Ok(response)
} }
} }
...@@ -170,17 +174,17 @@ impl From<DeltaChoice> for async_openai::types::Choice { ...@@ -170,17 +174,17 @@ impl From<DeltaChoice> for async_openai::types::Choice {
} }
} }
impl CompletionResponse { impl NvCreateCompletionResponse {
pub async fn from_sse_stream( pub async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>, stream: DataStream<Result<Message, SseCodecError>>,
) -> Result<CompletionResponse> { ) -> Result<NvCreateCompletionResponse> {
let stream = convert_sse_stream::<CompletionResponse>(stream); let stream = convert_sse_stream::<NvCreateCompletionResponse>(stream);
CompletionResponse::from_annotated_stream(stream).await NvCreateCompletionResponse::from_annotated_stream(stream).await
} }
pub async fn from_annotated_stream( pub async fn from_annotated_stream(
stream: DataStream<Annotated<CompletionResponse>>, stream: DataStream<Annotated<NvCreateCompletionResponse>>,
) -> Result<CompletionResponse> { ) -> Result<NvCreateCompletionResponse> {
DeltaAggregator::apply(stream).await DeltaAggregator::apply(stream).await
} }
} }
...@@ -192,13 +196,13 @@ mod tests { ...@@ -192,13 +196,13 @@ mod tests {
use futures::stream; use futures::stream;
use super::*; use super::*;
use crate::protocols::openai::completions::CompletionResponse; use crate::protocols::openai::completions::NvCreateCompletionResponse;
fn create_test_delta( fn create_test_delta(
index: u64, index: u64,
text: &str, text: &str,
finish_reason: Option<String>, finish_reason: Option<String>,
) -> Annotated<CompletionResponse> { ) -> Annotated<NvCreateCompletionResponse> {
// This will silently discard invalid_finish reason values and fall back // This will silently discard invalid_finish reason values and fall back
// to None - totally fine since this is test code // to None - totally fine since this is test code
let finish_reason = finish_reason let finish_reason = finish_reason
...@@ -206,8 +210,7 @@ mod tests { ...@@ -206,8 +210,7 @@ mod tests {
.and_then(|s| FinishReason::from_str(s).ok()) .and_then(|s| FinishReason::from_str(s).ok())
.map(Into::into); .map(Into::into);
Annotated { let inner = async_openai::types::CreateCompletionResponse {
data: Some(CompletionResponse {
id: "test_id".to_string(), id: "test_id".to_string(),
model: "meta/llama-3.1-8b".to_string(), model: "meta/llama-3.1-8b".to_string(),
created: 1234567890, created: 1234567890,
...@@ -220,7 +223,12 @@ mod tests { ...@@ -220,7 +223,12 @@ mod tests {
logprobs: None, logprobs: None,
}], }],
object: "text_completion".to_string(), object: "text_completion".to_string(),
}), };
let response = NvCreateCompletionResponse { inner };
Annotated {
data: Some(response),
id: Some("test_id".to_string()), id: Some("test_id".to_string()),
event: None, event: None,
comment: None, comment: None,
...@@ -230,7 +238,7 @@ mod tests { ...@@ -230,7 +238,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_empty_stream() { async fn test_empty_stream() {
// Create an empty stream // Create an empty stream
let stream: DataStream<Annotated<CompletionResponse>> = Box::pin(stream::empty()); let stream: DataStream<Annotated<NvCreateCompletionResponse>> = Box::pin(stream::empty());
// Call DeltaAggregator::apply // Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await; let result = DeltaAggregator::apply(stream).await;
...@@ -240,12 +248,12 @@ mod tests { ...@@ -240,12 +248,12 @@ mod tests {
let response = result.unwrap(); let response = result.unwrap();
// Verify that the response is empty and has default values // Verify that the response is empty and has default values
assert_eq!(response.id, ""); assert_eq!(response.inner.id, "");
assert_eq!(response.model, ""); assert_eq!(response.inner.model, "");
assert_eq!(response.created, 0); assert_eq!(response.inner.created, 0);
assert!(response.usage.is_none()); assert!(response.inner.usage.is_none());
assert!(response.system_fingerprint.is_none()); assert!(response.inner.system_fingerprint.is_none());
assert_eq!(response.choices.len(), 0); assert_eq!(response.inner.choices.len(), 0);
} }
#[tokio::test] #[tokio::test]
...@@ -264,19 +272,23 @@ mod tests { ...@@ -264,19 +272,23 @@ mod tests {
let response = result.unwrap(); let response = result.unwrap();
// Verify the response fields // Verify the response fields
assert_eq!(response.id, "test_id"); assert_eq!(response.inner.id, "test_id");
assert_eq!(response.model, "meta/llama-3.1-8b"); assert_eq!(response.inner.model, "meta/llama-3.1-8b");
assert_eq!(response.created, 1234567890); assert_eq!(response.inner.created, 1234567890);
assert!(response.usage.is_none()); assert!(response.inner.usage.is_none());
assert!(response.system_fingerprint.is_none()); assert!(response.inner.system_fingerprint.is_none());
assert_eq!(response.choices.len(), 1); assert_eq!(response.inner.choices.len(), 1);
let choice = &response.choices[0]; let choice = &response.inner.choices[0];
assert_eq!(choice.index, 0); assert_eq!(choice.index, 0);
assert_eq!(choice.text, "Hello,".to_string()); assert_eq!(choice.text, "Hello,".to_string());
assert_eq!( assert_eq!(
choice.finish_reason, choice.finish_reason,
Some(async_openai::types::CompletionFinishReason::Length) Some(async_openai::types::CompletionFinishReason::Length)
); );
assert_eq!(
choice.finish_reason,
Some(async_openai::types::CompletionFinishReason::Length)
);
assert!(choice.logprobs.is_none()); assert!(choice.logprobs.is_none());
} }
...@@ -300,21 +312,24 @@ mod tests { ...@@ -300,21 +312,24 @@ mod tests {
let response = result.unwrap(); let response = result.unwrap();
// Verify the response fields // Verify the response fields
assert_eq!(response.choices.len(), 1); assert_eq!(response.inner.choices.len(), 1);
let choice = &response.choices[0]; let choice = &response.inner.choices[0];
assert_eq!(choice.index, 0); assert_eq!(choice.index, 0);
assert_eq!(choice.text, "Hello, world!".to_string()); assert_eq!(choice.text, "Hello, world!".to_string());
assert_eq!( assert_eq!(
choice.finish_reason, choice.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop) Some(async_openai::types::CompletionFinishReason::Stop)
); );
assert_eq!(
choice.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop)
);
} }
#[tokio::test] #[tokio::test]
async fn test_multiple_choices() { async fn test_multiple_choices() {
// Create a delta with multiple choices // Create a delta with multiple choices
let annotated_delta = Annotated { let inner = async_openai::types::CreateCompletionResponse {
data: Some(CompletionResponse {
id: "test_id".to_string(), id: "test_id".to_string(),
model: "meta/llama-3.1-8b".to_string(), model: "meta/llama-3.1-8b".to_string(),
created: 1234567890, created: 1234567890,
...@@ -335,7 +350,12 @@ mod tests { ...@@ -335,7 +350,12 @@ mod tests {
}, },
], ],
object: "text_completion".to_string(), object: "text_completion".to_string(),
}), };
let response = NvCreateCompletionResponse { inner };
let annotated_delta = Annotated {
data: Some(response),
id: Some("test_id".to_string()), id: Some("test_id".to_string()),
event: None, event: None,
comment: None, comment: None,
...@@ -352,22 +372,30 @@ mod tests { ...@@ -352,22 +372,30 @@ mod tests {
let mut response = result.unwrap(); let mut response = result.unwrap();
// Verify the response fields // Verify the response fields
assert_eq!(response.choices.len(), 2); assert_eq!(response.inner.choices.len(), 2);
response.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered response.inner.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered
let choice0 = &response.choices[0]; let choice0 = &response.inner.choices[0];
assert_eq!(choice0.index, 0); assert_eq!(choice0.index, 0);
assert_eq!(choice0.text, "Choice 0".to_string()); assert_eq!(choice0.text, "Choice 0".to_string());
assert_eq!( assert_eq!(
choice0.finish_reason, choice0.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop) Some(async_openai::types::CompletionFinishReason::Stop)
); );
assert_eq!(
choice0.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop)
);
let choice1 = &response.choices[1]; let choice1 = &response.inner.choices[1];
assert_eq!(choice1.index, 1); assert_eq!(choice1.index, 1);
assert_eq!(choice1.text, "Choice 1".to_string()); assert_eq!(choice1.text, "Choice 1".to_string());
assert_eq!( assert_eq!(
choice1.finish_reason, choice1.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop) Some(async_openai::types::CompletionFinishReason::Stop)
); );
assert_eq!(
choice1.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop)
);
} }
} }
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use super::{CompletionResponse, NvCreateCompletionRequest}; use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
use crate::protocols::common; use crate::protocols::common;
impl NvCreateCompletionRequest { impl NvCreateCompletionRequest {
...@@ -83,7 +83,7 @@ impl DeltaGenerator { ...@@ -83,7 +83,7 @@ impl DeltaGenerator {
index: u64, index: u64,
text: Option<String>, text: Option<String>,
finish_reason: Option<async_openai::types::CompletionFinishReason>, finish_reason: Option<async_openai::types::CompletionFinishReason>,
) -> CompletionResponse { ) -> NvCreateCompletionResponse {
// todo - update for tool calling // todo - update for tool calling
let mut usage = self.usage.clone(); let mut usage = self.usage.clone();
...@@ -91,10 +91,10 @@ impl DeltaGenerator { ...@@ -91,10 +91,10 @@ impl DeltaGenerator {
usage.total_tokens = usage.prompt_tokens + usage.completion_tokens; usage.total_tokens = usage.prompt_tokens + usage.completion_tokens;
} }
CompletionResponse { let inner = async_openai::types::CreateCompletionResponse {
id: self.id.clone(), id: self.id.clone(),
object: self.object.clone(), object: self.object.clone(),
created: self.created, created: self.created as u32,
model: self.model.clone(), model: self.model.clone(),
system_fingerprint: self.system_fingerprint.clone(), system_fingerprint: self.system_fingerprint.clone(),
choices: vec![async_openai::types::Choice { choices: vec![async_openai::types::Choice {
...@@ -108,15 +108,17 @@ impl DeltaGenerator { ...@@ -108,15 +108,17 @@ impl DeltaGenerator {
} else { } else {
None None
}, },
} };
NvCreateCompletionResponse { inner }
} }
} }
impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGenerator { impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for DeltaGenerator {
fn choice_from_postprocessor( fn choice_from_postprocessor(
&mut self, &mut self,
delta: common::llm_backend::BackendOutput, delta: common::llm_backend::BackendOutput,
) -> anyhow::Result<CompletionResponse> { ) -> anyhow::Result<NvCreateCompletionResponse> {
// aggregate usage // aggregate usage
if self.options.enable_usage { if self.options.enable_usage {
self.usage.completion_tokens += delta.token_ids.len() as u32; self.usage.completion_tokens += delta.token_ids.len() as u32;
......
...@@ -24,15 +24,17 @@ pub mod openai { ...@@ -24,15 +24,17 @@ pub mod openai {
pub mod completions { pub mod completions {
use super::*; use super::*;
pub use protocols::openai::completions::{CompletionResponse, NvCreateCompletionRequest}; pub use protocols::openai::completions::{
NvCreateCompletionRequest, NvCreateCompletionResponse,
};
/// A [`UnaryEngine`] implementation for the OpenAI Completions API /// A [`UnaryEngine`] implementation for the OpenAI Completions API
pub type OpenAICompletionsUnaryEngine = pub type OpenAICompletionsUnaryEngine =
UnaryEngine<NvCreateCompletionRequest, CompletionResponse>; UnaryEngine<NvCreateCompletionRequest, NvCreateCompletionResponse>;
/// A [`ServerStreamingEngine`] implementation for the OpenAI Completions API /// A [`ServerStreamingEngine`] implementation for the OpenAI Completions API
pub type OpenAICompletionsStreamingEngine = pub type OpenAICompletionsStreamingEngine =
ServerStreamingEngine<NvCreateCompletionRequest, Annotated<CompletionResponse>>; ServerStreamingEngine<NvCreateCompletionRequest, Annotated<NvCreateCompletionResponse>>;
} }
pub mod chat_completions { pub mod chat_completions {
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
use dynamo_llm::protocols::{ use dynamo_llm::protocols::{
codec::{create_message_stream, Message, SseCodecError}, codec::{create_message_stream, Message, SseCodecError},
openai::{chat_completions::NvCreateChatCompletionResponse, completions::CompletionResponse}, openai::{
chat_completions::NvCreateChatCompletionResponse, completions::NvCreateCompletionResponse,
},
ContentProvider, DataStream, ContentProvider, DataStream,
}; };
use futures::StreamExt; use futures::StreamExt;
...@@ -112,13 +114,13 @@ async fn test_openai_chat_edge_case_invalid_deserialize_error() { ...@@ -112,13 +114,13 @@ async fn test_openai_chat_edge_case_invalid_deserialize_error() {
#[tokio::test] #[tokio::test]
async fn test_openai_cmpl_stream() { async fn test_openai_cmpl_stream() {
let stream = create_stream(CMPL_ROOT_PATH, "completion.streaming.1").take(16); let stream = create_stream(CMPL_ROOT_PATH, "completion.streaming.1").take(16);
let result = CompletionResponse::from_sse_stream(Box::pin(stream)) let result = NvCreateCompletionResponse::from_sse_stream(Box::pin(stream))
.await .await
.unwrap(); .unwrap();
// todo: provide a cleaner way to extract the content from choices // todo: provide a cleaner way to extract the content from choices
assert_eq!( assert_eq!(
result.choices.first().unwrap().content(), result.inner.choices.first().unwrap().content(),
" This is a question that is often asked by those outside of AI research and development" " This is a question that is often asked by those outside of AI research and development"
); );
} }
...@@ -24,7 +24,7 @@ use dynamo_llm::http::service::{ ...@@ -24,7 +24,7 @@ use dynamo_llm::http::service::{
use dynamo_llm::protocols::{ use dynamo_llm::protocols::{
openai::{ openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{CompletionResponse, NvCreateCompletionRequest}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
}, },
Annotated, Annotated,
}; };
...@@ -101,13 +101,17 @@ impl ...@@ -101,13 +101,17 @@ impl
} }
#[async_trait] #[async_trait]
impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error> impl
for AlwaysFailEngine AsyncEngine<
SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<NvCreateCompletionResponse>>,
Error,
> for AlwaysFailEngine
{ {
async fn generate( async fn generate(
&self, &self,
_request: SingleIn<NvCreateCompletionRequest>, _request: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
Err(HttpError { Err(HttpError {
code: 401, code: 401,
message: "Always fail".to_string(), message: "Always fail".to_string(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment