Unverified Commit 6b41eb0b authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add download tests for CIFAR (#2747)

* add download tests for CIFAR

* fix tests in case of bad request
parent fdca3073
...@@ -35,7 +35,7 @@ jobs: ...@@ -35,7 +35,7 @@ jobs:
run: pip install pytest run: pip install pytest
- name: Run tests - name: Run tests
run: pytest --durations=20 -ra test/test_datasets_download.py run: pytest -ra -v test/test_datasets_download.py
- uses: JasonEtco/create-an-issue@v2.4.0 - uses: JasonEtco/create-an-issue@v2.4.0
name: Create issue if download tests failed name: Create issue if download tests failed
......
...@@ -4,6 +4,7 @@ import time ...@@ -4,6 +4,7 @@ import time
import unittest.mock import unittest.mock
from datetime import datetime from datetime import datetime
from os import path from os import path
from urllib.error import HTTPError
from urllib.parse import urlparse from urllib.parse import urlparse
from urllib.request import urlopen, Request from urllib.request import urlopen, Request
...@@ -86,24 +87,25 @@ def retry(fn, times=1, wait=5.0): ...@@ -86,24 +87,25 @@ def retry(fn, times=1, wait=5.0):
) )
def assert_server_response_ok(response, url=None): @contextlib.contextmanager
msg = f"The server returned status code {response.code}" def assert_server_response_ok():
if url is not None: try:
msg += f"for the the URL {url}" yield
assert 200 <= response.code < 300, msg except HTTPError as error:
raise AssertionError(f"The server returned {error.code}: {error.reason}.") from error
def assert_url_is_accessible(url): def assert_url_is_accessible(url):
request = Request(url, headers=dict(method="HEAD")) request = Request(url, headers=dict(method="HEAD"))
response = urlopen(request) with assert_server_response_ok():
assert_server_response_ok(response, url) urlopen(request)
def assert_file_downloads_correctly(url, md5): def assert_file_downloads_correctly(url, md5):
with get_tmp_dir() as root: with get_tmp_dir() as root:
file = path.join(root, path.basename(url)) file = path.join(root, path.basename(url))
with assert_server_response_ok():
with urlopen(url) as response, open(file, "wb") as fh: with urlopen(url) as response, open(file, "wb") as fh:
assert_server_response_ok(response, url)
fh.write(response.read()) fh.write(response.read())
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch" assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
...@@ -125,6 +127,16 @@ def make_download_configs(urls_and_md5s, name=None): ...@@ -125,6 +127,16 @@ def make_download_configs(urls_and_md5s, name=None):
] ]
def collect_download_configs(dataset_loader, name):
try:
with log_download_attempts() as urls_and_md5s:
dataset_loader()
except Exception:
pass
return make_download_configs(urls_and_md5s, name)
def places365(): def places365():
with log_download_attempts(patch=False) as urls_and_md5s: with log_download_attempts(patch=False) as urls_and_md5s:
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)): for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
...@@ -137,23 +149,19 @@ def places365(): ...@@ -137,23 +149,19 @@ def places365():
def caltech101(): def caltech101():
try: return collect_download_configs(lambda: datasets.Caltech101(".", download=True), "Caltech101")
with log_download_attempts() as urls_and_md5s:
datasets.Caltech101(".", download=True)
except Exception:
pass
return make_download_configs(urls_and_md5s, "Caltech101")
def caltech256(): def caltech256():
try: return collect_download_configs(lambda: datasets.Caltech256(".", download=True), "Caltech256")
with log_download_attempts() as urls_and_md5s:
datasets.Caltech256(".", download=True)
except Exception: def cifar10():
pass return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR10")
return make_download_configs(urls_and_md5s, "Caltech256") def cifar100():
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR100")
def make_parametrize_kwargs(download_configs): def make_parametrize_kwargs(download_configs):
...@@ -166,7 +174,9 @@ def make_parametrize_kwargs(download_configs): ...@@ -166,7 +174,9 @@ def make_parametrize_kwargs(download_configs):
return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids) return dict(argnames=("url", "md5"), argvalues=argvalues, ids=ids)
@pytest.mark.parametrize(**make_parametrize_kwargs(itertools.chain(places365(), caltech101(), caltech256()))) @pytest.mark.parametrize(
**make_parametrize_kwargs(itertools.chain(places365(), caltech101(), caltech256(), cifar10(), cifar100()))
)
def test_url_is_accessible(url, md5): def test_url_is_accessible(url, md5):
retry(lambda: assert_url_is_accessible(url)) retry(lambda: assert_url_is_accessible(url))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment