Commit 59c97d77 authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

Miscellaneous dataset fixes (#1174)

* fix stl10

* fix lsun
parent 81021581
......@@ -5,6 +5,7 @@ import os.path
import six
import string
import sys
from collections import Iterable
if sys.version_info[0] == 2:
import cPickle as pickle
......@@ -72,6 +73,24 @@ class LSUN(VisionDataset):
def __init__(self, root, classes='train', transform=None, target_transform=None):
super(LSUN, self).__init__(root, transform=transform,
target_transform=target_transform)
self.classes = self._verify_classes(classes)
# for each class, create an LSUNClassDataset
self.dbs = []
for c in self.classes:
self.dbs.append(LSUNClass(
root=root + '/' + c + '_lmdb',
transform=transform))
self.indices = []
count = 0
for db in self.dbs:
count += len(db)
self.indices.append(count)
self.length = count
def _verify_classes(self, classes):
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
'conference_room', 'dining_room', 'kitchen',
'living_room', 'restaurant', 'tower']
......@@ -84,39 +103,28 @@ class LSUN(VisionDataset):
else:
classes = [c + '_' + classes for c in categories]
except ValueError:
# TODO: Should this check for Iterable instead of list?
if not isinstance(classes, list):
raise ValueError
if not isinstance(classes, Iterable):
msg = ("Expected type str or Iterable for argument classes, "
"but got type {}.")
raise ValueError(msg.format(type(classes)))
classes = list(classes)
msg_fmtstr = ("Expected type str for elements in argument classes, "
"but got type {}.")
for c in classes:
# TODO: This assumes each item is a str (or subclass). Should this
# also be checked?
verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c)))
c_short = c.split('_')
category, dset_opt = '_'.join(c_short[:-1]), c_short[-1]
msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
msg = msg_fmtstr.format(category, "LSUN class",
iterable_to_str(categories))
verify_str_arg(category, valid_values=categories, custom_msg=msg)
msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
finally:
self.classes = classes
# for each class, create an LSUNClassDataset
self.dbs = []
for c in self.classes:
self.dbs.append(LSUNClass(
root=root + '/' + c + '_lmdb',
transform=transform))
self.indices = []
count = 0
for db in self.dbs:
count += len(db)
self.indices.append(count)
self.length = count
return classes
def __getitem__(self, index):
"""
......
......@@ -51,7 +51,7 @@ class STL10(VisionDataset):
super(STL10, self).__init__(root, transform=transform,
target_transform=target_transform)
self.split = verify_str_arg(split, "split", self.splits)
self.folds = folds # one of the 10 pre-defined folds or the full dataset
self.folds = self._verify_folds(folds)
if download:
self.download()
......@@ -89,6 +89,19 @@ class STL10(VisionDataset):
with open(class_file) as f:
self.classes = f.read().splitlines()
def _verify_folds(self, folds):
if folds is None:
return folds
elif isinstance(folds, int):
if folds in range(10):
return folds
msg = ("Value for argument folds should be in the range [0, 10), "
"but got {}.")
raise ValueError(msg.format(folds))
else:
msg = "Expected type None or int for argument folds, but got type {}."
raise ValueError(msg.format(type(folds)))
def __getitem__(self, index):
"""
Args:
......@@ -154,15 +167,11 @@ class STL10(VisionDataset):
def __load_folds(self, folds):
# loads one of the folds if specified
if isinstance(folds, int):
if folds >= 0 and folds < 10:
path_to_folds = os.path.join(
self.root, self.base_folder, self.folds_list_file)
with open(path_to_folds, 'r') as f:
str_idx = f.read().splitlines()[folds]
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ')
self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx]
else:
# FIXME: docstring allows None for folds (it is even the default value)
# Is this intended?
raise ValueError('Folds "{}" not found. Valid splits are: 0-9.'.format(folds))
if folds is None:
return
path_to_folds = os.path.join(
self.root, self.base_folder, self.folds_list_file)
with open(path_to_folds, 'r') as f:
str_idx = f.read().splitlines()[folds]
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ')
self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment