Unverified Commit 7f424379 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add FlyingChairs dataset for optical flow (#4860)

parent eb48a1d8
......@@ -43,6 +43,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
FashionMNIST
Flickr8k
Flickr30k
FlyingChairs
HMDB51
ImageNet
INaturalist
......
......@@ -8,6 +8,7 @@ import pathlib
import random
import shutil
import string
import struct
import tarfile
import unittest
import unittest.mock
......@@ -922,3 +923,11 @@ def create_random_string(length: int, *digits: str) -> str:
digits = "".join(itertools.chain(*digits))
return "".join(random.choice(digits) for _ in range(length))
def make_fake_flo_file(h, w, file_name):
"""Creates a fake flow file in .flo format."""
values = list(range(2 * h * w))
content = b"PIEH" + struct.pack("i", w) + struct.pack("i", h) + struct.pack("f" * len(values), *values)
with open(file_name, "wb") as f:
f.write(content)
......@@ -1874,11 +1874,9 @@ class LFWPairsTestCase(LFWPeopleTestCase):
class SintelTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Sintel
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"), pass_name=("clean", "final"))
# We patch the flow reader, because this would otherwise force us to generate fake (but readable) .flo files,
# which is something we want to # avoid.
_FAKE_FLOW = "Fake Flow"
EXTRA_PATCHES = {unittest.mock.patch("torchvision.datasets.Sintel._read_flow", return_value=_FAKE_FLOW)}
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (type(_FAKE_FLOW), type(None)))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
FLOW_H, FLOW_W = 3, 4
def inject_fake_data(self, tmpdir, config):
root = pathlib.Path(tmpdir) / "Sintel"
......@@ -1899,14 +1897,13 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
num_examples=num_images_per_scene,
)
# For the ground truth flow value we just create empty files so that they're properly discovered,
# see comment above about EXTRA_PATCHES
flow_root = root / "training" / "flow"
for scene_id in range(num_scenes):
scene_dir = flow_root / f"scene_{scene_id}"
os.makedirs(scene_dir)
for i in range(num_images_per_scene - 1):
open(str(scene_dir / f"frame_000{i}.flo"), "a").close()
file_name = str(scene_dir / f"frame_000{i}.flo")
datasets_utils.make_fake_flo_file(h=self.FLOW_H, w=self.FLOW_W, file_name=file_name)
# with e.g. num_images_per_scene = 3, for a single scene with have 3 images
# which are frame_0000, frame_0001 and frame_0002
......@@ -1920,7 +1917,8 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
with self.create_dataset(split="train") as (dataset, _):
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
for _, _, flow in dataset:
assert flow == self._FAKE_FLOW
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape))
# Make sure flow is always None for test split
with self.create_dataset(split="test") as (dataset, _):
......@@ -1929,11 +1927,11 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
assert flow is None
def test_bad_input(self):
with pytest.raises(ValueError, match="split must be either"):
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
with self.create_dataset(split="bad"):
pass
with pytest.raises(ValueError, match="pass_name must be either"):
with pytest.raises(ValueError, match="Unknown value 'bad' for argument pass_name"):
with self.create_dataset(pass_name="bad"):
pass
......@@ -1993,10 +1991,62 @@ class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase):
assert valid is None
def test_bad_input(self):
with pytest.raises(ValueError, match="split must be either"):
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
with self.create_dataset(split="bad"):
pass
class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FlyingChairs
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
FLOW_H, FLOW_W = 3, 4
def _make_split_file(self, root, num_examples):
# We create a fake split file here, but users are asked to download the real one from the authors website
split_ids = [1] * num_examples["train"] + [2] * num_examples["val"]
random.shuffle(split_ids)
with open(str(root / "FlyingChairs_train_val.txt"), "w+") as split_file:
for split_id in split_ids:
split_file.write(f"{split_id}\n")
def inject_fake_data(self, tmpdir, config):
root = pathlib.Path(tmpdir) / "FlyingChairs"
num_examples = {"train": 5, "val": 3}
num_examples_total = sum(num_examples.values())
datasets_utils.create_image_folder( # img1
root,
name="data",
file_name_fn=lambda image_idx: f"00{image_idx}_img1.ppm",
num_examples=num_examples_total,
)
datasets_utils.create_image_folder( # img2
root,
name="data",
file_name_fn=lambda image_idx: f"00{image_idx}_img2.ppm",
num_examples=num_examples_total,
)
for i in range(num_examples_total):
file_name = str(root / "data" / f"00{i}_flow.flo")
datasets_utils.make_fake_flo_file(h=self.FLOW_H, w=self.FLOW_W, file_name=file_name)
self._make_split_file(root, num_examples)
return num_examples[config["split"]]
@datasets_utils.test_all_configs
def test_flow(self, config):
# Make sure flow always exists, and make sure there are as many flow values as (pairs of) images
# Also make sure the flow is properly decoded
with self.create_dataset(config=config) as (dataset, _):
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
for _, _, flow in dataset:
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape))
if __name__ == "__main__":
unittest.main()
from ._optical_flow import KittiFlow, Sintel
from ._optical_flow import KittiFlow, Sintel, FlyingChairs
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100
......@@ -74,4 +74,5 @@ __all__ = (
"LFWPairs",
"KittiFlow",
"Sintel",
"FlyingChairs",
)
......@@ -8,12 +8,14 @@ import torch
from PIL import Image
from ..io.image import _read_png_16
from .utils import verify_str_arg
from .vision import VisionDataset
__all__ = (
"KittiFlow",
"Sintel",
"FlyingChairs",
)
......@@ -109,11 +111,8 @@ class Sintel(FlowDataset):
def __init__(self, root, split="train", pass_name="clean", transforms=None):
super().__init__(root=root, transforms=transforms)
if split not in ("train", "test"):
raise ValueError("split must be either 'train' or 'test'")
if pass_name not in ("clean", "final"):
raise ValueError("pass_name must be either 'clean' or 'final'")
verify_str_arg(split, "split", valid_values=("train", "test"))
verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final"))
root = Path(root) / "Sintel"
......@@ -171,8 +170,7 @@ class KittiFlow(FlowDataset):
def __init__(self, root, split="train", transforms=None):
super().__init__(root=root, transforms=transforms)
if split not in ("train", "test"):
raise ValueError("split must be either 'train' or 'test'")
verify_str_arg(split, "split", valid_values=("train", "test"))
root = Path(root) / "Kitti" / (split + "ing")
images1 = sorted(glob(str(root / "image_2" / "*_10.png")))
......@@ -208,6 +206,71 @@ class KittiFlow(FlowDataset):
return _read_16bits_png_with_flow_and_valid_mask(file_name)
class FlyingChairs(FlowDataset):
"""`FlyingChairs <https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs>`_ Dataset for optical flow.
You will also need to download the FlyingChairs_train_val.txt file from the dataset page.
The dataset is expected to have the following structure: ::
root
FlyingChairs
data
00001_flow.flo
00001_img1.ppm
00001_img2.ppm
...
FlyingChairs_train_val.txt
Args:
root (string): Root directory of the FlyingChairs Dataset.
split (string, optional): The dataset split, either "train" (default) or "val"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
def __init__(self, root, split="train", transforms=None):
super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "val"))
root = Path(root) / "FlyingChairs"
images = sorted(glob(str(root / "data" / "*.ppm")))
flows = sorted(glob(str(root / "data" / "*.flo")))
split_file_name = "FlyingChairs_train_val.txt"
if not os.path.exists(root / split_file_name):
raise FileNotFoundError(
"The FlyingChairs_train_val.txt file was not found - please download it from the dataset page (see docstring)."
)
split_list = np.loadtxt(str(root / split_file_name), dtype=np.int32)
for i in range(len(flows)):
split_id = split_list[i]
if (split == "train" and split_id == 1) or (split == "val" and split_id == 2):
self._flow_list += [flows[i]]
self._image_list += [[images[2 * i], images[2 * i + 1]]]
def __getitem__(self, index):
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
"""
return super().__getitem__(index)
def _read_flow(self, file_name):
return _read_flo(file_name)
def _read_flo(file_name):
"""Read .flo file in Middlebury format"""
# Code adapted from:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment