Unverified Commit 50a35717 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add Kitti and Sintel datasets for optical flow (#4845)

parent f5af07b4
...@@ -48,6 +48,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas ...@@ -48,6 +48,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
INaturalist INaturalist
Kinetics400 Kinetics400
Kitti Kitti
KittiFlow
KMNIST KMNIST
LFWPeople LFWPeople
LFWPairs LFWPairs
...@@ -60,6 +61,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas ...@@ -60,6 +61,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
SBDataset SBDataset
SBU SBU
SEMEION SEMEION
Sintel
STL10 STL10
SVHN SVHN
UCF101 UCF101
......
...@@ -203,6 +203,7 @@ class DatasetTestCase(unittest.TestCase): ...@@ -203,6 +203,7 @@ class DatasetTestCase(unittest.TestCase):
``transforms``, or ``download``. ``transforms``, or ``download``.
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not - REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
available, the tests are skipped. available, the tests are skipped.
- EXTRA_PATCHES(set): Additional patches to add for each test, to e.g. mock a specific function
Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on. Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on.
The fake data should resemble the original data as close as necessary, while containing only few examples. During The fake data should resemble the original data as close as necessary, while containing only few examples. During
...@@ -254,6 +255,8 @@ class DatasetTestCase(unittest.TestCase): ...@@ -254,6 +255,8 @@ class DatasetTestCase(unittest.TestCase):
ADDITIONAL_CONFIGS = None ADDITIONAL_CONFIGS = None
REQUIRED_PACKAGES = None REQUIRED_PACKAGES = None
EXTRA_PATCHES = None
# These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS. # These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
_TRANSFORM_KWARGS = { _TRANSFORM_KWARGS = {
"transform", "transform",
...@@ -379,6 +382,9 @@ class DatasetTestCase(unittest.TestCase): ...@@ -379,6 +382,9 @@ class DatasetTestCase(unittest.TestCase):
if patch_checks: if patch_checks:
patchers.update(self._patch_checks()) patchers.update(self._patch_checks())
if self.EXTRA_PATCHES is not None:
patchers.update(self.EXTRA_PATCHES)
with get_tmp_dir() as tmpdir: with get_tmp_dir() as tmpdir:
args = self.dataset_args(tmpdir, complete_config) args = self.dataset_args(tmpdir, complete_config)
info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None
......
...@@ -1871,5 +1871,132 @@ class LFWPairsTestCase(LFWPeopleTestCase): ...@@ -1871,5 +1871,132 @@ class LFWPairsTestCase(LFWPeopleTestCase):
datasets_utils.create_image_folder(root, name2, lambda _: f"{name2}_{no2:04d}.jpg", 1, 250) datasets_utils.create_image_folder(root, name2, lambda _: f"{name2}_{no2:04d}.jpg", 1, 250)
class SintelTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Sintel
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"), pass_name=("clean", "final"))
# We patch the flow reader, because this would otherwise force us to generate fake (but readable) .flo files,
# which is something we want to # avoid.
_FAKE_FLOW = "Fake Flow"
EXTRA_PATCHES = {unittest.mock.patch("torchvision.datasets.Sintel._read_flow", return_value=_FAKE_FLOW)}
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (type(_FAKE_FLOW), type(None)))
def inject_fake_data(self, tmpdir, config):
root = pathlib.Path(tmpdir) / "Sintel"
num_images_per_scene = 3 if config["split"] == "train" else 4
num_scenes = 2
for split_dir in ("training", "test"):
for pass_name in ("clean", "final"):
image_root = root / split_dir / pass_name
for scene_id in range(num_scenes):
scene_dir = image_root / f"scene_{scene_id}"
datasets_utils.create_image_folder(
image_root,
name=str(scene_dir),
file_name_fn=lambda image_idx: f"frame_000{image_idx}.png",
num_examples=num_images_per_scene,
)
# For the ground truth flow value we just create empty files so that they're properly discovered,
# see comment above about EXTRA_PATCHES
flow_root = root / "training" / "flow"
for scene_id in range(num_scenes):
scene_dir = flow_root / f"scene_{scene_id}"
os.makedirs(scene_dir)
for i in range(num_images_per_scene - 1):
open(str(scene_dir / f"frame_000{i}.flo"), "a").close()
# with e.g. num_images_per_scene = 3, for a single scene with have 3 images
# which are frame_0000, frame_0001 and frame_0002
# They will be consecutively paired as (frame_0000, frame_0001), (frame_0001, frame_0002),
# that is 3 - 1 = 2 examples. Hence the formula below
num_examples = (num_images_per_scene - 1) * num_scenes
return num_examples
def test_flow(self):
# Make sure flow exists for train split, and make sure there are as many flow values as (pairs of) images
with self.create_dataset(split="train") as (dataset, _):
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
for _, _, flow in dataset:
assert flow == self._FAKE_FLOW
# Make sure flow is always None for test split
with self.create_dataset(split="test") as (dataset, _):
assert dataset._image_list and not dataset._flow_list
for _, _, flow in dataset:
assert flow is None
def test_bad_input(self):
with pytest.raises(ValueError, match="split must be either"):
with self.create_dataset(split="bad"):
pass
with pytest.raises(ValueError, match="pass_name must be either"):
with self.create_dataset(pass_name="bad"):
pass
class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.KittiFlow
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
def inject_fake_data(self, tmpdir, config):
root = pathlib.Path(tmpdir) / "Kitti"
num_examples = 2 if config["split"] == "train" else 3
for split_dir in ("training", "testing"):
datasets_utils.create_image_folder(
root / split_dir,
name="image_2",
file_name_fn=lambda image_idx: f"{image_idx}_10.png",
num_examples=num_examples,
)
datasets_utils.create_image_folder(
root / split_dir,
name="image_2",
file_name_fn=lambda image_idx: f"{image_idx}_11.png",
num_examples=num_examples,
)
# For kitti the ground truth flows are encoded as 16-bits pngs.
# create_image_folder() will actually create 8-bits pngs, but it doesn't
# matter much: the flow reader will still be able to read the files, it
# will just be garbage flow value - but we don't care about that here.
datasets_utils.create_image_folder(
root / "training",
name="flow_occ",
file_name_fn=lambda image_idx: f"{image_idx}_10.png",
num_examples=num_examples,
)
return num_examples
def test_flow_and_valid(self):
# Make sure flow exists for train split, and make sure there are as many flow values as (pairs of) images
# Also assert flow and valid are of the expected shape
with self.create_dataset(split="train") as (dataset, _):
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
for _, _, flow, valid in dataset:
two, h, w = flow.shape
assert two == 2
assert valid.shape == (h, w)
# Make sure flow and valid are always None for test split
with self.create_dataset(split="test") as (dataset, _):
assert dataset._image_list and not dataset._flow_list
for _, _, flow, valid in dataset:
assert flow is None
assert valid is None
def test_bad_input(self):
with pytest.raises(ValueError, match="split must be either"):
with self.create_dataset(split="bad"):
pass
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
from ._optical_flow import KittiFlow, Sintel
from .caltech import Caltech101, Caltech256 from .caltech import Caltech101, Caltech256
from .celeba import CelebA from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100 from .cifar import CIFAR10, CIFAR100
...@@ -71,4 +72,6 @@ __all__ = ( ...@@ -71,4 +72,6 @@ __all__ = (
"INaturalist", "INaturalist",
"LFWPeople", "LFWPeople",
"LFWPairs", "LFWPairs",
"KittiFlow",
"Sintel",
) )
import os
from abc import ABC, abstractmethod
from glob import glob
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from ..io.image import _read_png_16
from .vision import VisionDataset
__all__ = (
"KittiFlow",
"Sintel",
)
class FlowDataset(ABC, VisionDataset):
# Some datasets like Kitti have a built-in valid mask, indicating which flow values are valid
# For those we return (img1, img2, flow, valid), and for the rest we return (img1, img2, flow),
# and it's up to whatever consumes the dataset to decide what `valid` should be.
_has_builtin_flow_mask = False
def __init__(self, root, transforms=None):
super().__init__(root=root)
self.transforms = transforms
self._flow_list = []
self._image_list = []
def _read_img(self, file_name):
return Image.open(file_name)
@abstractmethod
def _read_flow(self, file_name):
# Return the flow or a tuple with the flow and the valid mask if _has_builtin_flow_mask is True
pass
def __getitem__(self, index):
img1 = self._read_img(self._image_list[index][0])
img2 = self._read_img(self._image_list[index][1])
if self._flow_list: # it will be empty for some dataset when split="test"
flow = self._read_flow(self._flow_list[index])
if self._has_builtin_flow_mask:
flow, valid = flow
else:
valid = None
else:
flow = valid = None
if self.transforms is not None:
img1, img2, flow, valid = self.transforms(img1, img2, flow, valid)
if self._has_builtin_flow_mask:
return img1, img2, flow, valid
else:
return img1, img2, flow
def __len__(self):
return len(self._image_list)
class Sintel(FlowDataset):
"""`Sintel <http://sintel.is.tue.mpg.de/>`_ Dataset for optical flow.
The dataset is expected to have the following structure: ::
root
Sintel
testing
clean
scene_1
scene_2
...
final
scene_1
scene_2
...
training
clean
scene_1
scene_2
...
final
scene_1
scene_2
...
flow
scene_1
scene_2
...
Args:
root (string): Root directory of the Sintel Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default) or "final". See link above for
details on the different passes.
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
def __init__(self, root, split="train", pass_name="clean", transforms=None):
super().__init__(root=root, transforms=transforms)
if split not in ("train", "test"):
raise ValueError("split must be either 'train' or 'test'")
if pass_name not in ("clean", "final"):
raise ValueError("pass_name must be either 'clean' or 'final'")
root = Path(root) / "Sintel"
split_dir = "training" if split == "train" else split
image_root = root / split_dir / pass_name
flow_root = root / "training" / "flow"
for scene in os.listdir(image_root):
image_list = sorted(glob(str(image_root / scene / "*.png")))
for i in range(len(image_list) - 1):
self._image_list += [[image_list[i], image_list[i + 1]]]
if split == "train":
self._flow_list += sorted(glob(str(flow_root / scene / "*.flo")))
def __getitem__(self, index):
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: If ``split="train"`` a 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images. If `split="test"`, a
3-tuple with ``(img1, img2, None)`` is returned.
"""
return super().__getitem__(index)
def _read_flow(self, file_name):
return _read_flo(file_name)
class KittiFlow(FlowDataset):
"""`KITTI <http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow>`__ dataset for optical flow (2015).
The dataset is expected to have the following structure: ::
root
Kitti
testing
image_2
training
image_2
flow_occ
Args:
root (string): Root directory of the KittiFlow Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
"""
_has_builtin_flow_mask = True
def __init__(self, root, split="train", transforms=None):
super().__init__(root=root, transforms=transforms)
if split not in ("train", "test"):
raise ValueError("split must be either 'train' or 'test'")
root = Path(root) / "Kitti" / (split + "ing")
images1 = sorted(glob(str(root / "image_2" / "*_10.png")))
images2 = sorted(glob(str(root / "image_2" / "*_11.png")))
if not images1 or not images2:
raise FileNotFoundError(
"Could not find the Kitti flow images. Please make sure the directory structure is correct."
)
for img1, img2 in zip(images1, images2):
self._image_list += [[img1, img2]]
if split == "train":
self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png")))
def __getitem__(self, index):
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
valid)`` where ``valid`` is a numpy boolean mask of shape (H, W)
indicating which flow values are valid. The flow is a numpy array of
shape (2, H, W) and the images are PIL images. If `split="test"`, a
4-tuple with ``(img1, img2, None, None)`` is returned.
"""
return super().__getitem__(index)
def _read_flow(self, file_name):
return _read_16bits_png_with_flow_and_valid_mask(file_name)
def _read_flo(file_name):
"""Read .flo file in Middlebury format"""
# Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
with open(file_name, "rb") as f:
magic = np.fromfile(f, np.float32, count=1)
if 202021.25 != magic:
raise ValueError("Magic number incorrect. Invalid .flo file")
w = int(np.fromfile(f, np.int32, count=1))
h = int(np.fromfile(f, np.int32, count=1))
data = np.fromfile(f, np.float32, count=2 * w * h)
return data.reshape(2, h, w)
def _read_16bits_png_with_flow_and_valid_mask(file_name):
flow_and_valid = _read_png_16(file_name).to(torch.float32)
flow, valid = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
flow = (flow - 2 ** 15) / 64 # This conversion is explained somewhere on the kitti archive
# For consistency with other datasets, we convert to numpy
return flow.numpy(), valid.numpy()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment