Commit 23e0d653 authored by Uridah Sami Ahmed's avatar Uridah Sami Ahmed Committed by Soumith Chintala
Browse files

SVHN dataset for torchvision (#98)

parent c4f4c73a
...@@ -169,6 +169,17 @@ STL10 ...@@ -169,6 +169,17 @@ STL10
puts it in root directory. If dataset already downloaded, does not do puts it in root directory. If dataset already downloaded, does not do
anything. anything.
SVHN
~~~~~
``dset.SVHN(root, split='train', transform=None, target_transform=None, download=False)``
- ``root`` : root directory of dataset where there is folder ``SVHN``
- ``split`` : ``'train'`` = Training set, ``'test'`` = Test set, ``'extra'`` = Extra training set
- ``download`` : ``True`` = downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, does not do
anything.
ImageFolder ImageFolder
~~~~~~~~~~~ ~~~~~~~~~~~
......
...@@ -4,9 +4,10 @@ from .coco import CocoCaptions, CocoDetection ...@@ -4,9 +4,10 @@ from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100 from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10 from .stl10 import STL10
from .mnist import MNIST from .mnist import MNIST
from .svhn import SVHN
__all__ = ('LSUN', 'LSUNClass', __all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'ImageFolder',
'CocoCaptions', 'CocoDetection', 'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'CIFAR10', 'CIFAR100',
'MNIST', 'STL10') 'MNIST', 'STL10', 'SVHN')
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import numpy as np
import sys
class SVHN(data.Dataset):
url = ""
filename = ""
file_md5 = ""
split_list = {
'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat",
"train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"],
'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat",
"test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"],
'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
"extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}
def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.split = split # training set or test set or extra set
if self.split not in self.split_list:
raise ValueError('Wrong split entered! Please use split="train" or split="extra" or split="test"')
self.url = self.split_list[split][0]
self.filename = self.split_list[split][1]
self.file_md5 = self.split_list[split][2]
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
# import here rather than at top of file because this is
# an optional dependency for torchvision
import scipy.io as sio
# reading(loading) mat file as array
loaded_mat = sio.loadmat(os.path.join(root, self.filename))
self.data = loaded_mat['X']
self.labels = loaded_mat['y']
self.data = np.transpose(self.data, (3, 2, 0, 1))
def __getitem__(self, index):
img, target = self.data[index], self.labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(np.transpose(img, (1, 2, 0)))
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
def _check_integrity(self):
import hashlib
root = self.root
md5 = self.split_list[self.split][2]
fpath = os.path.join(root, self.filename)
if not os.path.isfile(fpath):
return False
md5c = hashlib.md5(open(fpath, 'rb').read()).hexdigest()
if md5c != md5:
return False
return True
def download(self):
from six.moves import urllib
import tarfile
import hashlib
root = self.root
fpath = os.path.join(root, self.filename)
try:
os.makedirs(root)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
if self._check_integrity():
print('Files already downloaded and verified')
return
# downloads file
if os.path.isfile(fpath):
print('Using downloaded file: ' + fpath)
else:
print('Downloading ' + self.url + ' to ' + fpath)
urllib.request.urlretrieve(self.url, fpath)
print ('Downloaded!')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment