"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "f5aa5f587c7b583d08d202d33ae1e29df787b2d7"
Unverified Commit 6bea4ef5 authored by Ponku's avatar Ponku Committed by GitHub
Browse files

Splitting Stereo Dataset PR(#6269) (#6311)

* Broken down PR(#6269). Added an additional dataset

* Removed some types. Store None instead of "". Merged test util functions.

* minor mypy fixes. minor doc fixes

* reformated docstring

* Added additional line-skips
parent ea0be26b
...@@ -101,6 +101,17 @@ Optical Flow ...@@ -101,6 +101,17 @@ Optical Flow
KittiFlow KittiFlow
Sintel Sintel
Stereo Matching
~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated/
:template: class_dataset.rst
CarlaStereo
Kitti2012Stereo
Kitti2015Stereo
Image pairs Image pairs
~~~~~~~~~~~ ~~~~~~~~~~~
......
...@@ -16,6 +16,8 @@ import zipfile ...@@ -16,6 +16,8 @@ import zipfile
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
import numpy as np
import PIL import PIL
import PIL.Image import PIL.Image
import pytest import pytest
...@@ -23,6 +25,7 @@ import torch ...@@ -23,6 +25,7 @@ import torch
import torchvision.datasets import torchvision.datasets
import torchvision.io import torchvision.io
from common_utils import disable_console_output, get_tmp_dir from common_utils import disable_console_output, get_tmp_dir
from torchvision.transforms.functional import get_dimensions
__all__ = [ __all__ = [
...@@ -748,6 +751,33 @@ def create_image_folder( ...@@ -748,6 +751,33 @@ def create_image_folder(
] ]
def shape_test_for_stereo(
left: PIL.Image.Image,
right: PIL.Image.Image,
disparity: Optional[np.ndarray] = None,
valid_mask: Optional[np.ndarray] = None,
):
left_dims = get_dimensions(left)
right_dims = get_dimensions(right)
c, h, w = left_dims
# check that left and right are the same size
assert left_dims == right_dims
assert c == 3
# check that the disparity has the same spatial dimensions
# as the input
if disparity is not None:
assert disparity.ndim == 3
assert disparity.shape == (1, h, w)
if valid_mask is not None:
# check that valid mask is the same size as the disparity
_, dh, dw = disparity.shape
mh, mw = valid_mask.shape
assert dh == mh
assert dw == mw
@requires_lazy_imports("av") @requires_lazy_imports("av")
def create_video_file( def create_video_file(
root: Union[pathlib.Path, str], root: Union[pathlib.Path, str],
......
...@@ -13,6 +13,7 @@ import string ...@@ -13,6 +13,7 @@ import string
import unittest import unittest
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import zipfile import zipfile
from typing import Union
import datasets_utils import datasets_utils
import numpy as np import numpy as np
...@@ -2671,5 +2672,174 @@ class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2671,5 +2672,174 @@ class RenderedSST2TestCase(datasets_utils.ImageDatasetTestCase):
return len(sampled_classes) * num_images_per_class[config["split"]] return len(sampled_classes) * num_images_per_class[config["split"]]
class Kitti2012StereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Kitti2012Stereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
def inject_fake_data(self, tmpdir, config):
kitti_dir = pathlib.Path(tmpdir) / "Kitti2012"
os.makedirs(kitti_dir, exist_ok=True)
split_dir = kitti_dir / (config["split"] + "ing")
os.makedirs(split_dir, exist_ok=True)
num_examples = {"train": 4, "test": 3}.get(config["split"], 0)
datasets_utils.create_image_folder(
root=split_dir,
name="colored_0",
file_name_fn=lambda i: f"{i:06d}_10.png",
num_examples=num_examples,
size=(3, 100, 200),
)
datasets_utils.create_image_folder(
root=split_dir,
name="colored_1",
file_name_fn=lambda i: f"{i:06d}_10.png",
num_examples=num_examples,
size=(3, 100, 200),
)
if config["split"] == "train":
datasets_utils.create_image_folder(
root=split_dir,
name="disp_noc",
file_name_fn=lambda i: f"{i:06d}.png",
num_examples=num_examples,
# Kitti2012 uses a single channel image for disparities
size=(1, 100, 200),
)
return num_examples
def test_train_splits(self):
for split in ["train"]:
with self.create_dataset(split=split) as (dataset, _):
for left, right, disparity, mask in dataset:
assert mask is None
datasets_utils.shape_test_for_stereo(left, right, disparity)
def test_test_split(self):
for split in ["test"]:
with self.create_dataset(split=split) as (dataset, _):
for left, right, disparity, mask in dataset:
assert mask is None
assert disparity is None
datasets_utils.shape_test_for_stereo(left, right)
def test_bad_input(self):
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
with self.create_dataset(split="bad"):
pass
class Kitti2015StereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Kitti2015Stereo
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
def inject_fake_data(self, tmpdir, config):
kitti_dir = pathlib.Path(tmpdir) / "Kitti2015"
os.makedirs(kitti_dir, exist_ok=True)
split_dir = kitti_dir / (config["split"] + "ing")
os.makedirs(split_dir, exist_ok=True)
num_examples = {"train": 4, "test": 6}.get(config["split"], 0)
datasets_utils.create_image_folder(
root=split_dir,
name="image_2",
file_name_fn=lambda i: f"{i:06d}_10.png",
num_examples=num_examples,
size=(3, 100, 200),
)
datasets_utils.create_image_folder(
root=split_dir,
name="image_3",
file_name_fn=lambda i: f"{i:06d}_10.png",
num_examples=num_examples,
size=(3, 100, 200),
)
if config["split"] == "train":
datasets_utils.create_image_folder(
root=split_dir,
name="disp_occ_0",
file_name_fn=lambda i: f"{i:06d}.png",
num_examples=num_examples,
# Kitti2015 uses a single channel image for disparities
size=(1, 100, 200),
)
datasets_utils.create_image_folder(
root=split_dir,
name="disp_occ_1",
file_name_fn=lambda i: f"{i:06d}.png",
num_examples=num_examples,
# Kitti2015 uses a single channel image for disparities
size=(1, 100, 200),
)
return num_examples
def test_train_splits(self):
for split in ["train"]:
with self.create_dataset(split=split) as (dataset, _):
for left, right, disparity, mask in dataset:
assert mask is None
datasets_utils.shape_test_for_stereo(left, right, disparity)
def test_test_split(self):
for split in ["test"]:
with self.create_dataset(split=split) as (dataset, _):
for left, right, disparity, mask in dataset:
assert mask is None
assert disparity is None
datasets_utils.shape_test_for_stereo(left, right)
def test_bad_input(self):
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
with self.create_dataset(split="bad"):
pass
class CarlaStereoTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.CarlaStereo
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, None))
@staticmethod
def _create_scene_folders(num_examples: int, root_dir: Union[str, pathlib.Path]):
# make the root_dir if it does not exits
os.makedirs(root_dir, exist_ok=True)
for i in range(num_examples):
scene_dir = pathlib.Path(root_dir) / f"scene_{i}"
os.makedirs(scene_dir, exist_ok=True)
# populate with left right images
datasets_utils.create_image_file(root=scene_dir, name="im0.png", size=(100, 100))
datasets_utils.create_image_file(root=scene_dir, name="im1.png", size=(100, 100))
datasets_utils.make_fake_pfm_file(100, 100, file_name=str(scene_dir / "disp0GT.pfm"))
datasets_utils.make_fake_pfm_file(100, 100, file_name=str(scene_dir / "disp1GT.pfm"))
def inject_fake_data(self, tmpdir, config):
carla_dir = pathlib.Path(tmpdir) / "carla-highres"
os.makedirs(carla_dir, exist_ok=True)
split_dir = pathlib.Path(carla_dir) / "trainingF"
os.makedirs(split_dir, exist_ok=True)
num_examples = 6
self._create_scene_folders(num_examples=num_examples, root_dir=split_dir)
return num_examples
def test_train_splits(self):
with self.create_dataset() as (dataset, _):
for left, right, disparity in dataset:
datasets_utils.shape_test_for_stereo(left, right, disparity)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
from ._stereo_matching import CarlaStereo, Kitti2012Stereo, Kitti2015Stereo
from .caltech import Caltech101, Caltech256 from .caltech import Caltech101, Caltech256
from .celeba import CelebA from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100 from .cifar import CIFAR10, CIFAR100
...@@ -105,4 +106,7 @@ __all__ = ( ...@@ -105,4 +106,7 @@ __all__ = (
"FGVCAircraft", "FGVCAircraft",
"EuroSAT", "EuroSAT",
"RenderedSST2", "RenderedSST2",
"Kitti2012Stereo",
"Kitti2015Stereo",
"CarlaStereo",
) )
import functools
from abc import ABC, abstractmethod
from glob import glob
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
from PIL import Image
from .utils import _read_pfm, verify_str_arg
from .vision import VisionDataset
__all__ = ()
_read_pfm_file = functools.partial(_read_pfm, slice_channels=1)
class StereoMatchingDataset(ABC, VisionDataset):
"""Base interface for Stereo matching datasets"""
_has_built_in_disparity_mask = False
def __init__(self, root: str, transforms: Optional[Callable] = None):
"""
Args:
root(str): Root directory of the dataset.
transforms(callable, optional): A function/transform that takes in Tuples of
(images, disparities, valid_masks) and returns a transformed version of each of them.
images is a Tuple of (``PIL.Image``, ``PIL.Image``)
disparities is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (1, H, W)
valid_masks is a Tuple of (``np.ndarray``, ``np.ndarray``) with shape (H, W)
In some cases, when a dataset does not provide disparities, the ``disparities`` and
``valid_masks`` can be Tuples containing None values.
For training splits generally the datasets provide a minimal guarantee of
images: (``PIL.Image``, ``PIL.Image``)
disparities: (``np.ndarray``, ``None``) with shape (1, H, W)
Optionally, based on the dataset, it can return a ``mask`` as well:
valid_masks: (``np.ndarray | None``, ``None``) with shape (H, W)
For some test splits, the datasets provides outputs that look like:
imgaes: (``PIL.Image``, ``PIL.Image``)
disparities: (``None``, ``None``)
Optionally, based on the dataset, it can return a ``mask`` as well:
valid_masks: (``None``, ``None``)
"""
super().__init__(root=root)
self.transforms = transforms
self._images = [] # type: ignore
self._disparities = [] # type: ignore
def _read_img(self, file_path: str) -> Image.Image:
img = Image.open(file_path)
if img.mode != "RGB":
img = img.convert("RGB")
return img
def _scan_pairs(self, paths_left_pattern: str, paths_right_pattern: Optional[str] = None):
left_paths = list(sorted(glob(paths_left_pattern)))
right_paths: List[Union[None, str]]
if paths_right_pattern:
right_paths = list(sorted(glob(paths_right_pattern)))
else:
right_paths = list(None for _ in left_paths)
if not left_paths:
raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_left_pattern}")
if not right_paths:
raise FileNotFoundError(f"Could not find any files matching the patterns: {paths_right_pattern}")
if len(left_paths) != len(right_paths):
raise ValueError(
f"Found {len(left_paths)} left files but {len(right_paths)} right files using:\n "
f"left pattern: {paths_left_pattern}\n"
f"right pattern: {paths_right_pattern}\n"
)
paths = list((left, right) for left, right in zip(left_paths, right_paths))
return paths
@abstractmethod
def _read_disparity(self, file_path: str) -> Tuple:
# function that returns a disparity map and an occlusion map
pass
def __getitem__(self, index: int) -> Tuple:
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: A 3 or 4-tuple with ``(img_left, img_right, disparity, Optional[valid_mask])`` where ``valid_mask``
can be a numpy boolean mask of shape (H, W) if the dataset provides a file
indicating which disparity pixels are valid. The disparity is a numpy array of
shape (1, H, W) and the images are PIL images. ``disparity`` is None for
datasets on which for ``split="test"`` the authors did not provide annotations.
"""
img_left = self._read_img(self._images[index][0])
img_right = self._read_img(self._images[index][1])
dsp_map_left, valid_mask_left = self._read_disparity(self._disparities[index][0])
dsp_map_right, valid_mask_right = self._read_disparity(self._disparities[index][1])
imgs = (img_left, img_right)
dsp_maps = (dsp_map_left, dsp_map_right)
valid_masks = (valid_mask_left, valid_mask_right)
if self.transforms is not None:
(
imgs,
dsp_maps,
valid_masks,
) = self.transforms(imgs, dsp_maps, valid_masks)
if self._has_built_in_disparity_mask or valid_masks[0] is not None:
return imgs[0], imgs[1], dsp_maps[0], valid_masks[0]
else:
return imgs[0], imgs[1], dsp_maps[0]
def __len__(self) -> int:
return len(self._images)
class CarlaStereo(StereoMatchingDataset):
"""
Carla simulator data linked in the `CREStereo github repo <https://github.com/megvii-research/CREStereo>`_.
The dataset is expected to have the following structure: ::
root
carla-highres
trainingF
scene1
img0.png
img1.png
disp0GT.pfm
disp1GT.pfm
calib.txt
scene2
img0.png
img1.png
disp0GT.pfm
disp1GT.pfm
calib.txt
...
Args:
root (string): Root directory where `carla-highres` is located.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
def __init__(self, root: str, transforms: Optional[Callable] = None):
super().__init__(root, transforms)
root = Path(root) / "carla-highres"
left_image_pattern = str(root / "trainingF" / "*" / "im0.png")
right_image_pattern = str(root / "trainingF" / "*" / "im1.png")
imgs = self._scan_pairs(left_image_pattern, right_image_pattern)
self._images = imgs
left_disparity_pattern = str(root / "trainingF" / "*" / "disp0GT.pfm")
right_disparity_pattern = str(root / "trainingF" / "*" / "disp1GT.pfm")
disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
self._disparities = disparities
def _read_disparity(self, file_path: str) -> Tuple:
disparity_map = _read_pfm_file(file_path)
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
valid_mask = None
return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple:
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: A 3-tuple with ``(img_left, img_right, disparity)``.
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
"""
return super().__getitem__(index)
class Kitti2012Stereo(StereoMatchingDataset):
"""
KITTI dataset from the `2012 stereo evaluation benchmark <http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php>`_.
Uses the RGB images for consistency with KITTI 2015.
The dataset is expected to have the following structure: ::
root
Kitti2012
testing
colored_0
1_10.png
2_10.png
...
colored_1
1_10.png
2_10.png
...
training
colored_0
1_10.png
2_10.png
...
colored_1
1_10.png
2_10.png
...
disp_noc
1.png
2.png
...
calib
Args:
root (string): Root directory where `Kitti2012` is located.
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
_has_built_in_disparity_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None):
super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test"))
root = Path(root) / "Kitti2012" / (split + "ing")
left_img_pattern = str(root / "colored_0" / "*_10.png")
right_img_pattern = str(root / "colored_1" / "*_10.png")
self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
if split == "train":
disparity_pattern = str(root / "disp_noc" / "*.png")
self._disparities = self._scan_pairs(disparity_pattern, None)
else:
self._disparities = list((None, None) for _ in self._images)
def _read_disparity(self, file_path: str) -> Tuple:
# test split has no disparity maps
if file_path is None:
return None, None
disparity_map = np.asarray(Image.open(file_path)) / 256.0
# unsqueeze the disparity map into (C, H, W) format
disparity_map = disparity_map[None, :, :]
valid_mask = None
return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple:
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
generate a valid mask.
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
"""
return super().__getitem__(index)
class Kitti2015Stereo(StereoMatchingDataset):
"""
KITTI dataset from the `2015 stereo evaluation benchmark <http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php>`_.
The dataset is expected to have the following structure: ::
root
Kitti2015
testing
image_2
img1.png
img2.png
...
image_3
img1.png
img2.png
...
training
image_2
img1.png
img2.png
...
image_3
img1.png
img2.png
...
disp_occ_0
img1.png
img2.png
...
disp_occ_1
img1.png
img2.png
...
calib
Args:
root (string): Root directory where `Kitti2015` is located.
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
_has_built_in_disparity_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None):
super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test"))
root = Path(root) / "Kitti2015" / (split + "ing")
left_img_pattern = str(root / "image_2" / "*.png")
right_img_pattern = str(root / "image_3" / "*.png")
self._images = self._scan_pairs(left_img_pattern, right_img_pattern)
if split == "train":
left_disparity_pattern = str(root / "disp_occ_0" / "*.png")
right_disparity_pattern = str(root / "disp_occ_1" / "*.png")
self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
else:
self._disparities = list((None, None) for _ in self._images)
def _read_disparity(self, file_path: str) -> Tuple:
# test split has no disparity maps
if file_path is None:
return None, None
disparity_map = np.asarray(Image.open(file_path)) / 256.0
# unsqueeze the disparity map into (C, H, W) format
disparity_map = disparity_map[None, :, :]
valid_mask = None
return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple:
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
generate a valid mask.
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
"""
return super().__getitem__(index)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment