Unverified Commit d1e134c7 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add download tests for VOC (#2834)

parent 57b653c5
...@@ -46,7 +46,12 @@ urlopen = limit_requests_per_time()(urlopen) ...@@ -46,7 +46,12 @@ urlopen = limit_requests_per_time()(urlopen)
@contextlib.contextmanager @contextlib.contextmanager
def log_download_attempts(urls_and_md5s=None, patch=True, patch_auxiliaries=None): def log_download_attempts(
urls_and_md5s=None,
patch=True,
download_url_target="torchvision.datasets.utils.download_url",
patch_auxiliaries=None,
):
if urls_and_md5s is None: if urls_and_md5s is None:
urls_and_md5s = set() urls_and_md5s = set()
if patch_auxiliaries is None: if patch_auxiliaries is None:
...@@ -54,7 +59,7 @@ def log_download_attempts(urls_and_md5s=None, patch=True, patch_auxiliaries=None ...@@ -54,7 +59,7 @@ def log_download_attempts(urls_and_md5s=None, patch=True, patch_auxiliaries=None
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
download_url_mock = stack.enter_context( download_url_mock = stack.enter_context(
unittest.mock.patch("torchvision.datasets.utils.download_url", wraps=None if patch else download_url) unittest.mock.patch(download_url_target, wraps=None if patch else download_url)
) )
if patch_auxiliaries: if patch_auxiliaries:
# download_and_extract_archive # download_and_extract_archive
...@@ -127,13 +132,9 @@ def make_download_configs(urls_and_md5s, name=None): ...@@ -127,13 +132,9 @@ def make_download_configs(urls_and_md5s, name=None):
] ]
def collect_download_configs(dataset_loader, name): def collect_download_configs(dataset_loader, name, **kwargs):
try: with contextlib.suppress(Exception), log_download_attempts(**kwargs) as urls_and_md5s:
with log_download_attempts() as urls_and_md5s:
dataset_loader() dataset_loader()
except Exception:
pass
return make_download_configs(urls_and_md5s, name) return make_download_configs(urls_and_md5s, name)
...@@ -164,6 +165,17 @@ def cifar100(): ...@@ -164,6 +165,17 @@ def cifar100():
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR100") return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR100")
def voc():
download_configs = []
for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012"):
with contextlib.suppress(Exception), log_download_attempts(
download_url_target="torchvision.datasets.voc.download_url"
) as urls_and_md5s:
datasets.VOCSegmentation(".", year=year, download=True)
download_configs.extend(make_download_configs(urls_and_md5s, f"VOC, {year}"))
return download_configs
def make_parametrize_kwargs(download_configs): def make_parametrize_kwargs(download_configs):
argvalues = [] argvalues = []
ids = [] ids = []
...@@ -175,7 +187,16 @@ def make_parametrize_kwargs(download_configs): ...@@ -175,7 +187,16 @@ def make_parametrize_kwargs(download_configs):
@pytest.mark.parametrize( @pytest.mark.parametrize(
**make_parametrize_kwargs(itertools.chain(places365(), caltech101(), caltech256(), cifar10(), cifar100())) **make_parametrize_kwargs(
itertools.chain(
places365(),
caltech101(),
caltech256(),
cifar10(),
cifar100(),
voc(),
)
)
) )
def test_url_is_accessible(url, md5): def test_url_is_accessible(url, md5):
retry(lambda: assert_url_is_accessible(url)) retry(lambda: assert_url_is_accessible(url))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment