Unverified Commit bda17ddf authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add tests for the STL10 dataset (#3345)



* extract some functionality from places365 fakedata for common use

* add a common DatasetTestcase

* add fakedata generation and tests for STL10

* lint
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent c645f9d2
......@@ -13,6 +13,48 @@ from torchvision.io.video import write_video
import unittest.mock
import hashlib
from distutils import dir_util
import re
def mock_class_attribute(stack, target, new):
mock = unittest.mock.patch(target, new_callable=unittest.mock.PropertyMock, return_value=new)
stack.enter_context(mock)
return mock
def compute_md5(file):
with open(file, "rb") as fh:
return hashlib.md5(fh.read()).hexdigest()
def make_tar(root, name, *files, compression=None):
ext = ".tar"
mode = "w"
if compression is not None:
ext = f"{ext}.{compression}"
mode = f"{mode}:{compression}"
name = os.path.splitext(name)[0] + ext
archive = os.path.join(root, name)
with tarfile.open(archive, mode) as fh:
for file in files:
fh.add(os.path.join(root, file), arcname=file)
return name, compute_md5(archive)
def clean_dir(root, *keep):
pattern = re.compile(f"({f')|('.join(keep)})")
for file_or_dir in os.listdir(root):
if pattern.search(file_or_dir):
continue
file_or_dir = os.path.join(root, file_or_dir)
if os.path.isfile(file_or_dir):
os.remove(file_or_dir)
else:
dir_util.remove_tree(file_or_dir)
@contextlib.contextmanager
......@@ -385,7 +427,7 @@ def ucf101_root():
@contextlib.contextmanager
def places365_root(split="train-standard", small=False, extract_images=True):
def places365_root(split="train-standard", small=False):
VARIANTS = {
"train-standard": "standard",
"train-challenge": "challenge",
......@@ -425,15 +467,6 @@ def places365_root(split="train-standard", small=False, extract_images=True):
def mock_target(attr, partial="torchvision.datasets.places365.Places365"):
return f"{partial}.{attr}"
def mock_class_attribute(stack, attr, new):
mock = unittest.mock.patch(mock_target(attr), new_callable=unittest.mock.PropertyMock, return_value=new)
stack.enter_context(mock)
return mock
def compute_md5(file):
with open(file, "rb") as fh:
return hashlib.md5(fh.read()).hexdigest()
def make_txt(root, name, seq):
file = os.path.join(root, name)
with open(file, "w") as fh:
......@@ -451,37 +484,20 @@ def places365_root(split="train-standard", small=False, extract_images=True):
os.makedirs(os.path.dirname(file), exist_ok=True)
PIL.Image.fromarray(np.zeros((*size, 3), dtype=np.uint8)).save(file)
def make_tar(root, name, *files, remove_files=True):
name = f"{os.path.splitext(name)[0]}.tar"
archive = os.path.join(root, name)
with tarfile.open(archive, "w") as fh:
for file in files:
fh.add(os.path.join(root, file), arcname=file)
if remove_files:
for file in [os.path.join(root, file) for file in files]:
if os.path.isdir(file):
dir_util.remove_tree(file)
else:
os.remove(file)
return name, compute_md5(archive)
def make_devkit_archive(stack, root, split):
archive = DEVKITS[split]
files = []
meta = make_categories_txt(root, CATEGORIES)
mock_class_attribute(stack, "_CATEGORIES_META", meta)
mock_class_attribute(stack, mock_target("_CATEGORIES_META"), meta)
files.append(meta[0])
meta = {split: make_file_list_txt(root, FILE_LISTS[split])}
mock_class_attribute(stack, "_FILE_LIST_META", meta)
mock_class_attribute(stack, mock_target("_FILE_LIST_META"), meta)
files.extend([item[0] for item in meta.values()])
meta = {VARIANTS[split]: make_tar(root, archive, *files)}
mock_class_attribute(stack, "_DEVKIT_META", meta)
mock_class_attribute(stack, mock_target("_DEVKIT_META"), meta)
def make_images_archive(stack, root, split, small):
archive, folder_default, folder_renamed = IMAGES[(split, small)]
......@@ -493,7 +509,7 @@ def places365_root(split="train-standard", small=False, extract_images=True):
make_image(os.path.join(root, folder_default, image), image_size)
meta = {(split, small): make_tar(root, archive, folder_default)}
mock_class_attribute(stack, "_IMAGES_META", meta)
mock_class_attribute(stack, mock_target("_IMAGES_META"), meta)
return [(os.path.join(root, folder_renamed, image), idx) for image, idx in zip(images, idcs)]
......@@ -501,12 +517,89 @@ def places365_root(split="train-standard", small=False, extract_images=True):
make_devkit_archive(stack, root, split)
class_to_idx = dict(CATEGORIES_CONTENT)
classes = list(class_to_idx.keys())
data = {"class_to_idx": class_to_idx, "classes": classes}
if extract_images:
data = {"class_to_idx": class_to_idx, "classes": classes}
data["imgs"] = make_images_archive(stack, root, split, small)
else:
stack.enter_context(unittest.mock.patch(mock_target("download_images")))
data["imgs"] = None
clean_dir(root, ".tar$")
yield root, data
@contextlib.contextmanager
def stl10_root(_extracted=False):
CLASS_NAMES = ("airplane", "bird")
ARCHIVE_NAME = "stl10_binary"
NUM_FOLDS = 10
def mock_target(attr, partial="torchvision.datasets.stl10.STL10"):
return f"{partial}.{attr}"
def make_binary_file(num_elements, root, name):
file = os.path.join(root, name)
np.zeros(num_elements, dtype=np.uint8).tofile(file)
return name, compute_md5(file)
def make_image_file(num_images, root, name, num_channels=3, height=96, width=96):
return make_binary_file(num_images * num_channels * height * width, root, name)
def make_label_file(num_images, root, name):
return make_binary_file(num_images, root, name)
def make_class_names_file(root, name="class_names.txt"):
with open(os.path.join(root, name), "w") as fh:
for name in CLASS_NAMES:
fh.write(f"{name}\n")
def make_fold_indices_file(root):
offset = 0
with open(os.path.join(root, "fold_indices.txt"), "w") as fh:
for fold in range(NUM_FOLDS):
line = " ".join([str(idx) for idx in range(offset, offset + fold + 1)])
fh.write(f"{line}\n")
offset += fold + 1
return tuple(range(1, NUM_FOLDS + 1))
def make_train_files(stack, root, num_unlabeled_images=1):
num_images_in_fold = make_fold_indices_file(root)
num_train_images = sum(num_images_in_fold)
train_list = [
list(make_image_file(num_train_images, root, "train_X.bin")),
list(make_label_file(num_train_images, root, "train_y.bin")),
list(make_image_file(1, root, "unlabeled_X.bin"))
]
mock_class_attribute(stack, target=mock_target("train_list"), new=train_list)
return num_images_in_fold, dict(train=num_train_images, unlabeled=num_unlabeled_images)
def make_test_files(stack, root, num_images=2):
test_list = [
list(make_image_file(num_images, root, "test_X.bin")),
list(make_label_file(num_images, root, "test_y.bin")),
]
mock_class_attribute(stack, target=mock_target("test_list"), new=test_list)
return dict(test=num_images)
def make_archive(stack, root, name):
archive, md5 = make_tar(root, name, name, compression="gz")
mock_class_attribute(stack, target=mock_target("tgz_md5"), new=md5)
return archive
with contextlib.ExitStack() as stack, get_tmp_dir() as root:
archive_folder = os.path.join(root, ARCHIVE_NAME)
os.mkdir(archive_folder)
num_images_in_folds, num_images_in_split = make_train_files(stack, archive_folder)
num_images_in_split.update(make_test_files(stack, archive_folder))
make_class_names_file(archive_folder)
archive = make_archive(stack, root, ARCHIVE_NAME)
dir_util.remove_tree(archive_folder)
data = dict(num_images_in_folds=num_images_in_folds, num_images_in_split=num_images_in_split, archive=archive)
yield root, data
import contextlib
import sys
import os
import unittest
......@@ -7,9 +8,10 @@ import PIL
from PIL import Image
from torch._utils_internal import get_file_path_2
import torchvision
from torchvision.datasets import utils
from common_utils import get_tmp_dir
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root, stl10_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools
......@@ -28,7 +30,7 @@ except ImportError:
HAS_PYAV = False
class Tester(unittest.TestCase):
class DatasetTestcase(unittest.TestCase):
def generic_classification_dataset_test(self, dataset, num_images=1):
self.assertEqual(len(dataset), num_images)
img, target = dataset[0]
......@@ -41,6 +43,8 @@ class Tester(unittest.TestCase):
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, PIL.Image.Image))
class Tester(DatasetTestcase):
def test_imagefolder(self):
# TODO: create the fake data on-the-fly
FAKEDATA_DIR = get_file_path_2(
......@@ -354,7 +358,7 @@ class Tester(unittest.TestCase):
def test_places365_devkit_no_download(self):
for split in ("train-standard", "train-challenge", "val"):
with self.subTest(split=split):
with places365_root(split=split, extract_images=False) as places365:
with places365_root(split=split) as places365:
root, data = places365
with self.assertRaises(RuntimeError):
......@@ -383,12 +387,84 @@ class Tester(unittest.TestCase):
torchvision.datasets.Places365(root, split=split, small=small, download=True)
def test_places365_repr_smoke(self):
with places365_root(extract_images=False) as places365:
with places365_root() as places365:
root, data = places365
dataset = torchvision.datasets.Places365(root, download=True)
self.assertIsInstance(repr(dataset), str)
class STL10Tester(DatasetTestcase):
@contextlib.contextmanager
def mocked_root(self):
with stl10_root() as (root, data):
yield root, data
@contextlib.contextmanager
def mocked_dataset(self, pre_extract=False, download=True, **kwargs):
with self.mocked_root() as (root, data):
if pre_extract:
utils.extract_archive(os.path.join(root, data["archive"]))
dataset = torchvision.datasets.STL10(root, download=download, **kwargs)
yield dataset, data
def test_not_found(self):
with self.assertRaises(RuntimeError):
with self.mocked_dataset(download=False):
pass
def test_splits(self):
for split in ('train', 'train+unlabeled', 'unlabeled', 'test'):
with self.mocked_dataset(split=split) as (dataset, data):
num_images = sum([data["num_images_in_split"][part] for part in split.split("+")])
self.generic_classification_dataset_test(dataset, num_images=num_images)
def test_folds(self):
for fold in range(10):
with self.mocked_dataset(split="train", folds=fold) as (dataset, data):
num_images = data["num_images_in_folds"][fold]
self.assertEqual(len(dataset), num_images)
def test_invalid_folds1(self):
with self.assertRaises(ValueError):
with self.mocked_dataset(folds=10):
pass
def test_invalid_folds2(self):
with self.assertRaises(ValueError):
with self.mocked_dataset(folds="0"):
pass
def test_transforms(self):
expected_image = "image"
expected_target = "target"
def transform(image):
return expected_image
def target_transform(target):
return expected_target
with self.mocked_dataset(transform=transform, target_transform=target_transform) as (dataset, _):
actual_image, actual_target = dataset[0]
self.assertEqual(actual_image, expected_image)
self.assertEqual(actual_target, expected_target)
def test_unlabeled(self):
with self.mocked_dataset(split="unlabeled") as (dataset, _):
labels = [dataset[idx][1] for idx in range(len(dataset))]
self.assertTrue(all([label == -1 for label in labels]))
@unittest.mock.patch("torchvision.datasets.stl10.download_and_extract_archive")
def test_download_preexisting(self, mock):
with self.mocked_dataset(pre_extract=True) as (dataset, data):
mock.assert_not_called()
def test_repr_smoke(self):
with self.mocked_dataset() as (dataset, _):
self.assertIsInstance(repr(dataset), str)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment