Unverified Commit 56f02422 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix: Reject local file inputs in ImageLoader (#8158)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 597b7249
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio import asyncio
import base64 import base64
...@@ -142,6 +130,10 @@ class ImageLoader: ...@@ -142,6 +130,10 @@ class ImageLoader:
@_nvtx.annotate("mm:img:load_image", color="lime") @_nvtx.annotate("mm:img:load_image", color="lime")
async def load_image(self, image_url: str) -> Image.Image: async def load_image(self, image_url: str) -> Image.Image:
parsed_url = urlparse(image_url) parsed_url = urlparse(image_url)
if parsed_url.scheme in ("", "file"):
raise ValueError(
"Invalid image source scheme: local file access is not allowed"
)
if parsed_url.scheme in ("http", "https"): if parsed_url.scheme in ("http", "https"):
key = image_url.lower() key = image_url.lower()
...@@ -164,8 +156,8 @@ class ImageLoader: ...@@ -164,8 +156,8 @@ class ImageLoader:
# shield so cancelling THIS caller doesn't cancel the shared task # shield so cancelling THIS caller doesn't cancel the shared task
return await asyncio.shield(self._inflight[key]) return await asyncio.shield(self._inflight[key])
try: if parsed_url.scheme == "data":
if parsed_url.scheme == "data": try:
with _nvtx.annotate("mm:img:base64_decode", color="lime"): with _nvtx.annotate("mm:img:base64_decode", color="lime"):
if not parsed_url.path.startswith("image/"): if not parsed_url.path.startswith("image/"):
raise ValueError("Data URL must be an image type") raise ValueError("Data URL must be an image type")
...@@ -179,24 +171,13 @@ class ImageLoader: ...@@ -179,24 +171,13 @@ class ImageLoader:
except binascii.Error as e: except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding: {e}") from e raise ValueError(f"Invalid base64 encoding: {e}") from e
image_data = BytesIO(image_bytes) image_data = BytesIO(image_bytes)
return await self._open_image(image_data)
except Exception as e:
logger.error(f"{type(e).__name__} decoding image: '{image_url}': {e}")
raise ValueError(f"Failed to decoding image: '{image_url}': {e}") from e
elif parsed_url.scheme in ("", "file"): # It's not file:, http:, https:, or data:
path = image_url if parsed_url.scheme == "" else parsed_url.path raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")
def _read_local_file(p: str) -> bytes:
with open(p, "rb") as f:
return f.read()
image_bytes = await asyncio.to_thread(_read_local_file, path)
image_data = BytesIO(image_bytes)
else:
raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")
return await self._open_image(image_data)
except Exception as e:
logger.error(f"{type(e).__name__} loading image: '{image_url}': {e}")
raise ValueError(f"Failed to load image: '{image_url}': {e}") from e
async def load_image_batch( async def load_image_batch(
self, self,
......
...@@ -176,12 +176,24 @@ async def test_retry_after_failure(loader: ImageLoader) -> None: ...@@ -176,12 +176,24 @@ async def test_retry_after_failure(loader: ImageLoader) -> None:
# --- Error contract preserved for non-HTTP --- # --- Error contract preserved for non-HTTP ---
async def test_file_not_found_normalized(loader: ImageLoader) -> None: async def test_file_url_is_rejected(loader: ImageLoader) -> None:
"""file:// path that doesn't exist should raise ValueError, not FileNotFoundError.""" """file:// inputs should be rejected before any local file read is attempted."""
with pytest.raises(ValueError, match="Failed to load image"): with pytest.raises(ValueError, match="Invalid image source scheme"):
await loader.load_image("file:///nonexistent/path/img.png") await loader.load_image("file:///nonexistent/path/img.png")
@pytest.mark.parametrize("url_factory", [lambda p: p.as_uri(), lambda p: str(p)])
async def test_local_file_inputs_are_rejected(
loader: ImageLoader, tmp_path, url_factory
) -> None:
"""Local filesystem image inputs must be rejected for both file:// and bare paths."""
image_path = tmp_path / "secret.png"
Image.new("RGB", (1, 1), color="red").save(image_path, format="PNG")
with pytest.raises(ValueError, match="Invalid image source scheme"):
await loader.load_image(url_factory(image_path))
async def test_data_url_invalid_base64_normalized(loader: ImageLoader) -> None: async def test_data_url_invalid_base64_normalized(loader: ImageLoader) -> None:
"""Malformed base64 data URL should raise ValueError.""" """Malformed base64 data URL should raise ValueError."""
with pytest.raises(ValueError, match="Invalid base64"): with pytest.raises(ValueError, match="Invalid base64"):
......
...@@ -17,6 +17,8 @@ SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" ...@@ -17,6 +17,8 @@ SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
source "$SCRIPT_DIR/../../../common/launch_utils.sh" source "$SCRIPT_DIR/../../../common/launch_utils.sh"
MODEL="Wan-AI/Wan2.2-TI2V-5B-Diffusers" MODEL="Wan-AI/Wan2.2-TI2V-5B-Diffusers"
# Not a valid PNG, example only
INPUT_REFERENCE_DATA_URL="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+aX3kAAAAASUVORK5CYII="
# Parse command line arguments # Parse command line arguments
EXTRA_ARGS=() EXTRA_ARGS=()
...@@ -45,7 +47,7 @@ curl -s http://localhost:${HTTP_PORT}/v1/videos \\ ...@@ -45,7 +47,7 @@ curl -s http://localhost:${HTTP_PORT}/v1/videos \\
-d '{ -d '{
"model": "${MODEL}", "model": "${MODEL}",
"prompt": "A bear sleeping", "prompt": "A bear sleeping",
"input_reference": "/tmp/input.png", "input_reference": "${INPUT_REFERENCE_DATA_URL}",
"size": "832x480", "size": "832x480",
"response_format": "url", "response_format": "url",
"nvext": { "nvext": {
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import base64
import dataclasses import dataclasses
import logging import logging
import os import os
import tempfile
from dataclasses import dataclass, field from dataclasses import dataclass, field
from io import BytesIO
from typing import Any from typing import Any
import pytest import pytest
...@@ -91,15 +92,13 @@ class VideoGenerationPayload(BasePayload): ...@@ -91,15 +92,13 @@ class VideoGenerationPayload(BasePayload):
class I2VPayload(VideoGenerationPayload): class I2VPayload(VideoGenerationPayload):
"""Payload for image-to-video via /v1/videos with input_reference.""" """Payload for image-to-video via /v1/videos with input_reference."""
_tmp_dir: Any = field(default=None, init=False, repr=False, compare=False)
def __post_init__(self): def __post_init__(self):
from PIL import Image from PIL import Image
self._tmp_dir = tempfile.TemporaryDirectory() image_buffer = BytesIO()
path = os.path.join(self._tmp_dir.name, "input.png") Image.new("RGB", (64, 64), color="red").save(image_buffer, format="PNG")
Image.new("RGB", (64, 64), color="red").save(path) image_b64 = base64.b64encode(image_buffer.getvalue()).decode("ascii")
self.body["input_reference"] = path self.body["input_reference"] = f"data:image/png;base64,{image_b64}"
@dataclass @dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment