Unverified Commit fd5cc288 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

refactor(3/3): switch dynamo-protocols to upstream async-openai types (#7625)


Co-authored-by: default avatarDmitry Tokarev <dtokarev@nvidia.com>
parent d517fb80
...@@ -13,25 +13,17 @@ use dynamo_llm::protocols::{ ...@@ -13,25 +13,17 @@ use dynamo_llm::protocols::{
}, },
}; };
use dynamo_llm::{ use dynamo_llm::{
http::{ http::service::{
client::{ Metrics,
GenericBYOTClient, HttpClientConfig, HttpRequestContext, NvCustomClient, error::HttpError,
PureOpenAIClient, metrics::{Endpoint, ErrorType, RequestType, Status},
}, service_v2::HttpService,
service::{
Metrics,
error::HttpError,
metrics::{Endpoint, ErrorType, RequestType, Status},
service_v2::HttpService,
},
}, },
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard,
}; };
use dynamo_protocols::config::OpenAIConfig;
use dynamo_runtime::metrics::prometheus_names::{frontend_service, name_prefix}; use dynamo_runtime::metrics::prometheus_names::{frontend_service, name_prefix};
use dynamo_runtime::{ use dynamo_runtime::{
CancellationToken, CancellationToken,
engine::AsyncEngineContext,
pipeline::{ pipeline::{
AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn, async_trait, AsyncEngine, AsyncEngineContextProvider, ManyOut, ResponseStream, SingleIn, async_trait,
}, },
...@@ -582,398 +574,10 @@ async fn wait_for_service_ready(port: u16) { ...@@ -582,398 +574,10 @@ async fn wait_for_service_ready(port: u16) {
} }
} }
async fn service_with_engines() -> (HttpService, Arc<CounterEngine>, Arc<AlwaysFailEngine>, u16) { // NOTE: BYOT (Bring Your Own Type) client tests were removed during the
let port = get_random_port().await; // upstream async-openai migration. They depended on the forked
let service = HttpService::builder() // dynamo_protocols::config and http::client modules which no longer exist.
.enable_chat_endpoints(true) // TODO: Rewrite these tests using the upstream async-openai client.
.enable_cmpl_endpoints(true)
.port(port)
.build()
.unwrap();
let manager = service.model_manager();
let counter = Arc::new(CounterEngine {});
let failure = Arc::new(AlwaysFailEngine {});
let card = ModelDeploymentCard::with_name_only("foo");
manager
.add_chat_completions_model("foo", card.mdcsum(), counter.clone())
.unwrap();
let card = ModelDeploymentCard::with_name_only("bar");
manager
.add_chat_completions_model("bar", card.mdcsum(), failure.clone())
.unwrap();
manager
.add_completions_model("bar", card.mdcsum(), failure.clone())
.unwrap();
(service, counter, failure, port)
}
fn pure_openai_client(port: u16) -> PureOpenAIClient {
let config = HttpClientConfig {
openai_config: OpenAIConfig::new().with_api_base(format!("http://localhost:{}/v1", port)),
verbose: false,
};
PureOpenAIClient::new(config)
}
fn nv_custom_client(port: u16) -> NvCustomClient {
let config = HttpClientConfig {
openai_config: OpenAIConfig::new().with_api_base(format!("http://localhost:{}/v1", port)),
verbose: false,
};
NvCustomClient::new(config)
}
fn generic_byot_client(port: u16) -> GenericBYOTClient {
let config = HttpClientConfig {
openai_config: OpenAIConfig::new().with_api_base(format!("http://localhost:{}/v1", port)),
verbose: false,
};
GenericBYOTClient::new(config)
}
#[tokio::test]
async fn test_pure_openai_client() {
let (service, _counter, _failure, port) = service_with_engines().await;
let pure_openai_client = pure_openai_client(port);
let token = CancellationToken::new();
let cancel_token = token.clone();
// Start the service
let task = tokio::spawn(async move { service.run(token).await });
// Wait for service to be ready
wait_for_service_ready(port).await;
// Test successful streaming request
let request = dynamo_protocols::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.messages(vec![
dynamo_protocols::types::ChatCompletionRequestMessage::User(
dynamo_protocols::types::ChatCompletionRequestUserMessage {
content: dynamo_protocols::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let result = pure_openai_client.chat_stream(request).await;
assert!(result.is_ok(), "PureOpenAI client should succeed");
let (mut stream, _context) = result.unwrap().dissolve();
let mut count = 0;
while let Some(response) = stream.next().await {
count += 1;
assert!(response.is_ok(), "Response should be ok");
if count >= 3 {
break; // Don't consume entire stream
}
}
assert!(count > 0, "Should receive at least one response");
// Test error case with invalid model
let request = dynamo_protocols::types::CreateChatCompletionRequestArgs::default()
.model("bar") // This model will fail
.messages(vec![
dynamo_protocols::types::ChatCompletionRequestMessage::User(
dynamo_protocols::types::ChatCompletionRequestUserMessage {
content: dynamo_protocols::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let result = pure_openai_client.chat_stream(request).await;
assert!(
result.is_ok(),
"Client should return stream even for failing model"
);
let (mut stream, _context) = result.unwrap().dissolve();
if let Some(response) = stream.next().await {
assert!(
response.is_err(),
"Response should be error for failing model"
);
}
// Test context management
let ctx = HttpRequestContext::new();
let request = dynamo_protocols::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.messages(vec![
dynamo_protocols::types::ChatCompletionRequestMessage::User(
dynamo_protocols::types::ChatCompletionRequestUserMessage {
content: dynamo_protocols::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let result = pure_openai_client
.chat_stream_with_context(request, ctx.clone())
.await;
assert!(result.is_ok(), "Context-based request should succeed");
let (_stream, context) = result.unwrap().dissolve();
assert_eq!(context.id(), ctx.id(), "Context ID should match");
cancel_token.cancel();
task.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_nv_custom_client() {
let (service, _counter, _failure, port) = service_with_engines().await;
let nv_custom_client = nv_custom_client(port);
let token = CancellationToken::new();
let cancel_token = token.clone();
// Start the service
let task = tokio::spawn(async move { service.run(token).await });
// Wait for service to be ready
wait_for_service_ready(port).await;
// Test successful streaming request
let inner_request = dynamo_protocols::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.messages(vec![
dynamo_protocols::types::ChatCompletionRequestMessage::User(
dynamo_protocols::types::ChatCompletionRequestUserMessage {
content: dynamo_protocols::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let request = NvCreateChatCompletionRequest {
inner: inner_request,
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = nv_custom_client.chat_stream(request).await;
assert!(result.is_ok(), "NvCustom client should succeed");
let (mut stream, _context) = result.unwrap().dissolve();
let mut count = 0;
while let Some(response) = stream.next().await {
count += 1;
assert!(response.is_ok(), "Response should be ok");
if count >= 3 {
break; // Don't consume entire stream
}
}
assert!(count > 0, "Should receive at least one response");
// Test error case with invalid model
let inner_request = dynamo_protocols::types::CreateChatCompletionRequestArgs::default()
.model("bar") // This model will fail
.messages(vec![
dynamo_protocols::types::ChatCompletionRequestMessage::User(
dynamo_protocols::types::ChatCompletionRequestUserMessage {
content: dynamo_protocols::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let request = NvCreateChatCompletionRequest {
inner: inner_request,
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = nv_custom_client.chat_stream(request).await;
assert!(
result.is_ok(),
"Client should return stream even for failing model"
);
let (mut stream, _context) = result.unwrap().dissolve();
if let Some(response) = stream.next().await {
assert!(
response.is_err(),
"Response should be error for failing model"
);
}
// Test context management
let ctx = HttpRequestContext::new();
let inner_request = dynamo_protocols::types::CreateChatCompletionRequestArgs::default()
.model("foo")
.messages(vec![
dynamo_protocols::types::ChatCompletionRequestMessage::User(
dynamo_protocols::types::ChatCompletionRequestUserMessage {
content: dynamo_protocols::types::ChatCompletionRequestUserMessageContent::Text(
"Hi".to_string(),
),
name: None,
},
),
])
.stream(true)
.max_tokens(50u32)
.build()
.unwrap();
let request = NvCreateChatCompletionRequest {
inner: inner_request,
common: Default::default(),
nvext: None,
chat_template_args: None,
media_io_kwargs: None,
unsupported_fields: Default::default(),
};
let result = nv_custom_client
.chat_stream_with_context(request, ctx.clone())
.await;
assert!(result.is_ok(), "Context-based request should succeed");
let (_stream, context) = result.unwrap().dissolve();
assert_eq!(context.id(), ctx.id(), "Context ID should match");
cancel_token.cancel();
task.await.unwrap().unwrap();
}
#[tokio::test]
async fn test_generic_byot_client() {
let (service, _counter, _failure, port) = service_with_engines().await;
let generic_byot_client = generic_byot_client(port);
let token = CancellationToken::new();
let cancel_token = token.clone();
// Start the service
let task = tokio::spawn(async move { service.run(token).await });
// Wait for service to be ready
wait_for_service_ready(port).await;
// Test successful streaming request
let request = serde_json::json!({
"model": "foo",
"messages": [
{
"role": "user",
"content": "Hi"
}
],
"stream": true,
"max_tokens": 50
});
let result = generic_byot_client.chat_stream(request).await;
assert!(result.is_ok(), "GenericBYOT client should succeed");
let (mut stream, _context) = result.unwrap().dissolve();
let mut count = 0;
while let Some(response) = stream.next().await {
println!("Response: {:?}", response);
count += 1;
assert!(response.is_ok(), "Response should be ok");
if count >= 3 {
break; // Don't consume entire stream
}
}
assert!(count > 0, "Should receive at least one response");
// Test error case with invalid model
let request = serde_json::json!({
"model": "bar", // This model will fail
"messages": [
{
"role": "user",
"content": "Hi"
}
],
"stream": true,
"max_tokens": 50
});
let result = generic_byot_client.chat_stream(request).await;
assert!(
result.is_ok(),
"Client should return stream even for failing model"
);
let (mut stream, _context) = result.unwrap().dissolve();
if let Some(response) = stream.next().await {
assert!(
response.is_err(),
"Response should be error for failing model"
);
}
// Test context management
let ctx = HttpRequestContext::new();
let request = serde_json::json!({
"model": "foo",
"messages": [
{
"role": "user",
"content": "Hi"
}
],
"stream": true,
"max_tokens": 50
});
let result = generic_byot_client
.chat_stream_with_context(request, ctx.clone())
.await;
assert!(result.is_ok(), "Context-based request should succeed");
let (_stream, context) = result.unwrap().dissolve();
assert_eq!(context.id(), ctx.id(), "Context ID should match");
cancel_token.cancel();
task.await.unwrap().unwrap();
}
#[tokio::test] #[tokio::test]
async fn test_client_disconnect_cancellation_unary() { async fn test_client_disconnect_cancellation_unary() {
let port = get_random_port().await; let port = get_random_port().await;
......
...@@ -381,7 +381,6 @@ fn create_response_with_linear_probs( ...@@ -381,7 +381,6 @@ fn create_response_with_linear_probs(
index: 0, index: 0,
delta: ChatCompletionStreamResponseDelta { delta: ChatCompletionStreamResponseDelta {
content: Some(ChatCompletionMessageContent::Text(_content.to_string())), content: Some(ChatCompletionMessageContent::Text(_content.to_string())),
#[expect(deprecated)]
function_call: None, function_call: None,
tool_calls: None, tool_calls: None,
role: Some(Role::Assistant), role: Some(Role::Assistant),
...@@ -463,7 +462,6 @@ fn create_multi_choice_response( ...@@ -463,7 +462,6 @@ fn create_multi_choice_response(
index: choice_idx as u32, index: choice_idx as u32,
delta: ChatCompletionStreamResponseDelta { delta: ChatCompletionStreamResponseDelta {
content: Some(ChatCompletionMessageContent::Text("test".to_string())), content: Some(ChatCompletionMessageContent::Text("test".to_string())),
#[expect(deprecated)]
function_call: None, function_call: None,
tool_calls: None, tool_calls: None,
role: Some(Role::Assistant), role: Some(Role::Assistant),
......
...@@ -55,7 +55,10 @@ fn parse_fixture( ...@@ -55,7 +55,10 @@ fn parse_fixture(
let value: Value = serde_json::from_str(line).unwrap(); let value: Value = serde_json::from_str(line).unwrap();
let chunk: NvCreateChatCompletionStreamResponse = let chunk: NvCreateChatCompletionStreamResponse =
serde_json::from_value(value.clone()).unwrap(); serde_json::from_value(value.clone()).unwrap();
expected_stream_json.push(value); // Round-trip through the typed struct so expected JSON matches current serialization
// (upstream async-openai skips None fields that the old fork serialized as null).
let normalized = serde_json::to_value(&chunk).unwrap();
expected_stream_json.push(normalized);
input_chunks.push(chunk); input_chunks.push(chunk);
} }
......
...@@ -2493,7 +2493,7 @@ mod parallel_jail_tests { ...@@ -2493,7 +2493,7 @@ mod parallel_jail_tests {
assert_eq!( assert_eq!(
tool_call.r#type, tool_call.r#type,
Some(dynamo_protocols::types::ChatCompletionToolType::Function), Some(dynamo_protocols::types::FunctionType::Function),
"Tool call {} should be of type 'function'", "Tool call {} should be of type 'function'",
i i
); );
......
...@@ -7,7 +7,7 @@ use dynamo_protocols::types::{ ...@@ -7,7 +7,7 @@ use dynamo_protocols::types::{
ChatCompletionMessageContent, ChatCompletionNamedToolChoice, ChatCompletionRequestMessage, ChatCompletionMessageContent, ChatCompletionNamedToolChoice, ChatCompletionRequestMessage,
ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent,
ChatCompletionToolChoiceOption, ChatCompletionToolType, CreateChatCompletionRequest, ChatCompletionToolChoiceOption, ChatCompletionToolType, CreateChatCompletionRequest,
FunctionName, FunctionName, FunctionType,
}; };
/// Helper to extract text from ChatCompletionMessageContent /// Helper to extract text from ChatCompletionMessageContent
...@@ -172,7 +172,7 @@ async fn test_named_tool_choice_parses_json() { ...@@ -172,7 +172,7 @@ async fn test_named_tool_choice_parses_json() {
let tool_call = &tool_calls[0]; let tool_call = &tool_calls[0];
assert_eq!(tool_call.index, 0); assert_eq!(tool_call.index, 0);
assert!(tool_call.id.as_ref().unwrap().starts_with("call-")); assert!(tool_call.id.as_ref().unwrap().starts_with("call-"));
assert_eq!(tool_call.r#type, Some(ChatCompletionToolType::Function)); assert_eq!(tool_call.r#type, Some(FunctionType::Function));
assert_eq!( assert_eq!(
tool_call.function.as_ref().unwrap().name.as_deref(), tool_call.function.as_ref().unwrap().name.as_deref(),
Some("get_weather") Some("get_weather")
...@@ -213,7 +213,7 @@ async fn test_required_tool_choice_parses_json_array() { ...@@ -213,7 +213,7 @@ async fn test_required_tool_choice_parses_json_array() {
assert_eq!(tool_calls[0].index, 0); assert_eq!(tool_calls[0].index, 0);
assert!(tool_calls[0].id.as_ref().unwrap().starts_with("call-")); assert!(tool_calls[0].id.as_ref().unwrap().starts_with("call-"));
assert_eq!(tool_calls[0].r#type, Some(ChatCompletionToolType::Function)); assert_eq!(tool_calls[0].r#type, Some(FunctionType::Function));
assert_eq!( assert_eq!(
tool_calls[0].function.as_ref().unwrap().name.as_deref(), tool_calls[0].function.as_ref().unwrap().name.as_deref(),
Some("search") Some("search")
...@@ -230,7 +230,7 @@ async fn test_required_tool_choice_parses_json_array() { ...@@ -230,7 +230,7 @@ async fn test_required_tool_choice_parses_json_array() {
assert_eq!(tool_calls[1].index, 1); assert_eq!(tool_calls[1].index, 1);
assert!(tool_calls[1].id.as_ref().unwrap().starts_with("call-")); assert!(tool_calls[1].id.as_ref().unwrap().starts_with("call-"));
assert_eq!(tool_calls[1].r#type, Some(ChatCompletionToolType::Function)); assert_eq!(tool_calls[1].r#type, Some(FunctionType::Function));
assert_eq!( assert_eq!(
tool_calls[1].function.as_ref().unwrap().name.as_deref(), tool_calls[1].function.as_ref().unwrap().name.as_deref(),
Some("summarize") Some("summarize")
......
...@@ -30,7 +30,7 @@ pub async fn try_tool_call_parse_aggregate( ...@@ -30,7 +30,7 @@ pub async fn try_tool_call_parse_aggregate(
.map( .map(
|parsed| dynamo_protocols::types::ChatCompletionMessageToolCall { |parsed| dynamo_protocols::types::ChatCompletionMessageToolCall {
id: parsed.id, id: parsed.id,
r#type: dynamo_protocols::types::ChatCompletionToolType::Function, r#type: dynamo_protocols::types::FunctionType::Function,
function: dynamo_protocols::types::FunctionCall { function: dynamo_protocols::types::FunctionCall {
name: parsed.function.name, name: parsed.function.name,
arguments: parsed.function.arguments, arguments: parsed.function.arguments,
...@@ -65,7 +65,7 @@ pub async fn try_tool_call_parse_stream( ...@@ -65,7 +65,7 @@ pub async fn try_tool_call_parse_stream(
|(idx, parsed)| dynamo_protocols::types::ChatCompletionMessageToolCallChunk { |(idx, parsed)| dynamo_protocols::types::ChatCompletionMessageToolCallChunk {
index: idx as u32, index: idx as u32,
id: Some(parsed.id), id: Some(parsed.id),
r#type: Some(dynamo_protocols::types::ChatCompletionToolType::Function), r#type: Some(dynamo_protocols::types::FunctionType::Function),
function: Some(dynamo_protocols::types::FunctionCallStream { function: Some(dynamo_protocols::types::FunctionCallStream {
name: Some(parsed.function.name), name: Some(parsed.function.name),
arguments: Some(parsed.function.arguments), arguments: Some(parsed.function.arguments),
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
[package] [package]
name = "dynamo-protocols" name = "dynamo-protocols"
description = "Protocol types for OpenAI-compatible inference APIs, forked from async-openai." description = "Protocol types for OpenAI-compatible inference APIs with inference-serving extensions."
license = "Apache-2.0 AND MIT" license = "Apache-2.0 AND MIT"
version.workspace = true version.workspace = true
edition.workspace = true edition.workspace = true
...@@ -19,44 +19,35 @@ homepage.workspace = true ...@@ -19,44 +19,35 @@ homepage.workspace = true
repository.workspace = true repository.workspace = true
readme.workspace = true readme.workspace = true
[features]
realtime = ["dep:tokio-tungstenite"]
# Bring your own types
byot = []
[dependencies] [dependencies]
futures = { workspace = true } # Upstream OpenAI types (types-only, no HTTP client)
rand = { workspace = true } async-openai = { version = "0.34", default-features = false, features = [
reqwest = { workspace = true } "chat-completion-types",
"response-types",
"completion-types",
"embedding-types",
"image-types",
] }
# Serialization
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
derive_builder = { workspace = true }
# Type support
thiserror = { workspace = true } thiserror = { workspace = true }
tokio = { workspace = true }
tokio-stream = { workspace = true }
tokio-util = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
url = { workspace = true } url = { workspace = true }
uuid = { workspace = true } uuid = { workspace = true }
derive_builder = { workspace = true } futures = { workspace = true }
bytes = { workspace = true }
eventsource-stream = "0.2.3"
async-openai-macros = "0.1.0"
backoff = { version = "0.4.0", features = ["tokio"] }
base64 = "0.22.1"
reqwest-eventsource = "0.6.0"
secrecy = { version = "0.10.3", features = ["serde"] }
tokio-tungstenite = { version = "0.26.1", optional = true, default-features = false }
utoipa = { version = "5.3", features = ["url", "uuid"] }
[dev-dependencies] [dev-dependencies]
tokio = { workspace = true }
tokio-test = "0.4.4" tokio-test = "0.4.4"
serde_json = { workspace = true } serde_json = { workspace = true }
[[test]]
name = "bring-your-own-type"
required-features = ["byot"]
[package.metadata.docs.rs] [package.metadata.docs.rs]
all-features = true all-features = true
rustdoc-args = ["--cfg", "docsrs"] rustdoc-args = ["--cfg", "docsrs"]
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use serde::Serialize;
use crate::{
Client,
config::Config,
error::OpenAIError,
types::{
AssistantObject, CreateAssistantRequest, DeleteAssistantResponse, ListAssistantsResponse,
ModifyAssistantRequest,
},
};
/// Build assistants that can call models and use tools to perform tasks.
///
/// [Get started with the Assistants API](https://platform.openai.com/docs/assistants)
pub struct Assistants<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Assistants<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Create an assistant with a model and instructions.
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn create(
&self,
request: CreateAssistantRequest,
) -> Result<AssistantObject, OpenAIError> {
self.client.post("/assistants", request).await
}
/// Retrieves an assistant.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn retrieve(&self, assistant_id: &str) -> Result<AssistantObject, OpenAIError> {
self.client
.get(&format!("/assistants/{assistant_id}"))
.await
}
/// Modifies an assistant.
#[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn update(
&self,
assistant_id: &str,
request: ModifyAssistantRequest,
) -> Result<AssistantObject, OpenAIError> {
self.client
.post(&format!("/assistants/{assistant_id}"), request)
.await
}
/// Delete an assistant.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn delete(&self, assistant_id: &str) -> Result<DeleteAssistantResponse, OpenAIError> {
self.client
.delete(&format!("/assistants/{assistant_id}"))
.await
}
/// Returns a list of assistants.
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn list<Q>(&self, query: &Q) -> Result<ListAssistantsResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client.get_with_query("/assistants", &query).await
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use bytes::Bytes;
use crate::{
Client,
config::Config,
error::OpenAIError,
types::{
CreateSpeechRequest, CreateSpeechResponse, CreateTranscriptionRequest,
CreateTranscriptionResponseJson, CreateTranscriptionResponseVerboseJson,
CreateTranslationRequest, CreateTranslationResponseJson,
CreateTranslationResponseVerboseJson,
},
};
/// Turn audio into text or text into audio.
/// Related guide: [Speech to text](https://platform.openai.com/docs/guides/speech-to-text)
pub struct Audio<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Audio<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Transcribes audio into the input language.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn transcribe(
&self,
request: CreateTranscriptionRequest,
) -> Result<CreateTranscriptionResponseJson, OpenAIError> {
self.client
.post_form("/audio/transcriptions", request)
.await
}
/// Transcribes audio into the input language.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn transcribe_verbose_json(
&self,
request: CreateTranscriptionRequest,
) -> Result<CreateTranscriptionResponseVerboseJson, OpenAIError> {
self.client
.post_form("/audio/transcriptions", request)
.await
}
/// Transcribes audio into the input language.
pub async fn transcribe_raw(
&self,
request: CreateTranscriptionRequest,
) -> Result<Bytes, OpenAIError> {
self.client
.post_form_raw("/audio/transcriptions", request)
.await
}
/// Translates audio into English.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn translate(
&self,
request: CreateTranslationRequest,
) -> Result<CreateTranslationResponseJson, OpenAIError> {
self.client.post_form("/audio/translations", request).await
}
/// Translates audio into English.
#[crate::byot(
T0 = Clone,
R = serde::de::DeserializeOwned,
where_clause = "reqwest::multipart::Form: crate::traits::AsyncTryFrom<T0, Error = OpenAIError>",
)]
pub async fn translate_verbose_json(
&self,
request: CreateTranslationRequest,
) -> Result<CreateTranslationResponseVerboseJson, OpenAIError> {
self.client.post_form("/audio/translations", request).await
}
/// Transcribes audio into the input language.
pub async fn translate_raw(
&self,
request: CreateTranslationRequest,
) -> Result<Bytes, OpenAIError> {
self.client
.post_form_raw("/audio/translations", request)
.await
}
/// Generates audio from the input text.
pub async fn speech(
&self,
request: CreateSpeechRequest,
) -> Result<CreateSpeechResponse, OpenAIError> {
let bytes = self.client.post_raw("/audio/speech", request).await?;
Ok(CreateSpeechResponse { bytes })
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use serde::Serialize;
use crate::{Client, config::Config, error::OpenAIError, types::ListAuditLogsResponse};
/// Logs of user actions and configuration changes within this organization.
/// To log events, you must activate logging in the [Organization Settings](https://platform.openai.com/settings/organization/general).
/// Once activated, for security reasons, logging cannot be deactivated.
pub struct AuditLogs<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> AuditLogs<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// List user actions and configuration changes within this organization.
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn get<Q>(&self, query: &Q) -> Result<ListAuditLogsResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client
.get_with_query("/organization/audit_logs", &query)
.await
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use serde::Serialize;
use crate::{
Client,
config::Config,
error::OpenAIError,
types::{Batch, BatchRequest, ListBatchesResponse},
};
/// Create large batches of API requests for asynchronous processing. The Batch API returns completions within 24 hours for a 50% discount.
///
/// Related guide: [Batch](https://platform.openai.com/docs/guides/batch)
pub struct Batches<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Batches<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Creates and executes a batch from an uploaded file of requests
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn create(&self, request: BatchRequest) -> Result<Batch, OpenAIError> {
self.client.post("/batches", request).await
}
/// List your organization's batches.
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn list<Q>(&self, query: &Q) -> Result<ListBatchesResponse, OpenAIError>
where
Q: Serialize + ?Sized,
{
self.client.get_with_query("/batches", &query).await
}
/// Retrieves a batch.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn retrieve(&self, batch_id: &str) -> Result<Batch, OpenAIError> {
self.client.get(&format!("/batches/{batch_id}")).await
}
/// Cancels an in-progress batch. The batch will be in status `cancelling` for up to 10 minutes, before changing to `cancelled`, where it will have partial results (if any) available in the output file.
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
pub async fn cancel(&self, batch_id: &str) -> Result<Batch, OpenAIError> {
self.client
.post(
&format!("/batches/{batch_id}/cancel"),
serde_json::json!({}),
)
.await
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use crate::{
Client,
config::Config,
error::OpenAIError,
types::{
ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionResponse,
},
};
/// Given a list of messages comprising a conversation, the model will return a response.
///
/// Related guide: [Chat completions](https://platform.openai.com//docs/guides/text-generation)
pub struct Chat<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Chat<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Creates a model response for the given chat conversation. Learn more in
/// the
///
/// [text generation](https://platform.openai.com/docs/guides/text-generation),
/// [vision](https://platform.openai.com/docs/guides/vision),
///
/// and [audio](https://platform.openai.com/docs/guides/audio) guides.
///
///
/// Parameter support can differ depending on the model used to generate the
/// response, particularly for newer reasoning models. Parameters that are
/// only supported for reasoning models are noted below. For the current state
/// of unsupported parameters in reasoning models,
///
/// [refer to the reasoning guide](https://platform.openai.com/docs/guides/reasoning).
///
/// byot: You must ensure "stream: false" in serialized `request`
#[crate::byot(
T0 = serde::Serialize,
R = serde::de::DeserializeOwned
)]
pub async fn create(
&self,
request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if request.stream.is_some() && request.stream.unwrap() {
return Err(OpenAIError::InvalidArgument(
"When stream is true, use Chat::create_stream".into(),
));
}
}
self.client.post("/chat/completions", request).await
}
/// Creates a completion for the chat message
///
/// partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message.
///
/// [ChatCompletionResponseStream] is a parsed SSE stream until a \[DONE\] is received from server.
///
/// byot: You must ensure "stream: true" in serialized `request`
#[crate::byot(
T0 = serde::Serialize,
R = serde::de::DeserializeOwned,
stream = "true",
where_clause = "R: std::marker::Send + 'static"
)]
#[allow(unused_mut)]
pub async fn create_stream(
&self,
mut request: CreateChatCompletionRequest,
) -> Result<ChatCompletionResponseStream, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if request.stream.is_some() && !request.stream.unwrap() {
return Err(OpenAIError::InvalidArgument(
"When stream is false, use Chat::create".into(),
));
}
request.stream = Some(true);
}
Ok(self.client.post_stream("/chat/completions", request).await)
}
}
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use crate::{
client::Client,
config::Config,
error::OpenAIError,
types::{CompletionResponseStream, CreateCompletionRequest, CreateCompletionResponse},
};
/// Given a prompt, the model will return one or more predicted completions,
/// and can also return the probabilities of alternative tokens at each position.
/// We recommend most users use our Chat completions API.
/// [Learn more](https://platform.openai.com/docs/deprecations/2023-07-06-gpt-and-embeddings)
///
/// Related guide: [Legacy Completions](https://platform.openai.com/docs/guides/gpt/completions-api)
pub struct Completions<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Completions<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Creates a completion for the provided prompt and parameters
///
/// You must ensure that "stream: false" in serialized `request`
#[crate::byot(
T0 = serde::Serialize,
R = serde::de::DeserializeOwned
)]
pub async fn create(
&self,
request: CreateCompletionRequest,
) -> Result<CreateCompletionResponse, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if request.stream.is_some() && request.stream.unwrap() {
return Err(OpenAIError::InvalidArgument(
"When stream is true, use Completion::create_stream".into(),
));
}
}
self.client.post("/completions", request).await
}
/// Creates a completion request for the provided prompt and parameters
///
/// Stream back partial progress. Tokens will be sent as data-only
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
/// as they become available, with the stream terminated by a data: \[DONE\] message.
///
/// [CompletionResponseStream] is a parsed SSE stream until a \[DONE\] is received from server.
///
/// You must ensure that "stream: true" in serialized `request`
#[crate::byot(
T0 = serde::Serialize,
R = serde::de::DeserializeOwned,
stream = "true",
where_clause = "R: std::marker::Send + 'static"
)]
#[allow(unused_mut)]
pub async fn create_stream(
&self,
mut request: CreateCompletionRequest,
) -> Result<CompletionResponseStream, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if request.stream.is_some() && !request.stream.unwrap() {
return Err(OpenAIError::InvalidArgument(
"When stream is false, use Completion::create".into(),
));
}
request.stream = Some(true);
}
Ok(self.client.post_stream("/completions", request).await)
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
//! Client configurations: [OpenAIConfig] for OpenAI, [AzureConfig] for Azure OpenAI Service.
use reqwest::header::{AUTHORIZATION, HeaderMap};
use secrecy::{ExposeSecret, SecretString};
use serde::Deserialize;
/// Default v1 API base url
pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1";
/// Organization header
pub const OPENAI_ORGANIZATION_HEADER: &str = "OpenAI-Organization";
/// Project header
pub const OPENAI_PROJECT_HEADER: &str = "OpenAI-Project";
/// Calls to the Assistants API require that you pass a Beta header
pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";
/// [crate::Client] relies on this for every API call on OpenAI
/// or Azure OpenAI service
pub trait Config: Send + Sync {
fn headers(&self) -> HeaderMap;
fn url(&self, path: &str) -> String;
fn query(&self) -> Vec<(&str, &str)>;
fn api_base(&self) -> &str;
fn api_key(&self) -> &SecretString;
}
/// Macro to implement Config trait for pointer types with dyn objects
macro_rules! impl_config_for_ptr {
($t:ty) => {
impl Config for $t {
fn headers(&self) -> HeaderMap {
self.as_ref().headers()
}
fn url(&self, path: &str) -> String {
self.as_ref().url(path)
}
fn query(&self) -> Vec<(&str, &str)> {
self.as_ref().query()
}
fn api_base(&self) -> &str {
self.as_ref().api_base()
}
fn api_key(&self) -> &SecretString {
self.as_ref().api_key()
}
}
};
}
impl_config_for_ptr!(Box<dyn Config>);
impl_config_for_ptr!(std::sync::Arc<dyn Config>);
/// Configuration for OpenAI API
#[derive(Clone, Debug, Deserialize)]
#[serde(default)]
pub struct OpenAIConfig {
api_base: String,
api_key: SecretString,
org_id: String,
project_id: String,
}
impl Default for OpenAIConfig {
fn default() -> Self {
Self {
api_base: OPENAI_API_BASE.to_string(),
api_key: std::env::var("OPENAI_API_KEY")
.unwrap_or_else(|_| "".to_string())
.into(),
org_id: Default::default(),
project_id: Default::default(),
}
}
}
impl OpenAIConfig {
/// Create client with default [OPENAI_API_BASE] url and default API key from OPENAI_API_KEY env var
pub fn new() -> Self {
Default::default()
}
/// To use a different organization id other than default
pub fn with_org_id<S: Into<String>>(mut self, org_id: S) -> Self {
self.org_id = org_id.into();
self
}
/// Non default project id
pub fn with_project_id<S: Into<String>>(mut self, project_id: S) -> Self {
self.project_id = project_id.into();
self
}
/// To use a different API key different from default OPENAI_API_KEY env var
pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
self.api_key = SecretString::from(api_key.into());
self
}
/// To use a API base url different from default [OPENAI_API_BASE]
pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
self.api_base = api_base.into();
self
}
pub fn org_id(&self) -> &str {
&self.org_id
}
}
impl Config for OpenAIConfig {
fn headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
if !self.org_id.is_empty() {
headers.insert(
OPENAI_ORGANIZATION_HEADER,
self.org_id.as_str().parse().unwrap(),
);
}
if !self.project_id.is_empty() {
headers.insert(
OPENAI_PROJECT_HEADER,
self.project_id.as_str().parse().unwrap(),
);
}
headers.insert(
AUTHORIZATION,
format!("Bearer {}", self.api_key.expose_secret())
.as_str()
.parse()
.unwrap(),
);
// hack for Assistants APIs
// Calls to the Assistants API require that you pass a Beta header
headers.insert(OPENAI_BETA_HEADER, "assistants=v2".parse().unwrap());
headers
}
fn url(&self, path: &str) -> String {
format!("{}{}", self.api_base, path)
}
fn api_base(&self) -> &str {
&self.api_base
}
fn api_key(&self) -> &SecretString {
&self.api_key
}
fn query(&self) -> Vec<(&str, &str)> {
vec![]
}
}
/// Configuration for Azure OpenAI Service
#[derive(Clone, Debug, Deserialize)]
#[serde(default)]
pub struct AzureConfig {
api_version: String,
deployment_id: String,
api_base: String,
api_key: SecretString,
}
impl Default for AzureConfig {
fn default() -> Self {
Self {
api_base: Default::default(),
api_key: std::env::var("OPENAI_API_KEY")
.unwrap_or_else(|_| "".to_string())
.into(),
deployment_id: Default::default(),
api_version: Default::default(),
}
}
}
impl AzureConfig {
pub fn new() -> Self {
Default::default()
}
pub fn with_api_version<S: Into<String>>(mut self, api_version: S) -> Self {
self.api_version = api_version.into();
self
}
pub fn with_deployment_id<S: Into<String>>(mut self, deployment_id: S) -> Self {
self.deployment_id = deployment_id.into();
self
}
/// To use a different API key different from default OPENAI_API_KEY env var
pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
self.api_key = SecretString::from(api_key.into());
self
}
/// API base url in form of <https://your-resource-name.openai.azure.com>
pub fn with_api_base<S: Into<String>>(mut self, api_base: S) -> Self {
self.api_base = api_base.into();
self
}
}
impl Config for AzureConfig {
fn headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert("api-key", self.api_key.expose_secret().parse().unwrap());
headers
}
fn url(&self, path: &str) -> String {
format!(
"{}/openai/deployments/{}{}",
self.api_base, self.deployment_id, path
)
}
fn api_base(&self) -> &str {
&self.api_base
}
fn api_key(&self) -> &SecretString {
&self.api_key
}
fn query(&self) -> Vec<(&str, &str)> {
vec![("api-version", &self.api_version)]
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::Client;
use crate::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, CreateChatCompletionRequest,
};
use std::sync::Arc;
#[test]
fn test_client_creation() {
unsafe { std::env::set_var("OPENAI_API_KEY", "test") }
let openai_config = OpenAIConfig::default();
let config = Box::new(openai_config.clone()) as Box<dyn Config>;
let client = Client::with_config(config);
assert!(client.config().url("").ends_with("/v1"));
let config = Arc::new(openai_config) as Arc<dyn Config>;
let client = Client::with_config(config);
assert!(client.config().url("").ends_with("/v1"));
let cloned_client = client.clone();
assert!(cloned_client.config().url("").ends_with("/v1"));
}
async fn dynamic_dispatch_compiles(client: &Client<Box<dyn Config>>) {
let _ = client.chat().create(CreateChatCompletionRequest {
model: "gpt-4o".to_string(),
messages: vec![ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: "Hello, world!".into(),
..Default::default()
},
)],
..Default::default()
});
}
#[tokio::test]
async fn test_dynamic_dispatch() {
let openai_config = OpenAIConfig::default();
let azure_config = AzureConfig::default();
let azure_client = Client::with_config(Box::new(azure_config.clone()) as Box<dyn Config>);
let oai_client = Client::with_config(Box::new(openai_config.clone()) as Box<dyn Config>);
let _ = dynamic_dispatch_compiles(&azure_client).await;
let _ = dynamic_dispatch_compiles(&oai_client).await;
let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&azure_client).await });
let _ = tokio::spawn(async move { dynamic_dispatch_compiles(&oai_client).await });
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use std::path::{Path, PathBuf};
use base64::{Engine as _, engine::general_purpose};
use rand::{Rng, distr::Alphanumeric};
use reqwest::Url;
use crate::error::OpenAIError;
fn create_paths<P: AsRef<Path>>(url: &Url, base_dir: P) -> (PathBuf, PathBuf) {
let mut dir = PathBuf::from(base_dir.as_ref());
let mut path = dir.clone();
let segments = url.path_segments().map(|c| c.collect::<Vec<_>>());
if let Some(segments) = segments {
for (idx, segment) in segments.iter().enumerate() {
if idx != segments.len() - 1 {
dir.push(segment);
}
path.push(segment);
}
}
(dir, path)
}
pub(crate) async fn download_url<P: AsRef<Path>>(
url: &str,
dir: P,
) -> Result<PathBuf, OpenAIError> {
let parsed_url = Url::parse(url).map_err(|e| OpenAIError::FileSaveError(e.to_string()))?;
let response = reqwest::get(url)
.await
.map_err(|e| OpenAIError::FileSaveError(e.to_string()))?;
if !response.status().is_success() {
return Err(OpenAIError::FileSaveError(format!(
"couldn't download file, status: {}, url: {url}",
response.status()
)));
}
let (dir, file_path) = create_paths(&parsed_url, dir);
tokio::fs::create_dir_all(dir.as_path())
.await
.map_err(|e| OpenAIError::FileSaveError(format!("{}, dir: {}", e, dir.display())))?;
tokio::fs::write(
file_path.as_path(),
response.bytes().await.map_err(|e| {
OpenAIError::FileSaveError(format!("{}, file path: {}", e, file_path.display()))
})?,
)
.await
.map_err(|e| OpenAIError::FileSaveError(e.to_string()))?;
Ok(file_path)
}
pub(crate) async fn save_b64<P: AsRef<Path>>(b64: &str, dir: P) -> Result<PathBuf, OpenAIError> {
let filename: String = rand::rng()
.sample_iter(&Alphanumeric)
.take(10)
.map(char::from)
.collect();
let filename = format!("{filename}.png");
let path = PathBuf::from(dir.as_ref()).join(filename);
tokio::fs::write(
path.as_path(),
general_purpose::STANDARD
.decode(b64)
.map_err(|e| OpenAIError::FileSaveError(e.to_string()))?,
)
.await
.map_err(|e| OpenAIError::FileSaveError(format!("{}, path: {}", e, path.display())))?;
Ok(path)
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
// Original Copyright (c) 2022 Himanshu Neema
// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
//
// Modifications Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
// Licensed under Apache 2.0
use crate::{
Client,
config::Config,
error::OpenAIError,
types::{CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse},
};
#[cfg(not(feature = "byot"))]
use crate::types::EncodingFormat;
/// Get a vector representation of a given input that can be easily
/// consumed by machine learning models and algorithms.
///
/// Related guide: [Embeddings](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)
pub struct Embeddings<'c, C: Config> {
client: &'c Client<C>,
}
impl<'c, C: Config> Embeddings<'c, C> {
pub fn new(client: &'c Client<C>) -> Self {
Self { client }
}
/// Creates an embedding vector representing the input text.
///
/// byot: In serialized `request` you must ensure "encoding_format" is not "base64"
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn create(
&self,
request: CreateEmbeddingRequest,
) -> Result<CreateEmbeddingResponse, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
return Err(OpenAIError::InvalidArgument(
"When encoding_format is base64, use Embeddings::create_base64".into(),
));
}
}
self.client.post("/embeddings", request).await
}
/// Creates an embedding vector representing the input text.
///
/// The response will contain the embedding in base64 format.
///
/// byot: In serialized `request` you must ensure "encoding_format" is "base64"
#[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
pub async fn create_base64(
&self,
request: CreateEmbeddingRequest,
) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
#[cfg(not(feature = "byot"))]
{
if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
return Err(OpenAIError::InvalidArgument(
"When encoding_format is not base64, use Embeddings::create".into(),
));
}
}
self.client.post("/embeddings", request).await
}
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment