Unverified Commit e3f1bd5d authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: refactored using CompletionResponse (#1658)

parent 7b7b6a6d
......@@ -139,7 +139,7 @@ mod tests {
use super::*;
use dynamo_llm::types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{CompletionResponse, NvCreateCompletionRequest},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
};
const HF_PATH: &str = concat!(
......@@ -174,7 +174,8 @@ mod tests {
// Build pipeline for completions
let pipeline =
build_pipeline::<NvCreateCompletionRequest, CompletionResponse>(&card, engine).await?;
build_pipeline::<NvCreateCompletionRequest, NvCreateCompletionResponse>(&card, engine)
.await?;
// Verify pipeline was created
assert!(Arc::strong_count(&pipeline) >= 1);
......
......@@ -15,7 +15,7 @@ use dynamo_llm::{
openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
},
openai::completions::{CompletionResponse, NvCreateCompletionRequest},
openai::completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
};
use dynamo_runtime::pipeline::RouterMode;
......@@ -78,7 +78,7 @@ pub async fn run(
let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest,
CompletionResponse,
NvCreateCompletionResponse,
>(model.card(), inner_engine)
.await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
......
......@@ -25,7 +25,7 @@ use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_llm::protocols::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{prompt_to_string, CompletionResponse, NvCreateCompletionRequest},
completions::{prompt_to_string, NvCreateCompletionRequest, NvCreateCompletionResponse},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
};
......@@ -467,13 +467,17 @@ fn to_logit_bias(lb: HashMap<String, serde_json::Value>) -> HashMap<u32, f32> {
}
#[async_trait]
impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error>
for MistralRsEngine
impl
AsyncEngine<
SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<NvCreateCompletionResponse>>,
Error,
> for MistralRsEngine
{
async fn generate(
&self,
request: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
let (request, context) = request.transfer(());
let ctx = context.context();
let (tx, mut rx) = channel(10_000);
......
......@@ -25,7 +25,7 @@ use crate::{
protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
},
protocols::openai::completions::{CompletionResponse, NvCreateCompletionRequest},
protocols::openai::completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
};
......@@ -240,7 +240,7 @@ impl ModelWatcher {
let frontend = SegmentSource::<
SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<CompletionResponse>>,
ManyOut<Annotated<NvCreateCompletionResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
let backend = Backend::from_mdc(card.clone()).await?.into_operator();
......@@ -292,7 +292,7 @@ impl ModelWatcher {
ModelType::Completion => {
let push_router = PushRouter::<
NvCreateCompletionRequest,
Annotated<CompletionResponse>,
Annotated<NvCreateCompletionResponse>,
>::from_client(client, Default::default())
.await?;
let engine = Arc::new(push_router);
......
......@@ -30,7 +30,7 @@ use crate::preprocessor::PreprocessedRequest;
use crate::protocols::common::llm_backend::LLMEngineOutput;
use crate::protocols::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{prompt_to_string, CompletionResponse, NvCreateCompletionRequest},
completions::{prompt_to_string, NvCreateCompletionRequest, NvCreateCompletionResponse},
};
use crate::types::openai::embeddings::NvCreateEmbeddingRequest;
use crate::types::openai::embeddings::NvCreateEmbeddingResponse;
......@@ -142,7 +142,7 @@ pub trait StreamingEngine: Send + Sync {
async fn handle_completion(
&self,
req: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error>;
) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error>;
async fn handle_chat(
&self,
......@@ -219,13 +219,17 @@ impl
}
#[async_trait]
impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error>
for EchoEngineFull
impl
AsyncEngine<
SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<NvCreateCompletionResponse>>,
Error,
> for EchoEngineFull
{
async fn generate(
&self,
incoming_request: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
let (request, context) = incoming_request.transfer(());
let deltas = request.response_generator();
let ctx = context.context();
......@@ -268,7 +272,7 @@ impl<E> StreamingEngine for EngineDispatcher<E>
where
E: AsyncEngine<
SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<CompletionResponse>>,
ManyOut<Annotated<NvCreateCompletionResponse>>,
Error,
> + AsyncEngine<
SingleIn<NvCreateChatCompletionRequest>,
......@@ -284,7 +288,7 @@ where
async fn handle_completion(
&self,
req: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
self.inner.generate(req).await
}
......@@ -347,13 +351,17 @@ impl StreamingEngineAdapter {
}
#[async_trait]
impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error>
for StreamingEngineAdapter
impl
AsyncEngine<
SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<NvCreateCompletionResponse>>,
Error,
> for StreamingEngineAdapter
{
async fn generate(
&self,
req: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
self.0.handle_completion(req).await
}
}
......
......@@ -30,7 +30,7 @@ use super::{
use crate::preprocessor::LLMMetricAnnotation;
use crate::protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse};
use crate::protocols::openai::{
chat_completions::NvCreateChatCompletionResponse, completions::CompletionResponse,
chat_completions::NvCreateChatCompletionResponse, completions::NvCreateCompletionResponse,
};
use crate::request_template::RequestTemplate;
use crate::types::{
......@@ -193,7 +193,7 @@ async fn completions(
Ok(sse_stream.into_response())
} else {
// TODO: report ISL/OSL for non-streaming requests
let response = CompletionResponse::from_annotated_stream(stream.into())
let response = NvCreateCompletionResponse::from_annotated_stream(stream.into())
.await
.map_err(|e| {
tracing::error!(
......
......@@ -46,7 +46,7 @@ use crate::protocols::{
common::{SamplingOptionsProvider, StopConditionsProvider},
openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{CompletionResponse, NvCreateCompletionRequest},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
nvext::NvExtProvider,
DeltaGeneratorExt,
},
......@@ -433,7 +433,7 @@ impl
impl
Operator<
SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<CompletionResponse>>,
ManyOut<Annotated<NvCreateCompletionResponse>>,
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<BackendOutput>>,
> for OpenAIPreprocessor
......@@ -448,7 +448,7 @@ impl
Error,
>,
>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
// unpack the request
let (request, context) = request.into_parts();
......@@ -465,7 +465,7 @@ impl
let common_request = context.map(|_| common_request);
// create a stream of annotations this will be prepend to the response stream
let annotations: Vec<Annotated<CompletionResponse>> = annotations
let annotations: Vec<Annotated<NvCreateCompletionResponse>> = annotations
.into_iter()
.flat_map(|(k, v)| Annotated::from_annotation(k, &v))
.collect();
......
......@@ -39,41 +39,10 @@ pub struct NvCreateCompletionRequest {
pub nvext: Option<NvExt>,
}
/// Legacy OpenAI CompletionResponse
/// Represents a completion response from the API.
/// Note: both the streamed and non-streamed response objects share the same
/// shape (unlike the chat endpoint).
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct CompletionResponse {
/// A unique identifier for the completion.
pub id: String,
/// The list of completion choices the model generated for the input prompt.
pub choices: Vec<async_openai::types::Choice>,
/// The Unix timestamp (in seconds) of when the completion was created.
pub created: u64,
/// The model used for completion.
pub model: String,
/// The object type, which is always "text_completion"
pub object: String,
/// Usage statistics for the completion request.
pub usage: Option<async_openai::types::CompletionUsage>,
/// This fingerprint represents the backend configuration that the model runs with.
/// Can be used in conjunction with the seed request parameter to understand when backend
/// changes have been made that might impact determinism.
///
/// NIM Compatibility:
/// This field is not supported by the NIM; however it will be added in the future.
/// The optional nature of this field will be relaxed when it is supported.
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
// TODO(ryan)
// pub nvext: Option<NimResponseExt>,
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct NvCreateCompletionResponse {
#[serde(flatten)]
pub inner: async_openai::types::CreateCompletionResponse,
}
impl ContentProvider for async_openai::types::Choice {
......@@ -205,16 +174,17 @@ impl ResponseFactory {
&self,
choice: async_openai::types::Choice,
usage: Option<async_openai::types::CompletionUsage>,
) -> CompletionResponse {
CompletionResponse {
) -> NvCreateCompletionResponse {
let inner = async_openai::types::CreateCompletionResponse {
id: self.id.clone(),
object: self.object.clone(),
created: self.created,
created: self.created as u32,
model: self.model.clone(),
choices: vec![choice],
system_fingerprint: self.system_fingerprint.clone(),
usage,
}
};
NvCreateCompletionResponse { inner }
}
}
......
......@@ -18,7 +18,7 @@ use std::collections::HashMap;
use anyhow::Result;
use futures::StreamExt;
use super::CompletionResponse;
use super::NvCreateCompletionResponse;
use crate::protocols::{
codec::{Message, SseCodecError},
common::FinishReason,
......@@ -64,8 +64,8 @@ impl DeltaAggregator {
/// Aggregates a stream of [`Annotated<CompletionResponse>`]s into a single [`CompletionResponse`].
pub async fn apply(
stream: DataStream<Annotated<CompletionResponse>>,
) -> Result<CompletionResponse> {
stream: DataStream<Annotated<NvCreateCompletionResponse>>,
) -> Result<NvCreateCompletionResponse> {
let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
let delta = match delta.ok() {
......@@ -83,18 +83,18 @@ impl DeltaAggregator {
// these are cheap to move so we do it every time since we are consuming the delta
let delta = delta.data.unwrap();
aggregator.id = delta.id;
aggregator.model = delta.model;
aggregator.created = delta.created;
if let Some(usage) = delta.usage {
aggregator.id = delta.inner.id;
aggregator.model = delta.inner.model;
aggregator.created = delta.inner.created as u64;
if let Some(usage) = delta.inner.usage {
aggregator.usage = Some(usage);
}
if let Some(system_fingerprint) = delta.system_fingerprint {
if let Some(system_fingerprint) = delta.inner.system_fingerprint {
aggregator.system_fingerprint = Some(system_fingerprint);
}
// handle the choices
for choice in delta.choices {
for choice in delta.inner.choices {
let state_choice =
aggregator
.choices
......@@ -145,15 +145,19 @@ impl DeltaAggregator {
choices.sort_by(|a, b| a.index.cmp(&b.index));
Ok(CompletionResponse {
let inner = async_openai::types::CreateCompletionResponse {
id: aggregator.id,
created: aggregator.created,
created: aggregator.created as u32,
usage: aggregator.usage,
model: aggregator.model,
object: "text_completion".to_string(),
system_fingerprint: aggregator.system_fingerprint,
choices,
})
};
let response = NvCreateCompletionResponse { inner };
Ok(response)
}
}
......@@ -170,17 +174,17 @@ impl From<DeltaChoice> for async_openai::types::Choice {
}
}
impl CompletionResponse {
impl NvCreateCompletionResponse {
pub async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>,
) -> Result<CompletionResponse> {
let stream = convert_sse_stream::<CompletionResponse>(stream);
CompletionResponse::from_annotated_stream(stream).await
) -> Result<NvCreateCompletionResponse> {
let stream = convert_sse_stream::<NvCreateCompletionResponse>(stream);
NvCreateCompletionResponse::from_annotated_stream(stream).await
}
pub async fn from_annotated_stream(
stream: DataStream<Annotated<CompletionResponse>>,
) -> Result<CompletionResponse> {
stream: DataStream<Annotated<NvCreateCompletionResponse>>,
) -> Result<NvCreateCompletionResponse> {
DeltaAggregator::apply(stream).await
}
}
......@@ -192,13 +196,13 @@ mod tests {
use futures::stream;
use super::*;
use crate::protocols::openai::completions::CompletionResponse;
use crate::protocols::openai::completions::NvCreateCompletionResponse;
fn create_test_delta(
index: u64,
text: &str,
finish_reason: Option<String>,
) -> Annotated<CompletionResponse> {
) -> Annotated<NvCreateCompletionResponse> {
// This will silently discard invalid_finish reason values and fall back
// to None - totally fine since this is test code
let finish_reason = finish_reason
......@@ -206,21 +210,25 @@ mod tests {
.and_then(|s| FinishReason::from_str(s).ok())
.map(Into::into);
let inner = async_openai::types::CreateCompletionResponse {
id: "test_id".to_string(),
model: "meta/llama-3.1-8b".to_string(),
created: 1234567890,
usage: None,
system_fingerprint: None,
choices: vec![async_openai::types::Choice {
index: index as u32,
text: text.to_string(),
finish_reason,
logprobs: None,
}],
object: "text_completion".to_string(),
};
let response = NvCreateCompletionResponse { inner };
Annotated {
data: Some(CompletionResponse {
id: "test_id".to_string(),
model: "meta/llama-3.1-8b".to_string(),
created: 1234567890,
usage: None,
system_fingerprint: None,
choices: vec![async_openai::types::Choice {
index: index as u32,
text: text.to_string(),
finish_reason,
logprobs: None,
}],
object: "text_completion".to_string(),
}),
data: Some(response),
id: Some("test_id".to_string()),
event: None,
comment: None,
......@@ -230,7 +238,7 @@ mod tests {
#[tokio::test]
async fn test_empty_stream() {
// Create an empty stream
let stream: DataStream<Annotated<CompletionResponse>> = Box::pin(stream::empty());
let stream: DataStream<Annotated<NvCreateCompletionResponse>> = Box::pin(stream::empty());
// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
......@@ -240,12 +248,12 @@ mod tests {
let response = result.unwrap();
// Verify that the response is empty and has default values
assert_eq!(response.id, "");
assert_eq!(response.model, "");
assert_eq!(response.created, 0);
assert!(response.usage.is_none());
assert!(response.system_fingerprint.is_none());
assert_eq!(response.choices.len(), 0);
assert_eq!(response.inner.id, "");
assert_eq!(response.inner.model, "");
assert_eq!(response.inner.created, 0);
assert!(response.inner.usage.is_none());
assert!(response.inner.system_fingerprint.is_none());
assert_eq!(response.inner.choices.len(), 0);
}
#[tokio::test]
......@@ -264,19 +272,23 @@ mod tests {
let response = result.unwrap();
// Verify the response fields
assert_eq!(response.id, "test_id");
assert_eq!(response.model, "meta/llama-3.1-8b");
assert_eq!(response.created, 1234567890);
assert!(response.usage.is_none());
assert!(response.system_fingerprint.is_none());
assert_eq!(response.choices.len(), 1);
let choice = &response.choices[0];
assert_eq!(response.inner.id, "test_id");
assert_eq!(response.inner.model, "meta/llama-3.1-8b");
assert_eq!(response.inner.created, 1234567890);
assert!(response.inner.usage.is_none());
assert!(response.inner.system_fingerprint.is_none());
assert_eq!(response.inner.choices.len(), 1);
let choice = &response.inner.choices[0];
assert_eq!(choice.index, 0);
assert_eq!(choice.text, "Hello,".to_string());
assert_eq!(
choice.finish_reason,
Some(async_openai::types::CompletionFinishReason::Length)
);
assert_eq!(
choice.finish_reason,
Some(async_openai::types::CompletionFinishReason::Length)
);
assert!(choice.logprobs.is_none());
}
......@@ -300,42 +312,50 @@ mod tests {
let response = result.unwrap();
// Verify the response fields
assert_eq!(response.choices.len(), 1);
let choice = &response.choices[0];
assert_eq!(response.inner.choices.len(), 1);
let choice = &response.inner.choices[0];
assert_eq!(choice.index, 0);
assert_eq!(choice.text, "Hello, world!".to_string());
assert_eq!(
choice.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop)
);
assert_eq!(
choice.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop)
);
}
#[tokio::test]
async fn test_multiple_choices() {
// Create a delta with multiple choices
let inner = async_openai::types::CreateCompletionResponse {
id: "test_id".to_string(),
model: "meta/llama-3.1-8b".to_string(),
created: 1234567890,
usage: None,
system_fingerprint: None,
choices: vec![
async_openai::types::Choice {
index: 0,
text: "Choice 0".to_string(),
finish_reason: Some(async_openai::types::CompletionFinishReason::Stop),
logprobs: None,
},
async_openai::types::Choice {
index: 1,
text: "Choice 1".to_string(),
finish_reason: Some(async_openai::types::CompletionFinishReason::Stop),
logprobs: None,
},
],
object: "text_completion".to_string(),
};
let response = NvCreateCompletionResponse { inner };
let annotated_delta = Annotated {
data: Some(CompletionResponse {
id: "test_id".to_string(),
model: "meta/llama-3.1-8b".to_string(),
created: 1234567890,
usage: None,
system_fingerprint: None,
choices: vec![
async_openai::types::Choice {
index: 0,
text: "Choice 0".to_string(),
finish_reason: Some(async_openai::types::CompletionFinishReason::Stop),
logprobs: None,
},
async_openai::types::Choice {
index: 1,
text: "Choice 1".to_string(),
finish_reason: Some(async_openai::types::CompletionFinishReason::Stop),
logprobs: None,
},
],
object: "text_completion".to_string(),
}),
data: Some(response),
id: Some("test_id".to_string()),
event: None,
comment: None,
......@@ -352,22 +372,30 @@ mod tests {
let mut response = result.unwrap();
// Verify the response fields
assert_eq!(response.choices.len(), 2);
response.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered
let choice0 = &response.choices[0];
assert_eq!(response.inner.choices.len(), 2);
response.inner.choices.sort_by(|a, b| a.index.cmp(&b.index)); // Ensure the choices are ordered
let choice0 = &response.inner.choices[0];
assert_eq!(choice0.index, 0);
assert_eq!(choice0.text, "Choice 0".to_string());
assert_eq!(
choice0.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop)
);
assert_eq!(
choice0.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop)
);
let choice1 = &response.choices[1];
let choice1 = &response.inner.choices[1];
assert_eq!(choice1.index, 1);
assert_eq!(choice1.text, "Choice 1".to_string());
assert_eq!(
choice1.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop)
);
assert_eq!(
choice1.finish_reason,
Some(async_openai::types::CompletionFinishReason::Stop)
);
}
}
......@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use super::{CompletionResponse, NvCreateCompletionRequest};
use super::{NvCreateCompletionRequest, NvCreateCompletionResponse};
use crate::protocols::common;
impl NvCreateCompletionRequest {
......@@ -83,7 +83,7 @@ impl DeltaGenerator {
index: u64,
text: Option<String>,
finish_reason: Option<async_openai::types::CompletionFinishReason>,
) -> CompletionResponse {
) -> NvCreateCompletionResponse {
// todo - update for tool calling
let mut usage = self.usage.clone();
......@@ -91,10 +91,10 @@ impl DeltaGenerator {
usage.total_tokens = usage.prompt_tokens + usage.completion_tokens;
}
CompletionResponse {
let inner = async_openai::types::CreateCompletionResponse {
id: self.id.clone(),
object: self.object.clone(),
created: self.created,
created: self.created as u32,
model: self.model.clone(),
system_fingerprint: self.system_fingerprint.clone(),
choices: vec![async_openai::types::Choice {
......@@ -108,15 +108,17 @@ impl DeltaGenerator {
} else {
None
},
}
};
NvCreateCompletionResponse { inner }
}
}
impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGenerator {
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for DeltaGenerator {
fn choice_from_postprocessor(
&mut self,
delta: common::llm_backend::BackendOutput,
) -> anyhow::Result<CompletionResponse> {
) -> anyhow::Result<NvCreateCompletionResponse> {
// aggregate usage
if self.options.enable_usage {
self.usage.completion_tokens += delta.token_ids.len() as u32;
......
......@@ -24,15 +24,17 @@ pub mod openai {
pub mod completions {
use super::*;
pub use protocols::openai::completions::{CompletionResponse, NvCreateCompletionRequest};
pub use protocols::openai::completions::{
NvCreateCompletionRequest, NvCreateCompletionResponse,
};
/// A [`UnaryEngine`] implementation for the OpenAI Completions API
pub type OpenAICompletionsUnaryEngine =
UnaryEngine<NvCreateCompletionRequest, CompletionResponse>;
UnaryEngine<NvCreateCompletionRequest, NvCreateCompletionResponse>;
/// A [`ServerStreamingEngine`] implementation for the OpenAI Completions API
pub type OpenAICompletionsStreamingEngine =
ServerStreamingEngine<NvCreateCompletionRequest, Annotated<CompletionResponse>>;
ServerStreamingEngine<NvCreateCompletionRequest, Annotated<NvCreateCompletionResponse>>;
}
pub mod chat_completions {
......
......@@ -15,7 +15,9 @@
use dynamo_llm::protocols::{
codec::{create_message_stream, Message, SseCodecError},
openai::{chat_completions::NvCreateChatCompletionResponse, completions::CompletionResponse},
openai::{
chat_completions::NvCreateChatCompletionResponse, completions::NvCreateCompletionResponse,
},
ContentProvider, DataStream,
};
use futures::StreamExt;
......@@ -112,13 +114,13 @@ async fn test_openai_chat_edge_case_invalid_deserialize_error() {
#[tokio::test]
async fn test_openai_cmpl_stream() {
let stream = create_stream(CMPL_ROOT_PATH, "completion.streaming.1").take(16);
let result = CompletionResponse::from_sse_stream(Box::pin(stream))
let result = NvCreateCompletionResponse::from_sse_stream(Box::pin(stream))
.await
.unwrap();
// todo: provide a cleaner way to extract the content from choices
assert_eq!(
result.choices.first().unwrap().content(),
result.inner.choices.first().unwrap().content(),
" This is a question that is often asked by those outside of AI research and development"
);
}
......@@ -24,7 +24,7 @@ use dynamo_llm::http::service::{
use dynamo_llm::protocols::{
openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{CompletionResponse, NvCreateCompletionRequest},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
Annotated,
};
......@@ -101,13 +101,17 @@ impl
}
#[async_trait]
impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error>
for AlwaysFailEngine
impl
AsyncEngine<
SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<NvCreateCompletionResponse>>,
Error,
> for AlwaysFailEngine
{
async fn generate(
&self,
_request: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
) -> Result<ManyOut<Annotated<NvCreateCompletionResponse>>, Error> {
Err(HttpError {
code: 401,
message: "Always fail".to_string(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment