Commit 96866f43 authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: rename ChatCompletionRequest to NvCreateChatCompletionRequest (#284)

parent 4b42b232
......@@ -19,7 +19,7 @@ use triton_distributed_llm::{
model_type::ModelType,
preprocessor::OpenAIPreprocessor,
types::{
openai::chat_completions::{ChatCompletionRequest, ChatCompletionResponseDelta},
openai::chat_completions::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest},
Annotated,
},
};
......@@ -54,7 +54,7 @@ pub async fn run(
card,
} => {
let frontend = SegmentSource::<
SingleIn<ChatCompletionRequest>,
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(*card.clone())
......
......@@ -21,7 +21,7 @@ use triton_distributed_llm::{
model_type::ModelType,
preprocessor::OpenAIPreprocessor,
types::{
openai::chat_completions::{ChatCompletionRequest, ChatCompletionResponseDelta},
openai::chat_completions::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest},
Annotated,
},
};
......@@ -74,7 +74,7 @@ pub async fn run(
card,
} => {
let frontend = ServiceFrontend::<
SingleIn<ChatCompletionRequest>,
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(*card.clone())
......
......@@ -23,7 +23,7 @@ use triton_distributed_llm::{
preprocessor::OpenAIPreprocessor,
types::{
openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta,
ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
OpenAIChatCompletionsStreamingEngine,
},
Annotated,
......@@ -71,7 +71,7 @@ pub async fn run(
card,
} => {
let frontend = ServiceFrontend::<
SingleIn<ChatCompletionRequest>,
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(*card.clone())
......@@ -165,7 +165,7 @@ async fn main_loop(
// req_builder.min_tokens(8192);
// }
let req = ChatCompletionRequest { inner, nvext: None };
let req = NvCreateChatCompletionRequest { inner, nvext: None };
// Call the model
let mut stream = engine.generate(Context::new(req)).await?;
......
......@@ -21,7 +21,7 @@ use triton_distributed_llm::{
model_card::model::ModelDeploymentCard,
types::{
openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta,
ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
OpenAIChatCompletionsStreamingEngine,
},
Annotated,
......@@ -113,7 +113,7 @@ pub struct Flags {
pub enum EngineConfig {
/// An remote networked engine we don't know about yet
/// We don't have the pre-processor yet so this is only text requests. Type will change later.
Dynamic(Client<ChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>),
Dynamic(Client<NvCreateChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>),
/// A Full service engine does it's own tokenization and prompt formatting.
StaticFull {
......@@ -223,7 +223,7 @@ pub async fn run(
.namespace(endpoint.namespace)?
.component(endpoint.component)?
.endpoint(endpoint.name)
.client::<ChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>()
.client::<NvCreateChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>()
.await?;
tracing::info!("Waiting for remote {}...", client.path());
......
......@@ -19,7 +19,7 @@ use async_stream::stream;
use async_trait::async_trait;
use triton_distributed_llm::protocols::openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta,
ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
};
use triton_distributed_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
......@@ -40,14 +40,14 @@ pub fn make_engine_full() -> OpenAIChatCompletionsStreamingEngine {
#[async_trait]
impl
AsyncEngine<
SingleIn<ChatCompletionRequest>,
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
Error,
> for EchoEngineFull
{
async fn generate(
&self,
incoming_request: SingleIn<ChatCompletionRequest>,
incoming_request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> {
let (request, context) = incoming_request.transfer(());
let deltas = request.response_generator();
......
......@@ -28,7 +28,7 @@ use triton_distributed_runtime::{
use super::ModelManager;
use crate::model_type::ModelType;
use crate::protocols::openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta,
ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
};
use crate::protocols::openai::completions::{CompletionRequest, CompletionResponse};
use tracing;
......@@ -135,7 +135,7 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str,
.namespace(model_entry.endpoint.namespace)?
.component(model_entry.endpoint.component)?
.endpoint(model_entry.endpoint.name)
.client::<ChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>()
.client::<NvCreateChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>()
.await?;
state
.manager
......
......@@ -44,7 +44,7 @@ use crate::protocols::openai::{
chat_completions::ChatCompletionResponse, completions::CompletionResponse,
};
use crate::types::{
openai::{chat_completions::ChatCompletionRequest, completions::CompletionRequest},
openai::{chat_completions::NvCreateChatCompletionRequest, completions::CompletionRequest},
Annotated,
};
......@@ -211,7 +211,7 @@ async fn completions(
#[tracing::instrument(skip_all)]
async fn chat_completions(
State(state): State<Arc<DeploymentState>>,
Json(request): Json<ChatCompletionRequest>,
Json(request): Json<NvCreateChatCompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
// return a 503 if the service is not ready
check_ready(&state)?;
......@@ -227,7 +227,7 @@ async fn chat_completions(
stream: Some(true),
..request.inner
};
let request = ChatCompletionRequest {
let request = NvCreateChatCompletionRequest {
inner: inner_request,
nvext: None,
};
......
......@@ -44,7 +44,7 @@ use triton_distributed_runtime::protocols::annotated::{Annotated, AnnotationsPro
use crate::protocols::{
common::{SamplingOptionsProvider, StopConditionsProvider},
openai::{
chat_completions::{ChatCompletionRequest, ChatCompletionResponseDelta},
chat_completions::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest},
completions::{CompletionRequest, CompletionResponse},
nvext::NvExtProvider,
DeltaGeneratorExt,
......@@ -251,7 +251,7 @@ impl OpenAIPreprocessor {
#[async_trait]
impl
Operator<
SingleIn<ChatCompletionRequest>,
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
SingleIn<BackendInput>,
ManyOut<Annotated<BackendOutput>>,
......@@ -259,7 +259,7 @@ impl
{
async fn generate(
&self,
request: SingleIn<ChatCompletionRequest>,
request: SingleIn<NvCreateChatCompletionRequest>,
next: Arc<
dyn AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<BackendOutput>>, Error>,
>,
......
......@@ -18,11 +18,11 @@ use super::*;
use minijinja::{context, value::Value};
use crate::protocols::openai::{
chat_completions::ChatCompletionRequest, completions::CompletionRequest,
chat_completions::NvCreateChatCompletionRequest, completions::CompletionRequest,
};
use tracing;
impl OAIChatLikeRequest for ChatCompletionRequest {
impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
fn messages(&self) -> Value {
Value::from_serialize(&self.inner.messages)
}
......
......@@ -24,12 +24,11 @@ use validator::Validate;
mod aggregator;
mod delta;
pub use super::{CompletionTokensDetails, CompletionUsage, PromptTokensDetails};
pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct ChatCompletionRequest {
pub struct NvCreateChatCompletionRequest {
#[serde(flatten)]
pub inner: async_openai::types::CreateChatCompletionRequest,
pub nvext: Option<NvExt>,
......@@ -53,7 +52,7 @@ pub struct ChatCompletionResponseDelta {
pub inner: async_openai::types::CreateChatCompletionStreamResponse,
}
impl NvExtProvider for ChatCompletionRequest {
impl NvExtProvider for NvCreateChatCompletionRequest {
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
......@@ -63,7 +62,7 @@ impl NvExtProvider for ChatCompletionRequest {
}
}
impl AnnotationsProvider for ChatCompletionRequest {
impl AnnotationsProvider for NvCreateChatCompletionRequest {
fn annotations(&self) -> Option<Vec<String>> {
self.nvext
.as_ref()
......@@ -79,7 +78,7 @@ impl AnnotationsProvider for ChatCompletionRequest {
}
}
impl OpenAISamplingOptionsProvider for ChatCompletionRequest {
impl OpenAISamplingOptionsProvider for NvCreateChatCompletionRequest {
fn get_temperature(&self) -> Option<f32> {
self.inner.temperature
}
......@@ -102,7 +101,7 @@ impl OpenAISamplingOptionsProvider for ChatCompletionRequest {
}
#[allow(deprecated)]
impl OpenAIStopConditionsProvider for ChatCompletionRequest {
impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest {
fn get_max_tokens(&self) -> Option<u32> {
// ALLOW: max_tokens is deprecated in favor of max_completion_tokens
self.inner.max_tokens
......
......@@ -13,10 +13,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use super::{ChatCompletionRequest, ChatCompletionResponseDelta};
use super::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest};
use crate::protocols::common;
impl ChatCompletionRequest {
impl NvCreateChatCompletionRequest {
// put this method on the request
// inspect the request to extract options
pub fn response_generator(&self) -> DeltaGenerator {
......
......@@ -37,20 +37,18 @@ pub mod openai {
pub mod chat_completions {
use super::*;
// pub use async_openai::types::CreateChatCompletionRequest as ChatCompletionRequest;
// pub use protocols::openai::chat_completions::{
// ChatCompletionResponse, ChatCompletionResponseDelta,
// };
pub use protocols::openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseDelta,
ChatCompletionResponse, ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
};
/// A [`UnaryEngine`] implementation for the OpenAI Chat Completions API
pub type OpenAIChatCompletionsUnaryEngine =
UnaryEngine<ChatCompletionRequest, ChatCompletionResponse>;
UnaryEngine<NvCreateChatCompletionRequest, ChatCompletionResponse>;
/// A [`ServerStreamingEngine`] implementation for the OpenAI Chat Completions API
pub type OpenAIChatCompletionsStreamingEngine =
ServerStreamingEngine<ChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>;
pub type OpenAIChatCompletionsStreamingEngine = ServerStreamingEngine<
NvCreateChatCompletionRequest,
Annotated<ChatCompletionResponseDelta>,
>;
}
}
......@@ -26,7 +26,7 @@ use triton_distributed_llm::http::service::{
};
use triton_distributed_llm::protocols::{
openai::{
chat_completions::{ChatCompletionRequest, ChatCompletionResponseDelta},
chat_completions::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest},
completions::{CompletionRequest, CompletionResponse},
},
Annotated,
......@@ -44,14 +44,14 @@ struct CounterEngine {}
#[async_trait]
impl
AsyncEngine<
SingleIn<ChatCompletionRequest>,
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
Error,
> for CounterEngine
{
async fn generate(
&self,
request: SingleIn<ChatCompletionRequest>,
request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> {
let (request, context) = request.transfer(());
let ctx = context.context();
......@@ -84,14 +84,14 @@ struct AlwaysFailEngine {}
#[async_trait]
impl
AsyncEngine<
SingleIn<ChatCompletionRequest>,
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
Error,
> for AlwaysFailEngine
{
async fn generate(
&self,
_request: SingleIn<ChatCompletionRequest>,
_request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> {
Err(HttpError {
code: 403,
......
......@@ -18,7 +18,7 @@ use anyhow::Ok;
use serde::{Deserialize, Serialize};
use triton_distributed_llm::model_card::model::{ModelDeploymentCard, PromptContextMixin};
use triton_distributed_llm::preprocessor::prompt::PromptFormatter;
use triton_distributed_llm::protocols::openai::chat_completions::ChatCompletionRequest;
use triton_distributed_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
use hf_hub::{api::tokio::ApiBuilder, Cache, Repo, RepoType};
......@@ -232,7 +232,7 @@ impl Request {
tools: Option<&str>,
tool_choice: Option<async_openai::types::ChatCompletionToolChoiceOption>,
model: String,
) -> ChatCompletionRequest {
) -> NvCreateChatCompletionRequest {
let messages: Vec<async_openai::types::ChatCompletionRequestMessage> =
serde_json::from_str(messages).unwrap();
let tools: Option<Vec<async_openai::types::ChatCompletionTool>> =
......@@ -248,7 +248,7 @@ impl Request {
.build()
.unwrap();
ChatCompletionRequest { inner, nvext: None }
NvCreateChatCompletionRequest { inner, nvext: None }
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment