Unverified Commit d1e134c7 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add download tests for VOC (#2834)

parent 57b653c5
......@@ -46,7 +46,12 @@ urlopen = limit_requests_per_time()(urlopen)
@contextlib.contextmanager
def log_download_attempts(urls_and_md5s=None, patch=True, patch_auxiliaries=None):
def log_download_attempts(
urls_and_md5s=None,
patch=True,
download_url_target="torchvision.datasets.utils.download_url",
patch_auxiliaries=None,
):
if urls_and_md5s is None:
urls_and_md5s = set()
if patch_auxiliaries is None:
......@@ -54,7 +59,7 @@ def log_download_attempts(urls_and_md5s=None, patch=True, patch_auxiliaries=None
with contextlib.ExitStack() as stack:
download_url_mock = stack.enter_context(
unittest.mock.patch("torchvision.datasets.utils.download_url", wraps=None if patch else download_url)
unittest.mock.patch(download_url_target, wraps=None if patch else download_url)
)
if patch_auxiliaries:
# download_and_extract_archive
......@@ -127,13 +132,9 @@ def make_download_configs(urls_and_md5s, name=None):
]
def collect_download_configs(dataset_loader, name):
try:
with log_download_attempts() as urls_and_md5s:
def collect_download_configs(dataset_loader, name, **kwargs):
with contextlib.suppress(Exception), log_download_attempts(**kwargs) as urls_and_md5s:
dataset_loader()
except Exception:
pass
return make_download_configs(urls_and_md5s, name)
......@@ -164,6 +165,17 @@ def cifar100():
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR100")
def voc():
download_configs = []
for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012"):
with contextlib.suppress(Exception), log_download_attempts(
download_url_target="torchvision.datasets.voc.download_url"
) as urls_and_md5s:
datasets.VOCSegmentation(".", year=year, download=True)
download_configs.extend(make_download_configs(urls_and_md5s, f"VOC, {year}"))
return download_configs
def make_parametrize_kwargs(download_configs):
argvalues = []
ids = []
......@@ -175,7 +187,16 @@ def make_parametrize_kwargs(download_configs):
@pytest.mark.parametrize(
**make_parametrize_kwargs(itertools.chain(places365(), caltech101(), caltech256(), cifar10(), cifar100()))
**make_parametrize_kwargs(
itertools.chain(
places365(),
caltech101(),
caltech256(),
cifar10(),
cifar100(),
voc(),
)
)
)
def test_url_is_accessible(url, md5):
retry(lambda: assert_url_is_accessible(url))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment