Unverified Commit 435803ea authored by ryan-lempka's avatar ryan-lempka Committed by GitHub
Browse files

chore: add additional param support for multimodal models (#3042)


Signed-off-by: default avatarRyan Lempka <rlempka@nvidia.com>
parent fa6feee4
...@@ -1942,6 +1942,8 @@ dependencies = [ ...@@ -1942,6 +1942,8 @@ dependencies = [
"tokio-tungstenite 0.26.2", "tokio-tungstenite 0.26.2",
"tokio-util", "tokio-util",
"tracing", "tracing",
"url",
"uuid 1.18.0",
] ]
[[package]] [[package]]
......
...@@ -52,6 +52,8 @@ tokio = { version = "1.43.0", features = ["fs", "macros"] } ...@@ -52,6 +52,8 @@ tokio = { version = "1.43.0", features = ["fs", "macros"] }
tokio-stream = "0.1.17" tokio-stream = "0.1.17"
tokio-util = { version = "0.7.13", features = ["codec", "io-util"] } tokio-util = { version = "0.7.13", features = ["codec", "io-util"] }
tracing = "0.1.41" tracing = "0.1.41"
url = { workspace = true }
uuid = { workspace = true }
derive_builder = "0.20.2" derive_builder = "0.20.2"
secrecy = { version = "0.10.3", features = ["serde"] } secrecy = { version = "0.10.3", features = ["serde"] }
bytes = "1.9.0" bytes = "1.9.0"
......
...@@ -14,6 +14,9 @@ use derive_builder::Builder; ...@@ -14,6 +14,9 @@ use derive_builder::Builder;
use futures::Stream; use futures::Stream;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use url::Url;
use uuid::{Uuid, uuid};
use crate::error::OpenAIError; use crate::error::OpenAIError;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
...@@ -199,52 +202,82 @@ pub enum ImageDetail { ...@@ -199,52 +202,82 @@ pub enum ImageDetail {
High, High,
} }
#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] #[derive(Debug, Serialize, Deserialize, Clone, Builder, PartialEq)]
#[builder(name = "ImageUrlArgs")] #[builder(name = "ImageUrlArgs")]
#[builder(pattern = "mutable")] #[builder(pattern = "mutable")]
#[builder(setter(into, strip_option), default)] #[builder(setter(into, strip_option))]
#[builder(derive(Debug))] #[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))] #[builder(build_fn(error = "OpenAIError"))]
pub struct ImageUrl { pub struct ImageUrl {
/// Either a URL of the image or the base64 encoded image data. /// Either a URL of the image or the base64 encoded image data.
pub url: String, pub url: url::Url,
/// Specifies the detail level of the image. Learn more in the [Vision guide](https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding). /// Specifies the detail level of the image. Learn more in the [Vision guide](https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding).
pub detail: Option<ImageDetail>, pub detail: Option<ImageDetail>,
/// Optional unique identifier for the image.
#[serde(skip_serializing_if = "Option::is_none")]
pub uuid: Option<uuid::Uuid>,
} }
#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] #[derive(Debug, Serialize, Deserialize, Clone, Builder, PartialEq)]
#[builder(name = "VideoUrlArgs")] #[builder(name = "VideoUrlArgs")]
#[builder(pattern = "mutable")] #[builder(pattern = "mutable")]
#[builder(setter(into, strip_option), default)] #[builder(setter(into, strip_option))]
#[builder(derive(Debug))] #[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))] #[builder(build_fn(error = "OpenAIError"))]
pub struct VideoUrl { pub struct VideoUrl {
/// Either a URL of the video or the base64 encoded video data. /// Either a URL of the video or the base64 encoded video data.
pub url: String, pub url: url::Url,
/// Specifies the detail level of the video processing. /// Specifies the detail level of the video processing.
pub detail: Option<ImageDetail>, pub detail: Option<ImageDetail>,
/// Optional unique identifier for the video.
#[serde(skip_serializing_if = "Option::is_none")]
pub uuid: Option<uuid::Uuid>,
} }
#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] #[derive(Debug, Serialize, Deserialize, Clone, Builder, PartialEq)]
#[builder(name = "ChatCompletionRequestMessageContentPartImageArgs")] #[builder(name = "ChatCompletionRequestMessageContentPartImageArgs")]
#[builder(pattern = "mutable")] #[builder(pattern = "mutable")]
#[builder(setter(into, strip_option), default)] #[builder(setter(into, strip_option))]
#[builder(derive(Debug))] #[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))] #[builder(build_fn(error = "OpenAIError"))]
pub struct ChatCompletionRequestMessageContentPartImage { pub struct ChatCompletionRequestMessageContentPartImage {
pub image_url: ImageUrl, pub image_url: ImageUrl,
} }
#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] #[derive(Debug, Serialize, Deserialize, Clone, Builder, PartialEq)]
#[builder(name = "ChatCompletionRequestMessageContentPartVideoArgs")] #[builder(name = "ChatCompletionRequestMessageContentPartVideoArgs")]
#[builder(pattern = "mutable")] #[builder(pattern = "mutable")]
#[builder(setter(into, strip_option), default)] #[builder(setter(into, strip_option))]
#[builder(derive(Debug))] #[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))] #[builder(build_fn(error = "OpenAIError"))]
pub struct ChatCompletionRequestMessageContentPartVideo { pub struct ChatCompletionRequestMessageContentPartVideo {
pub video_url: VideoUrl, pub video_url: VideoUrl,
} }
#[derive(Debug, Serialize, Deserialize, Clone, Builder, PartialEq)]
#[builder(name = "AudioUrlArgs")]
#[builder(pattern = "mutable")]
#[builder(setter(into, strip_option))]
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct AudioUrl {
/// URL of the audio file
pub url: url::Url,
/// Optional unique identifier for the audio.
#[serde(skip_serializing_if = "Option::is_none")]
pub uuid: Option<uuid::Uuid>,
}
#[derive(Debug, Serialize, Deserialize, Clone, Builder, PartialEq)]
#[builder(name = "ChatCompletionRequestMessageContentPartAudioUrlArgs")]
#[builder(pattern = "mutable")]
#[builder(setter(into, strip_option))]
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct ChatCompletionRequestMessageContentPartAudioUrl {
pub audio_url: AudioUrl,
}
#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)] #[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum InputAudioFormat { pub enum InputAudioFormat {
...@@ -279,6 +312,7 @@ pub enum ChatCompletionRequestUserMessageContentPart { ...@@ -279,6 +312,7 @@ pub enum ChatCompletionRequestUserMessageContentPart {
Text(ChatCompletionRequestMessageContentPartText), Text(ChatCompletionRequestMessageContentPartText),
ImageUrl(ChatCompletionRequestMessageContentPartImage), ImageUrl(ChatCompletionRequestMessageContentPartImage),
VideoUrl(ChatCompletionRequestMessageContentPartVideo), VideoUrl(ChatCompletionRequestMessageContentPartVideo),
AudioUrl(ChatCompletionRequestMessageContentPartAudioUrl),
InputAudio(ChatCompletionRequestMessageContentPartAudio), InputAudio(ChatCompletionRequestMessageContentPartAudio),
} }
...@@ -752,6 +786,10 @@ pub struct CreateChatCompletionRequest { ...@@ -752,6 +786,10 @@ pub struct CreateChatCompletionRequest {
/// See the [model endpoint compatibility](https://platform.openai.com/docs/models#model-endpoint-compatibility) table for details on which models work with the Chat API. /// See the [model endpoint compatibility](https://platform.openai.com/docs/models#model-endpoint-compatibility) table for details on which models work with the Chat API.
pub model: String, pub model: String,
/// Multimodal processor configuration parameters
#[serde(skip_serializing_if = "Option::is_none")]
pub mm_processor_kwargs: Option<serde_json::Value>,
/// Whether or not to store the output of this chat completion request /// Whether or not to store the output of this chat completion request
/// ///
/// for use in our [model distillation](https://platform.openai.com/docs/guides/distillation) or [evals](https://platform.openai.com/docs/guides/evals) products. /// for use in our [model distillation](https://platform.openai.com/docs/guides/distillation) or [evals](https://platform.openai.com/docs/guides/evals) products.
...@@ -1099,3 +1137,43 @@ pub struct CreateChatCompletionStreamResponse { ...@@ -1099,3 +1137,43 @@ pub struct CreateChatCompletionStreamResponse {
/// When present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request. /// When present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request.
pub usage: Option<CompletionUsage>, pub usage: Option<CompletionUsage>,
} }
#[cfg(test)]
mod tests {
use super::*;
use serde_json;
#[test]
fn test_audio_url_content_part_json() {
let json = r#"{"type": "audio_url", "audio_url": {"url": "https://example.com/audio.mp3", "uuid": "67e55044-10b1-426f-9247-bb680e5fe0c8"}}"#;
let content_part: ChatCompletionRequestUserMessageContentPart =
serde_json::from_str(json).unwrap();
match content_part {
ChatCompletionRequestUserMessageContentPart::AudioUrl(part) => {
assert_eq!(
part.audio_url.url,
"https://example.com/audio.mp3".parse().unwrap()
);
assert_eq!(
part.audio_url.uuid,
Some(uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8"))
);
}
_ => panic!("Expected AudioUrl variant"),
}
}
#[test]
fn test_mm_processor_kwargs() {
let request = CreateChatCompletionRequest {
messages: vec![],
model: "test-model".to_string(),
mm_processor_kwargs: Some(serde_json::json!({"max_pixels": 768})),
..Default::default()
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("mm_processor_kwargs"));
}
}
...@@ -24,12 +24,13 @@ use crate::{ ...@@ -24,12 +24,13 @@ use crate::{
use bytes::Bytes; use bytes::Bytes;
use super::{ use super::{
AddUploadPartRequest, AudioInput, AudioResponseFormat, ChatCompletionFunctionCall, AddUploadPartRequest, AudioInput, AudioResponseFormat, AudioUrl, ChatCompletionFunctionCall,
ChatCompletionFunctions, ChatCompletionNamedToolChoice, ChatCompletionRequestAssistantMessage, ChatCompletionFunctions, ChatCompletionNamedToolChoice, ChatCompletionRequestAssistantMessage,
ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestDeveloperMessage, ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestDeveloperMessage,
ChatCompletionRequestDeveloperMessageContent, ChatCompletionRequestFunctionMessage, ChatCompletionRequestDeveloperMessageContent, ChatCompletionRequestFunctionMessage,
ChatCompletionRequestMessage, ChatCompletionRequestMessageContentPartAudio, ChatCompletionRequestMessage, ChatCompletionRequestMessageContentPartAudio,
ChatCompletionRequestMessageContentPartImage, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartAudioUrl, ChatCompletionRequestMessageContentPartImage,
ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartVideo,
ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent,
...@@ -38,7 +39,7 @@ use super::{ ...@@ -38,7 +39,7 @@ use super::{
CreateSpeechResponse, CreateTranscriptionRequest, CreateTranslationRequest, DallE2ImageSize, CreateSpeechResponse, CreateTranscriptionRequest, CreateTranslationRequest, DallE2ImageSize,
EmbeddingInput, FileInput, FilePurpose, FunctionName, Image, ImageInput, ImageModel, EmbeddingInput, FileInput, FilePurpose, FunctionName, Image, ImageInput, ImageModel,
ImageResponseFormat, ImageSize, ImageUrl, ImagesResponse, ModerationInput, Prompt, Role, Stop, ImageResponseFormat, ImageSize, ImageUrl, ImagesResponse, ModerationInput, Prompt, Role, Stop,
TimestampGranularity, TimestampGranularity, VideoUrl,
responses::{CodeInterpreterContainer, Input, InputContent, Role as ResponsesRole}, responses::{CodeInterpreterContainer, Input, InputContent, Role as ResponsesRole},
}; };
...@@ -765,6 +766,22 @@ impl From<ChatCompletionRequestMessageContentPartAudio> ...@@ -765,6 +766,22 @@ impl From<ChatCompletionRequestMessageContentPartAudio>
} }
} }
impl From<ChatCompletionRequestMessageContentPartVideo>
for ChatCompletionRequestUserMessageContentPart
{
fn from(value: ChatCompletionRequestMessageContentPartVideo) -> Self {
ChatCompletionRequestUserMessageContentPart::VideoUrl(value)
}
}
impl From<ChatCompletionRequestMessageContentPartAudioUrl>
for ChatCompletionRequestUserMessageContentPart
{
fn from(value: ChatCompletionRequestMessageContentPartAudioUrl) -> Self {
ChatCompletionRequestUserMessageContentPart::AudioUrl(value)
}
}
impl From<&str> for ChatCompletionRequestMessageContentPartText { impl From<&str> for ChatCompletionRequestMessageContentPartText {
fn from(value: &str) -> Self { fn from(value: &str) -> Self {
ChatCompletionRequestMessageContentPartText { text: value.into() } ChatCompletionRequestMessageContentPartText { text: value.into() }
...@@ -780,8 +797,9 @@ impl From<String> for ChatCompletionRequestMessageContentPartText { ...@@ -780,8 +797,9 @@ impl From<String> for ChatCompletionRequestMessageContentPartText {
impl From<&str> for ImageUrl { impl From<&str> for ImageUrl {
fn from(value: &str) -> Self { fn from(value: &str) -> Self {
Self { Self {
url: value.into(), url: value.parse().expect("Invalid URL"),
detail: Default::default(), detail: Default::default(),
uuid: None,
} }
} }
} }
...@@ -789,8 +807,47 @@ impl From<&str> for ImageUrl { ...@@ -789,8 +807,47 @@ impl From<&str> for ImageUrl {
impl From<String> for ImageUrl { impl From<String> for ImageUrl {
fn from(value: String) -> Self { fn from(value: String) -> Self {
Self { Self {
url: value, url: value.parse().expect("Invalid URL"),
detail: Default::default(),
uuid: None,
}
}
}
impl From<&str> for VideoUrl {
fn from(value: &str) -> Self {
Self {
url: value.parse().expect("Invalid URL"),
detail: Default::default(),
uuid: None,
}
}
}
impl From<String> for VideoUrl {
fn from(value: String) -> Self {
Self {
url: value.parse().expect("Invalid URL"),
detail: Default::default(), detail: Default::default(),
uuid: None,
}
}
}
impl From<&str> for AudioUrl {
fn from(value: &str) -> Self {
Self {
url: value.parse().expect("Invalid URL"),
uuid: None,
}
}
}
impl From<String> for AudioUrl {
fn from(value: String) -> Self {
Self {
url: value.parse().expect("Invalid URL"),
uuid: None,
} }
} }
} }
......
...@@ -1373,6 +1373,8 @@ dependencies = [ ...@@ -1373,6 +1373,8 @@ dependencies = [
"tokio-stream", "tokio-stream",
"tokio-util", "tokio-util",
"tracing", "tracing",
"url",
"uuid",
] ]
[[package]] [[package]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment