Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
59c97d77
Commit
59c97d77
authored
Jul 26, 2019
by
Philip Meier
Committed by
Francisco Massa
Jul 26, 2019
Browse files
Miscellaneous dataset fixes (#1174)
* fix stl10 * fix lsun
parent
81021581
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
35 deletions
+52
-35
torchvision/datasets/lsun.py
torchvision/datasets/lsun.py
+30
-22
torchvision/datasets/stl10.py
torchvision/datasets/stl10.py
+22
-13
No files found.
torchvision/datasets/lsun.py
View file @
59c97d77
...
...
@@ -5,6 +5,7 @@ import os.path
import
six
import
string
import
sys
from
collections
import
Iterable
if
sys
.
version_info
[
0
]
==
2
:
import
cPickle
as
pickle
...
...
@@ -72,6 +73,24 @@ class LSUN(VisionDataset):
def
__init__
(
self
,
root
,
classes
=
'train'
,
transform
=
None
,
target_transform
=
None
):
super
(
LSUN
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
classes
=
self
.
_verify_classes
(
classes
)
# for each class, create an LSUNClassDataset
self
.
dbs
=
[]
for
c
in
self
.
classes
:
self
.
dbs
.
append
(
LSUNClass
(
root
=
root
+
'/'
+
c
+
'_lmdb'
,
transform
=
transform
))
self
.
indices
=
[]
count
=
0
for
db
in
self
.
dbs
:
count
+=
len
(
db
)
self
.
indices
.
append
(
count
)
self
.
length
=
count
def
_verify_classes
(
self
,
classes
):
categories
=
[
'bedroom'
,
'bridge'
,
'church_outdoor'
,
'classroom'
,
'conference_room'
,
'dining_room'
,
'kitchen'
,
'living_room'
,
'restaurant'
,
'tower'
]
...
...
@@ -84,39 +103,28 @@ class LSUN(VisionDataset):
else
:
classes
=
[
c
+
'_'
+
classes
for
c
in
categories
]
except
ValueError
:
# TODO: Should this check for Iterable instead of list?
if
not
isinstance
(
classes
,
list
):
raise
ValueError
if
not
isinstance
(
classes
,
Iterable
):
msg
=
(
"Expected type str or Iterable for argument classes, "
"but got type {}."
)
raise
ValueError
(
msg
.
format
(
type
(
classes
)))
classes
=
list
(
classes
)
msg_fmtstr
=
(
"Expected type str for elements in argument classes, "
"but got type {}."
)
for
c
in
classes
:
# TODO: This assumes each item is a str (or subclass). Should this
# also be checked?
verify_str_arg
(
c
,
custom_msg
=
msg_fmtstr
.
format
(
type
(
c
)))
c_short
=
c
.
split
(
'_'
)
category
,
dset_opt
=
'_'
.
join
(
c_short
[:
-
1
]),
c_short
[
-
1
]
msg_fmtstr
=
"Unknown value '{}' for {}. Valid values are {{{}}}."
msg_fmtstr
=
"Unknown value '{}' for {}. Valid values are {{{}}}."
msg
=
msg_fmtstr
.
format
(
category
,
"LSUN class"
,
iterable_to_str
(
categories
))
verify_str_arg
(
category
,
valid_values
=
categories
,
custom_msg
=
msg
)
msg
=
msg_fmtstr
.
format
(
dset_opt
,
"postfix"
,
iterable_to_str
(
dset_opts
))
verify_str_arg
(
dset_opt
,
valid_values
=
dset_opts
,
custom_msg
=
msg
)
finally
:
self
.
classes
=
classes
# for each class, create an LSUNClassDataset
self
.
dbs
=
[]
for
c
in
self
.
classes
:
self
.
dbs
.
append
(
LSUNClass
(
root
=
root
+
'/'
+
c
+
'_lmdb'
,
transform
=
transform
))
self
.
indices
=
[]
count
=
0
for
db
in
self
.
dbs
:
count
+=
len
(
db
)
self
.
indices
.
append
(
count
)
self
.
length
=
count
return
classes
def
__getitem__
(
self
,
index
):
"""
...
...
torchvision/datasets/stl10.py
View file @
59c97d77
...
...
@@ -51,7 +51,7 @@ class STL10(VisionDataset):
super
(
STL10
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
split
=
verify_str_arg
(
split
,
"split"
,
self
.
splits
)
self
.
folds
=
folds
# one of the 10 pre-defined folds or the full dataset
self
.
folds
=
self
.
_verify_folds
(
folds
)
if
download
:
self
.
download
()
...
...
@@ -89,6 +89,19 @@ class STL10(VisionDataset):
with
open
(
class_file
)
as
f
:
self
.
classes
=
f
.
read
().
splitlines
()
def
_verify_folds
(
self
,
folds
):
if
folds
is
None
:
return
folds
elif
isinstance
(
folds
,
int
):
if
folds
in
range
(
10
):
return
folds
msg
=
(
"Value for argument folds should be in the range [0, 10), "
"but got {}."
)
raise
ValueError
(
msg
.
format
(
folds
))
else
:
msg
=
"Expected type None or int for argument folds, but got type {}."
raise
ValueError
(
msg
.
format
(
type
(
folds
)))
def
__getitem__
(
self
,
index
):
"""
Args:
...
...
@@ -154,15 +167,11 @@ class STL10(VisionDataset):
def
__load_folds
(
self
,
folds
):
# loads one of the folds if specified
if
isinstance
(
folds
,
int
):
if
folds
>=
0
and
folds
<
10
:
path_to_folds
=
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
,
self
.
folds_list_file
)
with
open
(
path_to_folds
,
'r'
)
as
f
:
str_idx
=
f
.
read
().
splitlines
()[
folds
]
list_idx
=
np
.
fromstring
(
str_idx
,
dtype
=
np
.
uint8
,
sep
=
' '
)
self
.
data
,
self
.
labels
=
self
.
data
[
list_idx
,
:,
:,
:],
self
.
labels
[
list_idx
]
else
:
# FIXME: docstring allows None for folds (it is even the default value)
# Is this intended?
raise
ValueError
(
'Folds "{}" not found. Valid splits are: 0-9.'
.
format
(
folds
))
if
folds
is
None
:
return
path_to_folds
=
os
.
path
.
join
(
self
.
root
,
self
.
base_folder
,
self
.
folds_list_file
)
with
open
(
path_to_folds
,
'r'
)
as
f
:
str_idx
=
f
.
read
().
splitlines
()[
folds
]
list_idx
=
np
.
fromstring
(
str_idx
,
dtype
=
np
.
uint8
,
sep
=
' '
)
self
.
data
,
self
.
labels
=
self
.
data
[
list_idx
,
:,
:,
:],
self
.
labels
[
list_idx
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment