"worker/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "f1f29171fea13b4bab5138718afc6df091db067c"
Commit c13ea718 authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: rename ChatCompletionResponse to NvCreateChatCompletionResponse (#291)

parent 96866f43
...@@ -41,7 +41,7 @@ use super::{ ...@@ -41,7 +41,7 @@ use super::{
}; };
use crate::protocols::openai::{ use crate::protocols::openai::{
chat_completions::ChatCompletionResponse, completions::CompletionResponse, chat_completions::NvCreateChatCompletionResponse, completions::CompletionResponse,
}; };
use crate::types::{ use crate::types::{
openai::{chat_completions::NvCreateChatCompletionRequest, completions::CompletionRequest}, openai::{chat_completions::NvCreateChatCompletionRequest, completions::CompletionRequest},
...@@ -272,7 +272,7 @@ async fn chat_completions( ...@@ -272,7 +272,7 @@ async fn chat_completions(
.keep_alive(KeepAlive::default()) .keep_alive(KeepAlive::default())
.into_response()) .into_response())
} else { } else {
let response = ChatCompletionResponse::from_annotated_stream(stream.into()) let response = NvCreateChatCompletionResponse::from_annotated_stream(stream.into())
.await .await
.map_err(|e| { .map_err(|e| {
tracing::error!( tracing::error!(
......
...@@ -35,7 +35,7 @@ pub struct NvCreateChatCompletionRequest { ...@@ -35,7 +35,7 @@ pub struct NvCreateChatCompletionRequest {
} }
#[derive(Serialize, Deserialize, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct ChatCompletionResponse { pub struct NvCreateChatCompletionResponse {
#[serde(flatten)] #[serde(flatten)]
pub inner: async_openai::types::CreateChatCompletionResponse, pub inner: async_openai::types::CreateChatCompletionResponse,
} }
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use super::{ChatCompletionResponse, ChatCompletionResponseDelta}; use super::{ChatCompletionResponseDelta, NvCreateChatCompletionResponse};
use crate::protocols::{ use crate::protocols::{
codec::{Message, SseCodecError}, codec::{Message, SseCodecError},
convert_sse_stream, Annotated, convert_sse_stream, Annotated,
...@@ -69,7 +69,7 @@ impl DeltaAggregator { ...@@ -69,7 +69,7 @@ impl DeltaAggregator {
/// Aggregates a stream of [`ChatCompletionResponseDelta`]s into a single [`ChatCompletionResponse`]. /// Aggregates a stream of [`ChatCompletionResponseDelta`]s into a single [`ChatCompletionResponse`].
pub async fn apply( pub async fn apply(
stream: DataStream<Annotated<ChatCompletionResponseDelta>>, stream: DataStream<Annotated<ChatCompletionResponseDelta>>,
) -> Result<ChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String> {
let aggregator = stream let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move { .fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
// these are cheap to move so we do it every time since we are consuming the delta // these are cheap to move so we do it every time since we are consuming the delta
...@@ -153,7 +153,7 @@ impl DeltaAggregator { ...@@ -153,7 +153,7 @@ impl DeltaAggregator {
service_tier: aggregator.service_tier, service_tier: aggregator.service_tier,
}; };
let response = ChatCompletionResponse { inner }; let response = NvCreateChatCompletionResponse { inner };
Ok(response) Ok(response)
} }
...@@ -180,17 +180,17 @@ impl From<DeltaChoice> for async_openai::types::ChatChoice { ...@@ -180,17 +180,17 @@ impl From<DeltaChoice> for async_openai::types::ChatChoice {
} }
} }
impl ChatCompletionResponse { impl NvCreateChatCompletionResponse {
pub async fn from_sse_stream( pub async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>, stream: DataStream<Result<Message, SseCodecError>>,
) -> Result<ChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String> {
let stream = convert_sse_stream::<ChatCompletionResponseDelta>(stream); let stream = convert_sse_stream::<ChatCompletionResponseDelta>(stream);
ChatCompletionResponse::from_annotated_stream(stream).await NvCreateChatCompletionResponse::from_annotated_stream(stream).await
} }
pub async fn from_annotated_stream( pub async fn from_annotated_stream(
stream: DataStream<Annotated<ChatCompletionResponseDelta>>, stream: DataStream<Annotated<ChatCompletionResponseDelta>>,
) -> Result<ChatCompletionResponse, String> { ) -> Result<NvCreateChatCompletionResponse, String> {
DeltaAggregator::apply(stream).await DeltaAggregator::apply(stream).await
} }
} }
......
...@@ -38,12 +38,13 @@ pub mod openai { ...@@ -38,12 +38,13 @@ pub mod openai {
use super::*; use super::*;
pub use protocols::openai::chat_completions::{ pub use protocols::openai::chat_completions::{
ChatCompletionResponse, ChatCompletionResponseDelta, NvCreateChatCompletionRequest, ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
NvCreateChatCompletionResponse,
}; };
/// A [`UnaryEngine`] implementation for the OpenAI Chat Completions API /// A [`UnaryEngine`] implementation for the OpenAI Chat Completions API
pub type OpenAIChatCompletionsUnaryEngine = pub type OpenAIChatCompletionsUnaryEngine =
UnaryEngine<NvCreateChatCompletionRequest, ChatCompletionResponse>; UnaryEngine<NvCreateChatCompletionRequest, NvCreateChatCompletionResponse>;
/// A [`ServerStreamingEngine`] implementation for the OpenAI Chat Completions API /// A [`ServerStreamingEngine`] implementation for the OpenAI Chat Completions API
pub type OpenAIChatCompletionsStreamingEngine = ServerStreamingEngine< pub type OpenAIChatCompletionsStreamingEngine = ServerStreamingEngine<
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
use futures::StreamExt; use futures::StreamExt;
use triton_distributed_llm::protocols::{ use triton_distributed_llm::protocols::{
codec::{create_message_stream, Message, SseCodecError}, codec::{create_message_stream, Message, SseCodecError},
openai::{chat_completions::ChatCompletionResponse, completions::CompletionResponse}, openai::{chat_completions::NvCreateChatCompletionResponse, completions::CompletionResponse},
ContentProvider, DataStream, ContentProvider, DataStream,
}; };
...@@ -34,7 +34,7 @@ async fn test_openai_chat_stream() { ...@@ -34,7 +34,7 @@ async fn test_openai_chat_stream() {
// note: we are only taking the first 16 messages to keep the size of the response small // note: we are only taking the first 16 messages to keep the size of the response small
let stream = create_message_stream(&data).take(16); let stream = create_message_stream(&data).take(16);
let result = ChatCompletionResponse::from_sse_stream(Box::pin(stream)) let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream))
.await .await
.unwrap(); .unwrap();
...@@ -57,7 +57,7 @@ async fn test_openai_chat_stream() { ...@@ -57,7 +57,7 @@ async fn test_openai_chat_stream() {
#[tokio::test] #[tokio::test]
async fn test_openai_chat_edge_case_multi_line_data() { async fn test_openai_chat_edge_case_multi_line_data() {
let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-multi-line-data"); let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-multi-line-data");
let result = ChatCompletionResponse::from_sse_stream(Box::pin(stream)) let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream))
.await .await
.unwrap(); .unwrap();
...@@ -78,7 +78,7 @@ async fn test_openai_chat_edge_case_multi_line_data() { ...@@ -78,7 +78,7 @@ async fn test_openai_chat_edge_case_multi_line_data() {
#[tokio::test] #[tokio::test]
async fn test_openai_chat_edge_case_comments_per_response() { async fn test_openai_chat_edge_case_comments_per_response() {
let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-comments_per_response"); let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-comments_per_response");
let result = ChatCompletionResponse::from_sse_stream(Box::pin(stream)) let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream))
.await .await
.unwrap(); .unwrap();
...@@ -99,7 +99,7 @@ async fn test_openai_chat_edge_case_comments_per_response() { ...@@ -99,7 +99,7 @@ async fn test_openai_chat_edge_case_comments_per_response() {
#[tokio::test] #[tokio::test]
async fn test_openai_chat_edge_case_invalid_deserialize_error() { async fn test_openai_chat_edge_case_invalid_deserialize_error() {
let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/invalid-deserialize_error"); let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/invalid-deserialize_error");
let result = ChatCompletionResponse::from_sse_stream(Box::pin(stream)).await; let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream)).await;
assert!(result.is_err()); assert!(result.is_err());
// insta::assert_debug_snapshot!(result); // insta::assert_debug_snapshot!(result);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment