Commit 4886ccc8 authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

Standardize str argument verification in datasets (#1167)

* introduced function to verify str arguments

* flake8

* added FIXME to VOC

* Fixed error message

* added test for verify_str_arg

* cleanup todos

* added option for custom error message

* fix VOC

* fixed Caltech
parent d9830d86
......@@ -107,6 +107,11 @@ class Tester(unittest.TestCase):
data = nf.read()
self.assertEqual(data, 'this is the content')
def test_verify_str_arg(self):
self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))
self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
self.assertRaises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
if __name__ == '__main__':
unittest.main()
......@@ -4,7 +4,7 @@ import os
import os.path
from .vision import VisionDataset
from .utils import download_and_extract_archive, makedir_exist_ok
from .utils import download_and_extract_archive, makedir_exist_ok, verify_str_arg
class Caltech101(VisionDataset):
......@@ -32,10 +32,10 @@ class Caltech101(VisionDataset):
transform=transform,
target_transform=target_transform)
makedir_exist_ok(self.root)
if isinstance(target_type, list):
self.target_type = target_type
else:
self.target_type = [target_type]
if not isinstance(target_type, list):
target_type = [target_type]
self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation"))
for t in target_type]
if download:
self.download()
......@@ -88,8 +88,6 @@ class Caltech101(VisionDataset):
self.annotation_categories[self.y[index]],
"annotation_{:04d}.mat".format(self.index[index])))
target.append(data["obj_contour"])
else:
raise ValueError("Target type \"{}\" is not recognized.".format(t))
target = tuple(target) if len(target) > 1 else target[0]
if self.transform is not None:
......
......@@ -3,7 +3,7 @@ import torch
import os
import PIL
from .vision import VisionDataset
from .utils import download_file_from_google_drive, check_integrity
from .utils import download_file_from_google_drive, check_integrity, verify_str_arg
class CelebA(VisionDataset):
......@@ -66,17 +66,14 @@ class CelebA(VisionDataset):
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
if split.lower() == "train":
split = 0
elif split.lower() == "valid":
split = 1
elif split.lower() == "test":
split = 2
elif split.lower() == "all":
split = None
else:
raise ValueError('Wrong split entered! Please use "train", '
'"valid", "test", or "all"')
split_map = {
"train": 0,
"valid": 1,
"test": 2,
"all": None,
}
split = split_map[verify_str_arg(split.lower(), "split",
("train", "valid", "test", "all"))]
fn = partial(os.path.join, self.root, self.base_folder)
splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0)
......@@ -134,6 +131,7 @@ class CelebA(VisionDataset):
elif t == "landmarks":
target.append(self.landmarks_align[index, :])
else:
# TODO: refactor with utils.verify_str_arg
raise ValueError("Target type \"{}\" is not recognized.".format(t))
target = tuple(target) if len(target) > 1 else target[0]
......
......@@ -3,7 +3,7 @@ import os
from collections import namedtuple
import zipfile
from .utils import extract_archive
from .utils import extract_archive, verify_str_arg, iterable_to_str
from .vision import VisionDataset
from PIL import Image
......@@ -109,22 +109,21 @@ class Cityscapes(VisionDataset):
self.images = []
self.targets = []
if mode not in ['fine', 'coarse']:
raise ValueError('Invalid mode! Please use mode="fine" or mode="coarse"')
if mode == 'fine' and split not in ['train', 'test', 'val']:
raise ValueError('Invalid split for mode "fine"! Please use split="train", split="test"'
' or split="val"')
elif mode == 'coarse' and split not in ['train', 'train_extra', 'val']:
raise ValueError('Invalid split for mode "coarse"! Please use split="train", split="train_extra"'
' or split="val"')
verify_str_arg(mode, "mode", ("fine", "coarse"))
if mode == "fine":
valid_modes = ("train", "test", "val")
else:
valid_modes = ("train", "train_extra", "val")
msg = ("Unknown value '{}' for argument split if mode is '{}'. "
"Valid values are {{{}}}.")
msg = msg.format(split, mode, iterable_to_str(valid_modes))
verify_str_arg(split, "split", valid_modes, msg)
if not isinstance(target_type, list):
self.target_type = [target_type]
if not all(t in ['instance', 'semantic', 'polygon', 'color'] for t in self.target_type):
raise ValueError('Invalid value for "target_type"! Valid values are: "instance", "semantic", "polygon"'
' or "color"')
[verify_str_arg(value, "target_type",
("instance", "semantic", "polygon", "color"))
for value in self.target_type]
if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
......
......@@ -4,7 +4,8 @@ import shutil
import tempfile
import torch
from .folder import ImageFolder
from .utils import check_integrity, download_and_extract_archive, extract_archive
from .utils import check_integrity, download_and_extract_archive, extract_archive, \
verify_str_arg
ARCHIVE_DICT = {
'train': {
......@@ -48,7 +49,7 @@ class ImageNet(ImageFolder):
def __init__(self, root, split='train', download=False, **kwargs):
root = self.root = os.path.expanduser(root)
self.split = self._verify_split(split)
self.split = verify_str_arg(split, "split", ("train", "val"))
if download:
self.download()
......@@ -109,17 +110,6 @@ class ImageNet(ImageFolder):
def _save_meta_file(self, wnid_to_class, val_wnids):
torch.save((wnid_to_class, val_wnids), self.meta_file)
def _verify_split(self, split):
if split not in self.valid_splits:
msg = "Unknown split {} .".format(split)
msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits))
raise ValueError(msg)
return split
@property
def valid_splits(self):
return 'train', 'val'
@property
def split_folder(self):
return os.path.join(self.root, self.split)
......
......@@ -11,6 +11,8 @@ if sys.version_info[0] == 2:
else:
import pickle
from .utils import verify_str_arg, iterable_to_str
class LSUNClass(VisionDataset):
def __init__(self, root, transform=None, target_transform=None):
......@@ -75,26 +77,30 @@ class LSUN(VisionDataset):
'living_room', 'restaurant', 'tower']
dset_opts = ['train', 'val', 'test']
if type(classes) == str and classes in dset_opts:
try:
verify_str_arg(classes, "classes", dset_opts)
if classes == 'test':
classes = [classes]
else:
classes = [c + '_' + classes for c in categories]
elif type(classes) == list:
except ValueError:
# TODO: Should this check for Iterable instead of list?
if not isinstance(classes, list):
raise ValueError
for c in classes:
# TODO: This assumes each item is a str (or subclass). Should this
# also be checked?
c_short = c.split('_')
c_short.pop(len(c_short) - 1)
c_short = '_'.join(c_short)
if c_short not in categories:
raise (ValueError('Unknown LSUN class: ' + c_short + '.'
'Options are: ' + str(categories)))
c_short = c.split('_')
c_short = c_short.pop(len(c_short) - 1)
if c_short not in dset_opts:
raise (ValueError('Unknown postfix: ' + c_short + '.'
'Options are: ' + str(dset_opts)))
else:
raise (ValueError('Unknown option for classes'))
category, dset_opt = '_'.join(c_short[:-1]), c_short[-1]
msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
msg = msg_fmtstr.format(category, "LSUN class",
iterable_to_str(categories))
verify_str_arg(category, valid_values=categories, custom_msg=msg)
msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
finally:
self.classes = classes
# for each class, create an LSUNClassDataset
......
......@@ -7,7 +7,8 @@ import os.path
import numpy as np
import torch
import codecs
from .utils import download_url, download_and_extract_archive, extract_archive, makedir_exist_ok
from .utils import download_url, download_and_extract_archive, extract_archive, \
makedir_exist_ok, verify_str_arg
class MNIST(VisionDataset):
......@@ -230,11 +231,7 @@ class EMNIST(MNIST):
splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')
def __init__(self, root, split, **kwargs):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
self.split = split
self.split = verify_str_arg(split, "split", self.splits)
self.training_file = self._training_file(split)
self.test_file = self._test_file(split)
super(EMNIST, self).__init__(root, **kwargs)
......@@ -336,10 +333,7 @@ class QMNIST(MNIST):
def __init__(self, root, what=None, compat=True, train=True, **kwargs):
if what is None:
what = 'train' if train else 'test'
if not self.subsets.get(what):
raise RuntimeError("Argument 'what' should be one of: \n " +
repr(tuple(self.subsets.keys())))
self.what = what
self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
self.compat = compat
self.data_file = what + '.pt'
self.training_file = self.data_file
......
......@@ -5,7 +5,7 @@ from .vision import VisionDataset
import numpy as np
from PIL import Image
from .utils import download_url
from .utils import download_url, verify_str_arg
from .voc import download_extract
......@@ -64,12 +64,9 @@ class SBDataset(VisionDataset):
"pip install scipy")
super(SBDataset, self).__init__(root, transforms)
if mode not in ("segmentation", "boundaries"):
raise ValueError("Argument mode should be 'segmentation' or 'boundaries'")
self.image_set = image_set
self.mode = mode
self.image_set = verify_str_arg(image_set, "image_set",
("train", "val", "train_noval"))
self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries"))
self.num_classes = 20
sbd_root = self.root
......@@ -91,11 +88,6 @@ class SBDataset(VisionDataset):
split_f = os.path.join(sbd_root, image_set.rstrip('\n') + '.txt')
if not os.path.exists(split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="val" or image_set="train_noval"')
with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]
......
......@@ -5,7 +5,7 @@ import os.path
import numpy as np
from .vision import VisionDataset
from .utils import check_integrity, download_and_extract_archive
from .utils import check_integrity, download_and_extract_archive, verify_str_arg
class STL10(VisionDataset):
......@@ -48,13 +48,9 @@ class STL10(VisionDataset):
def __init__(self, root, split='train', folds=None, transform=None,
target_transform=None, download=False):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
super(STL10, self).__init__(root, transform=transform,
target_transform=target_transform)
self.split = split # train/test/unlabeled set
self.split = verify_str_arg(split, "split", self.splits)
self.folds = folds # one of the 10 pre-defined folds or the full dataset
if download:
......@@ -167,4 +163,6 @@ class STL10(VisionDataset):
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ')
self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx]
else:
# FIXME: docstring allows None for folds (it is even the default value)
# Is this intended?
raise ValueError('Folds "{}" not found. Valid splits are: 0-9.'.format(folds))
......@@ -4,7 +4,7 @@ from PIL import Image
import os
import os.path
import numpy as np
from .utils import download_url, check_integrity
from .utils import download_url, check_integrity, verify_str_arg
class SVHN(VisionDataset):
......@@ -43,12 +43,7 @@ class SVHN(VisionDataset):
download=False):
super(SVHN, self).__init__(root, transform=transform,
target_transform=target_transform)
self.split = split # training set or test set or extra set
if self.split not in self.split_list:
raise ValueError('Wrong split entered! Please use split="train" '
'or split="extra" or split="test"')
self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
self.url = self.split_list[split][0]
self.filename = self.split_list[split][1]
self.file_md5 = self.split_list[split][2]
......
......@@ -6,6 +6,7 @@ import errno
import tarfile
import zipfile
import torch
from torch.utils.model_zoo import tqdm
......@@ -249,3 +250,32 @@ def download_and_extract_archive(url, download_root, extract_root=None, filename
archive = os.path.join(download_root, filename)
print("Extracting {} to {}".format(archive, extract_root))
extract_archive(archive, extract_root, remove_finished)
def iterable_to_str(iterable):
return "'" + "', '".join([str(item) for item in iterable]) + "'"
def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None):
if not isinstance(value, torch._six.string_classes):
if arg is None:
msg = "Expected type str, but got type {type}."
else:
msg = "Expected type str for argument {arg}, but got type {type}."
msg = msg.format(type=type(value), arg=arg)
raise ValueError(msg)
if valid_values is None:
return value
if value not in valid_values:
if custom_msg is not None:
msg = custom_msg
else:
msg = ("Unknown value '{value}' for argument {arg}. "
"Valid values are {{{valid_values}}}.")
msg = msg.format(value=value, arg=arg,
valid_values=iterable_to_str(valid_values))
raise ValueError(msg)
return value
......@@ -10,7 +10,7 @@ else:
import xml.etree.ElementTree as ET
from PIL import Image
from .utils import download_url, check_integrity
from .utils import download_url, check_integrity, verify_str_arg
DATASET_YEAR_DICT = {
'2012': {
......@@ -83,7 +83,8 @@ class VOCSegmentation(VisionDataset):
self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename']
self.md5 = DATASET_YEAR_DICT[year]['md5']
self.image_set = image_set
self.image_set = verify_str_arg(image_set, "image_set",
("train", "trainval", "val"))
base_dir = DATASET_YEAR_DICT[year]['base_dir']
voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, 'JPEGImages')
......@@ -100,11 +101,6 @@ class VOCSegmentation(VisionDataset):
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
if not os.path.exists(split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="trainval" or image_set="val"')
with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]
......@@ -164,7 +160,8 @@ class VOCDetection(VisionDataset):
self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename']
self.md5 = DATASET_YEAR_DICT[year]['md5']
self.image_set = image_set
self.image_set = verify_str_arg(image_set, "image_set",
("train", "trainval", "val"))
base_dir = DATASET_YEAR_DICT[year]['base_dir']
voc_root = os.path.join(self.root, base_dir)
......@@ -182,12 +179,6 @@ class VOCDetection(VisionDataset):
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
if not os.path.exists(split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="trainval" or image_set="val" or a valid'
'image_set from the VOC ImageSets/Main folder.')
with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment