Unverified Commit 93723b48 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added typing annotations to datasets/_optical_flow (#6845)



* style: Added typing annotations to datasets/_optical_flow

* style: Reverted back to str typing
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 5785e2b0
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from glob import glob from glob import glob
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -13,6 +14,10 @@ from .utils import _read_pfm, verify_str_arg ...@@ -13,6 +14,10 @@ from .utils import _read_pfm, verify_str_arg
from .vision import VisionDataset from .vision import VisionDataset
T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
__all__ = ( __all__ = (
"KittiFlow", "KittiFlow",
"Sintel", "Sintel",
...@@ -28,26 +33,26 @@ class FlowDataset(ABC, VisionDataset): ...@@ -28,26 +33,26 @@ class FlowDataset(ABC, VisionDataset):
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be. # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
_has_builtin_flow_mask = False _has_builtin_flow_mask = False
def __init__(self, root, transforms=None): def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
super().__init__(root=root) super().__init__(root=root)
self.transforms = transforms self.transforms = transforms
self._flow_list = [] self._flow_list: List[str] = []
self._image_list = [] self._image_list: List[List[str]] = []
def _read_img(self, file_name): def _read_img(self, file_name: str) -> Image.Image:
img = Image.open(file_name) img = Image.open(file_name)
if img.mode != "RGB": if img.mode != "RGB":
img = img.convert("RGB") img = img.convert("RGB")
return img return img
@abstractmethod @abstractmethod
def _read_flow(self, file_name): def _read_flow(self, file_name: str):
# Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True # Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True
pass pass
def __getitem__(self, index): def __getitem__(self, index: int) -> Union[T1, T2]:
img1 = self._read_img(self._image_list[index][0]) img1 = self._read_img(self._image_list[index][0])
img2 = self._read_img(self._image_list[index][1]) img2 = self._read_img(self._image_list[index][1])
...@@ -70,10 +75,10 @@ class FlowDataset(ABC, VisionDataset): ...@@ -70,10 +75,10 @@ class FlowDataset(ABC, VisionDataset):
else: else:
return img1, img2, flow return img1, img2, flow
def __len__(self): def __len__(self) -> int:
return len(self._image_list) return len(self._image_list)
def __rmul__(self, v): def __rmul__(self, v: int) -> torch.utils.data.ConcatDataset:
return torch.utils.data.ConcatDataset([self] * v) return torch.utils.data.ConcatDataset([self] * v)
...@@ -118,7 +123,13 @@ class Sintel(FlowDataset): ...@@ -118,7 +123,13 @@ class Sintel(FlowDataset):
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
""" """
def __init__(self, root, split="train", pass_name="clean", transforms=None): def __init__(
self,
root: str,
split: str = "train",
pass_name: str = "clean",
transforms: Optional[Callable] = None,
) -> None:
super().__init__(root=root, transforms=transforms) super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
...@@ -139,7 +150,7 @@ class Sintel(FlowDataset): ...@@ -139,7 +150,7 @@ class Sintel(FlowDataset):
if split == "train": if split == "train":
self._flow_list += sorted(glob(str(flow_root / scene / "*.flo"))) self._flow_list += sorted(glob(str(flow_root / scene / "*.flo")))
def __getitem__(self, index): def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -154,7 +165,7 @@ class Sintel(FlowDataset): ...@@ -154,7 +165,7 @@ class Sintel(FlowDataset):
""" """
return super().__getitem__(index) return super().__getitem__(index)
def _read_flow(self, file_name): def _read_flow(self, file_name: str) -> np.ndarray:
return _read_flo(file_name) return _read_flo(file_name)
...@@ -180,7 +191,7 @@ class KittiFlow(FlowDataset): ...@@ -180,7 +191,7 @@ class KittiFlow(FlowDataset):
_has_builtin_flow_mask = True _has_builtin_flow_mask = True
def __init__(self, root, split="train", transforms=None): def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms) super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
...@@ -200,7 +211,7 @@ class KittiFlow(FlowDataset): ...@@ -200,7 +211,7 @@ class KittiFlow(FlowDataset):
if split == "train": if split == "train":
self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png"))) self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png")))
def __getitem__(self, index): def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -215,7 +226,7 @@ class KittiFlow(FlowDataset): ...@@ -215,7 +226,7 @@ class KittiFlow(FlowDataset):
""" """
return super().__getitem__(index) return super().__getitem__(index)
def _read_flow(self, file_name): def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
return _read_16bits_png_with_flow_and_valid_mask(file_name) return _read_16bits_png_with_flow_and_valid_mask(file_name)
...@@ -245,7 +256,7 @@ class FlyingChairs(FlowDataset): ...@@ -245,7 +256,7 @@ class FlyingChairs(FlowDataset):
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
""" """
def __init__(self, root, split="train", transforms=None): def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms) super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "val")) verify_str_arg(split, "split", valid_values=("train", "val"))
...@@ -268,7 +279,7 @@ class FlyingChairs(FlowDataset): ...@@ -268,7 +279,7 @@ class FlyingChairs(FlowDataset):
self._flow_list += [flows[i]] self._flow_list += [flows[i]]
self._image_list += [[images[2 * i], images[2 * i + 1]]] self._image_list += [[images[2 * i], images[2 * i + 1]]]
def __getitem__(self, index): def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -283,7 +294,7 @@ class FlyingChairs(FlowDataset): ...@@ -283,7 +294,7 @@ class FlyingChairs(FlowDataset):
""" """
return super().__getitem__(index) return super().__getitem__(index)
def _read_flow(self, file_name): def _read_flow(self, file_name: str) -> np.ndarray:
return _read_flo(file_name) return _read_flo(file_name)
...@@ -316,7 +327,14 @@ class FlyingThings3D(FlowDataset): ...@@ -316,7 +327,14 @@ class FlyingThings3D(FlowDataset):
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
""" """
def __init__(self, root, split="train", pass_name="clean", camera="left", transforms=None): def __init__(
self,
root: str,
split: str = "train",
pass_name: str = "clean",
camera: str = "left",
transforms: Optional[Callable] = None,
) -> None:
super().__init__(root=root, transforms=transforms) super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
...@@ -359,7 +377,7 @@ class FlyingThings3D(FlowDataset): ...@@ -359,7 +377,7 @@ class FlyingThings3D(FlowDataset):
self._image_list += [[images[i + 1], images[i]]] self._image_list += [[images[i + 1], images[i]]]
self._flow_list += [flows[i + 1]] self._flow_list += [flows[i + 1]]
def __getitem__(self, index): def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -374,7 +392,7 @@ class FlyingThings3D(FlowDataset): ...@@ -374,7 +392,7 @@ class FlyingThings3D(FlowDataset):
""" """
return super().__getitem__(index) return super().__getitem__(index)
def _read_flow(self, file_name): def _read_flow(self, file_name: str) -> np.ndarray:
return _read_pfm(file_name) return _read_pfm(file_name)
...@@ -401,7 +419,7 @@ class HD1K(FlowDataset): ...@@ -401,7 +419,7 @@ class HD1K(FlowDataset):
_has_builtin_flow_mask = True _has_builtin_flow_mask = True
def __init__(self, root, split="train", transforms=None): def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms) super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
...@@ -426,10 +444,10 @@ class HD1K(FlowDataset): ...@@ -426,10 +444,10 @@ class HD1K(FlowDataset):
"Could not find the HD1K images. Please make sure the directory structure is correct." "Could not find the HD1K images. Please make sure the directory structure is correct."
) )
def _read_flow(self, file_name): def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
return _read_16bits_png_with_flow_and_valid_mask(file_name) return _read_16bits_png_with_flow_and_valid_mask(file_name)
def __getitem__(self, index): def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index. """Return example at given index.
Args: Args:
...@@ -445,7 +463,7 @@ class HD1K(FlowDataset): ...@@ -445,7 +463,7 @@ class HD1K(FlowDataset):
return super().__getitem__(index) return super().__getitem__(index)
def _read_flo(file_name): def _read_flo(file_name: str) -> np.ndarray:
"""Read .flo file in Middlebury format""" """Read .flo file in Middlebury format"""
# Code adapted from: # Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
...@@ -462,7 +480,7 @@ def _read_flo(file_name): ...@@ -462,7 +480,7 @@ def _read_flo(file_name):
return data.reshape(h, w, 2).transpose(2, 0, 1) return data.reshape(h, w, 2).transpose(2, 0, 1)
def _read_16bits_png_with_flow_and_valid_mask(file_name): def _read_16bits_png_with_flow_and_valid_mask(file_name: str) -> Tuple[np.ndarray, np.ndarray]:
flow_and_valid = _read_png_16(file_name).to(torch.float32) flow_and_valid = _read_png_16(file_name).to(torch.float32)
flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :] flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment