Unverified Commit a89da92b authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Ported STL10 dataset's tests to new test framework (#3665)

* Ported STL10 dataset's tests to new test framework

* Added additional tests

* Removed unused import

* Made private methods static and other minor changes
parent a18b4af1
......@@ -308,82 +308,3 @@ def places365_root(split="train-standard", small=False):
clean_dir(root, ".tar$")
yield root, data
@contextlib.contextmanager
def stl10_root(_extracted=False):
CLASS_NAMES = ("airplane", "bird")
ARCHIVE_NAME = "stl10_binary"
NUM_FOLDS = 10
def mock_target(attr, partial="torchvision.datasets.stl10.STL10"):
return f"{partial}.{attr}"
def make_binary_file(num_elements, root, name):
file = os.path.join(root, name)
np.zeros(num_elements, dtype=np.uint8).tofile(file)
return name, compute_md5(file)
def make_image_file(num_images, root, name, num_channels=3, height=96, width=96):
return make_binary_file(num_images * num_channels * height * width, root, name)
def make_label_file(num_images, root, name):
return make_binary_file(num_images, root, name)
def make_class_names_file(root, name="class_names.txt"):
with open(os.path.join(root, name), "w") as fh:
for name in CLASS_NAMES:
fh.write(f"{name}\n")
def make_fold_indices_file(root):
offset = 0
with open(os.path.join(root, "fold_indices.txt"), "w") as fh:
for fold in range(NUM_FOLDS):
line = " ".join([str(idx) for idx in range(offset, offset + fold + 1)])
fh.write(f"{line}\n")
offset += fold + 1
return tuple(range(1, NUM_FOLDS + 1))
def make_train_files(stack, root, num_unlabeled_images=1):
num_images_in_fold = make_fold_indices_file(root)
num_train_images = sum(num_images_in_fold)
train_list = [
list(make_image_file(num_train_images, root, "train_X.bin")),
list(make_label_file(num_train_images, root, "train_y.bin")),
list(make_image_file(1, root, "unlabeled_X.bin"))
]
mock_class_attribute(stack, target=mock_target("train_list"), new=train_list)
return num_images_in_fold, dict(train=num_train_images, unlabeled=num_unlabeled_images)
def make_test_files(stack, root, num_images=2):
test_list = [
list(make_image_file(num_images, root, "test_X.bin")),
list(make_label_file(num_images, root, "test_y.bin")),
]
mock_class_attribute(stack, target=mock_target("test_list"), new=test_list)
return dict(test=num_images)
def make_archive(stack, root, name):
archive, md5 = make_tar(root, name, name, compression="gz")
mock_class_attribute(stack, target=mock_target("tgz_md5"), new=md5)
return archive
with contextlib.ExitStack() as stack, get_tmp_dir() as root:
archive_folder = os.path.join(root, ARCHIVE_NAME)
os.mkdir(archive_folder)
num_images_in_folds, num_images_in_split = make_train_files(stack, archive_folder)
num_images_in_split.update(make_test_files(stack, archive_folder))
make_class_names_file(archive_folder)
archive = make_archive(stack, root, ARCHIVE_NAME)
dir_util.remove_tree(archive_folder)
data = dict(num_images_in_folds=num_images_in_folds, num_images_in_split=num_images_in_split, archive=archive)
yield root, data
......@@ -9,7 +9,7 @@ from torch._utils_internal import get_file_path_2
import torchvision
from torchvision.datasets import utils
from common_utils import get_tmp_dir
from fakedata_generation import places365_root, widerface_root, stl10_root
from fakedata_generation import places365_root
import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen
import itertools
......@@ -141,76 +141,89 @@ class Tester(DatasetTestcase):
self.assertIsInstance(repr(dataset), str)
class STL10Tester(DatasetTestcase):
@contextlib.contextmanager
def mocked_root(self):
with stl10_root() as (root, data):
yield root, data
class STL10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.STL10
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train", "test", "unlabeled", "train+unlabeled"))
@contextlib.contextmanager
def mocked_dataset(self, pre_extract=False, download=True, **kwargs):
with self.mocked_root() as (root, data):
if pre_extract:
utils.extract_archive(os.path.join(root, data["archive"]))
dataset = torchvision.datasets.STL10(root, download=download, **kwargs)
yield dataset, data
def test_not_found(self):
with self.assertRaises(RuntimeError):
with self.mocked_dataset(download=False):
pass
@staticmethod
def _make_binary_file(num_elements, root, name):
file_name = os.path.join(root, name)
np.zeros(num_elements, dtype=np.uint8).tofile(file_name)
def test_splits(self):
for split in ('train', 'train+unlabeled', 'unlabeled', 'test'):
with self.mocked_dataset(split=split) as (dataset, data):
num_images = sum([data["num_images_in_split"][part] for part in split.split("+")])
self.generic_classification_dataset_test(dataset, num_images=num_images)
@staticmethod
def _make_image_file(num_images, root, name, num_channels=3, height=96, width=96):
STL10TestCase._make_binary_file(num_images * num_channels * height * width, root, name)
def test_folds(self):
for fold in range(10):
with self.mocked_dataset(split="train", folds=fold) as (dataset, data):
num_images = data["num_images_in_folds"][fold]
self.assertEqual(len(dataset), num_images)
@staticmethod
def _make_label_file(num_images, root, name):
STL10TestCase._make_binary_file(num_images, root, name)
def test_invalid_folds1(self):
with self.assertRaises(ValueError):
with self.mocked_dataset(folds=10):
pass
@staticmethod
def _make_class_names_file(root, name="class_names.txt"):
with open(os.path.join(root, name), "w") as fh:
for cname in ("airplane", "bird"):
fh.write(f"{cname}\n")
def test_invalid_folds2(self):
with self.assertRaises(ValueError):
with self.mocked_dataset(folds="0"):
pass
@staticmethod
def _make_fold_indices_file(root):
num_folds = 10
offset = 0
with open(os.path.join(root, "fold_indices.txt"), "w") as fh:
for fold in range(num_folds):
line = " ".join([str(idx) for idx in range(offset, offset + fold + 1)])
fh.write(f"{line}\n")
offset += fold + 1
def test_transforms(self):
expected_image = "image"
expected_target = "target"
return tuple(range(1, num_folds + 1))
def transform(image):
return expected_image
@staticmethod
def _make_train_files(root, num_unlabeled_images=1):
num_images_in_fold = STL10TestCase._make_fold_indices_file(root)
num_train_images = sum(num_images_in_fold)
def target_transform(target):
return expected_target
STL10TestCase._make_image_file(num_train_images, root, "train_X.bin")
STL10TestCase._make_label_file(num_train_images, root, "train_y.bin")
STL10TestCase._make_image_file(1, root, "unlabeled_X.bin")
with self.mocked_dataset(transform=transform, target_transform=target_transform) as (dataset, _):
actual_image, actual_target = dataset[0]
return dict(train=num_train_images, unlabeled=num_unlabeled_images)
self.assertEqual(actual_image, expected_image)
self.assertEqual(actual_target, expected_target)
@staticmethod
def _make_test_files(root, num_images=2):
STL10TestCase._make_image_file(num_images, root, "test_X.bin")
STL10TestCase._make_label_file(num_images, root, "test_y.bin")
return dict(test=num_images)
def inject_fake_data(self, tmpdir, config):
root_folder = os.path.join(tmpdir, "stl10_binary")
os.mkdir(root_folder)
num_images_in_split = self._make_train_files(root_folder)
num_images_in_split.update(self._make_test_files(root_folder))
self._make_class_names_file(root_folder)
return sum(num_images_in_split[part] for part in config["split"].split("+"))
def test_folds(self):
for fold in range(10):
with self.create_dataset(split="train", folds=fold) as (dataset, _):
self.assertEqual(len(dataset), fold + 1)
def test_unlabeled(self):
with self.mocked_dataset(split="unlabeled") as (dataset, _):
with self.create_dataset(split="unlabeled") as (dataset, _):
labels = [dataset[idx][1] for idx in range(len(dataset))]
self.assertTrue(all([label == -1 for label in labels]))
self.assertTrue(all(label == -1 for label in labels))
@unittest.mock.patch("torchvision.datasets.stl10.download_and_extract_archive")
def test_download_preexisting(self, mock):
with self.mocked_dataset(pre_extract=True) as (dataset, data):
mock.assert_not_called()
def test_invalid_folds1(self):
with self.assertRaises(ValueError):
with self.create_dataset(folds=10):
pass
def test_repr_smoke(self):
with self.mocked_dataset() as (dataset, _):
self.assertIsInstance(repr(dataset), str)
def test_invalid_folds2(self):
with self.assertRaises(ValueError):
with self.create_dataset(folds="0"):
pass
class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment