Unverified Commit 09600b90 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Enable custom default config for dataset tests (#3578)

parent 661d89fd
......@@ -117,19 +117,45 @@ def requires_lazy_imports(*modules):
def test_all_configs(test):
"""Decorator to run test against all configurations.
Add this as decorator to an arbitrary test to run it against all configurations. The current configuration is
provided as the first parameter:
Add this as decorator to an arbitrary test to run it against all configurations. This includes
:attr:`DatasetTestCase.DEFAULT_CONFIG` and :attr:`DatasetTestCase.ADDITIONAL_CONFIGS`.
The current configuration is provided as the first parameter for the test:
.. code-block::
@test_all_configs
@test_all_configs()
def test_foo(self, config):
pass
.. note::
This will try to remove duplicate configurations. During this process it will not not preserve a potential
ordering of the configurations or an inner ordering of a configuration.
"""
def maybe_remove_duplicates(configs):
try:
return [dict(config_) for config_ in set(tuple(sorted(config.items())) for config in configs)]
except TypeError:
# A TypeError will be raised if a value of any config is not hashable, e.g. a list. In that case duplicate
# removal would be a lot more elaborate and we simply bail out.
return configs
@functools.wraps(test)
def wrapper(self):
for config in self.CONFIGS or (self._DEFAULT_CONFIG,):
configs = []
if self.DEFAULT_CONFIG is not None:
configs.append(self.DEFAULT_CONFIG)
if self.ADDITIONAL_CONFIGS is not None:
configs.extend(self.ADDITIONAL_CONFIGS)
if not configs:
configs = [self._KWARG_DEFAULTS.copy()]
else:
configs = maybe_remove_duplicates(configs)
for config in configs:
with self.subTest(**config):
test(self, config)
......@@ -166,9 +192,13 @@ class DatasetTestCase(unittest.TestCase):
Optionally, you can overwrite the following class attributes:
- CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictonary can contain an
arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``,
``transforms``, or ``download``. The first element will be used as default configuration.
- DEFAULT_CONFIG (Dict[str, Any]): Config that will be used by default. If omitted, this defaults to all
keyword arguments of the dataset minus ``transform``, ``target_transform``, ``transforms``, and
``download``. Overwrite this if you want to use a default value for a parameter for which the dataset does
not provide one.
- ADDITIONAL_CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictionary can
contain an arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``,
``transforms``, or ``download``.
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
available, the tests are skipped.
......@@ -218,22 +248,31 @@ class DatasetTestCase(unittest.TestCase):
DATASET_CLASS = None
FEATURE_TYPES = None
CONFIGS = None
DEFAULT_CONFIG = None
ADDITIONAL_CONFIGS = None
REQUIRED_PACKAGES = None
_DEFAULT_CONFIG = None
# These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
_TRANSFORM_KWARGS = {
"transform",
"target_transform",
"transforms",
}
# These keyword arguments get a 'special' treatment and should not be set in DEFAULT_CONFIG or ADDITIONAL_CONFIGS.
_SPECIAL_KWARGS = {
*_TRANSFORM_KWARGS,
"download",
}
# These fields are populated during setupClass() within _populate_private_class_attributes()
# This will be a dictionary containing all keyword arguments with their respective default values extracted from
# the dataset constructor.
_KWARG_DEFAULTS = None
# This will be a set of all _SPECIAL_KWARGS that the dataset constructor takes.
_HAS_SPECIAL_KWARG = None
# These functions are disabled during dataset creation in create_dataset().
_CHECK_FUNCTIONS = {
"check_md5",
"check_integrity",
......@@ -256,7 +295,8 @@ class DatasetTestCase(unittest.TestCase):
Args:
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
to be created and in turn also for the fake data injected here.
config (Dict[str, Any]): Configuration that will be used to create the dataset.
config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
fields for all dataset parameters with default values.
Returns:
(Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets.
......@@ -273,7 +313,8 @@ class DatasetTestCase(unittest.TestCase):
Args:
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
to be created and in turn also for the fake data injected here.
config (Dict[str, Any]): Configuration that will be used to create the dataset.
config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
fields for all dataset parameters with default values.
Needs to return one of the following:
......@@ -293,9 +334,16 @@ class DatasetTestCase(unittest.TestCase):
) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
r"""Create the dataset in a temporary directory.
The configuration passed to the dataset is populated to contain at least all parameters with default values.
For this the following order of precedence is used:
1. Parameters in :attr:`kwargs`.
2. Configuration in :attr:`config`.
3. Configuration in :attr:`~DatasetTestCase.DEFAULT_CONFIG`.
4. Default parameters of the dataset.
Args:
config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset. If omitted, the
default configuration is used.
config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset.
inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
creating the dataset.
patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
......@@ -308,30 +356,33 @@ class DatasetTestCase(unittest.TestCase):
info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
for details.
"""
default_config = self._DEFAULT_CONFIG.copy()
if config is not None:
default_config.update(config)
config = default_config
if patch_checks is None:
patch_checks = inject_fake_data
special_kwargs, other_kwargs = self._split_kwargs(kwargs)
complete_config = self._KWARG_DEFAULTS.copy()
if self.DEFAULT_CONFIG:
complete_config.update(self.DEFAULT_CONFIG)
if config:
complete_config.update(config)
if other_kwargs:
complete_config.update(other_kwargs)
if "download" in self._HAS_SPECIAL_KWARG and special_kwargs.get("download", False):
# override download param to False param if its default is truthy
special_kwargs["download"] = False
config.update(other_kwargs)
patchers = self._patch_download_extract()
if patch_checks:
patchers.update(self._patch_checks())
with get_tmp_dir() as tmpdir:
args = self.dataset_args(tmpdir, config)
info = self._inject_fake_data(tmpdir, config) if inject_fake_data else None
args = self.dataset_args(tmpdir, complete_config)
info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None
with self._maybe_apply_patches(patchers), disable_console_output():
dataset = self.DATASET_CLASS(*args, **config, **special_kwargs)
dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs)
yield dataset, info
......@@ -357,26 +408,69 @@ class DatasetTestCase(unittest.TestCase):
@classmethod
def _populate_private_class_attributes(cls):
argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__)
defaults = []
for cls_ in cls.DATASET_CLASS.__mro__:
if cls_ is torchvision.datasets.VisionDataset:
break
cls._DEFAULT_CONFIG = {
kwarg: default
for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults)
if kwarg not in cls._SPECIAL_KWARGS
}
argspec = inspect.getfullargspec(cls_.__init__)
cls._HAS_SPECIAL_KWARG = {name for name in cls._SPECIAL_KWARGS if name in argspec.args}
if not argspec.defaults:
continue
defaults.append(
{kwarg: default for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults)}
)
if not argspec.varkw:
break
kwarg_defaults = dict()
for config in reversed(defaults):
kwarg_defaults.update(config)
has_special_kwargs = set()
for name in cls._SPECIAL_KWARGS:
if name not in kwarg_defaults:
continue
del kwarg_defaults[name]
has_special_kwargs.add(name)
cls._KWARG_DEFAULTS = kwarg_defaults
cls._HAS_SPECIAL_KWARG = has_special_kwargs
@classmethod
def _process_optional_public_class_attributes(cls):
if cls.REQUIRED_PACKAGES is not None:
try:
def check_config(config, name):
special_kwargs = tuple(f"'{name}'" for name in cls._SPECIAL_KWARGS if name in config)
if special_kwargs:
raise UsageError(
f"{name} contains a value for the parameter(s) {', '.join(special_kwargs)}. "
f"These are handled separately by the test case and should not be set here. "
f"If you need to test some custom behavior regarding these parameters, "
f"you need to write a custom test (*not* test case), e.g. test_custom_transform()."
)
if cls.DEFAULT_CONFIG is not None:
check_config(cls.DEFAULT_CONFIG, "DEFAULT_CONFIG")
if cls.ADDITIONAL_CONFIGS is not None:
for idx, config in enumerate(cls.ADDITIONAL_CONFIGS):
check_config(config, f"CONFIGS[{idx}]")
if cls.REQUIRED_PACKAGES:
missing_pkgs = []
for pkg in cls.REQUIRED_PACKAGES:
try:
importlib.import_module(pkg)
except ImportError as error:
except ImportError:
missing_pkgs.append(f"'{pkg}'")
if missing_pkgs:
raise unittest.SkipTest(
f"The package '{error.name}' is required to load the dataset '{cls.DATASET_CLASS.__name__}' but is "
f"not installed."
f"The package(s) {', '.join(missing_pkgs)} are required to load the dataset "
f"'{cls.DATASET_CLASS.__name__}', but are not installed."
)
def _split_kwargs(self, kwargs):
......
......@@ -369,7 +369,9 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Caltech101
FEATURE_TYPES = (PIL.Image.Image, (int, np.ndarray, tuple))
CONFIGS = datasets_utils.combinations_grid(target_type=("category", "annotation", ["category", "annotation"]))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
target_type=("category", "annotation", ["category", "annotation"])
)
REQUIRED_PACKAGES = ("scipy",)
def inject_fake_data(self, tmpdir, config):
......@@ -466,7 +468,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.WIDERFace
FEATURE_TYPES = (PIL.Image.Image, (dict, type(None))) # test split returns None as target
CONFIGS = datasets_utils.combinations_grid(split=('train', 'val', 'test'))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=('train', 'val', 'test'))
def inject_fake_data(self, tmpdir, config):
widerface_dir = pathlib.Path(tmpdir) / 'widerface'
......@@ -521,7 +523,7 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.ImageNet
REQUIRED_PACKAGES = ('scipy',)
CONFIGS = datasets_utils.combinations_grid(split=('train', 'val'))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=('train', 'val'))
def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)
......@@ -551,7 +553,7 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CIFAR10
CONFIGS = datasets_utils.combinations_grid(train=(True, False))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
_VERSION_CONFIG = dict(
base_folder="cifar-10-batches-py",
......@@ -623,7 +625,7 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CelebA
FEATURE_TYPES = (PIL.Image.Image, (torch.Tensor, int, tuple, type(None)))
CONFIGS = datasets_utils.combinations_grid(
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train", "valid", "test", "all"),
target_type=("attr", "identity", "bbox", "landmarks", ["attr", "identity"]),
)
......@@ -740,7 +742,7 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.VOCSegmentation
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image)
CONFIGS = (
ADDITIONAL_CONFIGS = (
*datasets_utils.combinations_grid(
year=[f"20{year:02d}" for year in range(7, 13)], image_set=("train", "val", "trainval")
),
......@@ -929,7 +931,7 @@ class CocoCaptionsTestCase(CocoDetectionTestCase):
class UCF101TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.UCF101
CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False))
_VIDEO_FOLDER = "videos"
_ANNOTATIONS_FOLDER = "annotations"
......@@ -990,7 +992,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.LSUN
REQUIRED_PACKAGES = ("lmdb",)
CONFIGS = datasets_utils.combinations_grid(
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
classes=("train", "test", "val", ["bedroom_train", "church_outdoor_train"])
)
......@@ -1097,7 +1099,7 @@ class Kinetics400TestCase(datasets_utils.VideoDatasetTestCase):
class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.HMDB51
CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False))
_VIDEO_FOLDER = "videos"
_SPLITS_FOLDER = "splits"
......@@ -1157,7 +1159,7 @@ class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
class OmniglotTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Omniglot
CONFIGS = datasets_utils.combinations_grid(background=(True, False))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(background=(True, False))
def inject_fake_data(self, tmpdir, config):
target_folder = (
......@@ -1237,7 +1239,7 @@ class SEMEIONTestCase(datasets_utils.ImageDatasetTestCase):
class USPSTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.USPS
CONFIGS = datasets_utils.combinations_grid(train=(True, False))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
def inject_fake_data(self, tmpdir, config):
num_images = 2 if config["train"] else 1
......@@ -1259,7 +1261,7 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
REQUIRED_PACKAGES = ("scipy.io", "scipy.sparse")
CONFIGS = datasets_utils.combinations_grid(
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
image_set=("train", "val", "train_noval"), mode=("boundaries", "segmentation")
)
......@@ -1345,7 +1347,7 @@ class PhotoTourTestCase(datasets_utils.ImageDatasetTestCase):
_TRAIN_FEATURE_TYPES = (torch.Tensor,)
_TEST_FEATURE_TYPES = (torch.Tensor, torch.Tensor, torch.Tensor)
CONFIGS = datasets_utils.combinations_grid(train=(True, False))
datasets_utils.combinations_grid(train=(True, False))
_NAME = "liberty"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment