Commit 71755dad authored by Soumith Chintala's avatar Soumith Chintala
Browse files

cifar10/100 only load train/test and not both

parent 2d55b9d5
......@@ -42,35 +42,35 @@ class CIFAR10(data.Dataset):
+ ' You can use download=True to download it')
# now load the picked numpy arrays
self.train_data = []
self.train_labels = []
for fentry in self.train_list:
f = fentry[0]
if self.train:
self.train_data = []
self.train_labels = []
for fentry in self.train_list:
f = fentry[0]
file = os.path.join(root, self.base_folder, f)
fo = open(file, 'rb')
entry = pickle.load(fo)
self.train_data.append(entry['data'])
if 'labels' in entry:
self.train_labels += entry['labels']
else:
self.train_labels += entry['fine_labels']
fo.close()
self.train_data = np.concatenate(self.train_data)
self.train_data = self.train_data.reshape((50000, 3, 32, 32))
else:
f = self.test_list[0][0]
file = os.path.join(root, self.base_folder, f)
fo = open(file, 'rb')
entry = pickle.load(fo)
self.train_data.append(entry['data'])
self.test_data = entry['data']
if 'labels' in entry:
self.train_labels += entry['labels']
self.test_labels = entry['labels']
else:
self.train_labels += entry['fine_labels']
self.test_labels = entry['fine_labels']
fo.close()
self.train_data = np.concatenate(self.train_data)
f = self.test_list[0][0]
file = os.path.join(root, self.base_folder, f)
fo = open(file, 'rb')
entry = pickle.load(fo)
self.test_data = entry['data']
if 'labels' in entry:
self.test_labels = entry['labels']
else:
self.test_labels = entry['fine_labels']
fo.close()
self.train_data = self.train_data.reshape((50000, 3, 32, 32))
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
def __getitem__(self, index):
if self.train:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment