Commit 36215690 authored by Leon Sixt's avatar Leon Sixt Committed by Soumith Chintala
Browse files

Fix FashionMNIST loading MNIST (#640)

Before this would lead FashionMNIST to contain mnist data:

```
MNIST(root, download=True)
FashionMNIST(root, download=True)
```

As MNIST and FashionMNIST are the same classes, the processed
outputs actual ended up to be the same files. This commit now
stores them at different files and also stores the class name when
saving them. I also added md5 sums.
parent c7e9bd30
import torch
from torchvision.datasets import MNIST, FashionMNIST
import unittest
import tempfile
import shutil
import os
class Tester(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.test_dir)
def test_fashion_mnist_doesnt_load_mnist(self):
MNIST(root=self.test_dir, download=True)
FashionMNIST(root=self.test_dir, download=True)
if __name__ == '__main__':
unittest.main()
...@@ -7,6 +7,7 @@ import gzip ...@@ -7,6 +7,7 @@ import gzip
import numpy as np import numpy as np
import torch import torch
import codecs import codecs
import hashlib
from .utils import download_url, makedir_exist_ok from .utils import download_url, makedir_exist_ok
...@@ -26,14 +27,25 @@ class MNIST(data.Dataset): ...@@ -26,14 +27,25 @@ class MNIST(data.Dataset):
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
""" """
urls = [ urls = [
'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
] ]
training_file = 'training.pt'
test_file = 'test.pt' md5s = {
't10k-images-idx3-ubyte.gz': '9fb629c4189551a2d022fa330f9573f3',
't10k-labels-idx1-ubyte.gz': 'ec29112dd5afa0611ce80d1b7f02629c',
'train-images-idx3-ubyte.gz': 'f68b3c2dcbeaaa9fbdd348bbdeb94873',
'train-labels-idx1-ubyte.gz': 'd53e105ee54ea40749a09fcbcd1e9432',
}
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'mnist-training.pt'
test_file = 'mnist-test.pt'
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
...@@ -50,11 +62,22 @@ class MNIST(data.Dataset): ...@@ -50,11 +62,22 @@ class MNIST(data.Dataset):
raise RuntimeError('Dataset not found.' + raise RuntimeError('Dataset not found.' +
' You can use download=True to download it') ' You can use download=True to download it')
def load_data(filename):
loaded_data = torch.load(
os.path.join(self.root, self.processed_folder, filename))
if len(loaded_data) == 2:
return loaded_data
else:
clsname, data, labels = loaded_data
if clsname != type(self).__name__:
raise RuntimeError("Expected {} data but found {}"
.format(type(self).__name__, clsname, ))
return data, labels
if self.train: if self.train:
data_file = self.training_file self.train_data, self.train_labels = load_data(self.training_file)
else: else:
data_file = self.test_file self.train_data, self.train_labels = load_data(self.test_file)
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
def __getitem__(self, index): def __getitem__(self, index):
""" """
...@@ -64,7 +87,10 @@ class MNIST(data.Dataset): ...@@ -64,7 +87,10 @@ class MNIST(data.Dataset):
Returns: Returns:
tuple: (image, target) where target is index of the target class. tuple: (image, target) where target is index of the target class.
""" """
img, target = self.data[index], int(self.targets[index]) if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
# doing this so that it is consistent with all other datasets # doing this so that it is consistent with all other datasets
# to return a PIL Image # to return a PIL Image
...@@ -108,6 +134,8 @@ class MNIST(data.Dataset): ...@@ -108,6 +134,8 @@ class MNIST(data.Dataset):
def download(self): def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already.""" """Download the MNIST data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import gzip
if self._check_exists(): if self._check_exists():
return return
...@@ -119,19 +147,21 @@ class MNIST(data.Dataset): ...@@ -119,19 +147,21 @@ class MNIST(data.Dataset):
for url in self.urls: for url in self.urls:
filename = url.rpartition('/')[2] filename = url.rpartition('/')[2]
file_path = os.path.join(self.raw_folder, filename) file_path = os.path.join(self.raw_folder, filename)
download_url(url, root=self.raw_folder, filename=filename, md5=None) download_url(url, root=self.raw_folder, filename=filename, md5=self.md5s[filename])
self.extract_gzip(gzip_path=file_path, remove_finished=True) self.extract_gzip(gzip_path=file_path, remove_finished=True)
# process and save as torch files # process and save as torch files
print('Processing...') print('Processing...')
training_set = ( training_set = (
read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')), type(self).__name__,
read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte')) read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
) )
test_set = ( test_set = (
read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')), type(self).__name__,
read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte')) read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
) )
with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f: with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
torch.save(training_set, f) torch.save(training_set, f)
...@@ -169,14 +199,22 @@ class FashionMNIST(MNIST): ...@@ -169,14 +199,22 @@ class FashionMNIST(MNIST):
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
""" """
training_file = 'fashion-mnist-training.pt'
test_file = 'fashion-mnist-test.pt'
urls = [ urls = [
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz', 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz', 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz', 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz', 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
] ]
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] md5s = {
't10k-images-idx3-ubyte.gz': 'bef4ecab320f06d8554ea6380940ec79',
't10k-labels-idx1-ubyte.gz': 'bb300cfdad3c16e7a12a480ee83cd310',
'train-images-idx3-ubyte.gz': '8d4fb7e6c68d591d4c3dfef9ec88bf0d',
'train-labels-idx1-ubyte.gz': '25c81989df183df01b3e8a0aad5dffbe',
}
class EMNIST(MNIST): class EMNIST(MNIST):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment