Unverified Commit 57861ae4 authored by Juan Pérez de Algaba's avatar Juan Pérez de Algaba Committed by GitHub
Browse files

(security) Fix SSRF in batch runner download_bytes_from_url (#38482)


Signed-off-by: default avatarjperezde <jperezde@redhat.com>
parent ac30a831
......@@ -66,6 +66,10 @@ Restrict domains that vLLM can access for media URLs by setting
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
This protection applies to both the online serving API (multimodal inputs) and
the **batch runner** (`vllm run-batch`), where `file_url` values in batch
transcription/translation requests are validated against the same allowlist.
Without domain restrictions, a malicious user could supply URLs that:
- **Target internal services**: Access internal network endpoints, cloud metadata
......
......@@ -4,11 +4,15 @@
import json
import subprocess
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from vllm.assets.audio import AudioAsset
from vllm.entrypoints.openai.run_batch import BatchRequestOutput
from vllm.entrypoints.openai.run_batch import (
BatchRequestOutput,
download_bytes_from_url,
)
CHAT_MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
EMBEDDING_MODEL_NAME = "intfloat/multilingual-e5-small"
......@@ -746,3 +750,131 @@ def test_tool_calling():
assert "arguments" in tool_call["function"]
# Verify the tool name matches our tool definition
assert tool_call["function"]["name"] == "get_current_weather"
# ---------------------------------------------------------------------------
# Unit tests for download_bytes_from_url SSRF protection
# ---------------------------------------------------------------------------
def _make_aiohttp_mocks(response_data: bytes = b"fake-data", status: int = 200):
"""Create mock objects that simulate aiohttp.ClientSession context managers."""
mock_resp = MagicMock()
mock_resp.status = status
mock_resp.read = AsyncMock(return_value=response_data)
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
mock_session = MagicMock()
mock_session.get = MagicMock(return_value=mock_resp)
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
return mock_session
@pytest.mark.asyncio
async def test_download_bytes_data_url_bypasses_domain_check():
"""data: URLs must work regardless of the domain allowlist."""
data_url = f"data:audio/wav;base64,{MINIMAL_WAV_BASE64}"
result = await download_bytes_from_url(
data_url, allowed_media_domains=["example.com"]
)
assert isinstance(result, bytes)
assert len(result) > 0
@pytest.mark.asyncio
async def test_download_bytes_rejects_disallowed_domain():
"""HTTP URLs whose hostname is not in the allowlist must be rejected."""
url = "https://evil.internal/secret"
with pytest.raises(ValueError, match="allowed domains"):
await download_bytes_from_url(url, allowed_media_domains=["example.com"])
@pytest.mark.asyncio
async def test_download_bytes_rejects_cloud_metadata_ip():
"""Cloud metadata endpoints must be blocked when an allowlist is set."""
url = "http://169.254.169.254/latest/meta-data/"
with pytest.raises(ValueError, match="allowed domains"):
await download_bytes_from_url(url, allowed_media_domains=["example.com"])
@pytest.mark.asyncio
async def test_download_bytes_rejects_internal_ip():
"""Private-range IPs must be blocked when an allowlist is set."""
for internal_url in [
"http://10.0.0.1/secret",
"http://192.168.1.1/admin",
"http://127.0.0.1:8080/internal",
]:
with pytest.raises(ValueError, match="allowed domains"):
await download_bytes_from_url(
internal_url, allowed_media_domains=["example.com"]
)
@pytest.mark.asyncio
async def test_download_bytes_allows_permitted_domain():
"""HTTP URLs whose hostname IS in the allowlist must be fetched."""
url = "https://example.com/audio.wav"
expected = b"audio-bytes"
mock_session = _make_aiohttp_mocks(expected)
with patch(
"vllm.entrypoints.openai.run_batch.aiohttp.ClientSession",
return_value=mock_session,
):
result = await download_bytes_from_url(
url, allowed_media_domains=["example.com"]
)
assert result == expected
@pytest.mark.asyncio
async def test_download_bytes_no_allowlist_permits_any_domain():
"""Without an allowlist all HTTP URLs must be attempted (backward compat)."""
url = "https://any-domain.example.org/file.wav"
expected = b"some-data"
mock_session = _make_aiohttp_mocks(expected)
with patch(
"vllm.entrypoints.openai.run_batch.aiohttp.ClientSession",
return_value=mock_session,
):
result = await download_bytes_from_url(url, allowed_media_domains=None)
assert result == expected
@pytest.mark.asyncio
async def test_download_bytes_empty_allowlist_denies_all():
"""An empty allowlist must deny all HTTP URLs (least privilege)."""
url = "https://any-domain.example.org/file.wav"
with pytest.raises(ValueError, match="allowed domains"):
await download_bytes_from_url(url, allowed_media_domains=[])
@pytest.mark.asyncio
async def test_download_bytes_unsupported_scheme():
"""Unsupported URL schemes must be rejected regardless of allowlist."""
with pytest.raises(ValueError, match="Unsupported URL scheme"):
await download_bytes_from_url("ftp://example.com/file.wav")
with pytest.raises(ValueError, match="Unsupported URL scheme"):
await download_bytes_from_url(
"ftp://example.com/file.wav",
allowed_media_domains=["example.com"],
)
@pytest.mark.asyncio
async def test_download_bytes_backslash_bypass():
"""Backslash-@ URL confusion must not bypass the allowlist.
urllib3.parse_url() and aiohttp/yarl disagree on backslash-before-@.
The fix normalizes through urllib3 before handing to aiohttp.
"""
bypass_url = "http://allowed.example.com\\@evil.internal/secret"
with pytest.raises(ValueError, match="allowed domains"):
await download_bytes_from_url(
bypass_url, allowed_media_domains=["evil.internal"]
)
......@@ -20,7 +20,9 @@ from pydantic import Field, TypeAdapter, field_validator, model_validator
from pydantic_core.core_schema import ValidationInfo
from starlette.datastructures import State
from tqdm import tqdm
from urllib3.util import parse_url
import vllm.envs as envs
from vllm.config import config
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
......@@ -439,19 +441,25 @@ async def write_file(
await write_local_file(path_or_url, batch_outputs)
async def download_bytes_from_url(url: str) -> bytes:
async def download_bytes_from_url(
url: str,
allowed_media_domains: list[str] | None = None,
) -> bytes:
"""
Download data from a URL or decode from a data URL.
Args:
url: Either an HTTP/HTTPS URL or a data URL (data:...;base64,...)
allowed_media_domains: If set, only HTTP/HTTPS URLs whose hostname
is in this list are permitted. data: URLs are not subject to
this restriction.
Returns:
Data as bytes
"""
parsed = urlparse(url)
# Handle data URLs (base64 encoded)
# Handle data URLs (base64 encoded) - not subject to domain restrictions
if parsed.scheme == "data":
# Format: data:...;base64,<base64_data>
if "," in url:
......@@ -465,9 +473,24 @@ async def download_bytes_from_url(url: str) -> bytes:
# Handle HTTP/HTTPS URLs
elif parsed.scheme in ("http", "https"):
if allowed_media_domains is not None:
url_spec = parse_url(url)
if url_spec.hostname not in allowed_media_domains:
raise ValueError(
f"The URL must be from one of the allowed domains: "
f"{allowed_media_domains}. Input URL domain: "
f"{url_spec.hostname}"
)
# Use the normalized URL to prevent parsing discrepancies
# between urllib3 and aiohttp (e.g. backslash-@ attacks).
url = url_spec.url
async with (
aiohttp.ClientSession() as session,
session.get(url) as resp,
session.get(
url,
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
) as resp,
):
if resp.status != 200:
raise Exception(
......@@ -593,7 +616,10 @@ def handle_endpoint_request(
return run_request(handler_fn, request, tracker)
def make_transcription_wrapper(is_translation: bool) -> WrapperFn:
def make_transcription_wrapper(
is_translation: bool,
allowed_media_domains: list[str] | None = None,
) -> WrapperFn:
"""
Factory function to create a wrapper for transcription/translation handlers.
The wrapper converts BatchTranscriptionRequest or BatchTranslationRequest
......@@ -602,6 +628,8 @@ def make_transcription_wrapper(is_translation: bool) -> WrapperFn:
Args:
is_translation: If True, process as translation; otherwise process
as transcription
allowed_media_domains: If set, only URLs from these domains are
permitted for HTTP/HTTPS fetches.
Returns:
A function that takes a handler and returns a wrapped handler
......@@ -619,7 +647,10 @@ def make_transcription_wrapper(is_translation: bool) -> WrapperFn:
):
try:
# Download data from URL
audio_data = await download_bytes_from_url(batch_request_body.file_url)
audio_data = await download_bytes_from_url(
batch_request_body.file_url,
allowed_media_domains=allowed_media_domains,
)
# Create a mock file from the downloaded audio data
mock_file = UploadFile(
......@@ -691,6 +722,8 @@ async def build_endpoint_registry(
serving_embedding = getattr(state, "serving_embedding", None)
serving_scores = getattr(state, "serving_scores", None)
allowed_media_domains = getattr(args, "allowed_media_domains", None)
# Registry of endpoint configurations
endpoint_registry: dict[str, dict[str, Any]] = {
"completions": {
......@@ -730,7 +763,10 @@ async def build_endpoint_registry(
if openai_serving_transcription is not None
else None
),
"wrapper_fn": make_transcription_wrapper(is_translation=False),
"wrapper_fn": make_transcription_wrapper(
is_translation=False,
allowed_media_domains=allowed_media_domains,
),
},
"translations": {
"url_matcher": lambda url: url == "/v1/audio/translations",
......@@ -739,7 +775,10 @@ async def build_endpoint_registry(
if openai_serving_translation is not None
else None
),
"wrapper_fn": make_transcription_wrapper(is_translation=True),
"wrapper_fn": make_transcription_wrapper(
is_translation=True,
allowed_media_domains=allowed_media_domains,
),
},
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment