Commit 5861f14a authored by Martin Raison's avatar Martin Raison Committed by Alykhan Tejani
Browse files

EMNIST dataset + speedup *MNIST preprocessing (#334)

* EMNIST dataset + speedup *MNIST preprocessing
parent ff3f738e
......@@ -35,6 +35,11 @@ Fashion-MNIST
.. autoclass:: FashionMNIST
EMNIST
~~~~~~
.. autoclass:: EMNIST
COCO
~~~~
......
......@@ -3,7 +3,7 @@ from .folder import ImageFolder
from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10
from .mnist import MNIST, FashionMNIST
from .mnist import MNIST, EMNIST, FashionMNIST
from .svhn import SVHN
from .phototour import PhotoTour
from .fakedata import FakeData
......@@ -12,5 +12,5 @@ from .semeion import SEMEION
__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'FashionMNIST',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION')
......@@ -4,6 +4,7 @@ from PIL import Image
import os
import os.path
import errno
import numpy as np
import torch
import codecs
......@@ -163,14 +164,106 @@ class FashionMNIST(MNIST):
]
def get_int(b):
return int(codecs.encode(b, 'hex'), 16)
class EMNIST(MNIST):
"""`EMNIST <https://www.nist.gov/itl/iad/image-group/emnist-dataset/>`_ Dataset.
Args:
root (string): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
which one to use.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
url = 'http://biometrics.nist.gov/cs_links/EMNIST/gzip.zip'
splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')
def __init__(self, root, split, **kwargs):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
self.split = split
self.training_file = self._training_file(split)
self.test_file = self._test_file(split)
super(EMNIST, self).__init__(root, **kwargs)
def _training_file(self, split):
return 'training_{}.pt'.format(split)
def _test_file(self, split):
return 'test_{}.pt'.format(split)
def download(self):
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import gzip
import shutil
import zipfile
if self._check_exists():
return
# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
print('Downloading ' + self.url)
data = urllib.request.urlopen(self.url)
filename = self.url.rpartition('/')[2]
raw_folder = os.path.join(self.root, self.raw_folder)
file_path = os.path.join(raw_folder, filename)
with open(file_path, 'wb') as f:
f.write(data.read())
print('Extracting zip archive')
with zipfile.ZipFile(file_path) as zip_f:
zip_f.extractall(raw_folder)
os.unlink(file_path)
gzip_folder = os.path.join(raw_folder, 'gzip')
for gzip_file in os.listdir(gzip_folder):
if gzip_file.endswith('.gz'):
print('Extracting ' + gzip_file)
with open(os.path.join(raw_folder, gzip_file.replace('.gz', '')), 'wb') as out_f, \
gzip.GzipFile(os.path.join(gzip_folder, gzip_file)) as zip_f:
out_f.write(zip_f.read())
shutil.rmtree(gzip_folder)
def parse_byte(b):
if isinstance(b, str):
return ord(b)
return b
# process and save as torch files
for split in self.splits:
print('Processing ' + split)
training_set = (
read_image_file(os.path.join(raw_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(raw_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
)
test_set = (
read_image_file(os.path.join(raw_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
read_label_file(os.path.join(raw_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
)
with open(os.path.join(self.root, self.processed_folder, self._training_file(split)), 'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.root, self.processed_folder, self._test_file(split)), 'wb') as f:
torch.save(test_set, f)
print('Done!')
def get_int(b):
return int(codecs.encode(b, 'hex'), 16)
def read_label_file(path):
......@@ -178,9 +271,8 @@ def read_label_file(path):
data = f.read()
assert get_int(data[:4]) == 2049
length = get_int(data[4:8])
labels = [parse_byte(b) for b in data[8:]]
assert len(labels) == length
return torch.LongTensor(labels)
parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
return torch.from_numpy(parsed).view(length).long()
def read_image_file(path):
......@@ -191,15 +283,5 @@ def read_image_file(path):
num_rows = get_int(data[8:12])
num_cols = get_int(data[12:16])
images = []
idx = 16
for l in range(length):
img = []
images.append(img)
for r in range(num_rows):
row = []
img.append(row)
for c in range(num_cols):
row.append(parse_byte(data[idx]))
idx += 1
assert len(images) == length
return torch.ByteTensor(images).view(-1, 28, 28)
parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
return torch.from_numpy(parsed).view(length, num_rows, num_cols)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment