Commit 36215690 authored by Leon Sixt's avatar Leon Sixt Committed by Soumith Chintala
Browse files

Fix FashionMNIST loading MNIST (#640)

Before this would lead FashionMNIST to contain mnist data:

```
MNIST(root, download=True)
FashionMNIST(root, download=True)
```

As MNIST and FashionMNIST are the same classes, the processed
outputs actual ended up to be the same files. This commit now
stores them at different files and also stores the class name when
saving them. I also added md5 sums.
parent c7e9bd30
import torch
from torchvision.datasets import MNIST, FashionMNIST
import unittest
import tempfile
import shutil
import os
class Tester(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.test_dir)
def test_fashion_mnist_doesnt_load_mnist(self):
MNIST(root=self.test_dir, download=True)
FashionMNIST(root=self.test_dir, download=True)
if __name__ == '__main__':
unittest.main()
......@@ -7,6 +7,7 @@ import gzip
import numpy as np
import torch
import codecs
import hashlib
from .utils import download_url, makedir_exist_ok
......@@ -26,14 +27,25 @@ class MNIST(data.Dataset):
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
urls = [
'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
]
training_file = 'training.pt'
test_file = 'test.pt'
md5s = {
't10k-images-idx3-ubyte.gz': '9fb629c4189551a2d022fa330f9573f3',
't10k-labels-idx1-ubyte.gz': 'ec29112dd5afa0611ce80d1b7f02629c',
'train-images-idx3-ubyte.gz': 'f68b3c2dcbeaaa9fbdd348bbdeb94873',
'train-labels-idx1-ubyte.gz': 'd53e105ee54ea40749a09fcbcd1e9432',
}
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'mnist-training.pt'
test_file = 'mnist-test.pt'
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
......@@ -50,11 +62,22 @@ class MNIST(data.Dataset):
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
def load_data(filename):
loaded_data = torch.load(
os.path.join(self.root, self.processed_folder, filename))
if len(loaded_data) == 2:
return loaded_data
else:
clsname, data, labels = loaded_data
if clsname != type(self).__name__:
raise RuntimeError("Expected {} data but found {}"
.format(type(self).__name__, clsname, ))
return data, labels
if self.train:
data_file = self.training_file
self.train_data, self.train_labels = load_data(self.training_file)
else:
data_file = self.test_file
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
self.train_data, self.train_labels = load_data(self.test_file)
def __getitem__(self, index):
"""
......@@ -64,7 +87,10 @@ class MNIST(data.Dataset):
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
......@@ -108,6 +134,8 @@ class MNIST(data.Dataset):
def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import gzip
if self._check_exists():
return
......@@ -119,19 +147,21 @@ class MNIST(data.Dataset):
for url in self.urls:
filename = url.rpartition('/')[2]
file_path = os.path.join(self.raw_folder, filename)
download_url(url, root=self.raw_folder, filename=filename, md5=None)
download_url(url, root=self.raw_folder, filename=filename, md5=self.md5s[filename])
self.extract_gzip(gzip_path=file_path, remove_finished=True)
# process and save as torch files
print('Processing...')
training_set = (
read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
type(self).__name__,
read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
)
test_set = (
read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
type(self).__name__,
read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
)
with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
torch.save(training_set, f)
......@@ -169,14 +199,22 @@ class FashionMNIST(MNIST):
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
training_file = 'fashion-mnist-training.pt'
test_file = 'fashion-mnist-test.pt'
urls = [
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
]
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
md5s = {
't10k-images-idx3-ubyte.gz': 'bef4ecab320f06d8554ea6380940ec79',
't10k-labels-idx1-ubyte.gz': 'bb300cfdad3c16e7a12a480ee83cd310',
'train-images-idx3-ubyte.gz': '8d4fb7e6c68d591d4c3dfef9ec88bf0d',
'train-labels-idx1-ubyte.gz': '25c81989df183df01b3e8a0aad5dffbe',
}
class EMNIST(MNIST):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment