Unverified Commit 00c460af authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix redirection in download tests (#3568)


Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 8ee63393
......@@ -14,7 +14,13 @@ import warnings
import pytest
from torchvision import datasets
from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive, USER_AGENT
from torchvision.datasets.utils import (
download_url,
check_integrity,
download_file_from_google_drive,
_get_redirect_url,
USER_AGENT,
)
from common_utils import get_tmp_dir
from fakedata_generation import places365_root
......@@ -48,22 +54,28 @@ def limit_requests_per_time(min_secs_between_requests=2.0):
urlopen = limit_requests_per_time()(urlopen)
def resolve_redirects(max_redirects=3):
def resolve_redirects(max_hops=3):
def outer_wrapper(fn):
def inner_wrapper(request, *args, **kwargs):
url = initial_url = request.full_url if isinstance(request, Request) else request
initial_url = request.full_url if isinstance(request, Request) else request
url = _get_redirect_url(initial_url, max_hops=max_hops)
for _ in range(max_redirects + 1):
response = fn(request, *args, **kwargs)
if url == initial_url:
return fn(request, *args, **kwargs)
if response.url == url or response.url is None:
if url != initial_url:
warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.")
return response
warnings.warn(f"The URL {initial_url} ultimately redirects to {url}.")
url = response.url
else:
raise RecursionError(f"Request to {initial_url} exceeded {max_redirects} redirects.")
if not isinstance(request, Request):
return fn(url, *args, **kwargs)
request_attrs = {
attr: getattr(request, attr) for attr in ("data", "headers", "origin_req_host", "unverifiable")
}
# the 'method' attribute does only exist if the request was created with it
if hasattr(request, "method"):
request_attrs["method"] = request.method
return fn(Request(url, **request_attrs), *args, **kwargs)
return inner_wrapper
......@@ -150,7 +162,7 @@ def assert_server_response_ok():
def assert_url_is_accessible(url, timeout=5.0):
request = Request(url, headers={"method": "HEAD", "User-Agent": USER_AGENT})
request = Request(url, headers={"User-Agent": USER_AGENT}, method="HEAD")
with assert_server_response_ok():
urlopen(request, timeout=timeout)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment