mnist.py 12.7 KB
Newer Older
Tian Qi Chen's avatar
Tian Qi Chen committed
1
2
3
4
5
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
6
import gzip
7
import numpy as np
Tian Qi Chen's avatar
Tian Qi Chen committed
8
9
import torch
import codecs
10
from .utils import download_url, makedir_exist_ok
Tian Qi Chen's avatar
Tian Qi Chen committed
11

12

Tian Qi Chen's avatar
Tian Qi Chen committed
13
class MNIST(data.Dataset):
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
Tian Qi Chen's avatar
Tian Qi Chen committed
29
30
31
32
33
34
    urls = [
        'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
    ]
35
36
    training_file = 'training.pt'
    test_file = 'test.pt'
37
38
39
    classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

Tian Qi Chen's avatar
Tian Qi Chen committed
40
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
41
        self.root = os.path.expanduser(root)
Tian Qi Chen's avatar
Tian Qi Chen committed
42
43
        self.transform = transform
        self.target_transform = target_transform
44
        self.train = train  # training set or test set
Tian Qi Chen's avatar
Tian Qi Chen committed
45
46
47
48
49

        if download:
            self.download()

        if not self._check_exists():
50
51
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')
Tian Qi Chen's avatar
Tian Qi Chen committed
52
53

        if self.train:
54
            data_file = self.training_file
Tian Qi Chen's avatar
Tian Qi Chen committed
55
        else:
56
57
            data_file = self.test_file
        self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
Tian Qi Chen's avatar
Tian Qi Chen committed
58
59

    def __getitem__(self, index):
60
61
62
63
64
65
66
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
67
        img, target = self.data[index], int(self.targets[index])
Tian Qi Chen's avatar
Tian Qi Chen committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
82
        return len(self.data)
Tian Qi Chen's avatar
Tian Qi Chen committed
83

84
85
86
87
88
89
90
91
92
93
94
95
    @property
    def raw_folder(self):
        return os.path.join(self.root, self.__class__.__name__, 'raw')

    @property
    def processed_folder(self):
        return os.path.join(self.root, self.__class__.__name__, 'processed')

    @property
    def class_to_idx(self):
        return {_class: i for i, _class in enumerate(self.classes)}

Tian Qi Chen's avatar
Tian Qi Chen committed
96
    def _check_exists(self):
97
98
99
100
101
102
103
104
105
106
107
        return os.path.exists(os.path.join(self.processed_folder, self.training_file)) and \
            os.path.exists(os.path.join(self.processed_folder, self.test_file))

    @staticmethod
    def extract_gzip(gzip_path, remove_finished=False):
        print('Extracting {}'.format(gzip_path))
        with open(gzip_path.replace('.gz', ''), 'wb') as out_f, \
                gzip.GzipFile(gzip_path) as zip_f:
            out_f.write(zip_f.read())
        if remove_finished:
            os.unlink(gzip_path)
Tian Qi Chen's avatar
Tian Qi Chen committed
108
109

    def download(self):
110
        """Download the MNIST data if it doesn't exist in processed_folder already."""
Tian Qi Chen's avatar
Tian Qi Chen committed
111
112
113
114

        if self._check_exists():
            return

115
116
        makedir_exist_ok(self.raw_folder)
        makedir_exist_ok(self.processed_folder)
Tian Qi Chen's avatar
Tian Qi Chen committed
117

118
        # download files
Tian Qi Chen's avatar
Tian Qi Chen committed
119
120
        for url in self.urls:
            filename = url.rpartition('/')[2]
121
            file_path = os.path.join(self.raw_folder, filename)
122
            download_url(url, root=self.raw_folder, filename=filename, md5=None)
123
            self.extract_gzip(gzip_path=file_path, remove_finished=True)
Tian Qi Chen's avatar
Tian Qi Chen committed
124
125

        # process and save as torch files
Adam Paszke's avatar
Adam Paszke committed
126
127
        print('Processing...')

Tian Qi Chen's avatar
Tian Qi Chen committed
128
        training_set = (
129
130
            read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
            read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
Tian Qi Chen's avatar
Tian Qi Chen committed
131
132
        )
        test_set = (
133
134
            read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
            read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
Tian Qi Chen's avatar
Tian Qi Chen committed
135
        )
136
        with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
Tian Qi Chen's avatar
Tian Qi Chen committed
137
            torch.save(training_set, f)
138
        with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
Tian Qi Chen's avatar
Tian Qi Chen committed
139
140
141
142
            torch.save(test_set, f)

        print('Done!')

143
144
145
146
147
148
149
150
151
152
153
154
    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        tmp = 'train' if self.train is True else 'test'
        fmt_str += '    Split: {}\n'.format(tmp)
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str

155

156
class FashionMNIST(MNIST):
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
171
172
173
174
175
176
177
    """
    urls = [
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
    ]
178
179
    classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
               'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
180
181


hysts's avatar
hysts committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
class KMNIST(MNIST):
    """`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    urls = [
        'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz',
        'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz',
        'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz',
        'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz',
    ]
    classes = ['o', 'ki', 'su', 'tsu', 'na', 'ha', 'ma', 'ya', 're', 'wo']


207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
class EMNIST(MNIST):
    """`EMNIST <https://www.nist.gov/itl/iad/image-group/emnist-dataset/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
            ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
            which one to use.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
226
    url = 'http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip'
227
228
229
230
231
232
233
234
235
236
237
    splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')

    def __init__(self, root, split, **kwargs):
        if split not in self.splits:
            raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
                split, ', '.join(self.splits),
            ))
        self.split = split
        self.training_file = self._training_file(split)
        self.test_file = self._test_file(split)
        super(EMNIST, self).__init__(root, **kwargs)
Tian Qi Chen's avatar
Tian Qi Chen committed
238

239
240
    @staticmethod
    def _training_file(split):
241
242
        return 'training_{}.pt'.format(split)

243
244
    @staticmethod
    def _test_file(split):
245
246
247
248
249
250
        return 'test_{}.pt'.format(split)

    def download(self):
        """Download the EMNIST data if it doesn't exist in processed_folder already."""
        import shutil
        import zipfile
251

252
253
254
        if self._check_exists():
            return

255
256
        makedir_exist_ok(self.raw_folder)
        makedir_exist_ok(self.processed_folder)
257

258
        # download files
259
        filename = self.url.rpartition('/')[2]
260
261
        file_path = os.path.join(self.raw_folder, filename)
        download_url(self.url, root=self.raw_folder, filename=filename, md5=None)
262
263
264

        print('Extracting zip archive')
        with zipfile.ZipFile(file_path) as zip_f:
265
            zip_f.extractall(self.raw_folder)
266
        os.unlink(file_path)
267
        gzip_folder = os.path.join(self.raw_folder, 'gzip')
268
269
        for gzip_file in os.listdir(gzip_folder):
            if gzip_file.endswith('.gz'):
270
                self.extract_gzip(gzip_path=os.path.join(gzip_folder, gzip_file))
271
272
273
274
275

        # process and save as torch files
        for split in self.splits:
            print('Processing ' + split)
            training_set = (
276
277
                read_image_file(os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
                read_label_file(os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
278
279
            )
            test_set = (
280
281
                read_image_file(os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
                read_label_file(os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
282
            )
283
            with open(os.path.join(self.processed_folder, self._training_file(split)), 'wb') as f:
284
                torch.save(training_set, f)
285
            with open(os.path.join(self.processed_folder, self._test_file(split)), 'wb') as f:
286
                torch.save(test_set, f)
287
        shutil.rmtree(gzip_folder)
288
289
290
291
292
293

        print('Done!')


def get_int(b):
    return int(codecs.encode(b, 'hex'), 16)
Tian Qi Chen's avatar
Tian Qi Chen committed
294

295

Tian Qi Chen's avatar
Tian Qi Chen committed
296
297
298
299
300
def read_label_file(path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2049
        length = get_int(data[4:8])
301
302
        parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
        return torch.from_numpy(parsed).view(length).long()
Tian Qi Chen's avatar
Tian Qi Chen committed
303

304

Tian Qi Chen's avatar
Tian Qi Chen committed
305
306
307
308
309
310
311
def read_image_file(path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2051
        length = get_int(data[4:8])
        num_rows = get_int(data[8:12])
        num_cols = get_int(data[12:16])
312
313
        parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
        return torch.from_numpy(parsed).view(length, num_rows, num_cols)