Unverified Commit 2e8c124f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Improve dataset test infrastructure (#3450)

* always use default config as base

* fix test_all_configs decorator

* lint

* add a utility function to create a random string

* move output check of inject_fake_data to dedicated method

* always disable download and extract functionality
parent a24191ed
...@@ -6,6 +6,8 @@ import inspect ...@@ -6,6 +6,8 @@ import inspect
import itertools import itertools
import os import os
import pathlib import pathlib
import random
import string
import unittest import unittest
import unittest.mock import unittest.mock
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
...@@ -32,6 +34,7 @@ __all__ = [ ...@@ -32,6 +34,7 @@ __all__ = [
"create_image_folder", "create_image_folder",
"create_video_file", "create_video_file",
"create_video_folder", "create_video_folder",
"create_random_string",
] ]
...@@ -93,14 +96,6 @@ def requires_lazy_imports(*modules): ...@@ -93,14 +96,6 @@ def requires_lazy_imports(*modules):
return outer_wrapper return outer_wrapper
# As of Python 3.7 this is provided by contextlib
# https://docs.python.org/3.7/library/contextlib.html#contextlib.nullcontext
# TODO: If the minimum Python requirement is >= 3.7, replace this
@contextlib.contextmanager
def nullcontext(enter_result=None):
yield enter_result
def test_all_configs(test): def test_all_configs(test):
"""Decorator to run test against all configurations. """Decorator to run test against all configurations.
...@@ -116,7 +111,7 @@ def test_all_configs(test): ...@@ -116,7 +111,7 @@ def test_all_configs(test):
@functools.wraps(test) @functools.wraps(test)
def wrapper(self): def wrapper(self):
for config in self.CONFIGS: for config in self.CONFIGS or (self._DEFAULT_CONFIG,):
with self.subTest(**config): with self.subTest(**config):
test(self, config) test(self, config)
...@@ -207,6 +202,8 @@ class DatasetTestCase(unittest.TestCase): ...@@ -207,6 +202,8 @@ class DatasetTestCase(unittest.TestCase):
CONFIGS = None CONFIGS = None
REQUIRED_PACKAGES = None REQUIRED_PACKAGES = None
_DEFAULT_CONFIG = None
_TRANSFORM_KWARGS = { _TRANSFORM_KWARGS = {
"transform", "transform",
"target_transform", "target_transform",
...@@ -268,7 +265,7 @@ class DatasetTestCase(unittest.TestCase): ...@@ -268,7 +265,7 @@ class DatasetTestCase(unittest.TestCase):
self, self,
config: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
inject_fake_data: bool = True, inject_fake_data: bool = True,
disable_download_extract: Optional[bool] = None, patch_checks: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]: ) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
r"""Create the dataset in a temporary directory. r"""Create the dataset in a temporary directory.
...@@ -278,8 +275,8 @@ class DatasetTestCase(unittest.TestCase): ...@@ -278,8 +275,8 @@ class DatasetTestCase(unittest.TestCase):
default configuration is used. default configuration is used.
inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
creating the dataset. creating the dataset.
disable_download_extract (Optional[bool]): If ``True`` disable download and extract logic while creating patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
the dataset. If ``None`` (default) this takes the same value as ``inject_fake_data``. omitted defaults to the same value as ``inject_fake_data``.
**kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they **kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they
overlap with ``config``. overlap with ``config``.
...@@ -288,43 +285,28 @@ class DatasetTestCase(unittest.TestCase): ...@@ -288,43 +285,28 @@ class DatasetTestCase(unittest.TestCase):
info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data` info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
for details. for details.
""" """
if config is None: default_config = self._DEFAULT_CONFIG.copy()
config = self.CONFIGS[0].copy() if config is not None:
default_config.update(config)
config = default_config
if patch_checks is None:
patch_checks = inject_fake_data
special_kwargs, other_kwargs = self._split_kwargs(kwargs) special_kwargs, other_kwargs = self._split_kwargs(kwargs)
if "download" in self._HAS_SPECIAL_KWARG:
special_kwargs["download"] = False
config.update(other_kwargs) config.update(other_kwargs)
if disable_download_extract is None: patchers = self._patch_download_extract()
disable_download_extract = inject_fake_data if patch_checks:
patchers.update(self._patch_checks())
with get_tmp_dir() as tmpdir: with get_tmp_dir() as tmpdir:
args = self.dataset_args(tmpdir, config) args = self.dataset_args(tmpdir, config)
info = self._inject_fake_data(tmpdir, config) if inject_fake_data else None
if inject_fake_data: with self._maybe_apply_patches(patchers), disable_console_output():
info = self.inject_fake_data(tmpdir, config)
if info is None:
raise UsageError(
"The method 'inject_fake_data' needs to return at least an integer indicating the number of "
"examples for the current configuration."
)
elif isinstance(info, int):
info = dict(num_examples=info)
elif not isinstance(info, dict):
raise UsageError(
f"The additional information returned by the method 'inject_fake_data' must be either an "
f"integer indicating the number of examples for the current configuration or a dictionary with "
f"the same content. Got {type(info)} instead."
)
elif "num_examples" not in info:
raise UsageError(
"The information dictionary returned by the method 'inject_fake_data' must contain a "
"'num_examples' field that holds the number of examples for the current configuration."
)
else:
info = None
cm = self._disable_download_extract if disable_download_extract else nullcontext
with cm(special_kwargs), disable_console_output():
dataset = self.DATASET_CLASS(*args, **config, **special_kwargs) dataset = self.DATASET_CLASS(*args, **config, **special_kwargs)
yield dataset, info yield dataset, info
...@@ -352,19 +334,17 @@ class DatasetTestCase(unittest.TestCase): ...@@ -352,19 +334,17 @@ class DatasetTestCase(unittest.TestCase):
@classmethod @classmethod
def _populate_private_class_attributes(cls): def _populate_private_class_attributes(cls):
argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__) argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__)
cls._DEFAULT_CONFIG = {
kwarg: default
for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults)
if kwarg not in cls._SPECIAL_KWARGS
}
cls._HAS_SPECIAL_KWARG = {name for name in cls._SPECIAL_KWARGS if name in argspec.args} cls._HAS_SPECIAL_KWARG = {name for name in cls._SPECIAL_KWARGS if name in argspec.args}
@classmethod @classmethod
def _process_optional_public_class_attributes(cls): def _process_optional_public_class_attributes(cls):
argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__)
if cls.CONFIGS is None:
config = {
kwarg: default
for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults)
if kwarg not in cls._SPECIAL_KWARGS
}
cls.CONFIGS = (config,)
if cls.REQUIRED_PACKAGES is not None: if cls.REQUIRED_PACKAGES is not None:
try: try:
for pkg in cls.REQUIRED_PACKAGES: for pkg in cls.REQUIRED_PACKAGES:
...@@ -380,28 +360,44 @@ class DatasetTestCase(unittest.TestCase): ...@@ -380,28 +360,44 @@ class DatasetTestCase(unittest.TestCase):
other_kwargs = {key: special_kwargs.pop(key) for key in set(special_kwargs.keys()) - self._SPECIAL_KWARGS} other_kwargs = {key: special_kwargs.pop(key) for key in set(special_kwargs.keys()) - self._SPECIAL_KWARGS}
return special_kwargs, other_kwargs return special_kwargs, other_kwargs
@contextlib.contextmanager def _inject_fake_data(self, tmpdir, config):
def _disable_download_extract(self, special_kwargs): info = self.inject_fake_data(tmpdir, config)
inject_download_kwarg = "download" in self._HAS_SPECIAL_KWARG and "download" not in special_kwargs if info is None:
if inject_download_kwarg: raise UsageError(
special_kwargs["download"] = False "The method 'inject_fake_data' needs to return at least an integer indicating the number of "
"examples for the current configuration."
)
elif isinstance(info, int):
info = dict(num_examples=info)
elif not isinstance(info, dict):
raise UsageError(
f"The additional information returned by the method 'inject_fake_data' must be either an "
f"integer indicating the number of examples for the current configuration or a dictionary with "
f"the same content. Got {type(info)} instead."
)
elif "num_examples" not in info:
raise UsageError(
"The information dictionary returned by the method 'inject_fake_data' must contain a "
"'num_examples' field that holds the number of examples for the current configuration."
)
return info
def _patch_download_extract(self):
module = inspect.getmodule(self.DATASET_CLASS).__name__
return {unittest.mock.patch(f"{module}.{function}") for function in self._DOWNLOAD_EXTRACT_FUNCTIONS}
def _patch_checks(self):
module = inspect.getmodule(self.DATASET_CLASS).__name__ module = inspect.getmodule(self.DATASET_CLASS).__name__
return {unittest.mock.patch(f"{module}.{function}", return_value=True) for function in self._CHECK_FUNCTIONS}
@contextlib.contextmanager
def _maybe_apply_patches(self, patchers):
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
mocks = {} mocks = {}
for function, kwargs in itertools.chain( for patcher in patchers:
zip(self._CHECK_FUNCTIONS, [dict(return_value=True)] * len(self._CHECK_FUNCTIONS)),
zip(self._DOWNLOAD_EXTRACT_FUNCTIONS, [dict()] * len(self._DOWNLOAD_EXTRACT_FUNCTIONS)),
):
with contextlib.suppress(AttributeError): with contextlib.suppress(AttributeError):
patcher = unittest.mock.patch(f"{module}.{function}", **kwargs) mocks[patcher.target] = stack.enter_context(patcher)
mocks[function] = stack.enter_context(patcher) yield mocks
try:
yield mocks
finally:
if inject_download_kwarg:
del special_kwargs["download"]
def test_not_found_or_corrupted(self): def test_not_found_or_corrupted(self):
with self.assertRaises((FileNotFoundError, RuntimeError)): with self.assertRaises((FileNotFoundError, RuntimeError)):
...@@ -469,13 +465,13 @@ class ImageDatasetTestCase(DatasetTestCase): ...@@ -469,13 +465,13 @@ class ImageDatasetTestCase(DatasetTestCase):
self, self,
config: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
inject_fake_data: bool = True, inject_fake_data: bool = True,
disable_download_extract: Optional[bool] = None, patch_checks: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]: ) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
with super().create_dataset( with super().create_dataset(
config=config, config=config,
inject_fake_data=inject_fake_data, inject_fake_data=inject_fake_data,
disable_download_extract=disable_download_extract, patch_checks=patch_checks,
**kwargs, **kwargs,
) as (dataset, info): ) as (dataset, info):
# PIL.Image.open() only loads the image meta data upfront and keeps the file open until the first access # PIL.Image.open() only loads the image meta data upfront and keeps the file open until the first access
...@@ -711,3 +707,18 @@ def create_video_folder( ...@@ -711,3 +707,18 @@ def create_video_folder(
create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs) create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs)
for idx in range(num_examples) for idx in range(num_examples)
] ]
def create_random_string(length: int, *digits: str) -> str:
"""Create a random string.
Args:
length (int): Number of characters in the generated string.
*characters (str): Characters to sample from. If omitted defaults to :attr:`string.ascii_lowercase`.
"""
if not digits:
digits = string.ascii_lowercase
else:
digits = "".join(itertools.chain(*digits))
return "".join(random.choice(digits) for _ in range(length))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment