test_datasets.py 23.6 KB
Newer Older
1
import contextlib
Francisco Massa's avatar
Francisco Massa committed
2
import sys
3
import os
4
import unittest
Philip Meier's avatar
Philip Meier committed
5
from unittest import mock
6
import numpy as np
7
import PIL
8
from PIL import Image
9
from torch._utils_internal import get_file_path_2
10
import torchvision
11
from torchvision.datasets import utils
12
from common_utils import get_tmp_dir
Philip Meier's avatar
Philip Meier committed
13
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
14
    cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root, stl10_root
15
import xml.etree.ElementTree as ET
Philip Meier's avatar
Philip Meier committed
16
17
from urllib.request import Request, urlopen
import itertools
18
19
20
21
import datasets_utils
import pathlib
import pickle
from torchvision import datasets
22
23


24
25
26
27
28
29
try:
    import scipy
    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False

30
31
32
33
34
35
try:
    import av
    HAS_PYAV = True
except ImportError:
    HAS_PYAV = False

36

37
class DatasetTestcase(unittest.TestCase):
38
39
40
41
42
43
    def generic_classification_dataset_test(self, dataset, num_images=1):
        self.assertEqual(len(dataset), num_images)
        img, target = dataset[0]
        self.assertTrue(isinstance(img, PIL.Image.Image))
        self.assertTrue(isinstance(target, int))

44
45
46
47
48
49
    def generic_segmentation_dataset_test(self, dataset, num_images=1):
        self.assertEqual(len(dataset), num_images)
        img, target = dataset[0]
        self.assertTrue(isinstance(img, PIL.Image.Image))
        self.assertTrue(isinstance(target, PIL.Image.Image))

50
51

class Tester(DatasetTestcase):
52
    def test_imagefolder(self):
53
54
55
56
        # TODO: create the fake data on-the-fly
        FAKEDATA_DIR = get_file_path_2(
            os.path.dirname(os.path.abspath(__file__)), 'assets', 'fakedata')

57
        with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
58
            classes = sorted(['a', 'b'])
59
60
61
62
63
64
            class_a_image_files = [
                os.path.join(root, 'a', file) for file in ('a1.png', 'a2.png', 'a3.png')
            ]
            class_b_image_files = [
                os.path.join(root, 'b', file) for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')
            ]
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
            dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x)

            # test if all classes are present
            self.assertEqual(classes, sorted(dataset.classes))

            # test if combination of classes and class_to_index functions correctly
            for cls in classes:
                self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])

            # test if all images were detected correctly
            class_a_idx = dataset.class_to_idx['a']
            class_b_idx = dataset.class_to_idx['b']
            imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files]
            imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files]
            imgs = sorted(imgs_a + imgs_b)
            self.assertEqual(imgs, dataset.imgs)

            # test if the datasets outputs all images correctly
            outputs = sorted([dataset[i] for i in range(len(dataset))])
            self.assertEqual(imgs, outputs)

            # redo all tests with specified valid image files
87
88
            dataset = torchvision.datasets.ImageFolder(
                root, loader=lambda x: x, is_valid_file=lambda x: '3' in x)
89
90
91
92
93
94
95
96
97
98
99
100
101
102
            self.assertEqual(classes, sorted(dataset.classes))

            class_a_idx = dataset.class_to_idx['a']
            class_b_idx = dataset.class_to_idx['b']
            imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files
                      if '3' in img_file]
            imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files
                      if '3' in img_file]
            imgs = sorted(imgs_a + imgs_b)
            self.assertEqual(imgs, dataset.imgs)

            outputs = sorted([dataset[i] for i in range(len(dataset))])
            self.assertEqual(imgs, outputs)

103
104
105
106
107
108
109
110
111
112
    def test_imagefolder_empty(self):
        with get_tmp_dir() as root:
            with self.assertRaises(RuntimeError):
                torchvision.datasets.ImageFolder(root, loader=lambda x: x)

            with self.assertRaises(RuntimeError):
                torchvision.datasets.ImageFolder(
                    root, loader=lambda x: x, is_valid_file=lambda x: False
                )

113
114
115
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_mnist(self, mock_download_extract):
        num_examples = 30
116
        with mnist_root(num_examples, "MNIST") as root:
117
            dataset = torchvision.datasets.MNIST(root, download=True)
118
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
119
            img, target = dataset[0]
120
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
121

122
123
124
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_kmnist(self, mock_download_extract):
        num_examples = 30
125
        with mnist_root(num_examples, "KMNIST") as root:
126
            dataset = torchvision.datasets.KMNIST(root, download=True)
127
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
128
            img, target = dataset[0]
129
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
130

131
132
133
    @mock.patch('torchvision.datasets.mnist.download_and_extract_archive')
    def test_fashionmnist(self, mock_download_extract):
        num_examples = 30
134
        with mnist_root(num_examples, "FashionMNIST") as root:
135
            dataset = torchvision.datasets.FashionMNIST(root, download=True)
136
            self.generic_classification_dataset_test(dataset, num_images=num_examples)
137
            img, target = dataset[0]
138
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
139

140
    @mock.patch('torchvision.datasets.imagenet._verify_archive')
141
    @unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
142
    def test_imagenet(self, mock_verify):
143
        with imagenet_root() as root:
144
            dataset = torchvision.datasets.ImageNet(root, split='train')
145
            self.generic_classification_dataset_test(dataset)
146

147
            dataset = torchvision.datasets.ImageNet(root, split='val')
148
            self.generic_classification_dataset_test(dataset)
149

Josh Bradley's avatar
Josh Bradley committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    @mock.patch('torchvision.datasets.WIDERFace._check_integrity')
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
    def test_widerface(self, mock_check_integrity):
        mock_check_integrity.return_value = True
        with widerface_root() as root:
            dataset = torchvision.datasets.WIDERFace(root, split='train')
            self.assertEqual(len(dataset), 1)
            img, target = dataset[0]
            self.assertTrue(isinstance(img, PIL.Image.Image))

            dataset = torchvision.datasets.WIDERFace(root, split='val')
            self.assertEqual(len(dataset), 1)
            img, target = dataset[0]
            self.assertTrue(isinstance(img, PIL.Image.Image))

            dataset = torchvision.datasets.WIDERFace(root, split='test')
            self.assertEqual(len(dataset), 1)
            img, target = dataset[0]
            self.assertTrue(isinstance(img, PIL.Image.Image))

Philip Meier's avatar
Philip Meier committed
170
171
172
173
174
175
176
    @mock.patch('torchvision.datasets.cifar.check_integrity')
    @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
    def test_cifar10(self, mock_ext_check, mock_int_check):
        mock_ext_check.return_value = True
        mock_int_check.return_value = True
        with cifar_root('CIFAR10') as root:
            dataset = torchvision.datasets.CIFAR10(root, train=True, download=True)
177
            self.generic_classification_dataset_test(dataset, num_images=5)
Philip Meier's avatar
Philip Meier committed
178
            img, target = dataset[0]
179
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
180
181

            dataset = torchvision.datasets.CIFAR10(root, train=False, download=True)
182
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
183
            img, target = dataset[0]
184
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
185
186
187
188
189
190
191
192

    @mock.patch('torchvision.datasets.cifar.check_integrity')
    @mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
    def test_cifar100(self, mock_ext_check, mock_int_check):
        mock_ext_check.return_value = True
        mock_int_check.return_value = True
        with cifar_root('CIFAR100') as root:
            dataset = torchvision.datasets.CIFAR100(root, train=True, download=True)
193
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
194
            img, target = dataset[0]
195
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
196
197

            dataset = torchvision.datasets.CIFAR100(root, train=False, download=True)
198
            self.generic_classification_dataset_test(dataset)
Philip Meier's avatar
Philip Meier committed
199
            img, target = dataset[0]
200
            self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
Philip Meier's avatar
Philip Meier committed
201

Francisco Massa's avatar
Francisco Massa committed
202
    @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
203
204
205
206
207
208
209
210
211
212
213
214
    def test_cityscapes(self):
        with cityscapes_root() as root:

            for mode in ['coarse', 'fine']:

                if mode == 'coarse':
                    splits = ['train', 'train_extra', 'val']
                else:
                    splits = ['train', 'val', 'test']

                for split in splits:
                    for target_type in ['semantic', 'instance']:
215
216
                        dataset = torchvision.datasets.Cityscapes(
                            root, split=split, target_type=target_type, mode=mode)
217
218
                        self.generic_segmentation_dataset_test(dataset, num_images=2)

219
220
                    color_dataset = torchvision.datasets.Cityscapes(
                        root, split=split, target_type='color', mode=mode)
221
222
223
224
                    color_img, color_target = color_dataset[0]
                    self.assertTrue(isinstance(color_img, PIL.Image.Image))
                    self.assertTrue(np.array(color_target).shape[2] == 4)

225
226
                    polygon_dataset = torchvision.datasets.Cityscapes(
                        root, split=split, target_type='polygon', mode=mode)
227
228
229
230
231
232
233
234
                    polygon_img, polygon_target = polygon_dataset[0]
                    self.assertTrue(isinstance(polygon_img, PIL.Image.Image))
                    self.assertTrue(isinstance(polygon_target, dict))
                    self.assertTrue(isinstance(polygon_target['imgHeight'], int))
                    self.assertTrue(isinstance(polygon_target['objects'], list))

                    # Test multiple target types
                    targets_combo = ['semantic', 'polygon', 'color']
235
236
                    multiple_types_dataset = torchvision.datasets.Cityscapes(
                        root, split=split, target_type=targets_combo, mode=mode)
237
238
239
240
241
242
243
244
245
246
                    output = multiple_types_dataset[0]
                    self.assertTrue(isinstance(output, tuple))
                    self.assertTrue(len(output) == 2)
                    self.assertTrue(isinstance(output[0], PIL.Image.Image))
                    self.assertTrue(isinstance(output[1], tuple))
                    self.assertTrue(len(output[1]) == 3)
                    self.assertTrue(isinstance(output[1][0], PIL.Image.Image))  # semantic
                    self.assertTrue(isinstance(output[1][1], dict))  # polygon
                    self.assertTrue(isinstance(output[1][2], PIL.Image.Image))  # color

Philip Meier's avatar
Philip Meier committed
247
    @mock.patch('torchvision.datasets.SVHN._check_integrity')
248
    @unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
Philip Meier's avatar
Philip Meier committed
249
250
251
252
253
254
255
256
257
258
259
260
    def test_svhn(self, mock_check):
        mock_check.return_value = True
        with svhn_root() as root:
            dataset = torchvision.datasets.SVHN(root, split="train")
            self.generic_classification_dataset_test(dataset, num_images=2)

            dataset = torchvision.datasets.SVHN(root, split="test")
            self.generic_classification_dataset_test(dataset, num_images=2)

            dataset = torchvision.datasets.SVHN(root, split="extra")
            self.generic_classification_dataset_test(dataset, num_images=2)

261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    @mock.patch('torchvision.datasets.voc.download_extract')
    def test_voc_parse_xml(self, mock_download_extract):
        with voc_root() as root:
            dataset = torchvision.datasets.VOCDetection(root)

            single_object_xml = """<annotation>
              <object>
                <name>cat</name>
              </object>
            </annotation>"""
            multiple_object_xml = """<annotation>
              <object>
                <name>cat</name>
              </object>
              <object>
                <name>dog</name>
              </object>
            </annotation>"""
279
280

            single_object_parsed = dataset.parse_voc_xml(ET.fromstring(single_object_xml))
281
282
            multiple_object_parsed = dataset.parse_voc_xml(ET.fromstring(multiple_object_xml))

283
284
285
286
287
288
289
290
291
            self.assertEqual(single_object_parsed, {'annotation': {'object': [{'name': 'cat'}]}})
            self.assertEqual(multiple_object_parsed,
                             {'annotation': {
                                 'object': [{
                                     'name': 'cat'
                                 }, {
                                     'name': 'dog'
                                 }]
                             }})
292

293
294
    @unittest.skipIf(not HAS_PYAV, "PyAV unavailable")
    def test_ucf101(self):
295
        cached_meta_data = None
296
297
298
299
        with ucf101_root() as (root, ann_root):
            for split in {True, False}:
                for fold in range(1, 4):
                    for length in {10, 15, 20}:
300
301
302
303
                        dataset = torchvision.datasets.UCF101(root, ann_root, length, fold=fold, train=split,
                                                              num_workers=2, _precomputed_metadata=cached_meta_data)
                        if cached_meta_data is None:
                            cached_meta_data = dataset.metadata
304
305
306
307
308
309
310
311
312
313
314
315
                        self.assertGreater(len(dataset), 0)

                        video, audio, label = dataset[0]
                        self.assertEqual(video.size(), (length, 320, 240, 3))
                        self.assertEqual(audio.numel(), 0)
                        self.assertEqual(label, 0)

                        video, audio, label = dataset[len(dataset) - 1]
                        self.assertEqual(video.size(), (length, 320, 240, 3))
                        self.assertEqual(audio.numel(), 0)
                        self.assertEqual(label, 1)

Philip Meier's avatar
Philip Meier committed
316
    def test_places365(self):
Philip Meier's avatar
Philip Meier committed
317
        for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
Philip Meier's avatar
Philip Meier committed
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
            with places365_root(split=split, small=small) as places365:
                root, data = places365

                dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)
                self.generic_classification_dataset_test(dataset, num_images=len(data["imgs"]))

    def test_places365_transforms(self):
        expected_image = "image"
        expected_target = "target"

        def transform(image):
            return expected_image

        def target_transform(target):
            return expected_target

        with places365_root() as places365:
            root, data = places365

            dataset = torchvision.datasets.Places365(
                root, transform=transform, target_transform=target_transform, download=True
            )
            actual_image, actual_target = dataset[0]

            self.assertEqual(actual_image, expected_image)
            self.assertEqual(actual_target, expected_target)

    def test_places365_devkit_download(self):
Philip Meier's avatar
Philip Meier committed
346
        for split in ("train-standard", "train-challenge", "val"):
Philip Meier's avatar
Philip Meier committed
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
            with self.subTest(split=split):
                with places365_root(split=split) as places365:
                    root, data = places365

                    dataset = torchvision.datasets.Places365(root, split=split, download=True)

                    with self.subTest("classes"):
                        self.assertSequenceEqual(dataset.classes, data["classes"])

                    with self.subTest("class_to_idx"):
                        self.assertDictEqual(dataset.class_to_idx, data["class_to_idx"])

                    with self.subTest("imgs"):
                        self.assertSequenceEqual(dataset.imgs, data["imgs"])

    def test_places365_devkit_no_download(self):
Philip Meier's avatar
Philip Meier committed
363
        for split in ("train-standard", "train-challenge", "val"):
Philip Meier's avatar
Philip Meier committed
364
            with self.subTest(split=split):
365
                with places365_root(split=split) as places365:
Philip Meier's avatar
Philip Meier committed
366
367
368
369
370
371
                    root, data = places365

                    with self.assertRaises(RuntimeError):
                        torchvision.datasets.Places365(root, split=split, download=False)

    def test_places365_images_download(self):
Philip Meier's avatar
Philip Meier committed
372
        for split, small in itertools.product(("train-standard", "train-challenge", "val"), (False, True)):
Philip Meier's avatar
Philip Meier committed
373
374
375
376
377
378
379
380
381
382
383
            with self.subTest(split=split, small=small):
                with places365_root(split=split, small=small) as places365:
                    root, data = places365

                    dataset = torchvision.datasets.Places365(root, split=split, small=small, download=True)

                    assert all(os.path.exists(item[0]) for item in dataset.imgs)

    def test_places365_images_download_preexisting(self):
        split = "train-standard"
        small = False
Philip Meier's avatar
Philip Meier committed
384
        images_dir = "data_large_standard"
Philip Meier's avatar
Philip Meier committed
385
386
387
388
389
390
391
392
393

        with places365_root(split=split, small=small) as places365:
            root, data = places365
            os.mkdir(os.path.join(root, images_dir))

            with self.assertRaises(RuntimeError):
                torchvision.datasets.Places365(root, split=split, small=small, download=True)

    def test_places365_repr_smoke(self):
394
        with places365_root() as places365:
Philip Meier's avatar
Philip Meier committed
395
396
397
398
399
            root, data = places365

            dataset = torchvision.datasets.Places365(root, download=True)
            self.assertIsInstance(repr(dataset), str)

400

401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
class STL10Tester(DatasetTestcase):
    @contextlib.contextmanager
    def mocked_root(self):
        with stl10_root() as (root, data):
            yield root, data

    @contextlib.contextmanager
    def mocked_dataset(self, pre_extract=False, download=True, **kwargs):
        with self.mocked_root() as (root, data):
            if pre_extract:
                utils.extract_archive(os.path.join(root, data["archive"]))
            dataset = torchvision.datasets.STL10(root, download=download, **kwargs)
            yield dataset, data

    def test_not_found(self):
        with self.assertRaises(RuntimeError):
            with self.mocked_dataset(download=False):
                pass

    def test_splits(self):
        for split in ('train', 'train+unlabeled', 'unlabeled', 'test'):
            with self.mocked_dataset(split=split) as (dataset, data):
                num_images = sum([data["num_images_in_split"][part] for part in split.split("+")])
                self.generic_classification_dataset_test(dataset, num_images=num_images)

    def test_folds(self):
        for fold in range(10):
            with self.mocked_dataset(split="train", folds=fold) as (dataset, data):
                num_images = data["num_images_in_folds"][fold]
                self.assertEqual(len(dataset), num_images)

    def test_invalid_folds1(self):
        with self.assertRaises(ValueError):
            with self.mocked_dataset(folds=10):
                pass

    def test_invalid_folds2(self):
        with self.assertRaises(ValueError):
            with self.mocked_dataset(folds="0"):
                pass

    def test_transforms(self):
        expected_image = "image"
        expected_target = "target"

        def transform(image):
            return expected_image

        def target_transform(target):
            return expected_target

        with self.mocked_dataset(transform=transform, target_transform=target_transform) as (dataset, _):
            actual_image, actual_target = dataset[0]

            self.assertEqual(actual_image, expected_image)
            self.assertEqual(actual_target, expected_target)

    def test_unlabeled(self):
        with self.mocked_dataset(split="unlabeled") as (dataset, _):
            labels = [dataset[idx][1] for idx in range(len(dataset))]
            self.assertTrue(all([label == -1 for label in labels]))

    @unittest.mock.patch("torchvision.datasets.stl10.download_and_extract_archive")
    def test_download_preexisting(self, mock):
        with self.mocked_dataset(pre_extract=True) as (dataset, data):
            mock.assert_not_called()

    def test_repr_smoke(self):
        with self.mocked_dataset() as (dataset, _):
            self.assertIsInstance(repr(dataset), str)


473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
    DATASET_CLASS = datasets.Caltech256

    def inject_fake_data(self, tmpdir, config):
        tmpdir = pathlib.Path(tmpdir) / "caltech256" / "256_ObjectCategories"

        categories = ((1, "ak47"), (127, "laptop-101"), (257, "clutter"))
        num_images_per_category = 2

        for idx, category in categories:
            datasets_utils.create_image_folder(
                tmpdir,
                name=f"{idx:03d}.{category}",
                file_name_fn=lambda image_idx: f"{idx:03d}_{image_idx + 1:04d}.jpg",
                num_examples=num_images_per_category,
            )

        return num_images_per_category * len(categories)


class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
    DATASET_CLASS = datasets.CIFAR10
    CONFIGS = datasets_utils.combinations_grid(train=(True, False))

    _VERSION_CONFIG = dict(
        base_folder="cifar-10-batches-py",
        train_files=tuple(f"data_batch_{idx}" for idx in range(1, 6)),
        test_files=("test_batch",),
        labels_key="labels",
        meta_file="batches.meta",
        num_categories=10,
        categories_key="label_names",
    )

    def inject_fake_data(self, tmpdir, config):
        tmpdir = pathlib.Path(tmpdir) / self._VERSION_CONFIG["base_folder"]
        os.makedirs(tmpdir)

        num_images_per_file = 1
        for name in itertools.chain(self._VERSION_CONFIG["train_files"], self._VERSION_CONFIG["test_files"]):
            self._create_batch_file(tmpdir, name, num_images_per_file)

        categories = self._create_meta_file(tmpdir)

        return dict(
            num_examples=num_images_per_file
            * len(self._VERSION_CONFIG["train_files"] if config["train"] else self._VERSION_CONFIG["test_files"]),
            categories=categories,
        )

    def _create_batch_file(self, root, name, num_images):
        data = datasets_utils.create_image_or_video_tensor((num_images, 32 * 32 * 3))
        labels = np.random.randint(0, self._VERSION_CONFIG["num_categories"], size=num_images).tolist()
        self._create_binary_file(root, name, {"data": data, self._VERSION_CONFIG["labels_key"]: labels})

    def _create_meta_file(self, root):
        categories = [
            f"{idx:0{len(str(self._VERSION_CONFIG['num_categories'] - 1))}d}"
            for idx in range(self._VERSION_CONFIG["num_categories"])
        ]
        self._create_binary_file(
            root, self._VERSION_CONFIG["meta_file"], {self._VERSION_CONFIG["categories_key"]: categories}
        )
        return categories

    def _create_binary_file(self, root, name, content):
        with open(pathlib.Path(root) / name, "wb") as fh:
            pickle.dump(content, fh)

    def test_class_to_idx(self):
        with self.create_dataset() as (dataset, info):
            expected = {category: label for label, category in enumerate(info["categories"])}
            actual = dataset.class_to_idx
            self.assertEqual(actual, expected)


class CIFAR100(CIFAR10TestCase):
    DATASET_CLASS = datasets.CIFAR100

    _VERSION_CONFIG = dict(
        base_folder="cifar-100-python",
        train_files=("train",),
        test_files=("test",),
        labels_key="fine_labels",
        meta_file="meta",
        num_categories=100,
        categories_key="fine_label_names",
    )


if __name__ == "__main__":
564
    unittest.main()