Unverified Commit fb8b5e05 authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[CI] Add retry with 4x backoff to HTTP fetches for transient failures (#37218)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent e5d96dc8
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping, MutableMapping import asyncio
import functools
import time
from collections.abc import Callable, Coroutine, Mapping, MutableMapping
from pathlib import Path from pathlib import Path
from typing import Any, ParamSpec, TypeVar
import aiohttp import aiohttp
import requests import requests
from urllib3.util import parse_url from urllib3.util import parse_url
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
_P = ParamSpec("_P")
_T = TypeVar("_T")
# Multiplier applied to timeout and sleep on each retry attempt.
# Attempt N uses: base_timeout * (_RETRY_BACKOFF_FACTOR ** N) for the
# per-attempt timeout and sleeps _RETRY_BACKOFF_FACTOR ** N seconds.
_RETRY_BACKOFF_FACTOR = 4
def _is_retryable(exc: Exception) -> bool:
"""Return True for transient errors that are worth retrying.
Retryable:
- Timeouts (aiohttp, requests, stdlib)
- Connection-level failures (refused, reset, DNS)
- Server errors (5xx) -- includes S3 503 SlowDown
Not retryable:
- Client errors (4xx) -- bad URL, auth, not-found
- Programming errors (ValueError, TypeError, ...)
"""
# Timeouts
if isinstance(
exc,
(
TimeoutError,
asyncio.TimeoutError,
requests.exceptions.Timeout,
aiohttp.ServerTimeoutError,
),
):
return True
# Connection-level failures
if isinstance(
exc,
(
ConnectionError,
aiohttp.ClientConnectionError,
requests.exceptions.ConnectionError,
),
):
return True
# aiohttp server-side disconnects
if isinstance(exc, aiohttp.ServerDisconnectedError):
return True
# requests 5xx -- raise_for_status() throws HTTPError
if (
isinstance(exc, requests.exceptions.HTTPError)
and exc.response is not None
and exc.response.status_code >= 500
):
return True
# aiohttp 5xx -- raise_for_status() throws ClientResponseError
return isinstance(exc, aiohttp.ClientResponseError) and exc.status >= 500
def _log_retry(
args: tuple,
kwargs: dict,
attempt: int,
max_retries: int,
attempt_timeout: float | None,
exc: Exception,
backoff: float,
base_timeout: float | None,
) -> None:
# args[0] is `self` (bound method), args[1] is the URL
url = args[1] if len(args) > 1 else kwargs.get("url")
timeout_info = (
f"timeout={attempt_timeout:.3f}s" if base_timeout is not None else "no timeout"
)
next_timeout = (
f" with timeout={base_timeout * (_RETRY_BACKOFF_FACTOR ** (attempt + 1)):.3f}s"
if base_timeout is not None
else ""
)
logger.warning(
"HTTP fetch failed for %s (attempt %d/%d, %s): %s -- retrying in %.3fs%s",
url,
attempt + 1,
max_retries,
timeout_info,
exc,
backoff,
next_timeout,
)
def _sync_retry(
fn: Callable[_P, _T],
) -> Callable[_P, _T]:
"""Add retry logic with exponential backoff to a sync method.
The decorated method must accept ``timeout`` as a keyword argument.
The decorator replaces it with a per-attempt timeout that grows by
``_RETRY_BACKOFF_FACTOR`` on each retry so transient slowness on busy
hosts is absorbed.
"""
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> _T:
base_timeout: float | None = kwargs.get("timeout")
max_retries = max(envs.VLLM_MEDIA_FETCH_MAX_RETRIES, 1)
for attempt in range(max_retries):
attempt_timeout = (
base_timeout * (_RETRY_BACKOFF_FACTOR**attempt)
if base_timeout is not None
else None
)
kwargs["timeout"] = attempt_timeout
try:
return fn(*args, **kwargs)
except Exception as e:
if not _is_retryable(e) or attempt + 1 >= max_retries:
raise
backoff = _RETRY_BACKOFF_FACTOR**attempt
_log_retry(
args,
kwargs,
attempt,
max_retries,
attempt_timeout,
e,
backoff,
base_timeout,
)
time.sleep(backoff)
raise AssertionError("unreachable")
return wrapper # type: ignore[return-value]
def _async_retry(
fn: Callable[_P, Coroutine[Any, Any, _T]],
) -> Callable[_P, Coroutine[Any, Any, _T]]:
"""Add retry logic with exponential backoff to an async method.
The decorated method must accept ``timeout`` as a keyword argument.
The decorator replaces it with a per-attempt timeout that grows by
``_RETRY_BACKOFF_FACTOR`` on each retry so transient slowness on busy
hosts is absorbed.
"""
@functools.wraps(fn)
async def wrapper(*args: Any, **kwargs: Any) -> _T:
base_timeout: float | None = kwargs.get("timeout")
max_retries = max(envs.VLLM_MEDIA_FETCH_MAX_RETRIES, 1)
for attempt in range(max_retries):
attempt_timeout = (
base_timeout * (_RETRY_BACKOFF_FACTOR**attempt)
if base_timeout is not None
else None
)
kwargs["timeout"] = attempt_timeout
try:
return await fn(*args, **kwargs)
except Exception as e:
if not _is_retryable(e) or attempt + 1 >= max_retries:
raise
backoff = _RETRY_BACKOFF_FACTOR**attempt
_log_retry(
args,
kwargs,
attempt,
max_retries,
attempt_timeout,
e,
backoff,
base_timeout,
)
await asyncio.sleep(backoff)
raise AssertionError("unreachable")
return wrapper # type: ignore[return-value]
class HTTPConnection: class HTTPConnection:
"""Helper class to send HTTP requests.""" """Helper class to send HTTP requests."""
...@@ -89,6 +275,7 @@ class HTTPConnection: ...@@ -89,6 +275,7 @@ class HTTPConnection:
allow_redirects=allow_redirects, allow_redirects=allow_redirects,
) )
@_sync_retry
def get_bytes( def get_bytes(
self, url: str, *, timeout: float | None = None, allow_redirects: bool = True self, url: str, *, timeout: float | None = None, allow_redirects: bool = True
) -> bytes: ) -> bytes:
...@@ -99,6 +286,7 @@ class HTTPConnection: ...@@ -99,6 +286,7 @@ class HTTPConnection:
return r.content return r.content
@_async_retry
async def async_get_bytes( async def async_get_bytes(
self, self,
url: str, url: str,
...@@ -147,6 +335,7 @@ class HTTPConnection: ...@@ -147,6 +335,7 @@ class HTTPConnection:
return await r.json() return await r.json()
@_sync_retry
def download_file( def download_file(
self, self,
url: str, url: str,
...@@ -155,6 +344,7 @@ class HTTPConnection: ...@@ -155,6 +344,7 @@ class HTTPConnection:
timeout: float | None = None, timeout: float | None = None,
chunk_size: int = 128, chunk_size: int = 128,
) -> Path: ) -> Path:
try:
with self.get_response(url, timeout=timeout) as r: with self.get_response(url, timeout=timeout) as r:
r.raise_for_status() r.raise_for_status()
...@@ -163,7 +353,13 @@ class HTTPConnection: ...@@ -163,7 +353,13 @@ class HTTPConnection:
f.write(chunk) f.write(chunk)
return save_path return save_path
except Exception:
# Clean up partial downloads before retrying or propagating
if save_path.exists():
save_path.unlink()
raise
@_async_retry
async def async_download_file( async def async_download_file(
self, self,
url: str, url: str,
...@@ -172,7 +368,11 @@ class HTTPConnection: ...@@ -172,7 +368,11 @@ class HTTPConnection:
timeout: float | None = None, timeout: float | None = None,
chunk_size: int = 128, chunk_size: int = 128,
) -> Path: ) -> Path:
async with await self.get_async_response(url, timeout=timeout) as r: try:
async with await self.get_async_response(
url,
timeout=timeout,
) as r:
r.raise_for_status() r.raise_for_status()
with save_path.open("wb") as f: with save_path.open("wb") as f:
...@@ -180,6 +380,11 @@ class HTTPConnection: ...@@ -180,6 +380,11 @@ class HTTPConnection:
f.write(chunk) f.write(chunk)
return save_path return save_path
except Exception:
# Clean up partial downloads before retrying or propagating
if save_path.exists():
save_path.unlink()
raise
global_http_connection = HTTPConnection() global_http_connection = HTTPConnection()
......
...@@ -64,6 +64,7 @@ if TYPE_CHECKING: ...@@ -64,6 +64,7 @@ if TYPE_CHECKING:
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_VIDEO_FETCH_TIMEOUT: int = 30
VLLM_AUDIO_FETCH_TIMEOUT: int = 10 VLLM_AUDIO_FETCH_TIMEOUT: int = 10
VLLM_MEDIA_FETCH_MAX_RETRIES: int = 3
VLLM_MEDIA_URL_ALLOW_REDIRECTS: bool = True VLLM_MEDIA_URL_ALLOW_REDIRECTS: bool = True
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8 VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25 VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
...@@ -773,6 +774,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -773,6 +774,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_AUDIO_FETCH_TIMEOUT": lambda: int( "VLLM_AUDIO_FETCH_TIMEOUT": lambda: int(
os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10") os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")
), ),
# Maximum number of retries for fetching media (images, audio, video)
# from URLs. Each retry quadruples the timeout. Default is 3.
"VLLM_MEDIA_FETCH_MAX_RETRIES": lambda: int(
os.getenv("VLLM_MEDIA_FETCH_MAX_RETRIES", "3")
),
# Whether to allow HTTP redirects when fetching from media URLs. # Whether to allow HTTP redirects when fetching from media URLs.
# Default to True # Default to True
"VLLM_MEDIA_URL_ALLOW_REDIRECTS": lambda: bool( "VLLM_MEDIA_URL_ALLOW_REDIRECTS": lambda: bool(
...@@ -1768,6 +1774,7 @@ def compile_factors() -> dict[str, object]: ...@@ -1768,6 +1774,7 @@ def compile_factors() -> dict[str, object]:
"VLLM_IMAGE_FETCH_TIMEOUT", "VLLM_IMAGE_FETCH_TIMEOUT",
"VLLM_VIDEO_FETCH_TIMEOUT", "VLLM_VIDEO_FETCH_TIMEOUT",
"VLLM_AUDIO_FETCH_TIMEOUT", "VLLM_AUDIO_FETCH_TIMEOUT",
"VLLM_MEDIA_FETCH_MAX_RETRIES",
"VLLM_MEDIA_URL_ALLOW_REDIRECTS", "VLLM_MEDIA_URL_ALLOW_REDIRECTS",
"VLLM_MEDIA_LOADING_THREAD_COUNT", "VLLM_MEDIA_LOADING_THREAD_COUNT",
"VLLM_MAX_AUDIO_CLIP_FILESIZE_MB", "VLLM_MAX_AUDIO_CLIP_FILESIZE_MB",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment