Commit 86aff237 authored by Paul Hendricks's avatar Paul Hendricks Committed by GitHub
Browse files

refactor: using async_openai


Co-authored-by: default avatarGraham King <grahamk@nvidia.com>
parent d694ca6e
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -21,6 +21,7 @@ use futures::StreamExt; ...@@ -21,6 +21,7 @@ use futures::StreamExt;
use super::{CompletionChoice, CompletionResponse, CompletionUsage, LogprobResult}; use super::{CompletionChoice, CompletionResponse, CompletionUsage, LogprobResult};
use crate::protocols::{ use crate::protocols::{
codec::{Message, SseCodecError}, codec::{Message, SseCodecError},
common::FinishReason,
convert_sse_stream, Annotated, DataStream, convert_sse_stream, Annotated, DataStream,
}; };
...@@ -38,7 +39,7 @@ pub struct DeltaAggregator { ...@@ -38,7 +39,7 @@ pub struct DeltaAggregator {
struct DeltaChoice { struct DeltaChoice {
index: u64, index: u64,
text: String, text: String,
finish_reason: Option<crate::protocols::openai::chat_completions::FinishReason>, finish_reason: Option<FinishReason>,
logprobs: Option<LogprobResult>, logprobs: Option<LogprobResult>,
} }
...@@ -110,11 +111,7 @@ impl DeltaAggregator { ...@@ -110,11 +111,7 @@ impl DeltaAggregator {
// todo - handle logprobs // todo - handle logprobs
if let Some(finish_reason) = choice.finish_reason { if let Some(finish_reason) = choice.finish_reason {
let reason = let reason = FinishReason::from_str(&finish_reason).ok();
crate::protocols::openai::chat_completions::FinishReason::from_str(
&finish_reason,
)
.ok();
state_choice.finish_reason = reason; state_choice.finish_reason = reason;
} }
} }
......
...@@ -37,6 +37,10 @@ pub mod openai { ...@@ -37,6 +37,10 @@ pub mod openai {
pub mod chat_completions { pub mod chat_completions {
use super::*; use super::*;
// pub use async_openai::types::CreateChatCompletionRequest as ChatCompletionRequest;
// pub use protocols::openai::chat_completions::{
// ChatCompletionResponse, ChatCompletionResponseDelta,
// };
pub use protocols::openai::chat_completions::{ pub use protocols::openai::chat_completions::{
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseDelta, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseDelta,
}; };
......
...@@ -40,8 +40,17 @@ async fn test_openai_chat_stream() { ...@@ -40,8 +40,17 @@ async fn test_openai_chat_stream() {
// todo: provide a cleaner way to extract the content from choices // todo: provide a cleaner way to extract the content from choices
assert_eq!( assert_eq!(
result.choices.first().unwrap().content(), result
.inner
.choices
.first()
.unwrap()
.message
.content
.clone()
.expect("there to be content"),
"Deep learning is a subfield of machine learning that involves the use of artificial" "Deep learning is a subfield of machine learning that involves the use of artificial"
.to_string()
); );
} }
...@@ -52,7 +61,18 @@ async fn test_openai_chat_edge_case_multi_line_data() { ...@@ -52,7 +61,18 @@ async fn test_openai_chat_edge_case_multi_line_data() {
.await .await
.unwrap(); .unwrap();
assert_eq!(result.choices.first().unwrap().content(), "Deep learning"); assert_eq!(
result
.inner
.choices
.first()
.unwrap()
.message
.content
.clone()
.expect("there to be content"),
"Deep learning".to_string()
);
} }
#[tokio::test] #[tokio::test]
...@@ -62,7 +82,18 @@ async fn test_openai_chat_edge_case_comments_per_response() { ...@@ -62,7 +82,18 @@ async fn test_openai_chat_edge_case_comments_per_response() {
.await .await
.unwrap(); .unwrap();
assert_eq!(result.choices.first().unwrap().content(), "Deep learning"); assert_eq!(
result
.inner
.choices
.first()
.unwrap()
.message
.content
.clone()
.expect("there to be content"),
"Deep learning".to_string()
);
} }
#[tokio::test] #[tokio::test]
......
...@@ -40,6 +40,7 @@ use triton_distributed_runtime::{ ...@@ -40,6 +40,7 @@ use triton_distributed_runtime::{
struct CounterEngine {} struct CounterEngine {}
#[allow(deprecated)]
#[async_trait] #[async_trait]
impl impl
AsyncEngine< AsyncEngine<
...@@ -55,7 +56,8 @@ impl ...@@ -55,7 +56,8 @@ impl
let (request, context) = request.transfer(()); let (request, context) = request.transfer(());
let ctx = context.context(); let ctx = context.context();
let max_tokens = request.max_tokens.unwrap_or(0) as u64; // ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
let max_tokens = request.inner.max_tokens.unwrap_or(0) as u64;
// let generator = ChatCompletionResponseDelta::generator(request.model.clone()); // let generator = ChatCompletionResponseDelta::generator(request.model.clone());
let generator = request.response_generator(); let generator = request.response_generator();
...@@ -63,8 +65,13 @@ impl ...@@ -63,8 +65,13 @@ impl
let stream = stream! { let stream = stream! {
tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await; tokio::time::sleep(std::time::Duration::from_millis(max_tokens)).await;
for i in 0..10 { for i in 0..10 {
let choice = generator.create_choice(i as u64,Some(format!("choice {i}")), None, None); let inner = generator.create_choice(i,Some(format!("choice {i}")), None, None);
yield Annotated::from_data(choice);
let output = ChatCompletionResponseDelta {
inner,
};
yield Annotated::from_data(output);
} }
}; };
...@@ -174,6 +181,7 @@ fn inc_counter( ...@@ -174,6 +181,7 @@ fn inc_counter(
expected[index] += 1; expected[index] += 1;
} }
#[allow(deprecated)]
#[tokio::test] #[tokio::test]
async fn test_http_service() { async fn test_http_service() {
let service = HttpService::builder().port(8989).build().unwrap(); let service = HttpService::builder().port(8989).build().unwrap();
...@@ -207,14 +215,31 @@ async fn test_http_service() { ...@@ -207,14 +215,31 @@ async fn test_http_service() {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let mut request = ChatCompletionRequest::builder() let message = async_openai::types::ChatCompletionRequestMessage::User(
async_openai::types::ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
"hi".to_string(),
),
name: None,
},
);
let mut request = async_openai::types::CreateChatCompletionRequestArgs::default()
.model("foo") .model("foo")
.add_user_message("hi") .messages(vec![message])
.build() .build()
.unwrap(); .expect("Failed to build request");
// let mut request = ChatCompletionRequest::builder()
// .model("foo")
// .add_user_message("hi")
// .build()
// .unwrap();
// ==== ChatCompletions / Stream / Success ==== // ==== ChatCompletions / Stream / Success ====
request.stream = Some(true); request.stream = Some(true);
// ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
request.max_tokens = Some(3000); request.max_tokens = Some(3000);
let response = client let response = client
...@@ -293,6 +318,8 @@ async fn test_http_service() { ...@@ -293,6 +318,8 @@ async fn test_http_service() {
// ==== ChatCompletions / Unary / Success ==== // ==== ChatCompletions / Unary / Success ====
request.stream = Some(false); request.stream = Some(false);
// ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
request.max_tokens = Some(0); request.max_tokens = Some(0);
let future = client let future = client
...@@ -315,6 +342,8 @@ async fn test_http_service() { ...@@ -315,6 +342,8 @@ async fn test_http_service() {
// ==== ChatCompletions / Stream / Error ==== // ==== ChatCompletions / Stream / Error ====
request.model = "bar".to_string(); request.model = "bar".to_string();
// ALLOW: max_tokens is deprecated in favor of completion_usage_tokens
request.max_tokens = Some(0); request.max_tokens = Some(0);
request.stream = Some(true); request.stream = Some(true);
......
...@@ -136,7 +136,7 @@ fn build_samples() -> Result<Vec<CompletionSample>, String> { ...@@ -136,7 +136,7 @@ fn build_samples() -> Result<Vec<CompletionSample>, String> {
samples.push(CompletionSample::new( samples.push(CompletionSample::new(
"should have prompt, model, and max_tokens fields", "should have prompt, model, and max_tokens fields",
|builder| builder.max_tokens(10), |builder| builder.max_tokens(10_u32),
)?); )?);
samples.push(CompletionSample::new( samples.push(CompletionSample::new(
......
...@@ -18,9 +18,7 @@ use anyhow::Ok; ...@@ -18,9 +18,7 @@ use anyhow::Ok;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use triton_distributed_llm::model_card::model::{ModelDeploymentCard, PromptContextMixin}; use triton_distributed_llm::model_card::model::{ModelDeploymentCard, PromptContextMixin};
use triton_distributed_llm::preprocessor::prompt::PromptFormatter; use triton_distributed_llm::preprocessor::prompt::PromptFormatter;
use triton_distributed_llm::protocols::openai::chat_completions::{ use triton_distributed_llm::protocols::openai::chat_completions::ChatCompletionRequest;
ChatCompletionMessage, ChatCompletionRequest, Tool, ToolChoiceType,
};
use hf_hub::{api::tokio::ApiBuilder, Cache, Repo, RepoType}; use hf_hub::{api::tokio::ApiBuilder, Cache, Repo, RepoType};
...@@ -217,29 +215,40 @@ const TOOLS: &str = r#" ...@@ -217,29 +215,40 @@ const TOOLS: &str = r#"
] ]
"#; "#;
// Notes:
// protocols::openai::chat_completions::ChatCompletionMessage -> async_openai::types::ChatCompletionRequestMessage
// protocols::openai::chat_completions::Tool -> async_openai::types::ChatCompletionTool
// protocols::openai::chat_completions::ToolChoiceType -> async_openai::types::ChatCompletionToolChoiceOption
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct Request { struct Request {
messages: Vec<ChatCompletionMessage>, messages: Vec<async_openai::types::ChatCompletionRequestMessage>,
tools: Option<Vec<Tool>>, tools: Option<Vec<async_openai::types::ChatCompletionTool>>,
tool_choice: Option<ToolChoiceType>, tool_choice: Option<async_openai::types::ChatCompletionToolChoiceOption>,
} }
impl Request { impl Request {
fn from( fn from(
messages: &str, messages: &str,
tools: Option<&str>, tools: Option<&str>,
tool_choice: Option<ToolChoiceType>, tool_choice: Option<async_openai::types::ChatCompletionToolChoiceOption>,
model: String, model: String,
) -> ChatCompletionRequest { ) -> ChatCompletionRequest {
let messages: Vec<ChatCompletionMessage> = serde_json::from_str(messages).unwrap(); let messages: Vec<async_openai::types::ChatCompletionRequestMessage> =
let tools: Option<Vec<Tool>> = tools.map(|x| serde_json::from_str(x).unwrap()); serde_json::from_str(messages).unwrap();
ChatCompletionRequest::builder() let tools: Option<Vec<async_openai::types::ChatCompletionTool>> =
tools.map(|x| serde_json::from_str(x).unwrap());
let tools = tools.unwrap();
let tool_choice = tool_choice.unwrap();
let inner = async_openai::types::CreateChatCompletionRequestArgs::default()
.model(model) .model(model)
.messages(messages) .messages(messages)
.tools(tools) .tools(tools)
.tool_choice(tool_choice) .tool_choice(tool_choice)
.build() .build()
.unwrap() .unwrap();
ChatCompletionRequest { inner, nvext: None }
} }
} }
...@@ -295,7 +304,7 @@ async fn test_single_turn_with_tools() { ...@@ -295,7 +304,7 @@ async fn test_single_turn_with_tools() {
let request = Request::from( let request = Request::from(
SINGLE_CHAT_MESSAGE, SINGLE_CHAT_MESSAGE,
Some(TOOLS), Some(TOOLS),
Some(ToolChoiceType::Auto), Some(async_openai::types::ChatCompletionToolChoiceOption::Auto),
mdc.slug().to_string(), mdc.slug().to_string(),
); );
let formatted_prompt = formatter.render(&request).unwrap(); let formatted_prompt = formatter.render(&request).unwrap();
...@@ -402,7 +411,7 @@ async fn test_multi_turn_with_system_with_tools() { ...@@ -402,7 +411,7 @@ async fn test_multi_turn_with_system_with_tools() {
let request = Request::from( let request = Request::from(
THREE_TURN_CHAT_MESSAGE_WITH_SYSTEM, THREE_TURN_CHAT_MESSAGE_WITH_SYSTEM,
Some(TOOLS), Some(TOOLS),
Some(ToolChoiceType::Auto), Some(async_openai::types::ChatCompletionToolChoiceOption::Auto),
mdc.slug().to_string(), mdc.slug().to_string(),
); );
let formatted_prompt = formatter.render(&request).unwrap(); let formatted_prompt = formatter.render(&request).unwrap();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment