Commit fc7911c8 authored by Danylo Ulianych's avatar Danylo Ulianych Committed by Francisco Massa
Browse files

CIFAR: permanent 'data' and 'targets' fields (#594)

parent f3d5e85d
...@@ -51,13 +51,6 @@ class CIFAR10(data.Dataset): ...@@ -51,13 +51,6 @@ class CIFAR10(data.Dataset):
'md5': '5ff9c542aee3614f3951f8cda6e48888', 'md5': '5ff9c542aee3614f3951f8cda6e48888',
} }
@property
def targets(self):
if self.train:
return self.train_labels
else:
return self.test_labels
def __init__(self, root, train=True, def __init__(self, root, train=True,
transform=None, target_transform=None, transform=None, target_transform=None,
download=False): download=False):
...@@ -73,44 +66,30 @@ class CIFAR10(data.Dataset): ...@@ -73,44 +66,30 @@ class CIFAR10(data.Dataset):
raise RuntimeError('Dataset not found or corrupted.' + raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it') ' You can use download=True to download it')
# now load the picked numpy arrays
if self.train: if self.train:
self.train_data = [] downloaded_list = self.train_list
self.train_labels = []
for fentry in self.train_list:
f = fentry[0]
file = os.path.join(self.root, self.base_folder, f)
fo = open(file, 'rb')
if sys.version_info[0] == 2:
entry = pickle.load(fo)
else:
entry = pickle.load(fo, encoding='latin1')
self.train_data.append(entry['data'])
if 'labels' in entry:
self.train_labels += entry['labels']
else: else:
self.train_labels += entry['fine_labels'] downloaded_list = self.test_list
fo.close()
self.train_data = np.concatenate(self.train_data) self.data = []
self.train_data = self.train_data.reshape((50000, 3, 32, 32)) self.targets = []
self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC
else: # now load the picked numpy arrays
f = self.test_list[0][0] for file_name, checksum in downloaded_list:
file = os.path.join(self.root, self.base_folder, f) file_path = os.path.join(self.root, self.base_folder, file_name)
fo = open(file, 'rb') with open(file_path, 'rb') as f:
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
entry = pickle.load(fo) entry = pickle.load(f)
else: else:
entry = pickle.load(fo, encoding='latin1') entry = pickle.load(f, encoding='latin1')
self.test_data = entry['data'] self.data.append(entry['data'])
if 'labels' in entry: if 'labels' in entry:
self.test_labels = entry['labels'] self.targets.extend(entry['labels'])
else: else:
self.test_labels = entry['fine_labels'] self.targets.extend(entry['fine_labels'])
fo.close()
self.test_data = self.test_data.reshape((10000, 3, 32, 32)) self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
self._load_meta() self._load_meta()
...@@ -135,10 +114,7 @@ class CIFAR10(data.Dataset): ...@@ -135,10 +114,7 @@ class CIFAR10(data.Dataset):
Returns: Returns:
tuple: (image, target) where target is index of the target class. tuple: (image, target) where target is index of the target class.
""" """
if self.train: img, target = self.data[index], self.targets[index]
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
# doing this so that it is consistent with all other datasets # doing this so that it is consistent with all other datasets
# to return a PIL Image # to return a PIL Image
...@@ -153,10 +129,7 @@ class CIFAR10(data.Dataset): ...@@ -153,10 +129,7 @@ class CIFAR10(data.Dataset):
return img, target return img, target
def __len__(self): def __len__(self):
if self.train: return len(self.data)
return len(self.train_data)
else:
return len(self.test_data)
def _check_integrity(self): def _check_integrity(self):
root = self.root root = self.root
...@@ -174,16 +147,11 @@ class CIFAR10(data.Dataset): ...@@ -174,16 +147,11 @@ class CIFAR10(data.Dataset):
print('Files already downloaded and verified') print('Files already downloaded and verified')
return return
root = self.root download_url(self.url, self.root, self.filename, self.tgz_md5)
download_url(self.url, root, self.filename, self.tgz_md5)
# extract file # extract file
cwd = os.getcwd() with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
tar = tarfile.open(os.path.join(root, self.filename), "r:gz") tar.extractall(path=self.root)
os.chdir(root)
tar.extractall()
tar.close()
os.chdir(cwd)
def __repr__(self): def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment