test_datasets.py 11.7 KB
Newer Older
Francisco Massa's avatar
Francisco Massa committed
1
import sys
2
import os
3
import unittest
Philip Meier's avatar
Philip Meier committed
4
from unittest import mock
5
import numpy as np
6
import PIL
7
from PIL import Image
8
from torch._utils_internal import get_file_path_2
9
10
import torchvision
from common_utils import get_tmp_dir
Philip Meier's avatar
Philip Meier committed
11
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
12
13
    cityscapes_root, svhn_root, voc_root
import xml.etree.ElementTree as ET
14
15


16
17
18
19
20
21
22
try:
    import scipy
    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False


Philip Meier's avatar
Philip Meier committed
23
class Tester(unittest.TestCase):
24
25
26
27
28
29
    def generic_classification_dataset_test(self, dataset, num_images=1):
        self.assertEqual(len(dataset), num_images)
        img, target = dataset[0]
        self.assertTrue(isinstance(img, PIL.Image.Image))
        self.assertTrue(isinstance(target, int))

30
31
32
33
34
35
    def generic_segmentation_dataset_test(self, dataset, num_images=1):
        self.assertEqual(len(dataset), num_images)
        img, target = dataset[0]
        self.assertTrue(isinstance(img, PIL.Image.Image))
        self.assertTrue(isinstance(target, PIL.Image.Image))

36
    def test_imagefolder(self):
37
38
39
40
        # TODO: create the fake data on-the-fly
        FAKEDATA_DIR = get_file_path_2(
            os.path.dirname(os.path.abspath(__file__)), 'assets', 'fakedata')

41
        with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
42
            classes = sorted(['a', 'b'])
43
44
45
46
47
48
            class_a_image_files = [
                os.path.join(root, 'a', file) for file in ('a1.png', 'a2.png', 'a3.png')
            ]
            class_b_image_files = [
                os.path.join(root, 'b', file) for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')
            ]
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
            dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x)

            # test if all classes are present
            self.assertEqual(classes, sorted(dataset.classes))

            # test if combination of classes and class_to_index functions correctly
            for cls in classes:
                self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])

            # test if all images were detected correctly
            class_a_idx = dataset.class_to_idx['a']
            class_b_idx = dataset.class_to_idx['b']
            imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files]
            imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files]
            imgs = sorted(imgs_a + imgs_b)
            self.assertEqual(imgs, dataset.imgs)

            # test if the datasets outputs all images correctly
            outputs = sorted([dataset[i] for i in range(len(dataset))])
            self.assertEqual(imgs, outputs)

            # redo all tests with specified valid image files
71
72
            dataset = torchvision.datasets.ImageFolder(
                root, loader=lambda x: x, is_valid_file=lambda x: '3' in x)
73
74
75
76
77
78
79
80
81
82
83
84
85
86
            self.assertEqual(classes, sorted(dataset.classes))

            class_a_idx = dataset.class_to_idx['a']
            class_b_idx = dataset.class_to_idx['b']
            imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files
                      if '3' in img_file]
            imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files
                      if '3' in img_file]
            imgs = sorted(imgs_a + imgs_b)
            self.assertEqual(imgs, dataset.imgs)

            outputs = sorted([dataset[i] for i in range(len(dataset))])
            self.assertEqual(imgs, outputs)

87
88
89
90
91
92
93
94
95
96
    def test_imagefolder_empty(self):
        with get_tmp_dir() as root:
            with self.assertRaises(RuntimeError):
                torchvision.datasets.ImageFolder(root, loader=lambda x: x)

            with self.assertRaises(RuntimeError):
                torchvision.datasets.ImageFolder(
                    root, loader=lambda x: x, is_valid_file=lambda x: False
                )

97
98
99
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_mnist(self, mock_download_extract):
        num_examples = 30
100
        with mnist_root(num_examples, "MNIST") as root:
101
            dataset = torchvision.datasets.MNIST(root, download=True)
102
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
103
            img, target = dataset[0]
104
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
105

106
107
108
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_kmnist(self, mock_download_extract):
        num_examples = 30
109
        with mnist_root(num_examples, "KMNIST") as root:
110
            dataset = torchvision.datasets.KMNIST(root, download=True)
111
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
112
            img, target = dataset[0]
113
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
114

115
116
117
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_fashionmnist(self, mock_download_extract):
        num_examples = 30
118
        with mnist_root(num_examples, "FashionMNIST") as root:
119
            dataset = torchvision.datasets.FashionMNIST(root, download=True)
120
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
121
            img, target = dataset[0]
122
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
123

124
    @mock.patch('torchvision.datasets.imagenet._verify_archive')
125
    @unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
126
    def test_imagenet(self, mock_verify):
127
        with imagenet_root() as root:
128
            dataset = torchvision.datasets.ImageNet(root, split='train')
129
            self.generic_classification_dataset_test(dataset)
130

131
            dataset = torchvision.datasets.ImageNet(root, split='val')
132
            self.generic_classification_dataset_test(dataset)
133

Philip Meier's avatar
Philip Meier committed
134
135
136
137
138
139
140
    @mock.patch('torchvision.datasets.cifar.check_integrity')
    @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
    def test_cifar10(self, mock_ext_check, mock_int_check):
        mock_ext_check.return_value = True
        mock_int_check.return_value = True
        with cifar_root('CIFAR10') as root:
            dataset = torchvision.datasets.CIFAR10(root, train=True, download=True)
141
            self.generic_classification_dataset_test(dataset, num_images=5)
Philip Meier's avatar
Philip Meier committed
142
            img, target = dataset[0]
143
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
144
145

            dataset = torchvision.datasets.CIFAR10(root, train=False, download=True)
146
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
147
            img, target = dataset[0]
148
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
149
150
151
152
153
154
155
156

    @mock.patch('torchvision.datasets.cifar.check_integrity')
    @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
    def test_cifar100(self, mock_ext_check, mock_int_check):
        mock_ext_check.return_value = True
        mock_int_check.return_value = True
        with cifar_root('CIFAR100') as root:
            dataset = torchvision.datasets.CIFAR100(root, train=True, download=True)
157
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
158
            img, target = dataset[0]
159
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
160
161

            dataset = torchvision.datasets.CIFAR100(root, train=False, download=True)
162
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
163
            img, target = dataset[0]
164
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
165

Francisco Massa's avatar
Francisco Massa committed
166
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
167
168
169
170
171
172
173
174
175
176
177
178
    def test_cityscapes(self):
        with cityscapes_root() as root:

            for mode in ['coarse', 'fine']:

                if mode == 'coarse':
                    splits = ['train', 'train_extra', 'val']
                else:
                    splits = ['train', 'val', 'test']

                for split in splits:
                    for target_type in ['semantic', 'instance']:
179
180
                        dataset = torchvision.datasets.Cityscapes(
                            root, split=split, target_type=target_type, mode=mode)
181
182
                        self.generic_segmentation_dataset_test(dataset, num_images=2)

183
184
                    color_dataset = torchvision.datasets.Cityscapes(
                        root, split=split, target_type='color', mode=mode)
185
186
187
188
                    color_img, color_target = color_dataset[0]
                    self.assertTrue(isinstance(color_img, PIL.Image.Image))
                    self.assertTrue(np.array(color_target).shape[2] == 4)

189
190
                    polygon_dataset = torchvision.datasets.Cityscapes(
                        root, split=split, target_type='polygon', mode=mode)
191
192
193
194
195
196
197
198
                    polygon_img, polygon_target = polygon_dataset[0]
                    self.assertTrue(isinstance(polygon_img, PIL.Image.Image))
                    self.assertTrue(isinstance(polygon_target, dict))
                    self.assertTrue(isinstance(polygon_target['imgHeight'], int))
                    self.assertTrue(isinstance(polygon_target['objects'], list))

                    # Test multiple target types
                    targets_combo = ['semantic', 'polygon', 'color']
199
200
                    multiple_types_dataset = torchvision.datasets.Cityscapes(
                        root, split=split, target_type=targets_combo, mode=mode)
201
202
203
204
205
206
207
208
209
210
                    output = multiple_types_dataset[0]
                    self.assertTrue(isinstance(output, tuple))
                    self.assertTrue(len(output) == 2)
                    self.assertTrue(isinstance(output[0], PIL.Image.Image))
                    self.assertTrue(isinstance(output[1], tuple))
                    self.assertTrue(len(output[1]) == 3)
                    self.assertTrue(isinstance(output[1][0], PIL.Image.Image))  # semantic
                    self.assertTrue(isinstance(output[1][1], dict))  # polygon
                    self.assertTrue(isinstance(output[1][2], PIL.Image.Image))  # color

Philip Meier's avatar
Philip Meier committed
211
    @mock.patch('torchvision.datasets.SVHN._check_integrity')
212
    @unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
Philip Meier's avatar
Philip Meier committed
213
214
215
216
217
218
219
220
221
222
223
224
    def test_svhn(self, mock_check):
        mock_check.return_value = True
        with svhn_root() as root:
            dataset = torchvision.datasets.SVHN(root, split="train")
            self.generic_classification_dataset_test(dataset, num_images=2)

            dataset = torchvision.datasets.SVHN(root, split="test")
            self.generic_classification_dataset_test(dataset, num_images=2)

            dataset = torchvision.datasets.SVHN(root, split="extra")
            self.generic_classification_dataset_test(dataset, num_images=2)

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    @mock.patch('torchvision.datasets.voc.download_extract')
    def test_voc_parse_xml(self, mock_download_extract):
        with voc_root() as root:
            dataset = torchvision.datasets.VOCDetection(root)

            single_object_xml = """<annotation>
              <object>
                <name>cat</name>
              </object>
            </annotation>"""
            multiple_object_xml = """<annotation>
              <object>
                <name>cat</name>
              </object>
              <object>
                <name>dog</name>
              </object>
            </annotation>"""
243
244

            single_object_parsed = dataset.parse_voc_xml(ET.fromstring(single_object_xml))
245
246
            multiple_object_parsed = dataset.parse_voc_xml(ET.fromstring(multiple_object_xml))

247
248
249
250
251
252
253
254
255
            self.assertEqual(single_object_parsed, {'annotation': {'object': [{'name': 'cat'}]}})
            self.assertEqual(multiple_object_parsed,
                             {'annotation': {
                                 'object': [{
                                     'name': 'cat'
                                 }, {
                                     'name': 'dog'
                                 }]
                             }})
256

257
258
259

if __name__ == '__main__':
    unittest.main()