Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
754d526f
Commit
754d526f
authored
Nov 10, 2016
by
Soumith Chintala
Committed by
Soumith Chintala
Nov 10, 2016
Browse files
cifar 10 and 100
parent
9378febe
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
191 additions
and
1 deletion
+191
-1
.gitignore
.gitignore
+7
-0
README.md
README.md
+10
-0
test/cifar.py
test/cifar.py
+12
-0
torchvision/datasets/__init__.py
torchvision/datasets/__init__.py
+3
-1
torchvision/datasets/cifar.py
torchvision/datasets/cifar.py
+159
-0
No files found.
.gitignore
0 → 100644
View file @
754d526f
build/
dist/
torchvision.egg-info/
*/**/__pycache__
*/**/*.pyc
*/**/*~
*~
\ No newline at end of file
README.md
View file @
754d526f
...
@@ -29,6 +29,7 @@ The following dataset loaders are available:
...
@@ -29,6 +29,7 @@ The following dataset loaders are available:
-
[
LSUN Classification
](
#lsun
)
-
[
LSUN Classification
](
#lsun
)
-
[
ImageFolder
](
#imagefolder
)
-
[
ImageFolder
](
#imagefolder
)
-
[
Imagenet-12
](
#imagenet-12
)
-
[
Imagenet-12
](
#imagenet-12
)
-
[
CIFAR10 and CIFAR100
](
#cifar
)
Datasets have the API:
Datasets have the API:
-
`__getitem__`
-
`__getitem__`
...
@@ -97,6 +98,15 @@ u'A mountain view with a plume of smoke in the background']
...
@@ -97,6 +98,15 @@ u'A mountain view with a plume of smoke in the background']
-
['bedroom_train', 'church_train', ...] : a list of categories to load
-
['bedroom_train', 'church_train', ...] : a list of categories to load
### CIFAR
`dset.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)`
`dset.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)`
-
`root`
: root directory of dataset where there is folder
`cifar-10-batches-py`
-
`train`
:
`True`
= Training set,
`False`
= Test set
-
`download`
:
`True`
= downloads the dataset from the internet and puts it in root directory. If dataset already downloaded, does not do anything.
### ImageFolder
### ImageFolder
A generic data loader where the images are arranged in this way:
A generic data loader where the images are arranged in this way:
...
...
test/cifar.py
0 → 100644
View file @
754d526f
import
torch
import
torchvision.datasets
as
dset
print
(
'
\n\n
Cifar 10'
)
a
=
dset
.
CIFAR10
(
root
=
"abc/def/ghi"
,
download
=
True
)
print
(
a
[
3
])
print
(
'
\n\n
Cifar 100'
)
a
=
dset
.
CIFAR100
(
root
=
"abc/def/ghi"
,
download
=
True
)
print
(
a
[
3
])
torchvision/datasets/__init__.py
View file @
754d526f
from
.lsun
import
LSUN
,
LSUNClass
from
.lsun
import
LSUN
,
LSUNClass
from
.folder
import
ImageFolder
from
.folder
import
ImageFolder
from
.coco
import
CocoCaptions
,
CocoDetection
from
.coco
import
CocoCaptions
,
CocoDetection
from
.cifar
import
CIFAR10
,
CIFAR100
__all__
=
(
'LSUN'
,
'LSUNClass'
,
__all__
=
(
'LSUN'
,
'LSUNClass'
,
'ImageFolder'
,
'ImageFolder'
,
'CocoCaptions'
,
'CocoDetection'
)
'CocoCaptions'
,
'CocoDetection'
,
'CIFAR10'
,
'CIFAR100'
)
torchvision/datasets/cifar.py
0 → 100644
View file @
754d526f
from
__future__
import
print_function
import
torch.utils.data
as
data
from
PIL
import
Image
import
os
import
os.path
import
errno
import
numpy
as
np
import
sys
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
class
CIFAR10
(
data
.
Dataset
):
base_folder
=
'cifar-10-batches-py'
url
=
"http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename
=
"cifar-10-python.tar.gz"
tgz_mdf
=
'c58f30108f718f92721af3b95e74349a'
train_list
=
[
[
'data_batch_1'
,
'c99cafc152244af753f735de768cd75f'
],
[
'data_batch_2'
,
'd4bba439e000b95fd0a9bffe97cbabec'
],
[
'data_batch_3'
,
'54ebc095f3ab1f0389bbae665268c751'
],
[
'data_batch_4'
,
'634d18415352ddfa80567beed471001a'
],
[
'data_batch_5'
,
'482c414d41f54cd18b22e5b47cb7c3cb'
],
]
test_list
=
[
[
'test_batch'
,
'40351d587109b95175f43aff81a1287e'
],
]
def
__init__
(
self
,
root
,
train
=
True
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
self
.
root
=
root
self
.
transform
=
transform
self
.
target_transform
=
target_transform
self
.
train
=
train
# training set or test set
if
download
:
self
.
download
()
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
'Dataset not found or corrupted.'
+
' You can use download=True to download it'
)
# now load the picked numpy arrays
self
.
train_data
=
[]
self
.
train_labels
=
[]
for
fentry
in
self
.
train_list
:
f
=
fentry
[
0
]
file
=
os
.
path
.
join
(
root
,
self
.
base_folder
,
f
)
fo
=
open
(
file
,
'rb'
)
entry
=
pickle
.
load
(
fo
)
self
.
train_data
.
append
(
entry
[
'data'
])
if
'labels'
in
entry
:
self
.
train_labels
+=
entry
[
'labels'
]
else
:
self
.
train_labels
+=
entry
[
'fine_labels'
]
fo
.
close
()
self
.
train_data
=
np
.
concatenate
(
self
.
train_data
)
f
=
self
.
test_list
[
0
][
0
]
file
=
os
.
path
.
join
(
root
,
self
.
base_folder
,
f
)
fo
=
open
(
file
,
'rb'
)
entry
=
pickle
.
load
(
fo
)
self
.
test_data
=
entry
[
'data'
]
if
'labels'
in
entry
:
self
.
test_labels
=
entry
[
'labels'
]
else
:
self
.
test_labels
=
entry
[
'fine_labels'
]
fo
.
close
()
self
.
train_data
=
self
.
train_data
.
reshape
((
50000
,
3
,
32
,
32
))
self
.
test_data
=
self
.
test_data
.
reshape
((
10000
,
3
,
32
,
32
))
def
__getitem__
(
self
,
index
):
if
self
.
train
:
img
,
target
=
self
.
train_data
[
index
],
self
.
train_labels
[
index
]
else
:
img
,
target
=
self
.
test_data
[
index
],
self
.
test_labels
[
index
]
if
self
.
transform
is
not
None
:
img
=
self
.
transform
(
img
)
if
self
.
target_transform
is
not
None
:
target
=
self
.
target_transform
(
target
)
return
img
,
target
def
__len__
(
self
):
if
self
.
train
:
return
50000
else
:
return
10000
def
_check_integrity
(
self
):
import
hashlib
root
=
self
.
root
for
fentry
in
(
self
.
train_list
+
self
.
test_list
):
filename
,
md5
=
fentry
[
0
],
fentry
[
1
]
fpath
=
os
.
path
.
join
(
root
,
self
.
base_folder
,
filename
)
if
not
os
.
path
.
isfile
(
fpath
):
return
False
md5c
=
hashlib
.
md5
(
open
(
fpath
,
'rb'
).
read
()).
hexdigest
()
if
md5c
!=
md5
:
return
False
return
True
def
download
(
self
):
from
six.moves
import
urllib
import
tarfile
import
hashlib
root
=
self
.
root
fpath
=
os
.
path
.
join
(
root
,
self
.
filename
)
try
:
os
.
makedirs
(
root
)
except
OSError
as
e
:
if
e
.
errno
==
errno
.
EEXIST
:
pass
else
:
raise
if
self
.
_check_integrity
():
print
(
'Files already downloaded and verified'
)
return
# downloads file
if
os
.
path
.
isfile
(
fpath
)
and
\
hashlib
.
md5
(
open
(
fpath
,
'rb'
).
read
()).
hexdigest
()
==
self
.
tgz_md5
:
print
(
'Using downloaded file: '
+
fpath
)
else
:
print
(
'Downloading '
+
self
.
url
+
' to '
+
fpath
)
urllib
.
request
.
urlretrieve
(
self
.
url
,
fpath
)
# extract file
cwd
=
os
.
getcwd
()
print
(
'Extracting tar file'
)
tar
=
tarfile
.
open
(
fpath
,
"r:gz"
)
os
.
chdir
(
root
)
tar
.
extractall
()
tar
.
close
()
os
.
chdir
(
cwd
)
print
(
'Done!'
)
class
CIFAR100
(
CIFAR10
):
base_folder
=
'cifar-100-python'
url
=
"http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename
=
"cifar-100-python.tar.gz"
tgz_md5
=
'eb9058c3a382ffc7106e4002c42a8d85'
train_list
=
[
[
'train'
,
'16019d7e3df5f24257cddd939b257f8d'
],
]
test_list
=
[
[
'test'
,
'f0ef6b0ae62326f3e7ffdfab6717acfc'
],
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment