Unverified Commit fd5cc288 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

refactor(3/3): switch dynamo-protocols to upstream async-openai types (#7625)


Co-authored-by: default avatarDmitry Tokarev <dtokarev@nvidia.com>
parent d517fb80
......@@ -13,25 +13,17 @@ use dynamo_llm::protocols::{
},
};
use dynamo_llm::{
http::{
client::{
GenericBYOTClient, HttpClientConfig, HttpRequestContext, NvCustomClient,
PureOpenAIClient,
},
service::{
Metrics,
error::HttpError,
metrics::{Endpoint, ErrorType, RequestType, Status},
service_v2::HttpService,
},
http::service::{
Metrics,
error::HttpError,
metrics::{Endpoint, ErrorType, RequestType, Status},
service_v2::HttpService,
},
model_card::ModelDeploymentCard,
};
use dynamo_protocols::config::OpenAIConfig;
use dynamo_runtime::metrics::prometheus_names::{frontend_service, name_prefix};
use dynamo_runtime::{
CancellationToken,
engine::AsyncEngineContext,
pipeline::{
AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn, async_trait,
},
......@@ -582,398 +574,10 @@ async fn wait_for_service_ready(port: u16) {
}
}
async fn service_with_engines() -> (HttpService, Arc<CounterEngine>, Arc<AlwaysFailEngine>, u16) {
let port = get_random_port().await;
let service = HttpService::builder()
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.port(port)
.build()
.unwrap();
let manager = service.model_manager();
let counter = Arc::new(CounterEngine {});
let failure = Arc::new(AlwaysFailEngine {});
let card = ModelDeploymentCard::with_name_only("foo");
manager
.add_chat_completions_model("foo", card.mdcsum(), counter.clone())
.unwrap();
let card = ModelDeploymentCard::with_name_only("bar");
manager
.add_chat_completions_model("bar", card.mdcsum(), failure.clone())
.unwrap();
manager
.add_completions_model("bar", card.mdcsum(), failure.clone())
.unwrap();
(service, counter, failure, port)
}
fn pure_openai_client(port: u16) -> PureOpenAIClient {
let config = HttpClientConfig {
openai_config: OpenAIConfig::new().with_api_base(format!("http://localhost:{}/v1", port)),
verbose: false,
};
PureOpenAIClient::new(config)
}
fn nv_custom_client(port: u16) -> NvCustomClient {
let config = HttpClientConfig {
openai_config: OpenAIConfig::new().with_api_base(format!("http://localhost:{}/v1", port)),
verbose: false,
};
NvCustomClient::new(config)
}
fn generic_byot_client(port: u16) -> GenericBYOTClient {
let config = HttpClientConfig {
openai_config: OpenAIConfig::new().with_api_base(format!("http://localhost:{}/v1", port)),
verbose: false,
};
GenericBYOTClient::new(config)
}
#[tokio::test]
async fn test_pure_openai_client() {
let (service, _counter, _failure, port) = service_with_engines().await;
let pure_openai_client = pure_openai_client(port);
let token = CancellationToken::new();
let cancel_token = token.clone();
// Start the service
let task = tokio::spawn(async move { service.run(token).await });
// Wait for service to be ready
wait_for_service_ready(port).await;
// Test successful streaming request
let request = dynamo_protocols::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.messages(vec![
dynamo_protocols::types::ChatCompletionRequestMessage::User(
dynamo_protocols::types::ChatCompletionRequestUserMessage {
content: dynamo_protocols::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let result = pure_openai_client.chat_stream(request).await;
assert!(result.is_ok(), "PureOpenAI client should succeed");
let (mut stream, _context) = result.unwrap().dissolve();
let mut count = 0;
while let Some(response) = stream.next().await {
count += 1;
assert!(response.is_ok(), "Response should be ok");
if count >= 3 {
break; // Don't consume entire stream
}
}
assert!(count > 0, "Should receive at least one response");
// Test error case with invalid model
let request = dynamo_protocols::types::CreateChatCompletionRequestArgs::default()
.model("bar") // This model will fail
.messages(vec![
dynamo_protocols::types::ChatCompletionRequestMessage::User(
dynamo_protocols::types::ChatCompletionRequestUserMessage {
content: dynamo_protocols::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let result = pure_openai_client.chat_stream(request).await;
assert!(
result.is_ok(),
"Client should return stream even for failing model"
);
let (mut stream, _context) = result.unwrap().dissolve();
if let Some(response) = stream.next().await {
assert!(
response.is_err(),
"Response should be error for failing model"
);
}
// Test context management
let ctx = HttpRequestContext::new();
let request = dynamo_protocols::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.messages(vec![
dynamo_protocols::types::ChatCompletionRequestMessage::User(
dynamo_protocols::types::ChatCompletionRequestUserMessage {
content: dynamo_protocols::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let result = pure_openai_client
.chat_stream_with_context(request, ctx.clone())
.await;
assert!(result.is_ok(), "Context-based request should succeed");
let (_stream, context) = result.unwrap().dissolve();
assert_eq!(context.id(), ctx.id(), "Context ID should match");
cancel_token.cancel();
task.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_nv_custom_client() {
let (service, _counter, _failure, port) = service_with_engines().await;
let nv_custom_client = nv_custom_client(port);
let token = CancellationToken::new();
let cancel_token = token.clone();
// Start the service
let task = tokio::spawn(async move { service.run(token).await });
// Wait for service to be ready
wait_for_service_ready(port).await;
// Test successful streaming request
let inner_request = dynamo_protocols::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.messages(vec![
dynamo_protocols::types::ChatCompletionRequestMessage::User(
dynamo_protocols::types::ChatCompletionRequestUserMessage {
content: dynamo_protocols::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let request = NvCreateChatCompletionRequest {
inner: inner_request,
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = nv_custom_client.chat_stream(request).await;
assert!(result.is_ok(), "NvCustom client should succeed");
let (mut stream, _context) = result.unwrap().dissolve();
let mut count = 0;
while let Some(response) = stream.next().await {
count += 1;
assert!(response.is_ok(), "Response should be ok");
if count >= 3 {
break; // Don't consume entire stream
}
}
assert!(count > 0, "Should receive at least one response");
// Test error case with invalid model
let inner_request = dynamo_protocols::types::CreateChatCompletionRequestArgs::default()
.model("bar") // This model will fail
.messages(vec![
dynamo_protocols::types::ChatCompletionRequestMessage::User(
dynamo_protocols::types::ChatCompletionRequestUserMessage {
content: dynamo_protocols::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let request = NvCreateChatCompletionRequest {
inner: inner_request,
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = nv_custom_client.chat_stream(request).await;
assert!(
result.is_ok(),
"Client should return stream even for failing model"
);
let (mut stream, _context) = result.unwrap().dissolve();
if let Some(response) = stream.next().await {
assert!(
response.is_err(),
"Response should be error for failing model"
);
}
// Test context management
let ctx = HttpRequestContext::new();
let inner_request = dynamo_protocols::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.messages(vec![
dynamo_protocols::types::ChatCompletionRequestMessage::User(
dynamo_protocols::types::ChatCompletionRequestUserMessage {
content: dynamo_protocols::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let request = NvCreateChatCompletionRequest {
inner: inner_request,
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = nv_custom_client
.chat_stream_with_context(request, ctx.clone())
.await;
assert!(result.is_ok(), "Context-based request should succeed");
let (_stream, context) = result.unwrap().dissolve();
assert_eq!(context.id(), ctx.id(), "Context ID should match");
cancel_token.cancel();
task.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_generic_byot_client() {
let (service, _counter, _failure, port) = service_with_engines().await;
let generic_byot_client = generic_byot_client(port);
let token = CancellationToken::new();
let cancel_token = token.clone();
// Start the service
let task = tokio::spawn(async move { service.run(token).await });
// Wait for service to be ready
wait_for_service_ready(port).await;
// Test successful streaming request
let request = serde_json::json!({
"model": "foo",
"messages": [
{
"role": "user",
"content": "Hi"
}
],
"stream": true,
"max_tokens": 50
});
let result = generic_byot_client.chat_stream(request).await;
assert!(result.is_ok(), "GenericBYOT client should succeed");
let (mut stream, _context) = result.unwrap().dissolve();
let mut count = 0;
while let Some(response) = stream.next().await {
println!("Response: {:?}", response);
count += 1;
assert!(response.is_ok(), "Response should be ok");
if count >= 3 {
break; // Don't consume entire stream
}
}
assert!(count > 0, "Should receive at least one response");
// Test error case with invalid model
let request = serde_json::json!({
"model": "bar", // This model will fail
"messages": [
{
"role": "user",
"content": "Hi"
}
],
"stream": true,
"max_tokens": 50
});
let result = generic_byot_client.chat_stream(request).await;
assert!(
result.is_ok(),
"Client should return stream even for failing model"
);
let (mut stream, _context) = result.unwrap().dissolve();
if let Some(response) = stream.next().await {
assert!(
response.is_err(),
"Response should be error for failing model"
);
}
// Test context management
let ctx = HttpRequestContext::new();
let request = serde_json::json!({
"model": "foo",
"messages": [
{
"role": "user",
"content": "Hi"
}
],
"stream": true,
"max_tokens": 50
});
let result = generic_byot_client
.chat_stream_with_context(request, ctx.clone())
.await;
assert!(result.is_ok(), "Context-based request should succeed");
let (_stream, context) = result.unwrap().dissolve();
assert_eq!(context.id(), ctx.id(), "Context ID should match");
cancel_token.cancel();
task.await.unwrap().unwrap();
}
// NOTE: BYOT (Bring Your Own Type) client tests were removed during the
// upstream async-openai migration. They depended on the forked
// dynamo_protocols::config and http::client modules which no longer exist.
// TODO: Rewrite these tests using the upstream async-openai client.
#[tokio::test]
async fn test_client_disconnect_cancellation_unary() {
let port = get_random_port().await;
......
......@@ -381,7 +381,6 @@ fn create_response_with_linear_probs(
index: 0,
delta: ChatCompletionStreamResponseDelta {
content: Some(ChatCompletionMessageContent::Text(_content.to_string())),
#[expect(deprecated)]
function_call: None,
tool_calls: None,
role: Some(Role::Assistant),
......@@ -463,7 +462,6 @@ fn create_multi_choice_response(
index: choice_idx as u32,
delta: ChatCompletionStreamResponseDelta {
content: Some(ChatCompletionMessageContent::Text("test".to_string())),
#[expect(deprecated)]
function_call: None,
tool_calls: None,
role: Some(Role::Assistant),
......
......@@ -55,7 +55,10 @@ fn parse_fixture(
let value: Value = serde_json::from_str(line).unwrap();
let chunk: NvCreateChatCompletionStreamResponse =
serde_json::from_value(value.clone()).unwrap();
expected_stream_json.push(value);
// Round-trip through the typed struct so expected JSON matches current serialization
// (upstream async-openai skips None fields that the old fork serialized as null).
let normalized = serde_json::to_value(&chunk).unwrap();
expected_stream_json.push(normalized);
input_chunks.push(chunk);
}
......
......@@ -2493,7 +2493,7 @@ mod parallel_jail_tests {
assert_eq!(
tool_call.r#type,
Some(dynamo_protocols::types::ChatCompletionToolType::Function),
Some(dynamo_protocols::types::FunctionType::Function),
"Tool call {} should be of type 'function'",
i
);
......
......@@ -7,7 +7,7 @@ use dynamo_protocols::types::{
ChatCompletionMessageContent, ChatCompletionNamedToolChoice, ChatCompletionRequestMessage,
ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent,
ChatCompletionToolChoiceOption, ChatCompletionToolType, CreateChatCompletionRequest,
FunctionName,
FunctionName, FunctionType,
};
/// Helper to extract text from ChatCompletionMessageContent
......@@ -172,7 +172,7 @@ async fn test_named_tool_choice_parses_json() {
let tool_call = &tool_calls[0];
assert_eq!(tool_call.index, 0);
assert!(tool_call.id.as_ref().unwrap().starts_with("call-"));
assert_eq!(tool_call.r#type, Some(ChatCompletionToolType::Function));
assert_eq!(tool_call.r#type, Some(FunctionType::Function));
assert_eq!(
tool_call.function.as_ref().unwrap().name.as_deref(),
Some("get_weather")
......@@ -213,7 +213,7 @@ async fn test_required_tool_choice_parses_json_array() {
assert_eq!(tool_calls[0].index, 0);
assert!(tool_calls[0].id.as_ref().unwrap().starts_with("call-"));
assert_eq!(tool_calls[0].r#type, Some(ChatCompletionToolType::Function));
assert_eq!(tool_calls[0].r#type, Some(FunctionType::Function));
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name.as_deref(),
Some("search")
......@@ -230,7 +230,7 @@ async fn test_required_tool_choice_parses_json_array() {
assert_eq!(tool_calls[1].index, 1);
assert!(tool_calls[1].id.as_ref().unwrap().starts_with("call-"));
assert_eq!(tool_calls[1].r#type, Some(ChatCompletionToolType::Function));
assert_eq!(tool_calls[1].r#type, Some(FunctionType::Function));
assert_eq!(
tool_calls[1].function.as_ref().unwrap().name.as_deref(),
Some("summarize")
......
......@@ -30,7 +30,7 @@ pub async fn try_tool_call_parse_aggregate(
.map(
|parsed| dynamo_protocols::types::ChatCompletionMessageToolCall {
id: parsed.id,
r#type: dynamo_protocols::types::ChatCompletionToolType::Function,
r#type: dynamo_protocols::types::FunctionType::Function,
function: dynamo_protocols::types::FunctionCall {
name: parsed.function.name,
arguments: parsed.function.arguments,
......@@ -65,7 +65,7 @@ pub async fn try_tool_call_parse_stream(
|(idx, parsed)| dynamo_protocols::types::ChatCompletionMessageToolCallChunk {
index: idx as u32,
id: Some(parsed.id),
r#type: Some(dynamo_protocols::types::ChatCompletionToolType::Function),
r#type: Some(dynamo_protocols::types::FunctionType::Function),
function: Some(dynamo_protocols::types::FunctionCallStream {
name: Some(parsed.function.name),
arguments: Some(parsed.function.arguments),
......
......@@ -10,7 +10,7 @@
[package]
name = "dynamo-protocols"
description = "Protocol types for OpenAI-compatible inference APIs, forked from async-openai."
description = "Protocol types for OpenAI-compatible inference APIs with inference-serving extensions."
license = "Apache-2.0 AND MIT"
version.workspace = true
edition.workspace = true
......@@ -19,44 +19,35 @@ homepage.workspace = true
repository.workspace = true
readme.workspace = true
[features]
realtime = ["dep:tokio-tungstenite"]
# Bring your own types
byot = []
[dependencies]
futures = { workspace = true }
rand = { workspace = true }
reqwest = { workspace = true }
# Upstream OpenAI types (types-only, no HTTP client)
async-openai = { version = "0.34", default-features = false, features = [
"chat-completion-types",
"response-types",
"completion-types",
"embedding-types",
"image-types",
] }
# Serialization
serde = { workspace = true }
serde_json = { workspace = true }
derive_builder = { workspace = true }
# Type support
thiserror = { workspace = true }
tokio = { workspace = true }
tokio-stream = { workspace = true }
tokio-util = { workspace = true }
tracing = { workspace = true }
url = { workspace = true }
uuid = { workspace = true }
derive_builder = { workspace = true }
bytes = { workspace = true }
futures = { workspace = true }
eventsource-stream = "0.2.3"
async-openai-macros = "0.1.0"
backoff = { version = "0.4.0", features = ["tokio"] }
base64 = "0.22.1"
reqwest-eventsource = "0.6.0"
secrecy = { version = "0.10.3", features = ["serde"] }
tokio-tungstenite = { version = "0.26.1", optional = true, default-features = false }
utoipa = { version = "5.3", features = ["url", "uuid"] }
[dev-dependencies]
tokio = { workspace = true }
tokio-test = "0.4.4"
serde_json = { workspace = true }
[[test]]
name = "bring-your-own-type"
required-features = ["byot"]
[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--cfg", "docsrs"]
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use serde::Serialize;
use crate::{
Client,
config::Config,
error::OpenAIError,
types::{
AssistantObject, CreateAssistantRequest, DeleteAssistantResponse, ListAssistantsResponse,
ModifyAssistantRequest,
},
};
/// Build assistants that can call models and use tools to perform tasks.
///
/// [Get started with the Assistants API](https://platform.openai.com/docs/assistants)
pub struct Assistants<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Assistants<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Create an assistant with a model and instructions.
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn create(
&self,
request: CreateAssistantRequest,
) -> Result<AssistantObject, OpenAIError> {
self.client.post("/assistants", request).await
}
/// Retrieves an assistant.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn retrieve(&self, assistant_id: &str) -> Result<AssistantObject, OpenAIError> {
self.client
.get(&format!("/assistants/{assistant_id}"))
.await
}
/// Modifies an assistant.
#[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn update(
&self,
assistant_id: &str,
request: ModifyAssistantRequest,
) -> Result<AssistantObject, OpenAIError> {
self.client
.post(&format!("/assistants/{assistant_id}"), request)
.await
}
/// Delete an assistant.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn delete(&self, assistant_id: &str) -> Result<DeleteAssistantResponse, OpenAIError> {
self.client
.delete(&format!("/assistants/{assistant_id}"))
.await
}
/// Returns a list of assistants.
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn list<Q>(&self, query: &Q) -> Result<ListAssistantsResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client.get_with_query("/assistants", &query).await
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use bytes::Bytes;
use crate::{
Client,
config::Config,
error::OpenAIError,
types::{
CreateSpeechRequest, CreateSpeechResponse, CreateTranscriptionRequest,
CreateTranscriptionResponseJson, CreateTranscriptionResponseVerboseJson,
CreateTranslationRequest, CreateTranslationResponseJson,
CreateTranslationResponseVerboseJson,
},
};
/// Turn audio into text or text into audio.
/// Related guide: [Speech to text](https://platform.openai.com/docs/guides/speech-to-text)
pub struct Audio<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Audio<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Transcribes audio into the input language.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn transcribe(
&self,
request: CreateTranscriptionRequest,
) -> Result<CreateTranscriptionResponseJson, OpenAIError> {
self.client
.post_form("/audio/transcriptions", request)
.await
}
/// Transcribes audio into the input language.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn transcribe_verbose_json(
&self,
request: CreateTranscriptionRequest,
) -> Result<CreateTranscriptionResponseVerboseJson, OpenAIError> {
self.client
.post_form("/audio/transcriptions", request)
.await
}
/// Transcribes audio into the input language.
pub async fn transcribe_raw(
&self,
request: CreateTranscriptionRequest,
) -> Result<Bytes, OpenAIError> {
self.client
.post_form_raw("/audio/transcriptions", request)
.await
}
/// Translates audio into English.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn translate(
&self,
request: CreateTranslationRequest,
) -> Result<CreateTranslationResponseJson, OpenAIError> {
self.client.post_form("/audio/translations", request).await
}
/// Translates audio into English.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn translate_verbose_json(
&self,
request: CreateTranslationRequest,
) -> Result<CreateTranslationResponseVerboseJson, OpenAIError> {
self.client.post_form("/audio/translations", request).await
}
/// Transcribes audio into the input language.
pub async fn translate_raw(
&self,
request: CreateTranslationRequest,
) -> Result<Bytes, OpenAIError> {
self.client
.post_form_raw("/audio/translations", request)
.await
}
/// Generates audio from the input text.
pub async fn speech(
&self,
request: CreateSpeechRequest,
) -> Result<CreateSpeechResponse, OpenAIError> {
let bytes = self.client.post_raw("/audio/speech", request).await?;
Ok(CreateSpeechResponse { bytes })
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use serde::Serialize;
use crate::{Client, config::Config, error::OpenAIError, types::ListAuditLogsResponse};
/// Logs of user actions and configuration changes within this organization.
/// To log events, you must activate logging in the [Organization Settings](https://platform.openai.com/settings/organization/general).
/// Once activated, for security reasons, logging cannot be deactivated.
pub struct AuditLogs<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> AuditLogs<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// List user actions and configuration changes within this organization.
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn get<Q>(&self, query: &Q) -> Result<ListAuditLogsResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client
.get_with_query("/organization/audit_logs", &query)
.await
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use serde::Serialize;
use crate::{
Client,
config::Config,
error::OpenAIError,
types::{Batch, BatchRequest, ListBatchesResponse},
};
/// Create large batches of API requests for asynchronous processing. The Batch API returns completions within 24 hours for a 50% discount.
///
/// Related guide: [Batch](https://platform.openai.com/docs/guides/batch)
pub struct Batches<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Batches<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Creates and executes a batch from an uploaded file of requests
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn create(&self, request: BatchRequest) -> Result<Batch, OpenAIError> {
self.client.post("/batches", request).await
}
/// List your organization's batches.
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn list<Q>(&self, query: &Q) -> Result<ListBatchesResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client.get_with_query("/batches", &query).await
}
/// Retrieves a batch.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn retrieve(&self, batch_id: &str) -> Result<Batch, OpenAIError> {
self.client.get(&format!("/batches/{batch_id}")).await
}
/// Cancels an in-progress batch. The batch will be in status `cancelling` for up to 10 minutes, before changing to `cancelled`, where it will have partial results (if any) available in the output file.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn cancel(&self, batch_id: &str) -> Result<Batch, OpenAIError> {
self.client
.post(
&format!("/batches/{batch_id}/cancel"),
serde_json::json!({}),
)
.await
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use crate::{
Client,
config::Config,
error::OpenAIError,
types::{
ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionResponse,
},
};
/// Given a list of messages comprising a conversation, the model will return a response.
///
/// Related guide: [Chat completions](https://platform.openai.com//docs/guides/text-generation)
pub struct Chat<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Chat<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Creates a model response for the given chat conversation. Learn more in
/// the
///
/// [text generation](https://platform.openai.com/docs/guides/text-generation),
/// [vision](https://platform.openai.com/docs/guides/vision),
///
/// and [audio](https://platform.openai.com/docs/guides/audio) guides.
///
///
/// Parameter support can differ depending on the model used to generate the
/// response, particularly for newer reasoning models. Parameters that are
/// only supported for reasoning models are noted below. For the current state
/// of unsupported parameters in reasoning models,
///
/// [refer to the reasoning guide](https://platform.openai.com/docs/guides/reasoning).
///
/// byot: You must ensure "stream: false" in serialized `request`
#[crate::byot(
T0 = serde::Serialize,
R = serde::de::DeserializeOwned
)]
pub async fn create(
&self,
request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if request.stream.is_some() && request.stream.unwrap() {
return Err(OpenAIError::InvalidArgument(
"When stream is true, use Chat::create_stream".into(),
));
}
}
self.client.post("/chat/completions", request).await
}
/// Creates a completion for the chat message
///
/// partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message.
///
/// [ChatCompletionResponseStream] is a parsed SSE stream until a \[DONE\] is received from server.
///
/// byot: You must ensure "stream: true" in serialized `request`
#[crate::byot(
T0 = serde::Serialize,
R = serde::de::DeserializeOwned,
stream = "true",
where_clause = "R: std::marker::Send + 'static"
)]
#[allow(unused_mut)]
pub async fn create_stream(
&self,
mut request: CreateChatCompletionRequest,
) -> Result<ChatCompletionResponseStream, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if request.stream.is_some() && !request.stream.unwrap() {
return Err(OpenAIError::InvalidArgument(
"When stream is false, use Chat::create".into(),
));
}
request.stream = Some(true);
}
Ok(self.client.post_stream("/chat/completions", request).await)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use std::pin::Pin;
use bytes::Bytes;
use futures::{Stream, stream::StreamExt};
use reqwest::multipart::Form;
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
use serde::{Serialize, de::DeserializeOwned};
use crate::{
Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
Models, Projects, Responses, Threads, Uploads, Users, VectorStores,
config::{Config, OpenAIConfig},
error::{ApiError, OpenAIError, WrappedError, map_deserialization_error},
file::Files,
image::Images,
moderation::Moderations,
traits::AsyncTryFrom,
};
#[derive(Debug, Clone, Default)]
/// Client is a container for config, backoff and http_client
/// used to make API calls.
pub struct Client<C: Config> {
http_client: reqwest::Client,
config: C,
backoff: backoff::ExponentialBackoff,
}
impl Client<OpenAIConfig> {
/// Client with default [OpenAIConfig]
pub fn new() -> Self {
Self::default()
}
}
impl<C: Config> Client<C> {
/// Create client with a custom HTTP client, OpenAI config, and backoff.
pub fn build(
http_client: reqwest::Client,
config: C,
backoff: backoff::ExponentialBackoff,
) -> Self {
Self {
http_client,
config,
backoff,
}
}
/// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
pub fn with_config(config: C) -> Self {
Self {
http_client: reqwest::Client::new(),
config,
backoff: Default::default(),
}
}
/// Provide your own [client] to make HTTP requests with.
///
/// [client]: reqwest::Client
pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
self.http_client = http_client;
self
}
/// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
self.backoff = backoff;
self
}
// API groups
/// To call [Models] group related APIs using this client.
pub fn models(&self) -> Models<C> {
Models::new(self)
}
/// To call [Completions] group related APIs using this client.
pub fn completions(&self) -> Completions<C> {
Completions::new(self)
}
/// To call [Chat] group related APIs using this client.
pub fn chat(&self) -> Chat<C> {
Chat::new(self)
}
/// To call [Images] group related APIs using this client.
pub fn images(&self) -> Images<C> {
Images::new(self)
}
/// To call [Moderations] group related APIs using this client.
pub fn moderations(&self) -> Moderations<C> {
Moderations::new(self)
}
/// To call [Files] group related APIs using this client.
pub fn files(&self) -> Files<C> {
Files::new(self)
}
/// To call [Uploads] group related APIs using this client.
pub fn uploads(&self) -> Uploads<C> {
Uploads::new(self)
}
/// To call [FineTuning] group related APIs using this client.
pub fn fine_tuning(&self) -> FineTuning<C> {
FineTuning::new(self)
}
/// To call [Embeddings] group related APIs using this client.
pub fn embeddings(&self) -> Embeddings<C> {
Embeddings::new(self)
}
/// To call [Audio] group related APIs using this client.
pub fn audio(&self) -> Audio<C> {
Audio::new(self)
}
/// To call [Assistants] group related APIs using this client.
pub fn assistants(&self) -> Assistants<C> {
Assistants::new(self)
}
/// To call [Threads] group related APIs using this client.
pub fn threads(&self) -> Threads<C> {
Threads::new(self)
}
/// To call [VectorStores] group related APIs using this client.
pub fn vector_stores(&self) -> VectorStores<C> {
VectorStores::new(self)
}
/// To call [Batches] group related APIs using this client.
pub fn batches(&self) -> Batches<C> {
Batches::new(self)
}
/// To call [AuditLogs] group related APIs using this client.
pub fn audit_logs(&self) -> AuditLogs<C> {
AuditLogs::new(self)
}
/// To call [Invites] group related APIs using this client.
pub fn invites(&self) -> Invites<C> {
Invites::new(self)
}
/// To call [Users] group related APIs using this client.
pub fn users(&self) -> Users<C> {
Users::new(self)
}
/// To call [Projects] group related APIs using this client.
pub fn projects(&self) -> Projects<C> {
Projects::new(self)
}
/// To call [Responses] group related APIs using this client.
pub fn responses(&self) -> Responses<C> {
Responses::new(self)
}
pub fn config(&self) -> &C {
&self.config
}
/// Make a GET request to {path} and deserialize the response body
pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
{
let request_maker = || async {
Ok(self
.http_client
.get(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.build()?)
};
self.execute(request_maker).await
}
/// Make a GET request to {path} with given Query and deserialize the response body
pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
Q: Serialize + ?Sized,
{
let request_maker = || async {
Ok(self
.http_client
.get(self.config.url(path))
.query(&self.config.query())
.query(query)
.headers(self.config.headers())
.build()?)
};
self.execute(request_maker).await
}
/// Make a DELETE request to {path} and deserialize the response body
pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
{
let request_maker = || async {
Ok(self
.http_client
.delete(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.build()?)
};
self.execute(request_maker).await
}
/// Make a GET request to {path} and return the response body
pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
let request_maker = || async {
Ok(self
.http_client
.get(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.build()?)
};
self.execute_raw(request_maker).await
}
/// Make a POST request to {path} and return the response body
pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
where
I: Serialize,
{
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.build()?)
};
self.execute_raw(request_maker).await
}
/// Make a POST request to {path} and deserialize the response body
pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
where
I: Serialize,
O: DeserializeOwned,
{
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.build()?)
};
self.execute(request_maker).await
}
/// POST a form at {path} and return the response body
pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
where
Form: AsyncTryFrom<F, Error = OpenAIError>,
F: Clone,
{
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
.build()?)
};
self.execute_raw(request_maker).await
}
/// POST a form at {path} and deserialize the response body
pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
Form: AsyncTryFrom<F, Error = OpenAIError>,
F: Clone,
{
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
.build()?)
};
self.execute(request_maker).await
}
/// Execute a HTTP request and retry on rate limit
///
/// request_maker serves one purpose: to be able to create request again
/// to retry API call after getting rate limited. request_maker is async because
/// reqwest::multipart::Form is created by async calls to read files for uploads.
async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
where
M: Fn() -> Fut,
Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
{
let client = self.http_client.clone();
backoff::future::retry(self.backoff.clone(), || async {
let request = request_maker().await.map_err(backoff::Error::Permanent)?;
let response = client
.execute(request)
.await
.map_err(OpenAIError::Reqwest)
.map_err(backoff::Error::Permanent)?;
let status = response.status();
let bytes = response
.bytes()
.await
.map_err(OpenAIError::Reqwest)
.map_err(backoff::Error::Permanent)?;
if status.is_server_error() {
// OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
let message: String = String::from_utf8_lossy(&bytes).into_owned();
tracing::warn!("Server error: {status} - {message}");
return Err(backoff::Error::Transient {
err: OpenAIError::ApiError(ApiError {
message,
r#type: None,
param: None,
code: None,
}),
retry_after: None,
});
}
// Deserialize response body from either error object or actual response object
if !status.is_success() {
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
.map_err(backoff::Error::Permanent)?;
if status.as_u16() == 429
// API returns 429 also when:
// "You exceeded your current quota, please check your plan and billing details."
&& wrapped_error.error.r#type != Some("insufficient_quota".to_string())
{
// Rate limited retry...
tracing::warn!("Rate limited: {}", wrapped_error.error.message);
return Err(backoff::Error::Transient {
err: OpenAIError::ApiError(wrapped_error.error),
retry_after: None,
});
} else {
return Err(backoff::Error::Permanent(OpenAIError::ApiError(
wrapped_error.error,
)));
}
}
Ok(bytes)
})
.await
}
/// Execute a HTTP request and retry on rate limit
///
/// request_maker serves one purpose: to be able to create request again
/// to retry API call after getting rate limited. request_maker is async because
/// reqwest::multipart::Form is created by async calls to read files for uploads.
async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
M: Fn() -> Fut,
Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
{
let bytes = self.execute_raw(request_maker).await?;
let response: O = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
Ok(response)
}
/// Make HTTP POST request to receive SSE
pub(crate) async fn post_stream<I, O>(
&self,
path: &str,
request: I,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
I: Serialize,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.eventsource()
.unwrap();
stream(event_source).await
}
pub(crate) async fn post_stream_mapped_raw_events<I, O>(
&self,
path: &str,
request: I,
event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
I: Serialize,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.eventsource()
.unwrap();
stream_mapped_raw_events(event_source, event_mapper).await
}
/// Make HTTP GET request to receive SSE
pub(crate) async fn _get_stream<Q, O>(
&self,
path: &str,
query: &Q,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
Q: Serialize + ?Sized,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = self
.http_client
.get(self.config.url(path))
.query(query)
.query(&self.config.query())
.headers(self.config.headers())
.eventsource()
.unwrap();
stream(event_source).await
}
}
/// Request which responds with SSE.
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
pub(crate) async fn stream<O>(
mut event_source: EventSource,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
O: DeserializeOwned + std::marker::Send + 'static,
{
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
while let Some(ev) = event_source.next().await {
match ev {
Err(e) => {
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
// rx dropped
break;
}
}
Ok(event) => match event {
Event::Message(message) => {
if message.data == "[DONE]" {
break;
}
let response = match serde_json::from_str::<O>(&message.data) {
Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
Ok(output) => Ok(output),
};
if let Err(_e) = tx.send(response) {
// rx dropped
break;
}
}
Event::Open => continue,
},
}
}
event_source.close();
});
Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
}
pub(crate) async fn stream_mapped_raw_events<O>(
mut event_source: EventSource,
event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
O: DeserializeOwned + std::marker::Send + 'static,
{
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
while let Some(ev) = event_source.next().await {
match ev {
Err(e) => {
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
// rx dropped
break;
}
}
Ok(event) => match event {
Event::Message(message) => {
let mut done = false;
if message.data == "[DONE]" {
done = true;
}
let response = event_mapper(message);
if let Err(_e) = tx.send(response) {
// rx dropped
break;
}
if done {
break;
}
}
Event::Open => continue,
},
}
}
event_source.close();
});
Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use crate::{
client::Client,
config::Config,
error::OpenAIError,
types::{CompletionResponseStream, CreateCompletionRequest, CreateCompletionResponse},
};
/// Given a prompt, the model will return one or more predicted completions,
/// and can also return the probabilities of alternative tokens at each position.
/// We recommend most users use our Chat completions API.
/// [Learn more](https://platform.openai.com/docs/deprecations/2023-07-06-gpt-and-embeddings)
///
/// Related guide: [Legacy Completions](https://platform.openai.com/docs/guides/gpt/completions-api)
pub struct Completions<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Completions<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Creates a completion for the provided prompt and parameters
///
/// You must ensure that "stream: false" in serialized `request`
#[crate::byot(
T0 = serde::Serialize,
R = serde::de::DeserializeOwned
)]
pub async fn create(
&self,
request: CreateCompletionRequest,
) -> Result<CreateCompletionResponse, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if request.stream.is_some() && request.stream.unwrap() {
return Err(OpenAIError::InvalidArgument(
"When stream is true, use Completion::create_stream".into(),
));
}
}
self.client.post("/completions", request).await
}
/// Creates a completion request for the provided prompt and parameters
///
/// Stream back partial progress. Tokens will be sent as data-only
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
/// as they become available, with the stream terminated by a data: \[DONE\] message.
///
/// [CompletionResponseStream] is a parsed SSE stream until a \[DONE\] is received from server.
///
/// You must ensure that "stream: true" in serialized `request`
#[crate::byot(
T0 = serde::Serialize,
R = serde::de::DeserializeOwned,
stream = "true",
where_clause = "R: std::marker::Send + 'static"
)]
#[allow(unused_mut)]
pub async fn create_stream(
&self,
mut request: CreateCompletionRequest,
) -> Result<CompletionResponseStream, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if request.stream.is_some() && !request.stream.unwrap() {
return Err(OpenAIError::InvalidArgument(
"When stream is false, use Completion::create".into(),
));
}
request.stream = Some(true);
}
Ok(self.client.post_stream("/completions", request).await)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
//! Client configurations: [OpenAIConfig] for OpenAI, [AzureConfig] for Azure OpenAI Service.
use reqwest::header::{AUTHORIZATION, HeaderMap};
use secrecy::{ExposeSecret, SecretString};
use serde::Deserialize;
/// Default v1 API base url
pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
/// Organization header
pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
/// Project header
pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
/// Calls to the Assistants API require that you pass a Beta header
pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
/// [crate::Client] relies on this for every API call on OpenAI
/// or Azure OpenAI service
pub trait Config: Send + Sync {
fn headers(&self) -> HeaderMap;
fn url(&self, path: &str) -> String;
fn query(&self) -> Vec<(&str, &str)>;
fn api_base(&self) -> &str;
fn api_key(&self) -> &SecretString;
}
/// Macro to implement Config trait for pointer types with dyn objects
macro_rules! impl_config_for_ptr {
($t:ty) => {
impl Config for $t {
fn headers(&self) -> HeaderMap {
self.as_ref().headers()
}
fn url(&self, path: &str) -> String {
self.as_ref().url(path)
}
fn query(&self) -> Vec<(&str, &str)> {
self.as_ref().query()
}
fn api_base(&self) -> &str {
self.as_ref().api_base()
}
fn api_key(&self) -> &SecretString {
self.as_ref().api_key()
}
}
};
}
impl_config_for_ptr!(Box<dyn Config>);
impl_config_for_ptr!(std::sync::Arc<dyn Config>);
/// Configuration for OpenAI API
#[derive(Clone, Debug, Deserialize)]
#[serde(default)]
pub struct OpenAIConfig {
api_base: String,
api_key: SecretString,
org_id: String,
project_id: String,
}
impl Default for OpenAIConfig {
fn default() -> Self {
Self {
api_base: OPENAI_API_BASE.to_string(),
api_key: std::env::var("OPENAI_API_KEY")
.unwrap_or_else(|_| "".to_string())
.into(),
org_id: Default::default(),
project_id: Default::default(),
}
}
}
impl OpenAIConfig {
/// Create client with default [OPENAI_API_BASE] url and default API key from OPENAI_API_KEY env var
pub fn new() -> Self {
Default::default()
}
/// To use a different organization id other than default
pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
self.org_id = org_id.into();
self
}
/// Non default project id
pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
self.project_id = project_id.into();
self
}
/// To use a different API key different from default OPENAI_API_KEY env var
pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
self.api_key = SecretString::from(api_key.into());
self
}
/// To use a API base url different from default [OPENAI_API_BASE]
pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
self.api_base = api_base.into();
self
}
pub fn org_id(&self) -> &str {
&self.org_id
}
}
impl Config for OpenAIConfig {
fn headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
if !self.org_id.is_empty() {
headers.insert(
OPENAI_ORGANIZATION_HEADER,
self.org_id.as_str().parse().unwrap(),
);
}
if !self.project_id.is_empty() {
headers.insert(
OPENAI_PROJECT_HEADER,
self.project_id.as_str().parse().unwrap(),
);
}
headers.insert(
AUTHORIZATION,
format!("Bearer {}", self.api_key.expose_secret())
.as_str()
.parse()
.unwrap(),
);
// hack for Assistants APIs
// Calls to the Assistants API require that you pass a Beta header
headers.insert(OPENAI_BETA_HEADER, "assistants=v2".parse().unwrap());
headers
}
fn url(&self, path: &str) -> String {
format!("{}{}", self.api_base, path)
}
fn api_base(&self) -> &str {
&self.api_base
}
fn api_key(&self) -> &SecretString {
&self.api_key
}
fn query(&self) -> Vec<(&str, &str)> {
vec![]
}
}
/// Configuration for Azure OpenAI Service
#[derive(Clone, Debug, Deserialize)]
#[serde(default)]
pub struct AzureConfig {
api_version: String,
deployment_id: String,
api_base: String,
api_key: SecretString,
}
impl Default for AzureConfig {
fn default() -> Self {
Self {
api_base: Default::default(),
api_key: std::env::var("OPENAI_API_KEY")
.unwrap_or_else(|_| "".to_string())
.into(),
deployment_id: Default::default(),
api_version: Default::default(),
}
}
}
impl AzureConfig {
pub fn new() -> Self {
Default::default()
}
pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
self.api_version = api_version.into();
self
}
pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
self.deployment_id = deployment_id.into();
self
}
/// To use a different API key different from default OPENAI_API_KEY env var
pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
self.api_key = SecretString::from(api_key.into());
self
}
/// API base url in form of <https://your-resource-name.openai.azure.com>
pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
self.api_base = api_base.into();
self
}
}
impl Config for AzureConfig {
fn headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
headers
}
fn url(&self, path: &str) -> String {
format!(
"{}/openai/deployments/{}{}",
self.api_base, self.deployment_id, path
)
}
fn api_base(&self) -> &str {
&self.api_base
}
fn api_key(&self) -> &SecretString {
&self.api_key
}
fn query(&self) -> Vec<(&str, &str)> {
vec![("api-version", &self.api_version)]
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::Client;
use crate::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
};
use std::sync::Arc;
#[test]
fn test_client_creation() {
unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
let openai_config = OpenAIConfig::default();
let config = Box::new(openai_config.clone()) as Box<dyn Config>;
let client = Client::with_config(config);
assert!(client.config().url("").ends_with("/v1"));
let config = Arc::new(openai_config) as Arc<dyn Config>;
let client = Client::with_config(config);
assert!(client.config().url("").ends_with("/v1"));
let cloned_client = client.clone();
assert!(cloned_client.config().url("").ends_with("/v1"));
}
async fn dynamic_dispatch_compiles(client: &Client<Box<dyn Config>>) {
let _ = client.chat().create(CreateChatCompletionRequest {
model: "gpt-4o".to_string(),
messages: vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: "Hello, world!".into(),
..Default::default()
},
)],
..Default::default()
});
}
#[tokio::test]
async fn test_dynamic_dispatch() {
let openai_config = OpenAIConfig::default();
let azure_config = AzureConfig::default();
let azure_client = Client::with_config(Box::new(azure_config.clone()) as Box<dyn Config>);
let oai_client = Client::with_config(Box::new(openai_config.clone()) as Box<dyn Config>);
let _ = dynamic_dispatch_compiles(&azure_client).await;
let _ = dynamic_dispatch_compiles(&oai_client).await;
let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&azure_client).await });
let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&oai_client).await });
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use std::path::{Path, PathBuf};
use base64::{Engine as _, engine::general_purpose};
use rand::{Rng, distr::Alphanumeric};
use reqwest::Url;
use crate::error::OpenAIError;
fn create_paths<P: AsRef<Path>>(url: &Url, base_dir: P) -> (PathBuf, PathBuf) {
let mut dir = PathBuf::from(base_dir.as_ref());
let mut path = dir.clone();
let segments = url.path_segments().map(|c| c.collect::<Vec<_>>());
if let Some(segments) = segments {
for (idx, segment) in segments.iter().enumerate() {
if idx != segments.len() - 1 {
dir.push(segment);
}
path.push(segment);
}
}
(dir, path)
}
pub(crate) async fn download_url<P: AsRef<Path>>(
url: &str,
dir: P,
) -> Result<PathBuf, OpenAIError> {
let parsed_url = Url::parse(url).map_err(|e| OpenAIError::FileSaveError(e.to_string()))?;
let response = reqwest::get(url)
.await
.map_err(|e| OpenAIError::FileSaveError(e.to_string()))?;
if !response.status().is_success() {
return Err(OpenAIError::FileSaveError(format!(
"couldn't download file, status: {}, url: {url}",
response.status()
)));
}
let (dir, file_path) = create_paths(&parsed_url, dir);
tokio::fs::create_dir_all(dir.as_path())
.await
.map_err(|e| OpenAIError::FileSaveError(format!("{}, dir: {}", e, dir.display())))?;
tokio::fs::write(
file_path.as_path(),
response.bytes().await.map_err(|e| {
OpenAIError::FileSaveError(format!("{}, file path: {}", e, file_path.display()))
})?,
)
.await
.map_err(|e| OpenAIError::FileSaveError(e.to_string()))?;
Ok(file_path)
}
pub(crate) async fn save_b64<P: AsRef<Path>>(b64: &str, dir: P) -> Result<PathBuf, OpenAIError> {
let filename: String = rand::rng()
.sample_iter(&Alphanumeric)
.take(10)
.map(char::from)
.collect();
let filename = format!("{filename}.png");
let path = PathBuf::from(dir.as_ref()).join(filename);
tokio::fs::write(
path.as_path(),
general_purpose::STANDARD
.decode(b64)
.map_err(|e| OpenAIError::FileSaveError(e.to_string()))?,
)
.await
.map_err(|e| OpenAIError::FileSaveError(format!("{}, path: {}", e, path.display())))?;
Ok(path)
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use crate::{
Client,
config::Config,
error::OpenAIError,
types::{CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse},
};
#[cfg(not(feature = "byot"))]
use crate::types::EncodingFormat;
/// Get a vector representation of a given input that can be easily
/// consumed by machine learning models and algorithms.
///
/// Related guide: [Embeddings](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)
pub struct Embeddings<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Embeddings<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Creates an embedding vector representing the input text.
///
/// byot: In serialized `request` you must ensure "encoding_format" is not "base64"
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn create(
&self,
request: CreateEmbeddingRequest,
) -> Result<CreateEmbeddingResponse, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
return Err(OpenAIError::InvalidArgument(
"When encoding_format is base64, use Embeddings::create_base64".into(),
));
}
}
self.client.post("/embeddings", request).await
}
/// Creates an embedding vector representing the input text.
///
/// The response will contain the embedding in base64 format.
///
/// byot: In serialized `request` you must ensure "encoding_format" is "base64"
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn create_base64(
&self,
request: CreateEmbeddingRequest,
) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
return Err(OpenAIError::InvalidArgument(
"When encoding_format is not base64, use Embeddings::create".into(),
));
}
}
self.client.post("/embeddings", request).await
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
//! Errors originating from API calls, parsing responses, and reading-or-writing to the file system.
//! Error types for protocol type operations.
use serde::{Deserialize, Serialize};
#[derive(Debug, thiserror::Error)]
pub enum OpenAIError {
/// Underlying error from reqwest library after an API call was made
#[error("http error: {0}")]
Reqwest(#[from] reqwest::Error),
/// OpenAI returns error object with details of API call failure
#[error("{0}")]
ApiError(ApiError),
/// Error when a response cannot be deserialized into a Rust type
#[error("failed to deserialize api response: {0}")]
JSONDeserialize(serde_json::Error),
/// Error on the client side when saving file to file system
#[error("failed to save file: {0}")]
FileSaveError(String),
/// Error on the client side when reading file from file system
#[error("failed to read file: {0}")]
FileReadError(String),
/// Error on SSE streaming
#[error("stream failed: {0}")]
StreamError(String),
/// Error from client side validation
/// or when builder fails to build request before making API call
#[error("invalid args: {0}")]
......@@ -47,9 +28,6 @@ pub struct ApiError {
}
impl std::fmt::Display for ApiError {
/// If all fields are available, `ApiError` is formatted as:
/// `{type}: {message} (param: {param}) (code: {code})`
/// Otherwise, missing fields will be ignored.
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut parts = Vec::new();
......@@ -76,11 +54,3 @@ impl std::fmt::Display for ApiError {
pub struct WrappedError {
pub error: ApiError,
}
pub(crate) fn map_deserialization_error(e: serde_json::Error, bytes: &[u8]) -> OpenAIError {
tracing::error!(
"failed deserialization of: {}",
String::from_utf8_lossy(bytes)
);
OpenAIError::JSONDeserialize(e)
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use bytes::Bytes;
use serde::Serialize;
use crate::{
Client,
config::Config,
error::OpenAIError,
types::{CreateFileRequest, DeleteFileResponse, ListFilesResponse, OpenAIFile},
};
/// Files are used to upload documents that can be used with features like Assistants and Fine-tuning.
pub struct Files<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Files<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Upload a file that can be used across various endpoints. Individual files can be up to 512 MB, and the size of all files uploaded by one organization can be up to 100 GB.
///
/// The Assistants API supports files up to 2 million tokens and of specific file types. See the [Assistants Tools guide](https://platform.openai.com/docs/assistants/tools) for details.
///
/// The Fine-tuning API only supports `.jsonl` files. The input also has certain required formats for fine-tuning [chat](https://platform.openai.com/docs/api-reference/fine-tuning/chat-input) or [completions](https://platform.openai.com/docs/api-reference/fine-tuning/completions-input) models.
///
///The Batch API only supports `.jsonl` files up to 100 MB in size. The input also has a specific required [format](https://platform.openai.com/docs/api-reference/batch/request-input).
///
/// Please [contact us](https://help.openai.com/) if you need to increase these storage limits.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn create(&self, request: CreateFileRequest) -> Result<OpenAIFile, OpenAIError> {
self.client.post_form("/files", request).await
}
/// Returns a list of files that belong to the user's organization.
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn list<Q>(&self, query: &Q) -> Result<ListFilesResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client.get_with_query("/files", &query).await
}
/// Returns information about a specific file.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn retrieve(&self, file_id: &str) -> Result<OpenAIFile, OpenAIError> {
self.client.get(format!("/files/{file_id}").as_str()).await
}
/// Delete a file.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn delete(&self, file_id: &str) -> Result<DeleteFileResponse, OpenAIError> {
self.client
.delete(format!("/files/{file_id}").as_str())
.await
}
/// Returns the contents of the specified file
pub async fn content(&self, file_id: &str) -> Result<Bytes, OpenAIError> {
self.client
.get_raw(format!("/files/{file_id}/content").as_str())
.await
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use serde::Serialize;
use crate::{
Client,
config::Config,
error::OpenAIError,
types::{
CreateFineTuningJobRequest, FineTuningJob, ListFineTuningJobCheckpointsResponse,
ListFineTuningJobEventsResponse, ListPaginatedFineTuningJobsResponse,
},
};
/// Manage fine-tuning jobs to tailor a model to your specific training data.
///
/// Related guide: [Fine-tune models](https://platform.openai.com/docs/guides/fine-tuning)
pub struct FineTuning<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> FineTuning<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Creates a job that fine-tunes a specified model from a given dataset.
///
/// Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete.
///
/// [Learn more about Fine-tuning](https://platform.openai.com/docs/guides/fine-tuning)
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn create(
&self,
request: CreateFineTuningJobRequest,
) -> Result<FineTuningJob, OpenAIError> {
self.client.post("/fine_tuning/jobs", request).await
}
/// List your organization's fine-tuning jobs
#[crate::byot(T0 = serde::Serialize, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn list_paginated<Q>(
&self,
query: &Q,
) -> Result<ListPaginatedFineTuningJobsResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client
.get_with_query("/fine_tuning/jobs", &query)
.await
}
/// Gets info about the fine-tune job.
///
/// [Learn more about Fine-tuning](https://platform.openai.com/docs/guides/fine-tuning)
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn retrieve(&self, fine_tuning_job_id: &str) -> Result<FineTuningJob, OpenAIError> {
self.client
.get(format!("/fine_tuning/jobs/{fine_tuning_job_id}").as_str())
.await
}
/// Immediately cancel a fine-tune job.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn cancel(&self, fine_tuning_job_id: &str) -> Result<FineTuningJob, OpenAIError> {
self.client
.post(
format!("/fine_tuning/jobs/{fine_tuning_job_id}/cancel").as_str(),
(),
)
.await
}
/// Get fine-grained status updates for a fine-tune job.
#[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn list_events<Q>(
&self,
fine_tuning_job_id: &str,
query: &Q,
) -> Result<ListFineTuningJobEventsResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client
.get_with_query(
format!("/fine_tuning/jobs/{fine_tuning_job_id}/events").as_str(),
&query,
)
.await
}
#[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn list_checkpoints<Q>(
&self,
fine_tuning_job_id: &str,
query: &Q,
) -> Result<ListFineTuningJobCheckpointsResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client
.get_with_query(
format!("/fine_tuning/jobs/{fine_tuning_job_id}/checkpoints").as_str(),
&query,
)
.await
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment