Unverified Commit 8daacbd7 authored by milesial's avatar milesial Committed by GitHub
Browse files

feat: default with lib/memory, media-nixl and kvbm (#5602)


Signed-off-by: default avatarAlexandre Milesi <milesial@users.noreply.github.com>
parent 0862f87b
......@@ -253,7 +253,7 @@ pub fn validate_block_transfer(
Ok(())
}
#[cfg(test)]
#[cfg(all(test, feature = "testing-nixl"))]
mod tests {
use super::super::tests::*;
use super::*;
......
......@@ -41,6 +41,33 @@ impl Default for MediaFetcher {
}
}
impl MediaFetcher {
pub fn check_if_url_allowed(&self, url: &url::Url) -> Result<()> {
if !matches!(url.scheme(), "http" | "https" | "data") {
anyhow::bail!("Only HTTP(S) and data URLs are allowed");
}
if url.scheme() == "data" {
return Ok(());
}
if !self.allow_direct_ip && !matches!(url.host(), Some(url::Host::Domain(_))) {
anyhow::bail!("Direct IP access is not allowed");
}
if !self.allow_direct_port && url.port().is_some() {
anyhow::bail!("Direct port access is not allowed");
}
if let Some(allowed_domains) = &self.allowed_media_domains
&& let Some(host) = url.host_str()
&& !allowed_domains.contains(host)
{
anyhow::bail!("Domain '{host}' is not in allowed list");
}
Ok(())
}
}
pub struct MediaLoader {
#[allow(dead_code)]
media_decoder: MediaDecoder,
......@@ -75,32 +102,6 @@ impl MediaLoader {
})
}
pub fn check_if_url_allowed(&self, url: &url::Url) -> Result<()> {
if !matches!(url.scheme(), "http" | "https" | "data") {
anyhow::bail!("Only HTTP(S) and data URLs are allowed");
}
if url.scheme() == "data" {
return Ok(());
}
if !self.media_fetcher.allow_direct_ip && !matches!(url.host(), Some(url::Host::Domain(_)))
{
anyhow::bail!("Direct IP access is not allowed");
}
if !self.media_fetcher.allow_direct_port && url.port().is_some() {
anyhow::bail!("Direct port access is not allowed");
}
if let Some(allowed_domains) = &self.media_fetcher.allowed_media_domains
&& let Some(host) = url.host_str()
&& !allowed_domains.contains(host)
{
anyhow::bail!("Domain '{host}' is not in allowed list");
}
Ok(())
}
pub async fn fetch_and_decode_media_part(
&self,
oai_content_part: &ChatCompletionRequestUserMessageContentPart,
......@@ -122,7 +123,7 @@ impl MediaLoader {
})?;
let url = &image_part.image_url.url;
self.check_if_url_allowed(url)?;
self.media_fetcher.check_if_url_allowed(url)?;
let data = EncodedMediaData::from_url(url, &self.http_client).await?;
// Use runtime decoder if provided, with MDC limits enforced
......@@ -130,6 +131,7 @@ impl MediaLoader {
mdc_decoder.with_runtime(media_io_kwargs.and_then(|k| k.image.as_ref()));
decoder.decode_async(data).await?
}
#[allow(unused_variables)]
ChatCompletionRequestUserMessageContentPart::VideoUrl(video_part) => {
#[cfg(not(feature = "media-ffmpeg"))]
anyhow::bail!(
......@@ -143,7 +145,7 @@ impl MediaLoader {
})?;
let url = &video_part.video_url.url;
self.check_if_url_allowed(url)?;
self.media_fetcher.check_if_url_allowed(url)?;
let data = EncodedMediaData::from_url(url, &self.http_client).await?;
// Use runtime decoder if provided, with MDC limits enforced
......@@ -164,7 +166,7 @@ impl MediaLoader {
}
}
#[cfg(all(test, feature = "media-nixl"))]
#[cfg(all(test, feature = "media-nixl", feature = "testing-nixl"))]
mod tests {
use super::super::decoders::ImageDecoder;
use super::super::rdma::DataType;
......@@ -255,10 +257,9 @@ mod tests_non_nixl {
allow_direct_ip: false,
..Default::default()
};
let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap();
let url = url::Url::parse("http://192.168.1.1/image.jpg").unwrap();
let result = loader.check_if_url_allowed(&url);
let result = fetcher.check_if_url_allowed(&url);
assert!(result.is_err());
assert!(
......@@ -275,10 +276,9 @@ mod tests_non_nixl {
allow_direct_port: false,
..Default::default()
};
let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap();
let url = url::Url::parse("http://example.com:8080/image.jpg").unwrap();
let result = loader.check_if_url_allowed(&url);
let result = fetcher.check_if_url_allowed(&url);
assert!(result.is_err());
assert!(
......@@ -299,15 +299,14 @@ mod tests_non_nixl {
allowed_media_domains: Some(allowed_domains),
..Default::default()
};
let loader = MediaLoader::new(MediaDecoder::default(), Some(fetcher)).unwrap();
// Allowed domain should pass
let url = url::Url::parse("https://trusted.com/image.jpg").unwrap();
assert!(loader.check_if_url_allowed(&url).is_ok());
assert!(fetcher.check_if_url_allowed(&url).is_ok());
// Disallowed domain should fail
let url = url::Url::parse("https://untrusted.com/image.jpg").unwrap();
let result = loader.check_if_url_allowed(&url);
let result = fetcher.check_if_url_allowed(&url);
assert!(result.is_err());
assert!(
result
......
......@@ -11,7 +11,7 @@ repository.workspace = true
description = "Memory management library for Dynamo"
[features]
default = ["testing-all"]
default = []
# feature to enable unsafe slices of memory descriptors
# for advanced testing in other crates
......
......@@ -164,7 +164,7 @@ impl std::ops::Deref for NixlAgent {
}
}
#[cfg(test)]
#[cfg(all(test, feature = "testing-nixl"))]
mod tests {
use super::*;
......
......@@ -299,7 +299,7 @@ impl Drop for CudaMemPool {
}
}
#[cfg(test)]
#[cfg(all(test, feature = "testing-cuda"))]
mod tests {
use super::*;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment