Unverified Commit 50a35717 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add Kitti and Sintel datasets for optical flow (#4845)

parent f5af07b4
......@@ -48,6 +48,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
INaturalist
Kinetics400
Kitti
KittiFlow
KMNIST
LFWPeople
LFWPairs
......@@ -60,6 +61,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
SBDataset
SBU
SEMEION
Sintel
STL10
SVHN
UCF101
......
......@@ -203,6 +203,7 @@ class DatasetTestCase(unittest.TestCase):
``transforms``, or ``download``.
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
available, the tests are skipped.
- EXTRA_PATCHES(set): Additional patches to add for each test, to e.g. mock a specific function
Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on.
The fake data should resemble the original data as close as necessary, while containing only few examples. During
......@@ -254,6 +255,8 @@ class DatasetTestCase(unittest.TestCase):
ADDITIONAL_CONFIGS = None
REQUIRED_PACKAGES = None
EXTRA_PATCHES = None
# These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
_TRANSFORM_KWARGS = {
"transform",
......@@ -379,6 +382,9 @@ class DatasetTestCase(unittest.TestCase):
if patch_checks:
patchers.update(self._patch_checks())
if self.EXTRA_PATCHES is not None:
patchers.update(self.EXTRA_PATCHES)
with get_tmp_dir() as tmpdir:
args = self.dataset_args(tmpdir, complete_config)
info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None
......@@ -386,7 +392,7 @@ class DatasetTestCase(unittest.TestCase):
with self._maybe_apply_patches(patchers), disable_console_output():
dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs)
yield dataset, info
yield dataset, info
@classmethod
def setUpClass(cls):
......
......@@ -1871,5 +1871,132 @@ class LFWPairsTestCase(LFWPeopleTestCase):
datasets_utils.create_image_folder(root, name2, lambda _: f"{name2}_{no2:04d}.jpg", 1, 250)
class SintelTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Sintel
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"), pass_name=("clean", "final"))
# We patch the flow reader, because this would otherwise force us to generate fake (but readable) .flo files,
# which is something we want to # avoid.
_FAKE_FLOW = "Fake Flow"
EXTRA_PATCHES = {unittest.mock.patch("torchvision.datasets.Sintel._read_flow", return_value=_FAKE_FLOW)}
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (type(_FAKE_FLOW), type(None)))
def inject_fake_data(self, tmpdir, config):
root = pathlib.Path(tmpdir) / "Sintel"
num_images_per_scene = 3 if config["split"] == "train" else 4
num_scenes = 2
for split_dir in ("training", "test"):
for pass_name in ("clean", "final"):
image_root = root / split_dir / pass_name
for scene_id in range(num_scenes):
scene_dir = image_root / f"scene_{scene_id}"
datasets_utils.create_image_folder(
image_root,
name=str(scene_dir),
file_name_fn=lambda image_idx: f"frame_000{image_idx}.png",
num_examples=num_images_per_scene,
)
# For the ground truth flow value we just create empty files so that they're properly discovered,
# see comment above about EXTRA_PATCHES
flow_root = root / "training" / "flow"
for scene_id in range(num_scenes):
scene_dir = flow_root / f"scene_{scene_id}"
os.makedirs(scene_dir)
for i in range(num_images_per_scene - 1):
open(str(scene_dir / f"frame_000{i}.flo"), "a").close()
# with e.g. num_images_per_scene = 3, for a single scene with have 3 images
# which are frame_0000, frame_0001 and frame_0002
# They will be consecutively paired as (frame_0000, frame_0001), (frame_0001, frame_0002),
# that is 3 - 1 = 2 examples. Hence the formula below
num_examples = (num_images_per_scene - 1) * num_scenes
return num_examples
def test_flow(self):
# Make sure flow exists for train split, and make sure there are as many flow values as (pairs of) images
with self.create_dataset(split="train") as (dataset, _):
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
for _, _, flow in dataset:
assert flow == self._FAKE_FLOW
# Make sure flow is always None for test split
with self.create_dataset(split="test") as (dataset, _):
assert dataset._image_list and not dataset._flow_list
for _, _, flow in dataset:
assert flow is None
def test_bad_input(self):
with pytest.raises(ValueError, match="split must be either"):
with self.create_dataset(split="bad"):
pass
with pytest.raises(ValueError, match="pass_name must be either"):
with self.create_dataset(pass_name="bad"):
pass
class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.KittiFlow
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
def inject_fake_data(self, tmpdir, config):
root = pathlib.Path(tmpdir) / "Kitti"
num_examples = 2 if config["split"] == "train" else 3
for split_dir in ("training", "testing"):
datasets_utils.create_image_folder(
root / split_dir,
name="image_2",
file_name_fn=lambda image_idx: f"{image_idx}_10.png",
num_examples=num_examples,
)
datasets_utils.create_image_folder(
root / split_dir,
name="image_2",
file_name_fn=lambda image_idx: f"{image_idx}_11.png",
num_examples=num_examples,
)
# For kitti the ground truth flows are encoded as 16-bits pngs.
# create_image_folder() will actually create 8-bits pngs, but it doesn't
# matter much: the flow reader will still be able to read the files, it
# will just be garbage flow value - but we don't care about that here.
datasets_utils.create_image_folder(
root / "training",
name="flow_occ",
file_name_fn=lambda image_idx: f"{image_idx}_10.png",
num_examples=num_examples,
)
return num_examples
def test_flow_and_valid(self):
# Make sure flow exists for train split, and make sure there are as many flow values as (pairs of) images
# Also assert flow and valid are of the expected shape
with self.create_dataset(split="train") as (dataset, _):
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
for _, _, flow, valid in dataset:
two, h, w = flow.shape
assert two == 2
assert valid.shape == (h, w)
# Make sure flow and valid are always None for test split
with self.create_dataset(split="test") as (dataset, _):
assert dataset._image_list and not dataset._flow_list
for _, _, flow, valid in dataset:
assert flow is None
assert valid is None
def test_bad_input(self):
with pytest.raises(ValueError, match="split must be either"):
with self.create_dataset(split="bad"):
pass
if __name__ == "__main__":
unittest.main()
from ._optical_flow import KittiFlow, Sintel
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100
......@@ -71,4 +72,6 @@ __all__ = (
"INaturalist",
"LFWPeople",
"LFWPairs",
"KittiFlow",
"Sintel",
)
import os
from abc import ABC, abstractmethod
from glob import glob
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from ..io.image import _read_png_16
from .vision import VisionDataset
__all__ = (
"KittiFlow",
"Sintel",
)
class FlowDataset(ABC, VisionDataset):
# Some datasets like Kitti have a built-in valid mask, indicating which flow values are valid
# For those we return (img1, img2, flow, valid), and for the rest we return (img1, img2, flow),
# and it's up to whatever consumes the dataset to decide what `valid` should be.
_has_builtin_flow_mask = False
def __init__(self, root, transforms=None):
super().__init__(root=root)
self.transforms = transforms
self._flow_list = []
self._image_list = []
def _read_img(self, file_name):
return Image.open(file_name)
@abstractmethod
def _read_flow(self, file_name):
# Return the flow or a tuple with the flow and the valid mask if _has_builtin_flow_mask is True
pass
def __getitem__(self, index):
img1 = self._read_img(self._image_list[index][0])
img2 = self._read_img(self._image_list[index][1])
if self._flow_list: # it will be empty for some dataset when split="test"
flow = self._read_flow(self._flow_list[index])
if self._has_builtin_flow_mask:
flow, valid = flow
else:
valid = None
else:
flow = valid = None
if self.transforms is not None:
img1, img2, flow, valid = self.transforms(img1, img2, flow, valid)
if self._has_builtin_flow_mask:
return img1, img2, flow, valid
else:
return img1, img2, flow
def __len__(self):
return len(self._image_list)
class Sintel(FlowDataset):
"""`Sintel <http://sintel.is.tue.mpg.de/>`_ Dataset for optical flow.
The dataset is expected to have the following structure: ::
root
Sintel
testing
clean
scene_1
scene_2
...
final
scene_1
scene_2
...
training
clean
scene_1
scene_2
...
final
scene_1
scene_2
...
flow
scene_1
scene_2
...
Args:
root (string): Root directory of the Sintel Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default) or "final". See link above for
details on the different passes.
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
def __init__(self, root, split="train", pass_name="clean", transforms=None):
super().__init__(root=root, transforms=transforms)
if split not in ("train", "test"):
raise ValueError("split must be either 'train' or 'test'")
if pass_name not in ("clean", "final"):
raise ValueError("pass_name must be either 'clean' or 'final'")
root = Path(root) / "Sintel"
split_dir = "training" if split == "train" else split
image_root = root / split_dir / pass_name
flow_root = root / "training" / "flow"
for scene in os.listdir(image_root):
image_list = sorted(glob(str(image_root / scene / "*.png")))
for i in range(len(image_list) - 1):
self._image_list += [[image_list[i], image_list[i + 1]]]
if split == "train":
self._flow_list += sorted(glob(str(flow_root / scene / "*.flo")))
def __getitem__(self, index):
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: If ``split="train"`` a 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images. If `split="test"`, a
3-tuple with ``(img1, img2, None)`` is returned.
"""
return super().__getitem__(index)
def _read_flow(self, file_name):
return _read_flo(file_name)
class KittiFlow(FlowDataset):
"""`KITTI <http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow>`__ dataset for optical flow (2015).
The dataset is expected to have the following structure: ::
root
Kitti
testing
image_2
training
image_2
flow_occ
Args:
root (string): Root directory of the KittiFlow Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
"""
_has_builtin_flow_mask = True
def __init__(self, root, split="train", transforms=None):
super().__init__(root=root, transforms=transforms)
if split not in ("train", "test"):
raise ValueError("split must be either 'train' or 'test'")
root = Path(root) / "Kitti" / (split + "ing")
images1 = sorted(glob(str(root / "image_2" / "*_10.png")))
images2 = sorted(glob(str(root / "image_2" / "*_11.png")))
if not images1 or not images2:
raise FileNotFoundError(
"Could not find the Kitti flow images. Please make sure the directory structure is correct."
)
for img1, img2 in zip(images1, images2):
self._image_list += [[img1, img2]]
if split == "train":
self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png")))
def __getitem__(self, index):
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
valid)`` where ``valid`` is a numpy boolean mask of shape (H, W)
indicating which flow values are valid. The flow is a numpy array of
shape (2, H, W) and the images are PIL images. If `split="test"`, a
4-tuple with ``(img1, img2, None, None)`` is returned.
"""
return super().__getitem__(index)
def _read_flow(self, file_name):
return _read_16bits_png_with_flow_and_valid_mask(file_name)
def _read_flo(file_name):
"""Read .flo file in Middlebury format"""
# Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
with open(file_name, "rb") as f:
magic = np.fromfile(f, np.float32, count=1)
if 202021.25 != magic:
raise ValueError("Magic number incorrect. Invalid .flo file")
w = int(np.fromfile(f, np.int32, count=1))
h = int(np.fromfile(f, np.int32, count=1))
data = np.fromfile(f, np.float32, count=2 * w * h)
return data.reshape(2, h, w)
def _read_16bits_png_with_flow_and_valid_mask(file_name):
flow_and_valid = _read_png_16(file_name).to(torch.float32)
flow, valid = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
flow = (flow - 2 ** 15) / 64 # This conversion is explained somewhere on the kitti archive
# For consistency with other datasets, we convert to numpy
return flow.numpy(), valid.numpy()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment