Commit c13ea718 authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: rename ChatCompletionResponse to NvCreateChatCompletionResponse (#291)

parent 96866f43
......@@ -41,7 +41,7 @@ use super::{
};
use crate::protocols::openai::{
chat_completions::ChatCompletionResponse, completions::CompletionResponse,
chat_completions::NvCreateChatCompletionResponse, completions::CompletionResponse,
};
use crate::types::{
openai::{chat_completions::NvCreateChatCompletionRequest, completions::CompletionRequest},
......@@ -272,7 +272,7 @@ async fn chat_completions(
.keep_alive(KeepAlive::default())
.into_response())
} else {
let response = ChatCompletionResponse::from_annotated_stream(stream.into())
let response = NvCreateChatCompletionResponse::from_annotated_stream(stream.into())
.await
.map_err(|e| {
tracing::error!(
......
......@@ -35,7 +35,7 @@ pub struct NvCreateChatCompletionRequest {
}
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct ChatCompletionResponse {
pub struct NvCreateChatCompletionResponse {
#[serde(flatten)]
pub inner: async_openai::types::CreateChatCompletionResponse,
}
......
......@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use super::{ChatCompletionResponse, ChatCompletionResponseDelta};
use super::{ChatCompletionResponseDelta, NvCreateChatCompletionResponse};
use crate::protocols::{
codec::{Message, SseCodecError},
convert_sse_stream, Annotated,
......@@ -69,7 +69,7 @@ impl DeltaAggregator {
/// Aggregates a stream of [`ChatCompletionResponseDelta`]s into a single [`ChatCompletionResponse`].
pub async fn apply(
stream: DataStream<Annotated<ChatCompletionResponseDelta>>,
) -> Result<ChatCompletionResponse, String> {
) -> Result<NvCreateChatCompletionResponse, String> {
let aggregator = stream
.fold(DeltaAggregator::new(), |mut aggregator, delta| async move {
// these are cheap to move so we do it every time since we are consuming the delta
......@@ -153,7 +153,7 @@ impl DeltaAggregator {
service_tier: aggregator.service_tier,
};
let response = ChatCompletionResponse { inner };
let response = NvCreateChatCompletionResponse { inner };
Ok(response)
}
......@@ -180,17 +180,17 @@ impl From<DeltaChoice> for async_openai::types::ChatChoice {
}
}
impl ChatCompletionResponse {
impl NvCreateChatCompletionResponse {
pub async fn from_sse_stream(
stream: DataStream<Result<Message, SseCodecError>>,
) -> Result<ChatCompletionResponse, String> {
) -> Result<NvCreateChatCompletionResponse, String> {
let stream = convert_sse_stream::<ChatCompletionResponseDelta>(stream);
ChatCompletionResponse::from_annotated_stream(stream).await
NvCreateChatCompletionResponse::from_annotated_stream(stream).await
}
pub async fn from_annotated_stream(
stream: DataStream<Annotated<ChatCompletionResponseDelta>>,
) -> Result<ChatCompletionResponse, String> {
) -> Result<NvCreateChatCompletionResponse, String> {
DeltaAggregator::apply(stream).await
}
}
......
......@@ -38,12 +38,13 @@ pub mod openai {
use super::*;
pub use protocols::openai::chat_completions::{
ChatCompletionResponse, ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
ChatCompletionResponseDelta, NvCreateChatCompletionRequest,
NvCreateChatCompletionResponse,
};
/// A [`UnaryEngine`] implementation for the OpenAI Chat Completions API
pub type OpenAIChatCompletionsUnaryEngine =
UnaryEngine<NvCreateChatCompletionRequest, ChatCompletionResponse>;
UnaryEngine<NvCreateChatCompletionRequest, NvCreateChatCompletionResponse>;
/// A [`ServerStreamingEngine`] implementation for the OpenAI Chat Completions API
pub type OpenAIChatCompletionsStreamingEngine = ServerStreamingEngine<
......
......@@ -16,7 +16,7 @@
use futures::StreamExt;
use triton_distributed_llm::protocols::{
codec::{create_message_stream, Message, SseCodecError},
openai::{chat_completions::ChatCompletionResponse, completions::CompletionResponse},
openai::{chat_completions::NvCreateChatCompletionResponse, completions::CompletionResponse},
ContentProvider, DataStream,
};
......@@ -34,7 +34,7 @@ async fn test_openai_chat_stream() {
// note: we are only taking the first 16 messages to keep the size of the response small
let stream = create_message_stream(&data).take(16);
let result = ChatCompletionResponse::from_sse_stream(Box::pin(stream))
let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream))
.await
.unwrap();
......@@ -57,7 +57,7 @@ async fn test_openai_chat_stream() {
#[tokio::test]
async fn test_openai_chat_edge_case_multi_line_data() {
let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-multi-line-data");
let result = ChatCompletionResponse::from_sse_stream(Box::pin(stream))
let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream))
.await
.unwrap();
......@@ -78,7 +78,7 @@ async fn test_openai_chat_edge_case_multi_line_data() {
#[tokio::test]
async fn test_openai_chat_edge_case_comments_per_response() {
let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/valid-comments_per_response");
let result = ChatCompletionResponse::from_sse_stream(Box::pin(stream))
let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream))
.await
.unwrap();
......@@ -99,7 +99,7 @@ async fn test_openai_chat_edge_case_comments_per_response() {
#[tokio::test]
async fn test_openai_chat_edge_case_invalid_deserialize_error() {
let stream = create_stream(CHAT_ROOT_PATH, "edge_cases/invalid-deserialize_error");
let result = ChatCompletionResponse::from_sse_stream(Box::pin(stream)).await;
let result = NvCreateChatCompletionResponse::from_sse_stream(Box::pin(stream)).await;
assert!(result.is_err());
// insta::assert_debug_snapshot!(result);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment