test_datasets.py 6.53 KB
Newer Older
1
import os
2
import unittest
3
4
import mock
import PIL
5
from PIL import Image
6
from torch._utils_internal import get_file_path_2
7
8
import torchvision
from common_utils import get_tmp_dir
9
from fakedata_generation import mnist_root, cifar_root, imagenet_root
10
11


Philip Meier's avatar
Philip Meier committed
12
class Tester(unittest.TestCase):
13
14
15
16
17
18
    def generic_classification_dataset_test(self, dataset, num_images=1):
        self.assertEqual(len(dataset), num_images)
        img, target = dataset[0]
        self.assertTrue(isinstance(img, PIL.Image.Image))
        self.assertTrue(isinstance(target, int))

19
    def test_imagefolder(self):
20
21
22
23
        # TODO: create the fake data on-the-fly
        FAKEDATA_DIR = get_file_path_2(
            os.path.dirname(os.path.abspath(__file__)), 'assets', 'fakedata')

24
        with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
            classes = sorted(['a', 'b'])
            class_a_image_files = [os.path.join(root, 'a', file)
                                   for file in ('a1.png', 'a2.png', 'a3.png')]
            class_b_image_files = [os.path.join(root, 'b', file)
                                   for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')]
            dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x)

            # test if all classes are present
            self.assertEqual(classes, sorted(dataset.classes))

            # test if combination of classes and class_to_index functions correctly
            for cls in classes:
                self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])

            # test if all images were detected correctly
            class_a_idx = dataset.class_to_idx['a']
            class_b_idx = dataset.class_to_idx['b']
            imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files]
            imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files]
            imgs = sorted(imgs_a + imgs_b)
            self.assertEqual(imgs, dataset.imgs)

            # test if the datasets outputs all images correctly
            outputs = sorted([dataset[i] for i in range(len(dataset))])
            self.assertEqual(imgs, outputs)

            # redo all tests with specified valid image files
            dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x,
                                                       is_valid_file=lambda x: '3' in x)
            self.assertEqual(classes, sorted(dataset.classes))

            class_a_idx = dataset.class_to_idx['a']
            class_b_idx = dataset.class_to_idx['b']
            imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files
                      if '3' in img_file]
            imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files
                      if '3' in img_file]
            imgs = sorted(imgs_a + imgs_b)
            self.assertEqual(imgs, dataset.imgs)

            outputs = sorted([dataset[i] for i in range(len(dataset))])
            self.assertEqual(imgs, outputs)

68
69
70
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_mnist(self, mock_download_extract):
        num_examples = 30
71
        with mnist_root(num_examples, "MNIST") as root:
72
            dataset = torchvision.datasets.MNIST(root, download=True)
73
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
74
            img, target = dataset[0]
75
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
76

77
78
79
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_kmnist(self, mock_download_extract):
        num_examples = 30
80
        with mnist_root(num_examples, "KMNIST") as root:
81
            dataset = torchvision.datasets.KMNIST(root, download=True)
82
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
83
            img, target = dataset[0]
84
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
85

86
87
88
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_fashionmnist(self, mock_download_extract):
        num_examples = 30
89
        with mnist_root(num_examples, "FashionMNIST") as root:
90
            dataset = torchvision.datasets.FashionMNIST(root, download=True)
91
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
92
            img, target = dataset[0]
93
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
94
95
96

    @mock.patch('torchvision.datasets.utils.download_url')
    def test_imagenet(self, mock_download):
97
        with imagenet_root() as root:
98
            dataset = torchvision.datasets.ImageNet(root, split='train', download=True)
99
            self.generic_classification_dataset_test(dataset)
100
101

            dataset = torchvision.datasets.ImageNet(root, split='val', download=True)
102
            self.generic_classification_dataset_test(dataset)
103

Philip Meier's avatar
Philip Meier committed
104
105
106
107
108
109
110
    @mock.patch('torchvision.datasets.cifar.check_integrity')
    @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
    def test_cifar10(self, mock_ext_check, mock_int_check):
        mock_ext_check.return_value = True
        mock_int_check.return_value = True
        with cifar_root('CIFAR10') as root:
            dataset = torchvision.datasets.CIFAR10(root, train=True, download=True)
111
            self.generic_classification_dataset_test(dataset, num_images=5)
Philip Meier's avatar
Philip Meier committed
112
            img, target = dataset[0]
113
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
114
115

            dataset = torchvision.datasets.CIFAR10(root, train=False, download=True)
116
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
117
            img, target = dataset[0]
118
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
119
120
121
122
123
124
125
126

    @mock.patch('torchvision.datasets.cifar.check_integrity')
    @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
    def test_cifar100(self, mock_ext_check, mock_int_check):
        mock_ext_check.return_value = True
        mock_int_check.return_value = True
        with cifar_root('CIFAR100') as root:
            dataset = torchvision.datasets.CIFAR100(root, train=True, download=True)
127
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
128
            img, target = dataset[0]
129
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
130
131

            dataset = torchvision.datasets.CIFAR100(root, train=False, download=True)
132
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
133
            img, target = dataset[0]
134
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
135

136
137
138

if __name__ == '__main__':
    unittest.main()