test_datasets.py 17.5 KB
Newer Older
Francisco Massa's avatar
Francisco Massa committed
1
import sys
2
import os
3
import unittest
Philip Meier's avatar
Philip Meier committed
4
from unittest import mock
5
import numpy as np
6
import PIL
7
from PIL import Image
8
from torch._utils_internal import get_file_path_2
9
10
import torchvision
from common_utils import get_tmp_dir
Philip Meier's avatar
Philip Meier committed
11
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
Josh Bradley's avatar
Josh Bradley committed
12
    cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root
13
import xml.etree.ElementTree as ET
Philip Meier's avatar
Philip Meier committed
14
15
from urllib.request import Request, urlopen
import itertools
16
17


18
19
20
21
22
23
try:
    import scipy
    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False

24
25
26
27
28
29
try:
    import av
    HAS_PYAV = True
except ImportError:
    HAS_PYAV = False

30

Philip Meier's avatar
Philip Meier committed
31
class Tester(unittest.TestCase):
32
33
34
35
36
37
    def generic_classification_dataset_test(self, dataset, num_images=1):
        self.assertEqual(len(dataset), num_images)
        img, target = dataset[0]
        self.assertTrue(isinstance(img, PIL.Image.Image))
        self.assertTrue(isinstance(target, int))

38
39
40
41
42
43
    def generic_segmentation_dataset_test(self, dataset, num_images=1):
        self.assertEqual(len(dataset), num_images)
        img, target = dataset[0]
        self.assertTrue(isinstance(img, PIL.Image.Image))
        self.assertTrue(isinstance(target, PIL.Image.Image))

44
    def test_imagefolder(self):
45
46
47
48
        # TODO: create the fake data on-the-fly
        FAKEDATA_DIR = get_file_path_2(
            os.path.dirname(os.path.abspath(__file__)), 'assets', 'fakedata')

49
        with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
50
            classes = sorted(['a', 'b'])
51
52
53
54
55
56
            class_a_image_files = [
                os.path.join(root, 'a', file) for file in ('a1.png', 'a2.png', 'a3.png')
            ]
            class_b_image_files = [
                os.path.join(root, 'b', file) for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')
            ]
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
            dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x)

            # test if all classes are present
            self.assertEqual(classes, sorted(dataset.classes))

            # test if combination of classes and class_to_index functions correctly
            for cls in classes:
                self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])

            # test if all images were detected correctly
            class_a_idx = dataset.class_to_idx['a']
            class_b_idx = dataset.class_to_idx['b']
            imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files]
            imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files]
            imgs = sorted(imgs_a + imgs_b)
            self.assertEqual(imgs, dataset.imgs)

            # test if the datasets outputs all images correctly
            outputs = sorted([dataset[i] for i in range(len(dataset))])
            self.assertEqual(imgs, outputs)

            # redo all tests with specified valid image files
79
80
            dataset = torchvision.datasets.ImageFolder(
                root, loader=lambda x: x, is_valid_file=lambda x: '3' in x)
81
82
83
84
85
86
87
88
89
90
91
92
93
94
            self.assertEqual(classes, sorted(dataset.classes))

            class_a_idx = dataset.class_to_idx['a']
            class_b_idx = dataset.class_to_idx['b']
            imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files
                      if '3' in img_file]
            imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files
                      if '3' in img_file]
            imgs = sorted(imgs_a + imgs_b)
            self.assertEqual(imgs, dataset.imgs)

            outputs = sorted([dataset[i] for i in range(len(dataset))])
            self.assertEqual(imgs, outputs)

95
96
97
98
99
100
101
102
103
104
    def test_imagefolder_empty(self):
        with get_tmp_dir() as root:
            with self.assertRaises(RuntimeError):
                torchvision.datasets.ImageFolder(root, loader=lambda x: x)

            with self.assertRaises(RuntimeError):
                torchvision.datasets.ImageFolder(
                    root, loader=lambda x: x, is_valid_file=lambda x: False
                )

105
106
107
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_mnist(self, mock_download_extract):
        num_examples = 30
108
        with mnist_root(num_examples, "MNIST") as root:
109
            dataset = torchvision.datasets.MNIST(root, download=True)
110
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
111
            img, target = dataset[0]
112
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
113

114
115
116
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_kmnist(self, mock_download_extract):
        num_examples = 30
117
        with mnist_root(num_examples, "KMNIST") as root:
118
            dataset = torchvision.datasets.KMNIST(root, download=True)
119
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
120
            img, target = dataset[0]
121
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
122

123
124
125
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_fashionmnist(self, mock_download_extract):
        num_examples = 30
126
        with mnist_root(num_examples, "FashionMNIST") as root:
127
            dataset = torchvision.datasets.FashionMNIST(root, download=True)
128
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
129
            img, target = dataset[0]
130
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
131

132
    @mock.patch('torchvision.datasets.imagenet._verify_archive')
133
    @unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
134
    def test_imagenet(self, mock_verify):
135
        with imagenet_root() as root:
136
            dataset = torchvision.datasets.ImageNet(root, split='train')
137
            self.generic_classification_dataset_test(dataset)
138

139
            dataset = torchvision.datasets.ImageNet(root, split='val')
140
            self.generic_classification_dataset_test(dataset)
141

Josh Bradley's avatar
Josh Bradley committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    @mock.patch('torchvision.datasets.WIDERFace._check_integrity')
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
    def test_widerface(self, mock_check_integrity):
        mock_check_integrity.return_value = True
        with widerface_root() as root:
            dataset = torchvision.datasets.WIDERFace(root, split='train')
            self.assertEqual(len(dataset), 1)
            img, target = dataset[0]
            self.assertTrue(isinstance(img, PIL.Image.Image))

            dataset = torchvision.datasets.WIDERFace(root, split='val')
            self.assertEqual(len(dataset), 1)
            img, target = dataset[0]
            self.assertTrue(isinstance(img, PIL.Image.Image))

            dataset = torchvision.datasets.WIDERFace(root, split='test')
            self.assertEqual(len(dataset), 1)
            img, target = dataset[0]
            self.assertTrue(isinstance(img, PIL.Image.Image))

Philip Meier's avatar
Philip Meier committed
162
163
164
165
166
167
168
    @mock.patch('torchvision.datasets.cifar.check_integrity')
    @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
    def test_cifar10(self, mock_ext_check, mock_int_check):
        mock_ext_check.return_value = True
        mock_int_check.return_value = True
        with cifar_root('CIFAR10') as root:
            dataset = torchvision.datasets.CIFAR10(root, train=True, download=True)
169
            self.generic_classification_dataset_test(dataset, num_images=5)
Philip Meier's avatar
Philip Meier committed
170
            img, target = dataset[0]
171
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
172
173

            dataset = torchvision.datasets.CIFAR10(root, train=False, download=True)
174
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
175
            img, target = dataset[0]
176
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
177
178
179
180
181
182
183
184

    @mock.patch('torchvision.datasets.cifar.check_integrity')
    @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
    def test_cifar100(self, mock_ext_check, mock_int_check):
        mock_ext_check.return_value = True
        mock_int_check.return_value = True
        with cifar_root('CIFAR100') as root:
            dataset = torchvision.datasets.CIFAR100(root, train=True, download=True)
185
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
186
            img, target = dataset[0]
187
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
188
189

            dataset = torchvision.datasets.CIFAR100(root, train=False, download=True)
190
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
191
            img, target = dataset[0]
192
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
193

Francisco Massa's avatar
Francisco Massa committed
194
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
195
196
197
198
199
200
201
202
203
204
205
206
    def test_cityscapes(self):
        with cityscapes_root() as root:

            for mode in ['coarse', 'fine']:

                if mode == 'coarse':
                    splits = ['train', 'train_extra', 'val']
                else:
                    splits = ['train', 'val', 'test']

                for split in splits:
                    for target_type in ['semantic', 'instance']:
207
208
                        dataset = torchvision.datasets.Cityscapes(
                            root, split=split, target_type=target_type, mode=mode)
209
210
                        self.generic_segmentation_dataset_test(dataset, num_images=2)

211
212
                    color_dataset = torchvision.datasets.Cityscapes(
                        root, split=split, target_type='color', mode=mode)
213
214
215
216
                    color_img, color_target = color_dataset[0]
                    self.assertTrue(isinstance(color_img, PIL.Image.Image))
                    self.assertTrue(np.array(color_target).shape[2] == 4)

217
218
                    polygon_dataset = torchvision.datasets.Cityscapes(
                        root, split=split, target_type='polygon', mode=mode)
219
220
221
222
223
224
225
226
                    polygon_img, polygon_target = polygon_dataset[0]
                    self.assertTrue(isinstance(polygon_img, PIL.Image.Image))
                    self.assertTrue(isinstance(polygon_target, dict))
                    self.assertTrue(isinstance(polygon_target['imgHeight'], int))
                    self.assertTrue(isinstance(polygon_target['objects'], list))

                    # Test multiple target types
                    targets_combo = ['semantic', 'polygon', 'color']
227
228
                    multiple_types_dataset = torchvision.datasets.Cityscapes(
                        root, split=split, target_type=targets_combo, mode=mode)
229
230
231
232
233
234
235
236
237
238
                    output = multiple_types_dataset[0]
                    self.assertTrue(isinstance(output, tuple))
                    self.assertTrue(len(output) == 2)
                    self.assertTrue(isinstance(output[0], PIL.Image.Image))
                    self.assertTrue(isinstance(output[1], tuple))
                    self.assertTrue(len(output[1]) == 3)
                    self.assertTrue(isinstance(output[1][0], PIL.Image.Image))  # semantic
                    self.assertTrue(isinstance(output[1][1], dict))  # polygon
                    self.assertTrue(isinstance(output[1][2], PIL.Image.Image))  # color

Philip Meier's avatar
Philip Meier committed
239
    @mock.patch('torchvision.datasets.SVHN._check_integrity')
240
    @unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
Philip Meier's avatar
Philip Meier committed
241
242
243
244
245
246
247
248
249
250
251
252
    def test_svhn(self, mock_check):
        mock_check.return_value = True
        with svhn_root() as root:
            dataset = torchvision.datasets.SVHN(root, split="train")
            self.generic_classification_dataset_test(dataset, num_images=2)

            dataset = torchvision.datasets.SVHN(root, split="test")
            self.generic_classification_dataset_test(dataset, num_images=2)

            dataset = torchvision.datasets.SVHN(root, split="extra")
            self.generic_classification_dataset_test(dataset, num_images=2)

253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    @mock.patch('torchvision.datasets.voc.download_extract')
    def test_voc_parse_xml(self, mock_download_extract):
        with voc_root() as root:
            dataset = torchvision.datasets.VOCDetection(root)

            single_object_xml = """<annotation>
              <object>
                <name>cat</name>
              </object>
            </annotation>"""
            multiple_object_xml = """<annotation>
              <object>
                <name>cat</name>
              </object>
              <object>
                <name>dog</name>
              </object>
            </annotation>"""
271
272

            single_object_parsed = dataset.parse_voc_xml(ET.fromstring(single_object_xml))
273
274
            multiple_object_parsed = dataset.parse_voc_xml(ET.fromstring(multiple_object_xml))

275
276
277
278
279
280
281
282
283
            self.assertEqual(single_object_parsed, {'annotation': {'object': [{'name': 'cat'}]}})
            self.assertEqual(multiple_object_parsed,
                             {'annotation': {
                                 'object': [{
                                     'name': 'cat'
                                 }, {
                                     'name': 'dog'
                                 }]
                             }})
284

285
286
    @unittest.skipIf(not HAS_PYAV, "PyAV unavailable")
    def test_ucf101(self):
287
        cached_meta_data = None
288
289
290
291
        with ucf101_root() as (root, ann_root):
            for split in {True, False}:
                for fold in range(1, 4):
                    for length in {10, 15, 20}:
292
293
294
295
                        dataset = torchvision.datasets.UCF101(root, ann_root, length, fold=fold, train=split,
                                                              num_workers=2, _precomputed_metadata=cached_meta_data)
                        if cached_meta_data is None:
                            cached_meta_data = dataset.metadata
296
297
298
299
300
301
302
303
304
305
306
307
                        self.assertGreater(len(dataset), 0)

                        video, audio, label = dataset[0]
                        self.assertEqual(video.size(), (length, 320, 240, 3))
                        self.assertEqual(audio.numel(), 0)
                        self.assertEqual(label, 0)

                        video, audio, label = dataset[len(dataset) - 1]
                        self.assertEqual(video.size(), (length, 320, 240, 3))
                        self.assertEqual(audio.numel(), 0)
                        self.assertEqual(label, 1)

Philip Meier's avatar
Philip Meier committed
308
    def test_places365(self):
Philip Meier's avatar
Philip Meier committed
309
        for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
Philip Meier's avatar
Philip Meier committed
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
            with places365_root(split=split, small=small) as places365:
                root, data = places365

                dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
                self.generic_classification_dataset_test(dataset, num_images=len(data["imgs"]))

    def test_places365_transforms(self):
        expected_image = "image"
        expected_target = "target"

        def transform(image):
            return expected_image

        def target_transform(target):
            return expected_target

        with places365_root() as places365:
            root, data = places365

            dataset = torchvision.datasets.Places365(
                root, transform=transform, target_transform=target_transform, download=True
            )
            actual_image, actual_target = dataset[0]

            self.assertEqual(actual_image, expected_image)
            self.assertEqual(actual_target, expected_target)

    def test_places365_devkit_download(self):
Philip Meier's avatar
Philip Meier committed
338
        for split in ("train-standard", "train-challenge", "val"):
Philip Meier's avatar
Philip Meier committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
            with self.subTest(split=split):
                with places365_root(split=split) as places365:
                    root, data = places365

                    dataset = torchvision.datasets.Places365(root, split=split, download=True)

                    with self.subTest("classes"):
                        self.assertSequenceEqual(dataset.classes, data["classes"])

                    with self.subTest("class_to_idx"):
                        self.assertDictEqual(dataset.class_to_idx, data["class_to_idx"])

                    with self.subTest("imgs"):
                        self.assertSequenceEqual(dataset.imgs, data["imgs"])

    def test_places365_devkit_no_download(self):
Philip Meier's avatar
Philip Meier committed
355
        for split in ("train-standard", "train-challenge", "val"):
Philip Meier's avatar
Philip Meier committed
356
357
358
359
360
361
362
363
            with self.subTest(split=split):
                with places365_root(split=split, extract_images=False) as places365:
                    root, data = places365

                    with self.assertRaises(RuntimeError):
                        torchvision.datasets.Places365(root, split=split, download=False)

    def test_places365_images_download(self):
Philip Meier's avatar
Philip Meier committed
364
        for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
Philip Meier's avatar
Philip Meier committed
365
366
367
368
369
370
371
372
373
374
375
            with self.subTest(split=split, small=small):
                with places365_root(split=split, small=small) as places365:
                    root, data = places365

                    dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)

                    assert all(os.path.exists(item[0]) for item in dataset.imgs)

    def test_places365_images_download_preexisting(self):
        split = "train-standard"
        small = False
Philip Meier's avatar
Philip Meier committed
376
        images_dir = "data_large_standard"
Philip Meier's avatar
Philip Meier committed
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391

        with places365_root(split=split, small=small) as places365:
            root, data = places365
            os.mkdir(os.path.join(root, images_dir))

            with self.assertRaises(RuntimeError):
                torchvision.datasets.Places365(root, split=split, small=small, download=True)

    def test_places365_repr_smoke(self):
        with places365_root(extract_images=False) as places365:
            root, data = places365

            dataset = torchvision.datasets.Places365(root, download=True)
            self.assertIsInstance(repr(dataset), str)

392
393
394

if __name__ == '__main__':
    unittest.main()