"git@developer.sourcefind.cn:OpenDAS/dcnv3.git" did not exist on "1b0b8836694955ad47552175e26602631c8d886f"
Commit 110f3f8c authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: rename ChatCompletionResponseDelta to NvCreateChatCompletionStreamResponse (#292)

parent c13ea718
......@@ -19,7 +19,9 @@ use triton_distributed_llm::{
model_type::ModelType,
preprocessor::OpenAIPreprocessor,
types::{
openai::chat_completions::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest},
openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
},
Annotated,
},
};
......@@ -55,7 +57,7 @@ pub async fn run(
} => {
let frontend = SegmentSource::<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(*card.clone())
.await?
......
......@@ -21,7 +21,9 @@ use triton_distributed_llm::{
model_type::ModelType,
preprocessor::OpenAIPreprocessor,
types::{
openai::chat_completions::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest},
openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
},
Annotated,
},
};
......@@ -75,7 +77,7 @@ pub async fn run(
} => {
let frontend = ServiceFrontend::<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(*card.clone())
.await?
......
......@@ -23,7 +23,7 @@ use triton_distributed_llm::{
preprocessor::OpenAIPreprocessor,
types::{
openai::chat_completions::{
ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
OpenAIChatCompletionsStreamingEngine,
},
Annotated,
......@@ -72,7 +72,7 @@ pub async fn run(
} => {
let frontend = ServiceFrontend::<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
>::new();
let preprocessor = OpenAIPreprocessor::new(*card.clone())
.await?
......
......@@ -21,7 +21,7 @@ use triton_distributed_llm::{
model_card::model::ModelDeploymentCard,
types::{
openai::chat_completions::{
ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
OpenAIChatCompletionsStreamingEngine,
},
Annotated,
......@@ -113,7 +113,7 @@ pub struct Flags {
pub enum EngineConfig {
/// An remote networked engine we don't know about yet
/// We don't have the pre-processor yet so this is only text requests. Type will change later.
Dynamic(Client<NvCreateChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>),
Dynamic(Client<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>),
/// A Full service engine does it's own tokenization and prompt formatting.
StaticFull {
......@@ -223,7 +223,7 @@ pub async fn run(
.namespace(endpoint.namespace)?
.component(endpoint.component)?
.endpoint(endpoint.name)
.client::<NvCreateChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>()
.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>()
.await?;
tracing::info!("Waiting for remote {}...", client.path());
......
......@@ -19,7 +19,7 @@ use async_stream::stream;
use async_trait::async_trait;
use triton_distributed_llm::protocols::openai::chat_completions::{
ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
};
use triton_distributed_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
use triton_distributed_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, ResponseStream};
......@@ -41,14 +41,14 @@ pub fn make_engine_full() -> OpenAIChatCompletionsStreamingEngine {
impl
AsyncEngine<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error,
> for EchoEngineFull
{
async fn generate(
&self,
incoming_request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> {
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
let (request, context) = incoming_request.transfer(());
let deltas = request.response_generator();
let ctx = context.context();
......@@ -72,7 +72,7 @@ impl
// we are returning characters not tokens, so speed up some
tokio::time::sleep(TOKEN_ECHO_DELAY/2).await;
let inner = deltas.create_choice(0, Some(c.to_string()), None, None);
let response = ChatCompletionResponseDelta {
let response = NvCreateChatCompletionStreamResponse {
inner,
};
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
......@@ -80,7 +80,7 @@ impl
}
let inner = deltas.create_choice(0, None, Some(async_openai::types::FinishReason::Stop), None);
let response = ChatCompletionResponseDelta {
let response = NvCreateChatCompletionStreamResponse {
inner,
};
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
......
......@@ -34,7 +34,7 @@ use triton_distributed_runtime::pipeline::{Error, ManyOut, SingleIn};
use triton_distributed_runtime::protocols::annotated::Annotated;
use crate::protocols::openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponseDelta,
ChatCompletionRequest, NvCreateChatCompletionStreamResponse,
};
use crate::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
......@@ -161,14 +161,14 @@ impl MistralRsEngine {
impl
AsyncEngine<
SingleIn<ChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error,
> for MistralRsEngine
{
async fn generate(
&self,
request: SingleIn<ChatCompletionRequest>,
) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> {
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
let (request, context) = request.transfer(());
let ctx = context.context();
let (tx, mut rx) = channel(10_000);
......@@ -286,7 +286,7 @@ impl
system_fingerprint: Some(c.system_fingerprint),
service_tier: None,
};
let delta = ChatCompletionResponseDelta{inner};
let delta = NvCreateChatCompletionStreamResponse{inner};
let ann = Annotated{
id: None,
data: Some(delta),
......
......@@ -28,7 +28,7 @@ use triton_distributed_runtime::{
use super::ModelManager;
use crate::model_type::ModelType;
use crate::protocols::openai::chat_completions::{
ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
};
use crate::protocols::openai::completions::{CompletionRequest, CompletionResponse};
use tracing;
......@@ -135,7 +135,7 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str,
.namespace(model_entry.endpoint.namespace)?
.component(model_entry.endpoint.component)?
.endpoint(model_entry.endpoint.name)
.client::<NvCreateChatCompletionRequest, Annotated<ChatCompletionResponseDelta>>()
.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>()
.await?;
state
.manager
......
......@@ -44,7 +44,7 @@ use triton_distributed_runtime::protocols::annotated::{Annotated, AnnotationsPro
use crate::protocols::{
common::{SamplingOptionsProvider, StopConditionsProvider},
openai::{
chat_completions::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest},
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{CompletionRequest, CompletionResponse},
nvext::NvExtProvider,
DeltaGeneratorExt,
......@@ -225,7 +225,7 @@ impl OpenAIPreprocessor {
tracing::trace!(
request_id = inner.context.id(),
"OpenAI ChatCompletionResponseDelta: {:?}",
"OpenAI NvCreateChatCompletionStreamResponse: {:?}",
response
);
......@@ -252,7 +252,7 @@ impl OpenAIPreprocessor {
impl
Operator<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
SingleIn<BackendInput>,
ManyOut<Annotated<BackendOutput>>,
> for OpenAIPreprocessor
......@@ -263,7 +263,7 @@ impl
next: Arc<
dyn AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<BackendOutput>>, Error>,
>,
) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> {
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
// unpack the request
let (request, context) = request.into_parts();
......@@ -281,7 +281,7 @@ impl
let common_request = context.map(|_| common_request);
// create a stream of annotations this will be prepend to the response stream
let annotations: Vec<Annotated<ChatCompletionResponseDelta>> = annotations
let annotations: Vec<Annotated<NvCreateChatCompletionStreamResponse>> = annotations
.into_iter()
.flat_map(|(k, v)| Annotated::from_annotation(k, &v))
.collect();
......
......@@ -640,7 +640,7 @@ data: [DONE]
#[tokio::test]
async fn test_openai_chat_stream() {
use crate::protocols::openai::chat_completions::ChatCompletionResponseDelta;
use crate::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse;
// let cursor = Cursor::new(SAMPLE_CHAT_DATA);
// let mut framed = FramedRead::new(cursor, SseLineCodec::new());
......@@ -652,7 +652,7 @@ data: [DONE]
loop {
match stream.next().await {
Some(Ok(message)) => {
let delta: ChatCompletionResponseDelta =
let delta: NvCreateChatCompletionStreamResponse =
serde_json::from_str(&message.data.unwrap()).unwrap();
counter += 1;
println!("counter: {}", counter);
......
......@@ -47,7 +47,7 @@ pub struct ChatCompletionContent {
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct ChatCompletionResponseDelta {
pub struct NvCreateChatCompletionStreamResponse {
#[serde(flatten)]
pub inner: async_openai::types::CreateChatCompletionStreamResponse,
}
......
......@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use super::{ChatCompletionResponseDelta, NvCreateChatCompletionResponse};
use super::{NvCreateChatCompletionResponse, NvCreateChatCompletionStreamResponse};
use crate::protocols::{
codec::{Message, SseCodecError},
convert_sse_stream, Annotated,
......@@ -24,7 +24,7 @@ use std::{collections::HashMap, pin::Pin};
type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
/// Aggregates a stream of [`ChatCompletionResponseDelta`]s into a single [`ChatCompletionResponse`].
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single [`NvCreateChatCompletionResponse`].
pub struct DeltaAggregator {
id: String,
model: String,
......@@ -66,9 +66,9 @@ impl DeltaAggregator {
}
}
/// Aggregates a stream of [`ChatCompletionResponseDelta`]s into a single [`ChatCompletionResponse`].
/// Aggregates a stream of [`NvCreateChatCompletionStreamResponse`]s into a single [`NvCreateChatCompletionResponse`].
pub async fn apply(
stream: DataStream<Annotated<ChatCompletionResponseDelta>>,
stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>,
) -> Result<NvCreateChatCompletionResponse, String> {
let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
......@@ -184,12 +184,12 @@ impl NvCreateChatCompletionResponse {
pub async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>,
) -> Result<NvCreateChatCompletionResponse, String> {
let stream = convert_sse_stream::<ChatCompletionResponseDelta>(stream);
let stream = convert_sse_stream::<NvCreateChatCompletionStreamResponse>(stream);
NvCreateChatCompletionResponse::from_annotated_stream(stream).await
}
pub async fn from_annotated_stream(
stream: DataStream<Annotated<ChatCompletionResponseDelta>>,
stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>>,
) -> Result<NvCreateChatCompletionResponse, String> {
DeltaAggregator::apply(stream).await
}
......@@ -207,7 +207,7 @@ mod tests {
text: &str,
role: Option<async_openai::types::Role>,
finish_reason: Option<async_openai::types::FinishReason>,
) -> Annotated<ChatCompletionResponseDelta> {
) -> Annotated<NvCreateChatCompletionStreamResponse> {
// ALLOW: function_call is deprecated
let delta = async_openai::types::ChatCompletionStreamResponseDelta {
content: Some(text.to_string()),
......@@ -234,7 +234,7 @@ mod tests {
object: "chat.completion".to_string(),
};
let data = ChatCompletionResponseDelta { inner };
let data = NvCreateChatCompletionStreamResponse { inner };
Annotated {
data: Some(data),
......@@ -247,7 +247,8 @@ mod tests {
#[tokio::test]
async fn test_empty_stream() {
// Create an empty stream
let stream: DataStream<Annotated<ChatCompletionResponseDelta>> = Box::pin(stream::empty());
let stream: DataStream<Annotated<NvCreateChatCompletionStreamResponse>> =
Box::pin(stream::empty());
// Call DeltaAggregator::apply
let result = DeltaAggregator::apply(stream).await;
......@@ -375,7 +376,7 @@ mod tests {
object: "chat.completion".to_string(),
};
let data = ChatCompletionResponseDelta { inner: delta };
let data = NvCreateChatCompletionStreamResponse { inner: delta };
// Wrap it in Annotated and create a stream
let annotated_delta = Annotated {
......
......@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use super::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest};
use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse};
use crate::protocols::common;
impl NvCreateChatCompletionRequest {
......@@ -135,11 +135,13 @@ impl DeltaGenerator {
}
}
impl crate::protocols::openai::DeltaGeneratorExt<ChatCompletionResponseDelta> for DeltaGenerator {
impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamResponse>
for DeltaGenerator
{
fn choice_from_postprocessor(
&mut self,
delta: crate::protocols::common::llm_backend::BackendOutput,
) -> anyhow::Result<ChatCompletionResponseDelta> {
) -> anyhow::Result<NvCreateChatCompletionStreamResponse> {
// aggregate usage
if self.options.enable_usage {
self.usage.completion_tokens += delta.token_ids.len() as u32;
......@@ -163,7 +165,7 @@ impl crate::protocols::openai::DeltaGeneratorExt<ChatCompletionResponseDelta> fo
let index = 0;
let stream_response = self.create_choice(index, delta.text, finish_reason, logprobs);
Ok(ChatCompletionResponseDelta {
Ok(NvCreateChatCompletionStreamResponse {
inner: stream_response,
})
}
......
......@@ -38,8 +38,8 @@ pub mod openai {
use super::*;
pub use protocols::openai::chat_completions::{
ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
NvCreateChatCompletionResponse,
NvCreateChatCompletionRequest, NvCreateChatCompletionResponse,
NvCreateChatCompletionStreamResponse,
};
/// A [`UnaryEngine`] implementation for the OpenAI Chat Completions API
......@@ -49,7 +49,7 @@ pub mod openai {
/// A [`ServerStreamingEngine`] implementation for the OpenAI Chat Completions API
pub type OpenAIChatCompletionsStreamingEngine = ServerStreamingEngine<
NvCreateChatCompletionRequest,
Annotated<ChatCompletionResponseDelta>,
Annotated<NvCreateChatCompletionStreamResponse>,
>;
}
}
......@@ -26,7 +26,7 @@ use triton_distributed_llm::http::service::{
};
use triton_distributed_llm::protocols::{
openai::{
chat_completions::{ChatCompletionResponseDelta, NvCreateChatCompletionRequest},
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{CompletionRequest, CompletionResponse},
},
Annotated,
......@@ -45,21 +45,21 @@ struct CounterEngine {}
impl
AsyncEngine<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error,
> for CounterEngine
{
async fn generate(
&self,
request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> {
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
let (request, context) = request.transfer(());
let ctx = context.context();
// ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
let max_tokens = request.inner.max_tokens.unwrap_or(0) as u64;
// let generator = ChatCompletionResponseDelta::generator(request.model.clone());
// let generator = NvCreateChatCompletionStreamResponse::generator(request.model.clone());
let generator = request.response_generator();
let stream = stream! {
......@@ -67,7 +67,7 @@ impl
for i in 0..10 {
let inner = generator.create_choice(i,Some(format!("choice {i}")), None, None);
let output = ChatCompletionResponseDelta {
let output = NvCreateChatCompletionStreamResponse {
inner,
};
......@@ -85,14 +85,14 @@ struct AlwaysFailEngine {}
impl
AsyncEngine<
SingleIn<NvCreateChatCompletionRequest>,
ManyOut<Annotated<ChatCompletionResponseDelta>>,
ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>,
Error,
> for AlwaysFailEngine
{
async fn generate(
&self,
_request: SingleIn<NvCreateChatCompletionRequest>,
) -> Result<ManyOut<Annotated<ChatCompletionResponseDelta>>, Error> {
) -> Result<ManyOut<Annotated<NvCreateChatCompletionStreamResponse>>, Error> {
Err(HttpError {
code: 403,
message: "Always fail".to_string(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment