Unverified Commit 8daacbd7 authored by milesial's avatar milesial Committed by GitHub
Browse files

feat: default with lib/memory, media-nixl and kvbm (#5602)


Signed-off-by: default avatarAlexandre Milesi <milesial@users.noreply.github.com>
parent 0862f87b
...@@ -253,7 +253,7 @@ pub fn validate_block_transfer( ...@@ -253,7 +253,7 @@ pub fn validate_block_transfer(
Ok(()) Ok(())
} }
#[cfg(test)] #[cfg(all(test, feature = "testing-nixl"))]
mod tests { mod tests {
use super::super::tests::*; use super::super::tests::*;
use super::*; use super::*;
......
...@@ -41,6 +41,33 @@ impl Default for MediaFetcher { ...@@ -41,6 +41,33 @@ impl Default for MediaFetcher {
} }
} }
impl MediaFetcher {
pub fn check_if_url_allowed(&self, url: &url::Url) -> Result<()> {
if !matches!(url.scheme(), "http" | "https" | "data") {
anyhow::bail!("Only HTTP(S) and data URLs are allowed");
}
if url.scheme() == "data" {
return Ok(());
}
if !self.allow_direct_ip && !matches!(url.host(), Some(url::Host::Domain(_))) {
anyhow::bail!("Direct IP access is not allowed");
}
if !self.allow_direct_port && url.port().is_some() {
anyhow::bail!("Direct port access is not allowed");
}
if let Some(allowed_domains) = &self.allowed_media_domains
&& let Some(host) = url.host_str()
&& !allowed_domains.contains(host)
{
anyhow::bail!("Domain '{host}' is not in allowed list");
}
Ok(())
}
}
pub struct MediaLoader { pub struct MediaLoader {
#[allow(dead_code)] #[allow(dead_code)]
media_decoder: MediaDecoder, media_decoder: MediaDecoder,
...@@ -75,32 +102,6 @@ impl MediaLoader { ...@@ -75,32 +102,6 @@ impl MediaLoader {
}) })
} }
pub fn check_if_url_allowed(&self, url: &url::Url) -> Result<()> {
if !matches!(url.scheme(), "http" | "https" | "data") {
anyhow::bail!("Only HTTP(S) and data URLs are allowed");
}
if url.scheme() == "data" {
return Ok(());
}
if !self.media_fetcher.allow_direct_ip && !matches!(url.host(), Some(url::Host::Domain(_)))
{
anyhow::bail!("Direct IP access is not allowed");
}
if !self.media_fetcher.allow_direct_port && url.port().is_some() {
anyhow::bail!("Direct port access is not allowed");
}
if let Some(allowed_domains) = &self.media_fetcher.allowed_media_domains
&& let Some(host) = url.host_str()
&& !allowed_domains.contains(host)
{
anyhow::bail!("Domain '{host}' is not in allowed list");
}
Ok(())
}
pub async fn fetch_and_decode_media_part( pub async fn fetch_and_decode_media_part(
&self, &self,
oai_content_part: &ChatCompletionRequestUserMessageContentPart, oai_content_part: &ChatCompletionRequestUserMessageContentPart,
...@@ -122,7 +123,7 @@ impl MediaLoader { ...@@ -122,7 +123,7 @@ impl MediaLoader {
})?; })?;
let url = &image_part.image_url.url; let url = &image_part.image_url.url;
self.check_if_url_allowed(url)?; self.media_fetcher.check_if_url_allowed(url)?;
let data = EncodedMediaData::from_url(url, &self.http_client).await?; let data = EncodedMediaData::from_url(url, &self.http_client).await?;
// Use runtime decoder if provided, with MDC limits enforced // Use runtime decoder if provided, with MDC limits enforced
...@@ -130,6 +131,7 @@ impl MediaLoader { ...@@ -130,6 +131,7 @@ impl MediaLoader {
mdc_decoder.with_runtime(media_io_kwargs.and_then(|k| k.image.as_ref())); mdc_decoder.with_runtime(media_io_kwargs.and_then(|k| k.image.as_ref()));
decoder.decode_async(data).await? decoder.decode_async(data).await?
} }
#[allow(unused_variables)]
ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => { ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => {
#[cfg(not(feature = "media-ffmpeg"))] #[cfg(not(feature = "media-ffmpeg"))]
anyhow::bail!( anyhow::bail!(
...@@ -143,7 +145,7 @@ impl MediaLoader { ...@@ -143,7 +145,7 @@ impl MediaLoader {
})?; })?;
let url = &video_part.video_url.url; let url = &video_part.video_url.url;
self.check_if_url_allowed(url)?; self.media_fetcher.check_if_url_allowed(url)?;
let data = EncodedMediaData::from_url(url, &self.http_client).await?; let data = EncodedMediaData::from_url(url, &self.http_client).await?;
// Use runtime decoder if provided, with MDC limits enforced // Use runtime decoder if provided, with MDC limits enforced
...@@ -164,7 +166,7 @@ impl MediaLoader { ...@@ -164,7 +166,7 @@ impl MediaLoader {
} }
} }
#[cfg(all(test, feature = "media-nixl"))] #[cfg(all(test, feature = "media-nixl", feature = "testing-nixl"))]
mod tests { mod tests {
use super::super::decoders::ImageDecoder; use super::super::decoders::ImageDecoder;
use super::super::rdma::DataType; use super::super::rdma::DataType;
...@@ -255,10 +257,9 @@ mod tests_non_nixl { ...@@ -255,10 +257,9 @@ mod tests_non_nixl {
allow_direct_ip: false, allow_direct_ip: false,
..Default::default() ..Default::default()
}; };
let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap();
let url = url::Url::parse("http://192.168.1.1/image.jpg").unwrap(); let url = url::Url::parse("http://192.168.1.1/image.jpg").unwrap();
let result = loader.check_if_url_allowed(&url); let result = fetcher.check_if_url_allowed(&url);
assert!(result.is_err()); assert!(result.is_err());
assert!( assert!(
...@@ -275,10 +276,9 @@ mod tests_non_nixl { ...@@ -275,10 +276,9 @@ mod tests_non_nixl {
allow_direct_port: false, allow_direct_port: false,
..Default::default() ..Default::default()
}; };
let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap();
let url = url::Url::parse("http://example.com:8080/image.jpg").unwrap(); let url = url::Url::parse("http://example.com:8080/image.jpg").unwrap();
let result = loader.check_if_url_allowed(&url); let result = fetcher.check_if_url_allowed(&url);
assert!(result.is_err()); assert!(result.is_err());
assert!( assert!(
...@@ -299,15 +299,14 @@ mod tests_non_nixl { ...@@ -299,15 +299,14 @@ mod tests_non_nixl {
allowed_media_domains: Some(allowed_domains), allowed_media_domains: Some(allowed_domains),
..Default::default() ..Default::default()
}; };
let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap();
// Allowed domain should pass // Allowed domain should pass
let url = url::Url::parse("https://trusted.com/image.jpg").unwrap(); let url = url::Url::parse("https://trusted.com/image.jpg").unwrap();
assert!(loader.check_if_url_allowed(&url).is_ok()); assert!(fetcher.check_if_url_allowed(&url).is_ok());
// Disallowed domain should fail // Disallowed domain should fail
let url = url::Url::parse("https://untrusted.com/image.jpg").unwrap(); let url = url::Url::parse("https://untrusted.com/image.jpg").unwrap();
let result = loader.check_if_url_allowed(&url); let result = fetcher.check_if_url_allowed(&url);
assert!(result.is_err()); assert!(result.is_err());
assert!( assert!(
result result
......
...@@ -11,7 +11,7 @@ repository.workspace = true ...@@ -11,7 +11,7 @@ repository.workspace = true
description = "Memory management library for Dynamo" description = "Memory management library for Dynamo"
[features] [features]
default = ["testing-all"] default = []
# feature to enable unsafe slices of memory descriptors # feature to enable unsafe slices of memory descriptors
# for advanced testing in other crates # for advanced testing in other crates
......
...@@ -164,7 +164,7 @@ impl std::ops::Deref for NixlAgent { ...@@ -164,7 +164,7 @@ impl std::ops::Deref for NixlAgent {
} }
} }
#[cfg(test)] #[cfg(all(test, feature = "testing-nixl"))]
mod tests { mod tests {
use super::*; use super::*;
......
...@@ -299,7 +299,7 @@ impl Drop for CudaMemPool { ...@@ -299,7 +299,7 @@ impl Drop for CudaMemPool {
} }
} }
#[cfg(test)] #[cfg(all(test, feature = "testing-cuda"))]
mod tests { mod tests {
use super::*; use super::*;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment