"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "c5e43637945c335a46a6c4dda4d2d874c07465c0"
Unverified Commit b0aa60ca authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added missing typing annotations in datasets/_stereo_matching (#6846)



* style: Added missing typing annotations in datasets/_stereo_matching

* style: Specified typing

* style: Specified type annotations further

* style: Specified typing of __getitem__

* style: Fixed linting
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 7aef1153
...@@ -6,7 +6,7 @@ import shutil ...@@ -6,7 +6,7 @@ import shutil
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from glob import glob from glob import glob
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, cast, List, Optional, Tuple, Union
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -14,6 +14,9 @@ from PIL import Image ...@@ -14,6 +14,9 @@ from PIL import Image
from .utils import _read_pfm, download_and_extract_archive, verify_str_arg from .utils import _read_pfm, download_and_extract_archive, verify_str_arg
from .vision import VisionDataset from .vision import VisionDataset
T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], np.ndarray]
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
__all__ = () __all__ = ()
_read_pfm_file = functools.partial(_read_pfm, slice_channels=1) _read_pfm_file = functools.partial(_read_pfm, slice_channels=1)
...@@ -24,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset): ...@@ -24,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset):
_has_built_in_disparity_mask = False _has_built_in_disparity_mask = False
def __init__(self, root: str, transforms: Optional[Callable] = None): def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
""" """
Args: Args:
root(str): Root directory of the dataset. root(str): Root directory of the dataset.
...@@ -58,7 +61,11 @@ class StereoMatchingDataset(ABC, VisionDataset): ...@@ -58,7 +61,11 @@ class StereoMatchingDataset(ABC, VisionDataset):
img = img.convert("RGB") img = img.convert("RGB")
return img return img
def _scan_pairs(self, paths_left_pattern: str, paths_right_pattern: Optional[str] = None): def _scan_pairs(
self,
paths_left_pattern: str,
paths_right_pattern: Optional[str] = None,
) -> List[Tuple[str, Optional[str]]]:
left_paths = list(sorted(glob(paths_left_pattern))) left_paths = list(sorted(glob(paths_left_pattern)))
...@@ -85,11 +92,11 @@ class StereoMatchingDataset(ABC, VisionDataset): ...@@ -85,11 +92,11 @@ class StereoMatchingDataset(ABC, VisionDataset):
return paths return paths
@abstractmethod @abstractmethod
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
# function that returns a disparity map and an occlusion map # function that returns a disparity map and an occlusion map
pass pass
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -120,7 +127,7 @@ class StereoMatchingDataset(ABC, VisionDataset): ...@@ -120,7 +127,7 @@ class StereoMatchingDataset(ABC, VisionDataset):
) = self.transforms(imgs, dsp_maps, valid_masks) ) = self.transforms(imgs, dsp_maps, valid_masks)
if self._has_built_in_disparity_mask or valid_masks[0] is not None: if self._has_built_in_disparity_mask or valid_masks[0] is not None:
return imgs[0], imgs[1], dsp_maps[0], valid_masks[0] return imgs[0], imgs[1], dsp_maps[0], cast(np.ndarray, valid_masks[0])
else: else:
return imgs[0], imgs[1], dsp_maps[0] return imgs[0], imgs[1], dsp_maps[0]
...@@ -156,7 +163,7 @@ class CarlaStereo(StereoMatchingDataset): ...@@ -156,7 +163,7 @@ class CarlaStereo(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
""" """
def __init__(self, root: str, transforms: Optional[Callable] = None): def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
root = Path(root) / "carla-highres" root = Path(root) / "carla-highres"
...@@ -171,13 +178,13 @@ class CarlaStereo(StereoMatchingDataset): ...@@ -171,13 +178,13 @@ class CarlaStereo(StereoMatchingDataset):
disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
self._disparities = disparities self._disparities = disparities
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
disparity_map = _read_pfm_file(file_path) disparity_map = _read_pfm_file(file_path)
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
valid_mask = None valid_mask = None
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T1:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -189,7 +196,7 @@ class CarlaStereo(StereoMatchingDataset): ...@@ -189,7 +196,7 @@ class CarlaStereo(StereoMatchingDataset):
If a ``valid_mask`` is generated within the ``transforms`` parameter, If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
""" """
return super().__getitem__(index) return cast(T1, super().__getitem__(index))
class Kitti2012Stereo(StereoMatchingDataset): class Kitti2012Stereo(StereoMatchingDataset):
...@@ -233,7 +240,7 @@ class Kitti2012Stereo(StereoMatchingDataset): ...@@ -233,7 +240,7 @@ class Kitti2012Stereo(StereoMatchingDataset):
_has_built_in_disparity_mask = True _has_built_in_disparity_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
...@@ -250,7 +257,7 @@ class Kitti2012Stereo(StereoMatchingDataset): ...@@ -250,7 +257,7 @@ class Kitti2012Stereo(StereoMatchingDataset):
else: else:
self._disparities = list((None, None) for _ in self._images) self._disparities = list((None, None) for _ in self._images)
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
# test split has no disparity maps # test split has no disparity maps
if file_path is None: if file_path is None:
return None, None return None, None
...@@ -261,7 +268,7 @@ class Kitti2012Stereo(StereoMatchingDataset): ...@@ -261,7 +268,7 @@ class Kitti2012Stereo(StereoMatchingDataset):
valid_mask = None valid_mask = None
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T1:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -274,7 +281,7 @@ class Kitti2012Stereo(StereoMatchingDataset): ...@@ -274,7 +281,7 @@ class Kitti2012Stereo(StereoMatchingDataset):
generate a valid mask. generate a valid mask.
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
""" """
return super().__getitem__(index) return cast(T1, super().__getitem__(index))
class Kitti2015Stereo(StereoMatchingDataset): class Kitti2015Stereo(StereoMatchingDataset):
...@@ -321,7 +328,7 @@ class Kitti2015Stereo(StereoMatchingDataset): ...@@ -321,7 +328,7 @@ class Kitti2015Stereo(StereoMatchingDataset):
_has_built_in_disparity_mask = True _has_built_in_disparity_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
...@@ -338,7 +345,7 @@ class Kitti2015Stereo(StereoMatchingDataset): ...@@ -338,7 +345,7 @@ class Kitti2015Stereo(StereoMatchingDataset):
else: else:
self._disparities = list((None, None) for _ in self._images) self._disparities = list((None, None) for _ in self._images)
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Tuple[Optional[np.ndarray], None]:
# test split has no disparity maps # test split has no disparity maps
if file_path is None: if file_path is None:
return None, None return None, None
...@@ -349,7 +356,7 @@ class Kitti2015Stereo(StereoMatchingDataset): ...@@ -349,7 +356,7 @@ class Kitti2015Stereo(StereoMatchingDataset):
valid_mask = None valid_mask = None
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T1:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -362,7 +369,7 @@ class Kitti2015Stereo(StereoMatchingDataset): ...@@ -362,7 +369,7 @@ class Kitti2015Stereo(StereoMatchingDataset):
generate a valid mask. generate a valid mask.
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
""" """
return super().__getitem__(index) return cast(T1, super().__getitem__(index))
class Middlebury2014Stereo(StereoMatchingDataset): class Middlebury2014Stereo(StereoMatchingDataset):
...@@ -479,7 +486,7 @@ class Middlebury2014Stereo(StereoMatchingDataset): ...@@ -479,7 +486,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
use_ambient_views: bool = False, use_ambient_views: bool = False,
transforms: Optional[Callable] = None, transforms: Optional[Callable] = None,
download: bool = False, download: bool = False,
): ) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test", "additional")) verify_str_arg(split, "split", valid_values=("train", "test", "additional"))
...@@ -558,7 +565,7 @@ class Middlebury2014Stereo(StereoMatchingDataset): ...@@ -558,7 +565,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
file_path = random.choice(ambient_file_paths) # type: ignore file_path = random.choice(ambient_file_paths) # type: ignore
return super()._read_img(file_path) return super()._read_img(file_path)
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
# test split has not disparity maps # test split has not disparity maps
if file_path is None: if file_path is None:
return None, None return None, None
...@@ -569,7 +576,7 @@ class Middlebury2014Stereo(StereoMatchingDataset): ...@@ -569,7 +576,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities
return disparity_map, valid_mask return disparity_map, valid_mask
def _download_dataset(self, root: str): def _download_dataset(self, root: str) -> None:
base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip" base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
# train and additional splits have 2 different calibration settings # train and additional splits have 2 different calibration settings
root = Path(root) / "Middlebury2014" root = Path(root) / "Middlebury2014"
...@@ -608,7 +615,7 @@ class Middlebury2014Stereo(StereoMatchingDataset): ...@@ -608,7 +615,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
# cleanup MiddEval3 directory # cleanup MiddEval3 directory
shutil.rmtree(str(root / "MiddEval3")) shutil.rmtree(str(root / "MiddEval3"))
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T2:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -619,7 +626,7 @@ class Middlebury2014Stereo(StereoMatchingDataset): ...@@ -619,7 +626,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
The disparity is a numpy array of shape (1, H, W) and the images are PIL images. The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
``valid_mask`` is implicitly ``None`` for `split=test`. ``valid_mask`` is implicitly ``None`` for `split=test`.
""" """
return super().__getitem__(index) return cast(T2, super().__getitem__(index))
class CREStereo(StereoMatchingDataset): class CREStereo(StereoMatchingDataset):
...@@ -670,7 +677,7 @@ class CREStereo(StereoMatchingDataset): ...@@ -670,7 +677,7 @@ class CREStereo(StereoMatchingDataset):
self, self,
root: str, root: str,
transforms: Optional[Callable] = None, transforms: Optional[Callable] = None,
): ) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
root = Path(root) / "CREStereo" root = Path(root) / "CREStereo"
...@@ -688,14 +695,14 @@ class CREStereo(StereoMatchingDataset): ...@@ -688,14 +695,14 @@ class CREStereo(StereoMatchingDataset):
disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
self._disparities += disparities self._disparities += disparities
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32) disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
# unsqueeze the disparity map into (C, H, W) format # unsqueeze the disparity map into (C, H, W) format
disparity_map = disparity_map[None, :, :] / 32.0 disparity_map = disparity_map[None, :, :] / 32.0
valid_mask = None valid_mask = None
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T1:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -707,7 +714,7 @@ class CREStereo(StereoMatchingDataset): ...@@ -707,7 +714,7 @@ class CREStereo(StereoMatchingDataset):
``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not ``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
generate a valid mask. generate a valid mask.
""" """
return super().__getitem__(index) return cast(T1, super().__getitem__(index))
class FallingThingsStereo(StereoMatchingDataset): class FallingThingsStereo(StereoMatchingDataset):
...@@ -755,7 +762,7 @@ class FallingThingsStereo(StereoMatchingDataset): ...@@ -755,7 +762,7 @@ class FallingThingsStereo(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
""" """
def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None): def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
root = Path(root) / "FallingThings" root = Path(root) / "FallingThings"
...@@ -782,7 +789,7 @@ class FallingThingsStereo(StereoMatchingDataset): ...@@ -782,7 +789,7 @@ class FallingThingsStereo(StereoMatchingDataset):
right_disparity_pattern = str(root / s / split_prefix[s] / "*.right.depth.png") right_disparity_pattern = str(root / s / split_prefix[s] / "*.right.depth.png")
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern) self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
# (H, W) image # (H, W) image
depth = np.asarray(Image.open(file_path)) depth = np.asarray(Image.open(file_path))
# as per https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt # as per https://research.nvidia.com/sites/default/files/pubs/2018-06_Falling-Things/readme_0.txt
...@@ -799,7 +806,7 @@ class FallingThingsStereo(StereoMatchingDataset): ...@@ -799,7 +806,7 @@ class FallingThingsStereo(StereoMatchingDataset):
valid_mask = None valid_mask = None
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T1:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -811,7 +818,7 @@ class FallingThingsStereo(StereoMatchingDataset): ...@@ -811,7 +818,7 @@ class FallingThingsStereo(StereoMatchingDataset):
If a ``valid_mask`` is generated within the ``transforms`` parameter, If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
""" """
return super().__getitem__(index) return cast(T1, super().__getitem__(index))
class SceneFlowStereo(StereoMatchingDataset): class SceneFlowStereo(StereoMatchingDataset):
...@@ -874,7 +881,7 @@ class SceneFlowStereo(StereoMatchingDataset): ...@@ -874,7 +881,7 @@ class SceneFlowStereo(StereoMatchingDataset):
variant: str = "FlyingThings3D", variant: str = "FlyingThings3D",
pass_name: str = "clean", pass_name: str = "clean",
transforms: Optional[Callable] = None, transforms: Optional[Callable] = None,
): ) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
root = Path(root) / "SceneFlow" root = Path(root) / "SceneFlow"
...@@ -905,13 +912,13 @@ class SceneFlowStereo(StereoMatchingDataset): ...@@ -905,13 +912,13 @@ class SceneFlowStereo(StereoMatchingDataset):
right_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "right" / "*.pfm") right_disparity_pattern = str(root / "disparity" / prefix_directories[variant] / "right" / "*.pfm")
self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern) self._disparities += self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
disparity_map = _read_pfm_file(file_path) disparity_map = _read_pfm_file(file_path)
disparity_map = np.abs(disparity_map) # ensure that the disparity is positive disparity_map = np.abs(disparity_map) # ensure that the disparity is positive
valid_mask = None valid_mask = None
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T1:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -923,7 +930,7 @@ class SceneFlowStereo(StereoMatchingDataset): ...@@ -923,7 +930,7 @@ class SceneFlowStereo(StereoMatchingDataset):
If a ``valid_mask`` is generated within the ``transforms`` parameter, If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
""" """
return super().__getitem__(index) return cast(T1, super().__getitem__(index))
class SintelStereo(StereoMatchingDataset): class SintelStereo(StereoMatchingDataset):
...@@ -973,7 +980,7 @@ class SintelStereo(StereoMatchingDataset): ...@@ -973,7 +980,7 @@ class SintelStereo(StereoMatchingDataset):
_has_built_in_disparity_mask = True _has_built_in_disparity_mask = True
def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None): def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both")) verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
...@@ -1014,7 +1021,7 @@ class SintelStereo(StereoMatchingDataset): ...@@ -1014,7 +1021,7 @@ class SintelStereo(StereoMatchingDataset):
return occlusion_path, outofframe_path return occlusion_path, outofframe_path
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
if file_path is None: if file_path is None:
return None, None return None, None
...@@ -1034,7 +1041,7 @@ class SintelStereo(StereoMatchingDataset): ...@@ -1034,7 +1041,7 @@ class SintelStereo(StereoMatchingDataset):
valid_mask = np.logical_and(off_mask, valid_mask) valid_mask = np.logical_and(off_mask, valid_mask)
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T2:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -1045,7 +1052,7 @@ class SintelStereo(StereoMatchingDataset): ...@@ -1045,7 +1052,7 @@ class SintelStereo(StereoMatchingDataset):
The disparity is a numpy array of shape (1, H, W) and the images are PIL images whilst The disparity is a numpy array of shape (1, H, W) and the images are PIL images whilst
the valid_mask is a numpy array of shape (H, W). the valid_mask is a numpy array of shape (H, W).
""" """
return super().__getitem__(index) return cast(T2, super().__getitem__(index))
class InStereo2k(StereoMatchingDataset): class InStereo2k(StereoMatchingDataset):
...@@ -1080,7 +1087,7 @@ class InStereo2k(StereoMatchingDataset): ...@@ -1080,7 +1087,7 @@ class InStereo2k(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
""" """
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
root = Path(root) / "InStereo2k" / split root = Path(root) / "InStereo2k" / split
...@@ -1095,14 +1102,14 @@ class InStereo2k(StereoMatchingDataset): ...@@ -1095,14 +1102,14 @@ class InStereo2k(StereoMatchingDataset):
right_disparity_pattern = str(root / "*" / "right_disp.png") right_disparity_pattern = str(root / "*" / "right_disp.png")
self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern) self._disparities = self._scan_pairs(left_disparity_pattern, right_disparity_pattern)
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Tuple[np.ndarray, None]:
disparity_map = np.asarray(Image.open(file_path), dtype=np.float32) disparity_map = np.asarray(Image.open(file_path), dtype=np.float32)
# unsqueeze disparity to (C, H, W) # unsqueeze disparity to (C, H, W)
disparity_map = disparity_map[None, :, :] / 1024.0 disparity_map = disparity_map[None, :, :] / 1024.0
valid_mask = None valid_mask = None
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T1:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -1114,7 +1121,7 @@ class InStereo2k(StereoMatchingDataset): ...@@ -1114,7 +1121,7 @@ class InStereo2k(StereoMatchingDataset):
If a ``valid_mask`` is generated within the ``transforms`` parameter, If a ``valid_mask`` is generated within the ``transforms`` parameter,
a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned. a 4-tuple with ``(img_left, img_right, disparity, valid_mask)`` is returned.
""" """
return super().__getitem__(index) return cast(T1, super().__getitem__(index))
class ETH3DStereo(StereoMatchingDataset): class ETH3DStereo(StereoMatchingDataset):
...@@ -1169,7 +1176,7 @@ class ETH3DStereo(StereoMatchingDataset): ...@@ -1169,7 +1176,7 @@ class ETH3DStereo(StereoMatchingDataset):
_has_built_in_disparity_mask = True _has_built_in_disparity_mask = True
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None): def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms) super().__init__(root, transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
...@@ -1189,7 +1196,7 @@ class ETH3DStereo(StereoMatchingDataset): ...@@ -1189,7 +1196,7 @@ class ETH3DStereo(StereoMatchingDataset):
disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm") disparity_pattern = str(root / anot_dir / "*" / "disp0GT.pfm")
self._disparities = self._scan_pairs(disparity_pattern, None) self._disparities = self._scan_pairs(disparity_pattern, None)
def _read_disparity(self, file_path: str) -> Tuple: def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.ndarray, np.ndarray]]:
# test split has no disparity maps # test split has no disparity maps
if file_path is None: if file_path is None:
return None, None return None, None
...@@ -1201,7 +1208,7 @@ class ETH3DStereo(StereoMatchingDataset): ...@@ -1201,7 +1208,7 @@ class ETH3DStereo(StereoMatchingDataset):
valid_mask = np.asarray(valid_mask).astype(bool) valid_mask = np.asarray(valid_mask).astype(bool)
return disparity_map, valid_mask return disparity_map, valid_mask
def __getitem__(self, index: int) -> Tuple: def __getitem__(self, index: int) -> T2:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -1214,4 +1221,4 @@ class ETH3DStereo(StereoMatchingDataset): ...@@ -1214,4 +1221,4 @@ class ETH3DStereo(StereoMatchingDataset):
generate a valid mask. generate a valid mask.
Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test. Both ``disparity`` and ``valid_mask`` are ``None`` if the dataset split is test.
""" """
return super().__getitem__(index) return cast(T2, super().__getitem__(index))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment