Unverified Commit d5bc2b24 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add download tests for MNIST (#3336)



* cleanup

* mnist

* lint
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent d5096a7f
...@@ -31,6 +31,9 @@ jobs: ...@@ -31,6 +31,9 @@ jobs:
pip install numpy pip install numpy
pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- name: Install all optional dataset requirements
run: pip install scipy pandas pycocotools lmdb requests
- name: Install tests requirements - name: Install tests requirements
run: pip install pytest run: pip install pytest
......
...@@ -49,17 +49,22 @@ urlopen = limit_requests_per_time()(urlopen) ...@@ -49,17 +49,22 @@ urlopen = limit_requests_per_time()(urlopen)
def log_download_attempts( def log_download_attempts(
urls_and_md5s=None, urls_and_md5s=None,
patch=True, patch=True,
download_url_target="torchvision.datasets.utils.download_url", download_url_location=".utils",
patch_auxiliaries=None, patch_auxiliaries=None,
): ):
if urls_and_md5s is None: if urls_and_md5s is None:
urls_and_md5s = set() urls_and_md5s = set()
if download_url_location.startswith("."):
download_url_location = f"torchvision.datasets{download_url_location}"
if patch_auxiliaries is None: if patch_auxiliaries is None:
patch_auxiliaries = patch patch_auxiliaries = patch
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
download_url_mock = stack.enter_context( download_url_mock = stack.enter_context(
unittest.mock.patch(download_url_target, wraps=None if patch else download_url) unittest.mock.patch(
f"{download_url_location}.download_url",
wraps=None if patch else download_url,
)
) )
if patch_auxiliaries: if patch_auxiliaries:
# download_and_extract_archive # download_and_extract_archive
...@@ -132,9 +137,17 @@ def make_download_configs(urls_and_md5s, name=None): ...@@ -132,9 +137,17 @@ def make_download_configs(urls_and_md5s, name=None):
] ]
def collect_download_configs(dataset_loader, name, **kwargs): def collect_download_configs(dataset_loader, name=None, **kwargs):
with contextlib.suppress(Exception), log_download_attempts(**kwargs) as urls_and_md5s: urls_and_md5s = set()
dataset_loader() try:
with log_download_attempts(urls_and_md5s=urls_and_md5s, **kwargs):
dataset = dataset_loader()
except Exception:
dataset = None
if name is None and dataset is not None:
name = type(dataset).__name__
return make_download_configs(urls_and_md5s, name) return make_download_configs(urls_and_md5s, name)
...@@ -146,34 +159,40 @@ def places365(): ...@@ -146,34 +159,40 @@ def places365():
datasets.Places365(root, split=split, small=small, download=True) datasets.Places365(root, split=split, small=small, download=True)
return make_download_configs(urls_and_md5s, "Places365") return make_download_configs(urls_and_md5s, name="Places365")
def caltech101(): def caltech101():
return collect_download_configs(lambda: datasets.Caltech101(".", download=True), "Caltech101") return collect_download_configs(lambda: datasets.Caltech101(".", download=True), name="Caltech101")
def caltech256(): def caltech256():
return collect_download_configs(lambda: datasets.Caltech256(".", download=True), "Caltech256") return collect_download_configs(lambda: datasets.Caltech256(".", download=True), name="Caltech256")
def cifar10(): def cifar10():
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR10") return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR10")
def cifar100(): def cifar100():
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR100") return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR100")
def voc(): def voc():
download_configs = [] return itertools.chain(
for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012"): *[
with contextlib.suppress(Exception), log_download_attempts( collect_download_configs(
download_url_target="torchvision.datasets.voc.download_url" lambda: datasets.VOCSegmentation(".", year=year, download=True),
) as urls_and_md5s: name=f"VOC, {year}",
datasets.VOCSegmentation(".", year=year, download=True) download_url_location=".voc",
download_configs.extend(make_download_configs(urls_and_md5s, f"VOC, {year}")) )
return download_configs for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012")
]
)
def mnist():
return collect_download_configs(lambda: datasets.MNIST(".", download=True), name="MNIST")
def make_parametrize_kwargs(download_configs): def make_parametrize_kwargs(download_configs):
...@@ -196,6 +215,7 @@ def make_parametrize_kwargs(download_configs): ...@@ -196,6 +215,7 @@ def make_parametrize_kwargs(download_configs):
cifar100(), cifar100(),
# The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details. # The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
# voc(), # voc(),
mnist(),
) )
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment