Unverified Commit 96aa3d92 authored by Ponku's avatar Ponku Committed by GitHub
Browse files

Added SceneFLow variant datasets (#6345)

* added SceneFLow variant datasets

* Changed split name to variant name

* removed trailing commented code line
parent 95fcd83a
...@@ -111,6 +111,7 @@ Stereo Matching ...@@ -111,6 +111,7 @@ Stereo Matching
CarlaStereo CarlaStereo
Kitti2012Stereo Kitti2012Stereo
Kitti2015Stereo Kitti2015Stereo
SceneFlowStereo
Image pairs Image pairs
~~~~~~~~~~~ ~~~~~~~~~~~
......
...@@ -13,7 +13,7 @@ import string ...@@ -13,7 +13,7 @@ import string
import unittest import unittest
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import zipfile import zipfile
from typing import Union from typing import Callable, Tuple, Union
import datasets_utils import datasets_utils
import numpy as np import numpy as np
...@@ -2841,5 +2841,80 @@ class CarlaStereoTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2841,5 +2841,80 @@ class CarlaStereoTestCase(datasets_utils.ImageDatasetTestCase):
datasets_utils.shape_test_for_stereo(left, right, disparity) datasets_utils.shape_test_for_stereo(left, right, disparity)
class SceneFlowStereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SceneFlowStereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
variant=("FlyingThings3D", "Driving", "Monkaa"), pass_name=("clean", "final", "both")
)
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
@staticmethod
def _create_pfm_folder(
root: str, name: str, file_name_fn: Callable[..., str], num_examples: int, size: Tuple[int, int]
) -> None:
root = pathlib.Path(root) / name
os.makedirs(root, exist_ok=True)
for i in range(num_examples):
datasets_utils.make_fake_pfm_file(size[0], size[1], root / file_name_fn(i))
def inject_fake_data(self, tmpdir, config):
scene_flow_dir = pathlib.Path(tmpdir) / "SceneFlow"
os.makedirs(scene_flow_dir, exist_ok=True)
variant_dir = scene_flow_dir / config["variant"]
os.makedirs(variant_dir, exist_ok=True)
num_examples = {"FlyingThings3D": 4, "Driving": 6, "Monkaa": 5}.get(config["variant"], 0)
passes = {
"clean": ["frames_cleanpass"],
"final": ["frames_finalpass"],
"both": ["frames_cleanpass", "frames_finalpass"],
}.get(config["pass_name"], [])
for pass_dir_name in passes:
# create pass directories
pass_dir = variant_dir / pass_dir_name
disp_dir = variant_dir / "disparity"
os.makedirs(pass_dir, exist_ok=True)
os.makedirs(disp_dir, exist_ok=True)
for direction in ["left", "right"]:
for scene_idx in range(num_examples):
os.makedirs(pass_dir / f"scene_{scene_idx:06d}", exist_ok=True)
datasets_utils.create_image_folder(
root=pass_dir / f"scene_{scene_idx:06d}",
name=direction,
file_name_fn=lambda i: f"{i:06d}.png",
num_examples=1,
size=(3, 200, 100),
)
os.makedirs(disp_dir / f"scene_{scene_idx:06d}", exist_ok=True)
self._create_pfm_folder(
root=disp_dir / f"scene_{scene_idx:06d}",
name=direction,
file_name_fn=lambda i: f"{i:06d}.pfm",
num_examples=1,
size=(100, 200),
)
if config["pass_name"] == "both":
num_examples *= 2
return num_examples
def test_splits(self):
for variant_name, pass_name in itertools.product(["FlyingThings3D", "Driving", "Monkaa"], ["clean", "final"]):
with self.create_dataset(variant=variant_name, pass_name=pass_name) as (dataset, _):
for left, right, disparity in dataset:
datasets_utils.shape_test_for_stereo(left, right, disparity)
def test_bad_input(self):
with pytest.raises(ValueError, match="Unknown value 'bad' for argument variant"):
with self.create_dataset(variant="bad"):
pass
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo, SceneFlowStereo
from .caltech import Caltech101, Caltech256 from .caltech import Caltech101, Caltech256
from .celeba import CelebA from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100 from .cifar import CIFAR10, CIFAR100
...@@ -109,4 +109,5 @@ __all__ = ( ...@@ -109,4 +109,5 @@ __all__ = (
"Kitti2012Stereo", "Kitti2012Stereo",
"Kitti2015Stereo", "Kitti2015Stereo",
"CarlaStereo", "CarlaStereo",
"SceneFlowStereo",
) )
...@@ -359,3 +359,109 @@ class Kitti2015Stereo(StereoMatchingDataset): ...@@ -359,3 +359,109 @@ class Kitti2015Stereo(StereoMatchingDataset):
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
""" """
return super().__getitem__(index) return super().__getitem__(index)
class SceneFlowStereo(StereoMatchingDataset):
"""Dataset interface for `Scene Flow <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ datasets.
This interface provides access to the `FlyingThings3D, `Monkaa` and `Driving` datasets.
The dataset is expected to have the following structre: ::
root
SceneFlow
Monkaa
frames_cleanpass
scene1
left
img1.png
img2.png
right
img1.png
img2.png
scene2
left
img1.png
img2.png
right
img1.png
img2.png
frames_finalpass
scene1
left
img1.png
img2.png
right
img1.png
img2.png
...
...
disparity
scene1
left
img1.pfm
img2.pfm
right
img1.pfm
img2.pfm
FlyingThings3D
...
...
Args:
root (string): Root directory where SceneFlow is located.
variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving".
pass_name (string): Which pass to use, "clean" (default), "final" or "both".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
def __init__(
self,
root: str,
variant: str = "FlyingThings3D",
pass_name: str = "clean",
transforms: Optional[Callable] = None,
):
super().__init__(root, transforms)
root = Path(root) / "SceneFlow"
verify_str_arg(variant, "variant", valid_values=("FlyingThings3D", "Driving", "Monkaa"))
verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
passes = {
"clean": ["frames_cleanpass"],
"final": ["frames_finalpass"],
"both": ["frames_cleanpass", "frames_finalpass"],
}[pass_name]
root = root / variant
for p in passes:
left_image_pattern = str(root / p / "*" / "left" / "*.png")
right_image_pattern = str(root / p / "*" / "right" / "*.png")
self._images += self._scan_pairs(left_image_pattern, right_image_pattern)
left_disparity_pattern = str(root / "disparity" / "*" / "left" / "*.pfm")
right_disparity_pattern = str(root / "disparity" / "*" / "right" / "*.pfm")
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
def _read_disparity(self, file_path: str) -> Tuple:
disparity_map = _read_pfm_file(file_path)
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
valid_mask = None
return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple:
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
"""
return super().__getitem__(index)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment