Commit c74b79c8 authored by Danylo Ulianych's avatar Danylo Ulianych Committed by Francisco Massa
Browse files

MNIST loader refactored: permanent 'data' and 'targets' fields (#578)

parent fe973cee
......@@ -40,13 +40,6 @@ class MNIST(data.Dataset):
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
class_to_idx = {_class: i for i, _class in enumerate(classes)}
@property
def targets(self):
if self.train:
return self.train_labels
else:
return self.test_labels
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
self.transform = transform
......@@ -61,11 +54,10 @@ class MNIST(data.Dataset):
' You can use download=True to download it')
if self.train:
self.train_data, self.train_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.training_file))
data_file = self.training_file
else:
self.test_data, self.test_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.test_file))
data_file = self.test_file
self.data, self.targets = torch.load(os.path.join(self.root, self.processed_folder, data_file))
def __getitem__(self, index):
"""
......@@ -75,10 +67,7 @@ class MNIST(data.Dataset):
Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
......@@ -93,10 +82,7 @@ class MNIST(data.Dataset):
return img, target
def __len__(self):
if self.train:
return len(self.train_data)
else:
return len(self.test_data)
return len(self.data)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
......@@ -104,7 +90,6 @@ class MNIST(data.Dataset):
def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import gzip
if self._check_exists():
......@@ -228,7 +213,6 @@ class EMNIST(MNIST):
def download(self):
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import gzip
import shutil
import zipfile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment