Unverified Commit 15a9a93b authored by Ponku's avatar Ponku Committed by GitHub
Browse files

Added CREStereo dataset (#6351)


Co-authored-by: default avatarJoao Gomes <jdsgomes@fb.com>
parent 1d6a259c
......@@ -111,6 +111,7 @@ Stereo Matching
CarlaStereo
Kitti2012Stereo
Kitti2015Stereo
CREStereo
FallingThingsStereo
SceneFlowStereo
SintelStereo
......
......@@ -2841,6 +2841,37 @@ class CarlaStereoTestCase(datasets_utils.ImageDatasetTestCase):
datasets_utils.shape_test_for_stereo(left, right, disparity)
class CREStereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CREStereo
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, np.ndarray, type(None))
def inject_fake_data(self, tmpdir, config):
crestereo_dir = pathlib.Path(tmpdir) / "CREStereo"
os.makedirs(crestereo_dir, exist_ok=True)
examples = {"tree": 2, "shapenet": 3, "reflective": 6, "hole": 5}
for category_name in ["shapenet", "reflective", "tree", "hole"]:
split_dir = crestereo_dir / category_name
os.makedirs(split_dir, exist_ok=True)
num_examples = examples[category_name]
for idx in range(num_examples):
datasets_utils.create_image_file(root=split_dir, name=f"{idx}_left.jpg", size=(100, 100))
datasets_utils.create_image_file(root=split_dir, name=f"{idx}_right.jpg", size=(100, 100))
# these are going to end up being gray scale images
datasets_utils.create_image_file(root=split_dir, name=f"{idx}_left.disp.png", size=(1, 100, 100))
datasets_utils.create_image_file(root=split_dir, name=f"{idx}_right.disp.png", size=(1, 100, 100))
return sum(examples.values())
def test_splits(self):
with self.create_dataset() as (dataset, _):
for left, right, disparity, mask in dataset:
assert mask is None
datasets_utils.shape_test_for_stereo(left, right, disparity)
class FallingThingsStereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FallingThingsStereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(variant=("single", "mixed", "both"))
......
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
from ._stereo_matching import (
CarlaStereo,
CREStereo,
ETH3DStereo,
FallingThingsStereo,
InStereo2k,
......@@ -118,6 +119,7 @@ __all__ = (
"Kitti2012Stereo",
"Kitti2015Stereo",
"CarlaStereo",
"CREStereo",
"FallingThingsStereo",
"SceneFlowStereo",
"SintelStereo",
......
......@@ -363,6 +363,94 @@ class Kitti2015Stereo(StereoMatchingDataset):
return super().__getitem__(index)
class CREStereo(StereoMatchingDataset):
"""Synthetic dataset used in training the `CREStereo <https://arxiv.org/pdf/2203.11483.pdf>`_ architecture.
Dataset details on the official paper `repo <https://github.com/megvii-research/CREStereo>`_.
The dataset is expected to have the following structure: ::
root
CREStereo
tree
img1_left.jpg
img1_right.jpg
img1_left.disp.jpg
img1_right.disp.jpg
img2_left.jpg
img2_right.jpg
img2_left.disp.jpg
img2_right.disp.jpg
...
shapenet
img1_left.jpg
img1_right.jpg
img1_left.disp.jpg
img1_right.disp.jpg
...
reflective
img1_left.jpg
img1_right.jpg
img1_left.disp.jpg
img1_right.disp.jpg
...
hole
img1_left.jpg
img1_right.jpg
img1_left.disp.jpg
img1_right.disp.jpg
...
Args:
root (str): Root directory of the dataset.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
_has_built_in_disparity_mask = True
def __init__(
self,
root: str,
transforms: Optional[Callable] = None,
):
super().__init__(root, transforms)
root = Path(root) / "CREStereo"
dirs = ["shapenet", "reflective", "tree", "hole"]
for s in dirs:
left_image_pattern = str(root / s / "*_left.jpg")
right_image_pattern = str(root / s / "*_right.jpg")
imgs = self._scan_pairs(left_image_pattern, right_image_pattern)
self._images += imgs
left_disparity_pattern = str(root / s / "*_left.disp.png")
right_disparity_pattern = str(root / s / "*_right.disp.png")
disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
self._disparities += disparities
def _read_disparity(self, file_path: str) -> Tuple:
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
# unsqueeze the disparity map into (C, H, W) format
disparity_map = disparity_map[None, :, :] / 256.0
valid_mask = None
return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple:
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
generate a valid mask.
"""
return super().__getitem__(index)
class FallingThingsStereo(StereoMatchingDataset):
"""`FallingThings <https://research.nvidia.com/publication/2018-06_falling-things-synthetic-dataset-3d-object-detection-and-pose-estimation>`_ dataset.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment