Unverified Commit 2f1399e8 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add download tests for Caltech(101|256) (#2731)

* add download tests for Caltech(101|256)

* lint
parent a0a92670
...@@ -45,13 +45,23 @@ urlopen = limit_requests_per_time()(urlopen) ...@@ -45,13 +45,23 @@ urlopen = limit_requests_per_time()(urlopen)
@contextlib.contextmanager @contextlib.contextmanager
def log_download_attempts(patch=True): def log_download_attempts(urls_and_md5s=None, patch=True, patch_auxiliaries=None):
urls_and_md5s = set() if urls_and_md5s is None:
with unittest.mock.patch("torchvision.datasets.utils.download_url", wraps=None if patch else download_url) as mock: urls_and_md5s = set()
if patch_auxiliaries is None:
patch_auxiliaries = patch
with contextlib.ExitStack() as stack:
download_url_mock = stack.enter_context(
unittest.mock.patch("torchvision.datasets.utils.download_url", wraps=None if patch else download_url)
)
if patch_auxiliaries:
# download_and_extract_archive
stack.enter_context(unittest.mock.patch("torchvision.datasets.utils.extract_archive"))
try: try:
yield urls_and_md5s yield urls_and_md5s
finally: finally:
for args, kwargs in mock.call_args_list: for args, kwargs in download_url_mock.call_args_list:
url = args[0] url = args[0]
md5 = args[-1] if len(args) == 4 else kwargs.get("md5") md5 = args[-1] if len(args) == 4 else kwargs.get("md5")
urls_and_md5s.add((url, md5)) urls_and_md5s.add((url, md5))
...@@ -105,15 +115,14 @@ class DownloadConfig: ...@@ -105,15 +115,14 @@ class DownloadConfig:
self.md5 = md5 self.md5 = md5
self.id = id or url self.id = id or url
def __repr__(self):
return self.id
def make_parametrize_kwargs(download_configs):
argvalues = []
ids = []
for config in download_configs:
argvalues.append((config.url, config.md5))
ids.append(config.id)
return dict(argnames="url, md5", argvalues=argvalues, ids=ids) def make_download_configs(urls_and_md5s, name=None):
return [
DownloadConfig(url, md5=md5, id=f"{name}, {url}" if name is not None else None) for url, md5 in urls_and_md5s
]
def places365(): def places365():
...@@ -124,10 +133,40 @@ def places365(): ...@@ -124,10 +133,40 @@ def places365():
datasets.Places365(root, split=split, small=small, download=True) datasets.Places365(root, split=split, small=small, download=True)
return [DownloadConfig(url, md5=md5, id=f"Places365, {url}") for url, md5 in urls_and_md5s] return make_download_configs(urls_and_md5s, "Places365")
def caltech101():
try:
with log_download_attempts() as urls_and_md5s:
datasets.Caltech101(".", download=True)
except Exception:
pass
return make_download_configs(urls_and_md5s, "Caltech101")
def caltech256():
try:
with log_download_attempts() as urls_and_md5s:
datasets.Caltech256(".", download=True)
except Exception:
pass
return make_download_configs(urls_and_md5s, "Caltech256")
def make_parametrize_kwargs(download_configs):
argvalues = []
ids = []
for config in download_configs:
argvalues.append((config.url, config.md5))
ids.append(config.id)
return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain(places365(),))) @pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain(places365(), caltech101(), caltech256())))
def test_url_is_accessible(url, md5): def test_url_is_accessible(url, md5):
retry(lambda: assert_url_is_accessible(url)) retry(lambda: assert_url_is_accessible(url))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment