"src/libtorchaudio/utils.cpp" did not exist on "f70b970ab11694c035db0062df6b60f2550c1d43"
Unverified Commit d5bc2b24 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add download tests for MNIST (#3336)



* cleanup

* mnist

* lint
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent d5096a7f
......@@ -31,6 +31,9 @@ jobs:
pip install numpy
pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- name: Install all optional dataset requirements
run: pip install scipy pandas pycocotools lmdb requests
- name: Install tests requirements
run: pip install pytest
......
......@@ -49,17 +49,22 @@ urlopen = limit_requests_per_time()(urlopen)
def log_download_attempts(
urls_and_md5s=None,
patch=True,
download_url_target="torchvision.datasets.utils.download_url",
download_url_location=".utils",
patch_auxiliaries=None,
):
if urls_and_md5s is None:
urls_and_md5s = set()
if download_url_location.startswith("."):
download_url_location = f"torchvision.datasets{download_url_location}"
if patch_auxiliaries is None:
patch_auxiliaries = patch
with contextlib.ExitStack() as stack:
download_url_mock = stack.enter_context(
unittest.mock.patch(download_url_target, wraps=None if patch else download_url)
unittest.mock.patch(
f"{download_url_location}.download_url",
wraps=None if patch else download_url,
)
)
if patch_auxiliaries:
# download_and_extract_archive
......@@ -132,9 +137,17 @@ def make_download_configs(urls_and_md5s, name=None):
]
def collect_download_configs(dataset_loader, name, **kwargs):
with contextlib.suppress(Exception), log_download_attempts(**kwargs) as urls_and_md5s:
dataset_loader()
def collect_download_configs(dataset_loader, name=None, **kwargs):
urls_and_md5s = set()
try:
with log_download_attempts(urls_and_md5s=urls_and_md5s, **kwargs):
dataset = dataset_loader()
except Exception:
dataset = None
if name is None and dataset is not None:
name = type(dataset).__name__
return make_download_configs(urls_and_md5s, name)
......@@ -146,34 +159,40 @@ def places365():
datasets.Places365(root, split=split, small=small, download=True)
return make_download_configs(urls_and_md5s, "Places365")
return make_download_configs(urls_and_md5s, name="Places365")
def caltech101():
return collect_download_configs(lambda: datasets.Caltech101(".", download=True), "Caltech101")
return collect_download_configs(lambda: datasets.Caltech101(".", download=True), name="Caltech101")
def caltech256():
return collect_download_configs(lambda: datasets.Caltech256(".", download=True), "Caltech256")
return collect_download_configs(lambda: datasets.Caltech256(".", download=True), name="Caltech256")
def cifar10():
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR10")
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR10")
def cifar100():
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), "CIFAR100")
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR100")
def voc():
download_configs = []
for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012"):
with contextlib.suppress(Exception), log_download_attempts(
download_url_target="torchvision.datasets.voc.download_url"
) as urls_and_md5s:
datasets.VOCSegmentation(".", year=year, download=True)
download_configs.extend(make_download_configs(urls_and_md5s, f"VOC, {year}"))
return download_configs
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.VOCSegmentation(".", year=year, download=True),
name=f"VOC, {year}",
download_url_location=".voc",
)
for year in ("2007", "2007-test", "2008", "2009", "2010", "2011", "2012")
]
)
def mnist():
return collect_download_configs(lambda: datasets.MNIST(".", download=True), name="MNIST")
def make_parametrize_kwargs(download_configs):
......@@ -196,6 +215,7 @@ def make_parametrize_kwargs(download_configs):
cifar100(),
# The VOC download server is unstable. See https://github.com/pytorch/vision/issues/2953 for details.
# voc(),
mnist(),
)
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment