Unverified Commit 93723b48 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added typing annotations to datasets/_optical_flow (#6845)



* style: Added typing annotations to datasets/_optical_flow

* style: Reverted back to str typing
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 5785e2b0
......@@ -3,6 +3,7 @@ import os
from abc import ABC, abstractmethod
from glob import glob
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -13,6 +14,10 @@ from .utils import _read_pfm, verify_str_arg
from .vision import VisionDataset
T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]
__all__ = (
"KittiFlow",
"Sintel",
......@@ -28,26 +33,26 @@ class FlowDataset(ABC, VisionDataset):
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
_has_builtin_flow_mask = False
def __init__(self, root, transforms=None):
def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
super().__init__(root=root)
self.transforms = transforms
self._flow_list = []
self._image_list = []
self._flow_list: List[str] = []
self._image_list: List[List[str]] = []
def _read_img(self, file_name):
def _read_img(self, file_name: str) -> Image.Image:
img = Image.open(file_name)
if img.mode != "RGB":
img = img.convert("RGB")
return img
@abstractmethod
def _read_flow(self, file_name):
def _read_flow(self, file_name: str):
# Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True
pass
def __getitem__(self, index):
def __getitem__(self, index: int) -> Union[T1, T2]:
img1 = self._read_img(self._image_list[index][0])
img2 = self._read_img(self._image_list[index][1])
......@@ -70,10 +75,10 @@ class FlowDataset(ABC, VisionDataset):
else:
return img1, img2, flow
def __len__(self):
def __len__(self) -> int:
return len(self._image_list)
def __rmul__(self, v):
def __rmul__(self, v: int) -> torch.utils.data.ConcatDataset:
return torch.utils.data.ConcatDataset([self] * v)
......@@ -118,7 +123,13 @@ class Sintel(FlowDataset):
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
def __init__(self, root, split="train", pass_name="clean", transforms=None):
def __init__(
self,
root: str,
split: str = "train",
pass_name: str = "clean",
transforms: Optional[Callable] = None,
) -> None:
super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test"))
......@@ -139,7 +150,7 @@ class Sintel(FlowDataset):
if split == "train":
self._flow_list += sorted(glob(str(flow_root / scene / "*.flo")))
def __getitem__(self, index):
def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index.
Args:
......@@ -154,7 +165,7 @@ class Sintel(FlowDataset):
"""
return super().__getitem__(index)
def _read_flow(self, file_name):
def _read_flow(self, file_name: str) -> np.ndarray:
return _read_flo(file_name)
......@@ -180,7 +191,7 @@ class KittiFlow(FlowDataset):
_has_builtin_flow_mask = True
def __init__(self, root, split="train", transforms=None):
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test"))
......@@ -200,7 +211,7 @@ class KittiFlow(FlowDataset):
if split == "train":
self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png")))
def __getitem__(self, index):
def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index.
Args:
......@@ -215,7 +226,7 @@ class KittiFlow(FlowDataset):
"""
return super().__getitem__(index)
def _read_flow(self, file_name):
def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
return _read_16bits_png_with_flow_and_valid_mask(file_name)
......@@ -245,7 +256,7 @@ class FlyingChairs(FlowDataset):
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
def __init__(self, root, split="train", transforms=None):
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "val"))
......@@ -268,7 +279,7 @@ class FlyingChairs(FlowDataset):
self._flow_list += [flows[i]]
self._image_list += [[images[2 * i], images[2 * i + 1]]]
def __getitem__(self, index):
def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index.
Args:
......@@ -283,7 +294,7 @@ class FlyingChairs(FlowDataset):
"""
return super().__getitem__(index)
def _read_flow(self, file_name):
def _read_flow(self, file_name: str) -> np.ndarray:
return _read_flo(file_name)
......@@ -316,7 +327,14 @@ class FlyingThings3D(FlowDataset):
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
def __init__(self, root, split="train", pass_name="clean", camera="left", transforms=None):
def __init__(
self,
root: str,
split: str = "train",
pass_name: str = "clean",
camera: str = "left",
transforms: Optional[Callable] = None,
) -> None:
super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test"))
......@@ -359,7 +377,7 @@ class FlyingThings3D(FlowDataset):
self._image_list += [[images[i + 1], images[i]]]
self._flow_list += [flows[i + 1]]
def __getitem__(self, index):
def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index.
Args:
......@@ -374,7 +392,7 @@ class FlyingThings3D(FlowDataset):
"""
return super().__getitem__(index)
def _read_flow(self, file_name):
def _read_flow(self, file_name: str) -> np.ndarray:
return _read_pfm(file_name)
......@@ -401,7 +419,7 @@ class HD1K(FlowDataset):
_has_builtin_flow_mask = True
def __init__(self, root, split="train", transforms=None):
def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test"))
......@@ -426,10 +444,10 @@ class HD1K(FlowDataset):
"Could not find the HD1K images. Please make sure the directory structure is correct."
)
def _read_flow(self, file_name):
def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
return _read_16bits_png_with_flow_and_valid_mask(file_name)
def __getitem__(self, index):
def __getitem__(self, index: int) -> Union[T1, T2]:
"""Return example at given index.
Args:
......@@ -445,7 +463,7 @@ class HD1K(FlowDataset):
return super().__getitem__(index)
def _read_flo(file_name):
def _read_flo(file_name: str) -> np.ndarray:
"""Read .flo file in Middlebury format"""
# Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
......@@ -462,7 +480,7 @@ def _read_flo(file_name):
return data.reshape(h, w, 2).transpose(2, 0, 1)
def _read_16bits_png_with_flow_and_valid_mask(file_name):
def _read_16bits_png_with_flow_and_valid_mask(file_name: str) -> Tuple[np.ndarray, np.ndarray]:
flow_and_valid = _read_png_16(file_name).to(torch.float32)
flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment