Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
c74b79c8
Commit
c74b79c8
authored
Aug 16, 2018
by
Danylo Ulianych
Committed by
Francisco Massa
Aug 16, 2018
Browse files
MNIST loader refactored: permanent 'data' and 'targets' fields (#578)
parent
fe973cee
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
21 deletions
+5
-21
torchvision/datasets/mnist.py
torchvision/datasets/mnist.py
+5
-21
No files found.
torchvision/datasets/mnist.py
View file @
c74b79c8
...
...
@@ -40,13 +40,6 @@ class MNIST(data.Dataset):
'5 - five'
,
'6 - six'
,
'7 - seven'
,
'8 - eight'
,
'9 - nine'
]
class_to_idx
=
{
_class
:
i
for
i
,
_class
in
enumerate
(
classes
)}
@
property
def
targets
(
self
):
if
self
.
train
:
return
self
.
train_labels
else
:
return
self
.
test_labels
def
__init__
(
self
,
root
,
train
=
True
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
self
.
root
=
os
.
path
.
expanduser
(
root
)
self
.
transform
=
transform
...
...
@@ -61,11 +54,10 @@ class MNIST(data.Dataset):
' You can use download=True to download it'
)
if
self
.
train
:
self
.
train_data
,
self
.
train_labels
=
torch
.
load
(
os
.
path
.
join
(
self
.
root
,
self
.
processed_folder
,
self
.
training_file
))
data_file
=
self
.
training_file
else
:
self
.
test_data
,
self
.
test_labels
=
torch
.
load
(
os
.
path
.
join
(
self
.
root
,
self
.
processed_folder
,
self
.
test
_file
))
data_file
=
self
.
test_file
self
.
data
,
self
.
targets
=
torch
.
load
(
os
.
path
.
join
(
self
.
root
,
self
.
processed_folder
,
data
_file
))
def
__getitem__
(
self
,
index
):
"""
...
...
@@ -75,10 +67,7 @@ class MNIST(data.Dataset):
Returns:
tuple: (image, target) where target is index of the target class.
"""
if
self
.
train
:
img
,
target
=
self
.
train_data
[
index
],
self
.
train_labels
[
index
]
else
:
img
,
target
=
self
.
test_data
[
index
],
self
.
test_labels
[
index
]
img
,
target
=
self
.
data
[
index
],
self
.
targets
[
index
]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
...
...
@@ -93,10 +82,7 @@ class MNIST(data.Dataset):
return
img
,
target
def
__len__
(
self
):
if
self
.
train
:
return
len
(
self
.
train_data
)
else
:
return
len
(
self
.
test_data
)
return
len
(
self
.
data
)
def
_check_exists
(
self
):
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
root
,
self
.
processed_folder
,
self
.
training_file
))
and
\
...
...
@@ -104,7 +90,6 @@ class MNIST(data.Dataset):
def
download
(
self
):
"""Download the MNIST data if it doesn't exist in processed_folder already."""
from
six.moves
import
urllib
import
gzip
if
self
.
_check_exists
():
...
...
@@ -228,7 +213,6 @@ class EMNIST(MNIST):
def
download
(
self
):
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
from
six.moves
import
urllib
import
gzip
import
shutil
import
zipfile
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment