Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
c1fdfa68
Commit
c1fdfa68
authored
Mar 02, 2017
by
Elad Hoffer
Committed by
Soumith Chintala
Mar 02, 2017
Browse files
add stl10 dataset (#83)
parent
66183f50
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
118 additions
and
1 deletion
+118
-1
README.rst
README.rst
+13
-0
torchvision/datasets/__init__.py
torchvision/datasets/__init__.py
+2
-1
torchvision/datasets/stl10.py
torchvision/datasets/stl10.py
+103
-0
No files found.
README.rst
View file @
c1fdfa68
...
@@ -49,6 +49,7 @@ The following dataset loaders are available:
...
@@ -49,6 +49,7 @@ The following dataset loaders are available:
- `ImageFolder <#imagefolder>`__
- `ImageFolder <#imagefolder>`__
- `Imagenet-12 <#imagenet-12>`__
- `Imagenet-12 <#imagenet-12>`__
- `CIFAR10 and CIFAR100 <#cifar>`__
- `CIFAR10 and CIFAR100 <#cifar>`__
- `STL10 <#stl10>`__
Datasets have the API: - ``__getitem__`` - ``__len__`` They all subclass
Datasets have the API: - ``__getitem__`` - ``__len__`` They all subclass
from ``torch.utils.data.Dataset`` Hence, they can all be multi-threaded
from ``torch.utils.data.Dataset`` Hence, they can all be multi-threaded
...
@@ -156,6 +157,18 @@ CIFAR
...
@@ -156,6 +157,18 @@ CIFAR
puts it in root directory. If dataset already downloaded, does not do
puts it in root directory. If dataset already downloaded, does not do
anything.
anything.
STL10
~~~~~
``dset.STL10(root, split='train', transform=None, target_transform=None, download=False)``
- ``root`` : root directory of dataset where there is folder ``stl10_binary``
- ``split`` : ``'train'`` = Training set, ``'test'`` = Test set, ``'unlabeled'`` = Unlabeled set,
``'train+unlabeled'`` = Training + Unlabeled set (missing label marked as ``-1``)
- ``download`` : ``True`` = downloads the dataset from the internet and
puts it in root directory. If dataset already downloaded, does not do
anything.
ImageFolder
ImageFolder
~~~~~~~~~~~
~~~~~~~~~~~
...
...
torchvision/datasets/__init__.py
View file @
c1fdfa68
...
@@ -2,10 +2,11 @@ from .lsun import LSUN, LSUNClass
...
@@ -2,10 +2,11 @@ from .lsun import LSUN, LSUNClass
from
.folder
import
ImageFolder
from
.folder
import
ImageFolder
from
.coco
import
CocoCaptions
,
CocoDetection
from
.coco
import
CocoCaptions
,
CocoDetection
from
.cifar
import
CIFAR10
,
CIFAR100
from
.cifar
import
CIFAR10
,
CIFAR100
from
.stl10
import
STL10
from
.mnist
import
MNIST
from
.mnist
import
MNIST
__all__
=
(
'LSUN'
,
'LSUNClass'
,
__all__
=
(
'LSUN'
,
'LSUNClass'
,
'ImageFolder'
,
'ImageFolder'
,
'CocoCaptions'
,
'CocoDetection'
,
'CocoCaptions'
,
'CocoDetection'
,
'CIFAR10'
,
'CIFAR100'
,
'CIFAR10'
,
'CIFAR100'
,
'MNIST'
)
'MNIST'
,
'STL10'
)
torchvision/datasets/stl10.py
0 → 100644
View file @
c1fdfa68
from
__future__
import
print_function
import
torch.utils.data
as
data
from
PIL
import
Image
import
os
import
os.path
import
errno
import
numpy
as
np
import
sys
from
.cifar
import
CIFAR10
class
STL10
(
CIFAR10
):
base_folder
=
'stl10_binary'
url
=
"http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz"
filename
=
"stl10_binary.tar.gz"
tgz_md5
=
'91f7769df0f17e558f3565bffb0c7dfb'
class_names_file
=
'class_names.txt'
train_list
=
[
[
'train_X.bin'
,
'918c2871b30a85fa023e0c44e0bee87f'
],
[
'train_y.bin'
,
'5a34089d4802c674881badbb80307741'
],
[
'unlabeled_X.bin'
,
'5242ba1fed5e4be9e1e742405eb56ca4'
]
]
test_list
=
[
[
'test_X.bin'
,
'7f263ba9f9e0b06b93213547f721ac82'
],
[
'test_y.bin'
,
'36f9794fa4beb8a2c72628de14fa638e'
]
]
def
__init__
(
self
,
root
,
split
=
'train'
,
transform
=
None
,
target_transform
=
None
,
download
=
False
):
self
.
root
=
root
self
.
transform
=
transform
self
.
target_transform
=
target_transform
self
.
split
=
split
# train/test/unlabeled set
if
download
:
self
.
download
()
if
not
self
.
_check_integrity
():
raise
RuntimeError
(
'Dataset not found or corrupted. You can use download=True to download it'
)
# now load the picked numpy arrays
if
self
.
split
==
'train'
:
self
.
data
,
self
.
labels
=
self
.
__loadfile
(
self
.
train_list
[
0
][
0
],
self
.
train_list
[
1
][
0
])
elif
self
.
split
==
'train+unlabeled'
:
self
.
data
,
self
.
labels
=
self
.
__loadfile
(
self
.
train_list
[
0
][
0
],
self
.
train_list
[
1
][
0
])
unlabeled_data
,
_
=
self
.
__loadfile
(
self
.
train_list
[
2
][
0
])
self
.
data
=
np
.
concatenate
((
self
.
data
,
unlabeled_data
))
self
.
labels
=
np
.
concatenate
(
(
self
.
labels
,
np
.
asarray
([
-
1
]
*
unlabeled_data
.
shape
[
0
])))
elif
self
.
split
==
'unlabeled'
:
self
.
data
,
_
=
self
.
__loadfile
(
self
.
train_list
[
2
][
0
])
self
.
labels
=
None
else
:
# self.split == 'test':
self
.
data
,
self
.
labels
=
self
.
__loadfile
(
self
.
test_list
[
0
][
0
],
self
.
test_list
[
1
][
0
])
class_file
=
os
.
path
.
join
(
root
,
self
.
base_folder
,
self
.
class_names_file
)
if
os
.
path
.
isfile
(
class_file
):
with
open
(
class_file
)
as
f
:
self
.
classes
=
f
.
read
().
splitlines
()
def
__getitem__
(
self
,
index
):
if
self
.
labels
is
not
None
:
img
,
target
=
self
.
data
[
index
],
int
(
self
.
labels
[
index
])
else
:
img
,
target
=
self
.
data
[
index
],
None
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img
=
Image
.
fromarray
(
np
.
transpose
(
img
,
(
1
,
2
,
0
)))
if
self
.
transform
is
not
None
:
img
=
self
.
transform
(
img
)
if
self
.
target_transform
is
not
None
:
target
=
self
.
target_transform
(
target
)
return
img
,
target
def
__len__
(
self
):
return
self
.
data
.
shape
[
0
]
def
__loadfile
(
self
,
data_file
,
labels_file
=
None
):
labels
=
None
if
labels_file
:
path_to_labels
=
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
,
labels_file
)
with
open
(
path_to_labels
,
'rb'
)
as
f
:
labels
=
np
.
fromfile
(
f
,
dtype
=
np
.
uint8
)
-
1
# 0-based
path_to_data
=
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
,
data_file
)
with
open
(
path_to_data
,
'rb'
)
as
f
:
# read whole file in uint8 chunks
everything
=
np
.
fromfile
(
f
,
dtype
=
np
.
uint8
)
images
=
np
.
reshape
(
everything
,
(
-
1
,
3
,
96
,
96
))
images
=
np
.
transpose
(
images
,
(
0
,
1
,
3
,
2
))
return
images
,
labels
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment