Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
fc7911c8
Commit
fc7911c8
authored
Sep 11, 2018
by
Danylo Ulianych
Committed by
Francisco Massa
Sep 11, 2018
Browse files
CIFAR: permanent 'data' and 'targets' fields (#594)
parent
f3d5e85d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
55 deletions
+23
-55
torchvision/datasets/cifar.py
torchvision/datasets/cifar.py
+23
-55
No files found.
torchvision/datasets/cifar.py
View file @
fc7911c8
...
...
@@ -51,13 +51,6 @@ class CIFAR10(data.Dataset):
'md5'
:
'5ff9c542aee3614f3951f8cda6e48888'
,
}
@
property
def
targets
(
self
):
if
self
.
train
:
return
self
.
train_labels
else
:
return
self
.
test_labels
def
__init__
(
self
,
root
,
train
=
True
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
...
...
@@ -73,44 +66,30 @@ class CIFAR10(data.Dataset):
raise
RuntimeError
(
'Dataset not found or corrupted.'
+
' You can use download=True to download it'
)
# now load the picked numpy arrays
if
self
.
train
:
self
.
train_data
=
[]
self
.
train_labels
=
[]
for
fentry
in
self
.
train_list
:
f
=
fentry
[
0
]
file
=
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
,
f
)
fo
=
open
(
file
,
'rb'
)
if
sys
.
version_info
[
0
]
==
2
:
entry
=
pickle
.
load
(
fo
)
else
:
entry
=
pickle
.
load
(
fo
,
encoding
=
'latin1'
)
self
.
train_data
.
append
(
entry
[
'data'
])
if
'labels'
in
entry
:
self
.
train_labels
+=
entry
[
'labels'
]
downloaded_list
=
self
.
train_list
else
:
self
.
train_labels
+=
entry
[
'fine_labels'
]
fo
.
close
()
downloaded_list
=
self
.
test_list
self
.
train_
data
=
np
.
concatenate
(
self
.
train_data
)
self
.
t
rain_data
=
self
.
train_data
.
reshape
((
50000
,
3
,
32
,
32
))
self
.
train_data
=
self
.
train_data
.
transpose
((
0
,
2
,
3
,
1
))
# convert to HWC
else
:
f
=
self
.
test_list
[
0
][
0
]
file
=
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
,
f
)
fo
=
open
(
file
,
'rb'
)
self
.
data
=
[]
self
.
t
argets
=
[]
# now load the picked numpy arrays
for
file_name
,
checksum
in
downloaded_list
:
file
_path
=
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
,
f
ile_name
)
with
open
(
file
_path
,
'rb'
)
as
f
:
if
sys
.
version_info
[
0
]
==
2
:
entry
=
pickle
.
load
(
f
o
)
entry
=
pickle
.
load
(
f
)
else
:
entry
=
pickle
.
load
(
f
o
,
encoding
=
'latin1'
)
self
.
test_data
=
entry
[
'data'
]
entry
=
pickle
.
load
(
f
,
encoding
=
'latin1'
)
self
.
data
.
append
(
entry
[
'data'
]
)
if
'labels'
in
entry
:
self
.
t
est_labels
=
entry
[
'labels'
]
self
.
t
argets
.
extend
(
entry
[
'labels'
]
)
else
:
self
.
t
est_labels
=
entry
[
'fine_labels'
]
fo
.
close
()
self
.
test_
data
=
self
.
test_
data
.
reshape
(
(
10000
,
3
,
32
,
32
)
)
self
.
test_
data
=
self
.
test_
data
.
transpose
((
0
,
2
,
3
,
1
))
# convert to HWC
self
.
t
argets
.
extend
(
entry
[
'fine_labels'
]
)
self
.
data
=
np
.
vstack
(
self
.
data
)
.
reshape
(
-
1
,
3
,
32
,
32
)
self
.
data
=
self
.
data
.
transpose
((
0
,
2
,
3
,
1
))
# convert to HWC
self
.
_load_meta
()
...
...
@@ -135,10 +114,7 @@ class CIFAR10(data.Dataset):
Returns:
tuple: (image, target) where target is index of the target class.
"""
if
self
.
train
:
img
,
target
=
self
.
train_data
[
index
],
self
.
train_labels
[
index
]
else
:
img
,
target
=
self
.
test_data
[
index
],
self
.
test_labels
[
index
]
img
,
target
=
self
.
data
[
index
],
self
.
targets
[
index
]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
...
...
@@ -153,10 +129,7 @@ class CIFAR10(data.Dataset):
return
img
,
target
def
__len__
(
self
):
if
self
.
train
:
return
len
(
self
.
train_data
)
else
:
return
len
(
self
.
test_data
)
return
len
(
self
.
data
)
def
_check_integrity
(
self
):
root
=
self
.
root
...
...
@@ -174,16 +147,11 @@ class CIFAR10(data.Dataset):
print
(
'Files already downloaded and verified'
)
return
root
=
self
.
root
download_url
(
self
.
url
,
root
,
self
.
filename
,
self
.
tgz_md5
)
download_url
(
self
.
url
,
self
.
root
,
self
.
filename
,
self
.
tgz_md5
)
# extract file
cwd
=
os
.
getcwd
()
tar
=
tarfile
.
open
(
os
.
path
.
join
(
root
,
self
.
filename
),
"r:gz"
)
os
.
chdir
(
root
)
tar
.
extractall
()
tar
.
close
()
os
.
chdir
(
cwd
)
with
tarfile
.
open
(
os
.
path
.
join
(
self
.
root
,
self
.
filename
),
"r:gz"
)
as
tar
:
tar
.
extractall
(
path
=
self
.
root
)
def
__repr__
(
self
):
fmt_str
=
'Dataset '
+
self
.
__class__
.
__name__
+
'
\n
'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment