Commit c74b79c8 authored by Danylo Ulianych's avatar Danylo Ulianych Committed by Francisco Massa
Browse files

MNIST loader refactored: permanent 'data' and 'targets' fields (#578)

parent fe973cee
...@@ -40,13 +40,6 @@ class MNIST(data.Dataset): ...@@ -40,13 +40,6 @@ class MNIST(data.Dataset):
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
class_to_idx = {_class: i for i, _class in enumerate(classes)} class_to_idx = {_class: i for i, _class in enumerate(classes)}
@property
def targets(self):
if self.train:
return self.train_labels
else:
return self.test_labels
def __init__(self, root, train=True, transform=None, target_transform=None, download=False): def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root) self.root = os.path.expanduser(root)
self.transform = transform self.transform = transform
...@@ -61,11 +54,10 @@ class MNIST(data.Dataset): ...@@ -61,11 +54,10 @@ class MNIST(data.Dataset):
' You can use download=True to download it') ' You can use download=True to download it')
if self.train: if self.train:
self.train_data, self.train_labels = torch.load( data_file = self.training_file
os.path.join(self.root, self.processed_folder, self.training_file))
else: else:
self.test_data, self.test_labels = torch.load( data_file = self.test_file
os.path.join(self.root, self.processed_folder, self.test_file)) self.data, self.targets = torch.load(os.path.join(self.root, self.processed_folder, data_file))
def __getitem__(self, index): def __getitem__(self, index):
""" """
...@@ -75,10 +67,7 @@ class MNIST(data.Dataset): ...@@ -75,10 +67,7 @@ class MNIST(data.Dataset):
Returns: Returns:
tuple: (image, target) where target is index of the target class. tuple: (image, target) where target is index of the target class.
""" """
if self.train: img, target = self.data[index], self.targets[index]
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
# doing this so that it is consistent with all other datasets # doing this so that it is consistent with all other datasets
# to return a PIL Image # to return a PIL Image
...@@ -93,10 +82,7 @@ class MNIST(data.Dataset): ...@@ -93,10 +82,7 @@ class MNIST(data.Dataset):
return img, target return img, target
def __len__(self): def __len__(self):
if self.train: return len(self.data)
return len(self.train_data)
else:
return len(self.test_data)
def _check_exists(self): def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
...@@ -104,7 +90,6 @@ class MNIST(data.Dataset): ...@@ -104,7 +90,6 @@ class MNIST(data.Dataset):
def download(self): def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already.""" """Download the MNIST data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import gzip import gzip
if self._check_exists(): if self._check_exists():
...@@ -228,7 +213,6 @@ class EMNIST(MNIST): ...@@ -228,7 +213,6 @@ class EMNIST(MNIST):
def download(self): def download(self):
"""Download the EMNIST data if it doesn't exist in processed_folder already.""" """Download the EMNIST data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import gzip import gzip
import shutil import shutil
import zipfile import zipfile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment