"git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "79b6fbd84aa5a06b0e9ea3a829fc116444ff64b9"
Commit 972b80c9 authored by Taihyun Hwang's avatar Taihyun Hwang Committed by Soumith Chintala
Browse files

Load and parse metadata for CIFAR-10, CIFAR-100 (#502)

* cifar10.meta['label_names']

* cifar100.meta['fine_label_names']

* cifar100.meta['coarse_label_names']
parent 9f28cff7
...@@ -46,6 +46,10 @@ class CIFAR10(data.Dataset): ...@@ -46,6 +46,10 @@ class CIFAR10(data.Dataset):
['test_batch', '40351d587109b95175f43aff81a1287e'], ['test_batch', '40351d587109b95175f43aff81a1287e'],
] ]
meta_list = [
['batches.meta', '5ff9c542aee3614f3951f8cda6e48888'],
]
def __init__(self, root, train=True, def __init__(self, root, train=True,
transform=None, target_transform=None, transform=None, target_transform=None,
download=False): download=False):
...@@ -100,6 +104,16 @@ class CIFAR10(data.Dataset): ...@@ -100,6 +104,16 @@ class CIFAR10(data.Dataset):
self.test_data = self.test_data.reshape((10000, 3, 32, 32)) self.test_data = self.test_data.reshape((10000, 3, 32, 32))
self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC
f = self.meta_list[0][0]
file = os.path.join(self.root, self.base_folder, f)
fo = open(file, 'rb')
if sys.version_info[0] == 2:
entry = pickle.load(fo)
else:
entry = pickle.load(fo, encoding='latin1')
fo.close()
self.meta = entry
def __getitem__(self, index): def __getitem__(self, index):
""" """
Args: Args:
...@@ -133,7 +147,7 @@ class CIFAR10(data.Dataset): ...@@ -133,7 +147,7 @@ class CIFAR10(data.Dataset):
def _check_integrity(self): def _check_integrity(self):
root = self.root root = self.root
for fentry in (self.train_list + self.test_list): for fentry in (self.train_list + self.test_list + self.meta_list):
filename, md5 = fentry[0], fentry[1] filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename) fpath = os.path.join(root, self.base_folder, filename)
if not check_integrity(fpath, md5): if not check_integrity(fpath, md5):
...@@ -187,3 +201,7 @@ class CIFAR100(CIFAR10): ...@@ -187,3 +201,7 @@ class CIFAR100(CIFAR10):
test_list = [ test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
] ]
meta_list = [
['meta', '7973b15100ade9c7d40fb424638fde48'],
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment