Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
c74b79c8
Commit
c74b79c8
authored
Aug 16, 2018
by
Danylo Ulianych
Committed by
Francisco Massa
Aug 16, 2018
Browse files
MNIST loader refactored: permanent 'data' and 'targets' fields (#578)
parent
fe973cee
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
21 deletions
+5
-21
torchvision/datasets/mnist.py
torchvision/datasets/mnist.py
+5
-21
No files found.
torchvision/datasets/mnist.py
View file @
c74b79c8
...
@@ -40,13 +40,6 @@ class MNIST(data.Dataset):
...
@@ -40,13 +40,6 @@ class MNIST(data.Dataset):
'5 - five'
,
'6 - six'
,
'7 - seven'
,
'8 - eight'
,
'9 - nine'
]
'5 - five'
,
'6 - six'
,
'7 - seven'
,
'8 - eight'
,
'9 - nine'
]
class_to_idx
=
{
_class
:
i
for
i
,
_class
in
enumerate
(
classes
)}
class_to_idx
=
{
_class
:
i
for
i
,
_class
in
enumerate
(
classes
)}
@
property
def
targets
(
self
):
if
self
.
train
:
return
self
.
train_labels
else
:
return
self
.
test_labels
def
__init__
(
self
,
root
,
train
=
True
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
def
__init__
(
self
,
root
,
train
=
True
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
self
.
root
=
os
.
path
.
expanduser
(
root
)
self
.
root
=
os
.
path
.
expanduser
(
root
)
self
.
transform
=
transform
self
.
transform
=
transform
...
@@ -61,11 +54,10 @@ class MNIST(data.Dataset):
...
@@ -61,11 +54,10 @@ class MNIST(data.Dataset):
' You can use download=True to download it'
)
' You can use download=True to download it'
)
if
self
.
train
:
if
self
.
train
:
self
.
train_data
,
self
.
train_labels
=
torch
.
load
(
data_file
=
self
.
training_file
os
.
path
.
join
(
self
.
root
,
self
.
processed_folder
,
self
.
training_file
))
else
:
else
:
self
.
test_data
,
self
.
test_labels
=
torch
.
load
(
data_file
=
self
.
test_file
os
.
path
.
join
(
self
.
root
,
self
.
processed_folder
,
self
.
test
_file
))
self
.
data
,
self
.
targets
=
torch
.
load
(
os
.
path
.
join
(
self
.
root
,
self
.
processed_folder
,
data
_file
))
def
__getitem__
(
self
,
index
):
def
__getitem__
(
self
,
index
):
"""
"""
...
@@ -75,10 +67,7 @@ class MNIST(data.Dataset):
...
@@ -75,10 +67,7 @@ class MNIST(data.Dataset):
Returns:
Returns:
tuple: (image, target) where target is index of the target class.
tuple: (image, target) where target is index of the target class.
"""
"""
if
self
.
train
:
img
,
target
=
self
.
data
[
index
],
self
.
targets
[
index
]
img
,
target
=
self
.
train_data
[
index
],
self
.
train_labels
[
index
]
else
:
img
,
target
=
self
.
test_data
[
index
],
self
.
test_labels
[
index
]
# doing this so that it is consistent with all other datasets
# doing this so that it is consistent with all other datasets
# to return a PIL Image
# to return a PIL Image
...
@@ -93,10 +82,7 @@ class MNIST(data.Dataset):
...
@@ -93,10 +82,7 @@ class MNIST(data.Dataset):
return
img
,
target
return
img
,
target
def
__len__
(
self
):
def
__len__
(
self
):
if
self
.
train
:
return
len
(
self
.
data
)
return
len
(
self
.
train_data
)
else
:
return
len
(
self
.
test_data
)
def
_check_exists
(
self
):
def
_check_exists
(
self
):
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
root
,
self
.
processed_folder
,
self
.
training_file
))
and
\
return
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
root
,
self
.
processed_folder
,
self
.
training_file
))
and
\
...
@@ -104,7 +90,6 @@ class MNIST(data.Dataset):
...
@@ -104,7 +90,6 @@ class MNIST(data.Dataset):
def
download
(
self
):
def
download
(
self
):
"""Download the MNIST data if it doesn't exist in processed_folder already."""
"""Download the MNIST data if it doesn't exist in processed_folder already."""
from
six.moves
import
urllib
import
gzip
import
gzip
if
self
.
_check_exists
():
if
self
.
_check_exists
():
...
@@ -228,7 +213,6 @@ class EMNIST(MNIST):
...
@@ -228,7 +213,6 @@ class EMNIST(MNIST):
def
download
(
self
):
def
download
(
self
):
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
"""Download the EMNIST data if it doesn't exist in processed_folder already."""
from
six.moves
import
urllib
import
gzip
import
gzip
import
shutil
import
shutil
import
zipfile
import
zipfile
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment