mnist.py 20.3 KB
Newer Older
1
import codecs
Tian Qi Chen's avatar
Tian Qi Chen committed
2
3
import os
import os.path
4
import shutil
5
import string
6
import sys
7
import warnings
8
from typing import Any, Callable, Dict, List, Optional, Tuple
9
from urllib.error import URLError
10
11
12
13
14

import numpy as np
import torch
from PIL import Image

15
from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity
16
from .vision import VisionDataset
Tian Qi Chen's avatar
Tian Qi Chen committed
17

18

19
class MNIST(VisionDataset):
20
21
22
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.

    Args:
23
24
25
26
27
        root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
            and  ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
            otherwise from ``t10k-images-idx3-ubyte``.
        download (bool, optional): If True, downloads the dataset from the internet and
28
29
30
31
32
33
34
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
35

36
    mirrors = [
37
38
        "http://yann.lecun.com/exdb/mnist/",
        "https://ossci-datasets.s3.amazonaws.com/mnist/",
39
40
    ]

41
    resources = [
42
43
44
        ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
        ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
        ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
45
        ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
Tian Qi Chen's avatar
Tian Qi Chen committed
46
    ]
47

48
49
50
51
52
53
54
55
56
57
58
59
60
61
    training_file = "training.pt"
    test_file = "test.pt"
    classes = [
        "0 - zero",
        "1 - one",
        "2 - two",
        "3 - three",
        "4 - four",
        "5 - five",
        "6 - six",
        "7 - seven",
        "8 - eight",
        "9 - nine",
    ]
62

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    @property
    def train_labels(self):
        warnings.warn("train_labels has been renamed targets")
        return self.targets

    @property
    def test_labels(self):
        warnings.warn("test_labels has been renamed targets")
        return self.targets

    @property
    def train_data(self):
        warnings.warn("train_data has been renamed data")
        return self.data

    @property
    def test_data(self):
        warnings.warn("test_data has been renamed data")
        return self.data

83
    def __init__(
84
85
86
87
88
89
        self,
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
90
    ) -> None:
91
        super(MNIST, self).__init__(root, transform=transform, target_transform=target_transform)
92
        self.train = train  # training set or test set
Tian Qi Chen's avatar
Tian Qi Chen committed
93

94
95
96
97
        if self._check_legacy_exist():
            self.data, self.targets = self._load_legacy_data()
            return

Tian Qi Chen's avatar
Tian Qi Chen committed
98
99
100
101
        if download:
            self.download()

        if not self._check_exists():
102
            raise RuntimeError("Dataset not found." + " You can use download=True to download it")
Tian Qi Chen's avatar
Tian Qi Chen committed
103

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        self.data, self.targets = self._load_data()

    def _check_legacy_exist(self):
        processed_folder_exists = os.path.exists(self.processed_folder)
        if not processed_folder_exists:
            return False

        return all(
            check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
        )

    def _load_legacy_data(self):
        # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
        # directly.
        data_file = self.training_file if self.train else self.test_file
        return torch.load(os.path.join(self.processed_folder, data_file))

    def _load_data(self):
        image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
        data = read_image_file(os.path.join(self.raw_folder, image_file))

        label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
        targets = read_label_file(os.path.join(self.raw_folder, label_file))

        return data, targets
Tian Qi Chen's avatar
Tian Qi Chen committed
129

130
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
131
132
133
134
135
136
137
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
138
        img, target = self.data[index], int(self.targets[index])
Tian Qi Chen's avatar
Tian Qi Chen committed
139
140
141

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
142
        img = Image.fromarray(img.numpy(), mode="L")
Tian Qi Chen's avatar
Tian Qi Chen committed
143
144
145
146
147
148
149
150
151

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

152
    def __len__(self) -> int:
153
        return len(self.data)
Tian Qi Chen's avatar
Tian Qi Chen committed
154

155
    @property
156
    def raw_folder(self) -> str:
157
        return os.path.join(self.root, self.__class__.__name__, "raw")
158
159

    @property
160
    def processed_folder(self) -> str:
161
        return os.path.join(self.root, self.__class__.__name__, "processed")
162
163

    @property
164
    def class_to_idx(self) -> Dict[str, int]:
165
166
        return {_class: i for i, _class in enumerate(self.classes)}

167
    def _check_exists(self) -> bool:
168
169
170
171
        return all(
            check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
            for url, _ in self.resources
        )
172

173
    def download(self) -> None:
174
        """Download the MNIST data if it doesn't exist already."""
Tian Qi Chen's avatar
Tian Qi Chen committed
175
176
177
178

        if self._check_exists():
            return

179
        os.makedirs(self.raw_folder, exist_ok=True)
Tian Qi Chen's avatar
Tian Qi Chen committed
180

181
        # download files
182
183
184
185
186
        for filename, md5 in self.resources:
            for mirror in self.mirrors:
                url = "{}{}".format(mirror, filename)
                try:
                    print("Downloading {}".format(url))
187
                    download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)
188
                except URLError as error:
189
                    print("Failed to download (trying next):\n{}".format(error))
190
191
192
193
194
195
                    continue
                finally:
                    print()
                break
            else:
                raise RuntimeError("Error downloading {}".format(filename))
Tian Qi Chen's avatar
Tian Qi Chen committed
196

197
    def extra_repr(self) -> str:
198
        return "Split: {}".format("Train" if self.train is True else "Test")
199

200

201
class FashionMNIST(MNIST):
202
203
204
    """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.

    Args:
205
206
207
208
209
        root (string): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte``
            and  ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist.
        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
            otherwise from ``t10k-images-idx3-ubyte``.
        download (bool, optional): If True, downloads the dataset from the internet and
210
211
212
213
214
215
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
216
    """
217
218

    mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"]
219

220
    resources = [
221
222
223
        ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
        ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"),
        ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"),
224
        ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"),
225
    ]
226
    classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
227
228


hysts's avatar
hysts committed
229
230
231
232
class KMNIST(MNIST):
    """`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset.

    Args:
233
234
235
236
237
        root (string): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte``
            and  ``KMNIST/raw/t10k-images-idx3-ubyte`` exist.
        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
            otherwise from ``t10k-images-idx3-ubyte``.
        download (bool, optional): If True, downloads the dataset from the internet and
hysts's avatar
hysts committed
238
239
240
241
242
243
244
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
245
246

    mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"]
247

248
    resources = [
249
250
251
        ("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"),
        ("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"),
        ("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"),
252
        ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"),
hysts's avatar
hysts committed
253
    ]
254
    classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"]
hysts's avatar
hysts committed
255
256


257
class EMNIST(MNIST):
Alex Alemi's avatar
Alex Alemi committed
258
    """`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.
259
260

    Args:
261
262
        root (string): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte``
            and  ``EMNIST/raw/t10k-images-idx3-ubyte`` exist.
263
264
265
266
267
        split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
            ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
            which one to use.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
268
        download (bool, optional): If True, downloads the dataset from the internet and
269
270
271
272
273
274
275
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
276
277

    url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip"
278
    md5 = "58c8d27c78d21e728a6bc7b3cc06412e"
279
    splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist")
280
    # Merged Classes assumes Same structure for both uppercase and lowercase version
281
    _merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"}
282
    _all_classes = set(string.digits + string.ascii_letters)
283
    classes_split_dict = {
284
285
286
287
288
289
        "byclass": sorted(list(_all_classes)),
        "bymerge": sorted(list(_all_classes - _merged_classes)),
        "balanced": sorted(list(_all_classes - _merged_classes)),
        "letters": ["N/A"] + list(string.ascii_lowercase),
        "digits": list(string.digits),
        "mnist": list(string.digits),
290
    }
291

292
    def __init__(self, root: str, split: str, **kwargs: Any) -> None:
293
        self.split = verify_str_arg(split, "split", self.splits)
294
295
296
        self.training_file = self._training_file(split)
        self.test_file = self._test_file(split)
        super(EMNIST, self).__init__(root, **kwargs)
297
        self.classes = self.classes_split_dict[self.split]
Tian Qi Chen's avatar
Tian Qi Chen committed
298

299
    @staticmethod
300
    def _training_file(split) -> str:
301
        return "training_{}.pt".format(split)
302

303
    @staticmethod
304
    def _test_file(split) -> str:
305
        return "test_{}.pt".format(split)
306

307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    @property
    def _file_prefix(self) -> str:
        return f"emnist-{self.split}-{'train' if self.train else 'test'}"

    @property
    def images_file(self) -> str:
        return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte")

    @property
    def labels_file(self) -> str:
        return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte")

    def _load_data(self):
        return read_image_file(self.images_file), read_label_file(self.labels_file)

    def _check_exists(self) -> bool:
        return all(check_integrity(file) for file in (self.images_file, self.labels_file))

325
    def download(self) -> None:
326
        """Download the EMNIST data if it doesn't exist already."""
327

328
329
330
        if self._check_exists():
            return

331
        os.makedirs(self.raw_folder, exist_ok=True)
332

333
        download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5)
334
        gzip_folder = os.path.join(self.raw_folder, "gzip")
335
        for gzip_file in os.listdir(gzip_folder):
336
            if gzip_file.endswith(".gz"):
337
                extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder)
338
        shutil.rmtree(gzip_folder)
339
340


341
342
343
344
class QMNIST(MNIST):
    """`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.

    Args:
345
346
        root (string): Root directory of dataset whose ``raw``
            subdir contains binary files of the datasets.
347
348
349
350
351
352
353
354
355
356
357
        what (string,optional): Can be 'train', 'test', 'test10k',
            'test50k', or 'nist' for respectively the mnist compatible
            training set, the 60k qmnist testing set, the 10k qmnist
            examples that match the mnist testing set, the 50k
            remaining qmnist testing examples, or all the nist
            digits. The default is to select 'train' or 'test'
            according to the compatibility argument 'train'.
        compat (bool,optional): A boolean that says whether the target
            for each example is class number (for compatibility with
            the MNIST dataloader) or a torch vector containing the
            full qmnist information. Default=True.
358
        download (bool, optional): If True, downloads the dataset from
359
360
361
362
363
364
365
366
367
368
369
370
            the internet and puts it in root directory. If dataset is
            already downloaded, it is not downloaded again.
        transform (callable, optional): A function/transform that
            takes in an PIL image and returns a transformed
            version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform
            that takes in the target and transforms it.
        train (bool,optional,compatibility): When argument 'what' is
            not specified, this boolean decides whether to load the
            training set ot the testing set.  Default: True.
    """

371
    subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
372
    resources: Dict[str, List[Tuple[str, str]]] = {  # type: ignore[assignment]
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        "train": [
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz",
                "ed72d4157d28c017586c42bc6afe6370",
            ),
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz",
                "0058f8dd561b90ffdd0f734c6a30e5e4",
            ),
        ],
        "test": [
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz",
                "1394631089c404de565df7b7aeaf9412",
            ),
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz",
                "5b5b05890a5e13444e108efe57b788aa",
            ),
        ],
        "nist": [
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz",
                "7f124b3b8ab81486c9d8c2749c17f834",
            ),
            (
                "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz",
                "5ed0e788978e45d4a8bd4b7caec3d79d",
            ),
        ],
403
    }
404
405
406
407
408
409
410
411
412
413
414
415
    classes = [
        "0 - zero",
        "1 - one",
        "2 - two",
        "3 - three",
        "4 - four",
        "5 - five",
        "6 - six",
        "7 - seven",
        "8 - eight",
        "9 - nine",
    ]
416

417
    def __init__(
418
        self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any
419
    ) -> None:
420
        if what is None:
421
            what = "train" if train else "test"
422
        self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
423
        self.compat = compat
424
        self.data_file = what + ".pt"
425
426
427
428
        self.training_file = self.data_file
        self.test_file = self.data_file
        super(QMNIST, self).__init__(root, train, **kwargs)

429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
    @property
    def images_file(self) -> str:
        (url, _), _ = self.resources[self.subsets[self.what]]
        return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])

    @property
    def labels_file(self) -> str:
        _, (url, _) = self.resources[self.subsets[self.what]]
        return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])

    def _check_exists(self) -> bool:
        return all(check_integrity(file) for file in (self.images_file, self.labels_file))

    def _load_data(self):
        data = read_sn3_pascalvincent_tensor(self.images_file)
444
445
        assert data.dtype == torch.uint8
        assert data.ndimension() == 3
446
447

        targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
448
        assert targets.ndimension() == 2
449

450
        if self.what == "test10k":
451
452
            data = data[0:10000, :, :].clone()
            targets = targets[0:10000, :].clone()
453
        elif self.what == "test50k":
454
455
456
457
458
            data = data[10000:, :, :].clone()
            targets = targets[10000:, :].clone()

        return data, targets

459
    def download(self) -> None:
460
        """Download the QMNIST data if it doesn't exist already.
461
        Note that we only download what has been asked for (argument 'what').
462
463
464
        """
        if self._check_exists():
            return
465

466
        os.makedirs(self.raw_folder, exist_ok=True)
467
        split = self.resources[self.subsets[self.what]]
468

469
        for url, md5 in split:
470
            download_and_extract_archive(url, self.raw_folder, md5=md5)
471

472
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
473
474
        # redefined to handle the compat flag
        img, target = self.data[index], self.targets[index]
475
        img = Image.fromarray(img.numpy(), mode="L")
476
477
478
479
480
481
482
483
        if self.transform is not None:
            img = self.transform(img)
        if self.compat:
            target = int(target[0])
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

484
    def extra_repr(self) -> str:
485
486
487
        return "Split: {}".format(self.what)


488
def get_int(b: bytes) -> int:
489
    return int(codecs.encode(b, "hex"), 16)
Tian Qi Chen's avatar
Tian Qi Chen committed
490

491

492
SN3_PASCALVINCENT_TYPEMAP = {
493
494
495
496
497
498
    8: torch.uint8,
    9: torch.int8,
    11: torch.int16,
    12: torch.int32,
    13: torch.float32,
    14: torch.float64,
499
500
501
}


502
def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
503
    """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
504
    Argument may be a filename, compressed filename, or file object.
505
506
    """
    # read
507
    with open(path, "rb") as f:
508
509
510
511
512
        data = f.read()
    # parse
    magic = get_int(data[0:4])
    nd = magic % 256
    ty = magic // 256
513
514
    assert 1 <= nd <= 3
    assert 8 <= ty <= 14
515
    torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
516
    s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
517
518
519
520
521
522
523
524
525

    num_bytes_per_value = torch.iinfo(torch_type).bits // 8
    # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
    # we need to reverse the bytes before we can read them with torch.frombuffer().
    needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
    parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
    if needs_byte_reversal:
        parsed = parsed.flip(0)

526
    assert parsed.shape[0] == np.prod(s) or not strict
527
    return parsed.view(*s)
528
529


530
def read_label_file(path: str) -> torch.Tensor:
531
    x = read_sn3_pascalvincent_tensor(path, strict=False)
532
533
    assert x.dtype == torch.uint8
    assert x.ndimension() == 1
534
    return x.long()
Tian Qi Chen's avatar
Tian Qi Chen committed
535

536

537
def read_image_file(path: str) -> torch.Tensor:
538
    x = read_sn3_pascalvincent_tensor(path, strict=False)
539
540
    assert x.dtype == torch.uint8
    assert x.ndimension() == 3
541
    return x