Unverified Commit e946e87d authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

refactor download tests (#7546)

parent 6381f7b2
......@@ -14,13 +14,7 @@ from urllib.request import Request, urlopen
import pytest
from torchvision import datasets
from torchvision.datasets.utils import (
_get_redirect_url,
check_integrity,
download_file_from_google_drive,
download_url,
USER_AGENT,
)
from torchvision.datasets.utils import _get_redirect_url, USER_AGENT
def limit_requests_per_time(min_secs_between_requests=2.0):
......@@ -84,47 +78,45 @@ urlopen = resolve_redirects()(urlopen)
@contextlib.contextmanager
def log_download_attempts(
urls_and_md5s=None,
file="utils",
patch=True,
mock_auxiliaries=None,
urls,
*,
dataset_module,
):
def add_mock(stack, name, file, **kwargs):
def maybe_add_mock(*, module, name, stack, lst=None):
patcher = unittest.mock.patch(f"torchvision.datasets.{module}.{name}")
try:
return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs))
except AttributeError as error:
if file != "utils":
return add_mock(stack, name, "utils", **kwargs)
else:
raise pytest.UsageError from error
if urls_and_md5s is None:
urls_and_md5s = set()
if mock_auxiliaries is None:
mock_auxiliaries = patch
mock = stack.enter_context(patcher)
except AttributeError:
return
with contextlib.ExitStack() as stack:
url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url)
google_drive_mock = add_mock(
stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive
)
if lst is not None:
lst.append(mock)
if mock_auxiliaries:
add_mock(stack, "extract_archive", file)
with contextlib.ExitStack() as stack:
download_url_mocks = []
download_file_from_google_drive_mocks = []
for module in [dataset_module, "utils"]:
maybe_add_mock(module=module, name="download_url", stack=stack, lst=download_url_mocks)
maybe_add_mock(
module=module,
name="download_file_from_google_drive",
stack=stack,
lst=download_file_from_google_drive_mocks,
)
maybe_add_mock(module=module, name="extract_archive", stack=stack)
try:
yield urls_and_md5s
yield
finally:
for args, kwargs in url_mock.call_args_list:
url = args[0]
md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5))
for download_url_mock in download_url_mocks:
for args, kwargs in download_url_mock.call_args_list:
urls.append(args[0] if args else kwargs["url"])
for args, kwargs in google_drive_mock.call_args_list:
id = args[0]
url = f"https://drive.google.com/file/d/{id}"
md5 = args[3] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5))
for download_file_from_google_drive_mock in download_file_from_google_drive_mocks:
for args, kwargs in download_file_from_google_drive_mock.call_args_list:
file_id = args[0] if args else kwargs["file_id"]
urls.append(f"https://drive.google.com/file/d/{file_id}")
def retry(fn, times=1, wait=5.0):
......@@ -170,45 +162,14 @@ def assert_url_is_accessible(url, timeout=5.0):
urlopen(request, timeout=timeout)
def assert_file_downloads_correctly(url, md5, tmpdir, timeout=5.0):
file = path.join(tmpdir, path.basename(url))
with assert_server_response_ok():
with open(file, "wb") as fh:
request = Request(url, headers={"User-Agent": USER_AGENT})
response = urlopen(request, timeout=timeout)
fh.write(response.read())
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
class DownloadConfig:
def __init__(self, url, md5=None, id=None):
self.url = url
self.md5 = md5
self.id = id or url
def __repr__(self) -> str:
return self.id
def collect_urls(dataset_cls, *args, **kwargs):
urls = []
with contextlib.suppress(Exception), log_download_attempts(
urls, dataset_module=dataset_cls.__module__.split(".")[-1]
):
dataset_cls(*args, **kwargs)
def make_download_configs(urls_and_md5s, name=None):
return [
DownloadConfig(url, md5=md5, id=f"{name}, {url}" if name is not None else None) for url, md5 in urls_and_md5s
]
def collect_download_configs(dataset_loader, name=None, **kwargs):
urls_and_md5s = set()
try:
with log_download_attempts(urls_and_md5s=urls_and_md5s, **kwargs):
dataset = dataset_loader()
except Exception:
dataset = None
if name is None and dataset is not None:
name = type(dataset).__name__
return make_download_configs(urls_and_md5s, name)
return [(url, f"{dataset_cls.__name__}, {url}") for url in urls]
# This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a
......@@ -223,12 +184,14 @@ def root():
def places365():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.Places365(ROOT, split=split, small=small, download=True),
name=f"Places365, {split}, {'small' if small else 'large'}",
file="places365",
return itertools.chain.from_iterable(
[
collect_urls(
datasets.Places365,
ROOT,
split=split,
small=small,
download=True,
)
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True))
]
......@@ -236,30 +199,26 @@ def places365():
def caltech101():
return collect_download_configs(lambda: datasets.Caltech101(ROOT, download=True), name="Caltech101")
return collect_urls(datasets.Caltech101, ROOT, download=True)
def caltech256():
return collect_download_configs(lambda: datasets.Caltech256(ROOT, download=True), name="Caltech256")
return collect_urls(datasets.Caltech256, ROOT, download=True)
def cifar10():
return collect_download_configs(lambda: datasets.CIFAR10(ROOT, download=True), name="CIFAR10")
return collect_urls(datasets.CIFAR10, ROOT, download=True)
def cifar100():
return collect_download_configs(lambda: datasets.CIFAR100(ROOT, download=True), name="CIFAR100")
return collect_urls(datasets.CIFAR100, ROOT, download=True)
def voc():
# TODO: Also test the "2007-test" key
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.VOCSegmentation(ROOT, year=year, download=True),
name=f"VOC, {year}",
file="voc",
)
return itertools.chain.from_iterable(
[
collect_urls(datasets.VOCSegmentation, ROOT, year=year, download=True)
for year in ("2007", "2008", "2009", "2010", "2011", "2012")
]
)
......@@ -267,59 +226,42 @@ def voc():
def mnist():
with unittest.mock.patch.object(datasets.MNIST, "mirrors", datasets.MNIST.mirrors[-1:]):
return collect_download_configs(lambda: datasets.MNIST(ROOT, download=True), name="MNIST")
return collect_urls(datasets.MNIST, ROOT, download=True)
def fashion_mnist():
return collect_download_configs(lambda: datasets.FashionMNIST(ROOT, download=True), name="FashionMNIST")
return collect_urls(datasets.FashionMNIST, ROOT, download=True)
def kmnist():
return collect_download_configs(lambda: datasets.KMNIST(ROOT, download=True), name="KMNIST")
return collect_urls(datasets.KMNIST, ROOT, download=True)
def emnist():
# the 'split' argument can be any valid one, since everything is downloaded anyway
return collect_download_configs(lambda: datasets.EMNIST(ROOT, split="byclass", download=True), name="EMNIST")
return collect_urls(datasets.EMNIST, ROOT, split="byclass", download=True)
def qmnist():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.QMNIST(ROOT, what=what, download=True),
name=f"QMNIST, {what}",
file="mnist",
)
for what in ("train", "test", "nist")
]
return itertools.chain.from_iterable(
[collect_urls(datasets.QMNIST, ROOT, what=what, download=True) for what in ("train", "test", "nist")]
)
def moving_mnist():
return collect_download_configs(lambda: datasets.MovingMNIST(ROOT, download=True), name="MovingMNIST")
return collect_urls(datasets.MovingMNIST, ROOT, download=True)
def omniglot():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.Omniglot(ROOT, background=background, download=True),
name=f"Omniglot, {'background' if background else 'evaluation'}",
)
for background in (True, False)
]
return itertools.chain.from_iterable(
[collect_urls(datasets.Omniglot, ROOT, background=background, download=True) for background in (True, False)]
)
def phototour():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.PhotoTour(ROOT, name=name, download=True),
name=f"PhotoTour, {name}",
file="phototour",
)
return itertools.chain.from_iterable(
[
collect_urls(datasets.PhotoTour, ROOT, name=name, download=True)
# The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
# requests timeout from within CI. They are disabled until this is resolved.
for name in ("notredame", "yosemite", "liberty") # "notredame_harris", "yosemite_harris", "liberty_harris"
......@@ -328,91 +270,51 @@ def phototour():
def sbdataset():
return collect_download_configs(
lambda: datasets.SBDataset(ROOT, download=True),
name="SBDataset",
file="voc",
)
return collect_urls(datasets.SBDataset, ROOT, download=True)
def sbu():
return collect_download_configs(
lambda: datasets.SBU(ROOT, download=True),
name="SBU",
file="sbu",
)
return collect_urls(datasets.SBU, ROOT, download=True)
def semeion():
return collect_download_configs(
lambda: datasets.SEMEION(ROOT, download=True),
name="SEMEION",
file="semeion",
)
return collect_urls(datasets.SEMEION, ROOT, download=True)
def stl10():
return collect_download_configs(
lambda: datasets.STL10(ROOT, download=True),
name="STL10",
)
return collect_urls(datasets.STL10, ROOT, download=True)
def svhn():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.SVHN(ROOT, split=split, download=True),
name=f"SVHN, {split}",
file="svhn",
)
for split in ("train", "test", "extra")
]
return itertools.chain.from_iterable(
[collect_urls(datasets.SVHN, ROOT, split=split, download=True) for split in ("train", "test", "extra")]
)
def usps():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.USPS(ROOT, train=train, download=True),
name=f"USPS, {'train' if train else 'test'}",
file="usps",
)
for train in (True, False)
]
return itertools.chain.from_iterable(
[collect_urls(datasets.USPS, ROOT, train=train, download=True) for train in (True, False)]
)
def celeba():
return collect_download_configs(
lambda: datasets.CelebA(ROOT, download=True),
name="CelebA",
file="celeba",
)
return collect_urls(datasets.CelebA, ROOT, download=True)
def widerface():
return collect_download_configs(
lambda: datasets.WIDERFace(ROOT, download=True),
name="WIDERFace",
file="widerface",
)
return collect_urls(datasets.WIDERFace, ROOT, download=True)
def kinetics():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.Kinetics(
path.join(ROOT, f"Kinetics{num_classes}"),
frames_per_clip=1,
num_classes=num_classes,
split=split,
download=True,
),
name=f"Kinetics, {num_classes}, {split}",
file="kinetics",
return itertools.chain.from_iterable(
[
collect_urls(
datasets.Kinetics,
path.join(ROOT, f"Kinetics{num_classes}"),
frames_per_clip=1,
num_classes=num_classes,
split=split,
download=True,
)
for num_classes, split in itertools.product(("400", "600", "700"), ("train", "val"))
]
......@@ -420,58 +322,55 @@ def kinetics():
def kitti():
return itertools.chain(
*[
collect_download_configs(
lambda train=train: datasets.Kitti(ROOT, train=train, download=True),
name=f"Kitti, {'train' if train else 'test'}",
file="kitti",
)
for train in (True, False)
]
return itertools.chain.from_iterable(
[collect_urls(datasets.Kitti, ROOT, train=train, download=True) for train in (True, False)]
)
def make_parametrize_kwargs(download_configs):
argvalues = []
ids = []
for config in download_configs:
argvalues.append((config.url, config.md5))
ids.append(config.id)
return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)
@pytest.mark.parametrize(
**make_parametrize_kwargs(
itertools.chain(
caltech101(),
caltech256(),
cifar10(),
cifar100(),
# The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
# voc(),
mnist(),
fashion_mnist(),
kmnist(),
emnist(),
qmnist(),
omniglot(),
phototour(),
sbdataset(),
semeion(),
stl10(),
svhn(),
usps(),
celeba(),
widerface(),
kinetics(),
kitti(),
places365(),
)
def stanford_cars():
return itertools.chain.from_iterable(
[collect_urls(datasets.StanfordCars, ROOT, split=split, download=True) for split in ["train", "test"]]
)
def url_parametrization(*dataset_urls_and_ids_fns):
return pytest.mark.parametrize(
"url",
[
pytest.param(url, id=id)
for dataset_urls_and_ids_fn in dataset_urls_and_ids_fns
for url, id in sorted(set(dataset_urls_and_ids_fn()))
],
)
@url_parametrization(
caltech101,
caltech256,
cifar10,
cifar100,
# The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
# voc,
mnist,
fashion_mnist,
kmnist,
emnist,
qmnist,
omniglot,
phototour,
sbdataset,
semeion,
stl10,
svhn,
usps,
celeba,
widerface,
kinetics,
kitti,
places365,
sbu,
)
def test_url_is_accessible(url, md5):
def test_url_is_accessible(url):
"""
If you see this test failing, find the offending dataset in the parametrization and move it to
``test_url_is_not_accessible`` and link an issue detailing the problem.
......@@ -479,15 +378,11 @@ def test_url_is_accessible(url, md5):
retry(lambda: assert_url_is_accessible(url))
@pytest.mark.parametrize(
**make_parametrize_kwargs(
itertools.chain(
sbu(), # https://github.com/pytorch/vision/issues/7005
)
)
@url_parametrization(
stanford_cars, # https://github.com/pytorch/vision/issues/7545
)
@pytest.mark.xfail
def test_url_is_not_accessible(url, md5):
def test_url_is_not_accessible(url):
"""
As the name implies, this test is the 'inverse' of ``test_url_is_accessible``. Since the download servers are
beyond our control, some files might not be accessible for longer stretches of time. Still, we want to know if they
......@@ -497,8 +392,3 @@ def test_url_is_not_accessible(url, md5):
``test_url_is_accessible``.
"""
retry(lambda: assert_url_is_accessible(url))
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain()))
def test_file_downloads_correctly(url, md5):
retry(lambda: assert_file_downloads_correctly(url, md5))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment