Commit 3c81d474 authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

Add a generic test for the datasets (#1015)

* added a generic test for the datasets

* addressed requested changes

- renamed generic*() to generic_classification*()
- moved function inside Tester
- test class_to_idx attribute outside of generic_classification*()
parent 250bac89
...@@ -29,7 +29,7 @@ def mnist_root(num_images, cls_name): ...@@ -29,7 +29,7 @@ def mnist_root(num_images, cls_name):
f.write(img.numpy().tobytes()) f.write(img.numpy().tobytes())
def _make_label_file(filename, num_images): def _make_label_file(filename, num_images):
labels = torch.randint(0, 10, size=(num_images,), dtype=torch.uint8) labels = torch.zeros((num_images,), dtype=torch.uint8)
with open(filename, "wb") as f: with open(filename, "wb") as f:
f.write(_encode(2049)) # magic header f.write(_encode(2049)) # magic header
f.write(_encode(num_images)) f.write(_encode(num_images))
......
...@@ -10,6 +10,12 @@ from fakedata_generation import mnist_root, cifar_root, imagenet_root ...@@ -10,6 +10,12 @@ from fakedata_generation import mnist_root, cifar_root, imagenet_root
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
def generic_classification_dataset_test(self, dataset, num_images=1):
self.assertEqual(len(dataset), num_images)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
def test_imagefolder(self): def test_imagefolder(self):
# TODO: create the fake data on-the-fly # TODO: create the fake data on-the-fly
FAKEDATA_DIR = get_file_path_2( FAKEDATA_DIR = get_file_path_2(
...@@ -64,47 +70,36 @@ class Tester(unittest.TestCase): ...@@ -64,47 +70,36 @@ class Tester(unittest.TestCase):
num_examples = 30 num_examples = 30
with mnist_root(num_examples, "MNIST") as root: with mnist_root(num_examples, "MNIST") as root:
dataset = torchvision.datasets.MNIST(root, download=True) dataset = torchvision.datasets.MNIST(root, download=True)
self.assertEqual(len(dataset), num_examples) self.generic_classification_dataset_test(dataset, num_images=num_examples)
img, target = dataset[0] img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image)) self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
self.assertTrue(isinstance(target, int))
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive') @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_kmnist(self, mock_download_extract): def test_kmnist(self, mock_download_extract):
num_examples = 30 num_examples = 30
with mnist_root(num_examples, "KMNIST") as root: with mnist_root(num_examples, "KMNIST") as root:
dataset = torchvision.datasets.KMNIST(root, download=True) dataset = torchvision.datasets.KMNIST(root, download=True)
self.generic_classification_dataset_test(dataset, num_images=num_examples)
img, target = dataset[0] img, target = dataset[0]
self.assertEqual(len(dataset), num_examples) self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
@mock.patch('torchvision.datasets.mnist.download_and_extract_archive') @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
def test_fashionmnist(self, mock_download_extract): def test_fashionmnist(self, mock_download_extract):
num_examples = 30 num_examples = 30
with mnist_root(num_examples, "FashionMNIST") as root: with mnist_root(num_examples, "FashionMNIST") as root:
dataset = torchvision.datasets.FashionMNIST(root, download=True) dataset = torchvision.datasets.FashionMNIST(root, download=True)
self.generic_classification_dataset_test(dataset, num_images=num_examples)
img, target = dataset[0] img, target = dataset[0]
self.assertEqual(len(dataset), num_examples) self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
@mock.patch('torchvision.datasets.utils.download_url') @mock.patch('torchvision.datasets.utils.download_url')
def test_imagenet(self, mock_download): def test_imagenet(self, mock_download):
with imagenet_root() as root: with imagenet_root() as root:
dataset = torchvision.datasets.ImageNet(root, split='train', download=True) dataset = torchvision.datasets.ImageNet(root, split='train', download=True)
self.assertEqual(len(dataset), 1) self.generic_classification_dataset_test(dataset)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
self.assertEqual(dataset.class_to_idx['fakedata'], target)
dataset = torchvision.datasets.ImageNet(root, split='val', download=True) dataset = torchvision.datasets.ImageNet(root, split='val', download=True)
self.assertEqual(len(dataset), 1) self.generic_classification_dataset_test(dataset)
img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image))
self.assertTrue(isinstance(target, int))
self.assertEqual(dataset.class_to_idx['fakedata'], target)
@mock.patch('torchvision.datasets.cifar.check_integrity') @mock.patch('torchvision.datasets.cifar.check_integrity')
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity') @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
...@@ -113,18 +108,14 @@ class Tester(unittest.TestCase): ...@@ -113,18 +108,14 @@ class Tester(unittest.TestCase):
mock_int_check.return_value = True mock_int_check.return_value = True
with cifar_root('CIFAR10') as root: with cifar_root('CIFAR10') as root:
dataset = torchvision.datasets.CIFAR10(root, train=True, download=True) dataset = torchvision.datasets.CIFAR10(root, train=True, download=True)
self.assertEqual(len(dataset), 5) self.generic_classification_dataset_test(dataset, num_images=5)
img, target = dataset[0] img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image)) self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
self.assertTrue(isinstance(target, int))
self.assertEqual(dataset.class_to_idx['fakedata'], target)
dataset = torchvision.datasets.CIFAR10(root, train=False, download=True) dataset = torchvision.datasets.CIFAR10(root, train=False, download=True)
self.assertEqual(len(dataset), 1) self.generic_classification_dataset_test(dataset)
img, target = dataset[0] img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image)) self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
self.assertTrue(isinstance(target, int))
self.assertEqual(dataset.class_to_idx['fakedata'], target)
@mock.patch('torchvision.datasets.cifar.check_integrity') @mock.patch('torchvision.datasets.cifar.check_integrity')
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity') @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
...@@ -133,18 +124,14 @@ class Tester(unittest.TestCase): ...@@ -133,18 +124,14 @@ class Tester(unittest.TestCase):
mock_int_check.return_value = True mock_int_check.return_value = True
with cifar_root('CIFAR100') as root: with cifar_root('CIFAR100') as root:
dataset = torchvision.datasets.CIFAR100(root, train=True, download=True) dataset = torchvision.datasets.CIFAR100(root, train=True, download=True)
self.assertEqual(len(dataset), 1) self.generic_classification_dataset_test(dataset)
img, target = dataset[0] img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image)) self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
self.assertTrue(isinstance(target, int))
self.assertEqual(dataset.class_to_idx['fakedata'], target)
dataset = torchvision.datasets.CIFAR100(root, train=False, download=True) dataset = torchvision.datasets.CIFAR100(root, train=False, download=True)
self.assertEqual(len(dataset), 1) self.generic_classification_dataset_test(dataset)
img, target = dataset[0] img, target = dataset[0]
self.assertTrue(isinstance(img, PIL.Image.Image)) self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
self.assertTrue(isinstance(target, int))
self.assertEqual(dataset.class_to_idx['fakedata'], target)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment