Unverified Commit 677424c7 authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[Core][CI] Add opt-in media URL caching via VLLM_MEDIA_CACHE (#37123)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent 1031c84c
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
import asyncio import asyncio
import mimetypes import mimetypes
import os import os
import shutil
import time
from tempfile import NamedTemporaryFile, TemporaryDirectory from tempfile import NamedTemporaryFile, TemporaryDirectory
import aiohttp import aiohttp
...@@ -375,3 +377,113 @@ async def test_ssrf_bypass_backslash_disallowed_domain(): ...@@ -375,3 +377,113 @@ async def test_ssrf_bypass_backslash_disallowed_domain():
with pytest.raises(ValueError, match="allowed domains"): with pytest.raises(ValueError, match="allowed domains"):
await connector.fetch_image_async(bypass_url) await connector.fetch_image_async(bypass_url)
def _make_cached_connector(cache_dir, *, max_mb=10, ttl_hours=24):
"""Create a MediaConnector with caching enabled via monkeypatched internals.
We bypass __init__'s env-var path and wire up the cache fields directly
so tests don't depend on environment variables. URLs in these tests are
only used as cache keys (hashed to derive filenames); no HTTP requests
are made.
"""
connector = MediaConnector()
connector._media_cache_dir = cache_dir
connector._media_cache_max_bytes = max_mb * 1024 * 1024
connector._media_cache_ttl_secs = ttl_hours * 3600
return connector
def test_cache_put_and_get():
"""Basic round-trip: put bytes, get them back."""
with TemporaryDirectory() as cache_dir:
connector = _make_cached_connector(cache_dir)
url = "https://example.com/image.png"
data = b"fake-image-bytes"
connector._put_cached_bytes(url, data)
cached = connector._get_cached_bytes(url)
assert cached == data
def test_cache_ttl_expiry():
"""Entries older than TTL are evicted on read."""
with TemporaryDirectory() as cache_dir:
connector = _make_cached_connector(cache_dir, ttl_hours=24)
url = "https://example.com/old.png"
data = b"old-data"
connector._put_cached_bytes(url, data)
# Backdate the file's mtime so it appears expired
cache_path = connector._media_cache_path(url)
expired_time = time.time() - (25 * 3600) # 25 hours ago
os.utime(cache_path, (expired_time, expired_time))
assert connector._get_cached_bytes(url) is None
assert not cache_path.exists()
def test_cache_lru_eviction():
"""Oldest entries are evicted when cache exceeds size budget."""
with TemporaryDirectory() as cache_dir:
# Set a very small max size: 100 bytes
connector = _make_cached_connector(cache_dir, max_mb=0)
connector._media_cache_max_bytes = 100
# Write three 50-byte entries (total 150 > 100 budget)
urls = [f"https://example.com/{i}.png" for i in range(3)]
for i, url in enumerate(urls):
connector._put_cached_bytes(url, b"x" * 50)
# Stagger mtime so eviction order is deterministic
path = connector._media_cache_path(url)
os.utime(path, (time.time() + i, time.time() + i))
# The oldest entry (urls[0]) should have been evicted
assert connector._get_cached_bytes(urls[0]) is None
# The newest entries should still be present
assert connector._get_cached_bytes(urls[2]) == b"x" * 50
def test_cache_ttl_eviction_during_write():
"""_maybe_evict removes expired files even if under size budget."""
with TemporaryDirectory() as cache_dir:
connector = _make_cached_connector(cache_dir, ttl_hours=1)
url_old = "https://example.com/stale.png"
url_new = "https://example.com/fresh.png"
connector._put_cached_bytes(url_old, b"stale")
# Backdate old entry past TTL
old_path = connector._media_cache_path(url_old)
expired_time = time.time() - (2 * 3600)
os.utime(old_path, (expired_time, expired_time))
# Writing a new entry triggers _maybe_evict
connector._put_cached_bytes(url_new, b"fresh")
assert not old_path.exists()
assert connector._get_cached_bytes(url_new) == b"fresh"
def test_put_cached_bytes_missing_dir():
"""_put_cached_bytes does not crash when the cache dir disappears."""
with TemporaryDirectory() as cache_dir:
connector = _make_cached_connector(cache_dir)
# Remove the directory to simulate it disappearing at runtime
shutil.rmtree(cache_dir)
# Should not raise (graceful degradation)
connector._put_cached_bytes("https://example.com/x.png", b"data")
def test_get_cached_bytes_file_deleted_before_read():
"""_get_cached_bytes returns None if the file vanishes mid-read."""
with TemporaryDirectory() as cache_dir:
connector = _make_cached_connector(cache_dir)
url = "https://example.com/vanish.png"
connector._put_cached_bytes(url, b"data")
# Delete the file to simulate concurrent eviction
connector._media_cache_path(url).unlink()
assert connector._get_cached_bytes(url) is None
...@@ -64,6 +64,9 @@ if TYPE_CHECKING: ...@@ -64,6 +64,9 @@ if TYPE_CHECKING:
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_VIDEO_FETCH_TIMEOUT: int = 30
VLLM_AUDIO_FETCH_TIMEOUT: int = 10 VLLM_AUDIO_FETCH_TIMEOUT: int = 10
VLLM_MEDIA_CACHE: str = ""
VLLM_MEDIA_CACHE_MAX_SIZE_MB: int = 5120
VLLM_MEDIA_CACHE_TTL_HOURS: float = 24
VLLM_MEDIA_FETCH_MAX_RETRIES: int = 3 VLLM_MEDIA_FETCH_MAX_RETRIES: int = 3
VLLM_MEDIA_URL_ALLOW_REDIRECTS: bool = True VLLM_MEDIA_URL_ALLOW_REDIRECTS: bool = True
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8 VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
...@@ -776,6 +779,19 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -776,6 +779,19 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_AUDIO_FETCH_TIMEOUT": lambda: int( "VLLM_AUDIO_FETCH_TIMEOUT": lambda: int(
os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10") os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")
), ),
# Directory for caching media downloads (images, video, audio fetched
# from URLs during inference). Empty string disables caching.
"VLLM_MEDIA_CACHE": lambda: os.getenv("VLLM_MEDIA_CACHE", ""),
# Maximum cache size in MB. When exceeded, least-recently-used entries
# are evicted. Default is 5120 (5 GB).
"VLLM_MEDIA_CACHE_MAX_SIZE_MB": lambda: int(
os.getenv("VLLM_MEDIA_CACHE_MAX_SIZE_MB", "5120")
),
# Time-to-live in hours for cached media files. Entries older than this
# are evicted regardless of cache size. Default is 24 hours.
"VLLM_MEDIA_CACHE_TTL_HOURS": lambda: float(
os.getenv("VLLM_MEDIA_CACHE_TTL_HOURS", "24")
),
# Maximum number of retries for fetching media (images, audio, video) # Maximum number of retries for fetching media (images, audio, video)
# from URLs. Each retry quadruples the timeout. Default is 3. # from URLs. Each retry quadruples the timeout. Default is 3.
"VLLM_MEDIA_FETCH_MAX_RETRIES": lambda: int( "VLLM_MEDIA_FETCH_MAX_RETRIES": lambda: int(
...@@ -1777,6 +1793,9 @@ def compile_factors() -> dict[str, object]: ...@@ -1777,6 +1793,9 @@ def compile_factors() -> dict[str, object]:
"VLLM_IMAGE_FETCH_TIMEOUT", "VLLM_IMAGE_FETCH_TIMEOUT",
"VLLM_VIDEO_FETCH_TIMEOUT", "VLLM_VIDEO_FETCH_TIMEOUT",
"VLLM_AUDIO_FETCH_TIMEOUT", "VLLM_AUDIO_FETCH_TIMEOUT",
"VLLM_MEDIA_CACHE",
"VLLM_MEDIA_CACHE_MAX_SIZE_MB",
"VLLM_MEDIA_CACHE_TTL_HOURS",
"VLLM_MEDIA_FETCH_MAX_RETRIES", "VLLM_MEDIA_FETCH_MAX_RETRIES",
"VLLM_MEDIA_URL_ALLOW_REDIRECTS", "VLLM_MEDIA_URL_ALLOW_REDIRECTS",
"VLLM_MEDIA_LOADING_THREAD_COUNT", "VLLM_MEDIA_LOADING_THREAD_COUNT",
......
...@@ -3,6 +3,11 @@ ...@@ -3,6 +3,11 @@
import asyncio import asyncio
import atexit import atexit
import contextlib
import hashlib
import os
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from typing import Any, TypeVar from typing import Any, TypeVar
...@@ -16,6 +21,7 @@ from urllib3.util import Url, parse_url ...@@ -16,6 +21,7 @@ from urllib3.util import Url, parse_url
import vllm.envs as envs import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection from vllm.connections import HTTPConnection, global_http_connection
from vllm.logger import init_logger
from vllm.utils.registry import ExtensionManager from vllm.utils.registry import ExtensionManager
from .audio import AudioEmbeddingMediaIO, AudioMediaIO from .audio import AudioEmbeddingMediaIO, AudioMediaIO
...@@ -23,6 +29,8 @@ from .base import MediaIO ...@@ -23,6 +29,8 @@ from .base import MediaIO
from .image import ImageEmbeddingMediaIO, ImageMediaIO from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .video import VideoMediaIO from .video import VideoMediaIO
logger = init_logger(__name__)
_M = TypeVar("_M") _M = TypeVar("_M")
global_thread_pool = ThreadPoolExecutor( global_thread_pool = ThreadPoolExecutor(
...@@ -116,6 +124,115 @@ class MediaConnector: ...@@ -116,6 +124,115 @@ class MediaConnector:
allowed_media_domains = [] allowed_media_domains = []
self.allowed_media_domains = allowed_media_domains self.allowed_media_domains = allowed_media_domains
# Media download cache (opt-in via VLLM_MEDIA_CACHE)
self._media_cache_dir: str | None = None
self._media_cache_max_bytes: int = 0
self._media_cache_ttl_secs: float = 0
media_cache = envs.VLLM_MEDIA_CACHE
if media_cache:
try:
os.makedirs(media_cache, exist_ok=True)
# Verify the directory is writable before enabling caching
with tempfile.NamedTemporaryFile(dir=media_cache, delete=True):
pass
self._media_cache_dir = media_cache
self._media_cache_max_bytes = (
envs.VLLM_MEDIA_CACHE_MAX_SIZE_MB * 1024 * 1024
)
self._media_cache_ttl_secs = envs.VLLM_MEDIA_CACHE_TTL_HOURS * 3600
logger.info(
"Media cache enabled at %s (max %d MB, TTL %s hours)",
media_cache,
envs.VLLM_MEDIA_CACHE_MAX_SIZE_MB,
envs.VLLM_MEDIA_CACHE_TTL_HOURS,
)
except OSError:
logger.warning(
"VLLM_MEDIA_CACHE path %s is not writable, media caching disabled",
media_cache,
)
def _get_cached_bytes(self, url: str) -> bytes | None:
"""Return cached bytes for a URL, or None if not cached/expired."""
if not self._media_cache_dir:
return None
cache_path = self._media_cache_path(url)
# Check TTL
try:
age = time.time() - cache_path.stat().st_mtime
except OSError:
return None
if age > self._media_cache_ttl_secs:
cache_path.unlink(missing_ok=True)
return None
# Touch mtime for LRU ordering
try:
cache_path.touch()
return cache_path.read_bytes()
except OSError:
return None
def _put_cached_bytes(self, url: str, data: bytes) -> None:
"""Store downloaded bytes and evict if over budget."""
if not self._media_cache_dir:
return
cache_path = self._media_cache_path(url)
# Atomic write via temp file + rename
tmp_path = None
try:
with tempfile.NamedTemporaryFile(
mode="wb", dir=self._media_cache_dir, delete=False
) as tmp_file:
tmp_file.write(data)
tmp_path = tmp_file.name
os.rename(tmp_path, str(cache_path))
except OSError:
# Another process beat us or disk issue
if tmp_path is not None:
with contextlib.suppress(OSError):
os.remove(tmp_path)
return
self._maybe_evict(exclude=cache_path)
def _maybe_evict(self, exclude: Path | None = None) -> None:
"""Evict expired entries first, then LRU until under size limit."""
cache_dir = Path(self._media_cache_dir) # type: ignore[arg-type]
entries = []
expired = []
total_size = 0
now = time.time()
for f in cache_dir.iterdir():
if f.name.startswith("."):
continue
try:
stat = f.stat()
except OSError:
continue
age = now - stat.st_mtime
if age > self._media_cache_ttl_secs:
expired.append(f)
continue
total_size += stat.st_size
# Never evict the file we just wrote
if exclude is not None and f.name == exclude.name:
continue
entries.append((stat.st_mtime, stat.st_size, f))
# Evict items according to LRU policy
entries.sort(key=lambda e: e[0], reverse=True)
while total_size > self._media_cache_max_bytes and entries:
mtime, size, f = entries.pop()
expired.append(f)
total_size -= size
for f in expired:
f.unlink(missing_ok=True)
def _media_cache_path(self, url: str) -> Path:
url_hash = hashlib.sha256(url.encode()).hexdigest()[:20]
ext = Path(url.split("?", 1)[0]).suffix or ""
return Path(self._media_cache_dir) / f"{url_hash}{ext}" # type: ignore[arg-type]
def _load_data_url( def _load_data_url(
self, self,
url_spec: Url, url_spec: Url,
...@@ -178,6 +295,10 @@ class MediaConnector: ...@@ -178,6 +295,10 @@ class MediaConnector:
if url_spec.scheme and url_spec.scheme.startswith("http"): if url_spec.scheme and url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec) self._assert_url_in_allowed_media_domains(url_spec)
cached = self._get_cached_bytes(url)
if cached is not None:
return media_io.load_bytes(cached)
connection = self.connection connection = self.connection
data = connection.get_bytes( data = connection.get_bytes(
url_spec.url, url_spec.url,
...@@ -185,6 +306,7 @@ class MediaConnector: ...@@ -185,6 +306,7 @@ class MediaConnector:
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS, allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
) )
self._put_cached_bytes(url, data)
return media_io.load_bytes(data) return media_io.load_bytes(data)
if url_spec.scheme == "data": if url_spec.scheme == "data":
...@@ -209,12 +331,25 @@ class MediaConnector: ...@@ -209,12 +331,25 @@ class MediaConnector:
if url_spec.scheme and url_spec.scheme.startswith("http"): if url_spec.scheme and url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec) self._assert_url_in_allowed_media_domains(url_spec)
cached = await loop.run_in_executor(
global_thread_pool, self._get_cached_bytes, url
)
if cached is not None:
future = loop.run_in_executor(
global_thread_pool, media_io.load_bytes, cached
)
return await future
connection = self.connection connection = self.connection
data = await connection.async_get_bytes( data = await connection.async_get_bytes(
url_spec.url, url_spec.url,
timeout=fetch_timeout, timeout=fetch_timeout,
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS, allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
) )
await loop.run_in_executor(
global_thread_pool, self._put_cached_bytes, url, data
)
future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data) future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data)
return await future return await future
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment