Unverified Commit 199b9a30 authored by nachiketb-nvidia's avatar nachiketb-nvidia Committed by GitHub
Browse files

chore: Bring async-openai into repo as request starter (#2520)


Co-authored-by: default avatarGraham King <grahamk@nvidia.com>
parent 26d9f159
......@@ -238,37 +238,9 @@ version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4288f83726785267c6f2ef073a3d83dc3f9b81464e9f99898240cced85fce35a"
[[package]]
name = "async-openai"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31acf814d6b499e33ec894bb0fd7ddaf2665b44fbdd42b858d736449271fde0c"
dependencies = [
"async-openai-macros",
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand 0.8.5",
"reqwest 0.12.22",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.12",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
]
[[package]]
name = "async-openai-macros"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0289cba6d5143bfe8251d57b4a8cac036adf158525a76533a7082ba65ec76398"
version = "0.4.0+post0"
dependencies = [
"proc-macro2",
"quote",
......@@ -1862,6 +1834,32 @@ dependencies = [
"bytemuck",
]
[[package]]
name = "dynamo-async-openai"
version = "0.4.0+post0"
dependencies = [
"async-openai-macros",
"backoff",
"base64 0.22.1",
"bytes",
"derive_builder",
"eventsource-stream",
"futures",
"rand 0.9.2",
"reqwest 0.12.22",
"reqwest-eventsource",
"secrecy",
"serde",
"serde_json",
"thiserror 2.0.12",
"tokio",
"tokio-stream",
"tokio-test",
"tokio-tungstenite 0.26.2",
"tokio-util",
"tracing",
]
[[package]]
name = "dynamo-engine-llamacpp"
version = "0.4.0+post0"
......@@ -1879,9 +1877,9 @@ name = "dynamo-engine-mistralrs"
version = "0.4.0+post0"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"dynamo-async-openai",
"dynamo-llm",
"dynamo-runtime",
"either",
......@@ -1903,7 +1901,6 @@ dependencies = [
"approx",
"assert_matches",
"async-nats",
"async-openai",
"async-stream",
"async-trait",
"async_zmq",
......@@ -1921,6 +1918,7 @@ dependencies = [
"derive-getters",
"derive_builder",
"dialoguer",
"dynamo-async-openai",
"dynamo-runtime",
"either",
"erased-serde",
......@@ -1980,10 +1978,10 @@ name = "dynamo-run"
version = "0.4.0+post0"
dependencies = [
"anyhow",
"async-openai",
"async-stream",
"async-trait",
"clap 4.5.42",
"dynamo-async-openai",
"dynamo-engine-llamacpp",
"dynamo-engine-mistralrs",
"dynamo-llm",
......@@ -2259,7 +2257,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad"
dependencies = [
"libc",
"windows-sys 0.59.0",
"windows-sys 0.60.2",
]
[[package]]
......@@ -2426,6 +2424,15 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared 0.1.1",
]
[[package]]
name = "foreign-types"
version = "0.5.0"
......@@ -2433,7 +2440,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965"
dependencies = [
"foreign-types-macros",
"foreign-types-shared",
"foreign-types-shared 0.3.1",
]
[[package]]
......@@ -2447,6 +2454,12 @@ dependencies = [
"syn 2.0.104",
]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]]
name = "foreign-types-shared"
version = "0.3.1"
......@@ -3315,6 +3328,22 @@ dependencies = [
"tower-service",
]
[[package]]
name = "hyper-tls"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
dependencies = [
"bytes",
"http-body-util",
"hyper 1.6.0",
"hyper-util",
"native-tls",
"tokio",
"tokio-native-tls",
"tower-service",
]
[[package]]
name = "hyper-util"
version = "0.1.16"
......@@ -3815,7 +3844,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667"
dependencies = [
"cfg-if 1.0.1",
"windows-targets 0.48.5",
"windows-targets 0.53.3",
]
[[package]]
......@@ -4078,7 +4107,7 @@ dependencies = [
"bitflags 2.9.1",
"block",
"core-graphics-types",
"foreign-types",
"foreign-types 0.5.0",
"log",
"objc",
"paste",
......@@ -4093,7 +4122,7 @@ dependencies = [
"bitflags 2.9.1",
"block",
"core-graphics-types",
"foreign-types",
"foreign-types 0.5.0",
"log",
"objc",
"paste",
......@@ -4343,7 +4372,7 @@ dependencies = [
"tokenizers",
"tokio",
"tokio-rayon",
"tokio-tungstenite",
"tokio-tungstenite 0.24.0",
"toktrie_hf_tokenizers 1.0.0",
"toml",
"tqdm",
......@@ -4369,7 +4398,7 @@ dependencies = [
"serde",
"serde_json",
"tokio",
"tokio-tungstenite",
"tokio-tungstenite 0.24.0",
"tracing",
"utoipa",
"uuid 1.17.0",
......@@ -4472,6 +4501,23 @@ dependencies = [
"typenum",
]
[[package]]
name = "native-tls"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework 2.11.1",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "ndarray"
version = "0.16.1"
......@@ -4876,12 +4922,60 @@ version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]]
name = "openssl"
version = "0.10.73"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8505734d46c8ab1e19a1dce3aef597ad87dcb4c37e7188231769bd6bd51cebf8"
dependencies = [
"bitflags 2.9.1",
"cfg-if 1.0.1",
"foreign-types 0.3.2",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.104",
]
[[package]]
name = "openssl-probe"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-src"
version = "300.5.2+3.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d270b79e2926f5150189d475bc7e9d2c69f9c4697b185fa917d5a32b792d21b4"
dependencies = [
"cc",
]
[[package]]
name = "openssl-sys"
version = "0.9.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90096e2e47630d78b7d1c20952dc621f957103f8bc2c8359ec81290d75238571"
dependencies = [
"cc",
"libc",
"openssl-src",
"pkg-config",
"vcpkg",
]
[[package]]
name = "option-ext"
version = "0.2.0"
......@@ -5857,11 +5951,13 @@ dependencies = [
"http-body-util",
"hyper 1.6.0",
"hyper-rustls",
"hyper-tls",
"hyper-util",
"js-sys",
"log",
"mime",
"mime_guess",
"native-tls",
"percent-encoding",
"pin-project-lite",
"quinn",
......@@ -5873,6 +5969,7 @@ dependencies = [
"serde_urlencoded",
"sync_wrapper 1.0.2",
"tokio",
"tokio-native-tls",
"tokio-rustls",
"tokio-util",
"tower 0.5.2",
......@@ -6118,7 +6215,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys 0.9.4",
"windows-sys 0.59.0",
"windows-sys 0.60.2",
]
[[package]]
......@@ -7451,6 +7548,16 @@ dependencies = [
"syn 2.0.104",
]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]]
name = "tokio-rayon"
version = "2.1.0"
......@@ -7482,6 +7589,19 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-test"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2468baabc3311435b55dd935f702f42cd1b8abb7e754fb7dfb16bd36aa88f9f7"
dependencies = [
"async-stream",
"bytes",
"futures-core",
"tokio",
"tokio-stream",
]
[[package]]
name = "tokio-tungstenite"
version = "0.24.0"
......@@ -7491,7 +7611,19 @@ dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
"tungstenite 0.24.0",
]
[[package]]
name = "tokio-tungstenite"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite 0.26.2",
]
[[package]]
......@@ -7899,6 +8031,19 @@ dependencies = [
"utf-8",
]
[[package]]
name = "tungstenite"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13"
dependencies = [
"bytes",
"log",
"rand 0.9.2",
"thiserror 2.0.12",
"utf-8",
]
[[package]]
name = "typeid"
version = "1.0.3"
......@@ -8192,6 +8337,12 @@ dependencies = [
"uuid 0.8.2",
]
[[package]]
name = "vcpkg"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "vergen"
version = "9.0.6"
......@@ -8496,7 +8647,7 @@ version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [
"windows-sys 0.48.0",
"windows-sys 0.59.0",
]
[[package]]
......
......@@ -9,6 +9,8 @@ members = [
"lib/llm",
"lib/runtime",
"lib/tokens",
"lib/async-openai",
"lib/async-openai-macros",
"lib/bindings/c",
"lib/engines/*",
]
......@@ -29,11 +31,11 @@ keywords = ["llm", "genai", "inference", "nvidia", "distributed", "dynamo"]
dynamo-runtime = { path = "lib/runtime", version = "0.4.0" }
dynamo-llm = { path = "lib/llm", version = "0.4.0" }
dynamo-tokens = { path = "lib/tokens", version = "0.4.0" }
dynamo-async-openai = { path = "lib/async-openai", version = "0.4.0", features = ["byot", "rustls"]}
# External dependencies
anyhow = { version = "1" }
async-nats = { version = "0.40", features = ["service"] }
async-openai = { version = "0.29.0", features = ["rustls", "byot"] }
async-stream = { version = "0.3" }
async-trait = { version = "0.1" }
async_zmq = { version = "0.4.0" }
......
......@@ -33,7 +33,7 @@ dynamo-engine-llamacpp = { path = "../../lib/engines/llamacpp", optional = true
dynamo-engine-mistralrs = { path = "../../lib/engines/mistralrs", optional = true }
anyhow = { workspace = true }
async-openai = { workspace = true }
dynamo-async-openai = { workspace = true }
async-stream = { workspace = true }
async-trait = { workspace = true }
either = { workspace = true }
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Based on https://github.com/64bit/async-openai/ by Himanshu Neema
# Original Copyright (c) 2022 Himanshu Neema
# Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
#
# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# Licensed under Apache 2.0
[package]
name = "async-openai-macros"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
readme.workspace = true
[lib]
proc-macro = true
[dependencies]
syn = { version = "2.0", features = ["full"] }
quote = "1.0"
proc-macro2 = "1.0"
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
token::Comma,
FnArg, GenericParam, Generics, ItemFn, Pat, PatType, TypeParam, WhereClause,
};
// Parse attribute arguments like #[byot(T0: Display + Debug, T1: Clone, R: Serialize)]
struct BoundArgs {
bounds: Vec<(String, syn::TypeParamBound)>,
where_clause: Option<String>,
stream: bool, // Add stream flag
}
impl Parse for BoundArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut bounds = Vec::new();
let mut where_clause = None;
let mut stream = false; // Default to false
let vars = Punctuated::<syn::MetaNameValue, Comma>::parse_terminated(input)?;
for var in vars {
let name = var.path.get_ident().unwrap().to_string();
match name.as_str() {
"where_clause" => {
where_clause = Some(var.value.into_token_stream().to_string());
}
"stream" => {
stream = var.value.into_token_stream().to_string().contains("true");
}
_ => {
let bound: syn::TypeParamBound =
syn::parse_str(&var.value.into_token_stream().to_string())?;
bounds.push((name, bound));
}
}
}
Ok(BoundArgs {
bounds,
where_clause,
stream,
})
}
}
#[proc_macro_attribute]
pub fn byot_passthrough(_args: TokenStream, item: TokenStream) -> TokenStream {
item
}
#[proc_macro_attribute]
pub fn byot(args: TokenStream, item: TokenStream) -> TokenStream {
let bounds_args = parse_macro_input!(args as BoundArgs);
let input = parse_macro_input!(item as ItemFn);
let mut new_generics = Generics::default();
let mut param_count = 0;
// Process function arguments
let mut new_params = Vec::new();
let args = input
.sig
.inputs
.iter()
.map(|arg| {
match arg {
FnArg::Receiver(receiver) => receiver.to_token_stream(),
FnArg::Typed(PatType { pat, .. }) => {
if let Pat::Ident(pat_ident) = &**pat {
let generic_name = format!("T{}", param_count);
let generic_ident =
syn::Ident::new(&generic_name, proc_macro2::Span::call_site());
// Create type parameter with optional bounds
let mut type_param = TypeParam::from(generic_ident.clone());
if let Some((_, bound)) = bounds_args
.bounds
.iter()
.find(|(name, _)| name == &generic_name)
{
type_param.bounds.extend(vec![bound.clone()]);
}
new_params.push(GenericParam::Type(type_param));
param_count += 1;
quote! { #pat_ident: #generic_ident }
} else {
arg.to_token_stream()
}
}
}
})
.collect::<Vec<_>>();
// Add R type parameter with optional bounds
let generic_r = syn::Ident::new("R", proc_macro2::Span::call_site());
let mut return_type_param = TypeParam::from(generic_r.clone());
if let Some((_, bound)) = bounds_args.bounds.iter().find(|(name, _)| name == "R") {
return_type_param.bounds.extend(vec![bound.clone()]);
}
new_params.push(GenericParam::Type(return_type_param));
// Add all generic parameters
new_generics.params.extend(new_params);
let fn_name = &input.sig.ident;
let byot_fn_name = syn::Ident::new(&format!("{}_byot", fn_name), fn_name.span());
let vis = &input.vis;
let block = &input.block;
let attrs = &input.attrs;
let asyncness = &input.sig.asyncness;
// Parse where clause if provided
let where_clause = if let Some(where_str) = bounds_args.where_clause {
match syn::parse_str::<WhereClause>(&format!("where {}", where_str.replace("\"", ""))) {
Ok(where_clause) => quote! { #where_clause },
Err(e) => return TokenStream::from(e.to_compile_error()),
}
} else {
quote! {}
};
// Generate return type based on stream flag
let return_type = if bounds_args.stream {
quote! { Result<::std::pin::Pin<Box<dyn ::futures::Stream<Item = Result<R, OpenAIError>> + Send>>, OpenAIError> }
} else {
quote! { Result<R, OpenAIError> }
};
let expanded = quote! {
#(#attrs)*
#input
#(#attrs)*
#vis #asyncness fn #byot_fn_name #new_generics (#(#args),*) -> #return_type #where_clause #block
};
expanded.into()
}
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Based on https://github.com/64bit/async-openai/ by Himanshu Neema
# Original Copyright (c) 2022 Himanshu Neema
# Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
#
# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# Licensed under Apache 2.0
[package]
name = "dynamo-async-openai"
version.workspace = true
edition.workspace = true
authors.workspace = true
license.workspace = true
homepage.workspace = true
repository.workspace = true
readme.workspace = true
[features]
default = ["rustls"]
# Enable rustls for TLS support
rustls = ["reqwest/rustls-tls-native-roots"]
# Enable rustls and webpki-roots
rustls-webpki-roots = ["reqwest/rustls-tls-webpki-roots"]
# Enable native-tls for TLS support
native-tls = ["reqwest/native-tls"]
# Remove dependency on OpenSSL
native-tls-vendored = ["reqwest/native-tls-vendored"]
realtime = ["dep:tokio-tungstenite"]
# Bring your own types
byot = []
[dependencies]
async-openai-macros = { path = "../async-openai-macros" }
backoff = { version = "0.4.0", features = ["tokio"] }
base64 = "0.22.1"
futures = "0.3.31"
rand = "0.9.0"
reqwest = { version = "0.12.12", features = [
"json",
"stream",
"multipart",
], default-features = false }
reqwest-eventsource = "0.6.0"
serde = { version = "1.0.217", features = ["derive", "rc"] }
serde_json = "1.0.135"
thiserror = "2.0.11"
tokio = { version = "1.43.0", features = ["fs", "macros"] }
tokio-stream = "0.1.17"
tokio-util = { version = "0.7.13", features = ["codec", "io-util"] }
tracing = "0.1.41"
derive_builder = "0.20.2"
secrecy = { version = "0.10.3", features = ["serde"] }
bytes = "1.9.0"
eventsource-stream = "0.2.3"
tokio-tungstenite = { version = "0.26.1", optional = true, default-features = false }
[dev-dependencies]
tokio-test = "0.4.4"
serde_json = "1.0"
[[test]]
name = "bring-your-own-type"
required-features = ["byot"]
[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--cfg", "docsrs"]
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use serde::Serialize;
use crate::{
config::Config,
error::OpenAIError,
types::{
AssistantObject, CreateAssistantRequest, DeleteAssistantResponse, ListAssistantsResponse,
ModifyAssistantRequest,
},
Client,
};
/// Build assistants that can call models and use tools to perform tasks.
///
/// [Get started with the Assistants API](https://platform.openai.com/docs/assistants)
pub struct Assistants<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Assistants<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Create an assistant with a model and instructions.
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn create(
&self,
request: CreateAssistantRequest,
) -> Result<AssistantObject, OpenAIError> {
self.client.post("/assistants", request).await
}
/// Retrieves an assistant.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn retrieve(&self, assistant_id: &str) -> Result<AssistantObject, OpenAIError> {
self.client
.get(&format!("/assistants/{assistant_id}"))
.await
}
/// Modifies an assistant.
#[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn update(
&self,
assistant_id: &str,
request: ModifyAssistantRequest,
) -> Result<AssistantObject, OpenAIError> {
self.client
.post(&format!("/assistants/{assistant_id}"), request)
.await
}
/// Delete an assistant.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn delete(&self, assistant_id: &str) -> Result<DeleteAssistantResponse, OpenAIError> {
self.client
.delete(&format!("/assistants/{assistant_id}"))
.await
}
/// Returns a list of assistants.
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn list<Q>(&self, query: &Q) -> Result<ListAssistantsResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client.get_with_query("/assistants", &query).await
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use bytes::Bytes;
use crate::{
config::Config,
error::OpenAIError,
types::{
CreateSpeechRequest, CreateSpeechResponse, CreateTranscriptionRequest,
CreateTranscriptionResponseJson, CreateTranscriptionResponseVerboseJson,
CreateTranslationRequest, CreateTranslationResponseJson,
CreateTranslationResponseVerboseJson,
},
Client,
};
/// Turn audio into text or text into audio.
/// Related guide: [Speech to text](https://platform.openai.com/docs/guides/speech-to-text)
pub struct Audio<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Audio<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Transcribes audio into the input language.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn transcribe(
&self,
request: CreateTranscriptionRequest,
) -> Result<CreateTranscriptionResponseJson, OpenAIError> {
self.client
.post_form("/audio/transcriptions", request)
.await
}
/// Transcribes audio into the input language.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn transcribe_verbose_json(
&self,
request: CreateTranscriptionRequest,
) -> Result<CreateTranscriptionResponseVerboseJson, OpenAIError> {
self.client
.post_form("/audio/transcriptions", request)
.await
}
/// Transcribes audio into the input language.
pub async fn transcribe_raw(
&self,
request: CreateTranscriptionRequest,
) -> Result<Bytes, OpenAIError> {
self.client
.post_form_raw("/audio/transcriptions", request)
.await
}
/// Translates audio into English.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn translate(
&self,
request: CreateTranslationRequest,
) -> Result<CreateTranslationResponseJson, OpenAIError> {
self.client.post_form("/audio/translations", request).await
}
/// Translates audio into English.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn translate_verbose_json(
&self,
request: CreateTranslationRequest,
) -> Result<CreateTranslationResponseVerboseJson, OpenAIError> {
self.client.post_form("/audio/translations", request).await
}
/// Transcribes audio into the input language.
pub async fn translate_raw(
&self,
request: CreateTranslationRequest,
) -> Result<Bytes, OpenAIError> {
self.client
.post_form_raw("/audio/translations", request)
.await
}
/// Generates audio from the input text.
pub async fn speech(
&self,
request: CreateSpeechRequest,
) -> Result<CreateSpeechResponse, OpenAIError> {
let bytes = self.client.post_raw("/audio/speech", request).await?;
Ok(CreateSpeechResponse { bytes })
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use serde::Serialize;
use crate::{config::Config, error::OpenAIError, types::ListAuditLogsResponse, Client};
/// Logs of user actions and configuration changes within this organization.
/// To log events, you must activate logging in the [Organization Settings](https://platform.openai.com/settings/organization/general).
/// Once activated, for security reasons, logging cannot be deactivated.
pub struct AuditLogs<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> AuditLogs<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// List user actions and configuration changes within this organization.
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn get<Q>(&self, query: &Q) -> Result<ListAuditLogsResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client
.get_with_query("/organization/audit_logs", &query)
.await
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use serde::Serialize;
use crate::{
config::Config,
error::OpenAIError,
types::{Batch, BatchRequest, ListBatchesResponse},
Client,
};
/// Create large batches of API requests for asynchronous processing. The Batch API returns completions within 24 hours for a 50% discount.
///
/// Related guide: [Batch](https://platform.openai.com/docs/guides/batch)
pub struct Batches<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Batches<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Creates and executes a batch from an uploaded file of requests
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn create(&self, request: BatchRequest) -> Result<Batch, OpenAIError> {
self.client.post("/batches", request).await
}
/// List your organization's batches.
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn list<Q>(&self, query: &Q) -> Result<ListBatchesResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client.get_with_query("/batches", &query).await
}
/// Retrieves a batch.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn retrieve(&self, batch_id: &str) -> Result<Batch, OpenAIError> {
self.client.get(&format!("/batches/{batch_id}")).await
}
/// Cancels an in-progress batch. The batch will be in status `cancelling` for up to 10 minutes, before changing to `cancelled`, where it will have partial results (if any) available in the output file.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn cancel(&self, batch_id: &str) -> Result<Batch, OpenAIError> {
self.client
.post(
&format!("/batches/{batch_id}/cancel"),
serde_json::json!({}),
)
.await
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use crate::{
config::Config,
error::OpenAIError,
types::{
ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionResponse,
},
Client,
};
/// Given a list of messages comprising a conversation, the model will return a response.
///
/// Related guide: [Chat completions](https://platform.openai.com//docs/guides/text-generation)
pub struct Chat<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Chat<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Creates a model response for the given chat conversation. Learn more in
/// the
///
/// [text generation](https://platform.openai.com/docs/guides/text-generation),
/// [vision](https://platform.openai.com/docs/guides/vision),
///
/// and [audio](https://platform.openai.com/docs/guides/audio) guides.
///
///
/// Parameter support can differ depending on the model used to generate the
/// response, particularly for newer reasoning models. Parameters that are
/// only supported for reasoning models are noted below. For the current state
/// of unsupported parameters in reasoning models,
///
/// [refer to the reasoning guide](https://platform.openai.com/docs/guides/reasoning).
///
/// byot: You must ensure "stream: false" in serialized `request`
#[crate::byot(
T0 = serde::Serialize,
R = serde::de::DeserializeOwned
)]
pub async fn create(
&self,
request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if request.stream.is_some() && request.stream.unwrap() {
return Err(OpenAIError::InvalidArgument(
"When stream is true, use Chat::create_stream".into(),
));
}
}
self.client.post("/chat/completions", request).await
}
/// Creates a completion for the chat message
///
/// partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message.
///
/// [ChatCompletionResponseStream] is a parsed SSE stream until a \[DONE\] is received from server.
///
/// byot: You must ensure "stream: true" in serialized `request`
#[crate::byot(
T0 = serde::Serialize,
R = serde::de::DeserializeOwned,
stream = "true",
where_clause = "R: std::marker::Send + 'static"
)]
#[allow(unused_mut)]
pub async fn create_stream(
&self,
mut request: CreateChatCompletionRequest,
) -> Result<ChatCompletionResponseStream, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if request.stream.is_some() && !request.stream.unwrap() {
return Err(OpenAIError::InvalidArgument(
"When stream is false, use Chat::create".into(),
));
}
request.stream = Some(true);
}
Ok(self.client.post_stream("/chat/completions", request).await)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use std::pin::Pin;
use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use reqwest::multipart::Form;
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
use serde::{de::DeserializeOwned, Serialize};
use crate::{
config::{Config, OpenAIConfig},
error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
file::Files,
image::Images,
moderation::Moderations,
traits::AsyncTryFrom,
Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
Models, Projects, Responses, Threads, Uploads, Users, VectorStores,
};
#[derive(Debug, Clone, Default)]
/// Client is a container for config, backoff and http_client
/// used to make API calls.
pub struct Client<C: Config> {
http_client: reqwest::Client,
config: C,
backoff: backoff::ExponentialBackoff,
}
impl Client<OpenAIConfig> {
/// Client with default [OpenAIConfig]
pub fn new() -> Self {
Self::default()
}
}
impl<C: Config> Client<C> {
/// Create client with a custom HTTP client, OpenAI config, and backoff.
pub fn build(
http_client: reqwest::Client,
config: C,
backoff: backoff::ExponentialBackoff,
) -> Self {
Self {
http_client,
config,
backoff,
}
}
/// Create client with [OpenAIConfig] or [crate::config::AzureConfig]
pub fn with_config(config: C) -> Self {
Self {
http_client: reqwest::Client::new(),
config,
backoff: Default::default(),
}
}
/// Provide your own [client] to make HTTP requests with.
///
/// [client]: reqwest::Client
pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
self.http_client = http_client;
self
}
/// Exponential backoff for retrying [rate limited](https://platform.openai.com/docs/guides/rate-limits) requests.
pub fn with_backoff(mut self, backoff: backoff::ExponentialBackoff) -> Self {
self.backoff = backoff;
self
}
// API groups
/// To call [Models] group related APIs using this client.
pub fn models(&self) -> Models<C> {
Models::new(self)
}
/// To call [Completions] group related APIs using this client.
pub fn completions(&self) -> Completions<C> {
Completions::new(self)
}
/// To call [Chat] group related APIs using this client.
pub fn chat(&self) -> Chat<C> {
Chat::new(self)
}
/// To call [Images] group related APIs using this client.
pub fn images(&self) -> Images<C> {
Images::new(self)
}
/// To call [Moderations] group related APIs using this client.
pub fn moderations(&self) -> Moderations<C> {
Moderations::new(self)
}
/// To call [Files] group related APIs using this client.
pub fn files(&self) -> Files<C> {
Files::new(self)
}
/// To call [Uploads] group related APIs using this client.
pub fn uploads(&self) -> Uploads<C> {
Uploads::new(self)
}
/// To call [FineTuning] group related APIs using this client.
pub fn fine_tuning(&self) -> FineTuning<C> {
FineTuning::new(self)
}
/// To call [Embeddings] group related APIs using this client.
pub fn embeddings(&self) -> Embeddings<C> {
Embeddings::new(self)
}
/// To call [Audio] group related APIs using this client.
pub fn audio(&self) -> Audio<C> {
Audio::new(self)
}
/// To call [Assistants] group related APIs using this client.
pub fn assistants(&self) -> Assistants<C> {
Assistants::new(self)
}
/// To call [Threads] group related APIs using this client.
pub fn threads(&self) -> Threads<C> {
Threads::new(self)
}
/// To call [VectorStores] group related APIs using this client.
pub fn vector_stores(&self) -> VectorStores<C> {
VectorStores::new(self)
}
/// To call [Batches] group related APIs using this client.
pub fn batches(&self) -> Batches<C> {
Batches::new(self)
}
/// To call [AuditLogs] group related APIs using this client.
pub fn audit_logs(&self) -> AuditLogs<C> {
AuditLogs::new(self)
}
/// To call [Invites] group related APIs using this client.
pub fn invites(&self) -> Invites<C> {
Invites::new(self)
}
/// To call [Users] group related APIs using this client.
pub fn users(&self) -> Users<C> {
Users::new(self)
}
/// To call [Projects] group related APIs using this client.
pub fn projects(&self) -> Projects<C> {
Projects::new(self)
}
/// To call [Responses] group related APIs using this client.
pub fn responses(&self) -> Responses<C> {
Responses::new(self)
}
pub fn config(&self) -> &C {
&self.config
}
/// Make a GET request to {path} and deserialize the response body
pub(crate) async fn get<O>(&self, path: &str) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
{
let request_maker = || async {
Ok(self
.http_client
.get(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.build()?)
};
self.execute(request_maker).await
}
/// Make a GET request to {path} with given Query and deserialize the response body
pub(crate) async fn get_with_query<Q, O>(&self, path: &str, query: &Q) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
Q: Serialize + ?Sized,
{
let request_maker = || async {
Ok(self
.http_client
.get(self.config.url(path))
.query(&self.config.query())
.query(query)
.headers(self.config.headers())
.build()?)
};
self.execute(request_maker).await
}
/// Make a DELETE request to {path} and deserialize the response body
pub(crate) async fn delete<O>(&self, path: &str) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
{
let request_maker = || async {
Ok(self
.http_client
.delete(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.build()?)
};
self.execute(request_maker).await
}
/// Make a GET request to {path} and return the response body
pub(crate) async fn get_raw(&self, path: &str) -> Result<Bytes, OpenAIError> {
let request_maker = || async {
Ok(self
.http_client
.get(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.build()?)
};
self.execute_raw(request_maker).await
}
/// Make a POST request to {path} and return the response body
pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
where
I: Serialize,
{
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.build()?)
};
self.execute_raw(request_maker).await
}
/// Make a POST request to {path} and deserialize the response body
pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
where
I: Serialize,
O: DeserializeOwned,
{
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.build()?)
};
self.execute(request_maker).await
}
/// POST a form at {path} and return the response body
pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
where
Form: AsyncTryFrom<F, Error = OpenAIError>,
F: Clone,
{
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
.build()?)
};
self.execute_raw(request_maker).await
}
/// POST a form at {path} and deserialize the response body
pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
Form: AsyncTryFrom<F, Error = OpenAIError>,
F: Clone,
{
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.multipart(<Form as AsyncTryFrom<F>>::try_from(form.clone()).await?)
.build()?)
};
self.execute(request_maker).await
}
/// Execute a HTTP request and retry on rate limit
///
/// request_maker serves one purpose: to be able to create request again
/// to retry API call after getting rate limited. request_maker is async because
/// reqwest::multipart::Form is created by async calls to read files for uploads.
async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
where
M: Fn() -> Fut,
Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
{
let client = self.http_client.clone();
backoff::future::retry(self.backoff.clone(), || async {
let request = request_maker().await.map_err(backoff::Error::Permanent)?;
let response = client
.execute(request)
.await
.map_err(OpenAIError::Reqwest)
.map_err(backoff::Error::Permanent)?;
let status = response.status();
let bytes = response
.bytes()
.await
.map_err(OpenAIError::Reqwest)
.map_err(backoff::Error::Permanent)?;
if status.is_server_error() {
// OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
let message: String = String::from_utf8_lossy(&bytes).into_owned();
tracing::warn!("Server error: {status} - {message}");
return Err(backoff::Error::Transient {
err: OpenAIError::ApiError(ApiError {
message,
r#type: None,
param: None,
code: None,
}),
retry_after: None,
});
}
// Deserialize response body from either error object or actual response object
if !status.is_success() {
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
.map_err(backoff::Error::Permanent)?;
if status.as_u16() == 429
// API returns 429 also when:
// "You exceeded your current quota, please check your plan and billing details."
&& wrapped_error.error.r#type != Some("insufficient_quota".to_string())
{
// Rate limited retry...
tracing::warn!("Rate limited: {}", wrapped_error.error.message);
return Err(backoff::Error::Transient {
err: OpenAIError::ApiError(wrapped_error.error),
retry_after: None,
});
} else {
return Err(backoff::Error::Permanent(OpenAIError::ApiError(
wrapped_error.error,
)));
}
}
Ok(bytes)
})
.await
}
/// Execute a HTTP request and retry on rate limit
///
/// request_maker serves one purpose: to be able to create request again
/// to retry API call after getting rate limited. request_maker is async because
/// reqwest::multipart::Form is created by async calls to read files for uploads.
async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
M: Fn() -> Fut,
Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
{
let bytes = self.execute_raw(request_maker).await?;
let response: O = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
Ok(response)
}
/// Make HTTP POST request to receive SSE
pub(crate) async fn post_stream<I, O>(
&self,
path: &str,
request: I,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
I: Serialize,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.eventsource()
.unwrap();
stream(event_source).await
}
pub(crate) async fn post_stream_mapped_raw_events<I, O>(
&self,
path: &str,
request: I,
event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
I: Serialize,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.eventsource()
.unwrap();
stream_mapped_raw_events(event_source, event_mapper).await
}
/// Make HTTP GET request to receive SSE
pub(crate) async fn _get_stream<Q, O>(
&self,
path: &str,
query: &Q,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
Q: Serialize + ?Sized,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = self
.http_client
.get(self.config.url(path))
.query(query)
.query(&self.config.query())
.headers(self.config.headers())
.eventsource()
.unwrap();
stream(event_source).await
}
}
/// Request which responds with SSE.
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
pub(crate) async fn stream<O>(
mut event_source: EventSource,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
O: DeserializeOwned + std::marker::Send + 'static,
{
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
while let Some(ev) = event_source.next().await {
match ev {
Err(e) => {
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
// rx dropped
break;
}
}
Ok(event) => match event {
Event::Message(message) => {
if message.data == "[DONE]" {
break;
}
let response = match serde_json::from_str::<O>(&message.data) {
Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
Ok(output) => Ok(output),
};
if let Err(_e) = tx.send(response) {
// rx dropped
break;
}
}
Event::Open => continue,
},
}
}
event_source.close();
});
Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
}
pub(crate) async fn stream_mapped_raw_events<O>(
mut event_source: EventSource,
event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
O: DeserializeOwned + std::marker::Send + 'static,
{
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
while let Some(ev) = event_source.next().await {
match ev {
Err(e) => {
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
// rx dropped
break;
}
}
Ok(event) => match event {
Event::Message(message) => {
let mut done = false;
if message.data == "[DONE]" {
done = true;
}
let response = event_mapper(message);
if let Err(_e) = tx.send(response) {
// rx dropped
break;
}
if done {
break;
}
}
Event::Open => continue,
},
}
}
event_source.close();
});
Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use crate::{
client::Client,
config::Config,
error::OpenAIError,
types::{CompletionResponseStream, CreateCompletionRequest, CreateCompletionResponse},
};
/// Given a prompt, the model will return one or more predicted completions,
/// and can also return the probabilities of alternative tokens at each position.
/// We recommend most users use our Chat completions API.
/// [Learn more](https://platform.openai.com/docs/deprecations/2023-07-06-gpt-and-embeddings)
///
/// Related guide: [Legacy Completions](https://platform.openai.com/docs/guides/gpt/completions-api)
pub struct Completions<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Completions<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Creates a completion for the provided prompt and parameters
///
/// You must ensure that "stream: false" in serialized `request`
#[crate::byot(
T0 = serde::Serialize,
R = serde::de::DeserializeOwned
)]
pub async fn create(
&self,
request: CreateCompletionRequest,
) -> Result<CreateCompletionResponse, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if request.stream.is_some() && request.stream.unwrap() {
return Err(OpenAIError::InvalidArgument(
"When stream is true, use Completion::create_stream".into(),
));
}
}
self.client.post("/completions", request).await
}
/// Creates a completion request for the provided prompt and parameters
///
/// Stream back partial progress. Tokens will be sent as data-only
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
/// as they become available, with the stream terminated by a data: \[DONE\] message.
///
/// [CompletionResponseStream] is a parsed SSE stream until a \[DONE\] is received from server.
///
/// You must ensure that "stream: true" in serialized `request`
#[crate::byot(
T0 = serde::Serialize,
R = serde::de::DeserializeOwned,
stream = "true",
where_clause = "R: std::marker::Send + 'static"
)]
#[allow(unused_mut)]
pub async fn create_stream(
&self,
mut request: CreateCompletionRequest,
) -> Result<CompletionResponseStream, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if request.stream.is_some() && !request.stream.unwrap() {
return Err(OpenAIError::InvalidArgument(
"When stream is false, use Completion::create".into(),
));
}
request.stream = Some(true);
}
Ok(self.client.post_stream("/completions", request).await)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
//! Client configurations: [OpenAIConfig] for OpenAI, [AzureConfig] for Azure OpenAI Service.
use reqwest::header::{HeaderMap, AUTHORIZATION};
use secrecy::{ExposeSecret, SecretString};
use serde::Deserialize;
/// Default v1 API base url
pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
/// Organization header
pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
/// Project header
pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
/// Calls to the Assistants API require that you pass a Beta header
pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
/// [crate::Client] relies on this for every API call on OpenAI
/// or Azure OpenAI service
pub trait Config: Send + Sync {
fn headers(&self) -> HeaderMap;
fn url(&self, path: &str) -> String;
fn query(&self) -> Vec<(&str, &str)>;
fn api_base(&self) -> &str;
fn api_key(&self) -> &SecretString;
}
/// Macro to implement Config trait for pointer types with dyn objects
macro_rules! impl_config_for_ptr {
($t:ty) => {
impl Config for $t {
fn headers(&self) -> HeaderMap {
self.as_ref().headers()
}
fn url(&self, path: &str) -> String {
self.as_ref().url(path)
}
fn query(&self) -> Vec<(&str, &str)> {
self.as_ref().query()
}
fn api_base(&self) -> &str {
self.as_ref().api_base()
}
fn api_key(&self) -> &SecretString {
self.as_ref().api_key()
}
}
};
}
impl_config_for_ptr!(Box<dyn Config>);
impl_config_for_ptr!(std::sync::Arc<dyn Config>);
/// Configuration for OpenAI API
#[derive(Clone, Debug, Deserialize)]
#[serde(default)]
pub struct OpenAIConfig {
api_base: String,
api_key: SecretString,
org_id: String,
project_id: String,
}
impl Default for OpenAIConfig {
fn default() -> Self {
Self {
api_base: OPENAI_API_BASE.to_string(),
api_key: std::env::var("OPENAI_API_KEY")
.unwrap_or_else(|_| "".to_string())
.into(),
org_id: Default::default(),
project_id: Default::default(),
}
}
}
impl OpenAIConfig {
/// Create client with default [OPENAI_API_BASE] url and default API key from OPENAI_API_KEY env var
pub fn new() -> Self {
Default::default()
}
/// To use a different organization id other than default
pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
self.org_id = org_id.into();
self
}
/// Non default project id
pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
self.project_id = project_id.into();
self
}
/// To use a different API key different from default OPENAI_API_KEY env var
pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
self.api_key = SecretString::from(api_key.into());
self
}
/// To use a API base url different from default [OPENAI_API_BASE]
pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
self.api_base = api_base.into();
self
}
pub fn org_id(&self) -> &str {
&self.org_id
}
}
impl Config for OpenAIConfig {
fn headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
if !self.org_id.is_empty() {
headers.insert(
OPENAI_ORGANIZATION_HEADER,
self.org_id.as_str().parse().unwrap(),
);
}
if !self.project_id.is_empty() {
headers.insert(
OPENAI_PROJECT_HEADER,
self.project_id.as_str().parse().unwrap(),
);
}
headers.insert(
AUTHORIZATION,
format!("Bearer {}", self.api_key.expose_secret())
.as_str()
.parse()
.unwrap(),
);
// hack for Assistants APIs
// Calls to the Assistants API require that you pass a Beta header
headers.insert(OPENAI_BETA_HEADER, "assistants=v2".parse().unwrap());
headers
}
fn url(&self, path: &str) -> String {
format!("{}{}", self.api_base, path)
}
fn api_base(&self) -> &str {
&self.api_base
}
fn api_key(&self) -> &SecretString {
&self.api_key
}
fn query(&self) -> Vec<(&str, &str)> {
vec![]
}
}
/// Configuration for Azure OpenAI Service
#[derive(Clone, Debug, Deserialize)]
#[serde(default)]
pub struct AzureConfig {
api_version: String,
deployment_id: String,
api_base: String,
api_key: SecretString,
}
impl Default for AzureConfig {
fn default() -> Self {
Self {
api_base: Default::default(),
api_key: std::env::var("OPENAI_API_KEY")
.unwrap_or_else(|_| "".to_string())
.into(),
deployment_id: Default::default(),
api_version: Default::default(),
}
}
}
impl AzureConfig {
pub fn new() -> Self {
Default::default()
}
pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
self.api_version = api_version.into();
self
}
pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
self.deployment_id = deployment_id.into();
self
}
/// To use a different API key different from default OPENAI_API_KEY env var
pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
self.api_key = SecretString::from(api_key.into());
self
}
/// API base url in form of <https://your-resource-name.openai.azure.com>
pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
self.api_base = api_base.into();
self
}
}
impl Config for AzureConfig {
fn headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
headers
}
fn url(&self, path: &str) -> String {
format!(
"{}/openai/deployments/{}{}",
self.api_base, self.deployment_id, path
)
}
fn api_base(&self) -> &str {
&self.api_base
}
fn api_key(&self) -> &SecretString {
&self.api_key
}
fn query(&self) -> Vec<(&str, &str)> {
vec![("api-version", &self.api_version)]
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
};
use crate::Client;
use std::sync::Arc;
#[test]
fn test_client_creation() {
unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
let openai_config = OpenAIConfig::default();
let config = Box::new(openai_config.clone()) as Box<dyn Config>;
let client = Client::with_config(config);
assert!(client.config().url("").ends_with("/v1"));
let config = Arc::new(openai_config) as Arc<dyn Config>;
let client = Client::with_config(config);
assert!(client.config().url("").ends_with("/v1"));
let cloned_client = client.clone();
assert!(cloned_client.config().url("").ends_with("/v1"));
}
async fn dynamic_dispatch_compiles(client: &Client<Box<dyn Config>>) {
let _ = client.chat().create(CreateChatCompletionRequest {
model: "gpt-4o".to_string(),
messages: vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: "Hello, world!".into(),
..Default::default()
},
)],
..Default::default()
});
}
#[tokio::test]
async fn test_dynamic_dispatch() {
let openai_config = OpenAIConfig::default();
let azure_config = AzureConfig::default();
let azure_client = Client::with_config(Box::new(azure_config.clone()) as Box<dyn Config>);
let oai_client = Client::with_config(Box::new(openai_config.clone()) as Box<dyn Config>);
let _ = dynamic_dispatch_compiles(&azure_client).await;
let _ = dynamic_dispatch_compiles(&oai_client).await;
let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&azure_client).await });
let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&oai_client).await });
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use std::path::{Path, PathBuf};
use base64::{engine::general_purpose, Engine as _};
use rand::{distr::Alphanumeric, Rng};
use reqwest::Url;
use crate::error::OpenAIError;
fn create_paths<P: AsRef<Path>>(url: &Url, base_dir: P) -> (PathBuf, PathBuf) {
let mut dir = PathBuf::from(base_dir.as_ref());
let mut path = dir.clone();
let segments = url.path_segments().map(|c| c.collect::<Vec<_>>());
if let Some(segments) = segments {
for (idx, segment) in segments.iter().enumerate() {
if idx != segments.len() - 1 {
dir.push(segment);
}
path.push(segment);
}
}
(dir, path)
}
pub(crate) async fn download_url<P: AsRef<Path>>(
url: &str,
dir: P,
) -> Result<PathBuf, OpenAIError> {
let parsed_url = Url::parse(url).map_err(|e| OpenAIError::FileSaveError(e.to_string()))?;
let response = reqwest::get(url)
.await
.map_err(|e| OpenAIError::FileSaveError(e.to_string()))?;
if !response.status().is_success() {
return Err(OpenAIError::FileSaveError(format!(
"couldn't download file, status: {}, url: {url}",
response.status()
)));
}
let (dir, file_path) = create_paths(&parsed_url, dir);
tokio::fs::create_dir_all(dir.as_path())
.await
.map_err(|e| OpenAIError::FileSaveError(format!("{}, dir: {}", e, dir.display())))?;
tokio::fs::write(
file_path.as_path(),
response.bytes().await.map_err(|e| {
OpenAIError::FileSaveError(format!("{}, file path: {}", e, file_path.display()))
})?,
)
.await
.map_err(|e| OpenAIError::FileSaveError(e.to_string()))?;
Ok(file_path)
}
pub(crate) async fn save_b64<P: AsRef<Path>>(b64: &str, dir: P) -> Result<PathBuf, OpenAIError> {
let filename: String = rand::rng()
.sample_iter(&Alphanumeric)
.take(10)
.map(char::from)
.collect();
let filename = format!("{filename}.png");
let path = PathBuf::from(dir.as_ref()).join(filename);
tokio::fs::write(
path.as_path(),
general_purpose::STANDARD
.decode(b64)
.map_err(|e| OpenAIError::FileSaveError(e.to_string()))?,
)
.await
.map_err(|e| OpenAIError::FileSaveError(format!("{}, path: {}", e, path.display())))?;
Ok(path)
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use crate::{
config::Config,
error::OpenAIError,
types::{CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse},
Client,
};
#[cfg(not(feature = "byot"))]
use crate::types::EncodingFormat;
/// Get a vector representation of a given input that can be easily
/// consumed by machine learning models and algorithms.
///
/// Related guide: [Embeddings](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)
pub struct Embeddings<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Embeddings<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Creates an embedding vector representing the input text.
///
/// byot: In serialized `request` you must ensure "encoding_format" is not "base64"
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn create(
&self,
request: CreateEmbeddingRequest,
) -> Result<CreateEmbeddingResponse, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
return Err(OpenAIError::InvalidArgument(
"When encoding_format is base64, use Embeddings::create_base64".into(),
));
}
}
self.client.post("/embeddings", request).await
}
/// Creates an embedding vector representing the input text.
///
/// The response will contain the embedding in base64 format.
///
/// byot: In serialized `request` you must ensure "encoding_format" is "base64"
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn create_base64(
&self,
request: CreateEmbeddingRequest,
) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
return Err(OpenAIError::InvalidArgument(
"When encoding_format is not base64, use Embeddings::create".into(),
));
}
}
self.client.post("/embeddings", request).await
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
//! Errors originating from API calls, parsing responses, and reading-or-writing to the file system.
use serde::{Deserialize, Serialize};
#[derive(Debug, thiserror::Error)]
pub enum OpenAIError {
/// Underlying error from reqwest library after an API call was made
#[error("http error: {0}")]
Reqwest(#[from] reqwest::Error),
/// OpenAI returns error object with details of API call failure
#[error("{0}")]
ApiError(ApiError),
/// Error when a response cannot be deserialized into a Rust type
#[error("failed to deserialize api response: {0}")]
JSONDeserialize(serde_json::Error),
/// Error on the client side when saving file to file system
#[error("failed to save file: {0}")]
FileSaveError(String),
/// Error on the client side when reading file from file system
#[error("failed to read file: {0}")]
FileReadError(String),
/// Error on SSE streaming
#[error("stream failed: {0}")]
StreamError(String),
/// Error from client side validation
/// or when builder fails to build request before making API call
#[error("invalid args: {0}")]
InvalidArgument(String),
}
/// OpenAI API returns error object on failure
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ApiError {
pub message: String,
pub r#type: Option<String>,
pub param: Option<String>,
pub code: Option<String>,
}
impl std::fmt::Display for ApiError {
/// If all fields are available, `ApiError` is formatted as:
/// `{type}: {message} (param: {param}) (code: {code})`
/// Otherwise, missing fields will be ignored.
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut parts = Vec::new();
if let Some(r#type) = &self.r#type {
parts.push(format!("{}:", r#type));
}
parts.push(self.message.clone());
if let Some(param) = &self.param {
parts.push(format!("(param: {param})"));
}
if let Some(code) = &self.code {
parts.push(format!("(code: {code})"));
}
write!(f, "{}", parts.join(" "))
}
}
/// Wrapper to deserialize the error object nested in "error" JSON key
#[derive(Debug, Deserialize, Serialize)]
pub struct WrappedError {
pub error: ApiError,
}
pub(crate) fn map_deserialization_error(e: serde_json::Error, bytes: &[u8]) -> OpenAIError {
tracing::error!(
"failed deserialization of: {}",
String::from_utf8_lossy(bytes)
);
OpenAIError::JSONDeserialize(e)
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use bytes::Bytes;
use serde::Serialize;
use crate::{
config::Config,
error::OpenAIError,
types::{CreateFileRequest, DeleteFileResponse, ListFilesResponse, OpenAIFile},
Client,
};
/// Files are used to upload documents that can be used with features like Assistants and Fine-tuning.
pub struct Files<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Files<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Upload a file that can be used across various endpoints. Individual files can be up to 512 MB, and the size of all files uploaded by one organization can be up to 100 GB.
///
/// The Assistants API supports files up to 2 million tokens and of specific file types. See the [Assistants Tools guide](https://platform.openai.com/docs/assistants/tools) for details.
///
/// The Fine-tuning API only supports `.jsonl` files. The input also has certain required formats for fine-tuning [chat](https://platform.openai.com/docs/api-reference/fine-tuning/chat-input) or [completions](https://platform.openai.com/docs/api-reference/fine-tuning/completions-input) models.
///
///The Batch API only supports `.jsonl` files up to 100 MB in size. The input also has a specific required [format](https://platform.openai.com/docs/api-reference/batch/request-input).
///
/// Please [contact us](https://help.openai.com/) if you need to increase these storage limits.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn create(&self, request: CreateFileRequest) -> Result<OpenAIFile, OpenAIError> {
self.client.post_form("/files", request).await
}
/// Returns a list of files that belong to the user's organization.
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn list<Q>(&self, query: &Q) -> Result<ListFilesResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client.get_with_query("/files", &query).await
}
/// Returns information about a specific file.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn retrieve(&self, file_id: &str) -> Result<OpenAIFile, OpenAIError> {
self.client.get(format!("/files/{file_id}").as_str()).await
}
/// Delete a file.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn delete(&self, file_id: &str) -> Result<DeleteFileResponse, OpenAIError> {
self.client
.delete(format!("/files/{file_id}").as_str())
.await
}
/// Returns the contents of the specified file
pub async fn content(&self, file_id: &str) -> Result<Bytes, OpenAIError> {
self.client
.get_raw(format!("/files/{file_id}/content").as_str())
.await
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use serde::Serialize;
use crate::{
config::Config,
error::OpenAIError,
types::{
CreateFineTuningJobRequest, FineTuningJob, ListFineTuningJobCheckpointsResponse,
ListFineTuningJobEventsResponse, ListPaginatedFineTuningJobsResponse,
},
Client,
};
/// Manage fine-tuning jobs to tailor a model to your specific training data.
///
/// Related guide: [Fine-tune models](https://platform.openai.com/docs/guides/fine-tuning)
pub struct FineTuning<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> FineTuning<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Creates a job that fine-tunes a specified model from a given dataset.
///
/// Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete.
///
/// [Learn more about Fine-tuning](https://platform.openai.com/docs/guides/fine-tuning)
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn create(
&self,
request: CreateFineTuningJobRequest,
) -> Result<FineTuningJob, OpenAIError> {
self.client.post("/fine_tuning/jobs", request).await
}
/// List your organization's fine-tuning jobs
#[crate::byot(T0 = serde::Serialize, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn list_paginated<Q>(
&self,
query: &Q,
) -> Result<ListPaginatedFineTuningJobsResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client
.get_with_query("/fine_tuning/jobs", &query)
.await
}
/// Gets info about the fine-tune job.
///
/// [Learn more about Fine-tuning](https://platform.openai.com/docs/guides/fine-tuning)
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn retrieve(&self, fine_tuning_job_id: &str) -> Result<FineTuningJob, OpenAIError> {
self.client
.get(format!("/fine_tuning/jobs/{fine_tuning_job_id}").as_str())
.await
}
/// Immediately cancel a fine-tune job.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn cancel(&self, fine_tuning_job_id: &str) -> Result<FineTuningJob, OpenAIError> {
self.client
.post(
format!("/fine_tuning/jobs/{fine_tuning_job_id}/cancel").as_str(),
(),
)
.await
}
/// Get fine-grained status updates for a fine-tune job.
#[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn list_events<Q>(
&self,
fine_tuning_job_id: &str,
query: &Q,
) -> Result<ListFineTuningJobEventsResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client
.get_with_query(
format!("/fine_tuning/jobs/{fine_tuning_job_id}/events").as_str(),
&query,
)
.await
}
#[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn list_checkpoints<Q>(
&self,
fine_tuning_job_id: &str,
query: &Q,
) -> Result<ListFineTuningJobCheckpointsResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client
.get_with_query(
format!("/fine_tuning/jobs/{fine_tuning_job_id}/checkpoints").as_str(),
&query,
)
.await
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use crate::{
config::Config,
error::OpenAIError,
types::{
CreateImageEditRequest, CreateImageRequest, CreateImageVariationRequest, ImagesResponse,
},
Client,
};
/// Given a prompt and/or an input image, the model will generate a new image.
///
/// Related guide: [Image generation](https://platform.openai.com/docs/guides/images)
pub struct Images<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Images<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Creates an image given a prompt.
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn create(&self, request: CreateImageRequest) -> Result<ImagesResponse, OpenAIError> {
self.client.post("/images/generations", request).await
}
/// Creates an edited or extended image given an original image and a prompt.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn create_edit(
&self,
request: CreateImageEditRequest,
) -> Result<ImagesResponse, OpenAIError> {
self.client.post_form("/images/edits", request).await
}
/// Creates a variation of a given image.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn create_variation(
&self,
request: CreateImageVariationRequest,
) -> Result<ImagesResponse, OpenAIError> {
self.client.post_form("/images/variations", request).await
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment