Unverified Commit 22c548b0 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

[POC] Base class for dataset tests (#3402)

* add base class for datasets tests

* add better type hints

* add documentation to subclasses

* add utility functions to create files / folders of random images and videos

* fix imports

* remove class properties

* fix smoke test

* fix type hints

* fix random size generation

* add Caltech256 as example

* add utility function to create grid of combinations

* add CIFAR100? as example

* lint

* add missing import

* improve documentation

* create 1 frame videos by default

* remove obsolete check

* return path of files created with utility functions

* [test] close PIL file handles before deletion

* fix video folder creation

* generalize file handle closing

* fix lazy imports

* add test for transforms

* fix explanation comment

* lint

* force load opened PIL images

* lint

* copy default config to avoid inplace modification

* enable additional arg forwarding
parent c1f85d34
...@@ -393,3 +393,11 @@ def int_dtypes(): ...@@ -393,3 +393,11 @@ def int_dtypes():
def float_dtypes(): def float_dtypes():
return torch.testing.floating_types() return torch.testing.floating_types()
@contextlib.contextmanager
def disable_console_output():
with contextlib.ExitStack() as stack, open(os.devnull, "w") as devnull:
stack.enter_context(contextlib.redirect_stdout(devnull))
stack.enter_context(contextlib.redirect_stderr(devnull))
yield
import collections.abc
import contextlib
import functools
import importlib
import inspect
import itertools
import os
import pathlib
import unittest
import unittest.mock
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
import PIL
import PIL.Image
import torch
import torchvision.datasets
import torchvision.io
from common_utils import get_tmp_dir, disable_console_output
__all__ = [
"UsageError",
"lazy_importer",
"test_all_configs",
"DatasetTestCase",
"ImageDatasetTestCase",
"VideoDatasetTestCase",
"create_image_or_video_tensor",
"create_image_file",
"create_image_folder",
"create_video_file",
"create_video_folder",
]
class UsageError(RuntimeError):
"""Should be raised in case an error happens in the setup rather than the test."""
class LazyImporter:
r"""Lazy importer for additional dependicies.
Some datasets require additional packages that are no direct dependencies of torchvision. Instances of this class
provide modules listed in MODULES as attributes. They are only imported when accessed.
"""
MODULES = (
"av",
"lmdb",
"pandas",
"pycocotools",
"requests",
"scipy.io",
)
def __init__(self):
cls = type(self)
for module in self.MODULES:
# We need the quirky 'module=module' argument to the lambda since otherwise the lookup for 'module' in this
# scope would happen at runtime rather than at definition. Thus, without it, every property would try to
# import the last 'module' in MODULES.
setattr(cls, module.split(".", 1)[0], property(lambda self, module=module: LazyImporter._import(module)))
@staticmethod
def _import(module):
try:
importlib.import_module(module)
return importlib.import_module(module.split(".", 1)[0])
except ImportError as error:
raise UsageError(
f"Failed to import module '{module}'. "
f"This probably means that the current test case needs '{module}' installed, "
f"but it is not a dependency of torchvision. "
f"You need to install it manually, for example 'pip install {module}'."
) from error
lazy_importer = LazyImporter()
def requires_lazy_imports(*modules):
def outer_wrapper(fn):
@functools.wraps(fn)
def inner_wrapper(*args, **kwargs):
for module in modules:
getattr(lazy_importer, module.replace(".", "_"))
return fn(*args, **kwargs)
return inner_wrapper
return outer_wrapper
# As of Python 3.7 this is provided by contextlib
# https://docs.python.org/3.7/library/contextlib.html#contextlib.nullcontext
# TODO: If the minimum Python requirement is >= 3.7, replace this
@contextlib.contextmanager
def nullcontext(enter_result=None):
yield enter_result
def test_all_configs(test):
"""Decorator to run test against all configurations.
Add this as decorator to an arbitrary test to run it against all configurations. The current configuration is
provided as the first parameter:
.. code-block::
@test_all_configs
def test_foo(self, config):
pass
"""
@functools.wraps(test)
def wrapper(self):
for config in self.CONFIGS:
with self.subTest(**config):
test(self, config)
return wrapper
def combinations_grid(**kwargs):
"""Creates a grid of input combinations.
Each element in the returned sequence is a dictionary containing one possible combination as values.
Example:
>>> combinations_grid(foo=("bar", "baz"), spam=("eggs", "ham"))
[
{'foo': 'bar', 'spam': 'eggs'},
{'foo': 'bar', 'spam': 'ham'},
{'foo': 'baz', 'spam': 'eggs'},
{'foo': 'baz', 'spam': 'ham'}
]
"""
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]
class DatasetTestCase(unittest.TestCase):
"""Abstract base class for all dataset testcases.
You have to overwrite the following class attributes:
- DATASET_CLASS (torchvision.datasets.VisionDataset): Class of dataset to be tested.
- FEATURE_TYPES (Sequence[Any]): Types of the elements returned by index access of the dataset. Instead of
providing these manually, you can instead subclass ``ImageDatasetTestCase`` or ``VideoDatasetTestCase```to
get a reasonable default, that should work for most cases.
Optionally, you can overwrite the following class attributes:
- CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictonary can contain an
arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``,
``transforms``, or ``download``. The first element will be used as default configuration.
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
available, the tests are skipped.
Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on.
The fake data should resemble the original data as close as necessary, while containing only few examples. During
the creation of the dataset check-, download-, and extract-functions from ``torchvision.datasets.utils`` are
disabled.
Without further configuration, the testcase will test if
1. the dataset raises a ``RuntimeError`` if the data files are not found,
2. the dataset inherits from `torchvision.datasets.VisionDataset`,
3. the dataset can be turned into a string,
4. the feature types of a returned example matches ``FEATURE_TYPES``,
5. the number of examples matches the injected fake data, and
6. the dataset calls ``transform``, ``target_transform``, or ``transforms`` if available when accessing data.
Case 3. to 6. are tested against all configurations in ``CONFIGS``.
To add dataset-specific tests, create a new method that takes no arguments with ``test_`` as a name prefix:
.. code-block::
def test_foo(self):
pass
If you want to run the test against all configs, add the ``@test_all_configs`` decorator to the definition and
accept a single argument:
.. code-block::
@test_all_configs
def test_bar(self, config):
pass
Within the test you can use the ``create_dataset()`` method that yields the dataset as well as additional
information provided by the ``ìnject_fake_data()`` method:
.. code-block::
def test_baz(self):
with self.create_dataset() as (dataset, info):
pass
"""
DATASET_CLASS = None
FEATURE_TYPES = None
CONFIGS = None
REQUIRED_PACKAGES = None
_TRANSFORM_KWARGS = {
"transform",
"target_transform",
"transforms",
}
_SPECIAL_KWARGS = {
*_TRANSFORM_KWARGS,
"download",
}
_HAS_SPECIAL_KWARG = None
_CHECK_FUNCTIONS = {
"check_md5",
"check_integrity",
}
_DOWNLOAD_EXTRACT_FUNCTIONS = {
"download_url",
"download_file_from_google_drive",
"extract_archive",
"download_and_extract_archive",
}
def inject_fake_data(
self, tmpdir: str, config: Dict[str, Any]
) -> Union[int, Dict[str, Any], Tuple[Sequence[Any], Union[int, Dict[str, Any]]]]:
"""Inject fake data for dataset into a temporary directory.
Args:
tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
to be created and in turn also for the fake data injected here.
config (Dict[str, Any]): Configuration that will be used to create the dataset.
Needs to return one of the following:
1. (int): Number of examples in the dataset to be created,
2. (Dict[str, Any]): Additional information about the injected fake data. Must contain the field
``"num_examples"`` that corresponds to the number of examples in the dataset to be created, or
3. (Tuple[Sequence[Any], Union[int, Dict[str, Any]]]): Additional required parameters that are passed to
the dataset constructor. The second element corresponds to cases 1. and 2.
If no ``args`` is returned (case 1. and 2.), the ``tmp_dir`` is passed as first parameter to the dataset
constructor. In most cases this corresponds to ``root``. If the dataset has more parameters without default
values you need to explicitly pass them as explained in case 3.
"""
raise NotImplementedError("You need to provide fake data in order for the tests to run.")
@contextlib.contextmanager
def create_dataset(
self,
config: Optional[Dict[str, Any]] = None,
inject_fake_data: bool = True,
disable_download_extract: Optional[bool] = None,
**kwargs: Any,
) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
r"""Create the dataset in a temporary directory.
Args:
config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset. If omitted, the
default configuration is used.
inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
creating the dataset.
disable_download_extract (Optional[bool]): If ``True`` disable download and extract logic while creating
the dataset. If ``None`` (default) this takes the same value as ``inject_fake_data``.
**kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they
overlap with ``config``.
Yields:
dataset (torchvision.dataset.VisionDataset): Dataset.
info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
for details.
"""
if config is None:
config = self.CONFIGS[0].copy()
special_kwargs, other_kwargs = self._split_kwargs(kwargs)
config.update(other_kwargs)
if disable_download_extract is None:
disable_download_extract = inject_fake_data
with get_tmp_dir() as tmpdir:
output = self.inject_fake_data(tmpdir, config) if inject_fake_data else None
if output is None:
raise UsageError(
"The method 'inject_fake_data' needs to return at least an integer indicating the number of "
"examples for the current configuration."
)
if isinstance(output, collections.abc.Sequence) and len(output) == 2:
args, info = output
else:
args = (tmpdir,)
info = output
if isinstance(info, int):
info = dict(num_examples=info)
elif isinstance(info, dict):
if "num_examples" not in info:
raise UsageError(
"The information dictionary returned by the method 'inject_fake_data' must contain a "
"'num_examples' field that holds the number of examples for the current configuration."
)
else:
raise UsageError(
f"The additional information returned by the method 'inject_fake_data' must be either an integer "
f"indicating the number of examples for the current configuration or a dictionary with the the "
f"same content. Got {type(info)} instead."
)
cm = self._disable_download_extract if disable_download_extract else nullcontext
with cm(special_kwargs), disable_console_output():
dataset = self.DATASET_CLASS(*args, **config, **special_kwargs)
yield dataset, info
@classmethod
def setUpClass(cls):
cls._verify_required_public_class_attributes()
cls._populate_private_class_attributes()
cls._process_optional_public_class_attributes()
super().setUpClass()
@classmethod
def _verify_required_public_class_attributes(cls):
if cls.DATASET_CLASS is None:
raise UsageError(
"The class attribute 'DATASET_CLASS' needs to be overwritten. "
"It should contain the class of the dataset to be tested."
)
if cls.FEATURE_TYPES is None:
raise UsageError(
"The class attribute 'FEATURE_TYPES' needs to be overwritten. "
"It should contain a sequence of types that the dataset returns when accessed by index."
)
@classmethod
def _populate_private_class_attributes(cls):
argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__)
cls._HAS_SPECIAL_KWARG = {name for name in cls._SPECIAL_KWARGS if name in argspec.args}
@classmethod
def _process_optional_public_class_attributes(cls):
argspec = inspect.getfullargspec(cls.DATASET_CLASS.__init__)
if cls.CONFIGS is None:
config = {
kwarg: default
for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults)
if kwarg not in cls._SPECIAL_KWARGS
}
cls.CONFIGS = (config,)
if cls.REQUIRED_PACKAGES is not None:
try:
for pkg in cls.REQUIRED_PACKAGES:
importlib.import_module(pkg)
except ImportError as error:
raise unittest.SkipTest(
f"The package '{error.name}' is required to load the dataset '{cls.DATASET_CLASS.__name__}' but is "
f"not installed."
)
def _split_kwargs(self, kwargs):
special_kwargs = kwargs.copy()
other_kwargs = {key: special_kwargs.pop(key) for key in set(special_kwargs.keys()) - self._SPECIAL_KWARGS}
return special_kwargs, other_kwargs
@contextlib.contextmanager
def _disable_download_extract(self, special_kwargs):
inject_download_kwarg = "download" in self._HAS_SPECIAL_KWARG and "download" not in special_kwargs
if inject_download_kwarg:
special_kwargs["download"] = False
module = inspect.getmodule(self.DATASET_CLASS).__name__
with contextlib.ExitStack() as stack:
mocks = {}
for function, kwargs in itertools.chain(
zip(self._CHECK_FUNCTIONS, [dict(return_value=True)] * len(self._CHECK_FUNCTIONS)),
zip(self._DOWNLOAD_EXTRACT_FUNCTIONS, [dict()] * len(self._DOWNLOAD_EXTRACT_FUNCTIONS)),
):
with contextlib.suppress(AttributeError):
patcher = unittest.mock.patch(f"{module}.{function}", **kwargs)
mocks[function] = stack.enter_context(patcher)
try:
yield mocks
finally:
if inject_download_kwarg:
del special_kwargs["download"]
def test_not_found(self):
with self.assertRaises(RuntimeError):
with self.create_dataset(inject_fake_data=False):
pass
def test_smoke(self):
with self.create_dataset() as (dataset, _):
self.assertIsInstance(dataset, torchvision.datasets.VisionDataset)
@test_all_configs
def test_str_smoke(self, config):
with self.create_dataset(config) as (dataset, _):
self.assertIsInstance(str(dataset), str)
@test_all_configs
def test_feature_types(self, config):
with self.create_dataset(config) as (dataset, _):
example = dataset[0]
actual = len(example)
expected = len(self.FEATURE_TYPES)
self.assertEqual(
actual,
expected,
f"The number of the returned features does not match the the number of elements in in FEATURE_TYPES: "
f"{actual} != {expected}",
)
for idx, (feature, expected_feature_type) in enumerate(zip(example, self.FEATURE_TYPES)):
with self.subTest(idx=idx):
self.assertIsInstance(feature, expected_feature_type)
@test_all_configs
def test_num_examples(self, config):
with self.create_dataset(config) as (dataset, info):
self.assertEqual(len(dataset), info["num_examples"])
@test_all_configs
def test_transforms(self, config):
mock = unittest.mock.Mock(wraps=lambda *args: args[0] if len(args) == 1 else args)
for kwarg in self._TRANSFORM_KWARGS:
if kwarg not in self._HAS_SPECIAL_KWARG:
continue
mock.reset_mock()
with self.subTest(kwarg=kwarg):
with self.create_dataset(config, **{kwarg: mock}) as (dataset, _):
dataset[0]
mock.assert_called()
class ImageDatasetTestCase(DatasetTestCase):
"""Abstract base class for image dataset testcases.
- Overwrites the FEATURE_TYPES class attribute to expect a :class:`PIL.Image.Image` and an integer label.
"""
FEATURE_TYPES = (PIL.Image.Image, int)
@contextlib.contextmanager
def create_dataset(
self,
config: Optional[Dict[str, Any]] = None,
inject_fake_data: bool = True,
disable_download_extract: Optional[bool] = None,
**kwargs: Any,
) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
with super().create_dataset(
config=config,
inject_fake_data=inject_fake_data,
disable_download_extract=disable_download_extract,
**kwargs,
) as (dataset, info):
# PIL.Image.open() only loads the image meta data upfront and keeps the file open until the first access
# to the pixel data occurs. Trying to delete such a file results in an PermissionError on Windows. Thus, we
# force-load opened images.
# This problem only occurs during testing since some tests, e.g. DatasetTestCase.test_feature_types open an
# image, but never use the underlying data. During normal operation it is reasonable to assume that the
# user wants to work with the image he just opened rather than deleting the underlying file.
with self._force_load_images():
yield dataset, info
@contextlib.contextmanager
def _force_load_images(self):
open = PIL.Image.open
def new(fp, *args, **kwargs):
image = open(fp, *args, **kwargs)
if isinstance(fp, (str, pathlib.Path)):
image.load()
return image
with unittest.mock.patch("PIL.Image.open", new=new):
yield
class VideoDatasetTestCase(DatasetTestCase):
"""Abstract base class for video dataset testcases.
- Overwrites the FEATURE_TYPES class attribute to expect two :class:`torch.Tensor` s for the video and audio as
well as an integer label.
- Overwrites the REQUIRED_PACKAGES class attribute to require PyAV (``av``).
"""
FEATURE_TYPES = (torch.Tensor, torch.Tensor, int)
REQUIRED_PACKAGES = ("av",)
def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
r"""Create a random uint8 tensor.
Args:
size (Sequence[int]): Size of the tensor.
"""
return torch.randint(0, 256, size, dtype=torch.uint8)
def create_image_file(
root: Union[pathlib.Path, str], name: Union[pathlib.Path, str], size: Union[Sequence[int], int] = 10, **kwargs: Any
) -> pathlib.Path:
"""Create an image file from random data.
Args:
root (Union[str, pathlib.Path]): Root directory the image file will be placed in.
name (Union[str, pathlib.Path]): Name of the image file.
size (Union[Sequence[int], int]): Size of the image that represents the ``(num_channels, height, width)``. If
scalar, the value is used for the height and width. If not provided, three channels are assumed.
kwargs (Any): Additional parameters passed to :meth:`PIL.Image.Image.save`.
Returns:
pathlib.Path: Path to the created image file.
"""
if isinstance(size, int):
size = (size, size)
if len(size) == 2:
size = (3, *size)
if len(size) != 3:
raise UsageError(
f"The 'size' argument should either be an int or a sequence of length 2 or 3. Got {len(size)} instead"
)
image = create_image_or_video_tensor(size)
file = pathlib.Path(root) / name
PIL.Image.fromarray(image.permute(2, 1, 0).numpy()).save(file)
return file
def create_image_folder(
root: Union[pathlib.Path, str],
name: Union[pathlib.Path, str],
file_name_fn: Callable[[int], str],
num_examples: int,
size: Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]] = None,
**kwargs: Any,
) -> List[pathlib.Path]:
"""Create a folder of random images.
Args:
root (Union[str, pathlib.Path]): Root directory the image folder will be placed in.
name (Union[str, pathlib.Path]): Name of the image folder.
file_name_fn (Callable[[int], str]): Should return a file name if called with the file index.
num_examples (int): Number of images to create.
size (Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]]): Size of the images. If
callable, will be called with the index of the corresponding file. If omitted, a random height and width
between 3 and 10 pixels is selected on a per-image basis.
kwargs (Any): Additional parameters passed to :func:`create_image_file`.
Returns:
List[pathlib.Path]: Paths to all created image files.
.. seealso::
- :func:`create_image_file`
"""
if size is None:
def size(idx: int) -> Tuple[int, int, int]:
num_channels = 3
height, width = torch.randint(3, 11, size=(2,), dtype=torch.int).tolist()
return (num_channels, height, width)
root = pathlib.Path(root) / name
os.makedirs(root)
return [
create_image_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs)
for idx in range(num_examples)
]
@requires_lazy_imports("av")
def create_video_file(
root: Union[pathlib.Path, str],
name: Union[pathlib.Path, str],
size: Union[Sequence[int], int] = (1, 3, 10, 10),
fps: float = 25,
**kwargs: Any,
) -> pathlib.Path:
"""Create an video file from random data.
Args:
root (Union[str, pathlib.Path]): Root directory the video file will be placed in.
name (Union[str, pathlib.Path]): Name of the video file.
size (Union[Sequence[int], int]): Size of the video that represents the
``(num_frames, num_channels, height, width)``. If scalar, the value is used for the height and width.
If not provided, ``num_frames=1`` and ``num_channels=3`` are assumed.
fps (float): Frame rate in frames per second.
kwargs (Any): Additional parameters passed to :func:`torchvision.io.write_video`.
Returns:
pathlib.Path: Path to the created image file.
Raises:
UsageError: If PyAV is not available.
"""
if isinstance(size, int):
size = (size, size)
if len(size) == 2:
size = (3, *size)
if len(size) == 3:
size = (1, *size)
if len(size) != 4:
raise UsageError(
f"The 'size' argument should either be an int or a sequence of length 2, 3, or 4. Got {len(size)} instead"
)
video = create_image_or_video_tensor(size)
file = pathlib.Path(root) / name
torchvision.io.write_video(str(file), video.permute(0, 2, 3, 1), fps, **kwargs)
return file
@requires_lazy_imports("av")
def create_video_folder(
root: Union[str, pathlib.Path],
name: Union[str, pathlib.Path],
file_name_fn: Callable[[int], str],
num_examples: int,
size: Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]] = None,
fps=25,
**kwargs,
) -> List[pathlib.Path]:
"""Create a folder of random videos.
Args:
root (Union[str, pathlib.Path]): Root directory the image folder will be placed in.
name (Union[str, pathlib.Path]): Name of the image folder.
file_name_fn (Callable[[int], str]): Should return a file name if called with the file index.
num_examples (int): Number of images to create.
size (Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]]): Size of the videos. If
callable, will be called with the index of the corresponding file. If omitted, a random even height and
width between 4 and 10 pixels is selected on a per-video basis.
fps (float): Frame rate in frames per second.
kwargs (Any): Additional parameters passed to :func:`create_video_file`.
Returns:
List[pathlib.Path]: Paths to all created video files.
Raises:
UsageError: If PyAV is not available.
.. seealso::
- :func:`create_video_file`
"""
if size is None:
def size(idx):
num_frames = 1
num_channels = 3
# The 'libx264' video codec, which is the default of torchvision.io.write_video, requires the height and
# width of the video to be divisible by 2.
height, width = (torch.randint(2, 6, size=(2,), dtype=torch.int) * 2).tolist()
return (num_frames, num_channels, height, width)
root = pathlib.Path(root) / name
os.makedirs(root)
return [
create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size)
for idx in range(num_examples)
]
...@@ -15,6 +15,10 @@ from fakedata_generation import mnist_root, cifar_root, imagenet_root, \ ...@@ -15,6 +15,10 @@ from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from urllib.request import Request, urlopen from urllib.request import Request, urlopen
import itertools import itertools
import datasets_utils
import pathlib
import pickle
from torchvision import datasets
try: try:
...@@ -466,5 +470,95 @@ class STL10Tester(DatasetTestcase): ...@@ -466,5 +470,95 @@ class STL10Tester(DatasetTestcase):
self.assertIsInstance(repr(dataset), str) self.assertIsInstance(repr(dataset), str)
if __name__ == '__main__': class Caltech256TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Caltech256
def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir) / "caltech256" / "256_ObjectCategories"
categories = ((1, "ak47"), (127, "laptop-101"), (257, "clutter"))
num_images_per_category = 2
for idx, category in categories:
datasets_utils.create_image_folder(
tmpdir,
name=f"{idx:03d}.{category}",
file_name_fn=lambda image_idx: f"{idx:03d}_{image_idx + 1:04d}.jpg",
num_examples=num_images_per_category,
)
return num_images_per_category * len(categories)
class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CIFAR10
CONFIGS = datasets_utils.combinations_grid(train=(True, False))
_VERSION_CONFIG = dict(
base_folder="cifar-10-batches-py",
train_files=tuple(f"data_batch_{idx}" for idx in range(1, 6)),
test_files=("test_batch",),
labels_key="labels",
meta_file="batches.meta",
num_categories=10,
categories_key="label_names",
)
def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir) / self._VERSION_CONFIG["base_folder"]
os.makedirs(tmpdir)
num_images_per_file = 1
for name in itertools.chain(self._VERSION_CONFIG["train_files"], self._VERSION_CONFIG["test_files"]):
self._create_batch_file(tmpdir, name, num_images_per_file)
categories = self._create_meta_file(tmpdir)
return dict(
num_examples=num_images_per_file
* len(self._VERSION_CONFIG["train_files"] if config["train"] else self._VERSION_CONFIG["test_files"]),
categories=categories,
)
def _create_batch_file(self, root, name, num_images):
data = datasets_utils.create_image_or_video_tensor((num_images, 32 * 32 * 3))
labels = np.random.randint(0, self._VERSION_CONFIG["num_categories"], size=num_images).tolist()
self._create_binary_file(root, name, {"data": data, self._VERSION_CONFIG["labels_key"]: labels})
def _create_meta_file(self, root):
categories = [
f"{idx:0{len(str(self._VERSION_CONFIG['num_categories'] - 1))}d}"
for idx in range(self._VERSION_CONFIG["num_categories"])
]
self._create_binary_file(
root, self._VERSION_CONFIG["meta_file"], {self._VERSION_CONFIG["categories_key"]: categories}
)
return categories
def _create_binary_file(self, root, name, content):
with open(pathlib.Path(root) / name, "wb") as fh:
pickle.dump(content, fh)
def test_class_to_idx(self):
with self.create_dataset() as (dataset, info):
expected = {category: label for label, category in enumerate(info["categories"])}
actual = dataset.class_to_idx
self.assertEqual(actual, expected)
class CIFAR100(CIFAR10TestCase):
DATASET_CLASS = datasets.CIFAR100
_VERSION_CONFIG = dict(
base_folder="cifar-100-python",
train_files=("train",),
test_files=("test",),
labels_key="fine_labels",
meta_file="meta",
num_categories=100,
categories_key="fine_label_names",
)
if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment