Unverified Commit 1bd131c7 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add FlyingThings3D dataset for optical flow (#4858)

parent 7f424379
...@@ -44,6 +44,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas ...@@ -44,6 +44,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
Flickr8k Flickr8k
Flickr30k Flickr30k
FlyingChairs FlyingChairs
FlyingThings3D
HMDB51 HMDB51
ImageNet ImageNet
INaturalist INaturalist
......
...@@ -204,7 +204,6 @@ class DatasetTestCase(unittest.TestCase): ...@@ -204,7 +204,6 @@ class DatasetTestCase(unittest.TestCase):
``transforms``, or ``download``. ``transforms``, or ``download``.
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not - REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
available, the tests are skipped. available, the tests are skipped.
- EXTRA_PATCHES(set): Additional patches to add for each test, to e.g. mock a specific function
Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on. Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on.
The fake data should resemble the original data as close as necessary, while containing only few examples. During The fake data should resemble the original data as close as necessary, while containing only few examples. During
...@@ -256,8 +255,6 @@ class DatasetTestCase(unittest.TestCase): ...@@ -256,8 +255,6 @@ class DatasetTestCase(unittest.TestCase):
ADDITIONAL_CONFIGS = None ADDITIONAL_CONFIGS = None
REQUIRED_PACKAGES = None REQUIRED_PACKAGES = None
EXTRA_PATCHES = None
# These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS. # These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
_TRANSFORM_KWARGS = { _TRANSFORM_KWARGS = {
"transform", "transform",
...@@ -383,9 +380,6 @@ class DatasetTestCase(unittest.TestCase): ...@@ -383,9 +380,6 @@ class DatasetTestCase(unittest.TestCase):
if patch_checks: if patch_checks:
patchers.update(self._patch_checks()) patchers.update(self._patch_checks())
if self.EXTRA_PATCHES is not None:
patchers.update(self.EXTRA_PATCHES)
with get_tmp_dir() as tmpdir: with get_tmp_dir() as tmpdir:
args = self.dataset_args(tmpdir, complete_config) args = self.dataset_args(tmpdir, complete_config)
info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None
...@@ -393,7 +387,7 @@ class DatasetTestCase(unittest.TestCase): ...@@ -393,7 +387,7 @@ class DatasetTestCase(unittest.TestCase):
with self._maybe_apply_patches(patchers), disable_console_output(): with self._maybe_apply_patches(patchers), disable_console_output():
dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs) dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs)
yield dataset, info yield dataset, info
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -925,6 +919,14 @@ def create_random_string(length: int, *digits: str) -> str: ...@@ -925,6 +919,14 @@ def create_random_string(length: int, *digits: str) -> str:
return "".join(random.choice(digits) for _ in range(length)) return "".join(random.choice(digits) for _ in range(length))
def make_fake_pfm_file(h, w, file_name):
values = list(range(3 * h * w))
# Note: we pack everything in little endian: -1.0, and "<"
content = f"PF \n{w} {h} \n-1.0\n".encode() + struct.pack("<" + "f" * len(values), *values)
with open(file_name, "wb") as f:
f.write(content)
def make_fake_flo_file(h, w, file_name): def make_fake_flo_file(h, w, file_name):
"""Creates a fake flow file in .flo format.""" """Creates a fake flow file in .flo format."""
values = list(range(2 * h * w)) values = list(range(2 * h * w))
......
...@@ -2048,5 +2048,72 @@ class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2048,5 +2048,72 @@ class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase):
np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape)) np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape))
class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FlyingThings3D
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train", "test"), pass_name=("clean", "final", "both"), camera=("left", "right", "both")
)
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
FLOW_H, FLOW_W = 3, 4
def inject_fake_data(self, tmpdir, config):
root = pathlib.Path(tmpdir) / "FlyingThings3D"
num_images_per_camera = 3 if config["split"] == "train" else 4
passes = ("frames_cleanpass", "frames_finalpass")
splits = ("TRAIN", "TEST")
letters = ("A", "B", "C")
subfolders = ("0000", "0001")
cameras = ("left", "right")
for pass_name, split, letter, subfolder, camera in itertools.product(
passes, splits, letters, subfolders, cameras
):
current_folder = root / pass_name / split / letter / subfolder
datasets_utils.create_image_folder(
current_folder,
name=camera,
file_name_fn=lambda image_idx: f"00{image_idx}.png",
num_examples=num_images_per_camera,
)
directions = ("into_future", "into_past")
for split, letter, subfolder, direction, camera in itertools.product(
splits, letters, subfolders, directions, cameras
):
current_folder = root / "optical_flow" / split / letter / subfolder / direction / camera
os.makedirs(str(current_folder), exist_ok=True)
for i in range(num_images_per_camera):
datasets_utils.make_fake_pfm_file(self.FLOW_H, self.FLOW_W, file_name=str(current_folder / f"{i}.pfm"))
num_cameras = 2 if config["camera"] == "both" else 1
num_passes = 2 if config["pass_name"] == "both" else 1
num_examples = (
(num_images_per_camera - 1) * num_cameras * len(subfolders) * len(letters) * len(splits) * num_passes
)
return num_examples
@datasets_utils.test_all_configs
def test_flow(self, config):
with self.create_dataset(config=config) as (dataset, _):
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
for _, _, flow in dataset:
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
# We don't check the values because the reshaping and flipping makes it hard to figure out
def test_bad_input(self):
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
with self.create_dataset(split="bad"):
pass
with pytest.raises(ValueError, match="Unknown value 'bad' for argument pass_name"):
with self.create_dataset(pass_name="bad"):
pass
with pytest.raises(ValueError, match="Unknown value 'bad' for argument camera"):
with self.create_dataset(camera="bad"):
pass
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
from ._optical_flow import KittiFlow, Sintel, FlyingChairs from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D
from .caltech import Caltech101, Caltech256 from .caltech import Caltech101, Caltech256
from .celeba import CelebA from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100 from .cifar import CIFAR10, CIFAR100
...@@ -75,4 +75,5 @@ __all__ = ( ...@@ -75,4 +75,5 @@ __all__ = (
"KittiFlow", "KittiFlow",
"Sintel", "Sintel",
"FlyingChairs", "FlyingChairs",
"FlyingThings3D",
) )
import itertools
import os import os
import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from glob import glob from glob import glob
from pathlib import Path from pathlib import Path
...@@ -15,6 +17,7 @@ from .vision import VisionDataset ...@@ -15,6 +17,7 @@ from .vision import VisionDataset
__all__ = ( __all__ = (
"KittiFlow", "KittiFlow",
"Sintel", "Sintel",
"FlyingThings3D",
"FlyingChairs", "FlyingChairs",
) )
...@@ -271,6 +274,94 @@ class FlyingChairs(FlowDataset): ...@@ -271,6 +274,94 @@ class FlyingChairs(FlowDataset):
return _read_flo(file_name) return _read_flo(file_name)
class FlyingThings3D(FlowDataset):
"""`FlyingThings3D <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ dataset for optical flow.
The dataset is expected to have the following structure: ::
root
FlyingThings3D
frames_cleanpass
TEST
TRAIN
frames_finalpass
TEST
TRAIN
optical_flow
TEST
TRAIN
Args:
root (string): Root directory of the intel FlyingThings3D Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
details on the different passes.
camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
def __init__(self, root, split="train", pass_name="clean", camera="left", transforms=None):
super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test"))
split = split.upper()
verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
passes = {
"clean": ["frames_cleanpass"],
"final": ["frames_finalpass"],
"both": ["frames_cleanpass", "frames_finalpass"],
}[pass_name]
verify_str_arg(camera, "camera", valid_values=("left", "right", "both"))
cameras = ["left", "right"] if camera == "both" else [camera]
root = Path(root) / "FlyingThings3D"
directions = ("into_future", "into_past")
for pass_name, camera, direction in itertools.product(passes, cameras, directions):
image_dirs = sorted(glob(str(root / pass_name / split / "*/*")))
image_dirs = sorted([Path(image_dir) / camera for image_dir in image_dirs])
flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*")))
flow_dirs = sorted([Path(flow_dir) / direction / camera for flow_dir in flow_dirs])
if not image_dirs or not flow_dirs:
raise FileNotFoundError(
"Could not find the FlyingThings3D flow images. "
"Please make sure the directory structure is correct."
)
for image_dir, flow_dir in zip(image_dirs, flow_dirs):
images = sorted(glob(str(image_dir / "*.png")))
flows = sorted(glob(str(flow_dir / "*.pfm")))
for i in range(len(flows) - 1):
if direction == "into_future":
self._image_list += [[images[i], images[i + 1]]]
self._flow_list += [flows[i]]
elif direction == "into_past":
self._image_list += [[images[i + 1], images[i]]]
self._flow_list += [flows[i + 1]]
def __getitem__(self, index):
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
"""
return super().__getitem__(index)
def _read_flow(self, file_name):
return _read_pfm(file_name)
def _read_flo(file_name): def _read_flo(file_name):
"""Read .flo file in Middlebury format""" """Read .flo file in Middlebury format"""
# Code adapted from: # Code adapted from:
...@@ -295,3 +386,31 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name): ...@@ -295,3 +386,31 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name):
# For consistency with other datasets, we convert to numpy # For consistency with other datasets, we convert to numpy
return flow.numpy(), valid.numpy() return flow.numpy(), valid.numpy()
def _read_pfm(file_name):
"""Read flow in .pfm format"""
with open(file_name, "rb") as f:
header = f.readline().rstrip()
if header != b"PF":
raise ValueError("Invalid PFM file")
dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
if not dim_match:
raise Exception("Malformed PFM header.")
w, h = (int(dim) for dim in dim_match.groups())
scale = float(f.readline().rstrip())
if scale < 0: # little-endian
endian = "<"
scale = -scale
else:
endian = ">" # big-endian
data = np.fromfile(f, dtype=endian + "f")
data = data.reshape(h, w, 3).transpose(2, 0, 1)
data = np.flip(data, axis=1) # flip on h dimension
data = data[:2, :, :]
return data.astype(np.float32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment