Commit fc7911c8 authored by Danylo Ulianych's avatar Danylo Ulianych Committed by Francisco Massa
Browse files

CIFAR: permanent 'data' and 'targets' fields (#594)

parent f3d5e85d
......@@ -51,13 +51,6 @@ class CIFAR10(data.Dataset):
'md5': '5ff9c542aee3614f3951f8cda6e48888',
}
@property
def targets(self):
if self.train:
return self.train_labels
else:
return self.test_labels
def __init__(self, root, train=True,
transform=None, target_transform=None,
download=False):
......@@ -73,44 +66,30 @@ class CIFAR10(data.Dataset):
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
# now load the picked numpy arrays
if self.train:
self.train_data = []
self.train_labels = []
for fentry in self.train_list:
f = fentry[0]
file = os.path.join(self.root, self.base_folder, f)
fo = open(file, 'rb')
downloaded_list = self.train_list
else:
downloaded_list = self.test_list
self.data = []
self.targets = []
# now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
if sys.version_info[0] == 2:
entry = pickle.load(fo)
entry = pickle.load(f)
else:
entry = pickle.load(fo, encoding='latin1')
self.train_data.append(entry['data'])
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
if 'labels' in entry:
self.train_labels += entry['labels']
self.targets.extend(entry['labels'])
else:
self.train_labels += entry['fine_labels']
fo.close()
self.targets.extend(entry['fine_labels'])
self.train_data = np.concatenate(self.train_data)
self.train_data = self.train_data.reshape((50000, 3, 32, 32))
self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC
else:
f = self.test_list[0][0]
file = os.path.join(self.root, self.base_folder, f)
fo = open(file, 'rb')
if sys.version_info[0] == 2:
entry = pickle.load(fo)
else:
entry = pickle.load(fo, encoding='latin1')
self.test_data = entry['data']
if 'labels' in entry:
self.test_labels = entry['labels']
else:
self.test_labels = entry['fine_labels']
fo.close()
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
self._load_meta()
......@@ -135,10 +114,7 @@ class CIFAR10(data.Dataset):
Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
......@@ -153,10 +129,7 @@ class CIFAR10(data.Dataset):
return img, target
def __len__(self):
if self.train:
return len(self.train_data)
else:
return len(self.test_data)
return len(self.data)
def _check_integrity(self):
root = self.root
......@@ -174,16 +147,11 @@ class CIFAR10(data.Dataset):
print('Files already downloaded and verified')
return
root = self.root
download_url(self.url, root, self.filename, self.tgz_md5)
download_url(self.url, self.root, self.filename, self.tgz_md5)
# extract file
cwd = os.getcwd()
tar = tarfile.open(os.path.join(root, self.filename), "r:gz")
os.chdir(root)
tar.extractall()
tar.close()
os.chdir(cwd)
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
tar.extractall(path=self.root)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment