Commit ac2e995a authored by Leon Bottou's avatar Leon Bottou Committed by Francisco Massa
Browse files

Added support for the QMNIST dataset (#995)

* Added general reader for sn3 tensors in "pascalvincent" format

* Added class QMNIST into mnist.py

* QMNIST dataset: make some pt files smaller

* Change request from fmassa.

* read_sn3_pascalvincent_tensor: cse

* read_sn3_pascalvincent_tensor: check file size (when strict!=False)

* Fix lint

* More lint

* Add documentation and expose QMNIST to dataset namespace
parent 2f64dd90
...@@ -45,6 +45,11 @@ EMNIST ...@@ -45,6 +45,11 @@ EMNIST
.. autoclass:: EMNIST .. autoclass:: EMNIST
QMNIST
~~~~~~
.. autoclass:: QMNIST
FakeData FakeData
~~~~~~~~ ~~~~~~~~
......
...@@ -3,7 +3,7 @@ from .folder import ImageFolder, DatasetFolder ...@@ -3,7 +3,7 @@ from .folder import ImageFolder, DatasetFolder
from .coco import CocoCaptions, CocoDetection from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100 from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10 from .stl10 import STL10
from .mnist import MNIST, EMNIST, FashionMNIST, KMNIST from .mnist import MNIST, EMNIST, FashionMNIST, KMNIST, QMNIST
from .svhn import SVHN from .svhn import SVHN
from .phototour import PhotoTour from .phototour import PhotoTour
from .fakedata import FakeData from .fakedata import FakeData
...@@ -23,7 +23,7 @@ from .usps import USPS ...@@ -23,7 +23,7 @@ from .usps import USPS
__all__ = ('LSUN', 'LSUNClass', __all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData', 'ImageFolder', 'DatasetFolder', 'FakeData',
'CocoCaptions', 'CocoDetection', 'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'QMNIST',
'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', 'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k', 'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet', 'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
......
...@@ -7,7 +7,7 @@ import os.path ...@@ -7,7 +7,7 @@ import os.path
import numpy as np import numpy as np
import torch import torch
import codecs import codecs
from .utils import download_and_extract_archive, extract_archive, makedir_exist_ok from .utils import download_url, download_and_extract_archive, extract_archive, makedir_exist_ok
class MNIST(VisionDataset): class MNIST(VisionDataset):
...@@ -286,25 +286,176 @@ class EMNIST(MNIST): ...@@ -286,25 +286,176 @@ class EMNIST(MNIST):
print('Done!') print('Done!')
class QMNIST(MNIST):
"""`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.
Args:
root (string): Root directory of dataset whose ``processed''
subdir contains torch binary files with the datasets.
what (string,optional): Can be 'train', 'test', 'test10k',
'test50k', or 'nist' for respectively the mnist compatible
training set, the 60k qmnist testing set, the 10k qmnist
examples that match the mnist testing set, the 50k
remaining qmnist testing examples, or all the nist
digits. The default is to select 'train' or 'test'
according to the compatibility argument 'train'.
compat (bool,optional): A boolean that says whether the target
for each example is class number (for compatibility with
the MNIST dataloader) or a torch vector containing the
full qmnist information. Default=True.
download (bool, optional): If true, downloads the dataset from
the internet and puts it in root directory. If dataset is
already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that
takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform
that takes in the target and transforms it.
train (bool,optional,compatibility): When argument 'what' is
not specified, this boolean decides whether to load the
training set ot the testing set. Default: True.
"""
subsets = {
'train': 'train',
'test': 'test', 'test10k': 'test', 'test50k': 'test',
'nist': 'nist'
}
urls = {
'train': ['https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz',
'https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz'],
'test': ['https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz',
'https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz'],
'nist': ['https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz',
'https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz']
}
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
def __init__(self, root, what=None, compat=True, train=True, **kwargs):
if what is None:
what = 'train' if train else 'test'
if not self.subsets.get(what):
raise RuntimeError("Argument 'what' should be one of: \n " +
repr(tuple(self.subsets.keys())))
self.what = what
self.compat = compat
self.data_file = what + '.pt'
self.training_file = self.data_file
self.test_file = self.data_file
super(QMNIST, self).__init__(root, train, **kwargs)
def download(self):
"""Download the QMNIST data if it doesn't exist in processed_folder already.
Note that we only download what has been asked for (argument 'what').
"""
if self._check_exists():
return
makedir_exist_ok(self.raw_folder)
makedir_exist_ok(self.processed_folder)
urls = self.urls[self.subsets[self.what]]
files = []
# download data files if not already there
for url in urls:
filename = url.rpartition('/')[2]
file_path = os.path.join(self.raw_folder, filename)
if not os.path.isfile(file_path):
download_url(url, root=self.raw_folder, filename=filename, md5=None)
files.append(file_path)
# process and save as torch files
print('Processing...')
data = read_sn3_pascalvincent_tensor(files[0])
assert(data.dtype == torch.uint8)
assert(data.ndimension() == 3)
targets = read_sn3_pascalvincent_tensor(files[1]).long()
assert(targets.ndimension() == 2)
if self.what == 'test10k':
data = data[0:10000, :, :].clone()
targets = targets[0:10000, :].clone()
if self.what == 'test50k':
data = data[10000:, :, :].clone()
targets = targets[10000:, :].clone()
with open(os.path.join(self.processed_folder, self.data_file), 'wb') as f:
torch.save((data, targets), f)
def __getitem__(self, index):
# redefined to handle the compat flag
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.compat:
target = int(target[0])
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def extra_repr(self):
return "Split: {}".format(self.what)
def get_int(b): def get_int(b):
return int(codecs.encode(b, 'hex'), 16) return int(codecs.encode(b, 'hex'), 16)
def open_maybe_compressed_file(path):
"""Return a file object that possibly decompresses 'path' on the fly.
Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'.
"""
if not isinstance(path, torch._six.string_classes):
return path
if path.endswith('.gz'):
import gzip
return gzip.open(path, 'rb')
if path.endswith('.xz'):
import lzma
return lzma.open(path, 'rb')
return open(path, 'rb')
def read_sn3_pascalvincent_tensor(path, strict=True):
"""Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
Argument may be a filename, compressed filename, or file object.
"""
# typemap
if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'):
read_sn3_pascalvincent_tensor.typemap = {
8: (torch.uint8, np.uint8, np.uint8),
9: (torch.int8, np.int8, np.int8),
11: (torch.int16, np.dtype('>i2'), 'i2'),
12: (torch.int32, np.dtype('>i4'), 'i4'),
13: (torch.float32, np.dtype('>f4'), 'f4'),
14: (torch.float64, np.dtype('>f8'), 'f8')}
# read
with open_maybe_compressed_file(path) as f:
data = f.read()
# parse
magic = get_int(data[0:4])
nd = magic % 256
ty = magic // 256
assert nd >= 1 and nd <= 3
assert ty >= 8 and ty <= 14
m = read_sn3_pascalvincent_tensor.typemap[ty]
s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)]
parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
assert parsed.shape[0] == np.prod(s) or not strict
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
def read_label_file(path): def read_label_file(path):
with open(path, 'rb') as f: with open(path, 'rb') as f:
data = f.read() x = read_sn3_pascalvincent_tensor(f, strict=False)
assert get_int(data[:4]) == 2049 assert(x.dtype == torch.uint8)
length = get_int(data[4:8]) assert(x.ndimension() == 1)
parsed = np.frombuffer(data, dtype=np.uint8, offset=8) return x.long()
return torch.from_numpy(parsed).view(length).long()
def read_image_file(path): def read_image_file(path):
with open(path, 'rb') as f: with open(path, 'rb') as f:
data = f.read() x = read_sn3_pascalvincent_tensor(f, strict=False)
assert get_int(data[:4]) == 2051 assert(x.dtype == torch.uint8)
length = get_int(data[4:8]) assert(x.ndimension() == 3)
num_rows = get_int(data[8:12]) return x
num_cols = get_int(data[12:16])
parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
return torch.from_numpy(parsed).view(length, num_rows, num_cols)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment