Commit 86aff237 authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: using async_openai


Co-authored-by: default avatarGraham King <grahamk@nvidia.com>
parent d694ca6e
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -21,6 +21,7 @@ use futures::StreamExt;
use super::{CompletionChoice, CompletionResponse, CompletionUsage, LogprobResult};
use crate::protocols::{
codec::{Message, SseCodecError},
common::FinishReason,
convert_sse_stream, Annotated, DataStream,
};
......@@ -38,7 +39,7 @@ pub struct DeltaAggregator {
struct DeltaChoice {
index: u64,
text: String,
finish_reason: Option<crate::protocols::openai::chat_completions::FinishReason>,
finish_reason: Option<FinishReason>,
logprobs: Option<LogprobResult>,
}
......@@ -110,11 +111,7 @@ impl DeltaAggregator {
// todo - handle logprobs
if let Some(finish_reason) = choice.finish_reason {
let reason =
crate::protocols::openai::chat_completions::FinishReason::from_str(
&finish_reason,
)
.ok();
let reason = FinishReason::from_str(&finish_reason).ok();
state_choice.finish_reason = reason;
}
}
......
......@@ -37,6 +37,10 @@ pub mod openai {
pub mod chat_completions {
use super::*;
// pub use async_openai::types::CreateChatCompletionRequest as ChatCompletionRequest;
// pub use protocols::openai::chat_completions::{
// ChatCompletionResponse, ChatCompletionResponseDelta,
// };
pub use protocols::openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseDelta,
};
......
......@@ -40,8 +40,17 @@ async fn test_openai_chat_stream() {
// todo: provide a cleaner way to extract the content from choices
assert_eq!(
result.choices.first().unwrap().content(),
result
.inner
.choices
.first()
.unwrap()
.message
.content
.clone()
.expect("there to be content"),
"Deep learning is a subfield of machine learning that involves the use of artificial"
.to_string()
);
}
......@@ -52,7 +61,18 @@ async fn test_openai_chat_edge_case_multi_line_data() {
.await
.unwrap();
assert_eq!(result.choices.first().unwrap().content(), "Deep learning");
assert_eq!(
result
.inner
.choices
.first()
.unwrap()
.message
.content
.clone()
.expect("there to be content"),
"Deep learning".to_string()
);
}
#[tokio::test]
......@@ -62,7 +82,18 @@ async fn test_openai_chat_edge_case_comments_per_response() {
.await
.unwrap();
assert_eq!(result.choices.first().unwrap().content(), "Deep learning");
assert_eq!(
result
.inner
.choices
.first()
.unwrap()
.message
.content
.clone()
.expect("there to be content"),
"Deep learning".to_string()
);
}
#[tokio::test]
......
......@@ -40,6 +40,7 @@ use triton_distributed_runtime::{
struct CounterEngine {}
#[allow(deprecated)]
#[async_trait]
impl
AsyncEngine<
......@@ -55,7 +56,8 @@ impl
let (request, context) = request.transfer(());
let ctx = context.context();
let max_tokens = request.max_tokens.unwrap_or(0) as u64;
// ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
let max_tokens = request.inner.max_tokens.unwrap_or(0) as u64;
// let generator = ChatCompletionResponseDelta::generator(request.model.clone());
let generator = request.response_generator();
......@@ -63,8 +65,13 @@ impl
let stream = stream! {
tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await;
for i in 0..10 {
let choice = generator.create_choice(i as u64,Some(format!("choice {i}")), None, None);
yield Annotated::from_data(choice);
let inner = generator.create_choice(i,Some(format!("choice {i}")), None, None);
let output = ChatCompletionResponseDelta {
inner,
};
yield Annotated::from_data(output);
}
};
......@@ -174,6 +181,7 @@ fn inc_counter(
expected[index] += 1;
}
#[allow(deprecated)]
#[tokio::test]
async fn test_http_service() {
let service = HttpService::builder().port(8989).build().unwrap();
......@@ -207,14 +215,31 @@ async fn test_http_service() {
let client = reqwest::Client::new();
let mut request = ChatCompletionRequest::builder()
let message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"hi".to_string(),
),
name: None,
},
);
let mut request = async_openai::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.add_user_message("hi")
.messages(vec![message])
.build()
.unwrap();
.expect("Failed to build request");
// let mut request = ChatCompletionRequest::builder()
// .model("foo")
// .add_user_message("hi")
// .build()
// .unwrap();
// ==== ChatCompletions / Stream / Success ====
request.stream = Some(true);
// ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
request.max_tokens = Some(3000);
let response = client
......@@ -293,6 +318,8 @@ async fn test_http_service() {
// ==== ChatCompletions / Unary / Success ====
request.stream = Some(false);
// ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
request.max_tokens = Some(0);
let future = client
......@@ -315,6 +342,8 @@ async fn test_http_service() {
// ==== ChatCompletions / Stream / Error ====
request.model = "bar".to_string();
// ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
request.max_tokens = Some(0);
request.stream = Some(true);
......
......@@ -136,7 +136,7 @@ fn build_samples() -> Result<Vec<CompletionSample>, String> {
samples.push(CompletionSample::new(
"should have prompt, model, and max_tokens fields",
|builder| builder.max_tokens(10),
|builder| builder.max_tokens(10_u32),
)?);
samples.push(CompletionSample::new(
......
......@@ -18,9 +18,7 @@ use anyhow::Ok;
use serde::{Deserialize, Serialize};
use triton_distributed_llm::model_card::model::{ModelDeploymentCard, PromptContextMixin};
use triton_distributed_llm::preprocessor::prompt::PromptFormatter;
use triton_distributed_llm::protocols::openai::chat_completions::{
ChatCompletionMessage, ChatCompletionRequest, Tool, ToolChoiceType,
};
use triton_distributed_llm::protocols::openai::chat_completions::ChatCompletionRequest;
use hf_hub::{api::tokio::ApiBuilder, Cache, Repo, RepoType};
......@@ -217,29 +215,40 @@ const TOOLS: &str = r#"
]
"#;
// Notes:
// protocols::openai::chat_completions::ChatCompletionMessage -> async_openai::types::ChatCompletionRequestMessage
// protocols::openai::chat_completions::Tool -> async_openai::types::ChatCompletionTool
// protocols::openai::chat_completions::ToolChoiceType -> async_openai::types::ChatCompletionToolChoiceOption
#[derive(Serialize, Deserialize)]
struct Request {
messages: Vec<ChatCompletionMessage>,
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolChoiceType>,
messages: Vec<async_openai::types::ChatCompletionRequestMessage>,
tools: Option<Vec<async_openai::types::ChatCompletionTool>>,
tool_choice: Option<async_openai::types::ChatCompletionToolChoiceOption>,
}
impl Request {
fn from(
messages: &str,
tools: Option<&str>,
tool_choice: Option<ToolChoiceType>,
tool_choice: Option<async_openai::types::ChatCompletionToolChoiceOption>,
model: String,
) -> ChatCompletionRequest {
let messages: Vec<ChatCompletionMessage> = serde_json::from_str(messages).unwrap();
let tools: Option<Vec<Tool>> = tools.map(|x| serde_json::from_str(x).unwrap());
ChatCompletionRequest::builder()
let messages: Vec<async_openai::types::ChatCompletionRequestMessage> =
serde_json::from_str(messages).unwrap();
let tools: Option<Vec<async_openai::types::ChatCompletionTool>> =
tools.map(|x| serde_json::from_str(x).unwrap());
let tools = tools.unwrap();
let tool_choice = tool_choice.unwrap();
let inner = async_openai::types::CreateChatCompletionRequestArgs::default()
.model(model)
.messages(messages)
.tools(tools)
.tool_choice(tool_choice)
.build()
.unwrap()
.unwrap();
ChatCompletionRequest { inner, nvext: None }
}
}
......@@ -295,7 +304,7 @@ async fn test_single_turn_with_tools() {
let request = Request::from(
SINGLE_CHAT_MESSAGE,
Some(TOOLS),
Some(ToolChoiceType::Auto),
Some(async_openai::types::ChatCompletionToolChoiceOption::Auto),
mdc.slug().to_string(),
);
let formatted_prompt = formatter.render(&request).unwrap();
......@@ -402,7 +411,7 @@ async fn test_multi_turn_with_system_with_tools() {
let request = Request::from(
THREE_TURN_CHAT_MESSAGE_WITH_SYSTEM,
Some(TOOLS),
Some(ToolChoiceType::Auto),
Some(async_openai::types::ChatCompletionToolChoiceOption::Auto),
mdc.slug().to_string(),
);
let formatted_prompt = formatter.render(&request).unwrap();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment