Unverified Commit 7e973b8a authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Revert "Fix FashionMNIST loading MNIST" (#646)

* Revert "move area calculation out of loop (#641)"

This reverts commit 62cbf0bf.

* Revert "Fix FashionMNIST loading MNIST (#640)"

This reverts commit 36215690.
parent 62cbf0bf
import torch
from torchvision.datasets import MNIST, FashionMNIST
import unittest
import tempfile
import shutil
import os
class Tester(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.test_dir)
def test_fashion_mnist_doesnt_load_mnist(self):
MNIST(root=self.test_dir, download=True)
FashionMNIST(root=self.test_dir, download=True)
if __name__ == '__main__':
unittest.main()
......@@ -7,7 +7,6 @@ import gzip
import numpy as np
import torch
import codecs
import hashlib
from .utils import download_url, makedir_exist_ok
......@@ -27,25 +26,14 @@ class MNIST(data.Dataset):
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
urls = [
'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
]
md5s = {
't10k-images-idx3-ubyte.gz': '9fb629c4189551a2d022fa330f9573f3',
't10k-labels-idx1-ubyte.gz': 'ec29112dd5afa0611ce80d1b7f02629c',
'train-images-idx3-ubyte.gz': 'f68b3c2dcbeaaa9fbdd348bbdeb94873',
'train-labels-idx1-ubyte.gz': 'd53e105ee54ea40749a09fcbcd1e9432',
}
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'mnist-training.pt'
test_file = 'mnist-test.pt'
training_file = 'training.pt'
test_file = 'test.pt'
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
......@@ -62,22 +50,11 @@ class MNIST(data.Dataset):
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
def load_data(filename):
loaded_data = torch.load(
os.path.join(self.root, self.processed_folder, filename))
if len(loaded_data) == 2:
return loaded_data
else:
clsname, data, labels = loaded_data
if clsname != type(self).__name__:
raise RuntimeError("Expected {} data but found {}"
.format(type(self).__name__, clsname, ))
return data, labels
if self.train:
self.train_data, self.train_labels = load_data(self.training_file)
data_file = self.training_file
else:
self.train_data, self.train_labels = load_data(self.test_file)
data_file = self.test_file
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
def __getitem__(self, index):
"""
......@@ -87,10 +64,7 @@ class MNIST(data.Dataset):
Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
......@@ -134,8 +108,6 @@ class MNIST(data.Dataset):
def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import gzip
if self._check_exists():
return
......@@ -147,21 +119,19 @@ class MNIST(data.Dataset):
for url in self.urls:
filename = url.rpartition('/')[2]
file_path = os.path.join(self.raw_folder, filename)
download_url(url, root=self.raw_folder, filename=filename, md5=self.md5s[filename])
download_url(url, root=self.raw_folder, filename=filename, md5=None)
self.extract_gzip(gzip_path=file_path, remove_finished=True)
# process and save as torch files
print('Processing...')
training_set = (
type(self).__name__,
read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
)
test_set = (
type(self).__name__,
read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
)
with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
torch.save(training_set, f)
......@@ -199,22 +169,14 @@ class FashionMNIST(MNIST):
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
training_file = 'fashion-mnist-training.pt'
test_file = 'fashion-mnist-test.pt'
urls = [
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
]
md5s = {
't10k-images-idx3-ubyte.gz': 'bef4ecab320f06d8554ea6380940ec79',
't10k-labels-idx1-ubyte.gz': 'bb300cfdad3c16e7a12a480ee83cd310',
'train-images-idx3-ubyte.gz': '8d4fb7e6c68d591d4c3dfef9ec88bf0d',
'train-labels-idx1-ubyte.gz': '25c81989df183df01b3e8a0aad5dffbe',
}
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
class EMNIST(MNIST):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment