datasets_utils.py 39.9 KB
Newer Older
1
2
3
4
5
6
7
import contextlib
import functools
import importlib
import inspect
import itertools
import os
import pathlib
8
import platform
9
import random
10
import shutil
11
import string
12
import struct
13
import tarfile
14
15
import unittest
import unittest.mock
16
import zipfile
17
from collections import defaultdict
18
19
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union

20
21
import numpy as np

22
23
import PIL
import PIL.Image
24
import pytest
25
26
27
import torch
import torchvision.datasets
import torchvision.io
28
from common_utils import disable_console_output, get_tmp_dir
29
from torch.utils._pytree import tree_any
30
31
32
from torch.utils.data import DataLoader
from torchvision import tv_tensors
from torchvision.datasets import wrap_dataset_for_transforms_v2
33
from torchvision.transforms.functional import get_dimensions
34
from torchvision.transforms.v2.functional import get_size
35
36
37
38
39
40
41
42
43
44
45
46
47
48


__all__ = [
    "UsageError",
    "lazy_importer",
    "test_all_configs",
    "DatasetTestCase",
    "ImageDatasetTestCase",
    "VideoDatasetTestCase",
    "create_image_or_video_tensor",
    "create_image_file",
    "create_image_folder",
    "create_video_file",
    "create_video_folder",
49
50
    "make_tar",
    "make_zip",
51
    "create_random_string",
52
53
54
]


55
class UsageError(Exception):
56
57
58
59
    """Should be raised in case an error happens in the setup rather than the test."""


class LazyImporter:
Prabhat Roy's avatar
Prabhat Roy committed
60
    r"""Lazy importer for additional dependencies.
61
62
63
64
65
66
67
68
69
70
71

    Some datasets require additional packages that are no direct dependencies of torchvision. Instances of this class
    provide modules listed in MODULES as attributes. They are only imported when accessed.

    """
    MODULES = (
        "av",
        "lmdb",
        "pycocotools",
        "requests",
        "scipy.io",
Philip Meier's avatar
Philip Meier committed
72
        "scipy.sparse",
73
        "h5py",
74
75
76
    )

    def __init__(self):
77
        modules = defaultdict(list)
78
        for module in self.MODULES:
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
            module, *submodules = module.split(".", 1)
            if submodules:
                modules[module].append(submodules[0])
            else:
                # This introduces the module so that it is known when we later iterate over the dictionary.
                modules.__missing__(module)

        for module, submodules in modules.items():
            # We need the quirky 'module=module' and submodules=submodules arguments to the lambda since otherwise the
            # lookup for these would happen at runtime rather than at definition. Thus, without it, every property
            # would try to import the last item in 'modules'
            setattr(
                type(self),
                module,
                property(lambda self, module=module, submodules=submodules: LazyImporter._import(module, submodules)),
            )
95
96

    @staticmethod
97
    def _import(package, subpackages):
98
        try:
99
            module = importlib.import_module(package)
100
101
        except ImportError as error:
            raise UsageError(
102
103
                f"Failed to import module '{package}'. "
                f"This probably means that the current test case needs '{package}' installed, "
104
                f"but it is not a dependency of torchvision. "
105
                f"You need to install it manually, for example 'pip install {package}'."
106
107
            ) from error

108
109
110
111
112
        for name in subpackages:
            importlib.import_module(f".{name}", package=package)

        return module

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

lazy_importer = LazyImporter()


def requires_lazy_imports(*modules):
    def outer_wrapper(fn):
        @functools.wraps(fn)
        def inner_wrapper(*args, **kwargs):
            for module in modules:
                getattr(lazy_importer, module.replace(".", "_"))
            return fn(*args, **kwargs)

        return inner_wrapper

    return outer_wrapper


def test_all_configs(test):
    """Decorator to run test against all configurations.

133
134
135
136
    Add this as decorator to an arbitrary test to run it against all configurations. This includes
    :attr:`DatasetTestCase.DEFAULT_CONFIG` and :attr:`DatasetTestCase.ADDITIONAL_CONFIGS`.

    The current configuration is provided as the first parameter for the test:
137
138
139

    .. code-block::

140
        @test_all_configs()
141
142
        def test_foo(self, config):
            pass
143
144
145

    .. note::

146
        This will try to remove duplicate configurations. During this process it will not preserve a potential
147
        ordering of the configurations or an inner ordering of a configuration.
148
149
    """

150
151
    def maybe_remove_duplicates(configs):
        try:
152
            return [dict(config_) for config_ in {tuple(sorted(config.items())) for config in configs}]
153
154
        except TypeError:
            # A TypeError will be raised if a value of any config is not hashable, e.g. a list. In that case duplicate
155
            # removal would be a lot more elaborate, and we simply bail out.
156
157
            return configs

158
159
    @functools.wraps(test)
    def wrapper(self):
160
161
162
163
164
165
166
167
168
169
170
171
        configs = []
        if self.DEFAULT_CONFIG is not None:
            configs.append(self.DEFAULT_CONFIG)
        if self.ADDITIONAL_CONFIGS is not None:
            configs.extend(self.ADDITIONAL_CONFIGS)

        if not configs:
            configs = [self._KWARG_DEFAULTS.copy()]
        else:
            configs = maybe_remove_duplicates(configs)

        for config in configs:
172
173
174
175
176
177
178
179
180
181
182
183
184
185
            with self.subTest(**config):
                test(self, config)

    return wrapper


class DatasetTestCase(unittest.TestCase):
    """Abstract base class for all dataset testcases.

    You have to overwrite the following class attributes:

        - DATASET_CLASS (torchvision.datasets.VisionDataset): Class of dataset to be tested.
        - FEATURE_TYPES (Sequence[Any]): Types of the elements returned by index access of the dataset. Instead of
            providing these manually, you can instead subclass ``ImageDatasetTestCase`` or ``VideoDatasetTestCase```to
186
187
            get a reasonable default, that should work for most cases. Each entry of the sequence may be a tuple,
            to indicate multiple possible values.
188
189
190

    Optionally, you can overwrite the following class attributes:

191
192
193
194
195
196
197
        - DEFAULT_CONFIG (Dict[str, Any]): Config that will be used by default. If omitted, this defaults to all
            keyword arguments of the dataset minus ``transform``, ``target_transform``, ``transforms``, and
            ``download``. Overwrite this if you want to use a default value for a parameter for which the dataset does
            not provide one.
        - ADDITIONAL_CONFIGS (Sequence[Dict[str, Any]]): Additional configs that should be tested. Each dictionary can
            contain an arbitrary combination of dataset parameters that are **not** ``transform``, ``target_transform``,
            ``transforms``, or ``download``.
198
199
200
201
202
203
204
205
206
207
        - REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
            available, the tests are skipped.

    Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on.
    The fake data should resemble the original data as close as necessary, while containing only few examples. During
    the creation of the dataset check-, download-, and extract-functions from ``torchvision.datasets.utils`` are
    disabled.

    Without further configuration, the testcase will test if

208
209
    1. the dataset raises a :class:`FileNotFoundError` or a :class:`RuntimeError` if the data files are not found or
       corrupted,
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    2. the dataset inherits from `torchvision.datasets.VisionDataset`,
    3. the dataset can be turned into a string,
    4. the feature types of a returned example matches ``FEATURE_TYPES``,
    5. the number of examples matches the injected fake data, and
    6. the dataset calls ``transform``, ``target_transform``, or ``transforms`` if available when accessing data.

    Case 3. to 6. are tested against all configurations in ``CONFIGS``.

    To add dataset-specific tests, create a new method that takes no arguments with ``test_`` as a name prefix:

    .. code-block::

        def test_foo(self):
            pass

    If you want to run the test against all configs, add the ``@test_all_configs`` decorator to the definition and
    accept a single argument:

    .. code-block::

        @test_all_configs
        def test_bar(self, config):
            pass

    Within the test you can use the ``create_dataset()`` method that yields the dataset as well as additional
    information provided by the ``ìnject_fake_data()`` method:

    .. code-block::

        def test_baz(self):
            with self.create_dataset() as (dataset, info):
                pass
    """

    DATASET_CLASS = None
    FEATURE_TYPES = None

247
248
    DEFAULT_CONFIG = None
    ADDITIONAL_CONFIGS = None
249
250
    REQUIRED_PACKAGES = None

251
    # These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
252
253
254
255
256
    _TRANSFORM_KWARGS = {
        "transform",
        "target_transform",
        "transforms",
    }
257
    # These keyword arguments get a 'special' treatment and should not be set in DEFAULT_CONFIG or ADDITIONAL_CONFIGS.
258
259
260
261
    _SPECIAL_KWARGS = {
        *_TRANSFORM_KWARGS,
        "download",
    }
262
263
264
265
266
267
268

    # These fields are populated during setupClass() within _populate_private_class_attributes()

    # This will be a dictionary containing all keyword arguments with their respective default values extracted from
    # the dataset constructor.
    _KWARG_DEFAULTS = None
    # This will be a set of all _SPECIAL_KWARGS that the dataset constructor takes.
269
270
    _HAS_SPECIAL_KWARG = None

271
    # These functions are disabled during dataset creation in create_dataset().
272
273
274
275
276
277
278
279
280
281
282
    _CHECK_FUNCTIONS = {
        "check_md5",
        "check_integrity",
    }
    _DOWNLOAD_EXTRACT_FUNCTIONS = {
        "download_url",
        "download_file_from_google_drive",
        "extract_archive",
        "download_and_extract_archive",
    }

283
284
285
286
287
288
    def dataset_args(self, tmpdir: str, config: Dict[str, Any]) -> Sequence[Any]:
        """Define positional arguments passed to the dataset.

        .. note::

            The default behavior is only valid if the dataset to be tested has ``root`` as the only required parameter.
289
            Otherwise, you need to overwrite this method.
290
291
292
293

        Args:
            tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
                to be created and in turn also for the fake data injected here.
294
295
            config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
                fields for all dataset parameters with default values.
296
297
298
299
300
301
302

        Returns:
            (Tuple[str]): ``tmpdir`` which corresponds to ``root`` for most datasets.
        """
        return (tmpdir,)

    def inject_fake_data(self, tmpdir: str, config: Dict[str, Any]) -> Union[int, Dict[str, Any]]:
303
304
        """Inject fake data for dataset into a temporary directory.

305
306
307
308
        During the creation of the dataset the download and extract logic is disabled. Thus, the fake data injected
        here needs to resemble the raw data, i.e. the state of the dataset directly after the files are downloaded and
        potentially extracted.

309
310
311
        Args:
            tmpdir (str): Path to a temporary directory. For most cases this acts as root directory for the dataset
                to be created and in turn also for the fake data injected here.
312
313
            config (Dict[str, Any]): Configuration that will be passed to the dataset constructor. It provides at least
                fields for all dataset parameters with default values.
314
315
316

        Needs to return one of the following:

317
            1. (int): Number of examples in the dataset to be created, or
318
            2. (Dict[str, Any]): Additional information about the injected fake data. Must contain the field
319
                ``"num_examples"`` that corresponds to the number of examples in the dataset to be created.
320
321
322
323
324
325
326
327
        """
        raise NotImplementedError("You need to provide fake data in order for the tests to run.")

    @contextlib.contextmanager
    def create_dataset(
        self,
        config: Optional[Dict[str, Any]] = None,
        inject_fake_data: bool = True,
328
        patch_checks: Optional[bool] = None,
329
330
331
332
        **kwargs: Any,
    ) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
        r"""Create the dataset in a temporary directory.

333
334
335
336
337
338
339
340
        The configuration passed to the dataset is populated to contain at least all parameters with default values.
        For this the following order of precedence is used:

        1. Parameters in :attr:`kwargs`.
        2. Configuration in :attr:`config`.
        3. Configuration in :attr:`~DatasetTestCase.DEFAULT_CONFIG`.
        4. Default parameters of the dataset.

341
        Args:
342
            config (Optional[Dict[str, Any]]): Configuration that will be used to create the dataset.
343
344
            inject_fake_data (bool): If ``True`` (default) inject the fake data with :meth:`.inject_fake_data` before
                creating the dataset.
345
346
            patch_checks (Optional[bool]): If ``True`` disable integrity check logic while creating the dataset. If
                omitted defaults to the same value as ``inject_fake_data``.
347
348
349
350
351
352
353
354
            **kwargs (Any): Additional parameters passed to the dataset. These parameters take precedence in case they
                overlap with ``config``.

        Yields:
            dataset (torchvision.dataset.VisionDataset): Dataset.
            info (Dict[str, Any]): Additional information about the injected fake data. See :meth:`.inject_fake_data`
                for details.
        """
355
356
        if patch_checks is None:
            patch_checks = inject_fake_data
357
358

        special_kwargs, other_kwargs = self._split_kwargs(kwargs)
359
360
361
362
363
364
365
366
367

        complete_config = self._KWARG_DEFAULTS.copy()
        if self.DEFAULT_CONFIG:
            complete_config.update(self.DEFAULT_CONFIG)
        if config:
            complete_config.update(config)
        if other_kwargs:
            complete_config.update(other_kwargs)

368
369
        if "download" in self._HAS_SPECIAL_KWARG and special_kwargs.get("download", False):
            # override download param to False param if its default is truthy
370
            special_kwargs["download"] = False
371

372
373
374
        patchers = self._patch_download_extract()
        if patch_checks:
            patchers.update(self._patch_checks())
375
376

        with get_tmp_dir() as tmpdir:
377
378
            args = self.dataset_args(tmpdir, complete_config)
            info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None
379

380
            with self._maybe_apply_patches(patchers), disable_console_output():
381
                dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs)
382

383
            yield dataset, info
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406

    @classmethod
    def setUpClass(cls):
        cls._verify_required_public_class_attributes()
        cls._populate_private_class_attributes()
        cls._process_optional_public_class_attributes()
        super().setUpClass()

    @classmethod
    def _verify_required_public_class_attributes(cls):
        if cls.DATASET_CLASS is None:
            raise UsageError(
                "The class attribute 'DATASET_CLASS' needs to be overwritten. "
                "It should contain the class of the dataset to be tested."
            )
        if cls.FEATURE_TYPES is None:
            raise UsageError(
                "The class attribute 'FEATURE_TYPES' needs to be overwritten. "
                "It should contain a sequence of types that the dataset returns when accessed by index."
            )

    @classmethod
    def _populate_private_class_attributes(cls):
407
408
409
410
411
412
413
414
415
416
417
        defaults = []
        for cls_ in cls.DATASET_CLASS.__mro__:
            if cls_ is torchvision.datasets.VisionDataset:
                break

            argspec = inspect.getfullargspec(cls_.__init__)

            if not argspec.defaults:
                continue

            defaults.append(
418
419
                {
                    kwarg: default
420
                    for kwarg, default in zip(argspec.args[-len(argspec.defaults) :], argspec.defaults)
421
422
                    if not kwarg.startswith("_")
                }
423
424
425
426
427
428
429
430
            )

            if not argspec.varkw:
                break

        kwarg_defaults = dict()
        for config in reversed(defaults):
            kwarg_defaults.update(config)
431

432
433
434
435
        has_special_kwargs = set()
        for name in cls._SPECIAL_KWARGS:
            if name not in kwarg_defaults:
                continue
436

437
438
439
440
441
            del kwarg_defaults[name]
            has_special_kwargs.add(name)

        cls._KWARG_DEFAULTS = kwarg_defaults
        cls._HAS_SPECIAL_KWARG = has_special_kwargs
442
443
444

    @classmethod
    def _process_optional_public_class_attributes(cls):
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
        def check_config(config, name):
            special_kwargs = tuple(f"'{name}'" for name in cls._SPECIAL_KWARGS if name in config)
            if special_kwargs:
                raise UsageError(
                    f"{name} contains a value for the parameter(s) {', '.join(special_kwargs)}. "
                    f"These are handled separately by the test case and should not be set here. "
                    f"If you need to test some custom behavior regarding these parameters, "
                    f"you need to write a custom test (*not* test case), e.g. test_custom_transform()."
                )

        if cls.DEFAULT_CONFIG is not None:
            check_config(cls.DEFAULT_CONFIG, "DEFAULT_CONFIG")

        if cls.ADDITIONAL_CONFIGS is not None:
            for idx, config in enumerate(cls.ADDITIONAL_CONFIGS):
                check_config(config, f"CONFIGS[{idx}]")

        if cls.REQUIRED_PACKAGES:
            missing_pkgs = []
            for pkg in cls.REQUIRED_PACKAGES:
                try:
466
                    importlib.import_module(pkg)
467
468
469
470
                except ImportError:
                    missing_pkgs.append(f"'{pkg}'")

            if missing_pkgs:
471
                raise unittest.SkipTest(
472
473
                    f"The package(s) {', '.join(missing_pkgs)} are required to load the dataset "
                    f"'{cls.DATASET_CLASS.__name__}', but are not installed."
474
475
476
477
478
479
480
                )

    def _split_kwargs(self, kwargs):
        special_kwargs = kwargs.copy()
        other_kwargs = {key: special_kwargs.pop(key) for key in set(special_kwargs.keys()) - self._SPECIAL_KWARGS}
        return special_kwargs, other_kwargs

481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    def _inject_fake_data(self, tmpdir, config):
        info = self.inject_fake_data(tmpdir, config)
        if info is None:
            raise UsageError(
                "The method 'inject_fake_data' needs to return at least an integer indicating the number of "
                "examples for the current configuration."
            )
        elif isinstance(info, int):
            info = dict(num_examples=info)
        elif not isinstance(info, dict):
            raise UsageError(
                f"The additional information returned by the method 'inject_fake_data' must be either an "
                f"integer indicating the number of examples for the current configuration or a dictionary with "
                f"the same content. Got {type(info)} instead."
            )
        elif "num_examples" not in info:
            raise UsageError(
                "The information dictionary returned by the method 'inject_fake_data' must contain a "
                "'num_examples' field that holds the number of examples for the current configuration."
            )
        return info

    def _patch_download_extract(self):
        module = inspect.getmodule(self.DATASET_CLASS).__name__
        return {unittest.mock.patch(f"{module}.{function}") for function in self._DOWNLOAD_EXTRACT_FUNCTIONS}
506

507
    def _patch_checks(self):
508
        module = inspect.getmodule(self.DATASET_CLASS).__name__
509
510
511
512
        return {unittest.mock.patch(f"{module}.{function}", return_value=True) for function in self._CHECK_FUNCTIONS}

    @contextlib.contextmanager
    def _maybe_apply_patches(self, patchers):
513
514
        with contextlib.ExitStack() as stack:
            mocks = {}
515
            for patcher in patchers:
516
                with contextlib.suppress(AttributeError):
517
518
                    mocks[patcher.target] = stack.enter_context(patcher)
            yield mocks
519

520
    def test_not_found_or_corrupted(self):
521
        with pytest.raises((FileNotFoundError, RuntimeError)):
522
523
524
525
526
            with self.create_dataset(inject_fake_data=False):
                pass

    def test_smoke(self):
        with self.create_dataset() as (dataset, _):
527
            assert isinstance(dataset, torchvision.datasets.VisionDataset)
528
529
530
531

    @test_all_configs
    def test_str_smoke(self, config):
        with self.create_dataset(config) as (dataset, _):
532
            assert isinstance(str(dataset), str)
533
534
535
536
537
538

    @test_all_configs
    def test_feature_types(self, config):
        with self.create_dataset(config) as (dataset, _):
            example = dataset[0]

539
540
541
            if len(self.FEATURE_TYPES) > 1:
                actual = len(example)
                expected = len(self.FEATURE_TYPES)
542
543
544
545
                assert (
                    actual == expected
                ), "The number of the returned features does not match the the number of elements in FEATURE_TYPES: "
                f"{actual} != {expected}"
546
547
            else:
                example = (example,)
548
549
550

            for idx, (feature, expected_feature_type) in enumerate(zip(example, self.FEATURE_TYPES)):
                with self.subTest(idx=idx):
551
                    assert isinstance(feature, expected_feature_type)
552
553
554
555

    @test_all_configs
    def test_num_examples(self, config):
        with self.create_dataset(config) as (dataset, info):
556
            assert len(list(dataset)) == len(dataset) == info["num_examples"]
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572

    @test_all_configs
    def test_transforms(self, config):
        mock = unittest.mock.Mock(wraps=lambda *args: args[0] if len(args) == 1 else args)
        for kwarg in self._TRANSFORM_KWARGS:
            if kwarg not in self._HAS_SPECIAL_KWARG:
                continue

            mock.reset_mock()

            with self.subTest(kwarg=kwarg):
                with self.create_dataset(config, **{kwarg: mock}) as (dataset, _):
                    dataset[0]

                mock.assert_called()

573
574
575
    @test_all_configs
    def test_transforms_v2_wrapper(self, config):
        try:
576
            with self.create_dataset(config) as (dataset, info):
577
578
579
580
581
582
583
584
585
586
587
588
                for target_keys in [None, "all"]:
                    if target_keys is not None and self.DATASET_CLASS not in {
                        torchvision.datasets.CocoDetection,
                        torchvision.datasets.VOCDetection,
                        torchvision.datasets.Kitti,
                        torchvision.datasets.WIDERFace,
                    }:
                        with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"):
                            wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
                        continue

                    wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
589
590
                    assert isinstance(wrapped_dataset, self.DATASET_CLASS)
                    assert len(wrapped_dataset) == info["num_examples"]
591

592
                    wrapped_sample = wrapped_dataset[0]
593
                    assert tree_any(
594
                        lambda item: isinstance(item, (tv_tensors.TVTensor, PIL.Image.Image)), wrapped_sample
595
                    )
596
        except TypeError as error:
597
598
599
            msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
            if str(error).startswith(msg):
                pytest.skip(msg)
600
601
602
            raise error
        except RuntimeError as error:
            if "currently not supported by this wrapper" in str(error):
603
                pytest.skip("Config is currently not supported by this wrapper")
604
605
            raise error

606
607
608
609
610
611
612
613
614
615
616
617
618
619

class ImageDatasetTestCase(DatasetTestCase):
    """Abstract base class for image dataset testcases.

    - Overwrites the FEATURE_TYPES class attribute to expect a :class:`PIL.Image.Image` and an integer label.
    """

    FEATURE_TYPES = (PIL.Image.Image, int)

    @contextlib.contextmanager
    def create_dataset(
        self,
        config: Optional[Dict[str, Any]] = None,
        inject_fake_data: bool = True,
620
        patch_checks: Optional[bool] = None,
621
622
623
624
625
        **kwargs: Any,
    ) -> Iterator[Tuple[torchvision.datasets.VisionDataset, Dict[str, Any]]]:
        with super().create_dataset(
            config=config,
            inject_fake_data=inject_fake_data,
626
            patch_checks=patch_checks,
627
628
            **kwargs,
        ) as (dataset, info):
629
            # PIL.Image.open() only loads the image metadata upfront and keeps the file open until the first access
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
            # to the pixel data occurs. Trying to delete such a file results in an PermissionError on Windows. Thus, we
            # force-load opened images.
            # This problem only occurs during testing since some tests, e.g. DatasetTestCase.test_feature_types open an
            # image, but never use the underlying data. During normal operation it is reasonable to assume that the
            # user wants to work with the image he just opened rather than deleting the underlying file.
            with self._force_load_images():
                yield dataset, info

    @contextlib.contextmanager
    def _force_load_images(self):
        open = PIL.Image.open

        def new(fp, *args, **kwargs):
            image = open(fp, *args, **kwargs)
            if isinstance(fp, (str, pathlib.Path)):
                image.load()
            return image

        with unittest.mock.patch("PIL.Image.open", new=new):
            yield


class VideoDatasetTestCase(DatasetTestCase):
    """Abstract base class for video dataset testcases.

Philip Meier's avatar
Philip Meier committed
655
    - Overwrites the 'FEATURE_TYPES' class attribute to expect two :class:`torch.Tensor` s for the video and audio as
656
      well as an integer label.
Philip Meier's avatar
Philip Meier committed
657
658
659
660
    - Overwrites the 'REQUIRED_PACKAGES' class attribute to require PyAV (``av``).
    - Adds the 'DEFAULT_FRAMES_PER_CLIP' class attribute. If no 'frames_per_clip' is provided by 'inject_fake_data()'
        and it is the last parameter without a default value in the dataset constructor, the value of the
        'DEFAULT_FRAMES_PER_CLIP' class attribute is appended to the output.
661
662
663
664
665
    """

    FEATURE_TYPES = (torch.Tensor, torch.Tensor, int)
    REQUIRED_PACKAGES = ("av",)

666
    FRAMES_PER_CLIP = 1
Philip Meier's avatar
Philip Meier committed
667
668
669

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
670
        self.dataset_args = self._set_default_frames_per_clip(self.dataset_args)
Philip Meier's avatar
Philip Meier committed
671

672
    def _set_default_frames_per_clip(self, dataset_args):
Philip Meier's avatar
Philip Meier committed
673
        argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__)
674
        args_without_default = argspec.args[1 : (-len(argspec.defaults) if argspec.defaults else None)]
Philip Meier's avatar
Philip Meier committed
675
676
        frames_per_clip_last = args_without_default[-1] == "frames_per_clip"

677
        @functools.wraps(dataset_args)
Philip Meier's avatar
Philip Meier committed
678
        def wrapper(tmpdir, config):
679
            args = dataset_args(tmpdir, config)
680
            if frames_per_clip_last and len(args) == len(args_without_default) - 1:
681
                args = (*args, self.FRAMES_PER_CLIP)
682
683

            return args
Philip Meier's avatar
Philip Meier committed
684
685
686

        return wrapper

687
688
689
690
691
692
693
694
695
696
697
698
    def test_output_format(self):
        for output_format in ["TCHW", "THWC"]:
            with self.create_dataset(output_format=output_format) as (dataset, _):
                for video, *_ in dataset:
                    if output_format == "TCHW":
                        num_frames, num_channels, *_ = video.shape
                    else:  # output_format == "THWC":
                        num_frames, *_, num_channels = video.shape

                assert num_frames == self.FRAMES_PER_CLIP
                assert num_channels == 3

699
700
701
702
703
704
705
706
707
    @test_all_configs
    def test_transforms_v2_wrapper(self, config):
        # `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly
        # or use the supported `"TCHW"`
        if config.setdefault("output_format", "TCHW") == "THWC":
            return

        super().test_transforms_v2_wrapper.__wrapped__(self, config)

708

709
710
711
712
def _no_collate(batch):
    return batch


713
714
715
716
717
def check_transforms_v2_wrapper_spawn(dataset, expected_size):
    # This check ensures that the wrapped datasets can be used with multiprocessing_context="spawn" in the DataLoader.
    # We also check that transforms are applied correctly as a non-regression test for
    # https://github.com/pytorch/vision/issues/8066
    # Implicitly, this also checks that the wrapped datasets are pickleable.
718

719
720
721
    # To save CI/test time, we only check on Windows where "spawn" is the default
    if platform.system() != "Windows":
        pytest.skip("Multiprocessing spawning is only checked on macOS.")
722
723
724
725
726

    wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)

    dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate)

727
728
729
730
    def resize_was_applied(item):
        # Checking the size of the output ensures that the Resize transform was correctly applied
        return isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)) and get_size(item) == list(
            expected_size
731
732
        )

733
734
735
    for wrapped_sample in dataloader:
        assert tree_any(resize_was_applied, wrapped_sample)

736

737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
    r"""Create a random uint8 tensor.

    Args:
        size (Sequence[int]): Size of the tensor.
    """
    return torch.randint(0, 256, size, dtype=torch.uint8)


def create_image_file(
    root: Union[pathlib.Path, str], name: Union[pathlib.Path, str], size: Union[Sequence[int], int] = 10, **kwargs: Any
) -> pathlib.Path:
    """Create an image file from random data.

    Args:
        root (Union[str, pathlib.Path]): Root directory the image file will be placed in.
        name (Union[str, pathlib.Path]): Name of the image file.
        size (Union[Sequence[int], int]): Size of the image that represents the ``(num_channels, height, width)``. If
            scalar, the value is used for the height and width. If not provided, three channels are assumed.
        kwargs (Any): Additional parameters passed to :meth:`PIL.Image.Image.save`.

    Returns:
        pathlib.Path: Path to the created image file.
    """
    if isinstance(size, int):
        size = (size, size)
    if len(size) == 2:
        size = (3, *size)
    if len(size) != 3:
        raise UsageError(
            f"The 'size' argument should either be an int or a sequence of length 2 or 3. Got {len(size)} instead"
        )

    image = create_image_or_video_tensor(size)
    file = pathlib.Path(root) / name
772
773
774
775
776
777
778

    # torch (num_channels x height x width) -> PIL (width x height x num_channels)
    image = image.permute(2, 1, 0)
    # For grayscale images PIL doesn't use a channel dimension
    if image.shape[2] == 1:
        image = torch.squeeze(image, 2)
    PIL.Image.fromarray(image.numpy()).save(file, **kwargs)
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
    return file


def create_image_folder(
    root: Union[pathlib.Path, str],
    name: Union[pathlib.Path, str],
    file_name_fn: Callable[[int], str],
    num_examples: int,
    size: Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]] = None,
    **kwargs: Any,
) -> List[pathlib.Path]:
    """Create a folder of random images.

    Args:
        root (Union[str, pathlib.Path]): Root directory the image folder will be placed in.
        name (Union[str, pathlib.Path]): Name of the image folder.
        file_name_fn (Callable[[int], str]): Should return a file name if called with the file index.
        num_examples (int): Number of images to create.
        size (Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]]): Size of the images. If
            callable, will be called with the index of the corresponding file. If omitted, a random height and width
            between 3 and 10 pixels is selected on a per-image basis.
        kwargs (Any): Additional parameters passed to :func:`create_image_file`.

    Returns:
        List[pathlib.Path]: Paths to all created image files.

    .. seealso::

        - :func:`create_image_file`
    """
    if size is None:

        def size(idx: int) -> Tuple[int, int, int]:
            num_channels = 3
            height, width = torch.randint(3, 11, size=(2,), dtype=torch.int).tolist()
            return (num_channels, height, width)

    root = pathlib.Path(root) / name
817
    os.makedirs(root, exist_ok=True)
818
819
820
821
822
823
824

    return [
        create_image_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs)
        for idx in range(num_examples)
    ]


825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
def shape_test_for_stereo(
    left: PIL.Image.Image,
    right: PIL.Image.Image,
    disparity: Optional[np.ndarray] = None,
    valid_mask: Optional[np.ndarray] = None,
):
    left_dims = get_dimensions(left)
    right_dims = get_dimensions(right)
    c, h, w = left_dims
    # check that left and right are the same size
    assert left_dims == right_dims
    assert c == 3

    # check that the disparity has the same spatial dimensions
    # as the input
    if disparity is not None:
        assert disparity.ndim == 3
        assert disparity.shape == (1, h, w)

    if valid_mask is not None:
        # check that valid mask is the same size as the disparity
        _, dh, dw = disparity.shape
        mh, mw = valid_mask.shape
        assert dh == mh
        assert dw == mw


852
853
854
855
856
857
858
859
@requires_lazy_imports("av")
def create_video_file(
    root: Union[pathlib.Path, str],
    name: Union[pathlib.Path, str],
    size: Union[Sequence[int], int] = (1, 3, 10, 10),
    fps: float = 25,
    **kwargs: Any,
) -> pathlib.Path:
860
    """Create a video file from random data.
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906

    Args:
        root (Union[str, pathlib.Path]): Root directory the video file will be placed in.
        name (Union[str, pathlib.Path]): Name of the video file.
        size (Union[Sequence[int], int]): Size of the video that represents the
            ``(num_frames, num_channels, height, width)``. If scalar, the value is used for the height and width.
            If not provided, ``num_frames=1`` and ``num_channels=3`` are assumed.
        fps (float): Frame rate in frames per second.
        kwargs (Any): Additional parameters passed to :func:`torchvision.io.write_video`.

    Returns:
        pathlib.Path: Path to the created image file.

    Raises:
        UsageError: If PyAV is not available.
    """
    if isinstance(size, int):
        size = (size, size)
    if len(size) == 2:
        size = (3, *size)
    if len(size) == 3:
        size = (1, *size)
    if len(size) != 4:
        raise UsageError(
            f"The 'size' argument should either be an int or a sequence of length 2, 3, or 4. Got {len(size)} instead"
        )

    video = create_image_or_video_tensor(size)
    file = pathlib.Path(root) / name
    torchvision.io.write_video(str(file), video.permute(0, 2, 3, 1), fps, **kwargs)
    return file


@requires_lazy_imports("av")
def create_video_folder(
    root: Union[str, pathlib.Path],
    name: Union[str, pathlib.Path],
    file_name_fn: Callable[[int], str],
    num_examples: int,
    size: Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]] = None,
    fps=25,
    **kwargs,
) -> List[pathlib.Path]:
    """Create a folder of random videos.

    Args:
907
908
        root (Union[str, pathlib.Path]): Root directory the video folder will be placed in.
        name (Union[str, pathlib.Path]): Name of the video folder.
909
        file_name_fn (Callable[[int], str]): Should return a file name if called with the file index.
910
        num_examples (int): Number of videos to create.
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
        size (Optional[Union[Sequence[int], int, Callable[[int], Union[Sequence[int], int]]]]): Size of the videos. If
            callable, will be called with the index of the corresponding file. If omitted, a random even height and
            width between 4 and 10 pixels is selected on a per-video basis.
        fps (float): Frame rate in frames per second.
        kwargs (Any): Additional parameters passed to :func:`create_video_file`.

    Returns:
        List[pathlib.Path]: Paths to all created video files.

    Raises:
        UsageError: If PyAV is not available.

    .. seealso::

        - :func:`create_video_file`
    """
    if size is None:

        def size(idx):
            num_frames = 1
            num_channels = 3
            # The 'libx264' video codec, which is the default of torchvision.io.write_video, requires the height and
            # width of the video to be divisible by 2.
            height, width = (torch.randint(2, 6, size=(2,), dtype=torch.int) * 2).tolist()
            return (num_frames, num_channels, height, width)

    root = pathlib.Path(root) / name
938
    os.makedirs(root, exist_ok=True)
939
940

    return [
941
        create_video_file(root, file_name_fn(idx), size=size(idx) if callable(size) else size, **kwargs)
942
943
        for idx in range(num_examples)
    ]
944
945


946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
def _split_files_or_dirs(root, *files_or_dirs):
    files = set()
    dirs = set()
    for file_or_dir in files_or_dirs:
        path = pathlib.Path(file_or_dir)
        if not path.is_absolute():
            path = root / path
        if path.is_file():
            files.add(path)
        else:
            dirs.add(path)
            for sub_file_or_dir in path.glob("**/*"):
                if sub_file_or_dir.is_file():
                    files.add(sub_file_or_dir)
                else:
                    dirs.add(sub_file_or_dir)

    if root in dirs:
        dirs.remove(root)

    return files, dirs


def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True):
    archive = pathlib.Path(root) / name
Philip Meier's avatar
Philip Meier committed
971
    if not files_or_dirs:
972
973
974
975
976
977
978
        # We need to invoke `Path.with_suffix("")`, since call only applies to the last suffix if multiple suffixes are
        # present. For example, `pathlib.Path("foo.tar.gz").with_suffix("")` results in `foo.tar`.
        file_or_dir = archive
        for _ in range(len(archive.suffixes)):
            file_or_dir = file_or_dir.with_suffix("")
        if file_or_dir.exists():
            files_or_dirs = (file_or_dir,)
Philip Meier's avatar
Philip Meier committed
979
980
981
        else:
            raise ValueError("No file or dir provided.")

982
983
984
    files, dirs = _split_files_or_dirs(root, *files_or_dirs)

    with opener(archive) as fh:
985
        for file in sorted(files):
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
            adder(fh, file, file.relative_to(root))

    if remove:
        for file in files:
            os.remove(file)
        for dir in dirs:
            shutil.rmtree(dir, ignore_errors=True)

    return archive


def make_tar(root, name, *files_or_dirs, remove=True, compression=None):
    # TODO: detect compression from name
    return _make_archive(
        root,
        name,
        *files_or_dirs,
        opener=lambda archive: tarfile.open(archive, f"w:{compression}" if compression else "w"),
        adder=lambda fh, file, relative_file: fh.add(file, arcname=relative_file),
        remove=remove,
    )


def make_zip(root, name, *files_or_dirs, remove=True):
    return _make_archive(
        root,
        name,
        *files_or_dirs,
        opener=lambda archive: zipfile.ZipFile(archive, "w"),
        adder=lambda fh, file, relative_file: fh.write(file, arcname=relative_file),
        remove=remove,
    )


1020
1021
1022
1023
1024
def create_random_string(length: int, *digits: str) -> str:
    """Create a random string.

    Args:
        length (int): Number of characters in the generated string.
1025
        *digits (str): Characters to sample from. If omitted defaults to :attr:`string.ascii_lowercase`.
1026
1027
1028
1029
1030
1031
1032
    """
    if not digits:
        digits = string.ascii_lowercase
    else:
        digits = "".join(itertools.chain(*digits))

    return "".join(random.choice(digits) for _ in range(length))
1033
1034


1035
1036
1037
1038
1039
1040
1041
1042
def make_fake_pfm_file(h, w, file_name):
    values = list(range(3 * h * w))
    # Note: we pack everything in little endian: -1.0, and "<"
    content = f"PF \n{w} {h} \n-1.0\n".encode() + struct.pack("<" + "f" * len(values), *values)
    with open(file_name, "wb") as f:
        f.write(content)


1043
1044
def make_fake_flo_file(h, w, file_name):
    """Creates a fake flow file in .flo format."""
1045
1046
    # Everything needs to be in little Endian according to
    # https://vision.middlebury.edu/flow/code/flow-code/README.txt
1047
    values = list(range(2 * h * w))
1048
1049
1050
1051
1052
1053
    content = (
        struct.pack("<4c", *(c.encode() for c in "PIEH"))
        + struct.pack("<i", w)
        + struct.pack("<i", h)
        + struct.pack("<" + "f" * len(values), *values)
    )
1054
1055
    with open(file_name, "wb") as f:
        f.write(content)