Unverified Commit 5acd5c9b authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove generated dataset folders after download tests (#3376)

* remove generated dataset folders after download tests

* revert unnecessary style changes
parent 5905de08
......@@ -3,10 +3,12 @@ import itertools
import time
import unittest.mock
from datetime import datetime
from distutils import dir_util
from os import path
from urllib.error import HTTPError, URLError
from urllib.parse import urlparse
from urllib.request import urlopen, Request
import tempfile
import warnings
import pytest
......@@ -194,6 +196,17 @@ def collect_download_configs(dataset_loader, name=None, **kwargs):
return make_download_configs(urls_and_md5s, name)
# This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a
# parametrization. Thus, we use a single root directory for all datasets and remove it when all download tests are run.
ROOT = tempfile.mkdtemp()
@pytest.fixture(scope="module", autouse=True)
def root():
yield ROOT
dir_util.remove_tree(ROOT)
def places365():
with log_download_attempts(patch=False) as urls_and_md5s:
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
......@@ -206,26 +219,26 @@ def places365():
def caltech101():
return collect_download_configs(lambda: datasets.Caltech101(".", download=True), name="Caltech101")
return collect_download_configs(lambda: datasets.Caltech101(ROOT, download=True), name="Caltech101")
def caltech256():
return collect_download_configs(lambda: datasets.Caltech256(".", download=True), name="Caltech256")
return collect_download_configs(lambda: datasets.Caltech256(ROOT, download=True), name="Caltech256")
def cifar10():
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR10")
return collect_download_configs(lambda: datasets.CIFAR10(ROOT, download=True), name="CIFAR10")
def cifar100():
return collect_download_configs(lambda: datasets.CIFAR100(".", download=True), name="CIFAR100")
return collect_download_configs(lambda: datasets.CIFAR100(ROOT, download=True), name="CIFAR100")
def voc():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.VOCSegmentation(".", year=year, download=True),
lambda: datasets.VOCSegmentation(ROOT, year=year, download=True),
name=f"VOC, {year}",
file="voc",
)
......@@ -235,27 +248,27 @@ def voc():
def mnist():
return collect_download_configs(lambda: datasets.MNIST(".", download=True), name="MNIST")
return collect_download_configs(lambda: datasets.MNIST(ROOT, download=True), name="MNIST")
def fashion_mnist():
return collect_download_configs(lambda: datasets.FashionMNIST(".", download=True), name="FashionMNIST")
return collect_download_configs(lambda: datasets.FashionMNIST(ROOT, download=True), name="FashionMNIST")
def kmnist():
return collect_download_configs(lambda: datasets.KMNIST(".", download=True), name="KMNIST")
return collect_download_configs(lambda: datasets.KMNIST(ROOT, download=True), name="KMNIST")
def emnist():
# the 'split' argument can be any valid one, since everything is downloaded anyway
return collect_download_configs(lambda: datasets.EMNIST(".", split="byclass", download=True), name="EMNIST")
return collect_download_configs(lambda: datasets.EMNIST(ROOT, split="byclass", download=True), name="EMNIST")
def qmnist():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.QMNIST(".", what=what, download=True),
lambda: datasets.QMNIST(ROOT, what=what, download=True),
name=f"QMNIST, {what}",
file="mnist",
)
......@@ -268,7 +281,7 @@ def omniglot():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.Omniglot(".", background=background, download=True),
lambda: datasets.Omniglot(ROOT, background=background, download=True),
name=f"Omniglot, {'background' if background else 'evaluation'}",
)
for background in (True, False)
......@@ -280,7 +293,7 @@ def phototour():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.PhotoTour(".", name=name, download=True),
lambda: datasets.PhotoTour(ROOT, name=name, download=True),
name=f"PhotoTour, {name}",
file="phototour",
)
......@@ -293,7 +306,7 @@ def phototour():
def sbdataset():
return collect_download_configs(
lambda: datasets.SBDataset(".", download=True),
lambda: datasets.SBDataset(ROOT, download=True),
name="SBDataset",
file="voc",
)
......@@ -301,7 +314,7 @@ def sbdataset():
def sbu():
return collect_download_configs(
lambda: datasets.SBU(".", download=True),
lambda: datasets.SBU(ROOT, download=True),
name="SBU",
file="sbu",
)
......@@ -309,7 +322,7 @@ def sbu():
def semeion():
return collect_download_configs(
lambda: datasets.SEMEION(".", download=True),
lambda: datasets.SEMEION(ROOT, download=True),
name="SEMEION",
file="semeion",
)
......@@ -317,7 +330,7 @@ def semeion():
def stl10():
return collect_download_configs(
lambda: datasets.STL10(".", download=True),
lambda: datasets.STL10(ROOT, download=True),
name="STL10",
)
......@@ -326,7 +339,7 @@ def svhn():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.SVHN(".", split=split, download=True),
lambda: datasets.SVHN(ROOT, split=split, download=True),
name=f"SVHN, {split}",
file="svhn",
)
......@@ -339,7 +352,7 @@ def usps():
return itertools.chain(
*[
collect_download_configs(
lambda: datasets.USPS(".", train=train, download=True),
lambda: datasets.USPS(ROOT, train=train, download=True),
name=f"USPS, {'train' if train else 'test'}",
file="usps",
)
......@@ -350,7 +363,7 @@ def usps():
def celeba():
return collect_download_configs(
lambda: datasets.CelebA(".", download=True),
lambda: datasets.CelebA(ROOT, download=True),
name="CelebA",
file="celeba",
)
......@@ -358,7 +371,7 @@ def celeba():
def widerface():
return collect_download_configs(
lambda: datasets.WIDERFace(".", download=True),
lambda: datasets.WIDERFace(ROOT, download=True),
name="WIDERFace",
file="widerface",
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment