mnist.py 12.8 KB
Newer Older
Tian Qi Chen's avatar
Tian Qi Chen committed
1
2
3
4
5
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
6
import gzip
7
import numpy as np
Tian Qi Chen's avatar
Tian Qi Chen committed
8
9
import torch
import codecs
10
from .utils import download_url, makedir_exist_ok
Tian Qi Chen's avatar
Tian Qi Chen committed
11

12

Tian Qi Chen's avatar
Tian Qi Chen committed
13
class MNIST(data.Dataset):
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
Tian Qi Chen's avatar
Tian Qi Chen committed
29
30
31
32
33
34
    urls = [
        'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
    ]
35
36
    training_file = 'training.pt'
    test_file = 'test.pt'
37
38
39
    classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

Tian Qi Chen's avatar
Tian Qi Chen committed
40
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
41
        self.root = os.path.expanduser(root)
Tian Qi Chen's avatar
Tian Qi Chen committed
42
43
        self.transform = transform
        self.target_transform = target_transform
44
        self.train = train  # training set or test set
Tian Qi Chen's avatar
Tian Qi Chen committed
45
46
47
48
49

        if download:
            self.download()

        if not self._check_exists():
50
51
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')
Tian Qi Chen's avatar
Tian Qi Chen committed
52
53

        if self.train:
54
            data_file = self.training_file
Tian Qi Chen's avatar
Tian Qi Chen committed
55
        else:
56
57
            data_file = self.test_file
        self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
Tian Qi Chen's avatar
Tian Qi Chen committed
58
59

    def __getitem__(self, index):
60
61
62
63
64
65
66
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
67
        img, target = self.data[index], int(self.targets[index])
Tian Qi Chen's avatar
Tian Qi Chen committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
82
        return len(self.data)
Tian Qi Chen's avatar
Tian Qi Chen committed
83

84
85
86
87
88
89
90
91
92
93
94
95
    @property
    def raw_folder(self):
        return os.path.join(self.root, self.__class__.__name__, 'raw')

    @property
    def processed_folder(self):
        return os.path.join(self.root, self.__class__.__name__, 'processed')

    @property
    def class_to_idx(self):
        return {_class: i for i, _class in enumerate(self.classes)}

Tian Qi Chen's avatar
Tian Qi Chen committed
96
    def _check_exists(self):
97
98
99
100
101
102
103
104
105
106
107
        return os.path.exists(os.path.join(self.processed_folder, self.training_file)) and \
            os.path.exists(os.path.join(self.processed_folder, self.test_file))

    @staticmethod
    def extract_gzip(gzip_path, remove_finished=False):
        print('Extracting {}'.format(gzip_path))
        with open(gzip_path.replace('.gz', ''), 'wb') as out_f, \
                gzip.GzipFile(gzip_path) as zip_f:
            out_f.write(zip_f.read())
        if remove_finished:
            os.unlink(gzip_path)
Tian Qi Chen's avatar
Tian Qi Chen committed
108
109

    def download(self):
110
        """Download the MNIST data if it doesn't exist in processed_folder already."""
Tian Qi Chen's avatar
Tian Qi Chen committed
111
112
113
114

        if self._check_exists():
            return

115
116
        makedir_exist_ok(self.raw_folder)
        makedir_exist_ok(self.processed_folder)
Tian Qi Chen's avatar
Tian Qi Chen committed
117

118
        # download files
Tian Qi Chen's avatar
Tian Qi Chen committed
119
120
        for url in self.urls:
            filename = url.rpartition('/')[2]
121
            file_path = os.path.join(self.raw_folder, filename)
122
            download_url(url, root=self.raw_folder, filename=filename, md5=None)
123
            self.extract_gzip(gzip_path=file_path, remove_finished=True)
Tian Qi Chen's avatar
Tian Qi Chen committed
124
125

        # process and save as torch files
Adam Paszke's avatar
Adam Paszke committed
126
127
        print('Processing...')

Tian Qi Chen's avatar
Tian Qi Chen committed
128
        training_set = (
129
130
            read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
            read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
Tian Qi Chen's avatar
Tian Qi Chen committed
131
132
        )
        test_set = (
133
134
            read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
            read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
Tian Qi Chen's avatar
Tian Qi Chen committed
135
        )
136
        with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
Tian Qi Chen's avatar
Tian Qi Chen committed
137
            torch.save(training_set, f)
138
        with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
Tian Qi Chen's avatar
Tian Qi Chen committed
139
140
141
142
            torch.save(test_set, f)

        print('Done!')

143
144
145
146
147
148
149
150
151
152
153
154
    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        tmp = 'train' if self.train is True else 'test'
        fmt_str += '    Split: {}\n'.format(tmp)
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str

155

156
class FashionMNIST(MNIST):
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
171
172
173
174
175
176
177
    """
    urls = [
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
    ]
178
179
    classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
               'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
180
181


hysts's avatar
hysts committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
class KMNIST(MNIST):
    """`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    urls = [
        'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz',
        'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz',
        'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz',
        'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz',
    ]
    classes = ['o', 'ki', 'su', 'tsu', 'na', 'ha', 'ma', 'ya', 're', 'wo']


207
class EMNIST(MNIST):
Alex Alemi's avatar
Alex Alemi committed
208
    """`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
            ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
            which one to use.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
Alex Alemi's avatar
Alex Alemi committed
226
227
    # Updated URL from https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist
    url = 'https://cloudstor.aarnet.edu.au/plus/index.php/s/54h3OuGJhFLwAlQ/download'
228
229
230
231
232
233
234
235
236
237
238
    splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')

    def __init__(self, root, split, **kwargs):
        if split not in self.splits:
            raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
                split, ', '.join(self.splits),
            ))
        self.split = split
        self.training_file = self._training_file(split)
        self.test_file = self._test_file(split)
        super(EMNIST, self).__init__(root, **kwargs)
Tian Qi Chen's avatar
Tian Qi Chen committed
239

240
241
    @staticmethod
    def _training_file(split):
242
243
        return 'training_{}.pt'.format(split)

244
245
    @staticmethod
    def _test_file(split):
246
247
248
249
250
251
        return 'test_{}.pt'.format(split)

    def download(self):
        """Download the EMNIST data if it doesn't exist in processed_folder already."""
        import shutil
        import zipfile
252

253
254
255
        if self._check_exists():
            return

256
257
        makedir_exist_ok(self.raw_folder)
        makedir_exist_ok(self.processed_folder)
258

259
        # download files
260
        filename = self.url.rpartition('/')[2]
261
262
        file_path = os.path.join(self.raw_folder, filename)
        download_url(self.url, root=self.raw_folder, filename=filename, md5=None)
263
264
265

        print('Extracting zip archive')
        with zipfile.ZipFile(file_path) as zip_f:
266
            zip_f.extractall(self.raw_folder)
267
        os.unlink(file_path)
268
        gzip_folder = os.path.join(self.raw_folder, 'gzip')
269
270
        for gzip_file in os.listdir(gzip_folder):
            if gzip_file.endswith('.gz'):
271
                self.extract_gzip(gzip_path=os.path.join(gzip_folder, gzip_file))
272
273
274
275
276

        # process and save as torch files
        for split in self.splits:
            print('Processing ' + split)
            training_set = (
277
278
                read_image_file(os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
                read_label_file(os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
279
280
            )
            test_set = (
281
282
                read_image_file(os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
                read_label_file(os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
283
            )
284
            with open(os.path.join(self.processed_folder, self._training_file(split)), 'wb') as f:
285
                torch.save(training_set, f)
286
            with open(os.path.join(self.processed_folder, self._test_file(split)), 'wb') as f:
287
                torch.save(test_set, f)
288
        shutil.rmtree(gzip_folder)
289
290
291
292
293
294

        print('Done!')


def get_int(b):
    return int(codecs.encode(b, 'hex'), 16)
Tian Qi Chen's avatar
Tian Qi Chen committed
295

296

Tian Qi Chen's avatar
Tian Qi Chen committed
297
298
299
300
301
def read_label_file(path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2049
        length = get_int(data[4:8])
302
303
        parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
        return torch.from_numpy(parsed).view(length).long()
Tian Qi Chen's avatar
Tian Qi Chen committed
304

305

Tian Qi Chen's avatar
Tian Qi Chen committed
306
307
308
309
310
311
312
def read_image_file(path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2051
        length = get_int(data[4:8])
        num_rows = get_int(data[8:12])
        num_cols = get_int(data[12:16])
313
314
        parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
        return torch.from_numpy(parsed).view(length, num_rows, num_cols)