Unverified Commit c5d6f1f3 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add download tests for remaining datasets (#3338)



* kmnist

* emnist

* qmnist

* omniglot

* phototour

* sbdataset

* sbu

* semeion

* stl10

* svhn

* usps

* cifar100

* enable download logging for google drive

* celeba

* widerface

* lint

* add timeout logic

* lint

* debug CI connection to problematic server

* set timeout for ping

* [ci skip] remove ping

* revert debugging

* disable requests to problematic server

* re-enable all other tests
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 8317295c
...@@ -4,14 +4,14 @@ import time ...@@ -4,14 +4,14 @@ import time
import unittest.mock import unittest.mock
from datetime import datetime from datetime import datetime
from os import path from os import path
from urllib.error import HTTPError from urllib.error import HTTPError, URLError
from urllib.parse import urlparse from urllib.parse import urlparse
from urllib.request import urlopen, Request from urllib.request import urlopen, Request
import pytest import pytest
from torchvision import datasets from torchvision import datasets
from torchvision.datasets.utils import download_url, check_integrity from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive
from common_utils import get_tmp_dir from common_utils import get_tmp_dir
from fakedata_generation import places365_root from fakedata_generation import places365_root
...@@ -48,35 +48,47 @@ urlopen = limit_requests_per_time()(urlopen) ...@@ -48,35 +48,47 @@ urlopen = limit_requests_per_time()(urlopen)
@contextlib.contextmanager @contextlib.contextmanager
def log_download_attempts( def log_download_attempts(
urls_and_md5s=None, urls_and_md5s=None,
file="utils",
patch=True, patch=True,
download_url_location=".utils", mock_auxiliaries=None,
patch_auxiliaries=None,
): ):
def add_mock(stack, name, file, **kwargs):
try:
return stack.enter_context(unittest.mock.patch(f"torchvision.datasets.{file}.{name}", **kwargs))
except AttributeError as error:
if file != "utils":
return add_mock(stack, name, "utils", **kwargs)
else:
raise pytest.UsageError from error
if urls_and_md5s is None: if urls_and_md5s is None:
urls_and_md5s = set() urls_and_md5s = set()
if download_url_location.startswith("."): if mock_auxiliaries is None:
download_url_location = f"torchvision.datasets{download_url_location}" mock_auxiliaries = patch
if patch_auxiliaries is None:
patch_auxiliaries = patch
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
download_url_mock = stack.enter_context( url_mock = add_mock(stack, "download_url", file, wraps=None if patch else download_url)
unittest.mock.patch( google_drive_mock = add_mock(
f"{download_url_location}.download_url", stack, "download_file_from_google_drive", file, wraps=None if patch else download_file_from_google_drive
wraps=None if patch else download_url,
) )
)
if patch_auxiliaries: if mock_auxiliaries:
# download_and_extract_archive add_mock(stack, "extract_archive", file)
stack.enter_context(unittest.mock.patch("torchvision.datasets.utils.extract_archive"))
try: try:
yield urls_and_md5s yield urls_and_md5s
finally: finally:
for args, kwargs in download_url_mock.call_args_list: for args, kwargs in url_mock.call_args_list:
url = args[0] url = args[0]
md5 = args[-1] if len(args) == 4 else kwargs.get("md5") md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5)) urls_and_md5s.add((url, md5))
for args, kwargs in google_drive_mock.call_args_list:
id = args[0]
url = f"https://drive.google.com/file/d/{id}"
md5 = args[3] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5))
def retry(fn, times=1, wait=5.0): def retry(fn, times=1, wait=5.0):
msgs = [] msgs = []
...@@ -101,6 +113,8 @@ def retry(fn, times=1, wait=5.0): ...@@ -101,6 +113,8 @@ def retry(fn, times=1, wait=5.0):
def assert_server_response_ok(): def assert_server_response_ok():
try: try:
yield yield
except URLError as error:
raise AssertionError("The request timed out.") from error
except HTTPError as error: except HTTPError as error:
raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
...@@ -108,14 +122,14 @@ def assert_server_response_ok(): ...@@ -108,14 +122,14 @@ def assert_server_response_ok():
def assert_url_is_accessible(url): def assert_url_is_accessible(url):
request = Request(url, headers=dict(method="HEAD")) request = Request(url, headers=dict(method="HEAD"))
with assert_server_response_ok(): with assert_server_response_ok():
urlopen(request) urlopen(request, timeout=5.0)
def assert_file_downloads_correctly(url, md5): def assert_file_downloads_correctly(url, md5):
with get_tmp_dir() as root: with get_tmp_dir() as root:
file = path.join(root, path.basename(url)) file = path.join(root, path.basename(url))
with assert_server_response_ok(): with assert_server_response_ok():
with urlopen(url) as response, open(file, "wb") as fh: with urlopen(url, timeout=5.0) as response, open(file, "wb") as fh:
fh.write(response.read()) fh.write(response.read())
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch" assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
...@@ -175,7 +189,7 @@ def cifar10(): ...@@ -175,7 +189,7 @@ def cifar10():
def cifar100(): def cifar100():
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR100") return collect_download_configs(lambda: datasets.CIFAR100(".", download=True), name="CIFAR100")
def voc(): def voc():
...@@ -184,7 +198,7 @@ def voc(): ...@@ -184,7 +198,7 @@ def voc():
collect_download_configs( collect_download_configs(
lambda: datasets.VOCSegmentation(".", year=year, download=True), lambda: datasets.VOCSegmentation(".", year=year, download=True),
name=f"VOC, {year}", name=f"VOC, {year}",
download_url_location=".voc", file="voc",
) )
for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012") for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012")
] ]
...@@ -199,6 +213,128 @@ def fashion_mnist(): ...@@ -199,6 +213,128 @@ def fashion_mnist():
return collect_download_configs(lambda: datasets.FashionMNIST(".", download=True), name="FashionMNIST") return collect_download_configs(lambda: datasets.FashionMNIST(".", download=True), name="FashionMNIST")
def kmnist():
return collect_download_configs(lambda: datasets.KMNIST(".", download=True), name="KMNIST")
def emnist():
# the 'split' argument can be any valid one, since everything is downloaded anyway
return collect_download_configs(lambda: datasets.EMNIST(".", split="byclass", download=True), name="EMNIST")
def qmnist():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.QMNIST(".", what=what, download=True),
name=f"QMNIST, {what}",
file="mnist",
)
for what in ("train", "test", "nist")
]
)
def omniglot():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.Omniglot(".", background=background, download=True),
name=f"Omniglot, {'background' if background else 'evaluation'}",
)
for background in (True, False)
]
)
def phototour():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.PhotoTour(".", name=name, download=True),
name=f"PhotoTour, {name}",
file="phototour",
)
# The names postfixed with '_harris' point to the domain 'matthewalunbrown.com'. For some reason all
# requests timeout from within CI. They are disabled until this is resolved.
for name in ("notredame", "yosemite", "liberty") # "notredame_harris", "yosemite_harris", "liberty_harris"
]
)
def sbdataset():
return collect_download_configs(
lambda: datasets.SBDataset(".", download=True),
name="SBDataset",
file="voc",
)
def sbu():
return collect_download_configs(
lambda: datasets.SBU(".", download=True),
name="SBU",
file="sbu",
)
def semeion():
return collect_download_configs(
lambda: datasets.SEMEION(".", download=True),
name="SEMEION",
file="semeion",
)
def stl10():
return collect_download_configs(
lambda: datasets.STL10(".", download=True),
name="STL10",
)
def svhn():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.SVHN(".", split=split, download=True),
name=f"SVHN, {split}",
file="svhn",
)
for split in ("train", "test", "extra")
]
)
def usps():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.USPS(".", train=train, download=True),
name=f"USPS, {'train' if train else 'test'}",
file="usps",
)
for train in (True, False)
]
)
def celeba():
return collect_download_configs(
lambda: datasets.CelebA(".", download=True),
name="CelebA",
file="celeba",
)
def widerface():
return collect_download_configs(
lambda: datasets.WIDERFace(".", download=True),
name="WIDERFace",
file="widerface",
)
def make_parametrize_kwargs(download_configs): def make_parametrize_kwargs(download_configs):
argvalues = [] argvalues = []
ids = [] ids = []
...@@ -221,6 +357,19 @@ def make_parametrize_kwargs(download_configs): ...@@ -221,6 +357,19 @@ def make_parametrize_kwargs(download_configs):
# voc(), # voc(),
mnist(), mnist(),
fashion_mnist(), fashion_mnist(),
kmnist(),
emnist(),
qmnist(),
omniglot(),
phototour(),
sbdataset(),
sbu(),
semeion(),
stl10(),
svhn(),
usps(),
celeba(),
widerface(),
) )
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment