Unverified Commit 96aa3d92 authored by Ponku's avatar Ponku Committed by GitHub
Browse files

Added SceneFLow variant datasets (#6345)

* added SceneFLow variant datasets

* Changed split name to variant name

* removed trailing commented code line
parent 95fcd83a
......@@ -111,6 +111,7 @@ Stereo Matching
CarlaStereo
Kitti2012Stereo
Kitti2015Stereo
SceneFlowStereo
Image pairs
~~~~~~~~~~~
......
......@@ -13,7 +13,7 @@ import string
import unittest
import xml.etree.ElementTree as ET
import zipfile
from typing import Union
from typing import Callable, Tuple, Union
import datasets_utils
import numpy as np
......@@ -2841,5 +2841,80 @@ class CarlaStereoTestCase(datasets_utils.ImageDatasetTestCase):
datasets_utils.shape_test_for_stereo(left, right, disparity)
class SceneFlowStereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SceneFlowStereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
variant=("FlyingThings3D", "Driving", "Monkaa"), pass_name=("clean", "final", "both")
)
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
@staticmethod
def _create_pfm_folder(
root: str, name: str, file_name_fn: Callable[..., str], num_examples: int, size: Tuple[int, int]
) -> None:
root = pathlib.Path(root) / name
os.makedirs(root, exist_ok=True)
for i in range(num_examples):
datasets_utils.make_fake_pfm_file(size[0], size[1], root / file_name_fn(i))
def inject_fake_data(self, tmpdir, config):
scene_flow_dir = pathlib.Path(tmpdir) / "SceneFlow"
os.makedirs(scene_flow_dir, exist_ok=True)
variant_dir = scene_flow_dir / config["variant"]
os.makedirs(variant_dir, exist_ok=True)
num_examples = {"FlyingThings3D": 4, "Driving": 6, "Monkaa": 5}.get(config["variant"], 0)
passes = {
"clean": ["frames_cleanpass"],
"final": ["frames_finalpass"],
"both": ["frames_cleanpass", "frames_finalpass"],
}.get(config["pass_name"], [])
for pass_dir_name in passes:
# create pass directories
pass_dir = variant_dir / pass_dir_name
disp_dir = variant_dir / "disparity"
os.makedirs(pass_dir, exist_ok=True)
os.makedirs(disp_dir, exist_ok=True)
for direction in ["left", "right"]:
for scene_idx in range(num_examples):
os.makedirs(pass_dir / f"scene_{scene_idx:06d}", exist_ok=True)
datasets_utils.create_image_folder(
root=pass_dir / f"scene_{scene_idx:06d}",
name=direction,
file_name_fn=lambda i: f"{i:06d}.png",
num_examples=1,
size=(3, 200, 100),
)
os.makedirs(disp_dir / f"scene_{scene_idx:06d}", exist_ok=True)
self._create_pfm_folder(
root=disp_dir / f"scene_{scene_idx:06d}",
name=direction,
file_name_fn=lambda i: f"{i:06d}.pfm",
num_examples=1,
size=(100, 200),
)
if config["pass_name"] == "both":
num_examples *= 2
return num_examples
def test_splits(self):
for variant_name, pass_name in itertools.product(["FlyingThings3D", "Driving", "Monkaa"], ["clean", "final"]):
with self.create_dataset(variant=variant_name, pass_name=pass_name) as (dataset, _):
for left, right, disparity in dataset:
datasets_utils.shape_test_for_stereo(left, right, disparity)
def test_bad_input(self):
with pytest.raises(ValueError, match="Unknown value 'bad' for argument variant"):
with self.create_dataset(variant="bad"):
pass
if __name__ == "__main__":
unittest.main()
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo
from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo, SceneFlowStereo
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100
......@@ -109,4 +109,5 @@ __all__ = (
"Kitti2012Stereo",
"Kitti2015Stereo",
"CarlaStereo",
"SceneFlowStereo",
)
......@@ -359,3 +359,109 @@ class Kitti2015Stereo(StereoMatchingDataset):
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
"""
return super().__getitem__(index)
class SceneFlowStereo(StereoMatchingDataset):
"""Dataset interface for `Scene Flow <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ datasets.
This interface provides access to the `FlyingThings3D, `Monkaa` and `Driving` datasets.
The dataset is expected to have the following structre: ::
root
SceneFlow
Monkaa
frames_cleanpass
scene1
left
img1.png
img2.png
right
img1.png
img2.png
scene2
left
img1.png
img2.png
right
img1.png
img2.png
frames_finalpass
scene1
left
img1.png
img2.png
right
img1.png
img2.png
...
...
disparity
scene1
left
img1.pfm
img2.pfm
right
img1.pfm
img2.pfm
FlyingThings3D
...
...
Args:
root (string): Root directory where SceneFlow is located.
variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving".
pass_name (string): Which pass to use, "clean" (default), "final" or "both".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
def __init__(
self,
root: str,
variant: str = "FlyingThings3D",
pass_name: str = "clean",
transforms: Optional[Callable] = None,
):
super().__init__(root, transforms)
root = Path(root) / "SceneFlow"
verify_str_arg(variant, "variant", valid_values=("FlyingThings3D", "Driving", "Monkaa"))
verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
passes = {
"clean": ["frames_cleanpass"],
"final": ["frames_finalpass"],
"both": ["frames_cleanpass", "frames_finalpass"],
}[pass_name]
root = root / variant
for p in passes:
left_image_pattern = str(root / p / "*" / "left" / "*.png")
right_image_pattern = str(root / p / "*" / "right" / "*.png")
self._images += self._scan_pairs(left_image_pattern, right_image_pattern)
left_disparity_pattern = str(root / "disparity" / "*" / "left" / "*.pfm")
right_disparity_pattern = str(root / "disparity" / "*" / "right" / "*.pfm")
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
def _read_disparity(self, file_path: str) -> Tuple:
disparity_map = _read_pfm_file(file_path)
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
valid_mask = None
return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple:
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
"""
return super().__getitem__(index)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment