Unverified Commit 56f02422 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix: Reject local file inputs in ImageLoader (#8158)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 597b7249
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import base64
......@@ -142,6 +130,10 @@ class ImageLoader:
@_nvtx.annotate("mm:img:load_image", color="lime")
async def load_image(self, image_url: str) -> Image.Image:
parsed_url = urlparse(image_url)
if parsed_url.scheme in ("", "file"):
raise ValueError(
"Invalid image source scheme: local file access is not allowed"
)
if parsed_url.scheme in ("http", "https"):
key = image_url.lower()
......@@ -164,8 +156,8 @@ class ImageLoader:
# shield so cancelling THIS caller doesn't cancel the shared task
return await asyncio.shield(self._inflight[key])
try:
if parsed_url.scheme == "data":
if parsed_url.scheme == "data":
try:
with _nvtx.annotate("mm:img:base64_decode", color="lime"):
if not parsed_url.path.startswith("image/"):
raise ValueError("Data URL must be an image type")
......@@ -179,24 +171,13 @@ class ImageLoader:
except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding: {e}") from e
image_data = BytesIO(image_bytes)
return await self._open_image(image_data)
except Exception as e:
logger.error(f"{type(e).__name__} decoding image: '{image_url}': {e}")
raise ValueError(f"Failed to decoding image: '{image_url}': {e}") from e
elif parsed_url.scheme in ("", "file"):
path = image_url if parsed_url.scheme == "" else parsed_url.path
def _read_local_file(p: str) -> bytes:
with open(p, "rb") as f:
return f.read()
image_bytes = await asyncio.to_thread(_read_local_file, path)
image_data = BytesIO(image_bytes)
else:
raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")
return await self._open_image(image_data)
except Exception as e:
logger.error(f"{type(e).__name__} loading image: '{image_url}': {e}")
raise ValueError(f"Failed to load image: '{image_url}': {e}") from e
# It's not file:, http:, https:, or data:
raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")
async def load_image_batch(
self,
......
......@@ -176,12 +176,24 @@ async def test_retry_after_failure(loader: ImageLoader) -> None:
# --- Error contract preserved for non-HTTP ---
async def test_file_not_found_normalized(loader: ImageLoader) -> None:
"""file:// path that doesn't exist should raise ValueError, not FileNotFoundError."""
with pytest.raises(ValueError, match="Failed to load image"):
async def test_file_url_is_rejected(loader: ImageLoader) -> None:
"""file:// inputs should be rejected before any local file read is attempted."""
with pytest.raises(ValueError, match="Invalid image source scheme"):
await loader.load_image("file:///nonexistent/path/img.png")
@pytest.mark.parametrize("url_factory", [lambda p: p.as_uri(), lambda p: str(p)])
async def test_local_file_inputs_are_rejected(
loader: ImageLoader, tmp_path, url_factory
) -> None:
"""Local filesystem image inputs must be rejected for both file:// and bare paths."""
image_path = tmp_path / "secret.png"
Image.new("RGB", (1, 1), color="red").save(image_path, format="PNG")
with pytest.raises(ValueError, match="Invalid image source scheme"):
await loader.load_image(url_factory(image_path))
async def test_data_url_invalid_base64_normalized(loader: ImageLoader) -> None:
"""Malformed base64 data URL should raise ValueError."""
with pytest.raises(ValueError, match="Invalid base64"):
......
......@@ -17,6 +17,8 @@ SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
source "$SCRIPT_DIR/../../../common/launch_utils.sh"
MODEL="Wan-AI/Wan2.2-TI2V-5B-Diffusers"
# Not a valid PNG, example only
INPUT_REFERENCE_DATA_URL="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+aX3kAAAAASUVORK5CYII="
# Parse command line arguments
EXTRA_ARGS=()
......@@ -45,7 +47,7 @@ curl -s http://localhost:${HTTP_PORT}/v1/videos \\
-d '{
"model": "${MODEL}",
"prompt": "A bear sleeping",
"input_reference": "/tmp/input.png",
"input_reference": "${INPUT_REFERENCE_DATA_URL}",
"size": "832x480",
"response_format": "url",
"nvext": {
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import base64
import dataclasses
import logging
import os
import tempfile
from dataclasses import dataclass, field
from io import BytesIO
from typing import Any
import pytest
......@@ -91,15 +92,13 @@ class VideoGenerationPayload(BasePayload):
class I2VPayload(VideoGenerationPayload):
"""Payload for image-to-video via /v1/videos with input_reference."""
_tmp_dir: Any = field(default=None, init=False, repr=False, compare=False)
def __post_init__(self):
from PIL import Image
self._tmp_dir = tempfile.TemporaryDirectory()
path = os.path.join(self._tmp_dir.name, "input.png")
Image.new("RGB", (64, 64), color="red").save(path)
self.body["input_reference"] = path
image_buffer = BytesIO()
Image.new("RGB", (64, 64), color="red").save(image_buffer, format="PNG")
image_b64 = base64.b64encode(image_buffer.getvalue()).decode("ascii")
self.body["input_reference"] = f"data:image/png;base64,{image_b64}"
@dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment