cifar.py 5.87 KB
Newer Older
Soumith Chintala's avatar
Soumith Chintala committed
1
2
3
4
5
6
7
8
9
10
11
12
from __future__ import print_function
from PIL import Image
import os
import os.path
import errno
import numpy as np
import sys
if sys.version_info[0] == 2:
    import cPickle as pickle
else:
    import pickle

soumith's avatar
soumith committed
13
import torch.utils.data as data
soumith's avatar
soumith committed
14
from .utils import download_url, check_integrity
15

16

Soumith Chintala's avatar
Soumith Chintala committed
17
class CIFAR10(data.Dataset):
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``cifar-10-batches-py`` exists.
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """
Soumith Chintala's avatar
Soumith Chintala committed
34
    base_folder = 'cifar-10-batches-py'
Tzu-Wei Huang's avatar
Tzu-Wei Huang committed
35
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
Soumith Chintala's avatar
Soumith Chintala committed
36
    filename = "cifar-10-python.tar.gz"
zhoumingjun's avatar
zhoumingjun committed
37
    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
Soumith Chintala's avatar
Soumith Chintala committed
38
    train_list = [
39
40
41
42
43
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
Soumith Chintala's avatar
Soumith Chintala committed
44
45
46
    ]

    test_list = [
47
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
Soumith Chintala's avatar
Soumith Chintala committed
48
49
    ]

50
51
52
    def __init__(self, root, train=True,
                 transform=None, target_transform=None,
                 download=False):
53
        self.root = os.path.expanduser(root)
Soumith Chintala's avatar
Soumith Chintala committed
54
55
        self.transform = transform
        self.target_transform = target_transform
56
57
        self.train = train  # training set or test set

Soumith Chintala's avatar
Soumith Chintala committed
58
59
60
61
        if download:
            self.download()

        if not self._check_integrity():
62
63
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')
64

Soumith Chintala's avatar
Soumith Chintala committed
65
        # now load the picked numpy arrays
66
67
68
69
70
        if self.train:
            self.train_data = []
            self.train_labels = []
            for fentry in self.train_list:
                f = fentry[0]
moskomule's avatar
moskomule committed
71
                file = os.path.join(self.root, self.base_folder, f)
72
                fo = open(file, 'rb')
Adam Lerer's avatar
Adam Lerer committed
73
74
75
76
                if sys.version_info[0] == 2:
                    entry = pickle.load(fo)
                else:
                    entry = pickle.load(fo, encoding='latin1')
77
78
79
80
81
82
83
84
85
                self.train_data.append(entry['data'])
                if 'labels' in entry:
                    self.train_labels += entry['labels']
                else:
                    self.train_labels += entry['fine_labels']
                fo.close()

            self.train_data = np.concatenate(self.train_data)
            self.train_data = self.train_data.reshape((50000, 3, 32, 32))
86
            self.train_data = self.train_data.transpose((0, 2, 3, 1))  # convert to HWC
87
88
        else:
            f = self.test_list[0][0]
moskomule's avatar
moskomule committed
89
            file = os.path.join(self.root, self.base_folder, f)
Soumith Chintala's avatar
Soumith Chintala committed
90
            fo = open(file, 'rb')
91
92
93
94
            if sys.version_info[0] == 2:
                entry = pickle.load(fo)
            else:
                entry = pickle.load(fo, encoding='latin1')
95
            self.test_data = entry['data']
Soumith Chintala's avatar
Soumith Chintala committed
96
            if 'labels' in entry:
97
                self.test_labels = entry['labels']
Soumith Chintala's avatar
Soumith Chintala committed
98
            else:
99
                self.test_labels = entry['fine_labels']
Soumith Chintala's avatar
Soumith Chintala committed
100
            fo.close()
101
            self.test_data = self.test_data.reshape((10000, 3, 32, 32))
102
            self.test_data = self.test_data.transpose((0, 2, 3, 1))  # convert to HWC
Soumith Chintala's avatar
Soumith Chintala committed
103
104

    def __getitem__(self, index):
105
106
107
108
109
110
111
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
Soumith Chintala's avatar
Soumith Chintala committed
112
113
114
115
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]
116

117
118
        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
119
        img = Image.fromarray(img)
Soumith Chintala's avatar
Soumith Chintala committed
120
121
122
123
124
125
126
127
128
129
130

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        if self.train:
131
            return len(self.train_data)
Soumith Chintala's avatar
Soumith Chintala committed
132
        else:
133
            return len(self.test_data)
Soumith Chintala's avatar
Soumith Chintala committed
134
135
136
137
138
139

    def _check_integrity(self):
        root = self.root
        for fentry in (self.train_list + self.test_list):
            filename, md5 = fentry[0], fentry[1]
            fpath = os.path.join(root, self.base_folder, filename)
soumith's avatar
soumith committed
140
            if not check_integrity(fpath, md5):
Soumith Chintala's avatar
Soumith Chintala committed
141
142
143
144
145
146
147
148
149
                return False
        return True

    def download(self):
        import tarfile

        if self._check_integrity():
            print('Files already downloaded and verified')
            return
150

151
        root = self.root
soumith's avatar
soumith committed
152
        download_url(self.url, root, self.filename, self.tgz_md5)
Soumith Chintala's avatar
Soumith Chintala committed
153
154
155

        # extract file
        cwd = os.getcwd()
156
        tar = tarfile.open(os.path.join(root, self.filename), "r:gz")
157
        os.chdir(root)
Soumith Chintala's avatar
Soumith Chintala committed
158
159
160
161
162
163
164
        tar.extractall()
        tar.close()
        os.chdir(cwd)


class CIFAR100(CIFAR10):
    base_folder = 'cifar-100-python'
Tzu-Wei Huang's avatar
Tzu-Wei Huang committed
165
    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
Soumith Chintala's avatar
Soumith Chintala committed
166
167
168
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
    train_list = [
169
        ['train', '16019d7e3df5f24257cddd939b257f8d'],
Soumith Chintala's avatar
Soumith Chintala committed
170
171
172
    ]

    test_list = [
173
        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
Soumith Chintala's avatar
Soumith Chintala committed
174
    ]