from __future__ import print_function import torch.utils.data as data from PIL import Image import os import os.path import errno import numpy as np import sys if sys.version_info[0] == 2: import cPickle as pickle else: import pickle class CIFAR10(data.Dataset): base_folder = 'cifar-10-batches-py' url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" filename = "cifar-10-python.tar.gz" tgz_mdf = 'c58f30108f718f92721af3b95e74349a' train_list = [ ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], ['data_batch_4', '634d18415352ddfa80567beed471001a'], ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], ] test_list = [ ['test_batch', '40351d587109b95175f43aff81a1287e'], ] def __init__(self, root, train=True, transform=None, target_transform=None, download=False): self.root = root self.transform = transform self.target_transform = target_transform self.train = train # training set or test set if download: self.download() if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') # now load the picked numpy arrays self.train_data = [] self.train_labels = [] for fentry in self.train_list: f = fentry[0] file = os.path.join(root, self.base_folder, f) fo = open(file, 'rb') entry = pickle.load(fo) self.train_data.append(entry['data']) if 'labels' in entry: self.train_labels += entry['labels'] else: self.train_labels += entry['fine_labels'] fo.close() self.train_data = np.concatenate(self.train_data) f = self.test_list[0][0] file = os.path.join(root, self.base_folder, f) fo = open(file, 'rb') entry = pickle.load(fo) self.test_data = entry['data'] if 'labels' in entry: self.test_labels = entry['labels'] else: self.test_labels = entry['fine_labels'] fo.close() self.train_data = self.train_data.reshape((50000, 3, 32, 32)) self.test_data = self.test_data.reshape((10000, 3, 32, 32)) def __getitem__(self, index): if self.train: img, target = self.train_data[index], self.train_labels[index] else: img, target = self.test_data[index], self.test_labels[index] if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): if self.train: return 50000 else: return 10000 def _check_integrity(self): import hashlib root = self.root for fentry in (self.train_list + self.test_list): filename, md5 = fentry[0], fentry[1] fpath = os.path.join(root, self.base_folder, filename) if not os.path.isfile(fpath): return False md5c = hashlib.md5(open(fpath, 'rb').read()).hexdigest() if md5c != md5: return False return True def download(self): from six.moves import urllib import tarfile import hashlib root = self.root fpath = os.path.join(root, self.filename) try: os.makedirs(root) except OSError as e: if e.errno == errno.EEXIST: pass else: raise if self._check_integrity(): print('Files already downloaded and verified') return # downloads file if os.path.isfile(fpath) and \ hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.tgz_md5: print('Using downloaded file: ' + fpath) else: print('Downloading ' + self.url + ' to ' + fpath) urllib.request.urlretrieve(self.url, fpath) # extract file cwd = os.getcwd() print('Extracting tar file') tar = tarfile.open(fpath, "r:gz") os.chdir(root) tar.extractall() tar.close() os.chdir(cwd) print('Done!') class CIFAR100(CIFAR10): base_folder = 'cifar-100-python' url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" filename = "cifar-100-python.tar.gz" tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' train_list = [ ['train', '16019d7e3df5f24257cddd939b257f8d'], ] test_list = [ ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], ]