Commit eec5ba44 authored by Dr. Kashif Rasul's avatar Dr. Kashif Rasul Committed by Soumith Chintala
Browse files

added FashionMNIST dataset (#238)

* added FashionMNIST dataset

* documentation

* fixed formatting

* fixed formatting
parent 7492fae4
...@@ -43,7 +43,7 @@ Datasets ...@@ -43,7 +43,7 @@ Datasets
The following dataset loaders are available: The following dataset loaders are available:
- `MNIST <#mnist>`__ - `MNIST and FashionMNIST <#mnist>`__
- `COCO (Captioning and Detection) <#coco>`__ - `COCO (Captioning and Detection) <#coco>`__
- `LSUN Classification <#lsun>`__ - `LSUN Classification <#lsun>`__
- `ImageFolder <#imagefolder>`__ - `ImageFolder <#imagefolder>`__
...@@ -77,6 +77,8 @@ MNIST ...@@ -77,6 +77,8 @@ MNIST
~~~~~ ~~~~~
``dset.MNIST(root, train=True, transform=None, target_transform=None, download=False)`` ``dset.MNIST(root, train=True, transform=None, target_transform=None, download=False)``
``dset.FashionMNIST(root, train=True, transform=None, target_transform=None, download=False)``
``root``: root directory of dataset where ``processed/training.pt`` and ``processed/test.pt`` exist ``root``: root directory of dataset where ``processed/training.pt`` and ``processed/test.pt`` exist
``train``: ``True`` - use training set, ``False`` - use test set. ``train``: ``True`` - use training set, ``False`` - use test set.
...@@ -390,32 +392,32 @@ For example: ...@@ -390,32 +392,32 @@ For example:
Utils Utils
===== =====
make\_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale\_each=False, pad\_value=0) ``make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Given a 4D mini-batch Tensor of shape (B x C x H x W), Given a 4D mini-batch Tensor of shape (B x C x H x W),
or a list of images all of the same size, or a list of images all of the same size,
makes a grid of images makes a grid of images
normalize=True will shift the image to the range (0, 1), ``normalize=True`` will shift the image to the range (0, 1),
by subtracting the minimum and dividing by the maximum pixel value. by subtracting the minimum and dividing by the maximum pixel value.
if range=(min, max) where min and max are numbers, then these numbers are used to if ``range=(min, max)`` where ``min`` and ``max`` are numbers, then these numbers are used to
normalize the image. normalize the image.
scale_each=True will scale each image in the batch of images separately rather than ``scale_each=True`` will scale each image in the batch of images separately rather than
computing the (min, max) over all images. computing the ``(min, max)`` over all images.
pad_value=<float> sets the value for the padded pixels. ``pad_value=<float>`` sets the value for the padded pixels.
`Example usage is given in this notebook` <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91> `Example usage is given in this notebook` <https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91>
save\_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale\_each=False, pad\_value=0) ``save_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Saves a given Tensor into an image file. Saves a given Tensor into an image file.
If given a mini-batch tensor, will save the tensor as a grid of images. If given a mini-batch tensor, will save the tensor as a grid of images.
All options after `filename` are passed through to `make_grid`. Refer to it's documentation for All options after ``filename`` are passed through to ``make_grid``. Refer to it's documentation for
more details more details
...@@ -3,7 +3,7 @@ from .folder import ImageFolder ...@@ -3,7 +3,7 @@ from .folder import ImageFolder
from .coco import CocoCaptions, CocoDetection from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100 from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10 from .stl10 import STL10
from .mnist import MNIST from .mnist import MNIST, FashionMNIST
from .svhn import SVHN from .svhn import SVHN
from .phototour import PhotoTour from .phototour import PhotoTour
from .fakedata import FakeData from .fakedata import FakeData
...@@ -11,5 +11,5 @@ from .fakedata import FakeData ...@@ -11,5 +11,5 @@ from .fakedata import FakeData
__all__ = ('LSUN', 'LSUNClass', __all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'FakeData', 'ImageFolder', 'FakeData',
'CocoCaptions', 'CocoDetection', 'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'CIFAR10', 'CIFAR100', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour') 'MNIST', 'STL10', 'SVHN', 'PhotoTour')
...@@ -139,6 +139,17 @@ class MNIST(data.Dataset): ...@@ -139,6 +139,17 @@ class MNIST(data.Dataset):
print('Done!') print('Done!')
class FashionMNIST(MNIST):
"""`Fashion MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.
"""
urls = [
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
]
def get_int(b): def get_int(b):
return int(codecs.encode(b, 'hex'), 16) return int(codecs.encode(b, 'hex'), 16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment