Commit 110f3f8c authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: rename ChatCompletionResponseDelta to NvCreateChatCompletionStreamResponse (#292)

parent c13ea718
...@@ -19,7 +19,9 @@ use triton_distributed_llm::{ ...@@ -19,7 +19,9 @@ use triton_distributed_llm::{
model_type::ModelType, model_type::ModelType,
preprocessor::OpenAIPreprocessor, preprocessor::OpenAIPreprocessor,
types::{ types::{
openai::chat_completions::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest}, openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
},
Annotated, Annotated,
}, },
}; };
...@@ -55,7 +57,7 @@ pub async fn run( ...@@ -55,7 +57,7 @@ pub async fn run(
} => { } => {
let frontend = SegmentSource::< let frontend = SegmentSource::<
SingleIn<NvCreateChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>, ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new(); >::new();
let preprocessor = OpenAIPreprocessor::new(*card.clone()) let preprocessor = OpenAIPreprocessor::new(*card.clone())
.await? .await?
......
...@@ -21,7 +21,9 @@ use triton_distributed_llm::{ ...@@ -21,7 +21,9 @@ use triton_distributed_llm::{
model_type::ModelType, model_type::ModelType,
preprocessor::OpenAIPreprocessor, preprocessor::OpenAIPreprocessor,
types::{ types::{
openai::chat_completions::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest}, openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
},
Annotated, Annotated,
}, },
}; };
...@@ -75,7 +77,7 @@ pub async fn run( ...@@ -75,7 +77,7 @@ pub async fn run(
} => { } => {
let frontend = ServiceFrontend::< let frontend = ServiceFrontend::<
SingleIn<NvCreateChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>, ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new(); >::new();
let preprocessor = OpenAIPreprocessor::new(*card.clone()) let preprocessor = OpenAIPreprocessor::new(*card.clone())
.await? .await?
......
...@@ -23,7 +23,7 @@ use triton_distributed_llm::{ ...@@ -23,7 +23,7 @@ use triton_distributed_llm::{
preprocessor::OpenAIPreprocessor, preprocessor::OpenAIPreprocessor,
types::{ types::{
openai::chat_completions::{ openai::chat_completions::{
ChatCompletionResponseDelta, NvCreateChatCompletionRequest, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
OpenAIChatCompletionsStreamingEngine, OpenAIChatCompletionsStreamingEngine,
}, },
Annotated, Annotated,
...@@ -72,7 +72,7 @@ pub async fn run( ...@@ -72,7 +72,7 @@ pub async fn run(
} => { } => {
let frontend = ServiceFrontend::< let frontend = ServiceFrontend::<
SingleIn<NvCreateChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>, ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new(); >::new();
let preprocessor = OpenAIPreprocessor::new(*card.clone()) let preprocessor = OpenAIPreprocessor::new(*card.clone())
.await? .await?
......
...@@ -21,7 +21,7 @@ use triton_distributed_llm::{ ...@@ -21,7 +21,7 @@ use triton_distributed_llm::{
model_card::model::ModelDeploymentCard, model_card::model::ModelDeploymentCard,
types::{ types::{
openai::chat_completions::{ openai::chat_completions::{
ChatCompletionResponseDelta, NvCreateChatCompletionRequest, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
OpenAIChatCompletionsStreamingEngine, OpenAIChatCompletionsStreamingEngine,
}, },
Annotated, Annotated,
...@@ -113,7 +113,7 @@ pub struct Flags { ...@@ -113,7 +113,7 @@ pub struct Flags {
pub enum EngineConfig { pub enum EngineConfig {
/// An remote networked engine we don't know about yet /// An remote networked engine we don't know about yet
/// We don't have the pre-processor yet so this is only text requests. Type will change later. /// We don't have the pre-processor yet so this is only text requests. Type will change later.
Dynamic(Client<NvCreateChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>), Dynamic(Client<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>),
/// A Full service engine does it's own tokenization and prompt formatting. /// A Full service engine does it's own tokenization and prompt formatting.
StaticFull { StaticFull {
...@@ -223,7 +223,7 @@ pub async fn run( ...@@ -223,7 +223,7 @@ pub async fn run(
.namespace(endpoint.namespace)? .namespace(endpoint.namespace)?
.component(endpoint.component)? .component(endpoint.component)?
.endpoint(endpoint.name) .endpoint(endpoint.name)
.client::<NvCreateChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>() .client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>()
.await?; .await?;
tracing::info!("Waiting for remote {}...", client.path()); tracing::info!("Waiting for remote {}...", client.path());
......
...@@ -19,7 +19,7 @@ use async_stream::stream; ...@@ -19,7 +19,7 @@ use async_stream::stream;
use async_trait::async_trait; use async_trait::async_trait;
use triton_distributed_llm::protocols::openai::chat_completions::{ use triton_distributed_llm::protocols::openai::chat_completions::{
ChatCompletionResponseDelta, NvCreateChatCompletionRequest, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
}; };
use triton_distributed_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine; use triton_distributed_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream}; use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
...@@ -41,14 +41,14 @@ pub fn make_engine_full() -> OpenAIChatCompletionsStreamingEngine { ...@@ -41,14 +41,14 @@ pub fn make_engine_full() -> OpenAIChatCompletionsStreamingEngine {
impl impl
AsyncEngine< AsyncEngine<
SingleIn<NvCreateChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>, ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error, Error,
> for EchoEngineFull > for EchoEngineFull
{ {
async fn generate( async fn generate(
&self, &self,
incoming_request: SingleIn<NvCreateChatCompletionRequest>, incoming_request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
let (request, context) = incoming_request.transfer(()); let (request, context) = incoming_request.transfer(());
let deltas = request.response_generator(); let deltas = request.response_generator();
let ctx = context.context(); let ctx = context.context();
...@@ -72,7 +72,7 @@ impl ...@@ -72,7 +72,7 @@ impl
// we are returning characters not tokens, so speed up some // we are returning characters not tokens, so speed up some
tokio::time::sleep(TOKEN_ECHO_DELAY/2).await; tokio::time::sleep(TOKEN_ECHO_DELAY/2).await;
let inner = deltas.create_choice(0, Some(c.to_string()), None, None); let inner = deltas.create_choice(0, Some(c.to_string()), None, None);
let response = ChatCompletionResponseDelta { let response = NvCreateChatCompletionStreamResponse {
inner, inner,
}; };
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
...@@ -80,7 +80,7 @@ impl ...@@ -80,7 +80,7 @@ impl
} }
let inner = deltas.create_choice(0, None, Some(async_openai::types::FinishReason::Stop), None); let inner = deltas.create_choice(0, None, Some(async_openai::types::FinishReason::Stop), None);
let response = ChatCompletionResponseDelta { let response = NvCreateChatCompletionStreamResponse {
inner, inner,
}; };
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None }; yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
......
...@@ -34,7 +34,7 @@ use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn}; ...@@ -34,7 +34,7 @@ use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated; use triton_distributed_runtime::protocols::annotated::Annotated;
use crate::protocols::openai::chat_completions::{ use crate::protocols::openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta, ChatCompletionRequest, NvCreateChatCompletionStreamResponse,
}; };
use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine; use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
...@@ -161,14 +161,14 @@ impl MistralRsEngine { ...@@ -161,14 +161,14 @@ impl MistralRsEngine {
impl impl
AsyncEngine< AsyncEngine<
SingleIn<ChatCompletionRequest>, SingleIn<ChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>, ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error, Error,
> for MistralRsEngine > for MistralRsEngine
{ {
async fn generate( async fn generate(
&self, &self,
request: SingleIn<ChatCompletionRequest>, request: SingleIn<ChatCompletionRequest>,
) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
let (request, context) = request.transfer(()); let (request, context) = request.transfer(());
let ctx = context.context(); let ctx = context.context();
let (tx, mut rx) = channel(10_000); let (tx, mut rx) = channel(10_000);
...@@ -286,7 +286,7 @@ impl ...@@ -286,7 +286,7 @@ impl
system_fingerprint: Some(c.system_fingerprint), system_fingerprint: Some(c.system_fingerprint),
service_tier: None, service_tier: None,
}; };
let delta = ChatCompletionResponseDelta{inner}; let delta = NvCreateChatCompletionStreamResponse{inner};
let ann = Annotated{ let ann = Annotated{
id: None, id: None,
data: Some(delta), data: Some(delta),
......
...@@ -28,7 +28,7 @@ use triton_distributed_runtime::{ ...@@ -28,7 +28,7 @@ use triton_distributed_runtime::{
use super::ModelManager; use super::ModelManager;
use crate::model_type::ModelType; use crate::model_type::ModelType;
use crate::protocols::openai::chat_completions::{ use crate::protocols::openai::chat_completions::{
ChatCompletionResponseDelta, NvCreateChatCompletionRequest, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
}; };
use crate::protocols::openai::completions::{CompletionRequest, CompletionResponse}; use crate::protocols::openai::completions::{CompletionRequest, CompletionResponse};
use tracing; use tracing;
...@@ -135,7 +135,7 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str, ...@@ -135,7 +135,7 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str,
.namespace(model_entry.endpoint.namespace)? .namespace(model_entry.endpoint.namespace)?
.component(model_entry.endpoint.component)? .component(model_entry.endpoint.component)?
.endpoint(model_entry.endpoint.name) .endpoint(model_entry.endpoint.name)
.client::<NvCreateChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>() .client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>()
.await?; .await?;
state state
.manager .manager
......
...@@ -44,7 +44,7 @@ use triton_distributed_runtime::protocols::annotated::{Annotated, AnnotationsPro ...@@ -44,7 +44,7 @@ use triton_distributed_runtime::protocols::annotated::{Annotated, AnnotationsPro
use crate::protocols::{ use crate::protocols::{
common::{SamplingOptionsProvider, StopConditionsProvider}, common::{SamplingOptionsProvider, StopConditionsProvider},
openai::{ openai::{
chat_completions::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{CompletionRequest, CompletionResponse}, completions::{CompletionRequest, CompletionResponse},
nvext::NvExtProvider, nvext::NvExtProvider,
DeltaGeneratorExt, DeltaGeneratorExt,
...@@ -225,7 +225,7 @@ impl OpenAIPreprocessor { ...@@ -225,7 +225,7 @@ impl OpenAIPreprocessor {
tracing::trace!( tracing::trace!(
request_id = inner.context.id(), request_id = inner.context.id(),
"OpenAI ChatCompletionResponseDelta: {:?}", "OpenAI NvCreateChatCompletionStreamResponse: {:?}",
response response
); );
...@@ -252,7 +252,7 @@ impl OpenAIPreprocessor { ...@@ -252,7 +252,7 @@ impl OpenAIPreprocessor {
impl impl
Operator< Operator<
SingleIn<NvCreateChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>, ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
SingleIn<BackendInput>, SingleIn<BackendInput>,
ManyOut<Annotated<BackendOutput>>, ManyOut<Annotated<BackendOutput>>,
> for OpenAIPreprocessor > for OpenAIPreprocessor
...@@ -263,7 +263,7 @@ impl ...@@ -263,7 +263,7 @@ impl
next: Arc< next: Arc<
dyn AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<BackendOutput>>, Error>, dyn AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<BackendOutput>>, Error>,
>, >,
) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
// unpack the request // unpack the request
let (request, context) = request.into_parts(); let (request, context) = request.into_parts();
...@@ -281,7 +281,7 @@ impl ...@@ -281,7 +281,7 @@ impl
let common_request = context.map(|_| common_request); let common_request = context.map(|_| common_request);
// create a stream of annotations this will be prepend to the response stream // create a stream of annotations this will be prepend to the response stream
let annotations: Vec<Annotated<ChatCompletionResponseDelta>> = annotations let annotations: Vec<Annotated<NvCreateChatCompletionStreamResponse>> = annotations
.into_iter() .into_iter()
.flat_map(|(k, v)| Annotated::from_annotation(k, &v)) .flat_map(|(k, v)| Annotated::from_annotation(k, &v))
.collect(); .collect();
......
...@@ -640,7 +640,7 @@ data: [DONE] ...@@ -640,7 +640,7 @@ data: [DONE]
#[tokio::test] #[tokio::test]
async fn test_openai_chat_stream() { async fn test_openai_chat_stream() {
use crate::protocols::openai::chat_completions::ChatCompletionResponseDelta; use crate::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
// let cursor = Cursor::new(SAMPLE_CHAT_DATA); // let cursor = Cursor::new(SAMPLE_CHAT_DATA);
// let mut framed = FramedRead::new(cursor, SseLineCodec::new()); // let mut framed = FramedRead::new(cursor, SseLineCodec::new());
...@@ -652,7 +652,7 @@ data: [DONE] ...@@ -652,7 +652,7 @@ data: [DONE]
loop { loop {
match stream.next().await { match stream.next().await {
Some(Ok(message)) => { Some(Ok(message)) => {
let delta: ChatCompletionResponseDelta = let delta: NvCreateChatCompletionStreamResponse =
serde_json::from_str(&message.data.unwrap()).unwrap(); serde_json::from_str(&message.data.unwrap()).unwrap();
counter += 1; counter += 1;
println!("counter: {}", counter); println!("counter: {}", counter);
......
...@@ -47,7 +47,7 @@ pub struct ChatCompletionContent { ...@@ -47,7 +47,7 @@ pub struct ChatCompletionContent {
} }
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct ChatCompletionResponseDelta { pub struct NvCreateChatCompletionStreamResponse {
#[serde(flatten)] #[serde(flatten)]
pub inner: async_openai::types::CreateChatCompletionStreamResponse, pub inner: async_openai::types::CreateChatCompletionStreamResponse,
} }
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use super::{ChatCompletionResponseDelta, NvCreateChatCompletionResponse}; use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
use crate::protocols::{ use crate::protocols::{
codec::{Message, SseCodecError}, codec::{Message, SseCodecError},
convert_sse_stream, Annotated, convert_sse_stream, Annotated,
...@@ -24,7 +24,7 @@ use std::{collections::HashMap, pin::Pin}; ...@@ -24,7 +24,7 @@ use std::{collections::HashMap, pin::Pin};
type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>; type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
/// Aggregates a stream of [`ChatCompletionResponseDelta`]s into a single [`ChatCompletionResponse`]. /// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single [`NvCreateChatCompletionResponse`].
pub struct DeltaAggregator { pub struct DeltaAggregator {
id: String, id: String,
model: String, model: String,
...@@ -66,9 +66,9 @@ impl DeltaAggregator { ...@@ -66,9 +66,9 @@ impl DeltaAggregator {
} }
} }
/// Aggregates a stream of [`ChatCompletionResponseDelta`]s into a single [`ChatCompletionResponse`]. /// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single [`NvCreateChatCompletionResponse`].
pub async fn apply( pub async fn apply(
stream: DataStream<Annotated<ChatCompletionResponseDelta>>, stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>,
) -> Result<NvCreateChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String> {
let aggregator = stream let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move { .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
...@@ -184,12 +184,12 @@ impl NvCreateChatCompletionResponse { ...@@ -184,12 +184,12 @@ impl NvCreateChatCompletionResponse {
pub async fn from_sse_stream( pub async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>, stream: DataStream<Result<Message, SseCodecError>>,
) -> Result<NvCreateChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String> {
let stream = convert_sse_stream::<ChatCompletionResponseDelta>(stream); let stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(stream);
NvCreateChatCompletionResponse::from_annotated_stream(stream).await NvCreateChatCompletionResponse::from_annotated_stream(stream).await
} }
pub async fn from_annotated_stream( pub async fn from_annotated_stream(
stream: DataStream<Annotated<ChatCompletionResponseDelta>>, stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>,
) -> Result<NvCreateChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String> {
DeltaAggregator::apply(stream).await DeltaAggregator::apply(stream).await
} }
...@@ -207,7 +207,7 @@ mod tests { ...@@ -207,7 +207,7 @@ mod tests {
text: &str, text: &str,
role: Option<async_openai::types::Role>, role: Option<async_openai::types::Role>,
finish_reason: Option<async_openai::types::FinishReason>, finish_reason: Option<async_openai::types::FinishReason>,
) -> Annotated<ChatCompletionResponseDelta> { ) -> Annotated<NvCreateChatCompletionStreamResponse> {
// ALLOW: function_call is deprecated // ALLOW: function_call is deprecated
let delta = async_openai::types::ChatCompletionStreamResponseDelta { let delta = async_openai::types::ChatCompletionStreamResponseDelta {
content: Some(text.to_string()), content: Some(text.to_string()),
...@@ -234,7 +234,7 @@ mod tests { ...@@ -234,7 +234,7 @@ mod tests {
object: "chat.completion".to_string(), object: "chat.completion".to_string(),
}; };
let data = ChatCompletionResponseDelta { inner }; let data = NvCreateChatCompletionStreamResponse { inner };
Annotated { Annotated {
data: Some(data), data: Some(data),
...@@ -247,7 +247,8 @@ mod tests { ...@@ -247,7 +247,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_empty_stream() { async fn test_empty_stream() {
// Create an empty stream // Create an empty stream
let stream: DataStream<Annotated<ChatCompletionResponseDelta>> = Box::pin(stream::empty()); let stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>> =
Box::pin(stream::empty());
// Call DeltaAggregator::apply // Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await; let result = DeltaAggregator::apply(stream).await;
...@@ -375,7 +376,7 @@ mod tests { ...@@ -375,7 +376,7 @@ mod tests {
object: "chat.completion".to_string(), object: "chat.completion".to_string(),
}; };
let data = ChatCompletionResponseDelta { inner: delta }; let data = NvCreateChatCompletionStreamResponse { inner: delta };
// Wrap it in Annotated and create a stream // Wrap it in Annotated and create a stream
let annotated_delta = Annotated { let annotated_delta = Annotated {
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use super::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest}; use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
use crate::protocols::common; use crate::protocols::common;
impl NvCreateChatCompletionRequest { impl NvCreateChatCompletionRequest {
...@@ -135,11 +135,13 @@ impl DeltaGenerator { ...@@ -135,11 +135,13 @@ impl DeltaGenerator {
} }
} }
impl crate::protocols::openai::DeltaGeneratorExt<ChatCompletionResponseDelta> for DeltaGenerator { impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamResponse>
for DeltaGenerator
{
fn choice_from_postprocessor( fn choice_from_postprocessor(
&mut self, &mut self,
delta: crate::protocols::common::llm_backend::BackendOutput, delta: crate::protocols::common::llm_backend::BackendOutput,
) -> anyhow::Result<ChatCompletionResponseDelta> { ) -> anyhow::Result<NvCreateChatCompletionStreamResponse> {
// aggregate usage // aggregate usage
if self.options.enable_usage { if self.options.enable_usage {
self.usage.completion_tokens += delta.token_ids.len() as u32; self.usage.completion_tokens += delta.token_ids.len() as u32;
...@@ -163,7 +165,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<ChatCompletionResponseDelta> fo ...@@ -163,7 +165,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<ChatCompletionResponseDelta> fo
let index = 0; let index = 0;
let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs); let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
Ok(ChatCompletionResponseDelta { Ok(NvCreateChatCompletionStreamResponse {
inner: stream_response, inner: stream_response,
}) })
} }
......
...@@ -38,8 +38,8 @@ pub mod openai { ...@@ -38,8 +38,8 @@ pub mod openai {
use super::*; use super::*;
pub use protocols::openai::chat_completions::{ pub use protocols::openai::chat_completions::{
ChatCompletionResponseDelta, NvCreateChatCompletionRequest, NvCreateChatCompletionRequest, NvCreateChatCompletionResponse,
NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse,
}; };
/// A [`UnaryEngine`] implementation for the OpenAI Chat Completions API /// A [`UnaryEngine`] implementation for the OpenAI Chat Completions API
...@@ -49,7 +49,7 @@ pub mod openai { ...@@ -49,7 +49,7 @@ pub mod openai {
/// A [`ServerStreamingEngine`] implementation for the OpenAI Chat Completions API /// A [`ServerStreamingEngine`] implementation for the OpenAI Chat Completions API
pub type OpenAIChatCompletionsStreamingEngine = ServerStreamingEngine< pub type OpenAIChatCompletionsStreamingEngine = ServerStreamingEngine<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
Annotated<ChatCompletionResponseDelta>, Annotated<NvCreateChatCompletionStreamResponse>,
>; >;
} }
} }
...@@ -26,7 +26,7 @@ use triton_distributed_llm::http::service::{ ...@@ -26,7 +26,7 @@ use triton_distributed_llm::http::service::{
}; };
use triton_distributed_llm::protocols::{ use triton_distributed_llm::protocols::{
openai::{ openai::{
chat_completions::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{CompletionRequest, CompletionResponse}, completions::{CompletionRequest, CompletionResponse},
}, },
Annotated, Annotated,
...@@ -45,21 +45,21 @@ struct CounterEngine {} ...@@ -45,21 +45,21 @@ struct CounterEngine {}
impl impl
AsyncEngine< AsyncEngine<
SingleIn<NvCreateChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>, ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error, Error,
> for CounterEngine > for CounterEngine
{ {
async fn generate( async fn generate(
&self, &self,
request: SingleIn<NvCreateChatCompletionRequest>, request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
let (request, context) = request.transfer(()); let (request, context) = request.transfer(());
let ctx = context.context(); let ctx = context.context();
// ALLOW: max_tokens is deprecated in favor of completion_usage_tokens // ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
let max_tokens = request.inner.max_tokens.unwrap_or(0) as u64; let max_tokens = request.inner.max_tokens.unwrap_or(0) as u64;
// let generator = ChatCompletionResponseDelta::generator(request.model.clone()); // let generator = NvCreateChatCompletionStreamResponse::generator(request.model.clone());
let generator = request.response_generator(); let generator = request.response_generator();
let stream = stream! { let stream = stream! {
...@@ -67,7 +67,7 @@ impl ...@@ -67,7 +67,7 @@ impl
for i in 0..10 { for i in 0..10 {
let inner = generator.create_choice(i,Some(format!("choice {i}")), None, None); let inner = generator.create_choice(i,Some(format!("choice {i}")), None, None);
let output = ChatCompletionResponseDelta { let output = NvCreateChatCompletionStreamResponse {
inner, inner,
}; };
...@@ -85,14 +85,14 @@ struct AlwaysFailEngine {} ...@@ -85,14 +85,14 @@ struct AlwaysFailEngine {}
impl impl
AsyncEngine< AsyncEngine<
SingleIn<NvCreateChatCompletionRequest>, SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>, ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error, Error,
> for AlwaysFailEngine > for AlwaysFailEngine
{ {
async fn generate( async fn generate(
&self, &self,
_request: SingleIn<NvCreateChatCompletionRequest>, _request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> { ) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
Err(HttpError { Err(HttpError {
code: 403, code: 403,
message: "Always fail".to_string(), message: "Always fail".to_string(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment