Unverified Commit a8894467 authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

fix: multimodal ImageLoader errors (#7703)


Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 4b1d442b
...@@ -18,6 +18,7 @@ import base64 ...@@ -18,6 +18,7 @@ import base64
import binascii import binascii
import logging import logging
import os import os
from collections import OrderedDict
from io import BytesIO from io import BytesIO
from typing import Any, Dict, Final, List from typing import Any, Dict, Final, List
from urllib.parse import urlparse from urllib.parse import urlparse
...@@ -34,7 +35,6 @@ from .http_client import get_http_client ...@@ -34,7 +35,6 @@ from .http_client import get_http_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Constants for multimodal data variants # Constants for multimodal data variants
URL_VARIANT_KEY: Final = "Url" URL_VARIANT_KEY: Final = "Url"
DECODED_VARIANT_KEY: Final = "Decoded" DECODED_VARIANT_KEY: Final = "Decoded"
...@@ -63,8 +63,9 @@ class ImageLoader: ...@@ -63,8 +63,9 @@ class ImageLoader:
network transport. Defaults to False. network transport. Defaults to False.
""" """
self._http_timeout = http_timeout self._http_timeout = http_timeout
self._image_cache: dict[str, Image.Image] = {} self._cache_size = cache_size
self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size) self._image_cache: OrderedDict[str, Image.Image] = OrderedDict()
self._inflight: dict[str, asyncio.Task[Image.Image]] = {}
self._enable_frontend_decoding = enable_frontend_decoding self._enable_frontend_decoding = enable_frontend_decoding
# Lazy-init NIXL connector only when frontend decoding is enabled # Lazy-init NIXL connector only when frontend decoding is enabled
self._nixl_connector = None self._nixl_connector = None
...@@ -74,47 +75,112 @@ class ImageLoader: ...@@ -74,47 +75,112 @@ class ImageLoader:
self._nixl_connector.initialize self._nixl_connector.initialize
) # Synchronously wait for async init ) # Synchronously wait for async init
@staticmethod
def _open_image_sync(image_data: BytesIO) -> Image.Image:
"""Open, validate, and decode an image from raw bytes. Runs in a thread."""
image = Image.open(image_data, formats=["JPEG", "PNG", "WEBP"])
if image.format not in ("JPEG", "PNG", "WEBP"):
raise ValueError(f"Unsupported image format: {image.format}")
# Image.open() is lazy — convert() forces the actual pixel decode
return image.convert("RGB")
@staticmethod
async def _open_image(image_data: BytesIO) -> Image.Image:
"""Open and validate an image from raw bytes, converting to RGB."""
with _nvtx.annotate("mm:img:pil_open_convert", color="lime"):
return await asyncio.to_thread(ImageLoader._open_image_sync, image_data)
def _cache_put(self, key: str, image: Image.Image) -> None:
"""Insert into cache if not already present. Sync — no awaits."""
if key not in self._image_cache:
if len(self._image_cache) >= self._cache_size:
self._image_cache.popitem(last=False)
self._image_cache[key] = image
async def _fetch_and_process(self, image_url: str) -> Image.Image:
"""Fetch image via HTTP(S), decode with PIL, return RGB Image.
All exception normalization happens here so shared callers
see identical error types.
"""
try:
with _nvtx.annotate("mm:img:http_fetch", color="lime"):
http_client = get_http_client(self._http_timeout)
response = await http_client.get(image_url)
response.raise_for_status()
if not response.content:
raise ValueError("Empty response content from image URL")
image_data = BytesIO(response.content)
return await self._open_image(image_data)
except httpx.HTTPStatusError as e:
logger.error(f"HTTP {e.response.status_code} loading image: '{image_url}'")
raise
except httpx.TimeoutException as e:
logger.error(
f"{type(e).__name__} loading image: '{image_url}' "
f"(timeout={self._http_timeout}s)"
)
raise ValueError(f"Timeout loading image: '{image_url}'") from e
except httpx.HTTPError as e:
logger.error(f"{type(e).__name__} loading image: '{image_url}': {e}")
raise
except Exception as e:
logger.error(f"{type(e).__name__} loading image: '{image_url}': {e}")
raise ValueError(f"Failed to load image: '{image_url}': {e}") from e
async def _fetch_and_cache(self, key: str, image_url: str) -> Image.Image:
"""Shared task: fetch, cache, then remove from _inflight."""
try:
image = await self._fetch_and_process(image_url)
self._cache_put(key, image)
return image
finally:
self._inflight.pop(key, None)
@_nvtx.annotate("mm:img:load_image", color="lime") @_nvtx.annotate("mm:img:load_image", color="lime")
async def load_image(self, image_url: str) -> Image.Image: async def load_image(self, image_url: str) -> Image.Image:
parsed_url = urlparse(image_url) parsed_url = urlparse(image_url)
# For HTTP(S) URLs, check cache first
if parsed_url.scheme in ("http", "https"): if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower() key = image_url.lower()
if image_url_lower in self._image_cache:
# Check cache (sync — no await, no interleaving possible)
if key in self._image_cache:
logger.debug(f"Image found in cache for URL: {image_url}") logger.debug(f"Image found in cache for URL: {image_url}")
return self._image_cache[image_url_lower] self._image_cache.move_to_end(key)
return self._image_cache[key]
# Join existing in-flight task, or start a new one
if key not in self._inflight:
task = asyncio.create_task(self._fetch_and_cache(key, image_url))
# Suppress "exception was never retrieved" if all waiters cancel
task.add_done_callback(
lambda t: t.exception() if not t.cancelled() else None
)
self._inflight[key] = task
# shield so cancelling THIS caller doesn't cancel the shared task
return await asyncio.shield(self._inflight[key])
try: try:
if parsed_url.scheme == "data": if parsed_url.scheme == "data":
with _nvtx.annotate("mm:img:base64_decode", color="lime"): with _nvtx.annotate("mm:img:base64_decode", color="lime"):
# Parse data URL format: data:[<media type>][;base64],<data>
if not parsed_url.path.startswith("image/"): if not parsed_url.path.startswith("image/"):
raise ValueError("Data URL must be an image type") raise ValueError("Data URL must be an image type")
# Split the path into media type and data
media_type, data = parsed_url.path.split(",", 1) media_type, data = parsed_url.path.split(",", 1)
if ";base64" not in media_type: if ";base64" not in media_type:
raise ValueError("Data URL must be base64 encoded") raise ValueError("Data URL must be base64 encoded")
try: try:
image_bytes = base64.b64decode(data) image_bytes = base64.b64decode(data, validate=True)
image_data = BytesIO(image_bytes)
except binascii.Error as e: except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding: {e}") raise ValueError(f"Invalid base64 encoding: {e}") from e
elif parsed_url.scheme in ("http", "https"): image_data = BytesIO(image_bytes)
with _nvtx.annotate("mm:img:http_fetch", color="lime"):
http_client = get_http_client(self._http_timeout)
response = await http_client.get(image_url)
response.raise_for_status()
if not response.content:
raise ValueError("Empty response content from image URL")
image_data = BytesIO(response.content)
elif parsed_url.scheme in ("", "file"): elif parsed_url.scheme in ("", "file"):
# Local file path (plain path or file:// URI)
path = image_url if parsed_url.scheme == "" else parsed_url.path path = image_url if parsed_url.scheme == "" else parsed_url.path
def _read_local_file(p: str) -> bytes: def _read_local_file(p: str) -> bytes:
...@@ -126,38 +192,11 @@ class ImageLoader: ...@@ -126,38 +192,11 @@ class ImageLoader:
else: else:
raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}") raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")
with _nvtx.annotate("mm:img:pil_open_convert", color="lime"): return await self._open_image(image_data)
# PIL is sync, so offload to a thread to avoid blocking the event loop
# Restrict to supported formats to prevent PSD parsing (GHSA-cfh3-3jmp-rvhc)
image = await asyncio.to_thread(
Image.open, image_data, formats=["JPEG", "PNG", "WEBP"]
)
# Validate image format and convert to RGB
if image.format not in ("JPEG", "PNG", "WEBP"):
raise ValueError(f"Unsupported image format: {image.format}")
image_converted = image.convert("RGB")
# Cache HTTP(S) URLs
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
# Cache the image for future use, and evict the oldest image if the cache is full
if self._cache_queue.full():
oldest_image_url = await self._cache_queue.get()
del self._image_cache[oldest_image_url]
self._image_cache[image_url_lower] = image_converted
await self._cache_queue.put(image_url_lower)
return image_converted
except httpx.HTTPError as e:
logger.error(f"HTTP error loading image: {e}")
raise
except Exception as e: except Exception as e:
logger.error(f"Error loading image: {e}") logger.error(f"{type(e).__name__} loading image: '{image_url}': {e}")
raise ValueError(f"Failed to load image: {e}") raise ValueError(f"Failed to load image: '{image_url}': {e}") from e
async def load_image_batch( async def load_image_batch(
self, self,
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for ImageLoader in-flight dedup, cancellation, and error contract."""
import asyncio
from io import BytesIO
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from PIL import Image
from dynamo.common.multimodal.image_loader import ImageLoader
pytestmark = [
pytest.mark.asyncio,
pytest.mark.unit,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]
def _make_png_bytes() -> bytes:
"""Create a minimal valid PNG in memory."""
img = Image.new("RGB", (2, 2), color="red")
buf = BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
PNG_BYTES = _make_png_bytes()
def _mock_http_client(
content: bytes = PNG_BYTES,
status_code: int = 200,
delay: float = 0.0,
side_effect: Exception | None = None,
) -> AsyncMock:
"""Return a mock httpx.AsyncClient whose .get() returns a fake response.
Args:
content: Raw bytes returned as the HTTP response body.
status_code: HTTP status code; >=400 triggers raise_for_status().
delay: Seconds to sleep before responding (simulates network latency).
side_effect: If set, .get() raises this exception instead of returning.
"""
async def _get(url: str) -> Any:
if delay > 0:
await asyncio.sleep(delay)
if side_effect is not None:
raise side_effect
resp = MagicMock(spec=httpx.Response)
resp.status_code = status_code
resp.content = content
resp.raise_for_status = MagicMock()
if status_code >= 400:
resp.raise_for_status.side_effect = httpx.HTTPStatusError(
"error", request=MagicMock(), response=resp
)
return resp
client = AsyncMock()
client.get = AsyncMock(side_effect=_get)
return client
@pytest.fixture(autouse=True)
def loader() -> ImageLoader:
return ImageLoader(cache_size=4, http_timeout=30.0)
# --- Concurrent same-URL dedup ---
async def test_concurrent_same_url_deduplicates(loader: ImageLoader) -> None:
"""Two concurrent load_image calls for the same URL should issue only one HTTP fetch."""
mock_client = _mock_http_client(delay=0.05)
with patch(
"dynamo.common.multimodal.image_loader.get_http_client",
return_value=mock_client,
):
results = await asyncio.gather(
loader.load_image("https://example.com/img.png"),
loader.load_image("https://example.com/img.png"),
)
assert len(results) == 2
assert results[0].size == results[1].size
# Only one HTTP GET should have been issued
assert mock_client.get.call_count == 1
async def test_concurrent_different_urls_fetch_independently(
loader: ImageLoader,
) -> None:
"""Different URLs should each get their own fetch."""
mock_client = _mock_http_client()
with patch(
"dynamo.common.multimodal.image_loader.get_http_client",
return_value=mock_client,
):
await asyncio.gather(
loader.load_image("https://example.com/a.png"),
loader.load_image("https://example.com/b.png"),
)
assert mock_client.get.call_count == 2
# --- Waiter cancellation isolation ---
async def test_waiter_cancellation_does_not_cancel_shared_task(
loader: ImageLoader,
) -> None:
"""Cancelling one waiter should not prevent the other from getting the image."""
mock_client = _mock_http_client(delay=0.1)
with patch(
"dynamo.common.multimodal.image_loader.get_http_client",
return_value=mock_client,
):
task_a = asyncio.create_task(loader.load_image("https://example.com/img.png"))
task_b = asyncio.create_task(loader.load_image("https://example.com/img.png"))
await asyncio.sleep(0.01)
task_a.cancel()
with pytest.raises(asyncio.CancelledError):
await task_a
result_b = await task_b
assert isinstance(result_b, Image.Image)
# --- Retry after failure ---
async def test_retry_after_failure(loader: ImageLoader) -> None:
"""After a fetch failure, the next caller should start a fresh fetch."""
fail_client = _mock_http_client(side_effect=httpx.TimeoutException("timeout"))
ok_client = _mock_http_client()
with patch(
"dynamo.common.multimodal.image_loader.get_http_client",
return_value=fail_client,
):
with pytest.raises(ValueError, match="Timeout"):
await loader.load_image("https://example.com/img.png")
# _inflight should be cleared after failure
assert "https://example.com/img.png" not in loader._inflight
with patch(
"dynamo.common.multimodal.image_loader.get_http_client",
return_value=ok_client,
):
result = await loader.load_image("https://example.com/img.png")
assert isinstance(result, Image.Image)
# --- Error contract preserved for non-HTTP ---
async def test_file_not_found_normalized(loader: ImageLoader) -> None:
"""file:// path that doesn't exist should raise ValueError, not FileNotFoundError."""
with pytest.raises(ValueError, match="Failed to load image"):
await loader.load_image("file:///nonexistent/path/img.png")
async def test_data_url_invalid_base64_normalized(loader: ImageLoader) -> None:
"""Malformed base64 data URL should raise ValueError."""
with pytest.raises(ValueError, match="Invalid base64"):
await loader.load_image("data:image/png;base64,NOT_VALID!!!")
async def test_data_url_non_image_rejected(loader: ImageLoader) -> None:
"""data: URL with non-image media type should raise ValueError."""
with pytest.raises(ValueError, match="Data URL must be an image type"):
await loader.load_image("data:text/plain;base64,aGVsbG8=")
# --- HTTP error contract ---
async def test_http_timeout_raises_valueerror(loader: ImageLoader) -> None:
"""HTTP timeout should be normalized to ValueError."""
mock_client = _mock_http_client(side_effect=httpx.TimeoutException("timed out"))
with patch(
"dynamo.common.multimodal.image_loader.get_http_client",
return_value=mock_client,
):
with pytest.raises(ValueError, match="Timeout loading image"):
await loader.load_image("https://example.com/img.png")
async def test_http_status_error_propagated(loader: ImageLoader) -> None:
"""HTTP 4xx/5xx should propagate as HTTPStatusError."""
mock_client = _mock_http_client(status_code=404)
with patch(
"dynamo.common.multimodal.image_loader.get_http_client",
return_value=mock_client,
):
with pytest.raises(httpx.HTTPStatusError):
await loader.load_image("https://example.com/img.png")
# --- Cache behavior ---
async def test_cache_hit_skips_fetch(loader: ImageLoader) -> None:
"""A cached image should be returned without making an HTTP request."""
img = Image.new("RGB", (2, 2))
loader._image_cache["https://example.com/img.png"] = img
result = await loader.load_image("https://example.com/img.png")
assert result is img
async def test_cache_is_lru_not_fifo(loader: ImageLoader) -> None:
"""Accessing a cached entry should protect it from eviction (LRU, not FIFO)."""
loader._cache_size = 3
mock_client = _mock_http_client()
with patch(
"dynamo.common.multimodal.image_loader.get_http_client",
return_value=mock_client,
):
# Fill cache: a, b, c (oldest → newest)
await loader.load_image("https://example.com/a.png")
await loader.load_image("https://example.com/b.png")
await loader.load_image("https://example.com/c.png")
assert len(loader._image_cache) == 3
# Touch "a" so it becomes most-recently-used
await loader.load_image("https://example.com/a.png")
# Insert "d" — should evict "b" (least recently used), not "a"
await loader.load_image("https://example.com/d.png")
assert "https://example.com/a.png" in loader._image_cache
assert "https://example.com/b.png" not in loader._image_cache
assert "https://example.com/c.png" in loader._image_cache
assert "https://example.com/d.png" in loader._image_cache
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment