"...text-generation-inference.git" did not exist on "08b3eac2ce54e25bec12088fd7e69ee3c07adaf5"
Commit 7eb9e658 authored by neoglez's avatar neoglez Committed by Soumith Chintala
Browse files

Add SEMEION Dataset (#324)

* First commit for semeion dataset

SEMEION Handwritten Digits Data Set
http://archive.ics.uci.edu/ml/datasets/semeion+handwritten+digit

* SEMEION Class

* Fix Lint errors

* More linting errors
parent 20913655
...@@ -7,9 +7,10 @@ from .mnist import MNIST, FashionMNIST ...@@ -7,9 +7,10 @@ from .mnist import MNIST, FashionMNIST
from .svhn import SVHN from .svhn import SVHN
from .phototour import PhotoTour from .phototour import PhotoTour
from .fakedata import FakeData from .fakedata import FakeData
from .semeion import SEMEION
__all__ = ('LSUN', 'LSUNClass', __all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'FakeData', 'ImageFolder', 'FakeData',
'CocoCaptions', 'CocoDetection', 'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour') 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION')
from __future__ import print_function
from PIL import Image
import os
import os.path
import errno
import numpy as np
import sys
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
import torch.utils.data as data
from .utils import download_url, check_integrity
class SEMEION(data.Dataset):
"""`SEMEION <http://archive.ics.uci.edu/ml/datasets/semeion+handwritten+digit>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``semeion.py`` exists.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
url = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data"
filename = "semeion.data"
md5_checksum = 'cb545d371d2ce14ec121470795a77432'
def __init__(self, root, transform=None, target_transform=None, download=True):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
self.data = []
self.labels = []
fp = os.path.join(root, self.filename)
file = open(fp, 'r')
data = file.read()
file.close()
dataSplitted = data.split("\n")[:-1]
datasetLength = len(dataSplitted)
i = 0
while i < datasetLength:
# Get the 'i-th' row
strings = dataSplitted[i]
# Split row into numbers(string), and avoid blank at the end
stringsSplitted = (strings[:-1]).split(" ")
# Get data (which ends at column 256th), then in a numpy array.
rawData = stringsSplitted[:256]
dataFloat = [float(j) for j in rawData]
img = np.array(dataFloat[:16])
j = 16
k = 0
while j < len(dataFloat):
temp = np.array(dataFloat[k:j])
img = np.vstack((img, temp))
k = j
j += 16
self.data.append(img)
# Get label and convert it into numbers, then in a numpy array.
labelString = stringsSplitted[256:]
labelInt = [int(index) for index in labelString]
self.labels.append(np.array(labelInt))
i += 1
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
# convert value to 8 bit unsigned integer
# color (white #255) the pixels
img = img.astype('uint8') * 255
img = Image.fromarray(img, mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
def _check_integrity(self):
root = self.root
fpath = os.path.join(root, self.filename)
if not check_integrity(fpath, self.md5_checksum):
return False
return True
def download(self):
if self._check_integrity():
print('Files already downloaded and verified')
return
root = self.root
download_url(self.url, root, self.filename, self.md5_checksum)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment