"src/vscode:/vscode.git/clone" did not exist on "671149e03604ef82dd32a8bd419b598a29a4c32e"
Unverified Commit 2e8c124f authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Improve dataset test infrastructure (#3450)

* always use default config as base

* fix test_all_configs decorator

* lint

* add a utility function to create a random string

* move output check of inject_fake_data to dedicated method

* always disable download and extract functionality
parent a24191ed
......@@ -6,6 +6,8 @@ import inspect
import itertools
import os
import pathlib
import random
import string
import unittest
import unittest.mock
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
......@@ -32,6 +34,7 @@ __all__ = [
"create_image_folder",
"create_video_file",
"create_video_folder",
"create_random_string",
]
......@@ -93,14 +96,6 @@ def requires_lazy_imports(*modules):
return outer_wrapper
# As of Python 3.7 this is provided by contextlib
# https://docs.python.org/3.7/library/contextlib.html#contextlib.nullcontext
# TODO: If the minimum Python requirement is >= 3.7, replace this
@contextlib.contextmanager
def nullcontext(enter_result=None):
yield enter_result
def test_all_configs(test):
"""Decorator to run test against all configurations.
......@@ -116,7 +111,7 @@ def test_all_configs(test):
@functools.wraps(test)
def wrapper(self):
for config in self.CONFIGS:
for config in self.CONFIGS or (self._DEFAULT_CONFIG,):
with self.subTest(**config):
test(self, config)
......@@ -207,6 +202,8 @@ class DatasetTestCase(unittest.TestCase):
CONFIGS = None
REQUIRED_PACKAGES = None
_DEFAULT_CONFIG = None
_TRANSFORM_KWARGS = {
"transform",
"target_transform",
......@@ -268,7 +265,7 @@ class DatasetTestCase(unittest.TestCase):
self,
config: Optional[Dict[str, Any]] = None,
inject_fake_data: bool = True,
disable_download_extract: Optional[bool] = None,
patch_checks: Optional[bool] = None,
**kwargs: Any,
) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
r"""Create the dataset in a temporary directory.
......@@ -278,8 +275,8 @@ class DatasetTestCase(unittest.TestCase):
default configuration is used.
inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
creating the dataset.
disable_download_extract (Optional[bool]): If ``True`` disable download and extract logic while creating
the dataset. If ``None`` (default) this takes the same value as ``inject_fake_data``.
patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
omitted defaults to the same value as ``inject_fake_data``.
**kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they
overlap with ``config``.
......@@ -288,43 +285,28 @@ class DatasetTestCase(unittest.TestCase):
info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
for details.
"""
if config is None:
config = self.CONFIGS[0].copy()
default_config = self._DEFAULT_CONFIG.copy()
if config is not None:
default_config.update(config)
config = default_config
if patch_checks is None:
patch_checks = inject_fake_data
special_kwargs, other_kwargs = self._split_kwargs(kwargs)
if "download" in self._HAS_SPECIAL_KWARG:
special_kwargs["download"] = False
config.update(other_kwargs)
if disable_download_extract is None:
disable_download_extract = inject_fake_data
patchers = self._patch_download_extract()
if patch_checks:
patchers.update(self._patch_checks())
with get_tmp_dir() as tmpdir:
args = self.dataset_args(tmpdir, config)
info = self._inject_fake_data(tmpdir, config) if inject_fake_data else None
if inject_fake_data:
info = self.inject_fake_data(tmpdir, config)
if info is None:
raise UsageError(
"The method 'inject_fake_data' needs to return at least an integer indicating the number of "
"examples for the current configuration."
)
elif isinstance(info, int):
info = dict(num_examples=info)
elif not isinstance(info, dict):
raise UsageError(
f"The additional information returned by the method 'inject_fake_data' must be either an "
f"integer indicating the number of examples for the current configuration or a dictionary with "
f"the same content. Got {type(info)} instead."
)
elif "num_examples" not in info:
raise UsageError(
"The information dictionary returned by the method 'inject_fake_data' must contain a "
"'num_examples' field that holds the number of examples for the current configuration."
)
else:
info = None
cm = self._disable_download_extract if disable_download_extract else nullcontext
with cm(special_kwargs), disable_console_output():
with self._maybe_apply_patches(patchers), disable_console_output():
dataset = self.DATASET_CLASS(*args, **config, **special_kwargs)
yield dataset, info
......@@ -352,19 +334,17 @@ class DatasetTestCase(unittest.TestCase):
@classmethod
def _populate_private_class_attributes(cls):
argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__)
cls._HAS_SPECIAL_KWARG = {name for name in cls._SPECIAL_KWARGS if name in argspec.args}
@classmethod
def _process_optional_public_class_attributes(cls):
argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__)
if cls.CONFIGS is None:
config = {
cls._DEFAULT_CONFIG = {
kwarg: default
for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults)
if kwarg not in cls._SPECIAL_KWARGS
}
cls.CONFIGS = (config,)
cls._HAS_SPECIAL_KWARG = {name for name in cls._SPECIAL_KWARGS if name in argspec.args}
@classmethod
def _process_optional_public_class_attributes(cls):
if cls.REQUIRED_PACKAGES is not None:
try:
for pkg in cls.REQUIRED_PACKAGES:
......@@ -380,28 +360,44 @@ class DatasetTestCase(unittest.TestCase):
other_kwargs = {key: special_kwargs.pop(key) for key in set(special_kwargs.keys()) - self._SPECIAL_KWARGS}
return special_kwargs, other_kwargs
@contextlib.contextmanager
def _disable_download_extract(self, special_kwargs):
inject_download_kwarg = "download" in self._HAS_SPECIAL_KWARG and "download" not in special_kwargs
if inject_download_kwarg:
special_kwargs["download"] = False
def _inject_fake_data(self, tmpdir, config):
info = self.inject_fake_data(tmpdir, config)
if info is None:
raise UsageError(
"The method 'inject_fake_data' needs to return at least an integer indicating the number of "
"examples for the current configuration."
)
elif isinstance(info, int):
info = dict(num_examples=info)
elif not isinstance(info, dict):
raise UsageError(
f"The additional information returned by the method 'inject_fake_data' must be either an "
f"integer indicating the number of examples for the current configuration or a dictionary with "
f"the same content. Got {type(info)} instead."
)
elif "num_examples" not in info:
raise UsageError(
"The information dictionary returned by the method 'inject_fake_data' must contain a "
"'num_examples' field that holds the number of examples for the current configuration."
)
return info
def _patch_download_extract(self):
module = inspect.getmodule(self.DATASET_CLASS).__name__
return {unittest.mock.patch(f"{module}.{function}") for function in self._DOWNLOAD_EXTRACT_FUNCTIONS}
def _patch_checks(self):
module = inspect.getmodule(self.DATASET_CLASS).__name__
return {unittest.mock.patch(f"{module}.{function}", return_value=True) for function in self._CHECK_FUNCTIONS}
@contextlib.contextmanager
def _maybe_apply_patches(self, patchers):
with contextlib.ExitStack() as stack:
mocks = {}
for function, kwargs in itertools.chain(
zip(self._CHECK_FUNCTIONS, [dict(return_value=True)] * len(self._CHECK_FUNCTIONS)),
zip(self._DOWNLOAD_EXTRACT_FUNCTIONS, [dict()] * len(self._DOWNLOAD_EXTRACT_FUNCTIONS)),
):
for patcher in patchers:
with contextlib.suppress(AttributeError):
patcher = unittest.mock.patch(f"{module}.{function}", **kwargs)
mocks[function] = stack.enter_context(patcher)
try:
mocks[patcher.target] = stack.enter_context(patcher)
yield mocks
finally:
if inject_download_kwarg:
del special_kwargs["download"]
def test_not_found_or_corrupted(self):
with self.assertRaises((FileNotFoundError, RuntimeError)):
......@@ -469,13 +465,13 @@ class ImageDatasetTestCase(DatasetTestCase):
self,
config: Optional[Dict[str, Any]] = None,
inject_fake_data: bool = True,
disable_download_extract: Optional[bool] = None,
patch_checks: Optional[bool] = None,
**kwargs: Any,
) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
with super().create_dataset(
config=config,
inject_fake_data=inject_fake_data,
disable_download_extract=disable_download_extract,
patch_checks=patch_checks,
**kwargs,
) as (dataset, info):
# PIL.Image.open() only loads the image meta data upfront and keeps the file open until the first access
......@@ -711,3 +707,18 @@ def create_video_folder(
create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs)
for idx in range(num_examples)
]
def create_random_string(length: int, *digits: str) -> str:
"""Create a random string.
Args:
length (int): Number of characters in the generated string.
*characters (str): Characters to sample from. If omitted defaults to :attr:`string.ascii_lowercase`.
"""
if not digits:
digits = string.ascii_lowercase
else:
digits = "".join(itertools.chain(*digits))
return "".join(random.choice(digits) for _ in range(length))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment