Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
63dabcaf
Commit
63dabcaf
authored
Nov 10, 2016
by
Soumith Chintala
Committed by
GitHub
Nov 10, 2016
Browse files
Merge pull request #3 from pytorch/cifar
cifar 10 and 100
parents
e37323d9
754d526f
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
191 additions
and
1 deletion
+191
-1
.gitignore
.gitignore
+7
-0
README.md
README.md
+10
-0
test/cifar.py
test/cifar.py
+12
-0
torchvision/datasets/__init__.py
torchvision/datasets/__init__.py
+3
-1
torchvision/datasets/cifar.py
torchvision/datasets/cifar.py
+159
-0
No files found.
.gitignore
0 → 100644
View file @
63dabcaf
build/
dist/
torchvision.egg-info/
*/**/__pycache__
*/**/*.pyc
*/**/*~
*~
\ No newline at end of file
README.md
View file @
63dabcaf
...
...
@@ -29,6 +29,7 @@ The following dataset loaders are available:
-
[
LSUN Classification
](
#lsun
)
-
[
ImageFolder
](
#imagefolder
)
-
[
Imagenet-12
](
#imagenet-12
)
-
[
CIFAR10 and CIFAR100
](
#cifar
)
Datasets have the API:
-
`__getitem__`
...
...
@@ -97,6 +98,15 @@ u'A mountain view with a plume of smoke in the background']
-
['bedroom_train', 'church_train', ...] : a list of categories to load
### CIFAR
`dset.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)`
`dset.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)`
-
`root`
: root directory of dataset where there is folder
`cifar-10-batches-py`
-
`train`
:
`True`
= Training set,
`False`
= Test set
-
`download`
:
`True`
= downloads the dataset from the internet and puts it in root directory. If dataset already downloaded, does not do anything.
### ImageFolder
A generic data loader where the images are arranged in this way:
...
...
test/cifar.py
0 → 100644
View file @
63dabcaf
import
torch
import
torchvision.datasets
as
dset
print
(
'
\n\n
Cifar 10'
)
a
=
dset
.
CIFAR10
(
root
=
"abc/def/ghi"
,
download
=
True
)
print
(
a
[
3
])
print
(
'
\n\n
Cifar 100'
)
a
=
dset
.
CIFAR100
(
root
=
"abc/def/ghi"
,
download
=
True
)
print
(
a
[
3
])
torchvision/datasets/__init__.py
View file @
63dabcaf
from
.lsun
import
LSUN
,
LSUNClass
from
.folder
import
ImageFolder
from
.coco
import
CocoCaptions
,
CocoDetection
from
.cifar
import
CIFAR10
,
CIFAR100
__all__
=
(
'LSUN'
,
'LSUNClass'
,
'ImageFolder'
,
'CocoCaptions'
,
'CocoDetection'
)
'CocoCaptions'
,
'CocoDetection'
,
'CIFAR10'
,
'CIFAR100'
)
torchvision/datasets/cifar.py
0 → 100644
View file @
63dabcaf
from
__future__
import
print_function
import
torch.utils.data
as
data
from
PIL
import
Image
import
os
import
os.path
import
errno
import
numpy
as
np
import
sys
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
else
:
import
pickle
class
CIFAR10
(
data
.
Dataset
):
base_folder
=
'cifar-10-batches-py'
url
=
"http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename
=
"cifar-10-python.tar.gz"
tgz_mdf
=
'c58f30108f718f92721af3b95e74349a'
train_list
=
[
[
'data_batch_1'
,
'c99cafc152244af753f735de768cd75f'
],
[
'data_batch_2'
,
'd4bba439e000b95fd0a9bffe97cbabec'
],
[
'data_batch_3'
,
'54ebc095f3ab1f0389bbae665268c751'
],
[
'data_batch_4'
,
'634d18415352ddfa80567beed471001a'
],
[
'data_batch_5'
,
'482c414d41f54cd18b22e5b47cb7c3cb'
],
]
test_list
=
[
[
'test_batch'
,
'40351d587109b95175f43aff81a1287e'
],
]
def
__init__
(
self
,
root
,
train
=
True
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
self
.
root
=
root
self
.
transform
=
transform
self
.
target_transform
=
target_transform
self
.
train
=
train
# training set or test set
if
download
:
self
.
download
()
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
'Dataset not found or corrupted.'
+
' You can use download=True to download it'
)
# now load the picked numpy arrays
self
.
train_data
=
[]
self
.
train_labels
=
[]
for
fentry
in
self
.
train_list
:
f
=
fentry
[
0
]
file
=
os
.
path
.
join
(
root
,
self
.
base_folder
,
f
)
fo
=
open
(
file
,
'rb'
)
entry
=
pickle
.
load
(
fo
)
self
.
train_data
.
append
(
entry
[
'data'
])
if
'labels'
in
entry
:
self
.
train_labels
+=
entry
[
'labels'
]
else
:
self
.
train_labels
+=
entry
[
'fine_labels'
]
fo
.
close
()
self
.
train_data
=
np
.
concatenate
(
self
.
train_data
)
f
=
self
.
test_list
[
0
][
0
]
file
=
os
.
path
.
join
(
root
,
self
.
base_folder
,
f
)
fo
=
open
(
file
,
'rb'
)
entry
=
pickle
.
load
(
fo
)
self
.
test_data
=
entry
[
'data'
]
if
'labels'
in
entry
:
self
.
test_labels
=
entry
[
'labels'
]
else
:
self
.
test_labels
=
entry
[
'fine_labels'
]
fo
.
close
()
self
.
train_data
=
self
.
train_data
.
reshape
((
50000
,
3
,
32
,
32
))
self
.
test_data
=
self
.
test_data
.
reshape
((
10000
,
3
,
32
,
32
))
def
__getitem__
(
self
,
index
):
if
self
.
train
:
img
,
target
=
self
.
train_data
[
index
],
self
.
train_labels
[
index
]
else
:
img
,
target
=
self
.
test_data
[
index
],
self
.
test_labels
[
index
]
if
self
.
transform
is
not
None
:
img
=
self
.
transform
(
img
)
if
self
.
target_transform
is
not
None
:
target
=
self
.
target_transform
(
target
)
return
img
,
target
def
__len__
(
self
):
if
self
.
train
:
return
50000
else
:
return
10000
def
_check_integrity
(
self
):
import
hashlib
root
=
self
.
root
for
fentry
in
(
self
.
train_list
+
self
.
test_list
):
filename
,
md5
=
fentry
[
0
],
fentry
[
1
]
fpath
=
os
.
path
.
join
(
root
,
self
.
base_folder
,
filename
)
if
not
os
.
path
.
isfile
(
fpath
):
return
False
md5c
=
hashlib
.
md5
(
open
(
fpath
,
'rb'
).
read
()).
hexdigest
()
if
md5c
!=
md5
:
return
False
return
True
def
download
(
self
):
from
six.moves
import
urllib
import
tarfile
import
hashlib
root
=
self
.
root
fpath
=
os
.
path
.
join
(
root
,
self
.
filename
)
try
:
os
.
makedirs
(
root
)
except
OSError
as
e
:
if
e
.
errno
==
errno
.
EEXIST
:
pass
else
:
raise
if
self
.
_check_integrity
():
print
(
'Files already downloaded and verified'
)
return
# downloads file
if
os
.
path
.
isfile
(
fpath
)
and
\
hashlib
.
md5
(
open
(
fpath
,
'rb'
).
read
()).
hexdigest
()
==
self
.
tgz_md5
:
print
(
'Using downloaded file: '
+
fpath
)
else
:
print
(
'Downloading '
+
self
.
url
+
' to '
+
fpath
)
urllib
.
request
.
urlretrieve
(
self
.
url
,
fpath
)
# extract file
cwd
=
os
.
getcwd
()
print
(
'Extracting tar file'
)
tar
=
tarfile
.
open
(
fpath
,
"r:gz"
)
os
.
chdir
(
root
)
tar
.
extractall
()
tar
.
close
()
os
.
chdir
(
cwd
)
print
(
'Done!'
)
class
CIFAR100
(
CIFAR10
):
base_folder
=
'cifar-100-python'
url
=
"http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename
=
"cifar-100-python.tar.gz"
tgz_md5
=
'eb9058c3a382ffc7106e4002c42a8d85'
train_list
=
[
[
'train'
,
'16019d7e3df5f24257cddd939b257f8d'
],
]
test_list
=
[
[
'test'
,
'f0ef6b0ae62326f3e7ffdfab6717acfc'
],
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment