Unverified Commit b0aa60ca authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added missing typing annotations in datasets/_stereo_matching (#6846)



* style: Added missing typing annotations in datasets/_stereo_matching

* style: Specified typing

* style: Specified type annotations further

* style: Specified typing of __getitem__

* style: Fixed linting
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 7aef1153
......@@ -6,7 +6,7 @@ import shutil
from abc import ABC, abstractmethod
from glob import glob
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, cast, List, Optional, Tuple, Union
import numpy as np
from PIL import Image
......@@ -14,6 +14,9 @@ from PIL import Image
from .utils import _read_pfm, download_and_extract_archive, verify_str_arg
from .vision import VisionDataset
T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], np.ndarray]
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
__all__ = ()
_read_pfm_file = functools.partial(_read_pfm, slice_channels=1)
......@@ -24,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset):
_has_built_in_disparity_mask = False
def __init__(self, root: str, transforms: Optional[Callable] = None):
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
"""
Args:
root(str): Root directory of the dataset.
......@@ -58,7 +61,11 @@ class StereoMatchingDataset(ABC, VisionDataset):
img = img.convert("RGB")
return img
def _scan_pairs(self, paths_left_pattern: str, paths_right_pattern: Optional[str] = None):
def _scan_pairs(
self,
paths_left_pattern: str,
paths_right_pattern: Optional[str] = None,
) -> List[Tuple[str, Optional[str]]]:
left_paths = list(sorted(glob(paths_left_pattern)))
......@@ -85,11 +92,11 @@ class StereoMatchingDataset(ABC, VisionDataset):
return paths
@abstractmethod
def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
# function that returns a disparity map and an occlusion map
pass
def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index.
Args:
......@@ -120,7 +127,7 @@ class StereoMatchingDataset(ABC, VisionDataset):
) = self.transforms(imgs, dsp_maps, valid_masks)
if self._has_built_in_disparity_mask or valid_masks[0] is not None:
return imgs[0], imgs[1], dsp_maps[0], valid_masks[0]
return imgs[0], imgs[1], dsp_maps[0], cast(np.ndarray, valid_masks[0])
else:
return imgs[0], imgs[1], dsp_maps[0]
......@@ -156,7 +163,7 @@ class CarlaStereo(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
def __init__(self, root: str, transforms: Optional[Callable] = None):
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)
root = Path(root) / "carla-highres"
......@@ -171,13 +178,13 @@ class CarlaStereo(StereoMatchingDataset):
disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
self._disparities = disparities
def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
disparity_map = _read_pfm_file(file_path)
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
valid_mask = None
return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> T1:
"""Return example at given index.
Args:
......@@ -189,7 +196,7 @@ class CarlaStereo(StereoMatchingDataset):
If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
"""
return super().__getitem__(index)
return cast(T1, super().__getitem__(index))
class Kitti2012Stereo(StereoMatchingDataset):
......@@ -233,7 +240,7 @@ class Kitti2012Stereo(StereoMatchingDataset):
_has_built_in_disparity_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None):
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test"))
......@@ -250,7 +257,7 @@ class Kitti2012Stereo(StereoMatchingDataset):
else:
self._disparities = list((None, None) for _ in self._images)
def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
# test split has no disparity maps
if file_path is None:
return None, None
......@@ -261,7 +268,7 @@ class Kitti2012Stereo(StereoMatchingDataset):
valid_mask = None
return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> T1:
"""Return example at given index.
Args:
......@@ -274,7 +281,7 @@ class Kitti2012Stereo(StereoMatchingDataset):
generate a valid mask.
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
"""
return super().__getitem__(index)
return cast(T1, super().__getitem__(index))
class Kitti2015Stereo(StereoMatchingDataset):
......@@ -321,7 +328,7 @@ class Kitti2015Stereo(StereoMatchingDataset):
_has_built_in_disparity_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None):
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test"))
......@@ -338,7 +345,7 @@ class Kitti2015Stereo(StereoMatchingDataset):
else:
self._disparities = list((None, None) for _ in self._images)
def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
# test split has no disparity maps
if file_path is None:
return None, None
......@@ -349,7 +356,7 @@ class Kitti2015Stereo(StereoMatchingDataset):
valid_mask = None
return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> T1:
"""Return example at given index.
Args:
......@@ -362,7 +369,7 @@ class Kitti2015Stereo(StereoMatchingDataset):
generate a valid mask.
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
"""
return super().__getitem__(index)
return cast(T1, super().__getitem__(index))
class Middlebury2014Stereo(StereoMatchingDataset):
......@@ -479,7 +486,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
use_ambient_views: bool = False,
transforms: Optional[Callable] = None,
download: bool = False,
):
) -> None:
super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test", "additional"))
......@@ -558,7 +565,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
file_path = random.choice(ambient_file_paths) # type: ignore
return super()._read_img(file_path)
def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
# test split has not disparity maps
if file_path is None:
return None, None
......@@ -569,7 +576,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities
return disparity_map, valid_mask
def _download_dataset(self, root: str):
def _download_dataset(self, root: str) -> None:
base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
# train and additional splits have 2 different calibration settings
root = Path(root) / "Middlebury2014"
......@@ -608,7 +615,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
# cleanup MiddEval3 directory
shutil.rmtree(str(root / "MiddEval3"))
def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> T2:
"""Return example at given index.
Args:
......@@ -619,7 +626,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
``valid_mask`` is implicitly ``None`` for `split=test`.
"""
return super().__getitem__(index)
return cast(T2, super().__getitem__(index))
class CREStereo(StereoMatchingDataset):
......@@ -670,7 +677,7 @@ class CREStereo(StereoMatchingDataset):
self,
root: str,
transforms: Optional[Callable] = None,
):
) -> None:
super().__init__(root, transforms)
root = Path(root) / "CREStereo"
......@@ -688,14 +695,14 @@ class CREStereo(StereoMatchingDataset):
disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
self._disparities += disparities
def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
# unsqueeze the disparity map into (C, H, W) format
disparity_map = disparity_map[None, :, :] / 32.0
valid_mask = None
return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> T1:
"""Return example at given index.
Args:
......@@ -707,7 +714,7 @@ class CREStereo(StereoMatchingDataset):
``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
generate a valid mask.
"""
return super().__getitem__(index)
return cast(T1, super().__getitem__(index))
class FallingThingsStereo(StereoMatchingDataset):
......@@ -755,7 +762,7 @@ class FallingThingsStereo(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None):
def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)
root = Path(root) / "FallingThings"
......@@ -782,7 +789,7 @@ class FallingThingsStereo(StereoMatchingDataset):
right_disparity_pattern = str(root / s / split_prefix[s] / "*.right.depth.png")
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
# (H, W) image
depth = np.asarray(Image.open(file_path))
# as per https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt
......@@ -799,7 +806,7 @@ class FallingThingsStereo(StereoMatchingDataset):
valid_mask = None
return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> T1:
"""Return example at given index.
Args:
......@@ -811,7 +818,7 @@ class FallingThingsStereo(StereoMatchingDataset):
If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
"""
return super().__getitem__(index)
return cast(T1, super().__getitem__(index))
class SceneFlowStereo(StereoMatchingDataset):
......@@ -874,7 +881,7 @@ class SceneFlowStereo(StereoMatchingDataset):
variant: str = "FlyingThings3D",
pass_name: str = "clean",
transforms: Optional[Callable] = None,
):
) -> None:
super().__init__(root, transforms)
root = Path(root) / "SceneFlow"
......@@ -905,13 +912,13 @@ class SceneFlowStereo(StereoMatchingDataset):
right_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "right" / "*.pfm")
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
disparity_map = _read_pfm_file(file_path)
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
valid_mask = None
return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> T1:
"""Return example at given index.
Args:
......@@ -923,7 +930,7 @@ class SceneFlowStereo(StereoMatchingDataset):
If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
"""
return super().__getitem__(index)
return cast(T1, super().__getitem__(index))
class SintelStereo(StereoMatchingDataset):
......@@ -973,7 +980,7 @@ class SintelStereo(StereoMatchingDataset):
_has_built_in_disparity_mask = True
def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None):
def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)
verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
......@@ -1014,7 +1021,7 @@ class SintelStereo(StereoMatchingDataset):
return occlusion_path, outofframe_path
def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
if file_path is None:
return None, None
......@@ -1034,7 +1041,7 @@ class SintelStereo(StereoMatchingDataset):
valid_mask = np.logical_and(off_mask, valid_mask)
return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> T2:
"""Return example at given index.
Args:
......@@ -1045,7 +1052,7 @@ class SintelStereo(StereoMatchingDataset):
The disparity is a numpy array of shape (1, H, W) and the images are PIL images whilst
the valid_mask is a numpy array of shape (H, W).
"""
return super().__getitem__(index)
return cast(T2, super().__getitem__(index))
class InStereo2k(StereoMatchingDataset):
......@@ -1080,7 +1087,7 @@ class InStereo2k(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None):
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)
root = Path(root) / "InStereo2k" / split
......@@ -1095,14 +1102,14 @@ class InStereo2k(StereoMatchingDataset):
right_disparity_pattern = str(root / "*" / "right_disp.png")
self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
# unsqueeze disparity to (C, H, W)
disparity_map = disparity_map[None, :, :] / 1024.0
valid_mask = None
return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> T1:
"""Return example at given index.
Args:
......@@ -1114,7 +1121,7 @@ class InStereo2k(StereoMatchingDataset):
If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
"""
return super().__getitem__(index)
return cast(T1, super().__getitem__(index))
class ETH3DStereo(StereoMatchingDataset):
......@@ -1169,7 +1176,7 @@ class ETH3DStereo(StereoMatchingDataset):
_has_built_in_disparity_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None):
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test"))
......@@ -1189,7 +1196,7 @@ class ETH3DStereo(StereoMatchingDataset):
disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm")
self._disparities = self._scan_pairs(disparity_pattern, None)
def _read_disparity(self, file_path: str) -> Tuple:
def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
# test split has no disparity maps
if file_path is None:
return None, None
......@@ -1201,7 +1208,7 @@ class ETH3DStereo(StereoMatchingDataset):
valid_mask = np.asarray(valid_mask).astype(bool)
return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple:
def __getitem__(self, index: int) -> T2:
"""Return example at given index.
Args:
......@@ -1214,4 +1221,4 @@ class ETH3DStereo(StereoMatchingDataset):
generate a valid mask.
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
"""
return super().__getitem__(index)
return cast(T2, super().__getitem__(index))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment