Unverified Commit 8fba4f56 authored by zhongdaor-nv's avatar zhongdaor-nv Committed by GitHub
Browse files

fix: Validate multimodal media URLs and load paths (#8282)


Signed-off-by: default avatarzhongdaor <zhongdaor@nvidia.com>
parent f701319e
......@@ -3,13 +3,18 @@
import asyncio
import logging
from pathlib import Path
from typing import Any, Awaitable, Dict, Final, List
from urllib.parse import urlparse
import numpy as np
import dynamo.nixl_connect as nixl_connect
from dynamo.common.multimodal.http_client import get_http_client
from dynamo.common.multimodal.url_validator import (
UrlValidationPolicy,
fetch_with_revalidation,
validate_media_url,
)
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl
from dynamo.common.utils.runtime import run_async
......@@ -57,37 +62,28 @@ class AudioLoader:
self,
http_timeout: float = 30.0,
enable_frontend_decoding: bool = False,
url_policy: UrlValidationPolicy | None = None,
) -> None:
if http_timeout <= 0:
raise ValueError(f"http_timeout must be positive, got {http_timeout}")
self._http_timeout = http_timeout
self._enable_frontend_decoding = enable_frontend_decoding
self._url_policy = url_policy or UrlValidationPolicy.from_env()
self._nixl_connector = None
self._vllm_media_connector = None
if self._enable_frontend_decoding:
self._nixl_connector = nixl_connect.Connector()
run_async(self._nixl_connector.initialize)
@staticmethod
def _normalize_audio_url(audio_url: str) -> str:
"""Convert bare filesystem paths to file:// URIs.
HTTP(S) and data: URLs are returned unchanged.
"""
parsed_url = urlparse(audio_url)
if parsed_url.scheme or not audio_url:
return audio_url
file_path = Path(audio_url).expanduser()
if not file_path.exists():
raise FileNotFoundError(f"Error reading file: {file_path}")
return file_path.resolve().as_uri()
def _get_vllm_media_connector(self) -> Any:
if self._vllm_media_connector is None:
MediaConnector, _ = _require_vllm_audio_media()
self._vllm_media_connector = MediaConnector(allowed_local_media_path="/")
# Confine vLLM's own local-path access to the same prefix we enforce.
# Empty string matches vLLM's secure default (no local access).
allowed = self._url_policy.allowed_local_path or ""
self._vllm_media_connector = MediaConnector(
allowed_local_media_path=allowed
)
return self._vllm_media_connector
......@@ -97,14 +93,23 @@ class AudioLoader:
@_nvtx.annotate("mm:audio:load_with_vllm", color="cyan")
async def _load_audio_with_vllm(self, audio_url: str) -> tuple[np.ndarray, float]:
normalized_url = await validate_media_url(audio_url, self._url_policy)
media_io = self._create_vllm_audio_io()
# HTTP(S) goes through our SSRF-safe fetcher so each redirect hop is
# revalidated; vLLM's own fetcher honors redirects without re-checking.
# data: and file:// never touch the network, so vLLM can handle them.
if urlparse(normalized_url).scheme in ("http", "https"):
http_client = get_http_client(self._http_timeout)
response = await fetch_with_revalidation(
http_client, normalized_url, self._url_policy
)
response.raise_for_status()
return await asyncio.to_thread(media_io.load_bytes, response.content)
connector = self._get_vllm_media_connector()
normalized_url = self._normalize_audio_url(audio_url)
# TODO: Add caching for repeated remote `audio_url` downloads to avoid
# refetching the same asset across requests.
return await connector.load_from_url_async(
normalized_url,
self._create_vllm_audio_io(),
fetch_timeout=self._http_timeout,
normalized_url, media_io, fetch_timeout=self._http_timeout
)
@_nvtx.annotate("mm:audio:load_audio", color="cyan")
......
......@@ -6,12 +6,6 @@
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
......@@ -20,28 +14,32 @@ import httpx
logger = logging.getLogger(__name__)
# Global HTTP client instance
# The shared client has ``follow_redirects=False`` to prevent redirect-based
# SSRF filter bypass. Callers must follow redirects manually via
# :func:`dynamo.common.multimodal.url_validator.fetch_with_revalidation` so
# that each hop is re-validated against the SSRF policy.
_global_http_client: Optional[httpx.AsyncClient] = None
def get_http_client(timeout: float = 60.0) -> httpx.AsyncClient:
"""
Get or create a shared HTTP client instance.
"""Return a shared async HTTP client for media fetches.
Args:
timeout: Timeout for HTTP requests
Returns:
Shared HTTP client instance
The client intentionally disables automatic redirect following. Callers
that need to follow redirects must route the request through
:func:`fetch_with_revalidation`, which revalidates every redirect hop
against the SSRF policy.
"""
global _global_http_client
if _global_http_client is None or _global_http_client.is_closed:
_global_http_client = httpx.AsyncClient(
timeout=timeout,
follow_redirects=True,
follow_redirects=False,
limits=httpx.Limits(max_keepalive_connections=20, max_connections=100),
)
logger.info(f"Shared HTTP client initialized with timeout={timeout}s")
logger.info(
"Shared HTTP client initialized (timeout=%ss, follow_redirects=False)",
timeout,
)
return _global_http_client
......@@ -20,6 +20,11 @@ from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl
from dynamo.common.utils.runtime import run_async
from .http_client import get_http_client
from .url_validator import (
UrlValidationPolicy,
fetch_with_revalidation,
validate_media_url,
)
logger = logging.getLogger(__name__)
......@@ -36,6 +41,7 @@ class ImageLoader:
cache_size: int = CACHE_SIZE_MAXIMUM,
http_timeout: float = 30.0,
enable_frontend_decoding: bool = False,
url_policy: UrlValidationPolicy | None = None,
):
"""
Initialize the ImageLoader with caching, HTTP settings, and optional NIXL config for
......@@ -49,12 +55,14 @@ class ImageLoader:
enable_frontend_decoding: If True, enables NIXL RDMA for transferring
decoded images directly from frontend memory, bypassing standard
network transport. Defaults to False.
url_policy: Policy for validating URLs. Defaults to UrlValidationPolicy.from_env().
"""
self._http_timeout = http_timeout
self._cache_size = cache_size
self._image_cache: OrderedDict[str, Image.Image] = OrderedDict()
self._inflight: dict[str, asyncio.Task[Image.Image]] = {}
self._enable_frontend_decoding = enable_frontend_decoding
self._url_policy = url_policy or UrlValidationPolicy.from_env()
# Lazy-init NIXL connector only when frontend decoding is enabled
self._nixl_connector = None
if self._enable_frontend_decoding:
......@@ -94,7 +102,9 @@ class ImageLoader:
try:
with _nvtx.annotate("mm:img:http_fetch", color="lime"):
http_client = get_http_client(self._http_timeout)
response = await http_client.get(image_url)
response = await fetch_with_revalidation(
http_client, image_url, self._url_policy
)
response.raise_for_status()
if not response.content:
raise ValueError("Empty response content from image URL")
......@@ -134,26 +144,25 @@ class ImageLoader:
raise ValueError(
"Invalid image source scheme: local file access is not allowed"
)
normalized_url = await validate_media_url(image_url, self._url_policy)
parsed_url = urlparse(normalized_url)
if parsed_url.scheme in ("http", "https"):
key = image_url.lower()
key = normalized_url.lower()
# Check cache (sync — no await, no interleaving possible)
if key in self._image_cache:
logger.debug(f"Image found in cache for URL: {image_url}")
self._image_cache.move_to_end(key)
return self._image_cache[key]
# Join existing in-flight task, or start a new one
if key not in self._inflight:
task = asyncio.create_task(self._fetch_and_cache(key, image_url))
task = asyncio.create_task(self._fetch_and_cache(key, normalized_url))
# Suppress "exception was never retrieved" if all waiters cancel
task.add_done_callback(
lambda t: t.exception() if not t.cancelled() else None
)
self._inflight[key] = task
# shield so cancelling THIS caller doesn't cancel the shared task
return await asyncio.shield(self._inflight[key])
if parsed_url.scheme == "data":
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""URL / path validation and SSRF-safe HTTP fetching for multimodal loaders.
By default (``UrlValidationPolicy()``), only ``https://`` and ``data:`` URLs
are allowed; private / internal IPs and local filesystem access are both
blocked. Individual loaders can add stricter rules on top — ``ImageLoader``,
for example, refuses every local input regardless of policy.
To loosen the defaults, either build a ``UrlValidationPolicy(...)`` directly
or call ``UrlValidationPolicy.from_env()`` to pick up the ``DYN_MM_*`` vars
below.
"""
import asyncio
import ipaddress
import os
import socket
from dataclasses import dataclass
from pathlib import Path
from urllib.parse import urlparse
import httpx
class UrlValidationError(ValueError):
"""Raised when a URL or filesystem path fails the configured policy."""
# IP ranges that must never be reachable from a user-controlled URL.
# Source: RFC1918 (private), RFC6598 (CGNAT), RFC5735 (loopback, link-local,
# 0.0.0.0/8), RFC4193 (ULA), RFC4291 (IPv6 loopback / link-local), RFC6890
# (reserved). Link-local 169.254/16 covers the AWS / OpenStack metadata IP.
_BLOCKED_IP_NETWORKS: tuple[ipaddress.IPv4Network | ipaddress.IPv6Network, ...] = (
ipaddress.ip_network("0.0.0.0/8"),
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("100.64.0.0/10"),
ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("169.254.0.0/16"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.0.0.0/24"),
ipaddress.ip_network("192.0.2.0/24"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("198.18.0.0/15"),
ipaddress.ip_network("198.51.100.0/24"),
ipaddress.ip_network("203.0.113.0/24"),
ipaddress.ip_network("224.0.0.0/4"),
ipaddress.ip_network("240.0.0.0/4"),
ipaddress.ip_network("255.255.255.255/32"),
ipaddress.ip_network("::/128"),
ipaddress.ip_network("::1/128"),
ipaddress.ip_network("::ffff:0:0/96"),
ipaddress.ip_network("fc00::/7"),
ipaddress.ip_network("fe80::/10"),
ipaddress.ip_network("ff00::/8"),
)
# Hostnames that resolve to cloud metadata / internal services regardless of
# DNS records. Matched case-insensitively.
_BLOCKED_HOSTS: frozenset[str] = frozenset(
{
"localhost",
"localhost.localdomain",
"ip6-localhost",
"ip6-loopback",
"metadata",
"metadata.google.internal",
"metadata.goog",
"kubernetes.default",
"kubernetes.default.svc",
}
)
def is_blocked_ip(ip_text: str) -> bool:
"""Return True if ``ip_text`` parses as an IP inside one of the blocked ranges."""
try:
ip = ipaddress.ip_address(ip_text)
except ValueError:
return False
return any(ip in net for net in _BLOCKED_IP_NETWORKS)
@dataclass(frozen=True)
class UrlValidationPolicy:
"""Frozen policy describing which media URLs and local paths are allowed."""
allow_http: bool = False
allow_private_ips: bool = False
allowed_local_path: str | None = None
@classmethod
def from_env(cls) -> "UrlValidationPolicy":
"""Build a policy by reading the ``DYN_MM_*`` environment variables."""
allow_internal = os.getenv("DYN_MM_ALLOW_INTERNAL", "0") == "1"
return cls(
allow_http=allow_internal,
allow_private_ips=allow_internal,
allowed_local_path=os.getenv("DYN_MM_LOCAL_PATH", "").strip() or None,
)
async def validate_url(url: str, policy: UrlValidationPolicy) -> str:
"""Check ``url`` against ``policy`` and return it unchanged if it passes.
``https://`` and ``data:`` always pass. ``http://`` needs
``allow_http=True``. Anything else is rejected outright.
For URLs with a hostname, we resolve it here (off the event loop via
``loop.getaddrinfo``) and check the resulting IPs against the blocked
ranges. This catches obvious DNS rebinding but not an attacker who
changes their answer between this lookup and httpx's actual connect.
Raises ``UrlValidationError`` on any policy violation.
"""
if not url:
raise UrlValidationError("URL is empty")
parsed = urlparse(url)
scheme = parsed.scheme.lower()
if scheme == "data":
return url
if scheme not in ("http", "https"):
raise UrlValidationError(f"URL scheme '{scheme}' not allowed")
if scheme == "http" and not policy.allow_http:
raise UrlValidationError(
"http:// URLs are not allowed; set DYN_MM_ALLOW_INTERNAL=1 to enable"
)
host = (parsed.hostname or "").lower()
if not host:
raise UrlValidationError(f"URL has no host component: {url!r}")
if not policy.allow_private_ips and host in _BLOCKED_HOSTS:
raise UrlValidationError(
f"Host '{host}' is blocked (resolves to internal service)"
)
try:
ipaddress.ip_address(host)
except ValueError:
pass
else:
if not policy.allow_private_ips and is_blocked_ip(host):
raise UrlValidationError(f"IP literal '{host}' is in a blocked range")
return url
if policy.allow_private_ips:
return url
loop = asyncio.get_running_loop()
try:
infos = await loop.getaddrinfo(host, None)
except socket.gaierror as exc:
raise UrlValidationError(f"Could not resolve host '{host}': {exc}") from exc
for info in infos:
addr = info[4][0]
if is_blocked_ip(addr):
raise UrlValidationError(f"Host '{host}' resolves to blocked IP '{addr}'")
return url
def validate_local_path(path: str, policy: UrlValidationPolicy) -> Path:
"""Resolve ``path`` and confirm it sits inside ``allowed_local_path``.
We call ``Path.resolve()`` first, so symlinks that point outside the
allowed prefix are caught. Local access is refused outright when
``allowed_local_path`` is unset (the default).
Raises ``UrlValidationError`` if the feature is off or the resolved
path escapes the prefix.
"""
if not policy.allowed_local_path:
raise UrlValidationError(
"Local media paths are not permitted; set " "DYN_MM_LOCAL_PATH to enable"
)
try:
resolved = Path(path).expanduser().resolve(strict=True)
except FileNotFoundError as exc:
raise UrlValidationError(f"File not found: {path}") from exc
except OSError as exc:
raise UrlValidationError(f"Could not resolve path '{path}': {exc}") from exc
try:
allowed = Path(policy.allowed_local_path).expanduser().resolve(strict=True)
except FileNotFoundError as exc:
raise UrlValidationError(
f"Configured allowed_local_path does not exist: {policy.allowed_local_path}"
) from exc
try:
resolved.relative_to(allowed)
except ValueError as exc:
raise UrlValidationError(
f"Path '{path}' is outside the allowed directory '{policy.allowed_local_path}'"
) from exc
return resolved
async def validate_media_url(url: str, policy: UrlValidationPolicy) -> str:
"""Validate any media input and return a canonical URL string.
Bare filesystem paths and ``file://`` URIs go through
``validate_local_path`` and come back as a resolved ``file://`` URI.
Everything else goes through ``validate_url`` and is returned
unchanged. Callers can still reject the result afterwards —
``ImageLoader``, for example, refuses local files regardless.
Raises ``UrlValidationError`` on any policy violation.
"""
if not url:
raise UrlValidationError("URL is empty")
parsed = urlparse(url)
scheme = parsed.scheme.lower()
if scheme in ("", "file"):
raw_path = parsed.path if scheme == "file" else url
resolved = validate_local_path(raw_path, policy)
return resolved.as_uri()
return await validate_url(url, policy)
_MAX_REDIRECTS = 3
async def fetch_with_revalidation(
client: httpx.AsyncClient,
url: str,
policy: UrlValidationPolicy,
) -> httpx.Response:
"""Safely fetch a URL while checking security policy at every redirect.
Only ``_MAX_REDIRECTS`` hops allowed. ``client`` must have
``follow_redirects=False`` (the default from ``get_http_client``).
We follow redirects ourselves and validate each ``Location`` header
against the policy first.
Only plain ``GET`` with no custom headers is supported. httpx normally
strips credentials on cross-origin redirects only when
``follow_redirects=True``.
Raises ``UrlValidationError`` on any policy violation or when the
redirect chain exceeds ``_MAX_REDIRECTS``.
"""
current_url = url
hops_remaining = _MAX_REDIRECTS
visited: list[str] = []
while True:
await validate_url(current_url, policy)
visited.append(current_url)
request = client.build_request("GET", current_url)
response = await client.send(request, follow_redirects=False)
if not response.is_redirect:
return response
location = response.headers.get("location")
if not location:
return response
if hops_remaining <= 0:
await response.aclose()
raise UrlValidationError(
f"Too many redirects (max={_MAX_REDIRECTS}); chain={visited}"
)
hops_remaining -= 1
next_url = str(response.url.join(location))
await response.aclose()
current_url = next_url
......@@ -16,13 +16,18 @@
import asyncio
import logging
import os
from pathlib import Path
from typing import Any, Awaitable, Dict, Final, List
from urllib.parse import urlparse
import numpy as np
import dynamo.nixl_connect as nixl_connect
from dynamo.common.multimodal.http_client import get_http_client
from dynamo.common.multimodal.url_validator import (
UrlValidationPolicy,
fetch_with_revalidation,
validate_media_url,
)
from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl
from dynamo.common.utils.runtime import run_async
......@@ -53,33 +58,27 @@ class VideoLoader:
http_timeout: float = 60.0,
num_frames: int = NUM_FRAMES_DEFAULT,
enable_frontend_decoding: bool = False,
url_policy: UrlValidationPolicy | None = None,
) -> None:
self._http_timeout = int(http_timeout)
self._num_frames = num_frames
self._enable_frontend_decoding = enable_frontend_decoding
self._url_policy = url_policy or UrlValidationPolicy.from_env()
self._nixl_connector = None
self._vllm_media_connector = None
if self._enable_frontend_decoding:
self._nixl_connector = nixl_connect.Connector()
run_async(self._nixl_connector.initialize)
@staticmethod
def _normalize_video_url(video_url: str) -> str:
parsed_url = urlparse(video_url)
if parsed_url.scheme or not video_url:
return video_url
file_path = Path(video_url).expanduser()
if not file_path.exists():
raise FileNotFoundError(f"Error reading file: {file_path}")
return file_path.resolve().as_uri()
def _get_vllm_media_connector(self) -> Any:
if self._vllm_media_connector is None:
MediaConnector, _, _ = _require_vllm_video_media()
# Match the previous backend behavior and allow direct local file paths.
self._vllm_media_connector = MediaConnector(allowed_local_media_path="/")
# Confine vLLM's own local-path access to the same prefix we enforce.
# Empty string matches vLLM's secure default (no local access).
allowed = self._url_policy.allowed_local_path or ""
self._vllm_media_connector = MediaConnector(
allowed_local_media_path=allowed
)
return self._vllm_media_connector
......@@ -93,14 +92,23 @@ class VideoLoader:
async def _load_video_with_vllm(
self, video_url: str
) -> tuple[np.ndarray, Dict[str, Any]]:
normalized_url = await validate_media_url(video_url, self._url_policy)
media_io = self._create_vllm_video_io()
# HTTP(S) goes through our SSRF-safe fetcher so each redirect hop is
# revalidated; vLLM's own fetcher honors redirects without re-checking.
# data: and file:// never touch the network, so vLLM can handle them.
if urlparse(normalized_url).scheme in ("http", "https"):
http_client = get_http_client(self._http_timeout)
response = await fetch_with_revalidation(
http_client, normalized_url, self._url_policy
)
response.raise_for_status()
return await asyncio.to_thread(media_io.load_bytes, response.content)
connector = self._get_vllm_media_connector()
normalized_url = self._normalize_video_url(video_url)
# TODO: Add caching for repeated remote `video_url` downloads to avoid
# refetching the same asset across requests.
return await connector.load_from_url_async(
normalized_url,
self._create_vllm_video_io(),
fetch_timeout=self._http_timeout,
normalized_url, media_io, fetch_timeout=self._http_timeout
)
async def load_video(self, video_url: str) -> tuple[np.ndarray, Dict[str, Any]]:
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import numpy as np
import pytest
import dynamo.common.multimodal.audio_loader as audio_loader_module
from dynamo.common.multimodal.audio_loader import AudioLoader
from dynamo.common.multimodal.url_validator import (
UrlValidationError,
UrlValidationPolicy,
validate_media_url,
)
pytestmark = [
pytest.mark.unit,
......@@ -16,34 +22,116 @@ pytestmark = [
]
def test_normalize_audio_url_converts_local_paths(tmp_path):
def _permissive_http_policy() -> UrlValidationPolicy:
"""Policy that lets existing tests keep using https://example.com/... URLs.
Private/loopback IPs and DNS checks are bypassed so tests don't depend on
real DNS resolution of example.com.
"""
return UrlValidationPolicy(
allow_http=True,
allow_private_ips=True,
)
async def test_normalize_audio_url_converts_local_paths(tmp_path):
audio_path = tmp_path / "sample.wav"
audio_path.write_bytes(b"RIFF")
policy = UrlValidationPolicy(allowed_local_path=str(tmp_path))
assert (
AudioLoader._normalize_audio_url(str(audio_path))
await validate_media_url(str(audio_path), policy)
== audio_path.resolve().as_uri()
)
def test_normalize_audio_url_preserves_data_urls():
async def test_normalize_audio_url_preserves_data_urls():
data_url = "data:audio/wav;base64,UklGRg=="
assert AudioLoader._normalize_audio_url(data_url) == data_url
policy = UrlValidationPolicy()
assert await validate_media_url(data_url, policy) == data_url
def test_normalize_audio_url_preserves_http_urls():
async def test_normalize_audio_url_preserves_http_urls():
url = "https://example.com/audio.wav"
assert AudioLoader._normalize_audio_url(url) == url
policy = _permissive_http_policy()
assert await validate_media_url(url, policy) == url
async def test_normalize_audio_url_rejects_bare_path_by_default(tmp_path):
audio_path = tmp_path / "sample.wav"
audio_path.write_bytes(b"RIFF")
policy = UrlValidationPolicy()
with pytest.raises(UrlValidationError, match="Local media paths are not permitted"):
await validate_media_url(str(audio_path), policy)
async def test_normalize_audio_url_rejects_private_ip():
policy = UrlValidationPolicy()
with pytest.raises(UrlValidationError):
await validate_media_url("https://169.254.169.254/audio.wav", policy)
def test_normalize_audio_url_raises_on_missing_file():
with pytest.raises(FileNotFoundError, match="Error reading file"):
AudioLoader._normalize_audio_url("/nonexistent/audio.wav")
async def test_normalize_audio_url_accepts_file_uri_inside_prefix(tmp_path):
audio_path = tmp_path / "sample.wav"
audio_path.write_bytes(b"RIFF")
policy = UrlValidationPolicy(allowed_local_path=str(tmp_path))
file_uri = audio_path.resolve().as_uri()
assert await validate_media_url(file_uri, policy) == file_uri
async def test_normalize_audio_url_rejects_file_uri_outside_prefix(tmp_path):
allowed = tmp_path / "media"
allowed.mkdir()
other = tmp_path / "secret.wav"
other.write_bytes(b"RIFF")
policy = UrlValidationPolicy(allowed_local_path=str(allowed))
with pytest.raises(UrlValidationError, match="outside the allowed directory"):
await validate_media_url(other.resolve().as_uri(), policy)
@pytest.mark.asyncio
async def test_load_audio_rejects_http_by_default():
loader = AudioLoader(url_policy=UrlValidationPolicy())
with pytest.raises(ValueError, match="not allowed"):
await loader.load_audio("http://example.com/x.wav")
@pytest.mark.asyncio
async def test_load_audio_blocks_redirect_to_private_ip():
"""A 302 to a blocked IP must be rejected per-hop, not only the initial URL."""
loader = AudioLoader(url_policy=UrlValidationPolicy())
loader._create_vllm_audio_io = MagicMock(return_value=MagicMock()) # type: ignore[method-assign]
redirect = MagicMock(spec=httpx.Response)
redirect.status_code = 302
redirect.is_redirect = True
redirect.headers = {"location": "https://169.254.169.254/evil"}
redirect.url = httpx.URL("https://8.8.8.8/a.wav")
redirect.aclose = AsyncMock()
client = MagicMock(spec=httpx.AsyncClient)
client.build_request = MagicMock(return_value=MagicMock(spec=httpx.Request))
client.send = AsyncMock(return_value=redirect)
with patch(
"dynamo.common.multimodal.audio_loader.get_http_client",
return_value=client,
):
with pytest.raises(ValueError, match="blocked range"):
await loader.load_audio("https://8.8.8.8/a.wav")
@pytest.mark.asyncio
async def test_load_audio_uses_vllm_media_connector():
loader = AudioLoader()
loader._url_policy = UrlValidationPolicy()
waveform = np.random.randn(16000).astype(np.float32)
sr = 44100.0
loader._load_audio_with_vllm = AsyncMock( # type: ignore[method-assign]
......@@ -60,7 +148,7 @@ async def test_load_audio_uses_vllm_media_connector():
@pytest.mark.asyncio
async def test_load_audio_rejects_empty_waveform():
loader = AudioLoader()
loader = AudioLoader(url_policy=_permissive_http_policy())
loader._load_audio_with_vllm = AsyncMock( # type: ignore[method-assign]
return_value=(np.array([], dtype=np.float32), 16000.0)
)
......
......@@ -25,6 +25,7 @@ import pytest
from PIL import Image
from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.common.multimodal.url_validator import UrlValidationPolicy
pytestmark = [
pytest.mark.asyncio,
......@@ -45,29 +46,42 @@ def _make_png_bytes() -> bytes:
PNG_BYTES = _make_png_bytes()
def _permissive_policy(
allowed_local_path: str | None = None,
) -> UrlValidationPolicy:
"""Return a policy that permits the schemes used by tests without DNS hits."""
return UrlValidationPolicy(
allow_http=True,
allow_private_ips=True,
allowed_local_path=allowed_local_path,
)
def _mock_http_client(
content: bytes = PNG_BYTES,
status_code: int = 200,
delay: float = 0.0,
side_effect: Exception | None = None,
) -> AsyncMock:
"""Return a mock httpx.AsyncClient whose .get() returns a fake response.
"""Return a mock httpx.AsyncClient compatible with fetch_with_revalidation.
``fetch_with_revalidation`` uses ``client.build_request(...)`` then
``client.send(request, follow_redirects=False)``. We stub both and also
keep ``client.get`` behaviour for any legacy callers.
Args:
content: Raw bytes returned as the HTTP response body.
status_code: HTTP status code; >=400 triggers raise_for_status().
delay: Seconds to sleep before responding (simulates network latency).
side_effect: If set, .get() raises this exception instead of returning.
side_effect: If set, .send()/.get() raises this exception instead.
"""
async def _get(url: str) -> Any:
if delay > 0:
await asyncio.sleep(delay)
if side_effect is not None:
raise side_effect
def _build_response() -> Any:
resp = MagicMock(spec=httpx.Response)
resp.status_code = status_code
resp.content = content
resp.is_redirect = False
resp.headers = {}
resp.raise_for_status = MagicMock()
if status_code >= 400:
resp.raise_for_status.side_effect = httpx.HTTPStatusError(
......@@ -75,14 +89,27 @@ def _mock_http_client(
)
return resp
async def _respond(*_args: Any, **_kwargs: Any) -> Any:
if delay > 0:
await asyncio.sleep(delay)
if side_effect is not None:
raise side_effect
return _build_response()
client = AsyncMock()
client.get = AsyncMock(side_effect=_get)
client.get = AsyncMock(side_effect=_respond)
client.build_request = MagicMock(return_value=MagicMock(spec=httpx.Request))
client.send = AsyncMock(side_effect=_respond)
return client
@pytest.fixture(autouse=True)
def loader() -> ImageLoader:
return ImageLoader(cache_size=4, http_timeout=30.0)
return ImageLoader(
cache_size=4,
http_timeout=30.0,
url_policy=_permissive_policy(),
)
# --- Concurrent same-URL dedup ---
......@@ -103,7 +130,7 @@ async def test_concurrent_same_url_deduplicates(loader: ImageLoader) -> None:
assert len(results) == 2
assert results[0].size == results[1].size
# Only one HTTP GET should have been issued
assert mock_client.get.call_count == 1
assert mock_client.send.call_count == 1
async def test_concurrent_different_urls_fetch_independently(
......@@ -120,7 +147,7 @@ async def test_concurrent_different_urls_fetch_independently(
loader.load_image("https://example.com/b.png"),
)
assert mock_client.get.call_count == 2
assert mock_client.send.call_count == 2
# --- Waiter cancellation isolation ---
......@@ -206,6 +233,46 @@ async def test_data_url_non_image_rejected(loader: ImageLoader) -> None:
await loader.load_image("data:text/plain;base64,aGVsbG8=")
# --- SSRF / scheme rejection ---
async def test_http_scheme_rejected_by_default(monkeypatch) -> None:
"""With default env policy, http:// URLs must be rejected before any fetch."""
monkeypatch.delenv("DYN_MM_ALLOW_INTERNAL", raising=False)
monkeypatch.delenv("DYN_MM_LOCAL_PATH", raising=False)
default_loader = ImageLoader(cache_size=4, http_timeout=30.0)
mock_client = _mock_http_client()
with patch(
"dynamo.common.multimodal.image_loader.get_http_client",
return_value=mock_client,
):
with pytest.raises(ValueError, match="scheme|not allowed"):
await default_loader.load_image("http://example.com/x.png")
# The shared HTTP client must not be touched when the URL is rejected.
assert mock_client.send.call_count == 0
assert mock_client.get.call_count == 0
async def test_blocked_private_ip_rejected(monkeypatch) -> None:
"""Cloud metadata / private IPs must be rejected even over https."""
monkeypatch.delenv("DYN_MM_ALLOW_INTERNAL", raising=False)
strict_loader = ImageLoader(
cache_size=4,
http_timeout=30.0,
url_policy=UrlValidationPolicy(
allow_http=True,
allow_private_ips=False,
),
)
with pytest.raises(ValueError, match="blocked range"):
await strict_loader.load_image("https://169.254.169.254/latest/meta-data/")
# --- HTTP error contract ---
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for ``dynamo.common.multimodal.url_validator``.
These cover scheme / IP / hostname / path / redirect logic in isolation of
the media loaders, so they run quickly with no network and no vLLM imports.
"""
from __future__ import annotations
import socket
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from dynamo.common.multimodal.url_validator import (
UrlValidationError,
UrlValidationPolicy,
fetch_with_revalidation,
is_blocked_ip,
validate_local_path,
validate_url,
)
pytestmark = [
pytest.mark.unit,
pytest.mark.pre_merge,
pytest.mark.gpu_0,
]
STRICT_HTTPS = UrlValidationPolicy()
PERMISSIVE = UrlValidationPolicy(
allow_http=True,
allow_private_ips=True,
)
# ---------------------------------------------------------------------------
# is_blocked_ip()
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"ip",
[
"127.0.0.1",
"10.0.0.1",
"172.16.5.5",
"192.168.1.1",
"169.254.169.254", # AWS metadata
"100.64.0.1", # CGNAT
"0.0.0.0",
"::1",
"fe80::1",
"fc00::1",
"240.0.0.1", # reserved
],
)
def test_is_blocked_ip_blocks_known_ranges(ip: str) -> None:
assert is_blocked_ip(ip) is True
@pytest.mark.parametrize(
"ip",
[
"8.8.8.8",
"1.1.1.1",
"93.184.216.34", # example.com
"2606:4700:4700::1111", # Cloudflare
],
)
def test_is_blocked_ip_allows_public(ip: str) -> None:
assert is_blocked_ip(ip) is False
def test_is_blocked_ip_non_ip_literal_returns_false() -> None:
# A hostname, not an IP — is_blocked_ip only classifies literals.
assert is_blocked_ip("example.com") is False
# ---------------------------------------------------------------------------
# validate_url() — scheme handling
# ---------------------------------------------------------------------------
async def test_validate_url_rejects_empty() -> None:
with pytest.raises(UrlValidationError, match="empty"):
await validate_url("", STRICT_HTTPS)
async def test_validate_url_rejects_http_by_default() -> None:
with pytest.raises(UrlValidationError, match="not allowed"):
await validate_url("http://example.com/x.png", STRICT_HTTPS)
async def test_validate_url_rejects_file_scheme() -> None:
with pytest.raises(UrlValidationError, match="scheme 'file'"):
await validate_url("file:///etc/passwd", STRICT_HTTPS)
async def test_validate_url_rejects_ftp() -> None:
with pytest.raises(UrlValidationError, match="scheme 'ftp'"):
await validate_url("ftp://example.com/x.png", STRICT_HTTPS)
async def test_validate_url_accepts_data_url_by_default() -> None:
# data: URLs never touch the network — we allow them without further checks.
await validate_url("data:image/png;base64,iVBORw0KGgoAAAA=", STRICT_HTTPS)
async def test_validate_url_http_allowed_when_opted_in() -> None:
policy = PERMISSIVE
# With public hostname + allow_private_ips=True (keeps DNS out of the test)
await validate_url("http://example.com/x.png", policy)
# ---------------------------------------------------------------------------
# validate_url() — IP literal handling
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"url",
[
"https://127.0.0.1/x.png",
"https://169.254.169.254/latest/meta-data/",
"https://10.0.0.5/x.png",
"https://192.168.1.1/x.png",
"https://[::1]/x.png",
],
)
async def test_validate_url_rejects_private_ip_literal(url: str) -> None:
with pytest.raises(UrlValidationError, match="blocked range"):
await validate_url(url, STRICT_HTTPS)
async def test_validate_url_allows_public_ip_literal() -> None:
# Public IP literal and allow_private_ips=False — should pass the IP test
# and skip DNS resolution (the IP path short-circuits).
await validate_url("https://8.8.8.8/x.png", STRICT_HTTPS)
async def test_validate_url_allows_private_ip_when_opted_in() -> None:
await validate_url("https://127.0.0.1/x.png", PERMISSIVE)
# ---------------------------------------------------------------------------
# validate_url() — blocked hostnames
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"host",
[
"localhost",
"metadata.google.internal",
"metadata",
"kubernetes.default.svc",
],
)
async def test_validate_url_rejects_blocked_hostname(host: str) -> None:
with pytest.raises(UrlValidationError, match="blocked"):
await validate_url(f"https://{host}/path", STRICT_HTTPS)
# ---------------------------------------------------------------------------
# validate_url() — DNS resolution
# ---------------------------------------------------------------------------
def _fake_getaddrinfo(addrs: list[str]):
def _impl(host: str, *_args: Any, **_kwargs: Any):
return [
(socket.AF_INET, socket.SOCK_STREAM, 6, "", (addr, 0)) for addr in addrs
]
return _impl
async def test_validate_url_rejects_host_resolving_to_private_ip() -> None:
with patch(
"dynamo.common.multimodal.url_validator.socket.getaddrinfo",
side_effect=_fake_getaddrinfo(["10.0.0.5"]),
):
with pytest.raises(UrlValidationError, match="blocked IP"):
await validate_url("https://attacker.example.com/x.png", STRICT_HTTPS)
async def test_validate_url_rejects_host_if_any_ip_is_private() -> None:
# Even if the host resolves to a public IP too, any blocked IP is fatal.
with patch(
"dynamo.common.multimodal.url_validator.socket.getaddrinfo",
side_effect=_fake_getaddrinfo(["8.8.8.8", "169.254.169.254"]),
):
with pytest.raises(UrlValidationError, match="169.254.169.254"):
await validate_url("https://mixed.example.com/x.png", STRICT_HTTPS)
async def test_validate_url_accepts_public_host() -> None:
with patch(
"dynamo.common.multimodal.url_validator.socket.getaddrinfo",
side_effect=_fake_getaddrinfo(["93.184.216.34"]),
):
await validate_url("https://example.com/x.png", STRICT_HTTPS)
async def test_validate_url_resolution_failure_raises() -> None:
with patch(
"dynamo.common.multimodal.url_validator.socket.getaddrinfo",
side_effect=socket.gaierror("nodename nor servname provided"),
):
with pytest.raises(UrlValidationError, match="Could not resolve"):
await validate_url("https://does-not-exist.invalid/x.png", STRICT_HTTPS)
async def test_validate_url_skips_resolution_when_private_allowed() -> None:
# In developer mode we short-circuit DNS to keep tests deterministic.
with patch("dynamo.common.multimodal.url_validator.socket.getaddrinfo") as resolver:
await validate_url("https://example.com/x.png", PERMISSIVE)
resolver.assert_not_called()
# ---------------------------------------------------------------------------
# validate_local_path()
# ---------------------------------------------------------------------------
def test_validate_local_path_rejected_when_disabled() -> None:
with pytest.raises(UrlValidationError, match="not permitted"):
validate_local_path("/etc/passwd", STRICT_HTTPS)
def test_validate_local_path_accepts_inside_prefix(tmp_path) -> None:
media = tmp_path / "media"
media.mkdir()
target = media / "sample.png"
target.write_bytes(b"\x89PNG\r\n")
policy = UrlValidationPolicy(allowed_local_path=str(media))
resolved = validate_local_path(str(target), policy)
assert resolved == target.resolve()
def test_validate_local_path_rejects_outside_prefix(tmp_path) -> None:
media = tmp_path / "media"
media.mkdir()
other = tmp_path / "secrets"
other.mkdir()
secret = other / "creds.txt"
secret.write_text("hunter2")
policy = UrlValidationPolicy(allowed_local_path=str(media))
with pytest.raises(UrlValidationError, match="outside the allowed directory"):
validate_local_path(str(secret), policy)
def test_validate_local_path_rejects_symlink_escape(tmp_path) -> None:
media = tmp_path / "media"
media.mkdir()
outside = tmp_path / "outside.txt"
outside.write_text("secret")
link = media / "link.png"
link.symlink_to(outside)
policy = UrlValidationPolicy(allowed_local_path=str(media))
# Path.resolve() follows the symlink; the target is outside the prefix.
with pytest.raises(UrlValidationError, match="outside the allowed directory"):
validate_local_path(str(link), policy)
def test_validate_local_path_missing_file(tmp_path) -> None:
policy = UrlValidationPolicy(allowed_local_path=str(tmp_path))
with pytest.raises(UrlValidationError, match="File not found"):
validate_local_path(str(tmp_path / "nope.png"), policy)
def test_validate_local_path_missing_prefix(tmp_path) -> None:
target = tmp_path / "sample.png"
target.write_bytes(b"x")
policy = UrlValidationPolicy(allowed_local_path=str(tmp_path / "does-not-exist"))
with pytest.raises(UrlValidationError, match="allowed_local_path does not exist"):
validate_local_path(str(target), policy)
# ---------------------------------------------------------------------------
# UrlValidationPolicy.from_env()
# ---------------------------------------------------------------------------
def test_policy_from_env_defaults(monkeypatch) -> None:
monkeypatch.delenv("DYN_MM_ALLOW_INTERNAL", raising=False)
monkeypatch.delenv("DYN_MM_LOCAL_PATH", raising=False)
policy = UrlValidationPolicy.from_env()
assert policy.allow_http is False
assert policy.allow_private_ips is False
assert policy.allowed_local_path is None
def test_policy_from_env_allow_internal(monkeypatch) -> None:
monkeypatch.setenv("DYN_MM_ALLOW_INTERNAL", "1")
monkeypatch.setenv("DYN_MM_LOCAL_PATH", "/data/media")
policy = UrlValidationPolicy.from_env()
assert policy.allow_http is True
assert policy.allow_private_ips is True
assert policy.allowed_local_path == "/data/media"
# ---------------------------------------------------------------------------
# fetch_with_revalidation()
# ---------------------------------------------------------------------------
def _mock_response(
status_code: int = 200,
location: str | None = None,
*,
request_url: str = "https://example.com/x.png",
) -> MagicMock:
resp = MagicMock(spec=httpx.Response)
resp.status_code = status_code
resp.headers = {}
resp.url = httpx.URL(request_url)
resp.is_redirect = status_code in (301, 302, 303, 307, 308)
if location is not None:
resp.headers = {"location": location}
resp.aclose = AsyncMock()
return resp
def _mock_client(responses: list[MagicMock]) -> MagicMock:
client = MagicMock(spec=httpx.AsyncClient)
client.build_request = MagicMock(
side_effect=lambda method, url, headers=None: MagicMock(spec=httpx.Request)
)
client.send = AsyncMock(side_effect=list(responses))
return client
@pytest.mark.asyncio
async def test_fetch_with_revalidation_returns_first_response() -> None:
policy = PERMISSIVE
resp = _mock_response(status_code=200)
client = _mock_client([resp])
result = await fetch_with_revalidation(client, "https://example.com/x.png", policy)
assert result is resp
assert client.send.await_count == 1
@pytest.mark.asyncio
async def test_fetch_with_revalidation_follows_safe_redirect() -> None:
policy = PERMISSIVE
redirect = _mock_response(
status_code=302,
location="https://example.com/final.png",
request_url="https://example.com/x.png",
)
final = _mock_response(status_code=200, request_url="https://example.com/final.png")
client = _mock_client([redirect, final])
result = await fetch_with_revalidation(client, "https://example.com/x.png", policy)
assert result is final
assert client.send.await_count == 2
redirect.aclose.assert_awaited()
@pytest.mark.asyncio
async def test_fetch_with_revalidation_blocks_redirect_to_private_ip() -> None:
# Strict policy — first hop is OK (public-IP literal), redirect target is blocked.
strict = UrlValidationPolicy(allow_private_ips=False)
redirect = _mock_response(
status_code=302,
location="http://169.254.169.254/latest/meta-data/",
request_url="https://8.8.8.8/x.png",
)
client = _mock_client([redirect])
with pytest.raises(UrlValidationError):
await fetch_with_revalidation(client, "https://8.8.8.8/x.png", strict)
# Only one send — the redirect target is rejected before any further fetch.
assert client.send.await_count == 1
@pytest.mark.asyncio
async def test_fetch_with_revalidation_enforces_redirect_limit() -> None:
# _MAX_REDIRECTS is hardcoded at 3; we need 4 redirect responses to trip it.
policy = UrlValidationPolicy(allow_private_ips=True) # keep DNS out of this test
def _hop(src: str, dst: str) -> MagicMock:
return _mock_response(status_code=302, location=dst, request_url=src)
client = _mock_client(
[
_hop("https://example.com/a", "https://example.com/b"),
_hop("https://example.com/b", "https://example.com/c"),
_hop("https://example.com/c", "https://example.com/d"),
_hop("https://example.com/d", "https://example.com/e"),
]
)
with pytest.raises(UrlValidationError, match="Too many redirects"):
await fetch_with_revalidation(client, "https://example.com/a", policy)
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import numpy as np
import pytest
import dynamo.common.multimodal.video_loader as video_loader_module
from dynamo.common.multimodal.url_validator import (
UrlValidationError,
UrlValidationPolicy,
validate_media_url,
)
from dynamo.common.multimodal.video_loader import VideoLoader
pytestmark = [
......@@ -16,25 +22,103 @@ pytestmark = [
]
def test_normalize_video_url_converts_local_paths(tmp_path):
async def test_normalize_video_url_converts_local_paths(tmp_path):
video_path = tmp_path / "sample.webm"
video_path.write_bytes(b"video")
policy = UrlValidationPolicy(allowed_local_path=str(tmp_path))
assert (
VideoLoader._normalize_video_url(str(video_path))
await validate_media_url(str(video_path), policy)
== video_path.resolve().as_uri()
)
def test_normalize_video_url_preserves_data_urls():
async def test_normalize_video_url_preserves_data_urls():
data_url = "data:video/webm;base64,Zm9v"
policy = UrlValidationPolicy()
assert await validate_media_url(data_url, policy) == data_url
async def test_normalize_video_url_rejects_bare_path_by_default(tmp_path):
video_path = tmp_path / "sample.webm"
video_path.write_bytes(b"video")
# Default policy has no allowed_local_path -> local paths rejected.
policy = UrlValidationPolicy()
assert VideoLoader._normalize_video_url(data_url) == data_url
with pytest.raises(UrlValidationError, match="Local media paths are not permitted"):
await validate_media_url(str(video_path), policy)
async def test_normalize_video_url_rejects_private_ip():
policy = UrlValidationPolicy()
with pytest.raises(UrlValidationError):
await validate_media_url("https://169.254.169.254/video.mp4", policy)
async def test_normalize_video_url_accepts_file_uri_inside_prefix(tmp_path):
video_path = tmp_path / "sample.webm"
video_path.write_bytes(b"video")
policy = UrlValidationPolicy(allowed_local_path=str(tmp_path))
file_uri = video_path.resolve().as_uri()
assert await validate_media_url(file_uri, policy) == file_uri
async def test_normalize_video_url_rejects_file_uri_outside_prefix(tmp_path):
allowed = tmp_path / "media"
allowed.mkdir()
other = tmp_path / "secret.webm"
other.write_bytes(b"video")
policy = UrlValidationPolicy(allowed_local_path=str(allowed))
with pytest.raises(UrlValidationError, match="outside the allowed directory"):
await validate_media_url(other.resolve().as_uri(), policy)
@pytest.mark.asyncio
async def test_load_video_rejects_http_by_default():
# Default env policy: http is disabled, so validation should reject this
# before any fetch is attempted.
loader = VideoLoader(url_policy=UrlValidationPolicy())
with pytest.raises(ValueError, match="not allowed"):
await loader.load_video("http://example.com/x.mp4")
@pytest.mark.asyncio
async def test_load_video_blocks_redirect_to_private_ip():
"""A 302 to a blocked IP must be rejected per-hop, not only the initial URL."""
loader = VideoLoader(url_policy=UrlValidationPolicy())
loader._create_vllm_video_io = MagicMock(return_value=MagicMock()) # type: ignore[method-assign]
redirect = MagicMock(spec=httpx.Response)
redirect.status_code = 302
redirect.is_redirect = True
redirect.headers = {"location": "https://169.254.169.254/evil"}
redirect.url = httpx.URL("https://8.8.8.8/v.mp4")
redirect.aclose = AsyncMock()
client = MagicMock(spec=httpx.AsyncClient)
client.build_request = MagicMock(return_value=MagicMock(spec=httpx.Request))
client.send = AsyncMock(return_value=redirect)
with patch(
"dynamo.common.multimodal.video_loader.get_http_client",
return_value=client,
):
with pytest.raises(ValueError, match="blocked range"):
await loader.load_video("https://8.8.8.8/v.mp4")
@pytest.mark.asyncio
async def test_load_video_uses_vllm_media_connector():
loader = VideoLoader()
# data: scheme is in the default allowlist regardless of env flags.
loader._url_policy = UrlValidationPolicy()
frames = np.arange(24, dtype=np.uint8).reshape(1, 2, 4, 3)[:, :, ::-1, :]
metadata = {"fps": 4.0, "frames_indices": [0], "total_num_frames": 1}
loader._load_video_with_vllm = AsyncMock( # type: ignore[method-assign]
......
......@@ -48,6 +48,25 @@ Dynamo provides support for improving latency and throughput for vision-and-lang
**Status:** ✅ Supported | 🧪 Experimental | ❌ Not supported
## Security: URL Validation
All multimodal loaders route remote fetches through a shared URL policy
(`dynamo.common.multimodal.url_validator`). Only
`https://` and `data:` URLs are allowed by default, private / internal IPs are blocked,
and local file access is disabled. Every HTTP redirect hop is re-validated
against the policy.
Two environment variables loosen the defaults for non-public deployments:
| Variable | Default | Effect |
|----------|---------|--------|
| `DYN_MM_ALLOW_INTERNAL` | `0` | Set to `1` to allow `http://` and private / internal IP targets. Intended for on-prem or local-dev setups where media lives on an internal network. |
| `DYN_MM_LOCAL_PATH` | *(empty)* | Absolute directory prefix. When set, `file://` URIs and bare paths are allowed if they resolve inside this prefix. |
<Warning>
**Never set `DYN_MM_ALLOW_INTERNAL=1` on public-facing deployments.** It opens SSRF paths to cloud metadata endpoints (AWS IMDS, GCE, Azure) and other internal services.
</Warning>
## Example Workflows
Reference implementations for deploying multimodal models:
......
......@@ -81,6 +81,7 @@ def _make_process_env(log_level: str = "debug", **extra) -> dict[str, str]:
env["DYN_LOG"] = log_level
env["DYN_NAMESPACE"] = NAMESPACE
env["DYN_REQUEST_PLANE"] = "tcp"
env["DYN_MM_ALLOW_INTERNAL"] = "1"
env.update(extra)
return env
......
......@@ -179,6 +179,11 @@ def make_multimodal_configs(
marks.extend(profile.marks)
key = f"mm_{topology}_{profile.short_name}"
worker_env = {
"DYN_MM_ALLOW_INTERNAL": "1",
"DYN_MM_LOCAL_PATH": str(WORKSPACE_DIR),
**topo_cfg.env,
}
configs[key] = config_cls(
name=key,
directory=topo_cfg.directory or directory,
......@@ -188,6 +193,6 @@ def make_multimodal_configs(
marks=marks,
delayed_start=topo_cfg.delayed_start,
request_payloads=profile.request_payloads,
env=topo_cfg.env,
env=worker_env,
)
return configs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment