Commit 96866f43 authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: rename ChatCompletionRequest to NvCreateChatCompletionRequest (#284)

parent 4b42b232
...@@ -19,7 +19,7 @@ use triton_distributed_llm::{ ...@@ -19,7 +19,7 @@ use triton_distributed_llm::{
model_type::ModelType, model_type::ModelType,
preprocessor::OpenAIPreprocessor, preprocessor::OpenAIPreprocessor,
types::{ types::{
openai::chat_completions::{ChatCompletionRequest, ChatCompletionResponseDelta}, openai::chat_completions::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest},
Annotated, Annotated,
}, },
}; };
...@@ -54,7 +54,7 @@ pub async fn run( ...@@ -54,7 +54,7 @@ pub async fn run(
card, card,
} => { } => {
let frontend = SegmentSource::< let frontend = SegmentSource::<
SingleIn<ChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>, ManyOut<Annotated<ChatCompletionResponseDelta>>,
>::new(); >::new();
let preprocessor = OpenAIPreprocessor::new(*card.clone()) let preprocessor = OpenAIPreprocessor::new(*card.clone())
......
...@@ -21,7 +21,7 @@ use triton_distributed_llm::{ ...@@ -21,7 +21,7 @@ use triton_distributed_llm::{
model_type::ModelType, model_type::ModelType,
preprocessor::OpenAIPreprocessor, preprocessor::OpenAIPreprocessor,
types::{ types::{
openai::chat_completions::{ChatCompletionRequest, ChatCompletionResponseDelta}, openai::chat_completions::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest},
Annotated, Annotated,
}, },
}; };
...@@ -74,7 +74,7 @@ pub async fn run( ...@@ -74,7 +74,7 @@ pub async fn run(
card, card,
} => { } => {
let frontend = ServiceFrontend::< let frontend = ServiceFrontend::<
SingleIn<ChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>, ManyOut<Annotated<ChatCompletionResponseDelta>>,
>::new(); >::new();
let preprocessor = OpenAIPreprocessor::new(*card.clone()) let preprocessor = OpenAIPreprocessor::new(*card.clone())
......
...@@ -23,7 +23,7 @@ use triton_distributed_llm::{ ...@@ -23,7 +23,7 @@ use triton_distributed_llm::{
preprocessor::OpenAIPreprocessor, preprocessor::OpenAIPreprocessor,
types::{ types::{
openai::chat_completions::{ openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta, ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
OpenAIChatCompletionsStreamingEngine, OpenAIChatCompletionsStreamingEngine,
}, },
Annotated, Annotated,
...@@ -71,7 +71,7 @@ pub async fn run( ...@@ -71,7 +71,7 @@ pub async fn run(
card, card,
} => { } => {
let frontend = ServiceFrontend::< let frontend = ServiceFrontend::<
SingleIn<ChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>, ManyOut<Annotated<ChatCompletionResponseDelta>>,
>::new(); >::new();
let preprocessor = OpenAIPreprocessor::new(*card.clone()) let preprocessor = OpenAIPreprocessor::new(*card.clone())
...@@ -165,7 +165,7 @@ async fn main_loop( ...@@ -165,7 +165,7 @@ async fn main_loop(
// req_builder.min_tokens(8192); // req_builder.min_tokens(8192);
// } // }
let req = ChatCompletionRequest { inner, nvext: None }; let req = NvCreateChatCompletionRequest { inner, nvext: None };
// Call the model // Call the model
let mut stream = engine.generate(Context::new(req)).await?; let mut stream = engine.generate(Context::new(req)).await?;
......
...@@ -21,7 +21,7 @@ use triton_distributed_llm::{ ...@@ -21,7 +21,7 @@ use triton_distributed_llm::{
model_card::model::ModelDeploymentCard, model_card::model::ModelDeploymentCard,
types::{ types::{
openai::chat_completions::{ openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta, ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
OpenAIChatCompletionsStreamingEngine, OpenAIChatCompletionsStreamingEngine,
}, },
Annotated, Annotated,
...@@ -113,7 +113,7 @@ pub struct Flags { ...@@ -113,7 +113,7 @@ pub struct Flags {
pub enum EngineConfig { pub enum EngineConfig {
/// An remote networked engine we don't know about yet /// An remote networked engine we don't know about yet
/// We don't have the pre-processor yet so this is only text requests. Type will change later. /// We don't have the pre-processor yet so this is only text requests. Type will change later.
Dynamic(Client<ChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>), Dynamic(Client<NvCreateChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>),
/// A Full service engine does it's own tokenization and prompt formatting. /// A Full service engine does it's own tokenization and prompt formatting.
StaticFull { StaticFull {
...@@ -223,7 +223,7 @@ pub async fn run( ...@@ -223,7 +223,7 @@ pub async fn run(
.namespace(endpoint.namespace)? .namespace(endpoint.namespace)?
.component(endpoint.component)? .component(endpoint.component)?
.endpoint(endpoint.name) .endpoint(endpoint.name)
.client::<ChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>() .client::<NvCreateChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>()
.await?; .await?;
tracing::info!("Waiting for remote {}...", client.path()); tracing::info!("Waiting for remote {}...", client.path());
......
...@@ -19,7 +19,7 @@ use async_stream::stream; ...@@ -19,7 +19,7 @@ use async_stream::stream;
use async_trait::async_trait; use async_trait::async_trait;
use triton_distributed_llm::protocols::openai::chat_completions::{ use triton_distributed_llm::protocols::openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta, ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
}; };
use triton_distributed_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine; use triton_distributed_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream}; use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
...@@ -40,14 +40,14 @@ pub fn make_engine_full() -> OpenAIChatCompletionsStreamingEngine { ...@@ -40,14 +40,14 @@ pub fn make_engine_full() -> OpenAIChatCompletionsStreamingEngine {
#[async_trait] #[async_trait]
impl impl
AsyncEngine< AsyncEngine<
SingleIn<ChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>, ManyOut<Annotated<ChatCompletionResponseDelta>>,
Error, Error,
> for EchoEngineFull > for EchoEngineFull
{ {
async fn generate( async fn generate(
&self, &self,
incoming_request: SingleIn<ChatCompletionRequest>, incoming_request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> { ) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> {
let (request, context) = incoming_request.transfer(()); let (request, context) = incoming_request.transfer(());
let deltas = request.response_generator(); let deltas = request.response_generator();
......
...@@ -28,7 +28,7 @@ use triton_distributed_runtime::{ ...@@ -28,7 +28,7 @@ use triton_distributed_runtime::{
use super::ModelManager; use super::ModelManager;
use crate::model_type::ModelType; use crate::model_type::ModelType;
use crate::protocols::openai::chat_completions::{ use crate::protocols::openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta, ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
}; };
use crate::protocols::openai::completions::{CompletionRequest, CompletionResponse}; use crate::protocols::openai::completions::{CompletionRequest, CompletionResponse};
use tracing; use tracing;
...@@ -135,7 +135,7 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str, ...@@ -135,7 +135,7 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str,
.namespace(model_entry.endpoint.namespace)? .namespace(model_entry.endpoint.namespace)?
.component(model_entry.endpoint.component)? .component(model_entry.endpoint.component)?
.endpoint(model_entry.endpoint.name) .endpoint(model_entry.endpoint.name)
.client::<ChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>() .client::<NvCreateChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>()
.await?; .await?;
state state
.manager .manager
......
...@@ -44,7 +44,7 @@ use crate::protocols::openai::{ ...@@ -44,7 +44,7 @@ use crate::protocols::openai::{
chat_completions::ChatCompletionResponse, completions::CompletionResponse, chat_completions::ChatCompletionResponse, completions::CompletionResponse,
}; };
use crate::types::{ use crate::types::{
openai::{chat_completions::ChatCompletionRequest, completions::CompletionRequest}, openai::{chat_completions::NvCreateChatCompletionRequest, completions::CompletionRequest},
Annotated, Annotated,
}; };
...@@ -211,7 +211,7 @@ async fn completions( ...@@ -211,7 +211,7 @@ async fn completions(
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
async fn chat_completions( async fn chat_completions(
State(state): State<Arc<DeploymentState>>, State(state): State<Arc<DeploymentState>>,
Json(request): Json<ChatCompletionRequest>, Json(request): Json<NvCreateChatCompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
// return a 503 if the service is not ready // return a 503 if the service is not ready
check_ready(&state)?; check_ready(&state)?;
...@@ -227,7 +227,7 @@ async fn chat_completions( ...@@ -227,7 +227,7 @@ async fn chat_completions(
stream: Some(true), stream: Some(true),
..request.inner ..request.inner
}; };
let request = ChatCompletionRequest { let request = NvCreateChatCompletionRequest {
inner: inner_request, inner: inner_request,
nvext: None, nvext: None,
}; };
......
...@@ -44,7 +44,7 @@ use triton_distributed_runtime::protocols::annotated::{Annotated, AnnotationsPro ...@@ -44,7 +44,7 @@ use triton_distributed_runtime::protocols::annotated::{Annotated, AnnotationsPro
use crate::protocols::{ use crate::protocols::{
common::{SamplingOptionsProvider, StopConditionsProvider}, common::{SamplingOptionsProvider, StopConditionsProvider},
openai::{ openai::{
chat_completions::{ChatCompletionRequest, ChatCompletionResponseDelta}, chat_completions::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest},
completions::{CompletionRequest, CompletionResponse}, completions::{CompletionRequest, CompletionResponse},
nvext::NvExtProvider, nvext::NvExtProvider,
DeltaGeneratorExt, DeltaGeneratorExt,
...@@ -251,7 +251,7 @@ impl OpenAIPreprocessor { ...@@ -251,7 +251,7 @@ impl OpenAIPreprocessor {
#[async_trait] #[async_trait]
impl impl
Operator< Operator<
SingleIn<ChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>, ManyOut<Annotated<ChatCompletionResponseDelta>>,
SingleIn<BackendInput>, SingleIn<BackendInput>,
ManyOut<Annotated<BackendOutput>>, ManyOut<Annotated<BackendOutput>>,
...@@ -259,7 +259,7 @@ impl ...@@ -259,7 +259,7 @@ impl
{ {
async fn generate( async fn generate(
&self, &self,
request: SingleIn<ChatCompletionRequest>, request: SingleIn<NvCreateChatCompletionRequest>,
next: Arc< next: Arc<
dyn AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<BackendOutput>>, Error>, dyn AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<BackendOutput>>, Error>,
>, >,
......
...@@ -18,11 +18,11 @@ use super::*; ...@@ -18,11 +18,11 @@ use super::*;
use minijinja::{context, value::Value}; use minijinja::{context, value::Value};
use crate::protocols::openai::{ use crate::protocols::openai::{
chat_completions::ChatCompletionRequest, completions::CompletionRequest, chat_completions::NvCreateChatCompletionRequest, completions::CompletionRequest,
}; };
use tracing; use tracing;
impl OAIChatLikeRequest for ChatCompletionRequest { impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
fn messages(&self) -> Value { fn messages(&self) -> Value {
Value::from_serialize(&self.inner.messages) Value::from_serialize(&self.inner.messages)
} }
......
...@@ -24,12 +24,11 @@ use validator::Validate; ...@@ -24,12 +24,11 @@ use validator::Validate;
mod aggregator; mod aggregator;
mod delta; mod delta;
pub use super::{CompletionTokensDetails, CompletionUsage, PromptTokensDetails};
pub use aggregator::DeltaAggregator; pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator; pub use delta::DeltaGenerator;
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct ChatCompletionRequest { pub struct NvCreateChatCompletionRequest {
#[serde(flatten)] #[serde(flatten)]
pub inner: async_openai::types::CreateChatCompletionRequest, pub inner: async_openai::types::CreateChatCompletionRequest,
pub nvext: Option<NvExt>, pub nvext: Option<NvExt>,
...@@ -53,7 +52,7 @@ pub struct ChatCompletionResponseDelta { ...@@ -53,7 +52,7 @@ pub struct ChatCompletionResponseDelta {
pub inner: async_openai::types::CreateChatCompletionStreamResponse, pub inner: async_openai::types::CreateChatCompletionStreamResponse,
} }
impl NvExtProvider for ChatCompletionRequest { impl NvExtProvider for NvCreateChatCompletionRequest {
fn nvext(&self) -> Option<&NvExt> { fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref() self.nvext.as_ref()
} }
...@@ -63,7 +62,7 @@ impl NvExtProvider for ChatCompletionRequest { ...@@ -63,7 +62,7 @@ impl NvExtProvider for ChatCompletionRequest {
} }
} }
impl AnnotationsProvider for ChatCompletionRequest { impl AnnotationsProvider for NvCreateChatCompletionRequest {
fn annotations(&self) -> Option<Vec<String>> { fn annotations(&self) -> Option<Vec<String>> {
self.nvext self.nvext
.as_ref() .as_ref()
...@@ -79,7 +78,7 @@ impl AnnotationsProvider for ChatCompletionRequest { ...@@ -79,7 +78,7 @@ impl AnnotationsProvider for ChatCompletionRequest {
} }
} }
impl OpenAISamplingOptionsProvider for ChatCompletionRequest { impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest {
fn get_temperature(&self) -> Option<f32> { fn get_temperature(&self) -> Option<f32> {
self.inner.temperature self.inner.temperature
} }
...@@ -102,7 +101,7 @@ impl OpenAISamplingOptionsProvider for ChatCompletionRequest { ...@@ -102,7 +101,7 @@ impl OpenAISamplingOptionsProvider for ChatCompletionRequest {
} }
#[allow(deprecated)] #[allow(deprecated)]
impl OpenAIStopConditionsProvider for ChatCompletionRequest { impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
fn get_max_tokens(&self) -> Option<u32> { fn get_max_tokens(&self) -> Option<u32> {
// ALLOW: max_tokens is deprecated in favor of max_completion_tokens // ALLOW: max_tokens is deprecated in favor of max_completion_tokens
self.inner.max_tokens self.inner.max_tokens
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use super::{ChatCompletionRequest, ChatCompletionResponseDelta}; use super::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest};
use crate::protocols::common; use crate::protocols::common;
impl ChatCompletionRequest { impl NvCreateChatCompletionRequest {
// put this method on the request // put this method on the request
// inspect the request to extract options // inspect the request to extract options
pub fn response_generator(&self) -> DeltaGenerator { pub fn response_generator(&self) -> DeltaGenerator {
......
...@@ -37,20 +37,18 @@ pub mod openai { ...@@ -37,20 +37,18 @@ pub mod openai {
pub mod chat_completions { pub mod chat_completions {
use super::*; use super::*;
// pub use async_openai::types::CreateChatCompletionRequest as ChatCompletionRequest;
// pub use protocols::openai::chat_completions::{
// ChatCompletionResponse, ChatCompletionResponseDelta,
// };
pub use protocols::openai::chat_completions::{ pub use protocols::openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseDelta, ChatCompletionResponse, ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
}; };
/// A [`UnaryEngine`] implementation for the OpenAI Chat Completions API /// A [`UnaryEngine`] implementation for the OpenAI Chat Completions API
pub type OpenAIChatCompletionsUnaryEngine = pub type OpenAIChatCompletionsUnaryEngine =
UnaryEngine<ChatCompletionRequest, ChatCompletionResponse>; UnaryEngine<NvCreateChatCompletionRequest, ChatCompletionResponse>;
/// A [`ServerStreamingEngine`] implementation for the OpenAI Chat Completions API /// A [`ServerStreamingEngine`] implementation for the OpenAI Chat Completions API
pub type OpenAIChatCompletionsStreamingEngine = pub type OpenAIChatCompletionsStreamingEngine = ServerStreamingEngine<
ServerStreamingEngine<ChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>; NvCreateChatCompletionRequest,
Annotated<ChatCompletionResponseDelta>,
>;
} }
} }
...@@ -26,7 +26,7 @@ use triton_distributed_llm::http::service::{ ...@@ -26,7 +26,7 @@ use triton_distributed_llm::http::service::{
}; };
use triton_distributed_llm::protocols::{ use triton_distributed_llm::protocols::{
openai::{ openai::{
chat_completions::{ChatCompletionRequest, ChatCompletionResponseDelta}, chat_completions::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest},
completions::{CompletionRequest, CompletionResponse}, completions::{CompletionRequest, CompletionResponse},
}, },
Annotated, Annotated,
...@@ -44,14 +44,14 @@ struct CounterEngine {} ...@@ -44,14 +44,14 @@ struct CounterEngine {}
#[async_trait] #[async_trait]
impl impl
AsyncEngine< AsyncEngine<
SingleIn<ChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>, ManyOut<Annotated<ChatCompletionResponseDelta>>,
Error, Error,
> for CounterEngine > for CounterEngine
{ {
async fn generate( async fn generate(
&self, &self,
request: SingleIn<ChatCompletionRequest>, request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> { ) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> {
let (request, context) = request.transfer(()); let (request, context) = request.transfer(());
let ctx = context.context(); let ctx = context.context();
...@@ -84,14 +84,14 @@ struct AlwaysFailEngine {} ...@@ -84,14 +84,14 @@ struct AlwaysFailEngine {}
#[async_trait] #[async_trait]
impl impl
AsyncEngine< AsyncEngine<
SingleIn<ChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>, ManyOut<Annotated<ChatCompletionResponseDelta>>,
Error, Error,
> for AlwaysFailEngine > for AlwaysFailEngine
{ {
async fn generate( async fn generate(
&self, &self,
_request: SingleIn<ChatCompletionRequest>, _request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> { ) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> {
Err(HttpError { Err(HttpError {
code: 403, code: 403,
......
...@@ -18,7 +18,7 @@ use anyhow::Ok; ...@@ -18,7 +18,7 @@ use anyhow::Ok;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use triton_distributed_llm::model_card::model::{ModelDeploymentCard, PromptContextMixin}; use triton_distributed_llm::model_card::model::{ModelDeploymentCard, PromptContextMixin};
use triton_distributed_llm::preprocessor::prompt::PromptFormatter; use triton_distributed_llm::preprocessor::prompt::PromptFormatter;
use triton_distributed_llm::protocols::openai::chat_completions::ChatCompletionRequest; use triton_distributed_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
use hf_hub::{api::tokio::ApiBuilder, Cache, Repo, RepoType}; use hf_hub::{api::tokio::ApiBuilder, Cache, Repo, RepoType};
...@@ -232,7 +232,7 @@ impl Request { ...@@ -232,7 +232,7 @@ impl Request {
tools: Option<&str>, tools: Option<&str>,
tool_choice: Option<async_openai::types::ChatCompletionToolChoiceOption>, tool_choice: Option<async_openai::types::ChatCompletionToolChoiceOption>,
model: String, model: String,
) -> ChatCompletionRequest { ) -> NvCreateChatCompletionRequest {
let messages: Vec<async_openai::types::ChatCompletionRequestMessage> = let messages: Vec<async_openai::types::ChatCompletionRequestMessage> =
serde_json::from_str(messages).unwrap(); serde_json::from_str(messages).unwrap();
let tools: Option<Vec<async_openai::types::ChatCompletionTool>> = let tools: Option<Vec<async_openai::types::ChatCompletionTool>> =
...@@ -248,7 +248,7 @@ impl Request { ...@@ -248,7 +248,7 @@ impl Request {
.build() .build()
.unwrap(); .unwrap();
ChatCompletionRequest { inner, nvext: None } NvCreateChatCompletionRequest { inner, nvext: None }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment