Commit 71755dad authored by Soumith Chintala's avatar Soumith Chintala
Browse files

cifar10/100 only load train/test and not both

parent 2d55b9d5
...@@ -42,6 +42,7 @@ class CIFAR10(data.Dataset): ...@@ -42,6 +42,7 @@ class CIFAR10(data.Dataset):
+ ' You can use download=True to download it') + ' You can use download=True to download it')
# now load the picked numpy arrays # now load the picked numpy arrays
if self.train:
self.train_data = [] self.train_data = []
self.train_labels = [] self.train_labels = []
for fentry in self.train_list: for fentry in self.train_list:
...@@ -57,7 +58,8 @@ class CIFAR10(data.Dataset): ...@@ -57,7 +58,8 @@ class CIFAR10(data.Dataset):
fo.close() fo.close()
self.train_data = np.concatenate(self.train_data) self.train_data = np.concatenate(self.train_data)
self.train_data = self.train_data.reshape((50000, 3, 32, 32))
else:
f = self.test_list[0][0] f = self.test_list[0][0]
file = os.path.join(root, self.base_folder, f) file = os.path.join(root, self.base_folder, f)
fo = open(file, 'rb') fo = open(file, 'rb')
...@@ -68,8 +70,6 @@ class CIFAR10(data.Dataset): ...@@ -68,8 +70,6 @@ class CIFAR10(data.Dataset):
else: else:
self.test_labels = entry['fine_labels'] self.test_labels = entry['fine_labels']
fo.close() fo.close()
self.train_data = self.train_data.reshape((50000, 3, 32, 32))
self.test_data = self.test_data.reshape((10000, 3, 32, 32)) self.test_data = self.test_data.reshape((10000, 3, 32, 32))
def __getitem__(self, index): def __getitem__(self, index):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment