Unverified Commit 17088a68 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove videos from test for DatasetFolder (#7216)

parent c974742c
...@@ -1528,27 +1528,16 @@ class MovingMNISTTestCase(datasets_utils.DatasetTestCase): ...@@ -1528,27 +1528,16 @@ class MovingMNISTTestCase(datasets_utils.DatasetTestCase):
class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.DatasetFolder DATASET_CLASS = datasets.DatasetFolder
# The dataset has no fixed return type since it is defined by the loader parameter. For testing, we use a loader _EXTENSIONS = ("jpg", "png")
# that simply returns the path as type 'str' instead of loading anything. See the 'dataset_args()' method.
FEATURE_TYPES = (str, int)
_IMAGE_EXTENSIONS = ("jpg", "png")
_VIDEO_EXTENSIONS = ("avi", "mp4")
_EXTENSIONS = (*_IMAGE_EXTENSIONS, *_VIDEO_EXTENSIONS)
# DatasetFolder has two mutually exclusive parameters: 'extensions' and 'is_valid_file'. One of both is required. # DatasetFolder has two mutually exclusive parameters: 'extensions' and 'is_valid_file'. One of both is required.
# We only iterate over different 'extensions' here and handle the tests for 'is_valid_file' in the # We only iterate over different 'extensions' here and handle the tests for 'is_valid_file' in the
# 'test_is_valid_file()' method. # 'test_is_valid_file()' method.
DEFAULT_CONFIG = dict(extensions=_EXTENSIONS) DEFAULT_CONFIG = dict(extensions=_EXTENSIONS)
ADDITIONAL_CONFIGS = ( ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(extensions=[(ext,) for ext in _EXTENSIONS])
*datasets_utils.combinations_grid(extensions=[(ext,) for ext in _IMAGE_EXTENSIONS]),
dict(extensions=_IMAGE_EXTENSIONS),
*datasets_utils.combinations_grid(extensions=[(ext,) for ext in _VIDEO_EXTENSIONS]),
dict(extensions=_VIDEO_EXTENSIONS),
)
def dataset_args(self, tmpdir, config): def dataset_args(self, tmpdir, config):
return tmpdir, lambda x: x return tmpdir, datasets.folder.pil_loader
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
extensions = config["extensions"] or self._is_valid_file_to_extensions(config["is_valid_file"]) extensions = config["extensions"] or self._is_valid_file_to_extensions(config["is_valid_file"])
...@@ -1559,14 +1548,8 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1559,14 +1548,8 @@ class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
if ext not in extensions: if ext not in extensions:
continue continue
create_example_folder = (
datasets_utils.create_image_folder
if ext in self._IMAGE_EXTENSIONS
else datasets_utils.create_video_folder
)
num_examples = torch.randint(1, 3, size=()).item() num_examples = torch.randint(1, 3, size=()).item()
create_example_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples) datasets_utils.create_image_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples)
num_examples_total += num_examples num_examples_total += num_examples
classes.append(cls) classes.append(cls)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment