Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
23e0d653
Commit
23e0d653
authored
Mar 17, 2017
by
Uridah Sami Ahmed
Committed by
Soumith Chintala
Mar 16, 2017
Browse files
SVHN dataset for torchvision (#98)
parent
c4f4c73a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
124 additions
and
1 deletion
+124
-1
README.rst
README.rst
+11
-0
torchvision/datasets/__init__.py
torchvision/datasets/__init__.py
+2
-1
torchvision/datasets/svhn.py
torchvision/datasets/svhn.py
+111
-0
No files found.
README.rst
View file @
23e0d653
...
...
@@ -168,6 +168,17 @@ STL10
- ``download`` : ``True`` = downloads the dataset from the internet and
puts it in root directory. If dataset already downloaded, does not do
anything.
SVHN
~~~~~
``dset.SVHN(root, split='train', transform=None, target_transform=None, download=False)``
- ``root`` : root directory of dataset where there is folder ``SVHN``
- ``split`` : ``'train'`` = Training set, ``'test'`` = Test set, ``'extra'`` = Extra training set
- ``download`` : ``True`` = downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, does not do
anything.
ImageFolder
~~~~~~~~~~~
...
...
torchvision/datasets/__init__.py
View file @
23e0d653
...
...
@@ -4,9 +4,10 @@ from .coco import CocoCaptions, CocoDetection
from
.cifar
import
CIFAR10
,
CIFAR100
from
.stl10
import
STL10
from
.mnist
import
MNIST
from
.svhn
import
SVHN
__all__
=
(
'LSUN'
,
'LSUNClass'
,
'ImageFolder'
,
'CocoCaptions'
,
'CocoDetection'
,
'CIFAR10'
,
'CIFAR100'
,
'MNIST'
,
'STL10'
)
'MNIST'
,
'STL10'
,
'SVHN'
)
torchvision/datasets/svhn.py
0 → 100644
View file @
23e0d653
from
__future__
import
print_function
import
torch.utils.data
as
data
from
PIL
import
Image
import
os
import
os.path
import
errno
import
numpy
as
np
import
sys
class
SVHN
(
data
.
Dataset
):
url
=
""
filename
=
""
file_md5
=
""
split_list
=
{
'train'
:
[
"http://ufldl.stanford.edu/housenumbers/train_32x32.mat"
,
"train_32x32.mat"
,
"e26dedcc434d2e4c54c9b2d4a06d8373"
],
'test'
:
[
"http://ufldl.stanford.edu/housenumbers/test_32x32.mat"
,
"test_32x32.mat"
,
"eb5a983be6a315427106f1b164d9cef3"
],
'extra'
:
[
"http://ufldl.stanford.edu/housenumbers/extra_32x32.mat"
,
"extra_32x32.mat"
,
"a93ce644f1a588dc4d68dda5feec44a7"
]}
def
__init__
(
self
,
root
,
split
=
'train'
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
self
.
root
=
root
self
.
transform
=
transform
self
.
target_transform
=
target_transform
self
.
split
=
split
# training set or test set or extra set
if
self
.
split
not
in
self
.
split_list
:
raise
ValueError
(
'Wrong split entered! Please use split="train" or split="extra" or split="test"'
)
self
.
url
=
self
.
split_list
[
split
][
0
]
self
.
filename
=
self
.
split_list
[
split
][
1
]
self
.
file_md5
=
self
.
split_list
[
split
][
2
]
if
download
:
self
.
download
()
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
'Dataset not found or corrupted.'
+
' You can use download=True to download it'
)
# import here rather than at top of file because this is
# an optional dependency for torchvision
import
scipy.io
as
sio
# reading(loading) mat file as array
loaded_mat
=
sio
.
loadmat
(
os
.
path
.
join
(
root
,
self
.
filename
))
self
.
data
=
loaded_mat
[
'X'
]
self
.
labels
=
loaded_mat
[
'y'
]
self
.
data
=
np
.
transpose
(
self
.
data
,
(
3
,
2
,
0
,
1
))
def
__getitem__
(
self
,
index
):
img
,
target
=
self
.
data
[
index
],
self
.
labels
[
index
]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img
=
Image
.
fromarray
(
np
.
transpose
(
img
,
(
1
,
2
,
0
)))
if
self
.
transform
is
not
None
:
img
=
self
.
transform
(
img
)
if
self
.
target_transform
is
not
None
:
target
=
self
.
target_transform
(
target
)
return
img
,
target
def
__len__
(
self
):
return
len
(
self
.
data
)
def
_check_integrity
(
self
):
import
hashlib
root
=
self
.
root
md5
=
self
.
split_list
[
self
.
split
][
2
]
fpath
=
os
.
path
.
join
(
root
,
self
.
filename
)
if
not
os
.
path
.
isfile
(
fpath
):
return
False
md5c
=
hashlib
.
md5
(
open
(
fpath
,
'rb'
).
read
()).
hexdigest
()
if
md5c
!=
md5
:
return
False
return
True
def
download
(
self
):
from
six.moves
import
urllib
import
tarfile
import
hashlib
root
=
self
.
root
fpath
=
os
.
path
.
join
(
root
,
self
.
filename
)
try
:
os
.
makedirs
(
root
)
except
OSError
as
e
:
if
e
.
errno
==
errno
.
EEXIST
:
pass
else
:
raise
if
self
.
_check_integrity
():
print
(
'Files already downloaded and verified'
)
return
# downloads file
if
os
.
path
.
isfile
(
fpath
):
print
(
'Using downloaded file: '
+
fpath
)
else
:
print
(
'Downloading '
+
self
.
url
+
' to '
+
fpath
)
urllib
.
request
.
urlretrieve
(
self
.
url
,
fpath
)
print
(
'Downloaded!'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment