import itertools import os from abc import ABC, abstractmethod from glob import glob from pathlib import Path import numpy as np import torch from PIL import Image from ..io.image import _read_png_16 from .utils import verify_str_arg, _read_pfm from .vision import VisionDataset __all__ = ( "KittiFlow", "Sintel", "FlyingThings3D", "FlyingChairs", "HD1K", ) class FlowDataset(ABC, VisionDataset): # Some datasets like Kitti have a built-in valid_flow_mask, indicating which flow values are valid # For those we return (img1, img2, flow, valid_flow_mask), and for the rest we return (img1, img2, flow), # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be. _has_builtin_flow_mask = False def __init__(self, root, transforms=None): super().__init__(root=root) self.transforms = transforms self._flow_list = [] self._image_list = [] def _read_img(self, file_name): img = Image.open(file_name) if img.mode != "RGB": img = img.convert("RGB") return img @abstractmethod def _read_flow(self, file_name): # Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True pass def __getitem__(self, index): img1 = self._read_img(self._image_list[index][0]) img2 = self._read_img(self._image_list[index][1]) if self._flow_list: # it will be empty for some dataset when split="test" flow = self._read_flow(self._flow_list[index]) if self._has_builtin_flow_mask: flow, valid_flow_mask = flow else: valid_flow_mask = None else: flow = valid_flow_mask = None if self.transforms is not None: img1, img2, flow, valid_flow_mask = self.transforms(img1, img2, flow, valid_flow_mask) if self._has_builtin_flow_mask or valid_flow_mask is not None: # The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform return img1, img2, flow, valid_flow_mask else: return img1, img2, flow def __len__(self): return len(self._image_list) def __rmul__(self, v): return torch.utils.data.ConcatDataset([self] * v) class Sintel(FlowDataset): """`Sintel `_ Dataset for optical flow. The dataset is expected to have the following structure: :: root Sintel testing clean scene_1 scene_2 ... final scene_1 scene_2 ... training clean scene_1 scene_2 ... final scene_1 scene_2 ... flow scene_1 scene_2 ... Args: root (string): Root directory of the Sintel Dataset. split (string, optional): The dataset split, either "train" (default) or "test" pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for details on the different passes. transforms (callable, optional): A function/transform that takes in ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. ``valid_flow_mask`` is expected for consistency with other datasets which return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ def __init__(self, root, split="train", pass_name="clean", transforms=None): super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both")) passes = ["clean", "final"] if pass_name == "both" else [pass_name] root = Path(root) / "Sintel" flow_root = root / "training" / "flow" for pass_name in passes: split_dir = "training" if split == "train" else split image_root = root / split_dir / pass_name for scene in os.listdir(image_root): image_list = sorted(glob(str(image_root / scene / "*.png"))) for i in range(len(image_list) - 1): self._image_list += [[image_list[i], image_list[i + 1]]] if split == "train": self._flow_list += sorted(glob(str(flow_root / scene / "*.flo"))) def __getitem__(self, index): """Return example at given index. Args: index(int): The index of the example to retrieve Returns: tuple: A 3-tuple with ``(img1, img2, flow)``. The flow is a numpy array of shape (2, H, W) and the images are PIL images. ``flow`` is None if ``split="test"``. If a valid flow mask is generated within the ``transforms`` parameter, a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned. """ return super().__getitem__(index) def _read_flow(self, file_name): return _read_flo(file_name) class KittiFlow(FlowDataset): """`KITTI `__ dataset for optical flow (2015). The dataset is expected to have the following structure: :: root KittiFlow testing image_2 training image_2 flow_occ Args: root (string): Root directory of the KittiFlow Dataset. split (string, optional): The dataset split, either "train" (default) or "test" transforms (callable, optional): A function/transform that takes in ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. """ _has_builtin_flow_mask = True def __init__(self, root, split="train", transforms=None): super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "test")) root = Path(root) / "KittiFlow" / (split + "ing") images1 = sorted(glob(str(root / "image_2" / "*_10.png"))) images2 = sorted(glob(str(root / "image_2" / "*_11.png"))) if not images1 or not images2: raise FileNotFoundError( "Could not find the Kitti flow images. Please make sure the directory structure is correct." ) for img1, img2 in zip(images1, images2): self._image_list += [[img1, img2]] if split == "train": self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png"))) def __getitem__(self, index): """Return example at given index. Args: index(int): The index of the example to retrieve Returns: tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask`` is a numpy boolean mask of shape (H, W) indicating which flow values are valid. The flow is a numpy array of shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if ``split="test"``. """ return super().__getitem__(index) def _read_flow(self, file_name): return _read_16bits_png_with_flow_and_valid_mask(file_name) class FlyingChairs(FlowDataset): """`FlyingChairs `_ Dataset for optical flow. You will also need to download the FlyingChairs_train_val.txt file from the dataset page. The dataset is expected to have the following structure: :: root FlyingChairs data 00001_flow.flo 00001_img1.ppm 00001_img2.ppm ... FlyingChairs_train_val.txt Args: root (string): Root directory of the FlyingChairs Dataset. split (string, optional): The dataset split, either "train" (default) or "val" transforms (callable, optional): A function/transform that takes in ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. ``valid_flow_mask`` is expected for consistency with other datasets which return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ def __init__(self, root, split="train", transforms=None): super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "val")) root = Path(root) / "FlyingChairs" images = sorted(glob(str(root / "data" / "*.ppm"))) flows = sorted(glob(str(root / "data" / "*.flo"))) split_file_name = "FlyingChairs_train_val.txt" if not os.path.exists(root / split_file_name): raise FileNotFoundError( "The FlyingChairs_train_val.txt file was not found - please download it from the dataset page (see docstring)." ) split_list = np.loadtxt(str(root / split_file_name), dtype=np.int32) for i in range(len(flows)): split_id = split_list[i] if (split == "train" and split_id == 1) or (split == "val" and split_id == 2): self._flow_list += [flows[i]] self._image_list += [[images[2 * i], images[2 * i + 1]]] def __getitem__(self, index): """Return example at given index. Args: index(int): The index of the example to retrieve Returns: tuple: A 3-tuple with ``(img1, img2, flow)``. The flow is a numpy array of shape (2, H, W) and the images are PIL images. ``flow`` is None if ``split="val"``. If a valid flow mask is generated within the ``transforms`` parameter, a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned. """ return super().__getitem__(index) def _read_flow(self, file_name): return _read_flo(file_name) class FlyingThings3D(FlowDataset): """`FlyingThings3D `_ dataset for optical flow. The dataset is expected to have the following structure: :: root FlyingThings3D frames_cleanpass TEST TRAIN frames_finalpass TEST TRAIN optical_flow TEST TRAIN Args: root (string): Root directory of the intel FlyingThings3D Dataset. split (string, optional): The dataset split, either "train" (default) or "test" pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for details on the different passes. camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both". transforms (callable, optional): A function/transform that takes in ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. ``valid_flow_mask`` is expected for consistency with other datasets which return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ def __init__(self, root, split="train", pass_name="clean", camera="left", transforms=None): super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "test")) split = split.upper() verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both")) passes = { "clean": ["frames_cleanpass"], "final": ["frames_finalpass"], "both": ["frames_cleanpass", "frames_finalpass"], }[pass_name] verify_str_arg(camera, "camera", valid_values=("left", "right", "both")) cameras = ["left", "right"] if camera == "both" else [camera] root = Path(root) / "FlyingThings3D" directions = ("into_future", "into_past") for pass_name, camera, direction in itertools.product(passes, cameras, directions): image_dirs = sorted(glob(str(root / pass_name / split / "*/*"))) image_dirs = sorted(Path(image_dir) / camera for image_dir in image_dirs) flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*"))) flow_dirs = sorted(Path(flow_dir) / direction / camera for flow_dir in flow_dirs) if not image_dirs or not flow_dirs: raise FileNotFoundError( "Could not find the FlyingThings3D flow images. " "Please make sure the directory structure is correct." ) for image_dir, flow_dir in zip(image_dirs, flow_dirs): images = sorted(glob(str(image_dir / "*.png"))) flows = sorted(glob(str(flow_dir / "*.pfm"))) for i in range(len(flows) - 1): if direction == "into_future": self._image_list += [[images[i], images[i + 1]]] self._flow_list += [flows[i]] elif direction == "into_past": self._image_list += [[images[i + 1], images[i]]] self._flow_list += [flows[i + 1]] def __getitem__(self, index): """Return example at given index. Args: index(int): The index of the example to retrieve Returns: tuple: A 3-tuple with ``(img1, img2, flow)``. The flow is a numpy array of shape (2, H, W) and the images are PIL images. ``flow`` is None if ``split="test"``. If a valid flow mask is generated within the ``transforms`` parameter, a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned. """ return super().__getitem__(index) def _read_flow(self, file_name): return _read_pfm(file_name) class HD1K(FlowDataset): """`HD1K `__ dataset for optical flow. The dataset is expected to have the following structure: :: root hd1k hd1k_challenge image_2 hd1k_flow_gt flow_occ hd1k_input image_2 Args: root (string): Root directory of the HD1K Dataset. split (string, optional): The dataset split, either "train" (default) or "test" transforms (callable, optional): A function/transform that takes in ``img1, img2, flow, valid_flow_mask`` and returns a transformed version. """ _has_builtin_flow_mask = True def __init__(self, root, split="train", transforms=None): super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "test")) root = Path(root) / "hd1k" if split == "train": # There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop for seq_idx in range(36): flows = sorted(glob(str(root / "hd1k_flow_gt" / "flow_occ" / f"{seq_idx:06d}_*.png"))) images = sorted(glob(str(root / "hd1k_input" / "image_2" / f"{seq_idx:06d}_*.png"))) for i in range(len(flows) - 1): self._flow_list += [flows[i]] self._image_list += [[images[i], images[i + 1]]] else: images1 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*10.png"))) images2 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*11.png"))) for image1, image2 in zip(images1, images2): self._image_list += [[image1, image2]] if not self._image_list: raise FileNotFoundError( "Could not find the HD1K images. Please make sure the directory structure is correct." ) def _read_flow(self, file_name): return _read_16bits_png_with_flow_and_valid_mask(file_name) def __getitem__(self, index): """Return example at given index. Args: index(int): The index of the example to retrieve Returns: tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask`` is a numpy boolean mask of shape (H, W) indicating which flow values are valid. The flow is a numpy array of shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if ``split="test"``. """ return super().__getitem__(index) def _read_flo(file_name): """Read .flo file in Middlebury format""" # Code adapted from: # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy # Everything needs to be in little Endian according to # https://vision.middlebury.edu/flow/code/flow-code/README.txt with open(file_name, "rb") as f: magic = np.fromfile(f, "c", count=4).tobytes() if magic != b"PIEH": raise ValueError("Magic number incorrect. Invalid .flo file") w = int(np.fromfile(f, "