Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
6db1569c
Unverified
Commit
6db1569c
authored
Jul 31, 2020
by
Philip Meier
Committed by
GitHub
Jul 31, 2020
Browse files
cifar (#2527)
parent
47f80acc
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
9 deletions
+16
-9
torchvision/datasets/cifar.py
torchvision/datasets/cifar.py
+16
-9
No files found.
torchvision/datasets/cifar.py
View file @
6db1569c
...
...
@@ -3,6 +3,7 @@ import os
import
os.path
import
numpy
as
np
import
pickle
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
from
.vision
import
VisionDataset
from
.utils
import
check_integrity
,
download_and_extract_archive
...
...
@@ -46,8 +47,14 @@ class CIFAR10(VisionDataset):
'md5'
:
'5ff9c542aee3614f3951f8cda6e48888'
,
}
def
__init__
(
self
,
root
,
train
=
True
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
def
__init__
(
self
,
root
:
str
,
train
:
bool
=
True
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
download
:
bool
=
False
,
)
->
None
:
super
(
CIFAR10
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
...
...
@@ -66,7 +73,7 @@ class CIFAR10(VisionDataset):
else
:
downloaded_list
=
self
.
test_list
self
.
data
=
[]
self
.
data
:
Any
=
[]
self
.
targets
=
[]
# now load the picked numpy arrays
...
...
@@ -85,7 +92,7 @@ class CIFAR10(VisionDataset):
self
.
_load_meta
()
def
_load_meta
(
self
):
def
_load_meta
(
self
)
->
None
:
path
=
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
,
self
.
meta
[
'filename'
])
if
not
check_integrity
(
path
,
self
.
meta
[
'md5'
]):
raise
RuntimeError
(
'Dataset metadata file not found or corrupted.'
+
...
...
@@ -95,7 +102,7 @@ class CIFAR10(VisionDataset):
self
.
classes
=
data
[
self
.
meta
[
'key'
]]
self
.
class_to_idx
=
{
_class
:
i
for
i
,
_class
in
enumerate
(
self
.
classes
)}
def
__getitem__
(
self
,
index
)
:
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
Any
,
Any
]
:
"""
Args:
index (int): Index
...
...
@@ -117,10 +124,10 @@ class CIFAR10(VisionDataset):
return
img
,
target
def
__len__
(
self
):
def
__len__
(
self
)
->
int
:
return
len
(
self
.
data
)
def
_check_integrity
(
self
):
def
_check_integrity
(
self
)
->
bool
:
root
=
self
.
root
for
fentry
in
(
self
.
train_list
+
self
.
test_list
):
filename
,
md5
=
fentry
[
0
],
fentry
[
1
]
...
...
@@ -129,13 +136,13 @@ class CIFAR10(VisionDataset):
return
False
return
True
def
download
(
self
):
def
download
(
self
)
->
None
:
if
self
.
_check_integrity
():
print
(
'Files already downloaded and verified'
)
return
download_and_extract_archive
(
self
.
url
,
self
.
root
,
filename
=
self
.
filename
,
md5
=
self
.
tgz_md5
)
def
extra_repr
(
self
):
def
extra_repr
(
self
)
->
str
:
return
"Split: {}"
.
format
(
"Train"
if
self
.
train
is
True
else
"Test"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment