cifar.py 6.63 KB
Newer Older
Soumith Chintala's avatar
Soumith Chintala committed
1
2
3
4
5
6
7
8
9
10
11
from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
import sys
if sys.version_info[0] == 2:
    import cPickle as pickle
else:
    import pickle

soumith's avatar
soumith committed
12
import torch.utils.data as data
soumith's avatar
soumith committed
13
from .utils import download_url, check_integrity
14

15

Soumith Chintala's avatar
Soumith Chintala committed
16
class CIFAR10(data.Dataset):
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``cifar-10-batches-py`` exists.
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """
Soumith Chintala's avatar
Soumith Chintala committed
33
    base_folder = 'cifar-10-batches-py'
Tzu-Wei Huang's avatar
Tzu-Wei Huang committed
34
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
Soumith Chintala's avatar
Soumith Chintala committed
35
    filename = "cifar-10-python.tar.gz"
zhoumingjun's avatar
zhoumingjun committed
36
    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
Soumith Chintala's avatar
Soumith Chintala committed
37
    train_list = [
38
39
40
41
42
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
Soumith Chintala's avatar
Soumith Chintala committed
43
44
45
    ]

    test_list = [
46
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
Soumith Chintala's avatar
Soumith Chintala committed
47
48
    ]

49
50
51
    def __init__(self, root, train=True,
                 transform=None, target_transform=None,
                 download=False):
52
        self.root = os.path.expanduser(root)
Soumith Chintala's avatar
Soumith Chintala committed
53
54
        self.transform = transform
        self.target_transform = target_transform
55
56
        self.train = train  # training set or test set

Soumith Chintala's avatar
Soumith Chintala committed
57
58
59
60
        if download:
            self.download()

        if not self._check_integrity():
61
62
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')
63

Soumith Chintala's avatar
Soumith Chintala committed
64
        # now load the picked numpy arrays
65
66
67
68
69
        if self.train:
            self.train_data = []
            self.train_labels = []
            for fentry in self.train_list:
                f = fentry[0]
moskomule's avatar
moskomule committed
70
                file = os.path.join(self.root, self.base_folder, f)
71
                fo = open(file, 'rb')
Adam Lerer's avatar
Adam Lerer committed
72
73
74
75
                if sys.version_info[0] == 2:
                    entry = pickle.load(fo)
                else:
                    entry = pickle.load(fo, encoding='latin1')
76
77
78
79
80
81
82
83
84
                self.train_data.append(entry['data'])
                if 'labels' in entry:
                    self.train_labels += entry['labels']
                else:
                    self.train_labels += entry['fine_labels']
                fo.close()

            self.train_data = np.concatenate(self.train_data)
            self.train_data = self.train_data.reshape((50000, 3, 32, 32))
85
            self.train_data = self.train_data.transpose((0, 2, 3, 1))  # convert to HWC
86
87
        else:
            f = self.test_list[0][0]
moskomule's avatar
moskomule committed
88
            file = os.path.join(self.root, self.base_folder, f)
Soumith Chintala's avatar
Soumith Chintala committed
89
            fo = open(file, 'rb')
90
91
92
93
            if sys.version_info[0] == 2:
                entry = pickle.load(fo)
            else:
                entry = pickle.load(fo, encoding='latin1')
94
            self.test_data = entry['data']
Soumith Chintala's avatar
Soumith Chintala committed
95
            if 'labels' in entry:
96
                self.test_labels = entry['labels']
Soumith Chintala's avatar
Soumith Chintala committed
97
            else:
98
                self.test_labels = entry['fine_labels']
Soumith Chintala's avatar
Soumith Chintala committed
99
            fo.close()
100
            self.test_data = self.test_data.reshape((10000, 3, 32, 32))
101
            self.test_data = self.test_data.transpose((0, 2, 3, 1))  # convert to HWC
Soumith Chintala's avatar
Soumith Chintala committed
102
103

    def __getitem__(self, index):
104
105
106
107
108
109
110
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
Soumith Chintala's avatar
Soumith Chintala committed
111
112
113
114
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]
115

116
117
        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
118
        img = Image.fromarray(img)
Soumith Chintala's avatar
Soumith Chintala committed
119
120
121
122
123
124
125
126
127
128
129

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        if self.train:
130
            return len(self.train_data)
Soumith Chintala's avatar
Soumith Chintala committed
131
        else:
132
            return len(self.test_data)
Soumith Chintala's avatar
Soumith Chintala committed
133
134
135
136
137
138

    def _check_integrity(self):
        root = self.root
        for fentry in (self.train_list + self.test_list):
            filename, md5 = fentry[0], fentry[1]
            fpath = os.path.join(root, self.base_folder, filename)
soumith's avatar
soumith committed
139
            if not check_integrity(fpath, md5):
Soumith Chintala's avatar
Soumith Chintala committed
140
141
142
143
144
145
146
147
148
                return False
        return True

    def download(self):
        import tarfile

        if self._check_integrity():
            print('Files already downloaded and verified')
            return
149

150
        root = self.root
soumith's avatar
soumith committed
151
        download_url(self.url, root, self.filename, self.tgz_md5)
Soumith Chintala's avatar
Soumith Chintala committed
152
153
154

        # extract file
        cwd = os.getcwd()
155
        tar = tarfile.open(os.path.join(root, self.filename), "r:gz")
156
        os.chdir(root)
Soumith Chintala's avatar
Soumith Chintala committed
157
158
159
160
        tar.extractall()
        tar.close()
        os.chdir(cwd)

161
162
163
164
165
166
167
168
169
170
171
172
    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        tmp = 'train' if self.train is True else 'test'
        fmt_str += '    Split: {}\n'.format(tmp)
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str

Soumith Chintala's avatar
Soumith Chintala committed
173
174

class CIFAR100(CIFAR10):
jvmancuso's avatar
jvmancuso committed
175
176
177
178
    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

    This is a subclass of the `CIFAR10` Dataset.
    """
Soumith Chintala's avatar
Soumith Chintala committed
179
    base_folder = 'cifar-100-python'
Tzu-Wei Huang's avatar
Tzu-Wei Huang committed
180
    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
Soumith Chintala's avatar
Soumith Chintala committed
181
182
183
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
    train_list = [
184
        ['train', '16019d7e3df5f24257cddd939b257f8d'],
Soumith Chintala's avatar
Soumith Chintala committed
185
186
187
    ]

    test_list = [
188
        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
Soumith Chintala's avatar
Soumith Chintala committed
189
    ]