Unverified Commit 5acd5c9b authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove generated dataset folders after download tests (#3376)

* remove generated dataset folders after download tests

* revert unnecessary style changes
parent 5905de08
...@@ -3,10 +3,12 @@ import itertools ...@@ -3,10 +3,12 @@ import itertools
import time import time
import unittest.mock import unittest.mock
from datetime import datetime from datetime import datetime
from distutils import dir_util
from os import path from os import path
from urllib.error import HTTPError, URLError from urllib.error import HTTPError, URLError
from urllib.parse import urlparse from urllib.parse import urlparse
from urllib.request import urlopen, Request from urllib.request import urlopen, Request
import tempfile
import warnings import warnings
import pytest import pytest
...@@ -194,6 +196,17 @@ def collect_download_configs(dataset_loader, name=None, **kwargs): ...@@ -194,6 +196,17 @@ def collect_download_configs(dataset_loader, name=None, **kwargs):
return make_download_configs(urls_and_md5s, name) return make_download_configs(urls_and_md5s, name)
# This is a workaround since fixtures, such as the built-in tmp_dir, can only be used within a test but not within a
# parametrization. Thus, we use a single root directory for all datasets and remove it when all download tests are run.
ROOT = tempfile.mkdtemp()
@pytest.fixture(scope="module", autouse=True)
def root():
yield ROOT
dir_util.remove_tree(ROOT)
def places365(): def places365():
with log_download_attempts(patch=False) as urls_and_md5s: with log_download_attempts(patch=False) as urls_and_md5s:
for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)): for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
...@@ -206,26 +219,26 @@ def places365(): ...@@ -206,26 +219,26 @@ def places365():
def caltech101(): def caltech101():
return collect_download_configs(lambda: datasets.Caltech101(".", download=True), name="Caltech101") return collect_download_configs(lambda: datasets.Caltech101(ROOT, download=True), name="Caltech101")
def caltech256(): def caltech256():
return collect_download_configs(lambda: datasets.Caltech256(".", download=True), name="Caltech256") return collect_download_configs(lambda: datasets.Caltech256(ROOT, download=True), name="Caltech256")
def cifar10(): def cifar10():
return collect_download_configs(lambda: datasets.CIFAR10(".", download=True), name="CIFAR10") return collect_download_configs(lambda: datasets.CIFAR10(ROOT, download=True), name="CIFAR10")
def cifar100(): def cifar100():
return collect_download_configs(lambda: datasets.CIFAR100(".", download=True), name="CIFAR100") return collect_download_configs(lambda: datasets.CIFAR100(ROOT, download=True), name="CIFAR100")
def voc(): def voc():
return itertools.chain( return itertools.chain(
*[ *[
collect_download_configs( collect_download_configs(
lambda: datasets.VOCSegmentation(".", year=year, download=True), lambda: datasets.VOCSegmentation(ROOT, year=year, download=True),
name=f"VOC, {year}", name=f"VOC, {year}",
file="voc", file="voc",
) )
...@@ -235,27 +248,27 @@ def voc(): ...@@ -235,27 +248,27 @@ def voc():
def mnist(): def mnist():
return collect_download_configs(lambda: datasets.MNIST(".", download=True), name="MNIST") return collect_download_configs(lambda: datasets.MNIST(ROOT, download=True), name="MNIST")
def fashion_mnist(): def fashion_mnist():
return collect_download_configs(lambda: datasets.FashionMNIST(".", download=True), name="FashionMNIST") return collect_download_configs(lambda: datasets.FashionMNIST(ROOT, download=True), name="FashionMNIST")
def kmnist(): def kmnist():
return collect_download_configs(lambda: datasets.KMNIST(".", download=True), name="KMNIST") return collect_download_configs(lambda: datasets.KMNIST(ROOT, download=True), name="KMNIST")
def emnist(): def emnist():
# the 'split' argument can be any valid one, since everything is downloaded anyway # the 'split' argument can be any valid one, since everything is downloaded anyway
return collect_download_configs(lambda: datasets.EMNIST(".", split="byclass", download=True), name="EMNIST") return collect_download_configs(lambda: datasets.EMNIST(ROOT, split="byclass", download=True), name="EMNIST")
def qmnist(): def qmnist():
return itertools.chain( return itertools.chain(
*[ *[
collect_download_configs( collect_download_configs(
lambda: datasets.QMNIST(".", what=what, download=True), lambda: datasets.QMNIST(ROOT, what=what, download=True),
name=f"QMNIST, {what}", name=f"QMNIST, {what}",
file="mnist", file="mnist",
) )
...@@ -268,7 +281,7 @@ def omniglot(): ...@@ -268,7 +281,7 @@ def omniglot():
return itertools.chain( return itertools.chain(
*[ *[
collect_download_configs( collect_download_configs(
lambda: datasets.Omniglot(".", background=background, download=True), lambda: datasets.Omniglot(ROOT, background=background, download=True),
name=f"Omniglot, {'background' if background else 'evaluation'}", name=f"Omniglot, {'background' if background else 'evaluation'}",
) )
for background in (True, False) for background in (True, False)
...@@ -280,7 +293,7 @@ def phototour(): ...@@ -280,7 +293,7 @@ def phototour():
return itertools.chain( return itertools.chain(
*[ *[
collect_download_configs( collect_download_configs(
lambda: datasets.PhotoTour(".", name=name, download=True), lambda: datasets.PhotoTour(ROOT, name=name, download=True),
name=f"PhotoTour, {name}", name=f"PhotoTour, {name}",
file="phototour", file="phototour",
) )
...@@ -293,7 +306,7 @@ def phototour(): ...@@ -293,7 +306,7 @@ def phototour():
def sbdataset(): def sbdataset():
return collect_download_configs( return collect_download_configs(
lambda: datasets.SBDataset(".", download=True), lambda: datasets.SBDataset(ROOT, download=True),
name="SBDataset", name="SBDataset",
file="voc", file="voc",
) )
...@@ -301,7 +314,7 @@ def sbdataset(): ...@@ -301,7 +314,7 @@ def sbdataset():
def sbu(): def sbu():
return collect_download_configs( return collect_download_configs(
lambda: datasets.SBU(".", download=True), lambda: datasets.SBU(ROOT, download=True),
name="SBU", name="SBU",
file="sbu", file="sbu",
) )
...@@ -309,7 +322,7 @@ def sbu(): ...@@ -309,7 +322,7 @@ def sbu():
def semeion(): def semeion():
return collect_download_configs( return collect_download_configs(
lambda: datasets.SEMEION(".", download=True), lambda: datasets.SEMEION(ROOT, download=True),
name="SEMEION", name="SEMEION",
file="semeion", file="semeion",
) )
...@@ -317,7 +330,7 @@ def semeion(): ...@@ -317,7 +330,7 @@ def semeion():
def stl10(): def stl10():
return collect_download_configs( return collect_download_configs(
lambda: datasets.STL10(".", download=True), lambda: datasets.STL10(ROOT, download=True),
name="STL10", name="STL10",
) )
...@@ -326,7 +339,7 @@ def svhn(): ...@@ -326,7 +339,7 @@ def svhn():
return itertools.chain( return itertools.chain(
*[ *[
collect_download_configs( collect_download_configs(
lambda: datasets.SVHN(".", split=split, download=True), lambda: datasets.SVHN(ROOT, split=split, download=True),
name=f"SVHN, {split}", name=f"SVHN, {split}",
file="svhn", file="svhn",
) )
...@@ -339,7 +352,7 @@ def usps(): ...@@ -339,7 +352,7 @@ def usps():
return itertools.chain( return itertools.chain(
*[ *[
collect_download_configs( collect_download_configs(
lambda: datasets.USPS(".", train=train, download=True), lambda: datasets.USPS(ROOT, train=train, download=True),
name=f"USPS, {'train' if train else 'test'}", name=f"USPS, {'train' if train else 'test'}",
file="usps", file="usps",
) )
...@@ -350,7 +363,7 @@ def usps(): ...@@ -350,7 +363,7 @@ def usps():
def celeba(): def celeba():
return collect_download_configs( return collect_download_configs(
lambda: datasets.CelebA(".", download=True), lambda: datasets.CelebA(ROOT, download=True),
name="CelebA", name="CelebA",
file="celeba", file="celeba",
) )
...@@ -358,7 +371,7 @@ def celeba(): ...@@ -358,7 +371,7 @@ def celeba():
def widerface(): def widerface():
return collect_download_configs( return collect_download_configs(
lambda: datasets.WIDERFace(".", download=True), lambda: datasets.WIDERFace(ROOT, download=True),
name="WIDERFace", name="WIDERFace",
file="widerface", file="widerface",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment