Unverified Commit c5d6f1f3 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add download tests for remaining datasets (#3338)



* kmnist

* emnist

* qmnist

* omniglot

* phototour

* sbdataset

* sbu

* semeion

* stl10

* svhn

* usps

* cifar100

* enable download logging for google drive

* celeba

* widerface

* lint

* add timeout logic

* lint

* debug CI connection to problematic server

* set timeout for ping

* [ci skip] remove ping

* revert debugging

* disable requests to problematic server

* re-enable all other tests
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 8317295c
......@@ -4,14 +4,14 @@ import time
import unittest.mock
from datetime import datetime
from os import path
from urllib.error import HTTPError
from urllib.error import HTTPError, URLError
from urllib.parse import urlparse
from urllib.request import urlopen, Request
import pytest
from torchvision import datasets
from torchvision.datasets.utils import download_url, check_integrity
from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive
from common_utils import get_tmp_dir
from fakedata_generation import places365_root
......@@ -48,35 +48,47 @@ urlopen = limit_requests_per_time()(urlopen)
@contextlib.contextmanager
def log_download_attempts(
urls_and_md5s=None,
file="utils",
patch=True,
download_url_location=".utils",
patch_auxiliaries=None,
mock_auxiliaries=None,
):
def add_mock(stack, name, file, **kwargs):
try:
return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs))
except AttributeError as error:
if file != "utils":
return add_mock(stack, name, "utils", **kwargs)
else:
raise pytest.UsageError from error
if urls_and_md5s is None:
urls_and_md5s = set()
if download_url_location.startswith("."):
download_url_location = f"torchvision.datasets{download_url_location}"
if patch_auxiliaries is None:
patch_auxiliaries = patch
if mock_auxiliaries is None:
mock_auxiliaries = patch
with contextlib.ExitStack() as stack:
download_url_mock = stack.enter_context(
unittest.mock.patch(
f"{download_url_location}.download_url",
wraps=None if patch else download_url,
)
url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url)
google_drive_mock = add_mock(
stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive
)
if patch_auxiliaries:
# download_and_extract_archive
stack.enter_context(unittest.mock.patch("torchvision.datasets.utils.extract_archive"))
if mock_auxiliaries:
add_mock(stack, "extract_archive", file)
try:
yield urls_and_md5s
finally:
for args, kwargs in download_url_mock.call_args_list:
for args, kwargs in url_mock.call_args_list:
url = args[0]
md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5))
for args, kwargs in google_drive_mock.call_args_list:
id = args[0]
url = f"https://drive.google.com/file/d/{id}"
md5 = args[3] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5))
def retry(fn, times=1, wait=5.0):
msgs = []
......@@ -101,6 +113,8 @@ def retry(fn, times=1, wait=5.0):
def assert_server_response_ok():
try:
yield
except URLError as error:
raise AssertionError("The request timed out.") from error
except HTTPError as error:
raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
......@@ -108,14 +122,14 @@ def assert_server_response_ok():
def assert_url_is_accessible(url):
request = Request(url, headers=dict(method="HEAD"))
with assert_server_response_ok():
urlopen(request)
urlopen(request, timeout=5.0)
def assert_file_downloads_correctly(url, md5):
with get_tmp_dir() as root:
file = path.join(root, path.basename(url))
with assert_server_response_ok():
with urlopen(url) as response, open(file, "wb") as fh:
with urlopen(url, timeout=5.0) as response, open(file, "wb") as fh:
fh.write(response.read())
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
......@@ -175,7 +189,7 @@ def cifar10():
def cifar100():
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR100")
return collect_download_configs(lambda: datasets.CIFAR100(".", download=True), name="CIFAR100")
def voc():
......@@ -184,7 +198,7 @@ def voc():
collect_download_configs(
lambda: datasets.VOCSegmentation(".", year=year, download=True),
name=f"VOC, {year}",
download_url_location=".voc",
file="voc",
)
for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012")
]
......@@ -199,6 +213,128 @@ def fashion_mnist():
return collect_download_configs(lambda: datasets.FashionMNIST(".", download=True), name="FashionMNIST")
def kmnist():
return collect_download_configs(lambda: datasets.KMNIST(".", download=True), name="KMNIST")
def emnist():
# the 'split' argument can be any valid one, since everything is downloaded anyway
return collect_download_configs(lambda: datasets.EMNIST(".", split="byclass", download=True), name="EMNIST")
def qmnist():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.QMNIST(".", what=what, download=True),
name=f"QMNIST, {what}",
file="mnist",
)
for what in ("train", "test", "nist")
]
)
def omniglot():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.Omniglot(".", background=background, download=True),
name=f"Omniglot, {'background' if background else 'evaluation'}",
)
for background in (True, False)
]
)
def phototour():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.PhotoTour(".", name=name, download=True),
name=f"PhotoTour, {name}",
file="phototour",
)
# The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
# requests timeout from within CI. They are disabled until this is resolved.
for name in ("notredame", "yosemite", "liberty") # "notredame_harris", "yosemite_harris", "liberty_harris"
]
)
def sbdataset():
return collect_download_configs(
lambda: datasets.SBDataset(".", download=True),
name="SBDataset",
file="voc",
)
def sbu():
return collect_download_configs(
lambda: datasets.SBU(".", download=True),
name="SBU",
file="sbu",
)
def semeion():
return collect_download_configs(
lambda: datasets.SEMEION(".", download=True),
name="SEMEION",
file="semeion",
)
def stl10():
return collect_download_configs(
lambda: datasets.STL10(".", download=True),
name="STL10",
)
def svhn():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.SVHN(".", split=split, download=True),
name=f"SVHN, {split}",
file="svhn",
)
for split in ("train", "test", "extra")
]
)
def usps():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.USPS(".", train=train, download=True),
name=f"USPS, {'train' if train else 'test'}",
file="usps",
)
for train in (True, False)
]
)
def celeba():
return collect_download_configs(
lambda: datasets.CelebA(".", download=True),
name="CelebA",
file="celeba",
)
def widerface():
return collect_download_configs(
lambda: datasets.WIDERFace(".", download=True),
name="WIDERFace",
file="widerface",
)
def make_parametrize_kwargs(download_configs):
argvalues = []
ids = []
......@@ -221,6 +357,19 @@ def make_parametrize_kwargs(download_configs):
# voc(),
mnist(),
fashion_mnist(),
kmnist(),
emnist(),
qmnist(),
omniglot(),
phototour(),
sbdataset(),
sbu(),
semeion(),
stl10(),
svhn(),
usps(),
celeba(),
widerface(),
)
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment