Unverified Commit 09600b90 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Enable custom default config for dataset tests (#3578)

parent 661d89fd
...@@ -117,19 +117,45 @@ def requires_lazy_imports(*modules): ...@@ -117,19 +117,45 @@ def requires_lazy_imports(*modules):
def test_all_configs(test): def test_all_configs(test):
"""Decorator to run test against all configurations. """Decorator to run test against all configurations.
Add this as decorator to an arbitrary test to run it against all configurations. The current configuration is Add this as decorator to an arbitrary test to run it against all configurations. This includes
provided as the first parameter: :attr:`DatasetTestCase.DEFAULT_CONFIG` and :attr:`DatasetTestCase.ADDITIONAL_CONFIGS`.
The current configuration is provided as the first parameter for the test:
.. code-block:: .. code-block::
@test_all_configs @test_all_configs()
def test_foo(self, config): def test_foo(self, config):
pass pass
.. note::
This will try to remove duplicate configurations. During this process it will not not preserve a potential
ordering of the configurations or an inner ordering of a configuration.
""" """
def maybe_remove_duplicates(configs):
try:
return [dict(config_) for config_ in set(tuple(sorted(config.items())) for config in configs)]
except TypeError:
# A TypeError will be raised if a value of any config is not hashable, e.g. a list. In that case duplicate
# removal would be a lot more elaborate and we simply bail out.
return configs
@functools.wraps(test) @functools.wraps(test)
def wrapper(self): def wrapper(self):
for config in self.CONFIGS or (self._DEFAULT_CONFIG,): configs = []
if self.DEFAULT_CONFIG is not None:
configs.append(self.DEFAULT_CONFIG)
if self.ADDITIONAL_CONFIGS is not None:
configs.extend(self.ADDITIONAL_CONFIGS)
if not configs:
configs = [self._KWARG_DEFAULTS.copy()]
else:
configs = maybe_remove_duplicates(configs)
for config in configs:
with self.subTest(**config): with self.subTest(**config):
test(self, config) test(self, config)
...@@ -166,9 +192,13 @@ class DatasetTestCase(unittest.TestCase): ...@@ -166,9 +192,13 @@ class DatasetTestCase(unittest.TestCase):
Optionally, you can overwrite the following class attributes: Optionally, you can overwrite the following class attributes:
- CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictonary can contain an - DEFAULT_CONFIG (Dict[str, Any]): Config that will be used by default. If omitted, this defaults to all
arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``, keyword arguments of the dataset minus ``transform``, ``target_transform``, ``transforms``, and
``transforms``, or ``download``. The first element will be used as default configuration. ``download``. Overwrite this if you want to use a default value for a parameter for which the dataset does
not provide one.
- ADDITIONAL_CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictionary can
contain an arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``,
``transforms``, or ``download``.
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not - REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
available, the tests are skipped. available, the tests are skipped.
...@@ -218,22 +248,31 @@ class DatasetTestCase(unittest.TestCase): ...@@ -218,22 +248,31 @@ class DatasetTestCase(unittest.TestCase):
DATASET_CLASS = None DATASET_CLASS = None
FEATURE_TYPES = None FEATURE_TYPES = None
CONFIGS = None DEFAULT_CONFIG = None
ADDITIONAL_CONFIGS = None
REQUIRED_PACKAGES = None REQUIRED_PACKAGES = None
_DEFAULT_CONFIG = None # These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
_TRANSFORM_KWARGS = { _TRANSFORM_KWARGS = {
"transform", "transform",
"target_transform", "target_transform",
"transforms", "transforms",
} }
# These keyword arguments get a 'special' treatment and should not be set in DEFAULT_CONFIG or ADDITIONAL_CONFIGS.
_SPECIAL_KWARGS = { _SPECIAL_KWARGS = {
*_TRANSFORM_KWARGS, *_TRANSFORM_KWARGS,
"download", "download",
} }
# These fields are populated during setupClass() within _populate_private_class_attributes()
# This will be a dictionary containing all keyword arguments with their respective default values extracted from
# the dataset constructor.
_KWARG_DEFAULTS = None
# This will be a set of all _SPECIAL_KWARGS that the dataset constructor takes.
_HAS_SPECIAL_KWARG = None _HAS_SPECIAL_KWARG = None
# These functions are disabled during dataset creation in create_dataset().
_CHECK_FUNCTIONS = { _CHECK_FUNCTIONS = {
"check_md5", "check_md5",
"check_integrity", "check_integrity",
...@@ -256,7 +295,8 @@ class DatasetTestCase(unittest.TestCase): ...@@ -256,7 +295,8 @@ class DatasetTestCase(unittest.TestCase):
Args: Args:
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
to be created and in turn also for the fake data injected here. to be created and in turn also for the fake data injected here.
config (Dict[str, Any]): Configuration that will be used to create the dataset. config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
fields for all dataset parameters with default values.
Returns: Returns:
(Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets. (Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets.
...@@ -273,7 +313,8 @@ class DatasetTestCase(unittest.TestCase): ...@@ -273,7 +313,8 @@ class DatasetTestCase(unittest.TestCase):
Args: Args:
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
to be created and in turn also for the fake data injected here. to be created and in turn also for the fake data injected here.
config (Dict[str, Any]): Configuration that will be used to create the dataset. config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
fields for all dataset parameters with default values.
Needs to return one of the following: Needs to return one of the following:
...@@ -293,9 +334,16 @@ class DatasetTestCase(unittest.TestCase): ...@@ -293,9 +334,16 @@ class DatasetTestCase(unittest.TestCase):
) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]: ) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
r"""Create the dataset in a temporary directory. r"""Create the dataset in a temporary directory.
The configuration passed to the dataset is populated to contain at least all parameters with default values.
For this the following order of precedence is used:
1. Parameters in :attr:`kwargs`.
2. Configuration in :attr:`config`.
3. Configuration in :attr:`~DatasetTestCase.DEFAULT_CONFIG`.
4. Default parameters of the dataset.
Args: Args:
config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset. If omitted, the config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset.
default configuration is used.
inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
creating the dataset. creating the dataset.
patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
...@@ -308,30 +356,33 @@ class DatasetTestCase(unittest.TestCase): ...@@ -308,30 +356,33 @@ class DatasetTestCase(unittest.TestCase):
info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data` info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
for details. for details.
""" """
default_config = self._DEFAULT_CONFIG.copy()
if config is not None:
default_config.update(config)
config = default_config
if patch_checks is None: if patch_checks is None:
patch_checks = inject_fake_data patch_checks = inject_fake_data
special_kwargs, other_kwargs = self._split_kwargs(kwargs) special_kwargs, other_kwargs = self._split_kwargs(kwargs)
complete_config = self._KWARG_DEFAULTS.copy()
if self.DEFAULT_CONFIG:
complete_config.update(self.DEFAULT_CONFIG)
if config:
complete_config.update(config)
if other_kwargs:
complete_config.update(other_kwargs)
if "download" in self._HAS_SPECIAL_KWARG and special_kwargs.get("download", False): if "download" in self._HAS_SPECIAL_KWARG and special_kwargs.get("download", False):
# override download param to False param if its default is truthy # override download param to False param if its default is truthy
special_kwargs["download"] = False special_kwargs["download"] = False
config.update(other_kwargs)
patchers = self._patch_download_extract() patchers = self._patch_download_extract()
if patch_checks: if patch_checks:
patchers.update(self._patch_checks()) patchers.update(self._patch_checks())
with get_tmp_dir() as tmpdir: with get_tmp_dir() as tmpdir:
args = self.dataset_args(tmpdir, config) args = self.dataset_args(tmpdir, complete_config)
info = self._inject_fake_data(tmpdir, config) if inject_fake_data else None info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None
with self._maybe_apply_patches(patchers), disable_console_output(): with self._maybe_apply_patches(patchers), disable_console_output():
dataset = self.DATASET_CLASS(*args, **config, **special_kwargs) dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs)
yield dataset, info yield dataset, info
...@@ -357,26 +408,69 @@ class DatasetTestCase(unittest.TestCase): ...@@ -357,26 +408,69 @@ class DatasetTestCase(unittest.TestCase):
@classmethod @classmethod
def _populate_private_class_attributes(cls): def _populate_private_class_attributes(cls):
argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__) defaults = []
for cls_ in cls.DATASET_CLASS.__mro__:
if cls_ is torchvision.datasets.VisionDataset:
break
cls._DEFAULT_CONFIG = { argspec = inspect.getfullargspec(cls_.__init__)
kwarg: default
for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults)
if kwarg not in cls._SPECIAL_KWARGS
}
cls._HAS_SPECIAL_KWARG = {name for name in cls._SPECIAL_KWARGS if name in argspec.args} if not argspec.defaults:
continue
defaults.append(
{kwarg: default for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults)}
)
if not argspec.varkw:
break
kwarg_defaults = dict()
for config in reversed(defaults):
kwarg_defaults.update(config)
has_special_kwargs = set()
for name in cls._SPECIAL_KWARGS:
if name not in kwarg_defaults:
continue
del kwarg_defaults[name]
has_special_kwargs.add(name)
cls._KWARG_DEFAULTS = kwarg_defaults
cls._HAS_SPECIAL_KWARG = has_special_kwargs
@classmethod @classmethod
def _process_optional_public_class_attributes(cls): def _process_optional_public_class_attributes(cls):
if cls.REQUIRED_PACKAGES is not None: def check_config(config, name):
try: special_kwargs = tuple(f"'{name}'" for name in cls._SPECIAL_KWARGS if name in config)
if special_kwargs:
raise UsageError(
f"{name} contains a value for the parameter(s) {', '.join(special_kwargs)}. "
f"These are handled separately by the test case and should not be set here. "
f"If you need to test some custom behavior regarding these parameters, "
f"you need to write a custom test (*not* test case), e.g. test_custom_transform()."
)
if cls.DEFAULT_CONFIG is not None:
check_config(cls.DEFAULT_CONFIG, "DEFAULT_CONFIG")
if cls.ADDITIONAL_CONFIGS is not None:
for idx, config in enumerate(cls.ADDITIONAL_CONFIGS):
check_config(config, f"CONFIGS[{idx}]")
if cls.REQUIRED_PACKAGES:
missing_pkgs = []
for pkg in cls.REQUIRED_PACKAGES: for pkg in cls.REQUIRED_PACKAGES:
try:
importlib.import_module(pkg) importlib.import_module(pkg)
except ImportError as error: except ImportError:
missing_pkgs.append(f"'{pkg}'")
if missing_pkgs:
raise unittest.SkipTest( raise unittest.SkipTest(
f"The package '{error.name}' is required to load the dataset '{cls.DATASET_CLASS.__name__}' but is " f"The package(s) {', '.join(missing_pkgs)} are required to load the dataset "
f"not installed." f"'{cls.DATASET_CLASS.__name__}', but are not installed."
) )
def _split_kwargs(self, kwargs): def _split_kwargs(self, kwargs):
......
...@@ -369,7 +369,9 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -369,7 +369,9 @@ class Caltech101TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Caltech101 DATASET_CLASS = datasets.Caltech101
FEATURE_TYPES = (PIL.Image.Image, (int, np.ndarray, tuple)) FEATURE_TYPES = (PIL.Image.Image, (int, np.ndarray, tuple))
CONFIGS = datasets_utils.combinations_grid(target_type=("category", "annotation", ["category", "annotation"])) ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
target_type=("category", "annotation", ["category", "annotation"])
)
REQUIRED_PACKAGES = ("scipy",) REQUIRED_PACKAGES = ("scipy",)
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
...@@ -466,7 +468,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -466,7 +468,7 @@ class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase): class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.WIDERFace DATASET_CLASS = datasets.WIDERFace
FEATURE_TYPES = (PIL.Image.Image, (dict, type(None))) # test split returns None as target FEATURE_TYPES = (PIL.Image.Image, (dict, type(None))) # test split returns None as target
CONFIGS = datasets_utils.combinations_grid(split=('train', 'val', 'test')) ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=('train', 'val', 'test'))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
widerface_dir = pathlib.Path(tmpdir) / 'widerface' widerface_dir = pathlib.Path(tmpdir) / 'widerface'
...@@ -521,7 +523,7 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -521,7 +523,7 @@ class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase):
class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.ImageNet DATASET_CLASS = datasets.ImageNet
REQUIRED_PACKAGES = ('scipy',) REQUIRED_PACKAGES = ('scipy',)
CONFIGS = datasets_utils.combinations_grid(split=('train', 'val')) ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=('train', 'val'))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir) tmpdir = pathlib.Path(tmpdir)
...@@ -551,7 +553,7 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -551,7 +553,7 @@ class ImageNetTestCase(datasets_utils.ImageDatasetTestCase):
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase): class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CIFAR10 DATASET_CLASS = datasets.CIFAR10
CONFIGS = datasets_utils.combinations_grid(train=(True, False)) ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
_VERSION_CONFIG = dict( _VERSION_CONFIG = dict(
base_folder="cifar-10-batches-py", base_folder="cifar-10-batches-py",
...@@ -623,7 +625,7 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase): ...@@ -623,7 +625,7 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CelebA DATASET_CLASS = datasets.CelebA
FEATURE_TYPES = (PIL.Image.Image, (torch.Tensor, int, tuple, type(None))) FEATURE_TYPES = (PIL.Image.Image, (torch.Tensor, int, tuple, type(None)))
CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train", "valid", "test", "all"), split=("train", "valid", "test", "all"),
target_type=("attr", "identity", "bbox", "landmarks", ["attr", "identity"]), target_type=("attr", "identity", "bbox", "landmarks", ["attr", "identity"]),
) )
...@@ -740,7 +742,7 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -740,7 +742,7 @@ class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.VOCSegmentation DATASET_CLASS = datasets.VOCSegmentation
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image) FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image)
CONFIGS = ( ADDITIONAL_CONFIGS = (
*datasets_utils.combinations_grid( *datasets_utils.combinations_grid(
year=[f"20{year:02d}" for year in range(7, 13)], image_set=("train", "val", "trainval") year=[f"20{year:02d}" for year in range(7, 13)], image_set=("train", "val", "trainval")
), ),
...@@ -929,7 +931,7 @@ class CocoCaptionsTestCase(CocoDetectionTestCase): ...@@ -929,7 +931,7 @@ class CocoCaptionsTestCase(CocoDetectionTestCase):
class UCF101TestCase(datasets_utils.VideoDatasetTestCase): class UCF101TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.UCF101 DATASET_CLASS = datasets.UCF101
CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False)) ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False))
_VIDEO_FOLDER = "videos" _VIDEO_FOLDER = "videos"
_ANNOTATIONS_FOLDER = "annotations" _ANNOTATIONS_FOLDER = "annotations"
...@@ -990,7 +992,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -990,7 +992,7 @@ class LSUNTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.LSUN DATASET_CLASS = datasets.LSUN
REQUIRED_PACKAGES = ("lmdb",) REQUIRED_PACKAGES = ("lmdb",)
CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
classes=("train", "test", "val", ["bedroom_train", "church_outdoor_train"]) classes=("train", "test", "val", ["bedroom_train", "church_outdoor_train"])
) )
...@@ -1097,7 +1099,7 @@ class Kinetics400TestCase(datasets_utils.VideoDatasetTestCase): ...@@ -1097,7 +1099,7 @@ class Kinetics400TestCase(datasets_utils.VideoDatasetTestCase):
class HMDB51TestCase(datasets_utils.VideoDatasetTestCase): class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
DATASET_CLASS = datasets.HMDB51 DATASET_CLASS = datasets.HMDB51
CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False)) ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(fold=(1, 2, 3), train=(True, False))
_VIDEO_FOLDER = "videos" _VIDEO_FOLDER = "videos"
_SPLITS_FOLDER = "splits" _SPLITS_FOLDER = "splits"
...@@ -1157,7 +1159,7 @@ class HMDB51TestCase(datasets_utils.VideoDatasetTestCase): ...@@ -1157,7 +1159,7 @@ class HMDB51TestCase(datasets_utils.VideoDatasetTestCase):
class OmniglotTestCase(datasets_utils.ImageDatasetTestCase): class OmniglotTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Omniglot DATASET_CLASS = datasets.Omniglot
CONFIGS = datasets_utils.combinations_grid(background=(True, False)) ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(background=(True, False))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
target_folder = ( target_folder = (
...@@ -1237,7 +1239,7 @@ class SEMEIONTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1237,7 +1239,7 @@ class SEMEIONTestCase(datasets_utils.ImageDatasetTestCase):
class USPSTestCase(datasets_utils.ImageDatasetTestCase): class USPSTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.USPS DATASET_CLASS = datasets.USPS
CONFIGS = datasets_utils.combinations_grid(train=(True, False)) ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
num_images = 2 if config["train"] else 1 num_images = 2 if config["train"] else 1
...@@ -1259,7 +1261,7 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1259,7 +1261,7 @@ class SBDatasetTestCase(datasets_utils.ImageDatasetTestCase):
REQUIRED_PACKAGES = ("scipy.io", "scipy.sparse") REQUIRED_PACKAGES = ("scipy.io", "scipy.sparse")
CONFIGS = datasets_utils.combinations_grid( ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
image_set=("train", "val", "train_noval"), mode=("boundaries", "segmentation") image_set=("train", "val", "train_noval"), mode=("boundaries", "segmentation")
) )
...@@ -1345,7 +1347,7 @@ class PhotoTourTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1345,7 +1347,7 @@ class PhotoTourTestCase(datasets_utils.ImageDatasetTestCase):
_TRAIN_FEATURE_TYPES = (torch.Tensor,) _TRAIN_FEATURE_TYPES = (torch.Tensor,)
_TEST_FEATURE_TYPES = (torch.Tensor, torch.Tensor, torch.Tensor) _TEST_FEATURE_TYPES = (torch.Tensor, torch.Tensor, torch.Tensor)
CONFIGS = datasets_utils.combinations_grid(train=(True, False)) datasets_utils.combinations_grid(train=(True, False))
_NAME = "liberty" _NAME = "liberty"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment