Unverified Commit 61b0ffcd authored by zhongdaor-nv's avatar zhongdaor-nv Committed by GitHub
Browse files

fix: harden Rust MediaFetcher agains redirect bypass and DNS rebinding (#8569)


Signed-off-by: default avatarzhongdaor <zhongdaor@nvidia.com>
parent 12e144ae
...@@ -2431,6 +2431,7 @@ dependencies = [ ...@@ -2431,6 +2431,7 @@ dependencies = [
"image", "image",
"indicatif 0.18.4", "indicatif 0.18.4",
"insta", "insta",
"ipnet",
"json-five", "json-five",
"lazy_static", "lazy_static",
"memfile", "memfile",
......
...@@ -95,6 +95,7 @@ modelexpress-common = { version = "0.3.0" } ...@@ -95,6 +95,7 @@ modelexpress-common = { version = "0.3.0" }
humantime = { version = "2.2.0" } humantime = { version = "2.2.0" }
indexmap = { version = "2" } indexmap = { version = "2" }
ipnet = { version = "2" }
libc = { version = "0.2" } libc = { version = "0.2" }
oneshot = { version = "0.1.13", features = ["std", "async"] } oneshot = { version = "0.1.13", features = ["std", "async"] }
ordered-float = "4" ordered-float = "4"
......
...@@ -1563,6 +1563,7 @@ dependencies = [ ...@@ -1563,6 +1563,7 @@ dependencies = [
"galil-seiferas", "galil-seiferas",
"hf-hub", "hf-hub",
"image", "image",
"ipnet",
"json-five", "json-five",
"minijinja", "minijinja",
"minijinja-contrib", "minijinja-contrib",
......
...@@ -1576,6 +1576,7 @@ dependencies = [ ...@@ -1576,6 +1576,7 @@ dependencies = [
"galil-seiferas", "galil-seiferas",
"hf-hub", "hf-hub",
"image", "image",
"ipnet",
"json-five", "json-five",
"memfile", "memfile",
"minijinja", "minijinja",
......
...@@ -49,8 +49,11 @@ pub struct MediaFetcher { ...@@ -49,8 +49,11 @@ pub struct MediaFetcher {
impl MediaFetcher { impl MediaFetcher {
#[new] #[new]
fn new() -> Self { fn new() -> Self {
// Use from_env so DYN_MM_ALLOW_INTERNAL is honored by the
// Rust-side frontend-decode fetch path, matching the Python
// UrlValidationPolicy.from_env() behavior on the backend.
Self { Self {
inner: RsMediaFetcher::default(), inner: RsMediaFetcher::from_env(),
} }
} }
fn user_agent(&mut self, user_agent: String) { fn user_agent(&mut self, user_agent: String) {
...@@ -65,6 +68,10 @@ impl MediaFetcher { ...@@ -65,6 +68,10 @@ impl MediaFetcher {
self.inner.allow_direct_port = allow; self.inner.allow_direct_port = allow;
} }
fn allow_private_ips(&mut self, allow: bool) {
self.inner.allow_private_ips = allow;
}
fn allowed_media_domains(&mut self, domains: Vec<String>) { fn allowed_media_domains(&mut self, domains: Vec<String>) {
self.inner.allowed_media_domains = Some(domains.into_iter().collect()); self.inner.allowed_media_domains = Some(domains.into_iter().collect());
} }
......
...@@ -76,6 +76,7 @@ either = { workspace = true } ...@@ -76,6 +76,7 @@ either = { workspace = true }
futures = { workspace = true } futures = { workspace = true }
futures-util = { workspace = true } futures-util = { workspace = true }
hf-hub = { workspace = true } hf-hub = { workspace = true }
ipnet = { workspace = true }
rand = { workspace = true } rand = { workspace = true }
oneshot = { workspace = true } oneshot = { workspace = true }
parking_lot = { workspace = true } parking_lot = { workspace = true }
......
...@@ -2,9 +2,14 @@ ...@@ -2,9 +2,14 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::HashSet; use std::collections::HashSet;
use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, LazyLock};
use std::time::Duration; use std::time::Duration;
use anyhow::Result; use anyhow::Result;
use ipnet::IpNet;
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use reqwest::redirect::Policy;
use dynamo_memory::nixl::NixlAgent; use dynamo_memory::nixl::NixlAgent;
use dynamo_protocols::types::ChatCompletionRequestUserMessageContentPart; use dynamo_protocols::types::ChatCompletionRequestUserMessageContentPart;
...@@ -15,12 +20,90 @@ use super::rdma::{RdmaMediaDataDescriptor, get_nixl_agent}; ...@@ -15,12 +20,90 @@ use super::rdma::{RdmaMediaDataDescriptor, get_nixl_agent};
const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo"; const DEFAULT_HTTP_USER_AGENT: &str = "dynamo-ai/dynamo";
const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30); const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30);
const MAX_REDIRECTS: usize = 3;
// IP ranges that must never be reachable from a user-controlled URL.
// Source: RFC1918 (private), RFC6598 (CGNAT), RFC5735 (loopback, link-local,
// 0.0.0.0/8), RFC4193 (ULA), RFC4291 (IPv6 loopback / link-local), RFC6890
// (reserved). Link-local 169.254/16 covers the AWS / OpenStack metadata IP.
//
// Keep this list in sync with the Python counterpart
// (components/src/dynamo/common/multimodal/url_validator.py::_BLOCKED_IP_NETWORKS).
static BLOCKED_IP_NETWORKS: LazyLock<Vec<IpNet>> = LazyLock::new(|| {
[
"0.0.0.0/8",
"10.0.0.0/8",
"100.64.0.0/10",
"127.0.0.0/8",
"169.254.0.0/16",
"172.16.0.0/12",
"192.0.0.0/24",
"192.0.2.0/24",
"192.168.0.0/16",
"198.18.0.0/15",
"198.51.100.0/24",
"203.0.113.0/24",
"224.0.0.0/4",
"240.0.0.0/4",
"255.255.255.255/32",
"::/128",
"::1/128",
"::ffff:0:0/96",
"fc00::/7",
"fe80::/10",
"ff00::/8",
]
.iter()
.map(|s| s.parse().expect("invalid CIDR in BLOCKED_IP_NETWORKS"))
.collect()
});
// Hostnames we reject by literal match without any DNS lookup. Defends
// against /etc/hosts tricks or malicious resolvers that alias metadata /
// internal-service names to attacker IPs. Match is case-insensitive.
//
// Keep this list in sync with the Python counterpart
// (components/src/dynamo/common/multimodal/url_validator.py::_BLOCKED_HOSTS).
static BLOCKED_HOSTS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
[
"localhost",
"localhost.localdomain",
"ip6-localhost",
"ip6-loopback",
"metadata",
"metadata.google.internal",
"metadata.goog",
"kubernetes.default",
"kubernetes.default.svc",
]
.iter()
.copied()
.collect()
});
/// Return `true` if `ip` falls inside any of the blocked ranges.
pub fn is_blocked_ip(ip: &IpAddr) -> bool {
BLOCKED_IP_NETWORKS.iter().any(|net| net.contains(ip))
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct MediaFetcher { pub struct MediaFetcher {
pub user_agent: String, pub user_agent: String,
pub allow_direct_ip: bool, pub allow_direct_ip: bool,
pub allow_direct_port: bool, pub allow_direct_port: bool,
/// When `false` (default), reject URLs that target blocked locations:
/// an IP literal in the RFC-range blocklist (`BLOCKED_IP_NETWORKS`),
/// a hostname in the literal blocklist (`BLOCKED_HOSTS`, e.g.
/// `localhost` / `metadata.google.internal`), or — in
/// `check_if_url_allowed_with_dns` — a hostname that DNS-resolves
/// to a blocked IP. The name reads "IP" but semantically this is a
/// single "allow internal / on-prem targets" switch that covers
/// both IP and hostname blocklists together: real on-prem
/// deployments need both at once (private CIDRs *and* internal
/// service names), and splitting them would give no useful config
/// while doubling the footgun surface. **Never** set on anything
/// public-facing.
pub allow_private_ips: bool,
pub allowed_media_domains: Option<HashSet<String>>, pub allowed_media_domains: Option<HashSet<String>>,
pub timeout: Option<Duration>, pub timeout: Option<Duration>,
} }
...@@ -31,12 +114,34 @@ impl Default for MediaFetcher { ...@@ -31,12 +114,34 @@ impl Default for MediaFetcher {
user_agent: DEFAULT_HTTP_USER_AGENT.to_string(), user_agent: DEFAULT_HTTP_USER_AGENT.to_string(),
allow_direct_ip: false, allow_direct_ip: false,
allow_direct_port: false, allow_direct_port: false,
allow_private_ips: false,
allowed_media_domains: None, allowed_media_domains: None,
timeout: Some(DEFAULT_HTTP_TIMEOUT), timeout: Some(DEFAULT_HTTP_TIMEOUT),
} }
} }
} }
impl MediaFetcher {
/// Build a `MediaFetcher` whose defaults respect the shared
/// `DYN_MM_ALLOW_INTERNAL` environment variable. Mirrors the Python
/// `UrlValidationPolicy.from_env()` behavior so both fetch paths
/// (frontend decode in Rust, backend decode in Python) honor the
/// same on-prem opt-in flag.
///
/// `DYN_MM_ALLOW_INTERNAL=1` flips `allow_direct_ip`,
/// `allow_direct_port`, and `allow_private_ips` all to `true` at
/// once.
pub fn from_env() -> Self {
let allow_internal = std::env::var("DYN_MM_ALLOW_INTERNAL").ok().as_deref() == Some("1");
Self {
allow_direct_ip: allow_internal,
allow_direct_port: allow_internal,
allow_private_ips: allow_internal,
..Self::default()
}
}
}
impl MediaFetcher { impl MediaFetcher {
pub fn check_if_url_allowed(&self, url: &url::Url) -> Result<()> { pub fn check_if_url_allowed(&self, url: &url::Url) -> Result<()> {
if !matches!(url.scheme(), "http" | "https" | "data") { if !matches!(url.scheme(), "http" | "https" | "data") {
...@@ -47,21 +152,107 @@ impl MediaFetcher { ...@@ -47,21 +152,107 @@ impl MediaFetcher {
return Ok(()); return Ok(());
} }
if !self.allow_direct_ip && !matches!(url.host(), Some(url::Host::Domain(_))) { let host = url
.host()
.ok_or_else(|| anyhow::anyhow!("URL has no host component"))?;
if !self.allow_direct_ip && !matches!(host, url::Host::Domain(_)) {
anyhow::bail!("Direct IP access is not allowed"); anyhow::bail!("Direct IP access is not allowed");
} }
if !self.allow_direct_port && url.port().is_some() { if !self.allow_direct_port && url.port().is_some() {
anyhow::bail!("Direct port access is not allowed"); anyhow::bail!("Direct port access is not allowed");
} }
// Host-level checks: blocked hostnames and IP literals in blocked
// ranges. DNS-resolved IPs are checked in `check_if_url_allowed_with_dns`.
if !self.allow_private_ips {
let ip_literal = match host {
url::Host::Domain(domain) => {
let lowered = domain.trim_end_matches('.').to_ascii_lowercase();
if BLOCKED_HOSTS.contains(lowered.as_str()) {
anyhow::bail!("Host '{domain}' is blocked (resolves to internal service)");
}
None
}
url::Host::Ipv4(ip) => Some(IpAddr::V4(ip)),
url::Host::Ipv6(ip) => Some(IpAddr::V6(ip)),
};
if let Some(ip) = ip_literal
&& is_blocked_ip(&ip)
{
anyhow::bail!("IP literal '{ip}' is in a blocked range");
}
}
if let Some(allowed_domains) = &self.allowed_media_domains if let Some(allowed_domains) = &self.allowed_media_domains
&& let Some(host) = url.host_str() && let Some(host_str) = url.host_str()
&& !allowed_domains.contains(host) && !allowed_domains.contains(host_str)
{ {
anyhow::bail!("Domain '{host}' is not in allowed list"); anyhow::bail!("Host '{host_str}' is not in the allowed_media_domains list");
} }
Ok(()) Ok(())
} }
/// Full policy check: runs `check_if_url_allowed` and, for hostname
/// URLs, resolves DNS and checks each resulting IP against the blocked ranges.
pub async fn check_if_url_allowed_with_dns(&self, url: &url::Url) -> Result<()> {
self.check_if_url_allowed(url)?;
// Only hostnames need DNS resolution; IP-literal hosts were already
// checked against the blocklist above.
if self.allow_private_ips || url.scheme() == "data" {
return Ok(());
}
let Some(url::Host::Domain(host)) = url.host() else {
return Ok(());
};
let port = url.port_or_known_default().unwrap_or(0);
let iter = tokio::net::lookup_host((host, port))
.await
.map_err(|e| anyhow::anyhow!("Could not resolve host '{host}': {e}"))?;
for sock_addr in iter {
let ip = sock_addr.ip();
if is_blocked_ip(&ip) {
anyhow::bail!("Host '{host}' resolves to blocked IP '{ip}'");
}
}
Ok(())
}
}
/// DNS resolver that filters out blocked IP ranges before reqwest sees them.
///
/// Attached to the shared `reqwest::Client` via `ClientBuilder::dns_resolver`.
/// reqwest calls this for every hostname it needs to resolve — including
/// redirect targets — so DNS rebinding can't slip a blocked IP past us:
/// reqwest literally never learns about the blocked addresses.
struct BlocklistResolver {
allow_private_ips: bool,
}
impl Resolve for BlocklistResolver {
fn resolve(&self, name: Name) -> Resolving {
let host = name.as_str().to_string();
let allow_private = self.allow_private_ips;
Box::pin(async move {
let iter = tokio::net::lookup_host((host.as_str(), 0_u16)).await?;
let addrs: Vec<SocketAddr> = if allow_private {
iter.collect()
} else {
iter.filter(|sa| !is_blocked_ip(&sa.ip())).collect()
};
if addrs.is_empty() {
return Err(Box::new(std::io::Error::new(
std::io::ErrorKind::AddrNotAvailable,
format!("no non-blocked addresses for host '{host}'"),
))
as Box<dyn std::error::Error + Send + Sync>);
}
Ok(Box::new(addrs.into_iter()) as Addrs)
})
}
} }
pub struct MediaLoader { pub struct MediaLoader {
...@@ -76,9 +267,31 @@ pub struct MediaLoader { ...@@ -76,9 +267,31 @@ pub struct MediaLoader {
impl MediaLoader { impl MediaLoader {
pub fn new(media_decoder: MediaDecoder, media_fetcher: Option<MediaFetcher>) -> Result<Self> { pub fn new(media_decoder: MediaDecoder, media_fetcher: Option<MediaFetcher>) -> Result<Self> {
let media_fetcher = media_fetcher.unwrap_or_default(); // Fall back to env-aware defaults so `DYN_MM_ALLOW_INTERNAL=1` is
let mut http_client_builder: reqwest::ClientBuilder = // honored even when the caller doesn't pass an explicit fetcher.
reqwest::Client::builder().user_agent(&media_fetcher.user_agent); let media_fetcher = media_fetcher.unwrap_or_else(MediaFetcher::from_env);
// Redirect policy: revalidate the policy-visible part of the URL
// (scheme, IP literals, hostname blocklist, direct-IP / direct-port
// rules) on every hop. DNS-based attacks on redirect targets are
// handled by the custom resolver below.
let fetcher_for_redirects = media_fetcher.clone();
let redirect_policy = Policy::custom(move |attempt| {
if attempt.previous().len() >= MAX_REDIRECTS {
return attempt.error(anyhow::anyhow!("too many redirects (max={MAX_REDIRECTS})"));
}
match fetcher_for_redirects.check_if_url_allowed(attempt.url()) {
Ok(()) => attempt.follow(),
Err(e) => attempt.error(e),
}
});
let mut http_client_builder = reqwest::Client::builder()
.user_agent(&media_fetcher.user_agent)
.redirect(redirect_policy)
.dns_resolver(Arc::new(BlocklistResolver {
allow_private_ips: media_fetcher.allow_private_ips,
}));
if let Some(timeout) = media_fetcher.timeout { if let Some(timeout) = media_fetcher.timeout {
http_client_builder = http_client_builder.timeout(timeout); http_client_builder = http_client_builder.timeout(timeout);
...@@ -111,7 +324,9 @@ impl MediaLoader { ...@@ -111,7 +324,9 @@ impl MediaLoader {
.ok_or_else(|| anyhow::anyhow!("Model does not support image inputs"))?; .ok_or_else(|| anyhow::anyhow!("Model does not support image inputs"))?;
let url = &image_part.image_url.url; let url = &image_part.image_url.url;
self.media_fetcher.check_if_url_allowed(url)?; self.media_fetcher
.check_if_url_allowed_with_dns(url)
.await?;
let data = EncodedMediaData::from_url(url, &self.http_client).await?; let data = EncodedMediaData::from_url(url, &self.http_client).await?;
// Use runtime decoder if provided, with MDC limits enforced // Use runtime decoder if provided, with MDC limits enforced
...@@ -132,7 +347,9 @@ impl MediaLoader { ...@@ -132,7 +347,9 @@ impl MediaLoader {
})?; })?;
let url = &video_part.video_url.url; let url = &video_part.video_url.url;
self.media_fetcher.check_if_url_allowed(url)?; self.media_fetcher
.check_if_url_allowed_with_dns(url)
.await?;
let data = EncodedMediaData::from_url(url, &self.http_client).await?; let data = EncodedMediaData::from_url(url, &self.http_client).await?;
// Use runtime decoder if provided, with MDC limits enforced // Use runtime decoder if provided, with MDC limits enforced
...@@ -181,6 +398,8 @@ mod tests { ...@@ -181,6 +398,8 @@ mod tests {
let fetcher = MediaFetcher { let fetcher = MediaFetcher {
allow_direct_ip: true, allow_direct_ip: true,
allow_direct_port: true, allow_direct_port: true,
// mockito serves on 127.0.0.1 which is in the loopback blocklist.
allow_private_ips: true,
..Default::default() ..Default::default()
}; };
...@@ -307,7 +526,153 @@ mod tests_non_nixl { ...@@ -307,7 +526,153 @@ mod tests_non_nixl {
result result
.unwrap_err() .unwrap_err()
.to_string() .to_string()
.contains("not in allowed list") .contains("allowed_media_domains")
);
}
#[test]
fn test_is_blocked_ip_ranges() {
for ip in [
"127.0.0.1",
"10.0.0.1",
"172.16.5.5",
"192.168.1.1",
"169.254.169.254", // AWS metadata
"100.64.0.1", // CGNAT
"::1",
"fe80::1",
"fc00::1",
] {
let addr: IpAddr = ip.parse().unwrap();
assert!(is_blocked_ip(&addr), "{ip} should be blocked");
}
// Public IPs should pass.
for ip in ["8.8.8.8", "1.1.1.1", "2606:4700:4700::1111"] {
let addr: IpAddr = ip.parse().unwrap();
assert!(!is_blocked_ip(&addr), "{ip} should not be blocked");
}
}
#[test]
fn test_blocked_ip_literal_rejected_even_when_direct_ip_allowed() {
// allow_direct_ip=true lets IP-literal URLs through the early check,
// but the RFC-range blocklist must still reject cloud-metadata IPs.
let fetcher = MediaFetcher {
allow_direct_ip: true,
..Default::default()
};
let url = url::Url::parse("http://169.254.169.254/latest/meta-data/").unwrap();
let result = fetcher.check_if_url_allowed(&url);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("is in a blocked range")
); );
} }
#[test]
fn test_blocked_hostname_rejected() {
let fetcher = MediaFetcher::default();
for host in [
"localhost",
"metadata.google.internal",
"kubernetes.default.svc",
] {
let url = url::Url::parse(&format!("https://{host}/x")).unwrap();
let result = fetcher.check_if_url_allowed(&url);
assert!(result.is_err(), "{host} should be blocked");
assert!(
result.unwrap_err().to_string().contains("blocked"),
"{host} error should mention 'blocked'"
);
}
}
#[test]
fn test_allow_private_ips_bypasses_blocklist() {
// allow_private_ips=true is the escape hatch for on-prem / dev envs.
let fetcher = MediaFetcher {
allow_direct_ip: true,
allow_private_ips: true,
..Default::default()
};
// Both an IP literal in a blocked range and a blocked hostname
// should pass when the opt-in flag is set.
assert!(
fetcher
.check_if_url_allowed(&url::Url::parse("http://10.0.0.5/x").unwrap())
.is_ok()
);
assert!(
fetcher
.check_if_url_allowed(&url::Url::parse("https://localhost/x").unwrap())
.is_ok()
);
}
#[test]
fn test_hostname_blocklist_case_insensitive() {
let fetcher = MediaFetcher::default();
let url = url::Url::parse("https://Metadata.Google.Internal/x").unwrap();
let result = fetcher.check_if_url_allowed(&url);
assert!(result.is_err());
}
#[test]
fn test_from_env_default() {
// Saving/restoring env vars in tests is racy with parallel tests,
// so we only assert the "unset" case here (parallel-safe).
// SAFETY: single-threaded mutation acceptable for this restore.
unsafe {
std::env::remove_var("DYN_MM_ALLOW_INTERNAL");
}
let f = MediaFetcher::from_env();
assert!(!f.allow_private_ips);
assert!(!f.allow_direct_ip);
assert!(!f.allow_direct_port);
}
#[test]
fn test_hostname_blocklist_strips_trailing_dot() {
// FQDN form with a trailing dot must still match the blocklist;
// `metadata.google.internal.` resolves to the same host as
// `metadata.google.internal` at the DNS layer.
let fetcher = MediaFetcher::default();
let url = url::Url::parse("https://metadata.google.internal./x").unwrap();
let result = fetcher.check_if_url_allowed(&url);
assert!(result.is_err(), "FQDN with trailing dot should be rejected");
}
#[tokio::test]
async fn test_check_with_dns_data_url_skips_resolution() {
// data: URLs never touch the network, so the async path must early-return.
let fetcher = MediaFetcher::default();
let url = url::Url::parse("data:image/png;base64,iVBORw0KGgoAAAA=").unwrap();
fetcher.check_if_url_allowed_with_dns(&url).await.unwrap();
}
#[tokio::test]
async fn test_check_with_dns_public_ip_literal_passes() {
// IP literals were already checked by the sync pass; async path is a no-op.
let fetcher = MediaFetcher {
allow_direct_ip: true,
..Default::default()
};
let url = url::Url::parse("https://8.8.8.8/x").unwrap();
fetcher.check_if_url_allowed_with_dns(&url).await.unwrap();
}
#[tokio::test]
async fn test_check_with_dns_blocked_hostname_fails_before_resolution() {
// The sync hostname-blocklist check fires before we attempt any DNS.
let fetcher = MediaFetcher::default();
let url = url::Url::parse("https://localhost/x").unwrap();
let result = fetcher.check_if_url_allowed_with_dns(&url).await;
assert!(result.is_err());
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment