Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
5861f14a
Commit
5861f14a
authored
Dec 06, 2017
by
Martin Raison
Committed by
Alykhan Tejani
Dec 06, 2017
Browse files
EMNIST dataset + speedup *MNIST preprocessing (#334)
* EMNIST dataset + speedup *MNIST preprocessing
parent
ff3f738e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
110 additions
and
23 deletions
+110
-23
docs/source/datasets.rst
docs/source/datasets.rst
+5
-0
torchvision/datasets/__init__.py
torchvision/datasets/__init__.py
+2
-2
torchvision/datasets/mnist.py
torchvision/datasets/mnist.py
+103
-21
No files found.
docs/source/datasets.rst
View file @
5861f14a
...
@@ -35,6 +35,11 @@ Fashion-MNIST
...
@@ -35,6 +35,11 @@ Fashion-MNIST
.. autoclass:: FashionMNIST
.. autoclass:: FashionMNIST
EMNIST
~~~~~~
.. autoclass:: EMNIST
COCO
COCO
~~~~
~~~~
...
...
torchvision/datasets/__init__.py
View file @
5861f14a
...
@@ -3,7 +3,7 @@ from .folder import ImageFolder
...
@@ -3,7 +3,7 @@ from .folder import ImageFolder
from
.coco
import
CocoCaptions
,
CocoDetection
from
.coco
import
CocoCaptions
,
CocoDetection
from
.cifar
import
CIFAR10
,
CIFAR100
from
.cifar
import
CIFAR10
,
CIFAR100
from
.stl10
import
STL10
from
.stl10
import
STL10
from
.mnist
import
MNIST
,
FashionMNIST
from
.mnist
import
MNIST
,
EMNIST
,
FashionMNIST
from
.svhn
import
SVHN
from
.svhn
import
SVHN
from
.phototour
import
PhotoTour
from
.phototour
import
PhotoTour
from
.fakedata
import
FakeData
from
.fakedata
import
FakeData
...
@@ -12,5 +12,5 @@ from .semeion import SEMEION
...
@@ -12,5 +12,5 @@ from .semeion import SEMEION
__all__
=
(
'LSUN'
,
'LSUNClass'
,
__all__
=
(
'LSUN'
,
'LSUNClass'
,
'ImageFolder'
,
'FakeData'
,
'ImageFolder'
,
'FakeData'
,
'CocoCaptions'
,
'CocoDetection'
,
'CocoCaptions'
,
'CocoDetection'
,
'CIFAR10'
,
'CIFAR100'
,
'FashionMNIST'
,
'CIFAR10'
,
'CIFAR100'
,
'EMNIST'
,
'FashionMNIST'
,
'MNIST'
,
'STL10'
,
'SVHN'
,
'PhotoTour'
,
'SEMEION'
)
'MNIST'
,
'STL10'
,
'SVHN'
,
'PhotoTour'
,
'SEMEION'
)
torchvision/datasets/mnist.py
View file @
5861f14a
...
@@ -4,6 +4,7 @@ from PIL import Image
...
@@ -4,6 +4,7 @@ from PIL import Image
import
os
import
os
import
os.path
import
os.path
import
errno
import
errno
import
numpy
as
np
import
torch
import
torch
import
codecs
import
codecs
...
@@ -163,14 +164,106 @@ class FashionMNIST(MNIST):
...
@@ -163,14 +164,106 @@ class FashionMNIST(MNIST):
]
]
def
get_int
(
b
):
class
EMNIST
(
MNIST
):
return
int
(
codecs
.
encode
(
b
,
'hex'
),
16
)
"""`EMNIST <https://www.nist.gov/itl/iad/image-group/emnist-dataset/>`_ Dataset.
Args:
root (string): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
which one to use.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
url
=
'http://biometrics.nist.gov/cs_links/EMNIST/gzip.zip'
splits
=
(
'byclass'
,
'bymerge'
,
'balanced'
,
'letters'
,
'digits'
,
'mnist'
)
def
__init__
(
self
,
root
,
split
,
**
kwargs
):
if
split
not
in
self
.
splits
:
raise
ValueError
(
'Split "{}" not found. Valid splits are: {}'
.
format
(
split
,
', '
.
join
(
self
.
splits
),
))
self
.
split
=
split
self
.
training_file
=
self
.
_training_file
(
split
)
self
.
test_file
=
self
.
_test_file
(
split
)
super
(
EMNIST
,
self
).
__init__
(
root
,
**
kwargs
)
def
_training_file
(
self
,
split
):
return
'training_{}.pt'
.
format
(
split
)
def
_test_file
(
self
,
split
):
return
'test_{}.pt'
.
format
(
split
)
def
download
(
self
):
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
from
six.moves
import
urllib
import
gzip
import
shutil
import
zipfile
def
parse_byte
(
b
):
if
self
.
_check_exists
():
if
isinstance
(
b
,
str
):
return
return
ord
(
b
)
return
b
# download files
try
:
os
.
makedirs
(
os
.
path
.
join
(
self
.
root
,
self
.
raw_folder
))
os
.
makedirs
(
os
.
path
.
join
(
self
.
root
,
self
.
processed_folder
))
except
OSError
as
e
:
if
e
.
errno
==
errno
.
EEXIST
:
pass
else
:
raise
print
(
'Downloading '
+
self
.
url
)
data
=
urllib
.
request
.
urlopen
(
self
.
url
)
filename
=
self
.
url
.
rpartition
(
'/'
)[
2
]
raw_folder
=
os
.
path
.
join
(
self
.
root
,
self
.
raw_folder
)
file_path
=
os
.
path
.
join
(
raw_folder
,
filename
)
with
open
(
file_path
,
'wb'
)
as
f
:
f
.
write
(
data
.
read
())
print
(
'Extracting zip archive'
)
with
zipfile
.
ZipFile
(
file_path
)
as
zip_f
:
zip_f
.
extractall
(
raw_folder
)
os
.
unlink
(
file_path
)
gzip_folder
=
os
.
path
.
join
(
raw_folder
,
'gzip'
)
for
gzip_file
in
os
.
listdir
(
gzip_folder
):
if
gzip_file
.
endswith
(
'.gz'
):
print
(
'Extracting '
+
gzip_file
)
with
open
(
os
.
path
.
join
(
raw_folder
,
gzip_file
.
replace
(
'.gz'
,
''
)),
'wb'
)
as
out_f
,
\
gzip
.
GzipFile
(
os
.
path
.
join
(
gzip_folder
,
gzip_file
))
as
zip_f
:
out_f
.
write
(
zip_f
.
read
())
shutil
.
rmtree
(
gzip_folder
)
# process and save as torch files
for
split
in
self
.
splits
:
print
(
'Processing '
+
split
)
training_set
=
(
read_image_file
(
os
.
path
.
join
(
raw_folder
,
'emnist-{}-train-images-idx3-ubyte'
.
format
(
split
))),
read_label_file
(
os
.
path
.
join
(
raw_folder
,
'emnist-{}-train-labels-idx1-ubyte'
.
format
(
split
)))
)
test_set
=
(
read_image_file
(
os
.
path
.
join
(
raw_folder
,
'emnist-{}-test-images-idx3-ubyte'
.
format
(
split
))),
read_label_file
(
os
.
path
.
join
(
raw_folder
,
'emnist-{}-test-labels-idx1-ubyte'
.
format
(
split
)))
)
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
processed_folder
,
self
.
_training_file
(
split
)),
'wb'
)
as
f
:
torch
.
save
(
training_set
,
f
)
with
open
(
os
.
path
.
join
(
self
.
root
,
self
.
processed_folder
,
self
.
_test_file
(
split
)),
'wb'
)
as
f
:
torch
.
save
(
test_set
,
f
)
print
(
'Done!'
)
def
get_int
(
b
):
return
int
(
codecs
.
encode
(
b
,
'hex'
),
16
)
def
read_label_file
(
path
):
def
read_label_file
(
path
):
...
@@ -178,9 +271,8 @@ def read_label_file(path):
...
@@ -178,9 +271,8 @@ def read_label_file(path):
data
=
f
.
read
()
data
=
f
.
read
()
assert
get_int
(
data
[:
4
])
==
2049
assert
get_int
(
data
[:
4
])
==
2049
length
=
get_int
(
data
[
4
:
8
])
length
=
get_int
(
data
[
4
:
8
])
labels
=
[
parse_byte
(
b
)
for
b
in
data
[
8
:]]
parsed
=
np
.
frombuffer
(
data
,
dtype
=
np
.
uint8
,
offset
=
8
)
assert
len
(
labels
)
==
length
return
torch
.
from_numpy
(
parsed
).
view
(
length
).
long
()
return
torch
.
LongTensor
(
labels
)
def
read_image_file
(
path
):
def
read_image_file
(
path
):
...
@@ -191,15 +283,5 @@ def read_image_file(path):
...
@@ -191,15 +283,5 @@ def read_image_file(path):
num_rows
=
get_int
(
data
[
8
:
12
])
num_rows
=
get_int
(
data
[
8
:
12
])
num_cols
=
get_int
(
data
[
12
:
16
])
num_cols
=
get_int
(
data
[
12
:
16
])
images
=
[]
images
=
[]
idx
=
16
parsed
=
np
.
frombuffer
(
data
,
dtype
=
np
.
uint8
,
offset
=
16
)
for
l
in
range
(
length
):
return
torch
.
from_numpy
(
parsed
).
view
(
length
,
num_rows
,
num_cols
)
img
=
[]
images
.
append
(
img
)
for
r
in
range
(
num_rows
):
row
=
[]
img
.
append
(
row
)
for
c
in
range
(
num_cols
):
row
.
append
(
parse_byte
(
data
[
idx
]))
idx
+=
1
assert
len
(
images
)
==
length
return
torch
.
ByteTensor
(
images
).
view
(
-
1
,
28
,
28
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment