Commit e564f137 authored by Danylo Ulianych's avatar Danylo Ulianych Committed by Francisco Massa
Browse files

MNIST and FashionMNIST now have their own 'raw' and 'processed' folders (#601)

* MNIST and FashionMNIST now have their own 'raw' and 'processed' folders

* mkdir exist_ok
parent a7935ea6
...@@ -3,11 +3,11 @@ import torch.utils.data as data ...@@ -3,11 +3,11 @@ import torch.utils.data as data
from PIL import Image from PIL import Image
import os import os
import os.path import os.path
import errno import gzip
import numpy as np import numpy as np
import torch import torch
import codecs import codecs
from .utils import download_url from .utils import download_url, makedir_exist_ok
class MNIST(data.Dataset): class MNIST(data.Dataset):
...@@ -32,13 +32,10 @@ class MNIST(data.Dataset): ...@@ -32,13 +32,10 @@ class MNIST(data.Dataset):
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
] ]
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'training.pt' training_file = 'training.pt'
test_file = 'test.pt' test_file = 'test.pt'
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
class_to_idx = {_class: i for i, _class in enumerate(classes)}
def __init__(self, root, train=True, transform=None, target_transform=None, download=False): def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root) self.root = os.path.expanduser(root)
...@@ -57,7 +54,7 @@ class MNIST(data.Dataset): ...@@ -57,7 +54,7 @@ class MNIST(data.Dataset):
data_file = self.training_file data_file = self.training_file
else: else:
data_file = self.test_file data_file = self.test_file
self.data, self.targets = torch.load(os.path.join(self.root, self.processed_folder, data_file)) self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
def __getitem__(self, index): def __getitem__(self, index):
""" """
...@@ -84,51 +81,61 @@ class MNIST(data.Dataset): ...@@ -84,51 +81,61 @@ class MNIST(data.Dataset):
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
@property
def raw_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'raw')
@property
def processed_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'processed')
@property
def class_to_idx(self):
return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self): def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ return os.path.exists(os.path.join(self.processed_folder, self.training_file)) and \
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file)) os.path.exists(os.path.join(self.processed_folder, self.test_file))
@staticmethod
def extract_gzip(gzip_path, remove_finished=False):
print('Extracting {}'.format(gzip_path))
with open(gzip_path.replace('.gz', ''), 'wb') as out_f, \
gzip.GzipFile(gzip_path) as zip_f:
out_f.write(zip_f.read())
if remove_finished:
os.unlink(gzip_path)
def download(self): def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already.""" """Download the MNIST data if it doesn't exist in processed_folder already."""
import gzip
if self._check_exists(): if self._check_exists():
return return
# download files makedir_exist_ok(self.raw_folder)
try: makedir_exist_ok(self.processed_folder)
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
# download files
for url in self.urls: for url in self.urls:
filename = url.rpartition('/')[2] filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename) file_path = os.path.join(self.raw_folder, filename)
download_url(url, root=os.path.join(self.root, self.raw_folder), download_url(url, root=self.raw_folder, filename=filename, md5=None)
filename=filename, md5=None) self.extract_gzip(gzip_path=file_path, remove_finished=True)
with open(file_path.replace('.gz', ''), 'wb') as out_f, \
gzip.GzipFile(file_path) as zip_f:
out_f.write(zip_f.read())
os.unlink(file_path)
# process and save as torch files # process and save as torch files
print('Processing...') print('Processing...')
training_set = ( training_set = (
read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')), read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte')) read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
) )
test_set = ( test_set = (
read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')), read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte')) read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
) )
with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f: with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
torch.save(training_set, f) torch.save(training_set, f)
with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f: with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
torch.save(test_set, f) torch.save(test_set, f)
print('Done!') print('Done!')
...@@ -170,7 +177,6 @@ class FashionMNIST(MNIST): ...@@ -170,7 +177,6 @@ class FashionMNIST(MNIST):
] ]
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
class_to_idx = {_class: i for i, _class in enumerate(classes)}
class EMNIST(MNIST): class EMNIST(MNIST):
...@@ -205,64 +211,55 @@ class EMNIST(MNIST): ...@@ -205,64 +211,55 @@ class EMNIST(MNIST):
self.test_file = self._test_file(split) self.test_file = self._test_file(split)
super(EMNIST, self).__init__(root, **kwargs) super(EMNIST, self).__init__(root, **kwargs)
def _training_file(self, split): @staticmethod
def _training_file(split):
return 'training_{}.pt'.format(split) return 'training_{}.pt'.format(split)
def _test_file(self, split): @staticmethod
def _test_file(split):
return 'test_{}.pt'.format(split) return 'test_{}.pt'.format(split)
def download(self): def download(self):
"""Download the EMNIST data if it doesn't exist in processed_folder already.""" """Download the EMNIST data if it doesn't exist in processed_folder already."""
import gzip
import shutil import shutil
import zipfile import zipfile
if self._check_exists(): if self._check_exists():
return return
# download files makedir_exist_ok(self.raw_folder)
try: makedir_exist_ok(self.processed_folder)
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
# download files
filename = self.url.rpartition('/')[2] filename = self.url.rpartition('/')[2]
raw_folder = os.path.join(self.root, self.raw_folder) file_path = os.path.join(self.raw_folder, filename)
file_path = os.path.join(raw_folder, filename) download_url(self.url, root=self.raw_folder, filename=filename, md5=None)
download_url(self.url, root=file_path, filename=filename, md5=None)
print('Extracting zip archive') print('Extracting zip archive')
with zipfile.ZipFile(file_path) as zip_f: with zipfile.ZipFile(file_path) as zip_f:
zip_f.extractall(raw_folder) zip_f.extractall(self.raw_folder)
os.unlink(file_path) os.unlink(file_path)
gzip_folder = os.path.join(raw_folder, 'gzip') gzip_folder = os.path.join(self.raw_folder, 'gzip')
for gzip_file in os.listdir(gzip_folder): for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith('.gz'): if gzip_file.endswith('.gz'):
print('Extracting ' + gzip_file) self.extract_gzip(gzip_path=os.path.join(gzip_folder, gzip_file))
with open(os.path.join(raw_folder, gzip_file.replace('.gz', '')), 'wb') as out_f, \
gzip.GzipFile(os.path.join(gzip_folder, gzip_file)) as zip_f:
out_f.write(zip_f.read())
shutil.rmtree(gzip_folder)
# process and save as torch files # process and save as torch files
for split in self.splits: for split in self.splits:
print('Processing ' + split) print('Processing ' + split)
training_set = ( training_set = (
read_image_file(os.path.join(raw_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))), read_image_file(os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(raw_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split))) read_label_file(os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
) )
test_set = ( test_set = (
read_image_file(os.path.join(raw_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))), read_image_file(os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(raw_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split))) read_label_file(os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
) )
with open(os.path.join(self.root, self.processed_folder, self._training_file(split)), 'wb') as f: with open(os.path.join(self.processed_folder, self._training_file(split)), 'wb') as f:
torch.save(training_set, f) torch.save(training_set, f)
with open(os.path.join(self.root, self.processed_folder, self._test_file(split)), 'wb') as f: with open(os.path.join(self.processed_folder, self._test_file(split)), 'wb') as f:
torch.save(test_set, f) torch.save(test_set, f)
shutil.rmtree(gzip_folder)
print('Done!') print('Done!')
......
...@@ -31,20 +31,27 @@ def check_integrity(fpath, md5=None): ...@@ -31,20 +31,27 @@ def check_integrity(fpath, md5=None):
return True return True
def download_url(url, root, filename, md5): def makedir_exist_ok(dirpath):
from six.moves import urllib """
Python2 support for os.makedirs(.., exist_ok=True)
root = os.path.expanduser(root) """
fpath = os.path.join(root, filename)
try: try:
os.makedirs(root) os.makedirs(dirpath)
except OSError as e: except OSError as e:
if e.errno == errno.EEXIST: if e.errno == errno.EEXIST:
pass pass
else: else:
raise raise
def download_url(url, root, filename, md5):
from six.moves import urllib
root = os.path.expanduser(root)
fpath = os.path.join(root, filename)
makedir_exist_ok(root)
# downloads file # downloads file
if os.path.isfile(fpath) and check_integrity(fpath, md5): if os.path.isfile(fpath) and check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath) print('Using downloaded and verified file: ' + fpath)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment