Unverified Commit c103d56a authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: Rename CompletionRequest to NvCreateCompletionRequest (#1383)

parent cfd12d7f
...@@ -139,7 +139,7 @@ mod tests { ...@@ -139,7 +139,7 @@ mod tests {
use super::*; use super::*;
use dynamo_llm::types::openai::{ use dynamo_llm::types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{CompletionRequest, CompletionResponse}, completions::{CompletionResponse, NvCreateCompletionRequest},
}; };
const HF_PATH: &str = concat!( const HF_PATH: &str = concat!(
...@@ -174,7 +174,7 @@ mod tests { ...@@ -174,7 +174,7 @@ mod tests {
// Build pipeline for completions // Build pipeline for completions
let pipeline = let pipeline =
build_pipeline::<CompletionRequest, CompletionResponse>(&card, engine).await?; build_pipeline::<NvCreateCompletionRequest, CompletionResponse>(&card, engine).await?;
// Verify pipeline was created // Verify pipeline was created
assert!(Arc::strong_count(&pipeline) >= 1); assert!(Arc::strong_count(&pipeline) >= 1);
......
...@@ -15,7 +15,7 @@ use dynamo_llm::{ ...@@ -15,7 +15,7 @@ use dynamo_llm::{
openai::chat_completions::{ openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
}, },
openai::completions::{CompletionRequest, CompletionResponse}, openai::completions::{CompletionResponse, NvCreateCompletionRequest},
}, },
}; };
use dynamo_runtime::pipeline::RouterMode; use dynamo_runtime::pipeline::RouterMode;
...@@ -76,10 +76,10 @@ pub async fn run( ...@@ -76,10 +76,10 @@ pub async fn run(
.await?; .await?;
manager.add_chat_completions_model(model.service_name(), chat_pipeline)?; manager.add_chat_completions_model(model.service_name(), chat_pipeline)?;
let cmpl_pipeline = common::build_pipeline::<CompletionRequest, CompletionResponse>( let cmpl_pipeline = common::build_pipeline::<
model.card(), NvCreateCompletionRequest,
inner_engine, CompletionResponse,
) >(model.card(), inner_engine)
.await?; .await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?; manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
} }
......
...@@ -25,7 +25,7 @@ use dynamo_runtime::protocols::annotated::Annotated; ...@@ -25,7 +25,7 @@ use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_llm::protocols::openai::{ use dynamo_llm::protocols::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{prompt_to_string, CompletionRequest, CompletionResponse}, completions::{prompt_to_string, CompletionResponse, NvCreateCompletionRequest},
embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
}; };
...@@ -470,12 +470,12 @@ fn to_logit_bias(lb: HashMap<String, serde_json::Value>) -> HashMap<u32, f32> { ...@@ -470,12 +470,12 @@ fn to_logit_bias(lb: HashMap<String, serde_json::Value>) -> HashMap<u32, f32> {
} }
#[async_trait] #[async_trait]
impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error> impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error>
for MistralRsEngine for MistralRsEngine
{ {
async fn generate( async fn generate(
&self, &self,
request: SingleIn<CompletionRequest>, request: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> { ) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
let (request, context) = request.transfer(()); let (request, context) = request.transfer(());
let ctx = context.context(); let ctx = context.context();
......
...@@ -25,7 +25,7 @@ use crate::{ ...@@ -25,7 +25,7 @@ use crate::{
protocols::openai::chat_completions::{ protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
}, },
protocols::openai::completions::{CompletionRequest, CompletionResponse}, protocols::openai::completions::{CompletionResponse, NvCreateCompletionRequest},
protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
}; };
...@@ -239,7 +239,7 @@ impl ModelWatcher { ...@@ -239,7 +239,7 @@ impl ModelWatcher {
.add_chat_completions_model(&model_entry.name, chat_engine)?; .add_chat_completions_model(&model_entry.name, chat_engine)?;
let frontend = SegmentSource::< let frontend = SegmentSource::<
SingleIn<CompletionRequest>, SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<CompletionResponse>>, ManyOut<Annotated<CompletionResponse>>,
>::new(); >::new();
let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator(); let preprocessor = OpenAIPreprocessor::new(card.clone()).await?.into_operator();
...@@ -290,12 +290,11 @@ impl ModelWatcher { ...@@ -290,12 +290,11 @@ impl ModelWatcher {
.add_chat_completions_model(&model_entry.name, engine)?; .add_chat_completions_model(&model_entry.name, engine)?;
} }
ModelType::Completion => { ModelType::Completion => {
let push_router = let push_router = PushRouter::<
PushRouter::<CompletionRequest, Annotated<CompletionResponse>>::from_client( NvCreateCompletionRequest,
client, Annotated<CompletionResponse>,
Default::default(), >::from_client(client, Default::default())
) .await?;
.await?;
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
self.manager self.manager
.add_completions_model(&model_entry.name, engine)?; .add_completions_model(&model_entry.name, engine)?;
......
...@@ -30,7 +30,7 @@ use crate::preprocessor::PreprocessedRequest; ...@@ -30,7 +30,7 @@ use crate::preprocessor::PreprocessedRequest;
use crate::protocols::common::llm_backend::LLMEngineOutput; use crate::protocols::common::llm_backend::LLMEngineOutput;
use crate::protocols::openai::{ use crate::protocols::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{prompt_to_string, CompletionRequest, CompletionResponse}, completions::{prompt_to_string, CompletionResponse, NvCreateCompletionRequest},
}; };
use crate::types::openai::embeddings::NvCreateEmbeddingRequest; use crate::types::openai::embeddings::NvCreateEmbeddingRequest;
use crate::types::openai::embeddings::NvCreateEmbeddingResponse; use crate::types::openai::embeddings::NvCreateEmbeddingResponse;
...@@ -140,7 +140,7 @@ impl<E> EngineDispatcher<E> { ...@@ -140,7 +140,7 @@ impl<E> EngineDispatcher<E> {
pub trait StreamingEngine: Send + Sync { pub trait StreamingEngine: Send + Sync {
async fn handle_completion( async fn handle_completion(
&self, &self,
req: SingleIn<CompletionRequest>, req: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error>; ) -> Result<ManyOut<Annotated<CompletionResponse>>, Error>;
async fn handle_chat( async fn handle_chat(
...@@ -218,12 +218,12 @@ impl ...@@ -218,12 +218,12 @@ impl
} }
#[async_trait] #[async_trait]
impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error> impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error>
for EchoEngineFull for EchoEngineFull
{ {
async fn generate( async fn generate(
&self, &self,
incoming_request: SingleIn<CompletionRequest>, incoming_request: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> { ) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
let (request, context) = incoming_request.transfer(()); let (request, context) = incoming_request.transfer(());
let deltas = request.response_generator(); let deltas = request.response_generator();
...@@ -265,8 +265,11 @@ impl ...@@ -265,8 +265,11 @@ impl
#[async_trait] #[async_trait]
impl<E> StreamingEngine for EngineDispatcher<E> impl<E> StreamingEngine for EngineDispatcher<E>
where where
E: AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error> E: AsyncEngine<
+ AsyncEngine< SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<CompletionResponse>>,
Error,
> + AsyncEngine<
SingleIn<NvCreateChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error, Error,
...@@ -279,7 +282,7 @@ where ...@@ -279,7 +282,7 @@ where
{ {
async fn handle_completion( async fn handle_completion(
&self, &self,
req: SingleIn<CompletionRequest>, req: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> { ) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
self.inner.generate(req).await self.inner.generate(req).await
} }
...@@ -343,12 +346,12 @@ impl StreamingEngineAdapter { ...@@ -343,12 +346,12 @@ impl StreamingEngineAdapter {
} }
#[async_trait] #[async_trait]
impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error> impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error>
for StreamingEngineAdapter for StreamingEngineAdapter
{ {
async fn generate( async fn generate(
&self, &self,
req: SingleIn<CompletionRequest>, req: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> { ) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
self.0.handle_completion(req).await self.0.handle_completion(req).await
} }
......
...@@ -33,7 +33,9 @@ use crate::protocols::openai::{ ...@@ -33,7 +33,9 @@ use crate::protocols::openai::{
}; };
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
use crate::types::{ use crate::types::{
openai::{chat_completions::NvCreateChatCompletionRequest, completions::CompletionRequest}, openai::{
chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest,
},
Annotated, Annotated,
}; };
...@@ -120,7 +122,7 @@ impl From<HttpError> for ErrorResponse { ...@@ -120,7 +122,7 @@ impl From<HttpError> for ErrorResponse {
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
async fn completions( async fn completions(
State(state): State<Arc<service_v2::State>>, State(state): State<Arc<service_v2::State>>,
Json(request): Json<CompletionRequest>, Json(request): Json<NvCreateCompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
// return a 503 if the service is not ready // return a 503 if the service is not ready
check_ready(&state)?; check_ready(&state)?;
...@@ -137,7 +139,7 @@ async fn completions( ...@@ -137,7 +139,7 @@ async fn completions(
..request.inner ..request.inner
}; };
let request = CompletionRequest { let request = NvCreateCompletionRequest {
inner, inner,
nvext: request.nvext, nvext: request.nvext,
}; };
......
...@@ -46,7 +46,7 @@ use crate::protocols::{ ...@@ -46,7 +46,7 @@ use crate::protocols::{
common::{SamplingOptionsProvider, StopConditionsProvider}, common::{SamplingOptionsProvider, StopConditionsProvider},
openai::{ openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{CompletionRequest, CompletionResponse}, completions::{CompletionResponse, NvCreateCompletionRequest},
nvext::NvExtProvider, nvext::NvExtProvider,
DeltaGeneratorExt, DeltaGeneratorExt,
}, },
...@@ -341,7 +341,7 @@ impl ...@@ -341,7 +341,7 @@ impl
#[async_trait] #[async_trait]
impl impl
Operator< Operator<
SingleIn<CompletionRequest>, SingleIn<NvCreateCompletionRequest>,
ManyOut<Annotated<CompletionResponse>>, ManyOut<Annotated<CompletionResponse>>,
SingleIn<PreprocessedRequest>, SingleIn<PreprocessedRequest>,
ManyOut<Annotated<BackendOutput>>, ManyOut<Annotated<BackendOutput>>,
...@@ -349,7 +349,7 @@ impl ...@@ -349,7 +349,7 @@ impl
{ {
async fn generate( async fn generate(
&self, &self,
request: SingleIn<CompletionRequest>, request: SingleIn<NvCreateCompletionRequest>,
next: Arc< next: Arc<
dyn AsyncEngine< dyn AsyncEngine<
SingleIn<PreprocessedRequest>, SingleIn<PreprocessedRequest>,
......
...@@ -18,7 +18,7 @@ use super::*; ...@@ -18,7 +18,7 @@ use super::*;
use minijinja::{context, value::Value}; use minijinja::{context, value::Value};
use crate::protocols::openai::{ use crate::protocols::openai::{
chat_completions::NvCreateChatCompletionRequest, completions::CompletionRequest, chat_completions::NvCreateChatCompletionRequest, completions::NvCreateCompletionRequest,
}; };
use tracing; use tracing;
...@@ -55,7 +55,7 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest { ...@@ -55,7 +55,7 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
} }
} }
impl OAIChatLikeRequest for CompletionRequest { impl OAIChatLikeRequest for NvCreateCompletionRequest {
fn messages(&self) -> minijinja::value::Value { fn messages(&self) -> minijinja::value::Value {
let message = async_openai::types::ChatCompletionRequestMessage::User( let message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage { async_openai::types::ChatCompletionRequestUserMessage {
......
...@@ -34,7 +34,7 @@ use super::{ ...@@ -34,7 +34,7 @@ use super::{
use dynamo_runtime::protocols::annotated::AnnotationsProvider; use dynamo_runtime::protocols::annotated::AnnotationsProvider;
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct CompletionRequest { pub struct NvCreateCompletionRequest {
#[serde(flatten)] #[serde(flatten)]
pub inner: async_openai::types::CreateCompletionRequest, pub inner: async_openai::types::CreateCompletionRequest,
...@@ -141,7 +141,7 @@ pub fn prompt_to_string(prompt: &async_openai::types::Prompt) -> String { ...@@ -141,7 +141,7 @@ pub fn prompt_to_string(prompt: &async_openai::types::Prompt) -> String {
} }
} }
impl NvExtProvider for CompletionRequest { impl NvExtProvider for NvCreateCompletionRequest {
fn nvext(&self) -> Option<&NvExt> { fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref() self.nvext.as_ref()
} }
...@@ -158,7 +158,7 @@ impl NvExtProvider for CompletionRequest { ...@@ -158,7 +158,7 @@ impl NvExtProvider for CompletionRequest {
} }
} }
impl AnnotationsProvider for CompletionRequest { impl AnnotationsProvider for NvCreateCompletionRequest {
fn annotations(&self) -> Option<Vec<String>> { fn annotations(&self) -> Option<Vec<String>> {
self.nvext self.nvext
.as_ref() .as_ref()
...@@ -174,7 +174,7 @@ impl AnnotationsProvider for CompletionRequest { ...@@ -174,7 +174,7 @@ impl AnnotationsProvider for CompletionRequest {
} }
} }
impl OpenAISamplingOptionsProvider for CompletionRequest { impl OpenAISamplingOptionsProvider for NvCreateCompletionRequest {
fn get_temperature(&self) -> Option<f32> { fn get_temperature(&self) -> Option<f32> {
self.inner.temperature self.inner.temperature
} }
...@@ -196,7 +196,7 @@ impl OpenAISamplingOptionsProvider for CompletionRequest { ...@@ -196,7 +196,7 @@ impl OpenAISamplingOptionsProvider for CompletionRequest {
} }
} }
impl OpenAIStopConditionsProvider for CompletionRequest { impl OpenAIStopConditionsProvider for NvCreateCompletionRequest {
fn get_max_tokens(&self) -> Option<u32> { fn get_max_tokens(&self) -> Option<u32> {
self.inner.max_tokens self.inner.max_tokens
} }
...@@ -255,10 +255,10 @@ impl ResponseFactory { ...@@ -255,10 +255,10 @@ impl ResponseFactory {
} }
/// Implements TryFrom for converting an OpenAI's CompletionRequest to an Engine's CompletionRequest /// Implements TryFrom for converting an OpenAI's CompletionRequest to an Engine's CompletionRequest
impl TryFrom<CompletionRequest> for common::CompletionRequest { impl TryFrom<NvCreateCompletionRequest> for common::CompletionRequest {
type Error = anyhow::Error; type Error = anyhow::Error;
fn try_from(request: CompletionRequest) -> Result<Self, Self::Error> { fn try_from(request: NvCreateCompletionRequest) -> Result<Self, Self::Error> {
// openai_api_rs::v1::completion::CompletionRequest { // openai_api_rs::v1::completion::CompletionRequest {
// NA pub model: String, // NA pub model: String,
// pub prompt: String, // pub prompt: String,
......
...@@ -13,11 +13,11 @@ ...@@ -13,11 +13,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use super::{CompletionChoice, CompletionRequest, CompletionResponse}; use super::{CompletionChoice, CompletionResponse, NvCreateCompletionRequest};
use crate::protocols::common; use crate::protocols::common;
use crate::protocols::openai::CompletionUsage; use crate::protocols::openai::CompletionUsage;
impl CompletionRequest { impl NvCreateCompletionRequest {
// put this method on the request // put this method on the request
// inspect the request to extract options // inspect the request to extract options
pub fn response_generator(&self) -> DeltaGenerator { pub fn response_generator(&self) -> DeltaGenerator {
......
...@@ -24,14 +24,15 @@ pub mod openai { ...@@ -24,14 +24,15 @@ pub mod openai {
pub mod completions { pub mod completions {
use super::*; use super::*;
pub use protocols::openai::completions::{CompletionRequest, CompletionResponse}; pub use protocols::openai::completions::{CompletionResponse, NvCreateCompletionRequest};
/// A [`UnaryEngine`] implementation for the OpenAI Completions API /// A [`UnaryEngine`] implementation for the OpenAI Completions API
pub type OpenAICompletionsUnaryEngine = UnaryEngine<CompletionRequest, CompletionResponse>; pub type OpenAICompletionsUnaryEngine =
UnaryEngine<NvCreateCompletionRequest, CompletionResponse>;
/// A [`ServerStreamingEngine`] implementation for the OpenAI Completions API /// A [`ServerStreamingEngine`] implementation for the OpenAI Completions API
pub type OpenAICompletionsStreamingEngine = pub type OpenAICompletionsStreamingEngine =
ServerStreamingEngine<CompletionRequest, Annotated<CompletionResponse>>; ServerStreamingEngine<NvCreateCompletionRequest, Annotated<CompletionResponse>>;
} }
pub mod chat_completions { pub mod chat_completions {
......
...@@ -24,7 +24,7 @@ use dynamo_llm::http::service::{ ...@@ -24,7 +24,7 @@ use dynamo_llm::http::service::{
use dynamo_llm::protocols::{ use dynamo_llm::protocols::{
openai::{ openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{CompletionRequest, CompletionResponse}, completions::{CompletionResponse, NvCreateCompletionRequest},
}, },
Annotated, Annotated,
}; };
...@@ -101,12 +101,12 @@ impl ...@@ -101,12 +101,12 @@ impl
} }
#[async_trait] #[async_trait]
impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error> impl AsyncEngine<SingleIn<NvCreateCompletionRequest>, ManyOut<Annotated<CompletionResponse>>, Error>
for AlwaysFailEngine for AlwaysFailEngine
{ {
async fn generate( async fn generate(
&self, &self,
_request: SingleIn<CompletionRequest>, _request: SingleIn<NvCreateCompletionRequest>,
) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> { ) -> Result<ManyOut<Annotated<CompletionResponse>>, Error> {
Err(HttpError { Err(HttpError {
code: 401, code: 401,
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
// limitations under the License. // limitations under the License.
use async_openai::types::CreateCompletionRequestArgs; use async_openai::types::CreateCompletionRequestArgs;
use dynamo_llm::protocols::openai::{self, completions::CompletionRequest}; use dynamo_llm::protocols::openai::{self, completions::NvCreateCompletionRequest};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
struct CompletionSample { struct CompletionSample {
request: CompletionRequest, request: NvCreateCompletionRequest,
description: String, description: String,
} }
...@@ -36,7 +36,7 @@ impl CompletionSample { ...@@ -36,7 +36,7 @@ impl CompletionSample {
let inner = builder.build().unwrap(); let inner = builder.build().unwrap();
let request = CompletionRequest { inner, nvext: None }; let request = NvCreateCompletionRequest { inner, nvext: None };
Ok(Self { Ok(Self {
request, request,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment