Commit 4886ccc8 authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

Standardize str argument verification in datasets (#1167)

* introduced function to verify str arguments

* flake8

* added FIXME to VOC

* Fixed error message

* added test for verify_str_arg

* cleanup todos

* added option for custom error message

* fix VOC

* fixed Caltech
parent d9830d86
...@@ -107,6 +107,11 @@ class Tester(unittest.TestCase): ...@@ -107,6 +107,11 @@ class Tester(unittest.TestCase):
data = nf.read() data = nf.read()
self.assertEqual(data, 'this is the content') self.assertEqual(data, 'this is the content')
def test_verify_str_arg(self):
self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))
self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
self.assertRaises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
import os.path import os.path
from .vision import VisionDataset from .vision import VisionDataset
from .utils import download_and_extract_archive, makedir_exist_ok from .utils import download_and_extract_archive, makedir_exist_ok, verify_str_arg
class Caltech101(VisionDataset): class Caltech101(VisionDataset):
...@@ -32,10 +32,10 @@ class Caltech101(VisionDataset): ...@@ -32,10 +32,10 @@ class Caltech101(VisionDataset):
transform=transform, transform=transform,
target_transform=target_transform) target_transform=target_transform)
makedir_exist_ok(self.root) makedir_exist_ok(self.root)
if isinstance(target_type, list): if not isinstance(target_type, list):
self.target_type = target_type target_type = [target_type]
else: self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation"))
self.target_type = [target_type] for t in target_type]
if download: if download:
self.download() self.download()
...@@ -88,8 +88,6 @@ class Caltech101(VisionDataset): ...@@ -88,8 +88,6 @@ class Caltech101(VisionDataset):
self.annotation_categories[self.y[index]], self.annotation_categories[self.y[index]],
"annotation_{:04d}.mat".format(self.index[index]))) "annotation_{:04d}.mat".format(self.index[index])))
target.append(data["obj_contour"]) target.append(data["obj_contour"])
else:
raise ValueError("Target type \"{}\" is not recognized.".format(t))
target = tuple(target) if len(target) > 1 else target[0] target = tuple(target) if len(target) > 1 else target[0]
if self.transform is not None: if self.transform is not None:
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
import os import os
import PIL import PIL
from .vision import VisionDataset from .vision import VisionDataset
from .utils import download_file_from_google_drive, check_integrity from .utils import download_file_from_google_drive, check_integrity, verify_str_arg
class CelebA(VisionDataset): class CelebA(VisionDataset):
...@@ -66,17 +66,14 @@ class CelebA(VisionDataset): ...@@ -66,17 +66,14 @@ class CelebA(VisionDataset):
raise RuntimeError('Dataset not found or corrupted.' + raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it') ' You can use download=True to download it')
if split.lower() == "train": split_map = {
split = 0 "train": 0,
elif split.lower() == "valid": "valid": 1,
split = 1 "test": 2,
elif split.lower() == "test": "all": None,
split = 2 }
elif split.lower() == "all": split = split_map[verify_str_arg(split.lower(), "split",
split = None ("train", "valid", "test", "all"))]
else:
raise ValueError('Wrong split entered! Please use "train", '
'"valid", "test", or "all"')
fn = partial(os.path.join, self.root, self.base_folder) fn = partial(os.path.join, self.root, self.base_folder)
splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0) splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0)
...@@ -134,6 +131,7 @@ class CelebA(VisionDataset): ...@@ -134,6 +131,7 @@ class CelebA(VisionDataset):
elif t == "landmarks": elif t == "landmarks":
target.append(self.landmarks_align[index, :]) target.append(self.landmarks_align[index, :])
else: else:
# TODO: refactor with utils.verify_str_arg
raise ValueError("Target type \"{}\" is not recognized.".format(t)) raise ValueError("Target type \"{}\" is not recognized.".format(t))
target = tuple(target) if len(target) > 1 else target[0] target = tuple(target) if len(target) > 1 else target[0]
......
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
from collections import namedtuple from collections import namedtuple
import zipfile import zipfile
from .utils import extract_archive from .utils import extract_archive, verify_str_arg, iterable_to_str
from .vision import VisionDataset from .vision import VisionDataset
from PIL import Image from PIL import Image
...@@ -109,22 +109,21 @@ class Cityscapes(VisionDataset): ...@@ -109,22 +109,21 @@ class Cityscapes(VisionDataset):
self.images = [] self.images = []
self.targets = [] self.targets = []
if mode not in ['fine', 'coarse']: verify_str_arg(mode, "mode", ("fine", "coarse"))
raise ValueError('Invalid mode! Please use mode="fine" or mode="coarse"') if mode == "fine":
valid_modes = ("train", "test", "val")
if mode == 'fine' and split not in ['train', 'test', 'val']: else:
raise ValueError('Invalid split for mode "fine"! Please use split="train", split="test"' valid_modes = ("train", "train_extra", "val")
' or split="val"') msg = ("Unknown value '{}' for argument split if mode is '{}'. "
elif mode == 'coarse' and split not in ['train', 'train_extra', 'val']: "Valid values are {{{}}}.")
raise ValueError('Invalid split for mode "coarse"! Please use split="train", split="train_extra"' msg = msg.format(split, mode, iterable_to_str(valid_modes))
' or split="val"') verify_str_arg(split, "split", valid_modes, msg)
if not isinstance(target_type, list): if not isinstance(target_type, list):
self.target_type = [target_type] self.target_type = [target_type]
[verify_str_arg(value, "target_type",
if not all(t in ['instance', 'semantic', 'polygon', 'color'] for t in self.target_type): ("instance", "semantic", "polygon", "color"))
raise ValueError('Invalid value for "target_type"! Valid values are: "instance", "semantic", "polygon"' for value in self.target_type]
' or "color"')
if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
......
...@@ -4,7 +4,8 @@ import shutil ...@@ -4,7 +4,8 @@ import shutil
import tempfile import tempfile
import torch import torch
from .folder import ImageFolder from .folder import ImageFolder
from .utils import check_integrity, download_and_extract_archive, extract_archive from .utils import check_integrity, download_and_extract_archive, extract_archive, \
verify_str_arg
ARCHIVE_DICT = { ARCHIVE_DICT = {
'train': { 'train': {
...@@ -48,7 +49,7 @@ class ImageNet(ImageFolder): ...@@ -48,7 +49,7 @@ class ImageNet(ImageFolder):
def __init__(self, root, split='train', download=False, **kwargs): def __init__(self, root, split='train', download=False, **kwargs):
root = self.root = os.path.expanduser(root) root = self.root = os.path.expanduser(root)
self.split = self._verify_split(split) self.split = verify_str_arg(split, "split", ("train", "val"))
if download: if download:
self.download() self.download()
...@@ -109,17 +110,6 @@ class ImageNet(ImageFolder): ...@@ -109,17 +110,6 @@ class ImageNet(ImageFolder):
def _save_meta_file(self, wnid_to_class, val_wnids): def _save_meta_file(self, wnid_to_class, val_wnids):
torch.save((wnid_to_class, val_wnids), self.meta_file) torch.save((wnid_to_class, val_wnids), self.meta_file)
def _verify_split(self, split):
if split not in self.valid_splits:
msg = "Unknown split {} .".format(split)
msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits))
raise ValueError(msg)
return split
@property
def valid_splits(self):
return 'train', 'val'
@property @property
def split_folder(self): def split_folder(self):
return os.path.join(self.root, self.split) return os.path.join(self.root, self.split)
......
...@@ -11,6 +11,8 @@ if sys.version_info[0] == 2: ...@@ -11,6 +11,8 @@ if sys.version_info[0] == 2:
else: else:
import pickle import pickle
from .utils import verify_str_arg, iterable_to_str
class LSUNClass(VisionDataset): class LSUNClass(VisionDataset):
def __init__(self, root, transform=None, target_transform=None): def __init__(self, root, transform=None, target_transform=None):
...@@ -75,26 +77,30 @@ class LSUN(VisionDataset): ...@@ -75,26 +77,30 @@ class LSUN(VisionDataset):
'living_room', 'restaurant', 'tower'] 'living_room', 'restaurant', 'tower']
dset_opts = ['train', 'val', 'test'] dset_opts = ['train', 'val', 'test']
if type(classes) == str and classes in dset_opts: try:
verify_str_arg(classes, "classes", dset_opts)
if classes == 'test': if classes == 'test':
classes = [classes] classes = [classes]
else: else:
classes = [c + '_' + classes for c in categories] classes = [c + '_' + classes for c in categories]
elif type(classes) == list: except ValueError:
# TODO: Should this check for Iterable instead of list?
if not isinstance(classes, list):
raise ValueError
for c in classes: for c in classes:
# TODO: This assumes each item is a str (or subclass). Should this
# also be checked?
c_short = c.split('_') c_short = c.split('_')
c_short.pop(len(c_short) - 1) category, dset_opt = '_'.join(c_short[:-1]), c_short[-1]
c_short = '_'.join(c_short) msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
if c_short not in categories:
raise (ValueError('Unknown LSUN class: ' + c_short + '.' msg = msg_fmtstr.format(category, "LSUN class",
'Options are: ' + str(categories))) iterable_to_str(categories))
c_short = c.split('_') verify_str_arg(category, valid_values=categories, custom_msg=msg)
c_short = c_short.pop(len(c_short) - 1)
if c_short not in dset_opts: msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
raise (ValueError('Unknown postfix: ' + c_short + '.' verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
'Options are: ' + str(dset_opts))) finally:
else:
raise (ValueError('Unknown option for classes'))
self.classes = classes self.classes = classes
# for each class, create an LSUNClassDataset # for each class, create an LSUNClassDataset
......
...@@ -7,7 +7,8 @@ import os.path ...@@ -7,7 +7,8 @@ import os.path
import numpy as np import numpy as np
import torch import torch
import codecs import codecs
from .utils import download_url, download_and_extract_archive, extract_archive, makedir_exist_ok from .utils import download_url, download_and_extract_archive, extract_archive, \
makedir_exist_ok, verify_str_arg
class MNIST(VisionDataset): class MNIST(VisionDataset):
...@@ -230,11 +231,7 @@ class EMNIST(MNIST): ...@@ -230,11 +231,7 @@ class EMNIST(MNIST):
splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist') splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')
def __init__(self, root, split, **kwargs): def __init__(self, root, split, **kwargs):
if split not in self.splits: self.split = verify_str_arg(split, "split", self.splits)
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
self.split = split
self.training_file = self._training_file(split) self.training_file = self._training_file(split)
self.test_file = self._test_file(split) self.test_file = self._test_file(split)
super(EMNIST, self).__init__(root, **kwargs) super(EMNIST, self).__init__(root, **kwargs)
...@@ -336,10 +333,7 @@ class QMNIST(MNIST): ...@@ -336,10 +333,7 @@ class QMNIST(MNIST):
def __init__(self, root, what=None, compat=True, train=True, **kwargs): def __init__(self, root, what=None, compat=True, train=True, **kwargs):
if what is None: if what is None:
what = 'train' if train else 'test' what = 'train' if train else 'test'
if not self.subsets.get(what): self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
raise RuntimeError("Argument 'what' should be one of: \n " +
repr(tuple(self.subsets.keys())))
self.what = what
self.compat = compat self.compat = compat
self.data_file = what + '.pt' self.data_file = what + '.pt'
self.training_file = self.data_file self.training_file = self.data_file
......
...@@ -5,7 +5,7 @@ from .vision import VisionDataset ...@@ -5,7 +5,7 @@ from .vision import VisionDataset
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from .utils import download_url from .utils import download_url, verify_str_arg
from .voc import download_extract from .voc import download_extract
...@@ -64,12 +64,9 @@ class SBDataset(VisionDataset): ...@@ -64,12 +64,9 @@ class SBDataset(VisionDataset):
"pip install scipy") "pip install scipy")
super(SBDataset, self).__init__(root, transforms) super(SBDataset, self).__init__(root, transforms)
self.image_set = verify_str_arg(image_set, "image_set",
if mode not in ("segmentation", "boundaries"): ("train", "val", "train_noval"))
raise ValueError("Argument mode should be 'segmentation' or 'boundaries'") self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries"))
self.image_set = image_set
self.mode = mode
self.num_classes = 20 self.num_classes = 20
sbd_root = self.root sbd_root = self.root
...@@ -91,11 +88,6 @@ class SBDataset(VisionDataset): ...@@ -91,11 +88,6 @@ class SBDataset(VisionDataset):
split_f = os.path.join(sbd_root, image_set.rstrip('\n') + '.txt') split_f = os.path.join(sbd_root, image_set.rstrip('\n') + '.txt')
if not os.path.exists(split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="val" or image_set="train_noval"')
with open(os.path.join(split_f), "r") as f: with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()] file_names = [x.strip() for x in f.readlines()]
......
...@@ -5,7 +5,7 @@ import os.path ...@@ -5,7 +5,7 @@ import os.path
import numpy as np import numpy as np
from .vision import VisionDataset from .vision import VisionDataset
from .utils import check_integrity, download_and_extract_archive from .utils import check_integrity, download_and_extract_archive, verify_str_arg
class STL10(VisionDataset): class STL10(VisionDataset):
...@@ -48,13 +48,9 @@ class STL10(VisionDataset): ...@@ -48,13 +48,9 @@ class STL10(VisionDataset):
def __init__(self, root, split='train', folds=None, transform=None, def __init__(self, root, split='train', folds=None, transform=None,
target_transform=None, download=False): target_transform=None, download=False):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
super(STL10, self).__init__(root, transform=transform, super(STL10, self).__init__(root, transform=transform,
target_transform=target_transform) target_transform=target_transform)
self.split = split # train/test/unlabeled set self.split = verify_str_arg(split, "split", self.splits)
self.folds = folds # one of the 10 pre-defined folds or the full dataset self.folds = folds # one of the 10 pre-defined folds or the full dataset
if download: if download:
...@@ -167,4 +163,6 @@ class STL10(VisionDataset): ...@@ -167,4 +163,6 @@ class STL10(VisionDataset):
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ') list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ')
self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx] self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx]
else: else:
# FIXME: docstring allows None for folds (it is even the default value)
# Is this intended?
raise ValueError('Folds "{}" not found. Valid splits are: 0-9.'.format(folds)) raise ValueError('Folds "{}" not found. Valid splits are: 0-9.'.format(folds))
...@@ -4,7 +4,7 @@ from PIL import Image ...@@ -4,7 +4,7 @@ from PIL import Image
import os import os
import os.path import os.path
import numpy as np import numpy as np
from .utils import download_url, check_integrity from .utils import download_url, check_integrity, verify_str_arg
class SVHN(VisionDataset): class SVHN(VisionDataset):
...@@ -43,12 +43,7 @@ class SVHN(VisionDataset): ...@@ -43,12 +43,7 @@ class SVHN(VisionDataset):
download=False): download=False):
super(SVHN, self).__init__(root, transform=transform, super(SVHN, self).__init__(root, transform=transform,
target_transform=target_transform) target_transform=target_transform)
self.split = split # training set or test set or extra set self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
if self.split not in self.split_list:
raise ValueError('Wrong split entered! Please use split="train" '
'or split="extra" or split="test"')
self.url = self.split_list[split][0] self.url = self.split_list[split][0]
self.filename = self.split_list[split][1] self.filename = self.split_list[split][1]
self.file_md5 = self.split_list[split][2] self.file_md5 = self.split_list[split][2]
......
...@@ -6,6 +6,7 @@ import errno ...@@ -6,6 +6,7 @@ import errno
import tarfile import tarfile
import zipfile import zipfile
import torch
from torch.utils.model_zoo import tqdm from torch.utils.model_zoo import tqdm
...@@ -249,3 +250,32 @@ def download_and_extract_archive(url, download_root, extract_root=None, filename ...@@ -249,3 +250,32 @@ def download_and_extract_archive(url, download_root, extract_root=None, filename
archive = os.path.join(download_root, filename) archive = os.path.join(download_root, filename)
print("Extracting {} to {}".format(archive, extract_root)) print("Extracting {} to {}".format(archive, extract_root))
extract_archive(archive, extract_root, remove_finished) extract_archive(archive, extract_root, remove_finished)
def iterable_to_str(iterable):
return "'" + "', '".join([str(item) for item in iterable]) + "'"
def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None):
if not isinstance(value, torch._six.string_classes):
if arg is None:
msg = "Expected type str, but got type {type}."
else:
msg = "Expected type str for argument {arg}, but got type {type}."
msg = msg.format(type=type(value), arg=arg)
raise ValueError(msg)
if valid_values is None:
return value
if value not in valid_values:
if custom_msg is not None:
msg = custom_msg
else:
msg = ("Unknown value '{value}' for argument {arg}. "
"Valid values are {{{valid_values}}}.")
msg = msg.format(value=value, arg=arg,
valid_values=iterable_to_str(valid_values))
raise ValueError(msg)
return value
...@@ -10,7 +10,7 @@ else: ...@@ -10,7 +10,7 @@ else:
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from PIL import Image from PIL import Image
from .utils import download_url, check_integrity from .utils import download_url, check_integrity, verify_str_arg
DATASET_YEAR_DICT = { DATASET_YEAR_DICT = {
'2012': { '2012': {
...@@ -83,7 +83,8 @@ class VOCSegmentation(VisionDataset): ...@@ -83,7 +83,8 @@ class VOCSegmentation(VisionDataset):
self.url = DATASET_YEAR_DICT[year]['url'] self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename'] self.filename = DATASET_YEAR_DICT[year]['filename']
self.md5 = DATASET_YEAR_DICT[year]['md5'] self.md5 = DATASET_YEAR_DICT[year]['md5']
self.image_set = image_set self.image_set = verify_str_arg(image_set, "image_set",
("train", "trainval", "val"))
base_dir = DATASET_YEAR_DICT[year]['base_dir'] base_dir = DATASET_YEAR_DICT[year]['base_dir']
voc_root = os.path.join(self.root, base_dir) voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, 'JPEGImages') image_dir = os.path.join(voc_root, 'JPEGImages')
...@@ -100,11 +101,6 @@ class VOCSegmentation(VisionDataset): ...@@ -100,11 +101,6 @@ class VOCSegmentation(VisionDataset):
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
if not os.path.exists(split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="trainval" or image_set="val"')
with open(os.path.join(split_f), "r") as f: with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()] file_names = [x.strip() for x in f.readlines()]
...@@ -164,7 +160,8 @@ class VOCDetection(VisionDataset): ...@@ -164,7 +160,8 @@ class VOCDetection(VisionDataset):
self.url = DATASET_YEAR_DICT[year]['url'] self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename'] self.filename = DATASET_YEAR_DICT[year]['filename']
self.md5 = DATASET_YEAR_DICT[year]['md5'] self.md5 = DATASET_YEAR_DICT[year]['md5']
self.image_set = image_set self.image_set = verify_str_arg(image_set, "image_set",
("train", "trainval", "val"))
base_dir = DATASET_YEAR_DICT[year]['base_dir'] base_dir = DATASET_YEAR_DICT[year]['base_dir']
voc_root = os.path.join(self.root, base_dir) voc_root = os.path.join(self.root, base_dir)
...@@ -182,12 +179,6 @@ class VOCDetection(VisionDataset): ...@@ -182,12 +179,6 @@ class VOCDetection(VisionDataset):
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
if not os.path.exists(split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="trainval" or image_set="val" or a valid'
'image_set from the VOC ImageSets/Main folder.')
with open(os.path.join(split_f), "r") as f: with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()] file_names = [x.strip() for x in f.readlines()]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment