test_datasets.py 12.7 KB
Newer Older
Francisco Massa's avatar
Francisco Massa committed
1
import sys
2
import os
3
import unittest
Philip Meier's avatar
Philip Meier committed
4
from unittest import mock
5
import numpy as np
6
import PIL
7
from PIL import Image
8
from torch._utils_internal import get_file_path_2
9
10
import torchvision
from common_utils import get_tmp_dir
Philip Meier's avatar
Philip Meier committed
11
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
12
    cityscapes_root, svhn_root, voc_root, ucf101_root
13
import xml.etree.ElementTree as ET
14
15


16
17
18
19
20
21
try:
    import scipy
    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False

22
23
24
25
26
27
try:
    import av
    HAS_PYAV = True
except ImportError:
    HAS_PYAV = False

28

Philip Meier's avatar
Philip Meier committed
29
class Tester(unittest.TestCase):
30
31
32
33
34
35
    def generic_classification_dataset_test(self, dataset, num_images=1):
        self.assertEqual(len(dataset), num_images)
        img, target = dataset[0]
        self.assertTrue(isinstance(img, PIL.Image.Image))
        self.assertTrue(isinstance(target, int))

36
37
38
39
40
41
    def generic_segmentation_dataset_test(self, dataset, num_images=1):
        self.assertEqual(len(dataset), num_images)
        img, target = dataset[0]
        self.assertTrue(isinstance(img, PIL.Image.Image))
        self.assertTrue(isinstance(target, PIL.Image.Image))

42
    def test_imagefolder(self):
43
44
45
46
        # TODO: create the fake data on-the-fly
        FAKEDATA_DIR = get_file_path_2(
            os.path.dirname(os.path.abspath(__file__)), 'assets', 'fakedata')

47
        with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
48
            classes = sorted(['a', 'b'])
49
50
51
52
53
54
            class_a_image_files = [
                os.path.join(root, 'a', file) for file in ('a1.png', 'a2.png', 'a3.png')
            ]
            class_b_image_files = [
                os.path.join(root, 'b', file) for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')
            ]
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
            dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x)

            # test if all classes are present
            self.assertEqual(classes, sorted(dataset.classes))

            # test if combination of classes and class_to_index functions correctly
            for cls in classes:
                self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])

            # test if all images were detected correctly
            class_a_idx = dataset.class_to_idx['a']
            class_b_idx = dataset.class_to_idx['b']
            imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files]
            imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files]
            imgs = sorted(imgs_a + imgs_b)
            self.assertEqual(imgs, dataset.imgs)

            # test if the datasets outputs all images correctly
            outputs = sorted([dataset[i] for i in range(len(dataset))])
            self.assertEqual(imgs, outputs)

            # redo all tests with specified valid image files
77
78
            dataset = torchvision.datasets.ImageFolder(
                root, loader=lambda x: x, is_valid_file=lambda x: '3' in x)
79
80
81
82
83
84
85
86
87
88
89
90
91
92
            self.assertEqual(classes, sorted(dataset.classes))

            class_a_idx = dataset.class_to_idx['a']
            class_b_idx = dataset.class_to_idx['b']
            imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files
                      if '3' in img_file]
            imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files
                      if '3' in img_file]
            imgs = sorted(imgs_a + imgs_b)
            self.assertEqual(imgs, dataset.imgs)

            outputs = sorted([dataset[i] for i in range(len(dataset))])
            self.assertEqual(imgs, outputs)

93
94
95
96
97
98
99
100
101
102
    def test_imagefolder_empty(self):
        with get_tmp_dir() as root:
            with self.assertRaises(RuntimeError):
                torchvision.datasets.ImageFolder(root, loader=lambda x: x)

            with self.assertRaises(RuntimeError):
                torchvision.datasets.ImageFolder(
                    root, loader=lambda x: x, is_valid_file=lambda x: False
                )

103
104
105
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_mnist(self, mock_download_extract):
        num_examples = 30
106
        with mnist_root(num_examples, "MNIST") as root:
107
            dataset = torchvision.datasets.MNIST(root, download=True)
108
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
109
            img, target = dataset[0]
110
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
111

112
113
114
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_kmnist(self, mock_download_extract):
        num_examples = 30
115
        with mnist_root(num_examples, "KMNIST") as root:
116
            dataset = torchvision.datasets.KMNIST(root, download=True)
117
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
118
            img, target = dataset[0]
119
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
120

121
122
123
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_fashionmnist(self, mock_download_extract):
        num_examples = 30
124
        with mnist_root(num_examples, "FashionMNIST") as root:
125
            dataset = torchvision.datasets.FashionMNIST(root, download=True)
126
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
127
            img, target = dataset[0]
128
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
129

130
    @mock.patch('torchvision.datasets.imagenet._verify_archive')
131
    @unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
132
    def test_imagenet(self, mock_verify):
133
        with imagenet_root() as root:
134
            dataset = torchvision.datasets.ImageNet(root, split='train')
135
            self.generic_classification_dataset_test(dataset)
136

137
            dataset = torchvision.datasets.ImageNet(root, split='val')
138
            self.generic_classification_dataset_test(dataset)
139

Philip Meier's avatar
Philip Meier committed
140
141
142
143
144
145
146
    @mock.patch('torchvision.datasets.cifar.check_integrity')
    @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
    def test_cifar10(self, mock_ext_check, mock_int_check):
        mock_ext_check.return_value = True
        mock_int_check.return_value = True
        with cifar_root('CIFAR10') as root:
            dataset = torchvision.datasets.CIFAR10(root, train=True, download=True)
147
            self.generic_classification_dataset_test(dataset, num_images=5)
Philip Meier's avatar
Philip Meier committed
148
            img, target = dataset[0]
149
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
150
151

            dataset = torchvision.datasets.CIFAR10(root, train=False, download=True)
152
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
153
            img, target = dataset[0]
154
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
155
156
157
158
159
160
161
162

    @mock.patch('torchvision.datasets.cifar.check_integrity')
    @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
    def test_cifar100(self, mock_ext_check, mock_int_check):
        mock_ext_check.return_value = True
        mock_int_check.return_value = True
        with cifar_root('CIFAR100') as root:
            dataset = torchvision.datasets.CIFAR100(root, train=True, download=True)
163
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
164
            img, target = dataset[0]
165
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
166
167

            dataset = torchvision.datasets.CIFAR100(root, train=False, download=True)
168
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
169
            img, target = dataset[0]
170
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
171

Francisco Massa's avatar
Francisco Massa committed
172
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
173
174
175
176
177
178
179
180
181
182
183
184
    def test_cityscapes(self):
        with cityscapes_root() as root:

            for mode in ['coarse', 'fine']:

                if mode == 'coarse':
                    splits = ['train', 'train_extra', 'val']
                else:
                    splits = ['train', 'val', 'test']

                for split in splits:
                    for target_type in ['semantic', 'instance']:
185
186
                        dataset = torchvision.datasets.Cityscapes(
                            root, split=split, target_type=target_type, mode=mode)
187
188
                        self.generic_segmentation_dataset_test(dataset, num_images=2)

189
190
                    color_dataset = torchvision.datasets.Cityscapes(
                        root, split=split, target_type='color', mode=mode)
191
192
193
194
                    color_img, color_target = color_dataset[0]
                    self.assertTrue(isinstance(color_img, PIL.Image.Image))
                    self.assertTrue(np.array(color_target).shape[2] == 4)

195
196
                    polygon_dataset = torchvision.datasets.Cityscapes(
                        root, split=split, target_type='polygon', mode=mode)
197
198
199
200
201
202
203
204
                    polygon_img, polygon_target = polygon_dataset[0]
                    self.assertTrue(isinstance(polygon_img, PIL.Image.Image))
                    self.assertTrue(isinstance(polygon_target, dict))
                    self.assertTrue(isinstance(polygon_target['imgHeight'], int))
                    self.assertTrue(isinstance(polygon_target['objects'], list))

                    # Test multiple target types
                    targets_combo = ['semantic', 'polygon', 'color']
205
206
                    multiple_types_dataset = torchvision.datasets.Cityscapes(
                        root, split=split, target_type=targets_combo, mode=mode)
207
208
209
210
211
212
213
214
215
216
                    output = multiple_types_dataset[0]
                    self.assertTrue(isinstance(output, tuple))
                    self.assertTrue(len(output) == 2)
                    self.assertTrue(isinstance(output[0], PIL.Image.Image))
                    self.assertTrue(isinstance(output[1], tuple))
                    self.assertTrue(len(output[1]) == 3)
                    self.assertTrue(isinstance(output[1][0], PIL.Image.Image))  # semantic
                    self.assertTrue(isinstance(output[1][1], dict))  # polygon
                    self.assertTrue(isinstance(output[1][2], PIL.Image.Image))  # color

Philip Meier's avatar
Philip Meier committed
217
    @mock.patch('torchvision.datasets.SVHN._check_integrity')
218
    @unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
Philip Meier's avatar
Philip Meier committed
219
220
221
222
223
224
225
226
227
228
229
230
    def test_svhn(self, mock_check):
        mock_check.return_value = True
        with svhn_root() as root:
            dataset = torchvision.datasets.SVHN(root, split="train")
            self.generic_classification_dataset_test(dataset, num_images=2)

            dataset = torchvision.datasets.SVHN(root, split="test")
            self.generic_classification_dataset_test(dataset, num_images=2)

            dataset = torchvision.datasets.SVHN(root, split="extra")
            self.generic_classification_dataset_test(dataset, num_images=2)

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    @mock.patch('torchvision.datasets.voc.download_extract')
    def test_voc_parse_xml(self, mock_download_extract):
        with voc_root() as root:
            dataset = torchvision.datasets.VOCDetection(root)

            single_object_xml = """<annotation>
              <object>
                <name>cat</name>
              </object>
            </annotation>"""
            multiple_object_xml = """<annotation>
              <object>
                <name>cat</name>
              </object>
              <object>
                <name>dog</name>
              </object>
            </annotation>"""
249
250

            single_object_parsed = dataset.parse_voc_xml(ET.fromstring(single_object_xml))
251
252
            multiple_object_parsed = dataset.parse_voc_xml(ET.fromstring(multiple_object_xml))

253
254
255
256
257
258
259
260
261
            self.assertEqual(single_object_parsed, {'annotation': {'object': [{'name': 'cat'}]}})
            self.assertEqual(multiple_object_parsed,
                             {'annotation': {
                                 'object': [{
                                     'name': 'cat'
                                 }, {
                                     'name': 'dog'
                                 }]
                             }})
262

263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    @unittest.skipIf(not HAS_PYAV, "PyAV unavailable")
    def test_ucf101(self):
        with ucf101_root() as (root, ann_root):
            for split in {True, False}:
                for fold in range(1, 4):
                    for length in {10, 15, 20}:
                        dataset = torchvision.datasets.UCF101(
                            root, ann_root, length, fold=fold, train=split)
                        self.assertGreater(len(dataset), 0)

                        video, audio, label = dataset[0]
                        self.assertEqual(video.size(), (length, 320, 240, 3))
                        self.assertEqual(audio.numel(), 0)
                        self.assertEqual(label, 0)

                        video, audio, label = dataset[len(dataset) - 1]
                        self.assertEqual(video.size(), (length, 320, 240, 3))
                        self.assertEqual(audio.numel(), 0)
                        self.assertEqual(label, 1)

283
284
285

if __name__ == '__main__':
    unittest.main()