Commit 628e90cb authored by David Morton's avatar David Morton Committed by Soumith Chintala
Browse files

Add metadata to some datasets (#501)

* Add classes metadata to MNIST and FashionMNIST

* Add `targets` property to MNIST and FashionMNIST

* Add class metadata to CIFAR10/CIFAR100

* Add `targets` property to CIFAR10/CIFAR100

* Add targets attribute to DatasetFolder
parent 1a47a44d
......@@ -45,6 +45,18 @@ class CIFAR10(data.Dataset):
test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'],
]
meta = {
'filename': 'batches.meta',
'key': 'label_names',
'md5': '5ff9c542aee3614f3951f8cda6e48888',
}
@property
def targets(self):
if self.train:
return self.train_labels
else:
return self.test_labels
def __init__(self, root, train=True,
transform=None, target_transform=None,
......@@ -100,6 +112,21 @@ class CIFAR10(data.Dataset):
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC
self._load_meta()
def _load_meta(self):
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
if not check_integrity(path, self.meta['md5']):
raise RuntimeError('Dataset metadata file not found or corrupted.' +
' You can use download=True to download it')
with open(path, 'rb') as infile:
if sys.version_info[0] == 2:
data = pickle.load(infile)
else:
data = pickle.load(infile, encoding='latin1')
self.classes = data[self.meta['key']]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
def __getitem__(self, index):
"""
Args:
......@@ -187,3 +214,8 @@ class CIFAR100(CIFAR10):
test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
]
meta = {
'filename': 'meta',
'key': 'fine_label_names',
'md5': '7973b15100ade9c7d40fb424638fde48',
}
......@@ -69,6 +69,7 @@ class DatasetFolder(data.Dataset):
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(self, root, loader, extensions, transform=None, target_transform=None):
......@@ -85,6 +86,7 @@ class DatasetFolder(data.Dataset):
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
self.transform = transform
self.target_transform = target_transform
......
......@@ -35,6 +35,16 @@ class MNIST(data.Dataset):
processed_folder = 'processed'
training_file = 'training.pt'
test_file = 'test.pt'
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
class_to_idx = {_class: i for i, _class in enumerate(classes)}
@property
def targets(self):
if self.train:
return self.train_labels
else:
return self.test_labels
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
......@@ -174,6 +184,9 @@ class FashionMNIST(MNIST):
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
]
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
class_to_idx = {_class: i for i, _class in enumerate(classes)}
class EMNIST(MNIST):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment