Unverified Commit 57861ae4 authored by Juan Pérez de Algaba's avatar Juan Pérez de Algaba Committed by GitHub
Browse files

(security) Fix SSRF in batch runner download_bytes_from_url (#38482)


Signed-off-by: default avatarjperezde <jperezde@redhat.com>
parent ac30a831
...@@ -66,6 +66,10 @@ Restrict domains that vLLM can access for media URLs by setting ...@@ -66,6 +66,10 @@ Restrict domains that vLLM can access for media URLs by setting
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks. `--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`) (e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
This protection applies to both the online serving API (multimodal inputs) and
the **batch runner** (`vllm run-batch`), where `file_url` values in batch
transcription/translation requests are validated against the same allowlist.
Without domain restrictions, a malicious user could supply URLs that: Without domain restrictions, a malicious user could supply URLs that:
- **Target internal services**: Access internal network endpoints, cloud metadata - **Target internal services**: Access internal network endpoints, cloud metadata
......
...@@ -4,11 +4,15 @@ ...@@ -4,11 +4,15 @@
import json import json
import subprocess import subprocess
import tempfile import tempfile
from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
from vllm.entrypoints.openai.run_batch import BatchRequestOutput from vllm.entrypoints.openai.run_batch import (
BatchRequestOutput,
download_bytes_from_url,
)
CHAT_MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" CHAT_MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
EMBEDDING_MODEL_NAME = "intfloat/multilingual-e5-small" EMBEDDING_MODEL_NAME = "intfloat/multilingual-e5-small"
...@@ -746,3 +750,131 @@ def test_tool_calling(): ...@@ -746,3 +750,131 @@ def test_tool_calling():
assert "arguments" in tool_call["function"] assert "arguments" in tool_call["function"]
# Verify the tool name matches our tool definition # Verify the tool name matches our tool definition
assert tool_call["function"]["name"] == "get_current_weather" assert tool_call["function"]["name"] == "get_current_weather"
# ---------------------------------------------------------------------------
# Unit tests for download_bytes_from_url SSRF protection
# ---------------------------------------------------------------------------
def _make_aiohttp_mocks(response_data: bytes = b"fake-data", status: int = 200):
"""Create mock objects that simulate aiohttp.ClientSession context managers."""
mock_resp = MagicMock()
mock_resp.status = status
mock_resp.read = AsyncMock(return_value=response_data)
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
mock_session = MagicMock()
mock_session.get = MagicMock(return_value=mock_resp)
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
return mock_session
@pytest.mark.asyncio
async def test_download_bytes_data_url_bypasses_domain_check():
"""data: URLs must work regardless of the domain allowlist."""
data_url = f"data:audio/wav;base64,{MINIMAL_WAV_BASE64}"
result = await download_bytes_from_url(
data_url, allowed_media_domains=["example.com"]
)
assert isinstance(result, bytes)
assert len(result) > 0
@pytest.mark.asyncio
async def test_download_bytes_rejects_disallowed_domain():
"""HTTP URLs whose hostname is not in the allowlist must be rejected."""
url = "https://evil.internal/secret"
with pytest.raises(ValueError, match="allowed domains"):
await download_bytes_from_url(url, allowed_media_domains=["example.com"])
@pytest.mark.asyncio
async def test_download_bytes_rejects_cloud_metadata_ip():
"""Cloud metadata endpoints must be blocked when an allowlist is set."""
url = "http://169.254.169.254/latest/meta-data/"
with pytest.raises(ValueError, match="allowed domains"):
await download_bytes_from_url(url, allowed_media_domains=["example.com"])
@pytest.mark.asyncio
async def test_download_bytes_rejects_internal_ip():
"""Private-range IPs must be blocked when an allowlist is set."""
for internal_url in [
"http://10.0.0.1/secret",
"http://192.168.1.1/admin",
"http://127.0.0.1:8080/internal",
]:
with pytest.raises(ValueError, match="allowed domains"):
await download_bytes_from_url(
internal_url, allowed_media_domains=["example.com"]
)
@pytest.mark.asyncio
async def test_download_bytes_allows_permitted_domain():
"""HTTP URLs whose hostname IS in the allowlist must be fetched."""
url = "https://example.com/audio.wav"
expected = b"audio-bytes"
mock_session = _make_aiohttp_mocks(expected)
with patch(
"vllm.entrypoints.openai.run_batch.aiohttp.ClientSession",
return_value=mock_session,
):
result = await download_bytes_from_url(
url, allowed_media_domains=["example.com"]
)
assert result == expected
@pytest.mark.asyncio
async def test_download_bytes_no_allowlist_permits_any_domain():
"""Without an allowlist all HTTP URLs must be attempted (backward compat)."""
url = "https://any-domain.example.org/file.wav"
expected = b"some-data"
mock_session = _make_aiohttp_mocks(expected)
with patch(
"vllm.entrypoints.openai.run_batch.aiohttp.ClientSession",
return_value=mock_session,
):
result = await download_bytes_from_url(url, allowed_media_domains=None)
assert result == expected
@pytest.mark.asyncio
async def test_download_bytes_empty_allowlist_denies_all():
"""An empty allowlist must deny all HTTP URLs (least privilege)."""
url = "https://any-domain.example.org/file.wav"
with pytest.raises(ValueError, match="allowed domains"):
await download_bytes_from_url(url, allowed_media_domains=[])
@pytest.mark.asyncio
async def test_download_bytes_unsupported_scheme():
"""Unsupported URL schemes must be rejected regardless of allowlist."""
with pytest.raises(ValueError, match="Unsupported URL scheme"):
await download_bytes_from_url("ftp://example.com/file.wav")
with pytest.raises(ValueError, match="Unsupported URL scheme"):
await download_bytes_from_url(
"ftp://example.com/file.wav",
allowed_media_domains=["example.com"],
)
@pytest.mark.asyncio
async def test_download_bytes_backslash_bypass():
"""Backslash-@ URL confusion must not bypass the allowlist.
urllib3.parse_url() and aiohttp/yarl disagree on backslash-before-@.
The fix normalizes through urllib3 before handing to aiohttp.
"""
bypass_url = "http://allowed.example.com\\@evil.internal/secret"
with pytest.raises(ValueError, match="allowed domains"):
await download_bytes_from_url(
bypass_url, allowed_media_domains=["evil.internal"]
)
...@@ -20,7 +20,9 @@ from pydantic import Field, TypeAdapter, field_validator, model_validator ...@@ -20,7 +20,9 @@ from pydantic import Field, TypeAdapter, field_validator, model_validator
from pydantic_core.core_schema import ValidationInfo from pydantic_core.core_schema import ValidationInfo
from starlette.datastructures import State from starlette.datastructures import State
from tqdm import tqdm from tqdm import tqdm
from urllib3.util import parse_url
import vllm.envs as envs
from vllm.config import config from vllm.config import config
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
...@@ -439,19 +441,25 @@ async def write_file( ...@@ -439,19 +441,25 @@ async def write_file(
await write_local_file(path_or_url, batch_outputs) await write_local_file(path_or_url, batch_outputs)
async def download_bytes_from_url(url: str) -> bytes: async def download_bytes_from_url(
url: str,
allowed_media_domains: list[str] | None = None,
) -> bytes:
""" """
Download data from a URL or decode from a data URL. Download data from a URL or decode from a data URL.
Args: Args:
url: Either an HTTP/HTTPS URL or a data URL (data:...;base64,...) url: Either an HTTP/HTTPS URL or a data URL (data:...;base64,...)
allowed_media_domains: If set, only HTTP/HTTPS URLs whose hostname
is in this list are permitted. data: URLs are not subject to
this restriction.
Returns: Returns:
Data as bytes Data as bytes
""" """
parsed = urlparse(url) parsed = urlparse(url)
# Handle data URLs (base64 encoded) # Handle data URLs (base64 encoded) - not subject to domain restrictions
if parsed.scheme == "data": if parsed.scheme == "data":
# Format: data:...;base64,<base64_data> # Format: data:...;base64,<base64_data>
if "," in url: if "," in url:
...@@ -465,9 +473,24 @@ async def download_bytes_from_url(url: str) -> bytes: ...@@ -465,9 +473,24 @@ async def download_bytes_from_url(url: str) -> bytes:
# Handle HTTP/HTTPS URLs # Handle HTTP/HTTPS URLs
elif parsed.scheme in ("http", "https"): elif parsed.scheme in ("http", "https"):
if allowed_media_domains is not None:
url_spec = parse_url(url)
if url_spec.hostname not in allowed_media_domains:
raise ValueError(
f"The URL must be from one of the allowed domains: "
f"{allowed_media_domains}. Input URL domain: "
f"{url_spec.hostname}"
)
# Use the normalized URL to prevent parsing discrepancies
# between urllib3 and aiohttp (e.g. backslash-@ attacks).
url = url_spec.url
async with ( async with (
aiohttp.ClientSession() as session, aiohttp.ClientSession() as session,
session.get(url) as resp, session.get(
url,
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
) as resp,
): ):
if resp.status != 200: if resp.status != 200:
raise Exception( raise Exception(
...@@ -593,7 +616,10 @@ def handle_endpoint_request( ...@@ -593,7 +616,10 @@ def handle_endpoint_request(
return run_request(handler_fn, request, tracker) return run_request(handler_fn, request, tracker)
def make_transcription_wrapper(is_translation: bool) -> WrapperFn: def make_transcription_wrapper(
is_translation: bool,
allowed_media_domains: list[str] | None = None,
) -> WrapperFn:
""" """
Factory function to create a wrapper for transcription/translation handlers. Factory function to create a wrapper for transcription/translation handlers.
The wrapper converts BatchTranscriptionRequest or BatchTranslationRequest The wrapper converts BatchTranscriptionRequest or BatchTranslationRequest
...@@ -602,6 +628,8 @@ def make_transcription_wrapper(is_translation: bool) -> WrapperFn: ...@@ -602,6 +628,8 @@ def make_transcription_wrapper(is_translation: bool) -> WrapperFn:
Args: Args:
is_translation: If True, process as translation; otherwise process is_translation: If True, process as translation; otherwise process
as transcription as transcription
allowed_media_domains: If set, only URLs from these domains are
permitted for HTTP/HTTPS fetches.
Returns: Returns:
A function that takes a handler and returns a wrapped handler A function that takes a handler and returns a wrapped handler
...@@ -619,7 +647,10 @@ def make_transcription_wrapper(is_translation: bool) -> WrapperFn: ...@@ -619,7 +647,10 @@ def make_transcription_wrapper(is_translation: bool) -> WrapperFn:
): ):
try: try:
# Download data from URL # Download data from URL
audio_data = await download_bytes_from_url(batch_request_body.file_url) audio_data = await download_bytes_from_url(
batch_request_body.file_url,
allowed_media_domains=allowed_media_domains,
)
# Create a mock file from the downloaded audio data # Create a mock file from the downloaded audio data
mock_file = UploadFile( mock_file = UploadFile(
...@@ -691,6 +722,8 @@ async def build_endpoint_registry( ...@@ -691,6 +722,8 @@ async def build_endpoint_registry(
serving_embedding = getattr(state, "serving_embedding", None) serving_embedding = getattr(state, "serving_embedding", None)
serving_scores = getattr(state, "serving_scores", None) serving_scores = getattr(state, "serving_scores", None)
allowed_media_domains = getattr(args, "allowed_media_domains", None)
# Registry of endpoint configurations # Registry of endpoint configurations
endpoint_registry: dict[str, dict[str, Any]] = { endpoint_registry: dict[str, dict[str, Any]] = {
"completions": { "completions": {
...@@ -730,7 +763,10 @@ async def build_endpoint_registry( ...@@ -730,7 +763,10 @@ async def build_endpoint_registry(
if openai_serving_transcription is not None if openai_serving_transcription is not None
else None else None
), ),
"wrapper_fn": make_transcription_wrapper(is_translation=False), "wrapper_fn": make_transcription_wrapper(
is_translation=False,
allowed_media_domains=allowed_media_domains,
),
}, },
"translations": { "translations": {
"url_matcher": lambda url: url == "/v1/audio/translations", "url_matcher": lambda url: url == "/v1/audio/translations",
...@@ -739,7 +775,10 @@ async def build_endpoint_registry( ...@@ -739,7 +775,10 @@ async def build_endpoint_registry(
if openai_serving_translation is not None if openai_serving_translation is not None
else None else None
), ),
"wrapper_fn": make_transcription_wrapper(is_translation=True), "wrapper_fn": make_transcription_wrapper(
is_translation=True,
allowed_media_domains=allowed_media_domains,
),
}, },
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment