Unverified Commit e30a3054 authored by milesial's avatar milesial Committed by GitHub
Browse files

feat: Media HTTP fetching and b64 decoding (#3967)


Signed-off-by: default avatarAlexandre Milesi <milesial@users.noreply.github.com>
parent 20d1eb2e
......@@ -210,6 +210,16 @@ version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0f477b951e452a0b6b4a10b53ccd569042d1d01729b519e02074a9c0958a063"
[[package]]
name = "assert-json-diff"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "assert_matches"
version = "1.5.0"
......@@ -2170,6 +2180,7 @@ dependencies = [
"async_zmq",
"axum 0.8.4",
"axum-server",
"base64 0.22.1",
"bincode 2.0.1",
"bitflags 2.9.4",
"blake3",
......@@ -2201,6 +2212,7 @@ dependencies = [
"lazy_static",
"minijinja",
"minijinja-contrib",
"mockito",
"modelexpress-client",
"modelexpress-common",
"ndarray",
......@@ -4911,6 +4923,30 @@ dependencies = [
"rayon",
]
[[package]]
name = "mockito"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7760e0e418d9b7e5777c0374009ca4c93861b9066f18cb334a20ce50ab63aa48"
dependencies = [
"assert-json-diff",
"bytes",
"colored",
"futures-util",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"hyper 1.7.0",
"hyper-util",
"log",
"rand 0.9.2",
"regex",
"serde_json",
"serde_urlencoded",
"similar",
"tokio",
]
[[package]]
name = "modelexpress-client"
version = "0.2.0"
......
......@@ -140,6 +140,10 @@ minijinja = { version = "2.10.2", features = ["loader"] }
minijinja-contrib = { version = "2.10.2", features = ["pycompat"] }
json-five = { version = "0.3" }
# media loading in the preprocessor
reqwest = { workspace = true }
base64 = { version = "0.22" }
# Publishers
zeromq = "0.4.1"
rmp-serde = "1.3"
......@@ -167,6 +171,7 @@ insta = { version = "1.41", features = [
] }
lazy_static = "1.4"
mockito = "1.7.0"
[build-dependencies]
tonic-build = { version = "0.13.1" }
......@@ -11,6 +11,7 @@
//!
//! The Preprocessor will accept any IngressRequest and transform it to a BackendRequest.
pub mod media;
pub mod prompt;
pub mod tools;
use anyhow::Context;
......@@ -26,11 +27,11 @@ use std::{collections::HashMap, pin::Pin, sync::Arc};
use tracing;
use crate::model_card::{ModelDeploymentCard, ModelInfo};
use crate::preprocessor::media::MediaLoader;
use crate::preprocessor::prompt::OAIChatLikeRequest;
use crate::protocols::common::preprocessor::{
MultimodalData, MultimodalDataMap, PreprocessedRequestBuilder,
};
use crate::tokenizers::Encoding;
use dynamo_parsers::{ReasoningParser, ReasoningParserType};
......@@ -113,6 +114,7 @@ pub struct OpenAIPreprocessor {
/// Per-model runtime configuration propagated to response generator (e.g., reasoning/tool parser)
runtime_config: crate::local_model::runtime_config::ModelRuntimeConfig,
tool_call_parser: Option<String>,
media_loader: Option<MediaLoader>,
}
impl OpenAIPreprocessor {
......@@ -141,7 +143,7 @@ impl OpenAIPreprocessor {
// // Initialize runtime config from the ModelDeploymentCard
let runtime_config = mdc.runtime_config.clone();
let media_loader = None; // TODO: enable with decoder config from MDC
Ok(Arc::new(Self {
formatter,
tokenizer,
......@@ -149,6 +151,7 @@ impl OpenAIPreprocessor {
mdcsum,
runtime_config,
tool_call_parser,
media_loader,
}))
}
/// Encode a string to it's tokens
......@@ -162,7 +165,7 @@ impl OpenAIPreprocessor {
/// Annotations evaluated by this method include:
/// - `formatted_prompt`
/// - `token_ids`
pub fn preprocess_request<
pub async fn preprocess_request<
R: OAIChatLikeRequest
+ AnnotationsProvider
+ SamplingOptionsProvider
......@@ -181,6 +184,7 @@ impl OpenAIPreprocessor {
.gather_tokens(request, &mut builder, formatted_prompt)
.with_context(|| "Failed to gather tokens")?;
self.gather_multi_modal_data(request, &mut builder)
.await
.with_context(|| "Failed to gather multimodal data")?;
Ok((builder.build()?, annotations))
......@@ -267,7 +271,7 @@ impl OpenAIPreprocessor {
}
}
pub fn gather_multi_modal_data<R: OAIChatLikeRequest>(
pub async fn gather_multi_modal_data<R: OAIChatLikeRequest>(
&self,
request: &R,
builder: &mut PreprocessedRequestBuilder,
......@@ -275,6 +279,7 @@ impl OpenAIPreprocessor {
let messages = request.messages();
let message_count = messages.len().unwrap_or(0);
let mut media_map: MultimodalDataMap = HashMap::new();
let mut fetch_tasks = Vec::new();
for idx in 0..message_count {
let msg = messages
......@@ -307,10 +312,31 @@ impl OpenAIPreprocessor {
_ => continue,
};
let map_item = media_map.entry(type_str.clone()).or_default();
map_item.push(MultimodalData::Url(url));
if self.media_loader.is_some() {
fetch_tasks.push((type_str, content_part.clone()));
} else {
// No loader, just pass the URL through
media_map
.entry(type_str)
.or_default()
.push(MultimodalData::Url(url));
}
}
}
// Execute all fetch tasks
if !fetch_tasks.is_empty() {
let loader = self.media_loader.as_ref().unwrap();
let _results = futures::future::join_all(
fetch_tasks
.iter()
.map(|(_, content_part)| loader.fetch_media_part(content_part)),
)
.await;
// TODO: decode and pass NIXL descriptors to the media map
}
if !media_map.is_empty() {
builder.multi_modal_data(Some(media_map));
}
......@@ -839,7 +865,7 @@ impl
let response_generator = request.response_generator(context.id().to_string());
// convert the chat completion request to a common completion request
let (common_request, annotations) = self.preprocess_request(&request)?;
let (common_request, annotations) = self.preprocess_request(&request).await?;
let mut response_generator = Box::new(response_generator);
......@@ -974,7 +1000,7 @@ impl
// convert the chat completion request to a common completion request
let mut builder = self.builder(&request)?;
let annotations = self.gather_tokens(&request, &mut builder, None)?;
self.gather_multi_modal_data(&request, &mut builder)?;
self.gather_multi_modal_data(&request, &mut builder).await?;
let common_request = builder.build()?;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod common;
mod loader;
pub use common::EncodedMediaData;
pub use loader::MediaLoader;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Result;
use base64::{Engine as _, engine::general_purpose};
// Raw encoded media data (.png, .mp4, ...), optionally b64-encoded
#[derive(Debug)]
pub struct EncodedMediaData {
pub(crate) bytes: Vec<u8>,
pub(crate) b64_encoded: bool,
}
impl EncodedMediaData {
// Handles both web URLs (will download the bytes) and data URLs (will keep b64-encoded)
pub async fn from_url(url: &url::Url, client: &reqwest::Client) -> Result<Self> {
let (bytes, b64_encoded) = match url.scheme() {
"data" => {
let base64_data = url
.as_str()
.split_once(',')
.ok_or_else(|| anyhow::anyhow!("Invalid media data URL format"))?
.1;
anyhow::ensure!(!base64_data.is_empty(), "Media data URL is empty");
(base64_data.as_bytes().to_vec(), true)
}
"http" | "https" => {
let bytes = client
.get(url.to_string())
.send()
.await?
.error_for_status()?
.bytes()
.await?;
anyhow::ensure!(!bytes.is_empty(), "Media URL is empty");
(bytes.to_vec(), false)
}
scheme => anyhow::bail!("Unsupported media URL scheme: {scheme}"),
};
Ok(Self { bytes, b64_encoded })
}
// Potentially decodes b64 bytes
pub fn into_bytes(self) -> Result<Vec<u8>> {
if self.b64_encoded {
Ok(general_purpose::STANDARD.decode(self.bytes)?)
} else {
Ok(self.bytes)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_from_base64() {
// Simple base64 encoded "test" string: dGVzdA==
let data_url = url::Url::parse("data:text/plain;base64,dGVzdA==").unwrap();
let client = reqwest::Client::new();
let result = EncodedMediaData::from_url(&data_url, &client)
.await
.unwrap();
assert!(result.b64_encoded);
assert_eq!(result.bytes, b"dGVzdA==");
let decoded = result.into_bytes().unwrap();
assert_eq!(decoded, b"test");
}
#[tokio::test]
async fn test_from_empty_base64() {
let data_url = url::Url::parse("data:text/plain;base64,").unwrap();
let client = reqwest::Client::new();
let result = EncodedMediaData::from_url(&data_url, &client).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_from_invalid_base64() {
let data_url = url::Url::parse("data:invalid").unwrap();
let client = reqwest::Client::new();
let result = EncodedMediaData::from_url(&data_url, &client).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_from_url_http() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("GET", "/image.png")
.with_status(200)
.with_body(b"test data")
.create_async()
.await;
let url = url::Url::parse(&format!("{}/image.png", server.url())).unwrap();
let client = reqwest::Client::new();
let result = EncodedMediaData::from_url(&url, &client).await.unwrap();
assert!(!result.b64_encoded);
assert_eq!(result.bytes, b"test data");
let decoded = result.into_bytes().unwrap();
assert_eq!(decoded, b"test data");
mock.assert_async().await;
}
#[tokio::test]
async fn test_from_url_http_404() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("GET", "/image.png")
.with_status(404)
.create_async()
.await;
let url = url::Url::parse(&format!("{}/image.png", server.url())).unwrap();
let client = reqwest::Client::new();
let result = EncodedMediaData::from_url(&url, &client).await;
assert!(result.is_err());
mock.assert_async().await;
}
#[tokio::test]
async fn test_from_unsupported_scheme() {
let ftp_url = url::Url::parse("ftp://example.com/image.png").unwrap();
let client = reqwest::Client::new();
let result = EncodedMediaData::from_url(&ftp_url, &client).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Unsupported media URL scheme")
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashSet;
use std::time::Duration;
use anyhow::Result;
use dynamo_async_openai::types::ChatCompletionRequestUserMessageContentPart;
use super::common::EncodedMediaData;
const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo";
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct MediaFetcher {
pub user_agent: String,
pub allow_direct_ip: bool,
pub allow_direct_port: bool,
pub allowed_media_domains: Option<HashSet<String>>,
pub timeout: Option<Duration>,
}
impl Default for MediaFetcher {
fn default() -> Self {
Self {
user_agent: DEFAULT_HTTP_USER_AGENT.to_string(),
allow_direct_ip: false,
allow_direct_port: false,
allowed_media_domains: None,
timeout: None,
}
}
}
pub struct MediaLoader {
http_client: reqwest::Client,
media_fetcher: MediaFetcher,
// TODO: decoders, NIXL agent
}
impl MediaLoader {
pub fn new(media_fetcher: MediaFetcher) -> Result<Self> {
let mut http_client_builder =
reqwest::Client::builder().user_agent(&media_fetcher.user_agent);
if let Some(timeout) = media_fetcher.timeout {
http_client_builder = http_client_builder.timeout(timeout);
}
let http_client = http_client_builder.build()?;
Ok(Self {
http_client,
media_fetcher,
})
}
pub fn check_if_url_allowed(&self, url: &url::Url) -> Result<()> {
if !matches!(url.scheme(), "http" | "https" | "data") {
anyhow::bail!("Only HTTP(S) and data URLs are allowed");
}
if url.scheme() == "data" {
return Ok(());
}
if !self.media_fetcher.allow_direct_ip && !matches!(url.host(), Some(url::Host::Domain(_)))
{
anyhow::bail!("Direct IP access is not allowed");
}
if !self.media_fetcher.allow_direct_port && url.port().is_some() {
anyhow::bail!("Direct port access is not allowed");
}
if let Some(allowed_domains) = &self.media_fetcher.allowed_media_domains
&& let Some(host) = url.host_str()
&& !allowed_domains.contains(host)
{
anyhow::bail!("Domain '{host}' is not in allowed list");
}
Ok(())
}
pub async fn fetch_media_part(
&self,
oai_content_part: &ChatCompletionRequestUserMessageContentPart,
// TODO: request-level options
) -> Result<EncodedMediaData> {
// fetch the media
// TODO: decode and NIXL-register
let data = match oai_content_part {
ChatCompletionRequestUserMessageContentPart::ImageUrl(image_part) => {
let url = &image_part.image_url.url;
self.check_if_url_allowed(url)?;
EncodedMediaData::from_url(url, &self.http_client).await?
}
ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => {
let url = &video_part.video_url.url;
self.check_if_url_allowed(url)?;
EncodedMediaData::from_url(url, &self.http_client).await?
}
ChatCompletionRequestUserMessageContentPart::AudioUrl(_) => {
anyhow::bail!("Audio decoding is not supported yet");
}
_ => anyhow::bail!("Unsupported media type"),
};
Ok(data)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_direct_ip_blocked() {
let fetcher = MediaFetcher {
allow_direct_ip: false,
..Default::default()
};
let loader = MediaLoader::new(fetcher).unwrap();
let url = url::Url::parse("http://192.168.1.1/image.jpg").unwrap();
let result = loader.check_if_url_allowed(&url);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Direct IP access is not allowed")
);
}
#[test]
fn test_direct_port_blocked() {
let fetcher = MediaFetcher {
allow_direct_port: false,
..Default::default()
};
let loader = MediaLoader::new(fetcher).unwrap();
let url = url::Url::parse("http://example.com:8080/image.jpg").unwrap();
let result = loader.check_if_url_allowed(&url);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Direct port access is not allowed")
);
}
#[test]
fn test_domain_allowlist() {
let mut allowed_domains = HashSet::new();
allowed_domains.insert("trusted.com".to_string());
allowed_domains.insert("example.com".to_string());
let fetcher = MediaFetcher {
allowed_media_domains: Some(allowed_domains),
..Default::default()
};
let loader = MediaLoader::new(fetcher).unwrap();
// Allowed domain should pass
let url = url::Url::parse("https://trusted.com/image.jpg").unwrap();
assert!(loader.check_if_url_allowed(&url).is_ok());
// Disallowed domain should fail
let url = url::Url::parse("https://untrusted.com/image.jpg").unwrap();
let result = loader.check_if_url_allowed(&url);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("not in allowed list")
);
}
}
......@@ -551,7 +551,7 @@ async fn test_media_url_passthrough(#[case] media_chunks: &[(&str, usize)]) {
let message = build_message("Test multimodal content", media_chunks);
let request = Request::from(&message, None, None, mdc.slug().to_string());
let (preprocessed, _annotations) = preprocessor.preprocess_request(&request).unwrap();
let (preprocessed, _annotations) = preprocessor.preprocess_request(&request).await.unwrap();
// Verify multimodal data handling
if media_chunks.is_empty() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment