Unverified Commit 6b41eb0b authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add download tests for CIFAR (#2747)

* add download tests for CIFAR

* fix tests in case of bad request
parent fdca3073
......@@ -35,7 +35,7 @@ jobs:
run: pip install pytest
- name: Run tests
run: pytest --durations=20 -ra test/test_datasets_download.py
run: pytest -ra -v test/test_datasets_download.py
- uses: JasonEtco/create-an-issue@v2.4.0
name: Create issue if download tests failed
......
......@@ -4,6 +4,7 @@ import time
import unittest.mock
from datetime import datetime
from os import path
from urllib.error import HTTPError
from urllib.parse import urlparse
from urllib.request import urlopen, Request
......@@ -86,25 +87,26 @@ def retry(fn, times=1, wait=5.0):
)
def assert_server_response_ok(response, url=None):
msg = f"The server returned status code {response.code}"
if url is not None:
msg += f"for the the URL {url}"
assert 200 <= response.code < 300, msg
@contextlib.contextmanager
def assert_server_response_ok():
try:
yield
except HTTPError as error:
raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
def assert_url_is_accessible(url):
request = Request(url, headers=dict(method="HEAD"))
response = urlopen(request)
assert_server_response_ok(response, url)
with assert_server_response_ok():
urlopen(request)
def assert_file_downloads_correctly(url, md5):
with get_tmp_dir() as root:
file = path.join(root, path.basename(url))
with urlopen(url) as response, open(file, "wb") as fh:
assert_server_response_ok(response, url)
fh.write(response.read())
with assert_server_response_ok():
with urlopen(url) as response, open(file, "wb") as fh:
fh.write(response.read())
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
......@@ -125,6 +127,16 @@ def make_download_configs(urls_and_md5s, name=None):
]
def collect_download_configs(dataset_loader, name):
try:
with log_download_attempts() as urls_and_md5s:
dataset_loader()
except Exception:
pass
return make_download_configs(urls_and_md5s, name)
def places365():
with log_download_attempts(patch=False) as urls_and_md5s:
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
......@@ -137,23 +149,19 @@ def places365():
def caltech101():
try:
with log_download_attempts() as urls_and_md5s:
datasets.Caltech101(".", download=True)
except Exception:
pass
return make_download_configs(urls_and_md5s, "Caltech101")
return collect_download_configs(lambda: datasets.Caltech101(".", download=True), "Caltech101")
def caltech256():
try:
with log_download_attempts() as urls_and_md5s:
datasets.Caltech256(".", download=True)
except Exception:
pass
return collect_download_configs(lambda: datasets.Caltech256(".", download=True), "Caltech256")
def cifar10():
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR10")
return make_download_configs(urls_and_md5s, "Caltech256")
def cifar100():
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR100")
def make_parametrize_kwargs(download_configs):
......@@ -166,7 +174,9 @@ def make_parametrize_kwargs(download_configs):
return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain(places365(), caltech101(), caltech256())))
@pytest.mark.parametrize(
**make_parametrize_kwargs(itertools.chain(places365(), caltech101(), caltech256(), cifar10(), cifar100()))
)
def test_url_is_accessible(url, md5):
retry(lambda: assert_url_is_accessible(url))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment