"examples/tutorial/vscode:/vscode.git/clone" did not exist on "554aa9592ea6568c933b38b5235ec1e8a663bd9f"
Unverified Commit 8fba4f56 authored by zhongdaor-nv's avatar zhongdaor-nv Committed by GitHub
Browse files

fix: Validate multimodal media URLs and load paths (#8282)


Signed-off-by: default avatarzhongdaor <zhongdaor@nvidia.com>
parent f701319e
...@@ -3,13 +3,18 @@ ...@@ -3,13 +3,18 @@
import asyncio import asyncio
import logging import logging
from pathlib import Path
from typing import Any, Awaitable, Dict, Final, List from typing import Any, Awaitable, Dict, Final, List
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
import dynamo.nixl_connect as nixl_connect import dynamo.nixl_connect as nixl_connect
from dynamo.common.multimodal.http_client import get_http_client
from dynamo.common.multimodal.url_validator import (
UrlValidationPolicy,
fetch_with_revalidation,
validate_media_url,
)
from dynamo.common.utils import nvtx_utils as _nvtx from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl
from dynamo.common.utils.runtime import run_async from dynamo.common.utils.runtime import run_async
...@@ -57,37 +62,28 @@ class AudioLoader: ...@@ -57,37 +62,28 @@ class AudioLoader:
self, self,
http_timeout: float = 30.0, http_timeout: float = 30.0,
enable_frontend_decoding: bool = False, enable_frontend_decoding: bool = False,
url_policy: UrlValidationPolicy | None = None,
) -> None: ) -> None:
if http_timeout <= 0: if http_timeout <= 0:
raise ValueError(f"http_timeout must be positive, got {http_timeout}") raise ValueError(f"http_timeout must be positive, got {http_timeout}")
self._http_timeout = http_timeout self._http_timeout = http_timeout
self._enable_frontend_decoding = enable_frontend_decoding self._enable_frontend_decoding = enable_frontend_decoding
self._url_policy = url_policy or UrlValidationPolicy.from_env()
self._nixl_connector = None self._nixl_connector = None
self._vllm_media_connector = None self._vllm_media_connector = None
if self._enable_frontend_decoding: if self._enable_frontend_decoding:
self._nixl_connector = nixl_connect.Connector() self._nixl_connector = nixl_connect.Connector()
run_async(self._nixl_connector.initialize) run_async(self._nixl_connector.initialize)
@staticmethod
def _normalize_audio_url(audio_url: str) -> str:
"""Convert bare filesystem paths to file:// URIs.
HTTP(S) and data: URLs are returned unchanged.
"""
parsed_url = urlparse(audio_url)
if parsed_url.scheme or not audio_url:
return audio_url
file_path = Path(audio_url).expanduser()
if not file_path.exists():
raise FileNotFoundError(f"Error reading file: {file_path}")
return file_path.resolve().as_uri()
def _get_vllm_media_connector(self) -> Any: def _get_vllm_media_connector(self) -> Any:
if self._vllm_media_connector is None: if self._vllm_media_connector is None:
MediaConnector, _ = _require_vllm_audio_media() MediaConnector, _ = _require_vllm_audio_media()
self._vllm_media_connector = MediaConnector(allowed_local_media_path="/") # Confine vLLM's own local-path access to the same prefix we enforce.
# Empty string matches vLLM's secure default (no local access).
allowed = self._url_policy.allowed_local_path or ""
self._vllm_media_connector = MediaConnector(
allowed_local_media_path=allowed
)
return self._vllm_media_connector return self._vllm_media_connector
...@@ -97,14 +93,23 @@ class AudioLoader: ...@@ -97,14 +93,23 @@ class AudioLoader:
@_nvtx.annotate("mm:audio:load_with_vllm", color="cyan") @_nvtx.annotate("mm:audio:load_with_vllm", color="cyan")
async def _load_audio_with_vllm(self, audio_url: str) -> tuple[np.ndarray, float]: async def _load_audio_with_vllm(self, audio_url: str) -> tuple[np.ndarray, float]:
normalized_url = await validate_media_url(audio_url, self._url_policy)
media_io = self._create_vllm_audio_io()
# HTTP(S) goes through our SSRF-safe fetcher so each redirect hop is
# revalidated; vLLM's own fetcher honors redirects without re-checking.
# data: and file:// never touch the network, so vLLM can handle them.
if urlparse(normalized_url).scheme in ("http", "https"):
http_client = get_http_client(self._http_timeout)
response = await fetch_with_revalidation(
http_client, normalized_url, self._url_policy
)
response.raise_for_status()
return await asyncio.to_thread(media_io.load_bytes, response.content)
connector = self._get_vllm_media_connector() connector = self._get_vllm_media_connector()
normalized_url = self._normalize_audio_url(audio_url)
# TODO: Add caching for repeated remote `audio_url` downloads to avoid
# refetching the same asset across requests.
return await connector.load_from_url_async( return await connector.load_from_url_async(
normalized_url, normalized_url, media_io, fetch_timeout=self._http_timeout
self._create_vllm_audio_io(),
fetch_timeout=self._http_timeout,
) )
@_nvtx.annotate("mm:audio:load_audio", color="cyan") @_nvtx.annotate("mm:audio:load_audio", color="cyan")
......
...@@ -6,12 +6,6 @@ ...@@ -6,12 +6,6 @@
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
from typing import Optional from typing import Optional
...@@ -20,28 +14,32 @@ import httpx ...@@ -20,28 +14,32 @@ import httpx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Global HTTP client instance # The shared client has ``follow_redirects=False`` to prevent redirect-based
# SSRF filter bypass. Callers must follow redirects manually via
# :func:`dynamo.common.multimodal.url_validator.fetch_with_revalidation` so
# that each hop is re-validated against the SSRF policy.
_global_http_client: Optional[httpx.AsyncClient] = None _global_http_client: Optional[httpx.AsyncClient] = None
def get_http_client(timeout: float = 60.0) -> httpx.AsyncClient: def get_http_client(timeout: float = 60.0) -> httpx.AsyncClient:
""" """Return a shared async HTTP client for media fetches.
Get or create a shared HTTP client instance.
Args: The client intentionally disables automatic redirect following. Callers
timeout: Timeout for HTTP requests that need to follow redirects must route the request through
:func:`fetch_with_revalidation`, which revalidates every redirect hop
Returns: against the SSRF policy.
Shared HTTP client instance
""" """
global _global_http_client global _global_http_client
if _global_http_client is None or _global_http_client.is_closed: if _global_http_client is None or _global_http_client.is_closed:
_global_http_client = httpx.AsyncClient( _global_http_client = httpx.AsyncClient(
timeout=timeout, timeout=timeout,
follow_redirects=True, follow_redirects=False,
limits=httpx.Limits(max_keepalive_connections=20, max_connections=100), limits=httpx.Limits(max_keepalive_connections=20, max_connections=100),
) )
logger.info(f"Shared HTTP client initialized with timeout={timeout}s") logger.info(
"Shared HTTP client initialized (timeout=%ss, follow_redirects=False)",
timeout,
)
return _global_http_client return _global_http_client
...@@ -20,6 +20,11 @@ from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl ...@@ -20,6 +20,11 @@ from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl
from dynamo.common.utils.runtime import run_async from dynamo.common.utils.runtime import run_async
from .http_client import get_http_client from .http_client import get_http_client
from .url_validator import (
UrlValidationPolicy,
fetch_with_revalidation,
validate_media_url,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -36,6 +41,7 @@ class ImageLoader: ...@@ -36,6 +41,7 @@ class ImageLoader:
cache_size: int = CACHE_SIZE_MAXIMUM, cache_size: int = CACHE_SIZE_MAXIMUM,
http_timeout: float = 30.0, http_timeout: float = 30.0,
enable_frontend_decoding: bool = False, enable_frontend_decoding: bool = False,
url_policy: UrlValidationPolicy | None = None,
): ):
""" """
Initialize the ImageLoader with caching, HTTP settings, and optional NIXL config for Initialize the ImageLoader with caching, HTTP settings, and optional NIXL config for
...@@ -49,12 +55,14 @@ class ImageLoader: ...@@ -49,12 +55,14 @@ class ImageLoader:
enable_frontend_decoding: If True, enables NIXL RDMA for transferring enable_frontend_decoding: If True, enables NIXL RDMA for transferring
decoded images directly from frontend memory, bypassing standard decoded images directly from frontend memory, bypassing standard
network transport. Defaults to False. network transport. Defaults to False.
url_policy: Policy for validating URLs. Defaults to UrlValidationPolicy.from_env().
""" """
self._http_timeout = http_timeout self._http_timeout = http_timeout
self._cache_size = cache_size self._cache_size = cache_size
self._image_cache: OrderedDict[str, Image.Image] = OrderedDict() self._image_cache: OrderedDict[str, Image.Image] = OrderedDict()
self._inflight: dict[str, asyncio.Task[Image.Image]] = {} self._inflight: dict[str, asyncio.Task[Image.Image]] = {}
self._enable_frontend_decoding = enable_frontend_decoding self._enable_frontend_decoding = enable_frontend_decoding
self._url_policy = url_policy or UrlValidationPolicy.from_env()
# Lazy-init NIXL connector only when frontend decoding is enabled # Lazy-init NIXL connector only when frontend decoding is enabled
self._nixl_connector = None self._nixl_connector = None
if self._enable_frontend_decoding: if self._enable_frontend_decoding:
...@@ -94,7 +102,9 @@ class ImageLoader: ...@@ -94,7 +102,9 @@ class ImageLoader:
try: try:
with _nvtx.annotate("mm:img:http_fetch", color="lime"): with _nvtx.annotate("mm:img:http_fetch", color="lime"):
http_client = get_http_client(self._http_timeout) http_client = get_http_client(self._http_timeout)
response = await http_client.get(image_url) response = await fetch_with_revalidation(
http_client, image_url, self._url_policy
)
response.raise_for_status() response.raise_for_status()
if not response.content: if not response.content:
raise ValueError("Empty response content from image URL") raise ValueError("Empty response content from image URL")
...@@ -134,26 +144,25 @@ class ImageLoader: ...@@ -134,26 +144,25 @@ class ImageLoader:
raise ValueError( raise ValueError(
"Invalid image source scheme: local file access is not allowed" "Invalid image source scheme: local file access is not allowed"
) )
normalized_url = await validate_media_url(image_url, self._url_policy)
parsed_url = urlparse(normalized_url)
if parsed_url.scheme in ("http", "https"): if parsed_url.scheme in ("http", "https"):
key = image_url.lower() key = normalized_url.lower()
# Check cache (sync — no await, no interleaving possible)
if key in self._image_cache: if key in self._image_cache:
logger.debug(f"Image found in cache for URL: {image_url}") logger.debug(f"Image found in cache for URL: {image_url}")
self._image_cache.move_to_end(key) self._image_cache.move_to_end(key)
return self._image_cache[key] return self._image_cache[key]
# Join existing in-flight task, or start a new one
if key not in self._inflight: if key not in self._inflight:
task = asyncio.create_task(self._fetch_and_cache(key, image_url)) task = asyncio.create_task(self._fetch_and_cache(key, normalized_url))
# Suppress "exception was never retrieved" if all waiters cancel # Suppress "exception was never retrieved" if all waiters cancel
task.add_done_callback( task.add_done_callback(
lambda t: t.exception() if not t.cancelled() else None lambda t: t.exception() if not t.cancelled() else None
) )
self._inflight[key] = task self._inflight[key] = task
# shield so cancelling THIS caller doesn't cancel the shared task
return await asyncio.shield(self._inflight[key]) return await asyncio.shield(self._inflight[key])
if parsed_url.scheme == "data": if parsed_url.scheme == "data":
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""URL / path validation and SSRF-safe HTTP fetching for multimodal loaders.
By default (``UrlValidationPolicy()``), only ``https://`` and ``data:`` URLs
are allowed; private / internal IPs and local filesystem access are both
blocked. Individual loaders can add stricter rules on top — ``ImageLoader``,
for example, refuses every local input regardless of policy.
To loosen the defaults, either build a ``UrlValidationPolicy(...)`` directly
or call ``UrlValidationPolicy.from_env()`` to pick up the ``DYN_MM_*`` vars
below.
"""
import asyncio
import ipaddress
import os
import socket
from dataclasses import dataclass
from pathlib import Path
from urllib.parse import urlparse
import httpx
class UrlValidationError(ValueError):
"""Raised when a URL or filesystem path fails the configured policy."""
# IP ranges that must never be reachable from a user-controlled URL.
# Source: RFC1918 (private), RFC6598 (CGNAT), RFC5735 (loopback, link-local,
# 0.0.0.0/8), RFC4193 (ULA), RFC4291 (IPv6 loopback / link-local), RFC6890
# (reserved). Link-local 169.254/16 covers the AWS / OpenStack metadata IP.
_BLOCKED_IP_NETWORKS: tuple[ipaddress.IPv4Network | ipaddress.IPv6Network, ...] = (
ipaddress.ip_network("0.0.0.0/8"),
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("100.64.0.0/10"),
ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("169.254.0.0/16"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.0.0.0/24"),
ipaddress.ip_network("192.0.2.0/24"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("198.18.0.0/15"),
ipaddress.ip_network("198.51.100.0/24"),
ipaddress.ip_network("203.0.113.0/24"),
ipaddress.ip_network("224.0.0.0/4"),
ipaddress.ip_network("240.0.0.0/4"),
ipaddress.ip_network("255.255.255.255/32"),
ipaddress.ip_network("::/128"),
ipaddress.ip_network("::1/128"),
ipaddress.ip_network("::ffff:0:0/96"),
ipaddress.ip_network("fc00::/7"),
ipaddress.ip_network("fe80::/10"),
ipaddress.ip_network("ff00::/8"),
)
# Hostnames that resolve to cloud metadata / internal services regardless of
# DNS records. Matched case-insensitively.
_BLOCKED_HOSTS: frozenset[str] = frozenset(
{
"localhost",
"localhost.localdomain",
"ip6-localhost",
"ip6-loopback",
"metadata",
"metadata.google.internal",
"metadata.goog",
"kubernetes.default",
"kubernetes.default.svc",
}
)
def is_blocked_ip(ip_text: str) -> bool:
"""Return True if ``ip_text`` parses as an IP inside one of the blocked ranges."""
try:
ip = ipaddress.ip_address(ip_text)
except ValueError:
return False
return any(ip in net for net in _BLOCKED_IP_NETWORKS)
@dataclass(frozen=True)
class UrlValidationPolicy:
"""Frozen policy describing which media URLs and local paths are allowed."""
allow_http: bool = False
allow_private_ips: bool = False
allowed_local_path: str | None = None
@classmethod
def from_env(cls) -> "UrlValidationPolicy":
"""Build a policy by reading the ``DYN_MM_*`` environment variables."""
allow_internal = os.getenv("DYN_MM_ALLOW_INTERNAL", "0") == "1"
return cls(
allow_http=allow_internal,
allow_private_ips=allow_internal,
allowed_local_path=os.getenv("DYN_MM_LOCAL_PATH", "").strip() or None,
)
async def validate_url(url: str, policy: UrlValidationPolicy) -> str:
"""Check ``url`` against ``policy`` and return it unchanged if it passes.
``https://`` and ``data:`` always pass. ``http://`` needs
``allow_http=True``. Anything else is rejected outright.
For URLs with a hostname, we resolve it here (off the event loop via
``loop.getaddrinfo``) and check the resulting IPs against the blocked
ranges. This catches obvious DNS rebinding but not an attacker who
changes their answer between this lookup and httpx's actual connect.
Raises ``UrlValidationError`` on any policy violation.
"""
if not url:
raise UrlValidationError("URL is empty")
parsed = urlparse(url)
scheme = parsed.scheme.lower()
if scheme == "data":
return url
if scheme not in ("http", "https"):
raise UrlValidationError(f"URL scheme '{scheme}' not allowed")
if scheme == "http" and not policy.allow_http:
raise UrlValidationError(
"http:// URLs are not allowed; set DYN_MM_ALLOW_INTERNAL=1 to enable"
)
host = (parsed.hostname or "").lower()
if not host:
raise UrlValidationError(f"URL has no host component: {url!r}")
if not policy.allow_private_ips and host in _BLOCKED_HOSTS:
raise UrlValidationError(
f"Host '{host}' is blocked (resolves to internal service)"
)
try:
ipaddress.ip_address(host)
except ValueError:
pass
else:
if not policy.allow_private_ips and is_blocked_ip(host):
raise UrlValidationError(f"IP literal '{host}' is in a blocked range")
return url
if policy.allow_private_ips:
return url
loop = asyncio.get_running_loop()
try:
infos = await loop.getaddrinfo(host, None)
except socket.gaierror as exc:
raise UrlValidationError(f"Could not resolve host '{host}': {exc}") from exc
for info in infos:
addr = info[4][0]
if is_blocked_ip(addr):
raise UrlValidationError(f"Host '{host}' resolves to blocked IP '{addr}'")
return url
def validate_local_path(path: str, policy: UrlValidationPolicy) -> Path:
"""Resolve ``path`` and confirm it sits inside ``allowed_local_path``.
We call ``Path.resolve()`` first, so symlinks that point outside the
allowed prefix are caught. Local access is refused outright when
``allowed_local_path`` is unset (the default).
Raises ``UrlValidationError`` if the feature is off or the resolved
path escapes the prefix.
"""
if not policy.allowed_local_path:
raise UrlValidationError(
"Local media paths are not permitted; set " "DYN_MM_LOCAL_PATH to enable"
)
try:
resolved = Path(path).expanduser().resolve(strict=True)
except FileNotFoundError as exc:
raise UrlValidationError(f"File not found: {path}") from exc
except OSError as exc:
raise UrlValidationError(f"Could not resolve path '{path}': {exc}") from exc
try:
allowed = Path(policy.allowed_local_path).expanduser().resolve(strict=True)
except FileNotFoundError as exc:
raise UrlValidationError(
f"Configured allowed_local_path does not exist: {policy.allowed_local_path}"
) from exc
try:
resolved.relative_to(allowed)
except ValueError as exc:
raise UrlValidationError(
f"Path '{path}' is outside the allowed directory '{policy.allowed_local_path}'"
) from exc
return resolved
async def validate_media_url(url: str, policy: UrlValidationPolicy) -> str:
"""Validate any media input and return a canonical URL string.
Bare filesystem paths and ``file://`` URIs go through
``validate_local_path`` and come back as a resolved ``file://`` URI.
Everything else goes through ``validate_url`` and is returned
unchanged. Callers can still reject the result afterwards —
``ImageLoader``, for example, refuses local files regardless.
Raises ``UrlValidationError`` on any policy violation.
"""
if not url:
raise UrlValidationError("URL is empty")
parsed = urlparse(url)
scheme = parsed.scheme.lower()
if scheme in ("", "file"):
raw_path = parsed.path if scheme == "file" else url
resolved = validate_local_path(raw_path, policy)
return resolved.as_uri()
return await validate_url(url, policy)
_MAX_REDIRECTS = 3
async def fetch_with_revalidation(
client: httpx.AsyncClient,
url: str,
policy: UrlValidationPolicy,
) -> httpx.Response:
"""Safely fetch a URL while checking security policy at every redirect.
Only ``_MAX_REDIRECTS`` hops allowed. ``client`` must have
``follow_redirects=False`` (the default from ``get_http_client``).
We follow redirects ourselves and validate each ``Location`` header
against the policy first.
Only plain ``GET`` with no custom headers is supported. httpx normally
strips credentials on cross-origin redirects only when
``follow_redirects=True``.
Raises ``UrlValidationError`` on any policy violation or when the
redirect chain exceeds ``_MAX_REDIRECTS``.
"""
current_url = url
hops_remaining = _MAX_REDIRECTS
visited: list[str] = []
while True:
await validate_url(current_url, policy)
visited.append(current_url)
request = client.build_request("GET", current_url)
response = await client.send(request, follow_redirects=False)
if not response.is_redirect:
return response
location = response.headers.get("location")
if not location:
return response
if hops_remaining <= 0:
await response.aclose()
raise UrlValidationError(
f"Too many redirects (max={_MAX_REDIRECTS}); chain={visited}"
)
hops_remaining -= 1
next_url = str(response.url.join(location))
await response.aclose()
current_url = next_url
...@@ -16,13 +16,18 @@ ...@@ -16,13 +16,18 @@
import asyncio import asyncio
import logging import logging
import os import os
from pathlib import Path
from typing import Any, Awaitable, Dict, Final, List from typing import Any, Awaitable, Dict, Final, List
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
import dynamo.nixl_connect as nixl_connect import dynamo.nixl_connect as nixl_connect
from dynamo.common.multimodal.http_client import get_http_client
from dynamo.common.multimodal.url_validator import (
UrlValidationPolicy,
fetch_with_revalidation,
validate_media_url,
)
from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl
from dynamo.common.utils.runtime import run_async from dynamo.common.utils.runtime import run_async
...@@ -53,33 +58,27 @@ class VideoLoader: ...@@ -53,33 +58,27 @@ class VideoLoader:
http_timeout: float = 60.0, http_timeout: float = 60.0,
num_frames: int = NUM_FRAMES_DEFAULT, num_frames: int = NUM_FRAMES_DEFAULT,
enable_frontend_decoding: bool = False, enable_frontend_decoding: bool = False,
url_policy: UrlValidationPolicy | None = None,
) -> None: ) -> None:
self._http_timeout = int(http_timeout) self._http_timeout = int(http_timeout)
self._num_frames = num_frames self._num_frames = num_frames
self._enable_frontend_decoding = enable_frontend_decoding self._enable_frontend_decoding = enable_frontend_decoding
self._url_policy = url_policy or UrlValidationPolicy.from_env()
self._nixl_connector = None self._nixl_connector = None
self._vllm_media_connector = None self._vllm_media_connector = None
if self._enable_frontend_decoding: if self._enable_frontend_decoding:
self._nixl_connector = nixl_connect.Connector() self._nixl_connector = nixl_connect.Connector()
run_async(self._nixl_connector.initialize) run_async(self._nixl_connector.initialize)
@staticmethod
def _normalize_video_url(video_url: str) -> str:
parsed_url = urlparse(video_url)
if parsed_url.scheme or not video_url:
return video_url
file_path = Path(video_url).expanduser()
if not file_path.exists():
raise FileNotFoundError(f"Error reading file: {file_path}")
return file_path.resolve().as_uri()
def _get_vllm_media_connector(self) -> Any: def _get_vllm_media_connector(self) -> Any:
if self._vllm_media_connector is None: if self._vllm_media_connector is None:
MediaConnector, _, _ = _require_vllm_video_media() MediaConnector, _, _ = _require_vllm_video_media()
# Match the previous backend behavior and allow direct local file paths. # Confine vLLM's own local-path access to the same prefix we enforce.
self._vllm_media_connector = MediaConnector(allowed_local_media_path="/") # Empty string matches vLLM's secure default (no local access).
allowed = self._url_policy.allowed_local_path or ""
self._vllm_media_connector = MediaConnector(
allowed_local_media_path=allowed
)
return self._vllm_media_connector return self._vllm_media_connector
...@@ -93,14 +92,23 @@ class VideoLoader: ...@@ -93,14 +92,23 @@ class VideoLoader:
async def _load_video_with_vllm( async def _load_video_with_vllm(
self, video_url: str self, video_url: str
) -> tuple[np.ndarray, Dict[str, Any]]: ) -> tuple[np.ndarray, Dict[str, Any]]:
normalized_url = await validate_media_url(video_url, self._url_policy)
media_io = self._create_vllm_video_io()
# HTTP(S) goes through our SSRF-safe fetcher so each redirect hop is
# revalidated; vLLM's own fetcher honors redirects without re-checking.
# data: and file:// never touch the network, so vLLM can handle them.
if urlparse(normalized_url).scheme in ("http", "https"):
http_client = get_http_client(self._http_timeout)
response = await fetch_with_revalidation(
http_client, normalized_url, self._url_policy
)
response.raise_for_status()
return await asyncio.to_thread(media_io.load_bytes, response.content)
connector = self._get_vllm_media_connector() connector = self._get_vllm_media_connector()
normalized_url = self._normalize_video_url(video_url)
# TODO: Add caching for repeated remote `video_url` downloads to avoid
# refetching the same asset across requests.
return await connector.load_from_url_async( return await connector.load_from_url_async(
normalized_url, normalized_url, media_io, fetch_timeout=self._http_timeout
self._create_vllm_video_io(),
fetch_timeout=self._http_timeout,
) )
async def load_video(self, video_url: str) -> tuple[np.ndarray, Dict[str, Any]]: async def load_video(self, video_url: str) -> tuple[np.ndarray, Dict[str, Any]]:
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from unittest.mock import AsyncMock from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import numpy as np import numpy as np
import pytest import pytest
import dynamo.common.multimodal.audio_loader as audio_loader_module import dynamo.common.multimodal.audio_loader as audio_loader_module
from dynamo.common.multimodal.audio_loader import AudioLoader from dynamo.common.multimodal.audio_loader import AudioLoader
from dynamo.common.multimodal.url_validator import (
UrlValidationError,
UrlValidationPolicy,
validate_media_url,
)
pytestmark = [ pytestmark = [
pytest.mark.unit, pytest.mark.unit,
...@@ -16,34 +22,116 @@ pytestmark = [ ...@@ -16,34 +22,116 @@ pytestmark = [
] ]
def test_normalize_audio_url_converts_local_paths(tmp_path): def _permissive_http_policy() -> UrlValidationPolicy:
"""Policy that lets existing tests keep using https://example.com/... URLs.
Private/loopback IPs and DNS checks are bypassed so tests don't depend on
real DNS resolution of example.com.
"""
return UrlValidationPolicy(
allow_http=True,
allow_private_ips=True,
)
async def test_normalize_audio_url_converts_local_paths(tmp_path):
audio_path = tmp_path / "sample.wav" audio_path = tmp_path / "sample.wav"
audio_path.write_bytes(b"RIFF") audio_path.write_bytes(b"RIFF")
policy = UrlValidationPolicy(allowed_local_path=str(tmp_path))
assert ( assert (
AudioLoader._normalize_audio_url(str(audio_path)) await validate_media_url(str(audio_path), policy)
== audio_path.resolve().as_uri() == audio_path.resolve().as_uri()
) )
def test_normalize_audio_url_preserves_data_urls(): async def test_normalize_audio_url_preserves_data_urls():
data_url = "data:audio/wav;base64,UklGRg==" data_url = "data:audio/wav;base64,UklGRg=="
assert AudioLoader._normalize_audio_url(data_url) == data_url policy = UrlValidationPolicy()
assert await validate_media_url(data_url, policy) == data_url
def test_normalize_audio_url_preserves_http_urls(): async def test_normalize_audio_url_preserves_http_urls():
url = "https://example.com/audio.wav" url = "https://example.com/audio.wav"
assert AudioLoader._normalize_audio_url(url) == url policy = _permissive_http_policy()
assert await validate_media_url(url, policy) == url
async def test_normalize_audio_url_rejects_bare_path_by_default(tmp_path):
audio_path = tmp_path / "sample.wav"
audio_path.write_bytes(b"RIFF")
policy = UrlValidationPolicy()
with pytest.raises(UrlValidationError, match="Local media paths are not permitted"):
await validate_media_url(str(audio_path), policy)
async def test_normalize_audio_url_rejects_private_ip():
policy = UrlValidationPolicy()
with pytest.raises(UrlValidationError):
await validate_media_url("https://169.254.169.254/audio.wav", policy)
def test_normalize_audio_url_raises_on_missing_file(): async def test_normalize_audio_url_accepts_file_uri_inside_prefix(tmp_path):
with pytest.raises(FileNotFoundError, match="Error reading file"): audio_path = tmp_path / "sample.wav"
AudioLoader._normalize_audio_url("/nonexistent/audio.wav") audio_path.write_bytes(b"RIFF")
policy = UrlValidationPolicy(allowed_local_path=str(tmp_path))
file_uri = audio_path.resolve().as_uri()
assert await validate_media_url(file_uri, policy) == file_uri
async def test_normalize_audio_url_rejects_file_uri_outside_prefix(tmp_path):
allowed = tmp_path / "media"
allowed.mkdir()
other = tmp_path / "secret.wav"
other.write_bytes(b"RIFF")
policy = UrlValidationPolicy(allowed_local_path=str(allowed))
with pytest.raises(UrlValidationError, match="outside the allowed directory"):
await validate_media_url(other.resolve().as_uri(), policy)
@pytest.mark.asyncio
async def test_load_audio_rejects_http_by_default():
loader = AudioLoader(url_policy=UrlValidationPolicy())
with pytest.raises(ValueError, match="not allowed"):
await loader.load_audio("http://example.com/x.wav")
@pytest.mark.asyncio
async def test_load_audio_blocks_redirect_to_private_ip():
"""A 302 to a blocked IP must be rejected per-hop, not only the initial URL."""
loader = AudioLoader(url_policy=UrlValidationPolicy())
loader._create_vllm_audio_io = MagicMock(return_value=MagicMock()) # type: ignore[method-assign]
redirect = MagicMock(spec=httpx.Response)
redirect.status_code = 302
redirect.is_redirect = True
redirect.headers = {"location": "https://169.254.169.254/evil"}
redirect.url = httpx.URL("https://8.8.8.8/a.wav")
redirect.aclose = AsyncMock()
client = MagicMock(spec=httpx.AsyncClient)
client.build_request = MagicMock(return_value=MagicMock(spec=httpx.Request))
client.send = AsyncMock(return_value=redirect)
with patch(
"dynamo.common.multimodal.audio_loader.get_http_client",
return_value=client,
):
with pytest.raises(ValueError, match="blocked range"):
await loader.load_audio("https://8.8.8.8/a.wav")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_audio_uses_vllm_media_connector(): async def test_load_audio_uses_vllm_media_connector():
loader = AudioLoader() loader = AudioLoader()
loader._url_policy = UrlValidationPolicy()
waveform = np.random.randn(16000).astype(np.float32) waveform = np.random.randn(16000).astype(np.float32)
sr = 44100.0 sr = 44100.0
loader._load_audio_with_vllm = AsyncMock( # type: ignore[method-assign] loader._load_audio_with_vllm = AsyncMock( # type: ignore[method-assign]
...@@ -60,7 +148,7 @@ async def test_load_audio_uses_vllm_media_connector(): ...@@ -60,7 +148,7 @@ async def test_load_audio_uses_vllm_media_connector():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_audio_rejects_empty_waveform(): async def test_load_audio_rejects_empty_waveform():
loader = AudioLoader() loader = AudioLoader(url_policy=_permissive_http_policy())
loader._load_audio_with_vllm = AsyncMock( # type: ignore[method-assign] loader._load_audio_with_vllm = AsyncMock( # type: ignore[method-assign]
return_value=(np.array([], dtype=np.float32), 16000.0) return_value=(np.array([], dtype=np.float32), 16000.0)
) )
......
...@@ -25,6 +25,7 @@ import pytest ...@@ -25,6 +25,7 @@ import pytest
from PIL import Image from PIL import Image
from dynamo.common.multimodal.image_loader import ImageLoader from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.common.multimodal.url_validator import UrlValidationPolicy
pytestmark = [ pytestmark = [
pytest.mark.asyncio, pytest.mark.asyncio,
...@@ -45,29 +46,42 @@ def _make_png_bytes() -> bytes: ...@@ -45,29 +46,42 @@ def _make_png_bytes() -> bytes:
PNG_BYTES = _make_png_bytes() PNG_BYTES = _make_png_bytes()
def _permissive_policy(
allowed_local_path: str | None = None,
) -> UrlValidationPolicy:
"""Return a policy that permits the schemes used by tests without DNS hits."""
return UrlValidationPolicy(
allow_http=True,
allow_private_ips=True,
allowed_local_path=allowed_local_path,
)
def _mock_http_client( def _mock_http_client(
content: bytes = PNG_BYTES, content: bytes = PNG_BYTES,
status_code: int = 200, status_code: int = 200,
delay: float = 0.0, delay: float = 0.0,
side_effect: Exception | None = None, side_effect: Exception | None = None,
) -> AsyncMock: ) -> AsyncMock:
"""Return a mock httpx.AsyncClient whose .get() returns a fake response. """Return a mock httpx.AsyncClient compatible with fetch_with_revalidation.
``fetch_with_revalidation`` uses ``client.build_request(...)`` then
``client.send(request, follow_redirects=False)``. We stub both and also
keep ``client.get`` behaviour for any legacy callers.
Args: Args:
content: Raw bytes returned as the HTTP response body. content: Raw bytes returned as the HTTP response body.
status_code: HTTP status code; >=400 triggers raise_for_status(). status_code: HTTP status code; >=400 triggers raise_for_status().
delay: Seconds to sleep before responding (simulates network latency). delay: Seconds to sleep before responding (simulates network latency).
side_effect: If set, .get() raises this exception instead of returning. side_effect: If set, .send()/.get() raises this exception instead.
""" """
async def _get(url: str) -> Any: def _build_response() -> Any:
if delay > 0:
await asyncio.sleep(delay)
if side_effect is not None:
raise side_effect
resp = MagicMock(spec=httpx.Response) resp = MagicMock(spec=httpx.Response)
resp.status_code = status_code resp.status_code = status_code
resp.content = content resp.content = content
resp.is_redirect = False
resp.headers = {}
resp.raise_for_status = MagicMock() resp.raise_for_status = MagicMock()
if status_code >= 400: if status_code >= 400:
resp.raise_for_status.side_effect = httpx.HTTPStatusError( resp.raise_for_status.side_effect = httpx.HTTPStatusError(
...@@ -75,14 +89,27 @@ def _mock_http_client( ...@@ -75,14 +89,27 @@ def _mock_http_client(
) )
return resp return resp
async def _respond(*_args: Any, **_kwargs: Any) -> Any:
if delay > 0:
await asyncio.sleep(delay)
if side_effect is not None:
raise side_effect
return _build_response()
client = AsyncMock() client = AsyncMock()
client.get = AsyncMock(side_effect=_get) client.get = AsyncMock(side_effect=_respond)
client.build_request = MagicMock(return_value=MagicMock(spec=httpx.Request))
client.send = AsyncMock(side_effect=_respond)
return client return client
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def loader() -> ImageLoader: def loader() -> ImageLoader:
return ImageLoader(cache_size=4, http_timeout=30.0) return ImageLoader(
cache_size=4,
http_timeout=30.0,
url_policy=_permissive_policy(),
)
# --- Concurrent same-URL dedup --- # --- Concurrent same-URL dedup ---
...@@ -103,7 +130,7 @@ async def test_concurrent_same_url_deduplicates(loader: ImageLoader) -> None: ...@@ -103,7 +130,7 @@ async def test_concurrent_same_url_deduplicates(loader: ImageLoader) -> None:
assert len(results) == 2 assert len(results) == 2
assert results[0].size == results[1].size assert results[0].size == results[1].size
# Only one HTTP GET should have been issued # Only one HTTP GET should have been issued
assert mock_client.get.call_count == 1 assert mock_client.send.call_count == 1
async def test_concurrent_different_urls_fetch_independently( async def test_concurrent_different_urls_fetch_independently(
...@@ -120,7 +147,7 @@ async def test_concurrent_different_urls_fetch_independently( ...@@ -120,7 +147,7 @@ async def test_concurrent_different_urls_fetch_independently(
loader.load_image("https://example.com/b.png"), loader.load_image("https://example.com/b.png"),
) )
assert mock_client.get.call_count == 2 assert mock_client.send.call_count == 2
# --- Waiter cancellation isolation --- # --- Waiter cancellation isolation ---
...@@ -206,6 +233,46 @@ async def test_data_url_non_image_rejected(loader: ImageLoader) -> None: ...@@ -206,6 +233,46 @@ async def test_data_url_non_image_rejected(loader: ImageLoader) -> None:
await loader.load_image("data:text/plain;base64,aGVsbG8=") await loader.load_image("data:text/plain;base64,aGVsbG8=")
# --- SSRF / scheme rejection ---
async def test_http_scheme_rejected_by_default(monkeypatch) -> None:
"""With default env policy, http:// URLs must be rejected before any fetch."""
monkeypatch.delenv("DYN_MM_ALLOW_INTERNAL", raising=False)
monkeypatch.delenv("DYN_MM_LOCAL_PATH", raising=False)
default_loader = ImageLoader(cache_size=4, http_timeout=30.0)
mock_client = _mock_http_client()
with patch(
"dynamo.common.multimodal.image_loader.get_http_client",
return_value=mock_client,
):
with pytest.raises(ValueError, match="scheme|not allowed"):
await default_loader.load_image("http://example.com/x.png")
# The shared HTTP client must not be touched when the URL is rejected.
assert mock_client.send.call_count == 0
assert mock_client.get.call_count == 0
async def test_blocked_private_ip_rejected(monkeypatch) -> None:
"""Cloud metadata / private IPs must be rejected even over https."""
monkeypatch.delenv("DYN_MM_ALLOW_INTERNAL", raising=False)
strict_loader = ImageLoader(
cache_size=4,
http_timeout=30.0,
url_policy=UrlValidationPolicy(
allow_http=True,
allow_private_ips=False,
),
)
with pytest.raises(ValueError, match="blocked range"):
await strict_loader.load_image("https://169.254.169.254/latest/meta-data/")
# --- HTTP error contract --- # --- HTTP error contract ---
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for ``dynamo.common.multimodal.url_validator``.
These cover scheme / IP / hostname / path / redirect logic in isolation of
the media loaders, so they run quickly with no network and no vLLM imports.
"""
from __future__ import annotations
import socket
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from dynamo.common.multimodal.url_validator import (
UrlValidationError,
UrlValidationPolicy,
fetch_with_revalidation,
is_blocked_ip,
validate_local_path,
validate_url,
)
pytestmark = [
pytest.mark.unit,
pytest.mark.pre_merge,
pytest.mark.gpu_0,
]
STRICT_HTTPS = UrlValidationPolicy()
PERMISSIVE = UrlValidationPolicy(
allow_http=True,
allow_private_ips=True,
)
# ---------------------------------------------------------------------------
# is_blocked_ip()
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"ip",
[
"127.0.0.1",
"10.0.0.1",
"172.16.5.5",
"192.168.1.1",
"169.254.169.254", # AWS metadata
"100.64.0.1", # CGNAT
"0.0.0.0",
"::1",
"fe80::1",
"fc00::1",
"240.0.0.1", # reserved
],
)
def test_is_blocked_ip_blocks_known_ranges(ip: str) -> None:
assert is_blocked_ip(ip) is True
@pytest.mark.parametrize(
"ip",
[
"8.8.8.8",
"1.1.1.1",
"93.184.216.34", # example.com
"2606:4700:4700::1111", # Cloudflare
],
)
def test_is_blocked_ip_allows_public(ip: str) -> None:
assert is_blocked_ip(ip) is False
def test_is_blocked_ip_non_ip_literal_returns_false() -> None:
# A hostname, not an IP — is_blocked_ip only classifies literals.
assert is_blocked_ip("example.com") is False
# ---------------------------------------------------------------------------
# validate_url() — scheme handling
# ---------------------------------------------------------------------------
async def test_validate_url_rejects_empty() -> None:
with pytest.raises(UrlValidationError, match="empty"):
await validate_url("", STRICT_HTTPS)
async def test_validate_url_rejects_http_by_default() -> None:
with pytest.raises(UrlValidationError, match="not allowed"):
await validate_url("http://example.com/x.png", STRICT_HTTPS)
async def test_validate_url_rejects_file_scheme() -> None:
with pytest.raises(UrlValidationError, match="scheme 'file'"):
await validate_url("file:///etc/passwd", STRICT_HTTPS)
async def test_validate_url_rejects_ftp() -> None:
with pytest.raises(UrlValidationError, match="scheme 'ftp'"):
await validate_url("ftp://example.com/x.png", STRICT_HTTPS)
async def test_validate_url_accepts_data_url_by_default() -> None:
# data: URLs never touch the network — we allow them without further checks.
await validate_url("data:image/png;base64,iVBORw0KGgoAAAA=", STRICT_HTTPS)
async def test_validate_url_http_allowed_when_opted_in() -> None:
policy = PERMISSIVE
# With public hostname + allow_private_ips=True (keeps DNS out of the test)
await validate_url("http://example.com/x.png", policy)
# ---------------------------------------------------------------------------
# validate_url() — IP literal handling
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"url",
[
"https://127.0.0.1/x.png",
"https://169.254.169.254/latest/meta-data/",
"https://10.0.0.5/x.png",
"https://192.168.1.1/x.png",
"https://[::1]/x.png",
],
)
async def test_validate_url_rejects_private_ip_literal(url: str) -> None:
with pytest.raises(UrlValidationError, match="blocked range"):
await validate_url(url, STRICT_HTTPS)
async def test_validate_url_allows_public_ip_literal() -> None:
# Public IP literal and allow_private_ips=False — should pass the IP test
# and skip DNS resolution (the IP path short-circuits).
await validate_url("https://8.8.8.8/x.png", STRICT_HTTPS)
async def test_validate_url_allows_private_ip_when_opted_in() -> None:
await validate_url("https://127.0.0.1/x.png", PERMISSIVE)
# ---------------------------------------------------------------------------
# validate_url() — blocked hostnames
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"host",
[
"localhost",
"metadata.google.internal",
"metadata",
"kubernetes.default.svc",
],
)
async def test_validate_url_rejects_blocked_hostname(host: str) -> None:
with pytest.raises(UrlValidationError, match="blocked"):
await validate_url(f"https://{host}/path", STRICT_HTTPS)
# ---------------------------------------------------------------------------
# validate_url() — DNS resolution
# ---------------------------------------------------------------------------
def _fake_getaddrinfo(addrs: list[str]):
def _impl(host: str, *_args: Any, **_kwargs: Any):
return [
(socket.AF_INET, socket.SOCK_STREAM, 6, "", (addr, 0)) for addr in addrs
]
return _impl
async def test_validate_url_rejects_host_resolving_to_private_ip() -> None:
with patch(
"dynamo.common.multimodal.url_validator.socket.getaddrinfo",
side_effect=_fake_getaddrinfo(["10.0.0.5"]),
):
with pytest.raises(UrlValidationError, match="blocked IP"):
await validate_url("https://attacker.example.com/x.png", STRICT_HTTPS)
async def test_validate_url_rejects_host_if_any_ip_is_private() -> None:
# Even if the host resolves to a public IP too, any blocked IP is fatal.
with patch(
"dynamo.common.multimodal.url_validator.socket.getaddrinfo",
side_effect=_fake_getaddrinfo(["8.8.8.8", "169.254.169.254"]),
):
with pytest.raises(UrlValidationError, match="169.254.169.254"):
await validate_url("https://mixed.example.com/x.png", STRICT_HTTPS)
async def test_validate_url_accepts_public_host() -> None:
with patch(
"dynamo.common.multimodal.url_validator.socket.getaddrinfo",
side_effect=_fake_getaddrinfo(["93.184.216.34"]),
):
await validate_url("https://example.com/x.png", STRICT_HTTPS)
async def test_validate_url_resolution_failure_raises() -> None:
with patch(
"dynamo.common.multimodal.url_validator.socket.getaddrinfo",
side_effect=socket.gaierror("nodename nor servname provided"),
):
with pytest.raises(UrlValidationError, match="Could not resolve"):
await validate_url("https://does-not-exist.invalid/x.png", STRICT_HTTPS)
async def test_validate_url_skips_resolution_when_private_allowed() -> None:
# In developer mode we short-circuit DNS to keep tests deterministic.
with patch("dynamo.common.multimodal.url_validator.socket.getaddrinfo") as resolver:
await validate_url("https://example.com/x.png", PERMISSIVE)
resolver.assert_not_called()
# ---------------------------------------------------------------------------
# validate_local_path()
# ---------------------------------------------------------------------------
def test_validate_local_path_rejected_when_disabled() -> None:
with pytest.raises(UrlValidationError, match="not permitted"):
validate_local_path("/etc/passwd", STRICT_HTTPS)
def test_validate_local_path_accepts_inside_prefix(tmp_path) -> None:
media = tmp_path / "media"
media.mkdir()
target = media / "sample.png"
target.write_bytes(b"\x89PNG\r\n")
policy = UrlValidationPolicy(allowed_local_path=str(media))
resolved = validate_local_path(str(target), policy)
assert resolved == target.resolve()
def test_validate_local_path_rejects_outside_prefix(tmp_path) -> None:
media = tmp_path / "media"
media.mkdir()
other = tmp_path / "secrets"
other.mkdir()
secret = other / "creds.txt"
secret.write_text("hunter2")
policy = UrlValidationPolicy(allowed_local_path=str(media))
with pytest.raises(UrlValidationError, match="outside the allowed directory"):
validate_local_path(str(secret), policy)
def test_validate_local_path_rejects_symlink_escape(tmp_path) -> None:
media = tmp_path / "media"
media.mkdir()
outside = tmp_path / "outside.txt"
outside.write_text("secret")
link = media / "link.png"
link.symlink_to(outside)
policy = UrlValidationPolicy(allowed_local_path=str(media))
# Path.resolve() follows the symlink; the target is outside the prefix.
with pytest.raises(UrlValidationError, match="outside the allowed directory"):
validate_local_path(str(link), policy)
def test_validate_local_path_missing_file(tmp_path) -> None:
policy = UrlValidationPolicy(allowed_local_path=str(tmp_path))
with pytest.raises(UrlValidationError, match="File not found"):
validate_local_path(str(tmp_path / "nope.png"), policy)
def test_validate_local_path_missing_prefix(tmp_path) -> None:
target = tmp_path / "sample.png"
target.write_bytes(b"x")
policy = UrlValidationPolicy(allowed_local_path=str(tmp_path / "does-not-exist"))
with pytest.raises(UrlValidationError, match="allowed_local_path does not exist"):
validate_local_path(str(target), policy)
# ---------------------------------------------------------------------------
# UrlValidationPolicy.from_env()
# ---------------------------------------------------------------------------
def test_policy_from_env_defaults(monkeypatch) -> None:
monkeypatch.delenv("DYN_MM_ALLOW_INTERNAL", raising=False)
monkeypatch.delenv("DYN_MM_LOCAL_PATH", raising=False)
policy = UrlValidationPolicy.from_env()
assert policy.allow_http is False
assert policy.allow_private_ips is False
assert policy.allowed_local_path is None
def test_policy_from_env_allow_internal(monkeypatch) -> None:
monkeypatch.setenv("DYN_MM_ALLOW_INTERNAL", "1")
monkeypatch.setenv("DYN_MM_LOCAL_PATH", "/data/media")
policy = UrlValidationPolicy.from_env()
assert policy.allow_http is True
assert policy.allow_private_ips is True
assert policy.allowed_local_path == "/data/media"
# ---------------------------------------------------------------------------
# fetch_with_revalidation()
# ---------------------------------------------------------------------------
def _mock_response(
status_code: int = 200,
location: str | None = None,
*,
request_url: str = "https://example.com/x.png",
) -> MagicMock:
resp = MagicMock(spec=httpx.Response)
resp.status_code = status_code
resp.headers = {}
resp.url = httpx.URL(request_url)
resp.is_redirect = status_code in (301, 302, 303, 307, 308)
if location is not None:
resp.headers = {"location": location}
resp.aclose = AsyncMock()
return resp
def _mock_client(responses: list[MagicMock]) -> MagicMock:
client = MagicMock(spec=httpx.AsyncClient)
client.build_request = MagicMock(
side_effect=lambda method, url, headers=None: MagicMock(spec=httpx.Request)
)
client.send = AsyncMock(side_effect=list(responses))
return client
@pytest.mark.asyncio
async def test_fetch_with_revalidation_returns_first_response() -> None:
policy = PERMISSIVE
resp = _mock_response(status_code=200)
client = _mock_client([resp])
result = await fetch_with_revalidation(client, "https://example.com/x.png", policy)
assert result is resp
assert client.send.await_count == 1
@pytest.mark.asyncio
async def test_fetch_with_revalidation_follows_safe_redirect() -> None:
policy = PERMISSIVE
redirect = _mock_response(
status_code=302,
location="https://example.com/final.png",
request_url="https://example.com/x.png",
)
final = _mock_response(status_code=200, request_url="https://example.com/final.png")
client = _mock_client([redirect, final])
result = await fetch_with_revalidation(client, "https://example.com/x.png", policy)
assert result is final
assert client.send.await_count == 2
redirect.aclose.assert_awaited()
@pytest.mark.asyncio
async def test_fetch_with_revalidation_blocks_redirect_to_private_ip() -> None:
# Strict policy — first hop is OK (public-IP literal), redirect target is blocked.
strict = UrlValidationPolicy(allow_private_ips=False)
redirect = _mock_response(
status_code=302,
location="http://169.254.169.254/latest/meta-data/",
request_url="https://8.8.8.8/x.png",
)
client = _mock_client([redirect])
with pytest.raises(UrlValidationError):
await fetch_with_revalidation(client, "https://8.8.8.8/x.png", strict)
# Only one send — the redirect target is rejected before any further fetch.
assert client.send.await_count == 1
@pytest.mark.asyncio
async def test_fetch_with_revalidation_enforces_redirect_limit() -> None:
# _MAX_REDIRECTS is hardcoded at 3; we need 4 redirect responses to trip it.
policy = UrlValidationPolicy(allow_private_ips=True) # keep DNS out of this test
def _hop(src: str, dst: str) -> MagicMock:
return _mock_response(status_code=302, location=dst, request_url=src)
client = _mock_client(
[
_hop("https://example.com/a", "https://example.com/b"),
_hop("https://example.com/b", "https://example.com/c"),
_hop("https://example.com/c", "https://example.com/d"),
_hop("https://example.com/d", "https://example.com/e"),
]
)
with pytest.raises(UrlValidationError, match="Too many redirects"):
await fetch_with_revalidation(client, "https://example.com/a", policy)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from unittest.mock import AsyncMock from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import numpy as np import numpy as np
import pytest import pytest
import dynamo.common.multimodal.video_loader as video_loader_module import dynamo.common.multimodal.video_loader as video_loader_module
from dynamo.common.multimodal.url_validator import (
UrlValidationError,
UrlValidationPolicy,
validate_media_url,
)
from dynamo.common.multimodal.video_loader import VideoLoader from dynamo.common.multimodal.video_loader import VideoLoader
pytestmark = [ pytestmark = [
...@@ -16,25 +22,103 @@ pytestmark = [ ...@@ -16,25 +22,103 @@ pytestmark = [
] ]
def test_normalize_video_url_converts_local_paths(tmp_path): async def test_normalize_video_url_converts_local_paths(tmp_path):
video_path = tmp_path / "sample.webm" video_path = tmp_path / "sample.webm"
video_path.write_bytes(b"video") video_path.write_bytes(b"video")
policy = UrlValidationPolicy(allowed_local_path=str(tmp_path))
assert ( assert (
VideoLoader._normalize_video_url(str(video_path)) await validate_media_url(str(video_path), policy)
== video_path.resolve().as_uri() == video_path.resolve().as_uri()
) )
def test_normalize_video_url_preserves_data_urls(): async def test_normalize_video_url_preserves_data_urls():
data_url = "data:video/webm;base64,Zm9v" data_url = "data:video/webm;base64,Zm9v"
policy = UrlValidationPolicy()
assert await validate_media_url(data_url, policy) == data_url
async def test_normalize_video_url_rejects_bare_path_by_default(tmp_path):
video_path = tmp_path / "sample.webm"
video_path.write_bytes(b"video")
# Default policy has no allowed_local_path -> local paths rejected.
policy = UrlValidationPolicy()
assert VideoLoader._normalize_video_url(data_url) == data_url with pytest.raises(UrlValidationError, match="Local media paths are not permitted"):
await validate_media_url(str(video_path), policy)
async def test_normalize_video_url_rejects_private_ip():
policy = UrlValidationPolicy()
with pytest.raises(UrlValidationError):
await validate_media_url("https://169.254.169.254/video.mp4", policy)
async def test_normalize_video_url_accepts_file_uri_inside_prefix(tmp_path):
video_path = tmp_path / "sample.webm"
video_path.write_bytes(b"video")
policy = UrlValidationPolicy(allowed_local_path=str(tmp_path))
file_uri = video_path.resolve().as_uri()
assert await validate_media_url(file_uri, policy) == file_uri
async def test_normalize_video_url_rejects_file_uri_outside_prefix(tmp_path):
allowed = tmp_path / "media"
allowed.mkdir()
other = tmp_path / "secret.webm"
other.write_bytes(b"video")
policy = UrlValidationPolicy(allowed_local_path=str(allowed))
with pytest.raises(UrlValidationError, match="outside the allowed directory"):
await validate_media_url(other.resolve().as_uri(), policy)
@pytest.mark.asyncio
async def test_load_video_rejects_http_by_default():
# Default env policy: http is disabled, so validation should reject this
# before any fetch is attempted.
loader = VideoLoader(url_policy=UrlValidationPolicy())
with pytest.raises(ValueError, match="not allowed"):
await loader.load_video("http://example.com/x.mp4")
@pytest.mark.asyncio
async def test_load_video_blocks_redirect_to_private_ip():
"""A 302 to a blocked IP must be rejected per-hop, not only the initial URL."""
loader = VideoLoader(url_policy=UrlValidationPolicy())
loader._create_vllm_video_io = MagicMock(return_value=MagicMock()) # type: ignore[method-assign]
redirect = MagicMock(spec=httpx.Response)
redirect.status_code = 302
redirect.is_redirect = True
redirect.headers = {"location": "https://169.254.169.254/evil"}
redirect.url = httpx.URL("https://8.8.8.8/v.mp4")
redirect.aclose = AsyncMock()
client = MagicMock(spec=httpx.AsyncClient)
client.build_request = MagicMock(return_value=MagicMock(spec=httpx.Request))
client.send = AsyncMock(return_value=redirect)
with patch(
"dynamo.common.multimodal.video_loader.get_http_client",
return_value=client,
):
with pytest.raises(ValueError, match="blocked range"):
await loader.load_video("https://8.8.8.8/v.mp4")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_video_uses_vllm_media_connector(): async def test_load_video_uses_vllm_media_connector():
loader = VideoLoader() loader = VideoLoader()
# data: scheme is in the default allowlist regardless of env flags.
loader._url_policy = UrlValidationPolicy()
frames = np.arange(24, dtype=np.uint8).reshape(1, 2, 4, 3)[:, :, ::-1, :] frames = np.arange(24, dtype=np.uint8).reshape(1, 2, 4, 3)[:, :, ::-1, :]
metadata = {"fps": 4.0, "frames_indices": [0], "total_num_frames": 1} metadata = {"fps": 4.0, "frames_indices": [0], "total_num_frames": 1}
loader._load_video_with_vllm = AsyncMock( # type: ignore[method-assign] loader._load_video_with_vllm = AsyncMock( # type: ignore[method-assign]
......
...@@ -48,6 +48,25 @@ Dynamo provides support for improving latency and throughput for vision-and-lang ...@@ -48,6 +48,25 @@ Dynamo provides support for improving latency and throughput for vision-and-lang
**Status:** ✅ Supported | 🧪 Experimental | ❌ Not supported **Status:** ✅ Supported | 🧪 Experimental | ❌ Not supported
## Security: URL Validation
All multimodal loaders route remote fetches through a shared URL policy
(`dynamo.common.multimodal.url_validator`). Only
`https://` and `data:` URLs are allowed by default, private / internal IPs are blocked,
and local file access is disabled. Every HTTP redirect hop is re-validated
against the policy.
Two environment variables loosen the defaults for non-public deployments:
| Variable | Default | Effect |
|----------|---------|--------|
| `DYN_MM_ALLOW_INTERNAL` | `0` | Set to `1` to allow `http://` and private / internal IP targets. Intended for on-prem or local-dev setups where media lives on an internal network. |
| `DYN_MM_LOCAL_PATH` | *(empty)* | Absolute directory prefix. When set, `file://` URIs and bare paths are allowed if they resolve inside this prefix. |
<Warning>
**Never set `DYN_MM_ALLOW_INTERNAL=1` on public-facing deployments.** It opens SSRF paths to cloud metadata endpoints (AWS IMDS, GCE, Azure) and other internal services.
</Warning>
## Example Workflows ## Example Workflows
Reference implementations for deploying multimodal models: Reference implementations for deploying multimodal models:
......
...@@ -81,6 +81,7 @@ def _make_process_env(log_level: str = "debug", **extra) -> dict[str, str]: ...@@ -81,6 +81,7 @@ def _make_process_env(log_level: str = "debug", **extra) -> dict[str, str]:
env["DYN_LOG"] = log_level env["DYN_LOG"] = log_level
env["DYN_NAMESPACE"] = NAMESPACE env["DYN_NAMESPACE"] = NAMESPACE
env["DYN_REQUEST_PLANE"] = "tcp" env["DYN_REQUEST_PLANE"] = "tcp"
env["DYN_MM_ALLOW_INTERNAL"] = "1"
env.update(extra) env.update(extra)
return env return env
......
...@@ -179,6 +179,11 @@ def make_multimodal_configs( ...@@ -179,6 +179,11 @@ def make_multimodal_configs(
marks.extend(profile.marks) marks.extend(profile.marks)
key = f"mm_{topology}_{profile.short_name}" key = f"mm_{topology}_{profile.short_name}"
worker_env = {
"DYN_MM_ALLOW_INTERNAL": "1",
"DYN_MM_LOCAL_PATH": str(WORKSPACE_DIR),
**topo_cfg.env,
}
configs[key] = config_cls( configs[key] = config_cls(
name=key, name=key,
directory=topo_cfg.directory or directory, directory=topo_cfg.directory or directory,
...@@ -188,6 +193,6 @@ def make_multimodal_configs( ...@@ -188,6 +193,6 @@ def make_multimodal_configs(
marks=marks, marks=marks,
delayed_start=topo_cfg.delayed_start, delayed_start=topo_cfg.delayed_start,
request_payloads=profile.request_payloads, request_payloads=profile.request_payloads,
env=topo_cfg.env, env=worker_env,
) )
return configs return configs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment