"vscode:/vscode.git/clone" did not exist on "b2ca39c8ac160d58923c889a6ffc16a5734f7e84"
Unverified Commit d4a126b6 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Usps dataset (#961)

* add USPS dataset

* minor fixes

* Improvements to the USPS dataset

Add it to the documentation, expose it to torchvision.datasets
and inherit from VisionDataset
parent b45cdbfa
...@@ -188,3 +188,10 @@ SBD ...@@ -188,3 +188,10 @@ SBD
.. autoclass:: SBDataset .. autoclass:: SBDataset
:members: __getitem__ :members: __getitem__
:special-members: :special-members:
USPS
~~~~~
.. autoclass:: USPS
:members: __getitem__
:special-members:
...@@ -18,6 +18,7 @@ from .caltech import Caltech101, Caltech256 ...@@ -18,6 +18,7 @@ from .caltech import Caltech101, Caltech256
from .celeba import CelebA from .celeba import CelebA
from .sbd import SBDataset from .sbd import SBDataset
from .vision import VisionDataset from .vision import VisionDataset
from .usps import USPS
__all__ = ('LSUN', 'LSUNClass', __all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData', 'ImageFolder', 'DatasetFolder', 'FakeData',
...@@ -26,4 +27,5 @@ __all__ = ('LSUN', 'LSUNClass', ...@@ -26,4 +27,5 @@ __all__ = ('LSUN', 'LSUNClass',
'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', 'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k', 'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet', 'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset') 'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset',
'USPS')
from __future__ import print_function
from PIL import Image
import os
import numpy as np
from .utils import download_url
from .vision import VisionDataset
class USPS(VisionDataset):
"""`USPS <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps>`_ Dataset.
The data-format is : [label [index:value ]*256 \n] * num_lines, where ``label`` lies in ``[1, 10]``.
The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]``
and make pixel values in ``[0, 255]``.
Args:
root (string): Root directory of dataset to store``USPS`` data files.
train (bool, optional): If True, creates dataset from ``usps.bz2``,
otherwise from ``usps.t.bz2``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
split_list = {
'train': [
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
"usps.bz2", 'ec16c51db3855ca6c91edd34d0e9b197'
],
'test': [
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2",
"usps.t.bz2", '8ea070ee2aca1ac39742fdd1ef5ed118'
],
}
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
super(USPS, self).__init__(root, transform=transform, target_transform=target_transform)
split = 'train' if train else 'test'
url, filename, checksum = self.split_list[split]
full_path = os.path.join(self.root, filename)
if download and not os.path.exists(full_path):
download_url(url, self.root, filename, md5=checksum)
import bz2
with bz2.open(full_path) as fp:
raw_data = [l.decode().split() for l in fp.readlines()]
imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16))
imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
targets = [int(d[0]) - 1 for d in raw_data]
self.data = imgs
self.targets = targets
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img, mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment