test_datasets.py 17.1 KB
Newer Older
Francisco Massa's avatar
Francisco Massa committed
1
import sys
2
import os
3
import unittest
Philip Meier's avatar
Philip Meier committed
4
from unittest import mock
5
import numpy as np
6
import PIL
7
from PIL import Image
8
from torch._utils_internal import get_file_path_2
9
10
import torchvision
from common_utils import get_tmp_dir
Philip Meier's avatar
Philip Meier committed
11
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
Philip Meier's avatar
Philip Meier committed
12
    cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root
13
import xml.etree.ElementTree as ET
Philip Meier's avatar
Philip Meier committed
14
15
from urllib.request import Request, urlopen
import itertools
16
17


18
19
20
21
22
23
try:
    import scipy
    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False

24
25
26
27
28
29
try:
    import av
    HAS_PYAV = True
except ImportError:
    HAS_PYAV = False

30

Philip Meier's avatar
Philip Meier committed
31
class Tester(unittest.TestCase):
32
33
34
35
36
37
    def generic_classification_dataset_test(self, dataset, num_images=1):
        self.assertEqual(len(dataset), num_images)
        img, target = dataset[0]
        self.assertTrue(isinstance(img, PIL.Image.Image))
        self.assertTrue(isinstance(target, int))

38
39
40
41
42
43
    def generic_segmentation_dataset_test(self, dataset, num_images=1):
        self.assertEqual(len(dataset), num_images)
        img, target = dataset[0]
        self.assertTrue(isinstance(img, PIL.Image.Image))
        self.assertTrue(isinstance(target, PIL.Image.Image))

44
    def test_imagefolder(self):
45
46
47
48
        # TODO: create the fake data on-the-fly
        FAKEDATA_DIR = get_file_path_2(
            os.path.dirname(os.path.abspath(__file__)), 'assets', 'fakedata')

49
        with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
50
            classes = sorted(['a', 'b'])
51
52
53
54
55
56
            class_a_image_files = [
                os.path.join(root, 'a', file) for file in ('a1.png', 'a2.png', 'a3.png')
            ]
            class_b_image_files = [
                os.path.join(root, 'b', file) for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')
            ]
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
            dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x)

            # test if all classes are present
            self.assertEqual(classes, sorted(dataset.classes))

            # test if combination of classes and class_to_index functions correctly
            for cls in classes:
                self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])

            # test if all images were detected correctly
            class_a_idx = dataset.class_to_idx['a']
            class_b_idx = dataset.class_to_idx['b']
            imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files]
            imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files]
            imgs = sorted(imgs_a + imgs_b)
            self.assertEqual(imgs, dataset.imgs)

            # test if the datasets outputs all images correctly
            outputs = sorted([dataset[i] for i in range(len(dataset))])
            self.assertEqual(imgs, outputs)

            # redo all tests with specified valid image files
79
80
            dataset = torchvision.datasets.ImageFolder(
                root, loader=lambda x: x, is_valid_file=lambda x: '3' in x)
81
82
83
84
85
86
87
88
89
90
91
92
93
94
            self.assertEqual(classes, sorted(dataset.classes))

            class_a_idx = dataset.class_to_idx['a']
            class_b_idx = dataset.class_to_idx['b']
            imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files
                      if '3' in img_file]
            imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files
                      if '3' in img_file]
            imgs = sorted(imgs_a + imgs_b)
            self.assertEqual(imgs, dataset.imgs)

            outputs = sorted([dataset[i] for i in range(len(dataset))])
            self.assertEqual(imgs, outputs)

95
96
97
98
99
100
101
102
103
104
    def test_imagefolder_empty(self):
        with get_tmp_dir() as root:
            with self.assertRaises(RuntimeError):
                torchvision.datasets.ImageFolder(root, loader=lambda x: x)

            with self.assertRaises(RuntimeError):
                torchvision.datasets.ImageFolder(
                    root, loader=lambda x: x, is_valid_file=lambda x: False
                )

105
106
107
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_mnist(self, mock_download_extract):
        num_examples = 30
108
        with mnist_root(num_examples, "MNIST") as root:
109
            dataset = torchvision.datasets.MNIST(root, download=True)
110
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
111
            img, target = dataset[0]
112
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
113

114
115
116
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_kmnist(self, mock_download_extract):
        num_examples = 30
117
        with mnist_root(num_examples, "KMNIST") as root:
118
            dataset = torchvision.datasets.KMNIST(root, download=True)
119
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
120
            img, target = dataset[0]
121
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
122

123
124
125
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_fashionmnist(self, mock_download_extract):
        num_examples = 30
126
        with mnist_root(num_examples, "FashionMNIST") as root:
127
            dataset = torchvision.datasets.FashionMNIST(root, download=True)
128
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
129
            img, target = dataset[0]
130
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
131

132
    @mock.patch('torchvision.datasets.imagenet._verify_archive')
133
    @unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
134
    def test_imagenet(self, mock_verify):
135
        with imagenet_root() as root:
136
            dataset = torchvision.datasets.ImageNet(root, split='train')
137
            self.generic_classification_dataset_test(dataset)
138

139
            dataset = torchvision.datasets.ImageNet(root, split='val')
140
            self.generic_classification_dataset_test(dataset)
141

Philip Meier's avatar
Philip Meier committed
142
143
144
145
146
147
148
    @mock.patch('torchvision.datasets.cifar.check_integrity')
    @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
    def test_cifar10(self, mock_ext_check, mock_int_check):
        mock_ext_check.return_value = True
        mock_int_check.return_value = True
        with cifar_root('CIFAR10') as root:
            dataset = torchvision.datasets.CIFAR10(root, train=True, download=True)
149
            self.generic_classification_dataset_test(dataset, num_images=5)
Philip Meier's avatar
Philip Meier committed
150
            img, target = dataset[0]
151
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
152
153

            dataset = torchvision.datasets.CIFAR10(root, train=False, download=True)
154
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
155
            img, target = dataset[0]
156
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
157
158
159
160
161
162
163
164

    @mock.patch('torchvision.datasets.cifar.check_integrity')
    @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
    def test_cifar100(self, mock_ext_check, mock_int_check):
        mock_ext_check.return_value = True
        mock_int_check.return_value = True
        with cifar_root('CIFAR100') as root:
            dataset = torchvision.datasets.CIFAR100(root, train=True, download=True)
165
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
166
            img, target = dataset[0]
167
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
168
169

            dataset = torchvision.datasets.CIFAR100(root, train=False, download=True)
170
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
171
            img, target = dataset[0]
172
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
173

Francisco Massa's avatar
Francisco Massa committed
174
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
175
176
177
178
179
180
181
182
183
184
185
186
    def test_cityscapes(self):
        with cityscapes_root() as root:

            for mode in ['coarse', 'fine']:

                if mode == 'coarse':
                    splits = ['train', 'train_extra', 'val']
                else:
                    splits = ['train', 'val', 'test']

                for split in splits:
                    for target_type in ['semantic', 'instance']:
187
188
                        dataset = torchvision.datasets.Cityscapes(
                            root, split=split, target_type=target_type, mode=mode)
189
190
                        self.generic_segmentation_dataset_test(dataset, num_images=2)

191
192
                    color_dataset = torchvision.datasets.Cityscapes(
                        root, split=split, target_type='color', mode=mode)
193
194
195
196
                    color_img, color_target = color_dataset[0]
                    self.assertTrue(isinstance(color_img, PIL.Image.Image))
                    self.assertTrue(np.array(color_target).shape[2] == 4)

197
198
                    polygon_dataset = torchvision.datasets.Cityscapes(
                        root, split=split, target_type='polygon', mode=mode)
199
200
201
202
203
204
205
206
                    polygon_img, polygon_target = polygon_dataset[0]
                    self.assertTrue(isinstance(polygon_img, PIL.Image.Image))
                    self.assertTrue(isinstance(polygon_target, dict))
                    self.assertTrue(isinstance(polygon_target['imgHeight'], int))
                    self.assertTrue(isinstance(polygon_target['objects'], list))

                    # Test multiple target types
                    targets_combo = ['semantic', 'polygon', 'color']
207
208
                    multiple_types_dataset = torchvision.datasets.Cityscapes(
                        root, split=split, target_type=targets_combo, mode=mode)
209
210
211
212
213
214
215
216
217
218
                    output = multiple_types_dataset[0]
                    self.assertTrue(isinstance(output, tuple))
                    self.assertTrue(len(output) == 2)
                    self.assertTrue(isinstance(output[0], PIL.Image.Image))
                    self.assertTrue(isinstance(output[1], tuple))
                    self.assertTrue(len(output[1]) == 3)
                    self.assertTrue(isinstance(output[1][0], PIL.Image.Image))  # semantic
                    self.assertTrue(isinstance(output[1][1], dict))  # polygon
                    self.assertTrue(isinstance(output[1][2], PIL.Image.Image))  # color

Philip Meier's avatar
Philip Meier committed
219
    @mock.patch('torchvision.datasets.SVHN._check_integrity')
220
    @unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
Philip Meier's avatar
Philip Meier committed
221
222
223
224
225
226
227
228
229
230
231
232
    def test_svhn(self, mock_check):
        mock_check.return_value = True
        with svhn_root() as root:
            dataset = torchvision.datasets.SVHN(root, split="train")
            self.generic_classification_dataset_test(dataset, num_images=2)

            dataset = torchvision.datasets.SVHN(root, split="test")
            self.generic_classification_dataset_test(dataset, num_images=2)

            dataset = torchvision.datasets.SVHN(root, split="extra")
            self.generic_classification_dataset_test(dataset, num_images=2)

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    @mock.patch('torchvision.datasets.voc.download_extract')
    def test_voc_parse_xml(self, mock_download_extract):
        with voc_root() as root:
            dataset = torchvision.datasets.VOCDetection(root)

            single_object_xml = """<annotation>
              <object>
                <name>cat</name>
              </object>
            </annotation>"""
            multiple_object_xml = """<annotation>
              <object>
                <name>cat</name>
              </object>
              <object>
                <name>dog</name>
              </object>
            </annotation>"""
251
252

            single_object_parsed = dataset.parse_voc_xml(ET.fromstring(single_object_xml))
253
254
            multiple_object_parsed = dataset.parse_voc_xml(ET.fromstring(multiple_object_xml))

255
256
257
258
259
260
261
262
263
            self.assertEqual(single_object_parsed, {'annotation': {'object': [{'name': 'cat'}]}})
            self.assertEqual(multiple_object_parsed,
                             {'annotation': {
                                 'object': [{
                                     'name': 'cat'
                                 }, {
                                     'name': 'dog'
                                 }]
                             }})
264

265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    @unittest.skipIf(not HAS_PYAV, "PyAV unavailable")
    def test_ucf101(self):
        with ucf101_root() as (root, ann_root):
            for split in {True, False}:
                for fold in range(1, 4):
                    for length in {10, 15, 20}:
                        dataset = torchvision.datasets.UCF101(
                            root, ann_root, length, fold=fold, train=split)
                        self.assertGreater(len(dataset), 0)

                        video, audio, label = dataset[0]
                        self.assertEqual(video.size(), (length, 320, 240, 3))
                        self.assertEqual(audio.numel(), 0)
                        self.assertEqual(label, 0)

                        video, audio, label = dataset[len(dataset) - 1]
                        self.assertEqual(video.size(), (length, 320, 240, 3))
                        self.assertEqual(audio.numel(), 0)
                        self.assertEqual(label, 1)

Philip Meier's avatar
Philip Meier committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
    def test_places365(self):
        for split, small in itertools.product(("train-standard", "train-challenge", "val", "test"), (False, True)):
            with places365_root(split=split, small=small) as places365:
                root, data = places365

                dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
                self.generic_classification_dataset_test(dataset, num_images=len(data["imgs"]))

    def test_places365_transforms(self):
        expected_image = "image"
        expected_target = "target"

        def transform(image):
            return expected_image

        def target_transform(target):
            return expected_target

        with places365_root() as places365:
            root, data = places365

            dataset = torchvision.datasets.Places365(
                root, transform=transform, target_transform=target_transform, download=True
            )
            actual_image, actual_target = dataset[0]

            self.assertEqual(actual_image, expected_image)
            self.assertEqual(actual_target, expected_target)

    @mock.patch("torchvision.datasets.utils.download_url")
    def test_places365_downloadable(self, download_url):
        for split, small in itertools.product(("train-standard", "train-challenge", "val", "test"), (False, True)):
            with places365_root(split=split, small=small) as places365:
                root, data = places365

                torchvision.datasets.Places365(root, split=split, small=small, download=True)

        urls = {call_args[0][0] for call_args in download_url.call_args_list}
        for url in urls:
            with self.subTest(url=url):
                response = urlopen(Request(url, method="HEAD"))
                assert response.code == 200, f"Server returned status code {response.code} for {url}."

    def test_places365_devkit_download(self):
        for split in ("train-standard", "train-challenge", "val", "test"):
            with self.subTest(split=split):
                with places365_root(split=split) as places365:
                    root, data = places365

                    dataset = torchvision.datasets.Places365(root, split=split, download=True)

                    with self.subTest("classes"):
                        self.assertSequenceEqual(dataset.classes, data["classes"])

                    with self.subTest("class_to_idx"):
                        self.assertDictEqual(dataset.class_to_idx, data["class_to_idx"])

                    with self.subTest("imgs"):
                        self.assertSequenceEqual(dataset.imgs, data["imgs"])

    def test_places365_devkit_no_download(self):
        for split in ("train-standard", "train-challenge", "val", "test"):
            with self.subTest(split=split):
                with places365_root(split=split, extract_images=False) as places365:
                    root, data = places365

                    with self.assertRaises(RuntimeError):
                        torchvision.datasets.Places365(root, split=split, download=False)

    def test_places365_images_download(self):
        for split, small in itertools.product(("train-standard", "train-challenge", "val", "test"), (False, True)):
            with self.subTest(split=split, small=small):
                with places365_root(split=split, small=small) as places365:
                    root, data = places365

                    dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)

                    assert all(os.path.exists(item[0]) for item in dataset.imgs)

    def test_places365_images_download_preexisting(self):
        split = "train-standard"
        small = False
        images_dir = "train_large_places365standard"

        with places365_root(split=split, small=small) as places365:
            root, data = places365
            os.mkdir(os.path.join(root, images_dir))

            with self.assertRaises(RuntimeError):
                torchvision.datasets.Places365(root, split=split, small=small, download=True)

    def test_places365_repr_smoke(self):
        with places365_root(extract_images=False) as places365:
            root, data = places365

            dataset = torchvision.datasets.Places365(root, download=True)
            self.assertIsInstance(repr(dataset), str)

383
384
385

if __name__ == '__main__':
    unittest.main()