"...text-generation-inference.git" did not exist on "68e9d6ab333715008c542467c8d5202cf4692253"
Unverified Commit 7f424379 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add FlyingChairs dataset for optical flow (#4860)

parent eb48a1d8
...@@ -43,6 +43,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas ...@@ -43,6 +43,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
FashionMNIST FashionMNIST
Flickr8k Flickr8k
Flickr30k Flickr30k
FlyingChairs
HMDB51 HMDB51
ImageNet ImageNet
INaturalist INaturalist
......
...@@ -8,6 +8,7 @@ import pathlib ...@@ -8,6 +8,7 @@ import pathlib
import random import random
import shutil import shutil
import string import string
import struct
import tarfile import tarfile
import unittest import unittest
import unittest.mock import unittest.mock
...@@ -922,3 +923,11 @@ def create_random_string(length: int, *digits: str) -> str: ...@@ -922,3 +923,11 @@ def create_random_string(length: int, *digits: str) -> str:
digits = "".join(itertools.chain(*digits)) digits = "".join(itertools.chain(*digits))
return "".join(random.choice(digits) for _ in range(length)) return "".join(random.choice(digits) for _ in range(length))
def make_fake_flo_file(h, w, file_name):
"""Creates a fake flow file in .flo format."""
values = list(range(2 * h * w))
content = b"PIEH" + struct.pack("i", w) + struct.pack("i", h) + struct.pack("f" * len(values), *values)
with open(file_name, "wb") as f:
f.write(content)
...@@ -1874,11 +1874,9 @@ class LFWPairsTestCase(LFWPeopleTestCase): ...@@ -1874,11 +1874,9 @@ class LFWPairsTestCase(LFWPeopleTestCase):
class SintelTestCase(datasets_utils.ImageDatasetTestCase): class SintelTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Sintel DATASET_CLASS = datasets.Sintel
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"), pass_name=("clean", "final")) ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"), pass_name=("clean", "final"))
# We patch the flow reader, because this would otherwise force us to generate fake (but readable) .flo files, FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
# which is something we want to # avoid.
_FAKE_FLOW = "Fake Flow" FLOW_H, FLOW_W = 3, 4
EXTRA_PATCHES = {unittest.mock.patch("torchvision.datasets.Sintel._read_flow", return_value=_FAKE_FLOW)}
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (type(_FAKE_FLOW), type(None)))
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
root = pathlib.Path(tmpdir) / "Sintel" root = pathlib.Path(tmpdir) / "Sintel"
...@@ -1899,14 +1897,13 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1899,14 +1897,13 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
num_examples=num_images_per_scene, num_examples=num_images_per_scene,
) )
# For the ground truth flow value we just create empty files so that they're properly discovered,
# see comment above about EXTRA_PATCHES
flow_root = root / "training" / "flow" flow_root = root / "training" / "flow"
for scene_id in range(num_scenes): for scene_id in range(num_scenes):
scene_dir = flow_root / f"scene_{scene_id}" scene_dir = flow_root / f"scene_{scene_id}"
os.makedirs(scene_dir) os.makedirs(scene_dir)
for i in range(num_images_per_scene - 1): for i in range(num_images_per_scene - 1):
open(str(scene_dir / f"frame_000{i}.flo"), "a").close() file_name = str(scene_dir / f"frame_000{i}.flo")
datasets_utils.make_fake_flo_file(h=self.FLOW_H, w=self.FLOW_W, file_name=file_name)
# with e.g. num_images_per_scene = 3, for a single scene with have 3 images # with e.g. num_images_per_scene = 3, for a single scene with have 3 images
# which are frame_0000, frame_0001 and frame_0002 # which are frame_0000, frame_0001 and frame_0002
...@@ -1920,7 +1917,8 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1920,7 +1917,8 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
with self.create_dataset(split="train") as (dataset, _): with self.create_dataset(split="train") as (dataset, _):
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list) assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
for _, _, flow in dataset: for _, _, flow in dataset:
assert flow == self._FAKE_FLOW assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape))
# Make sure flow is always None for test split # Make sure flow is always None for test split
with self.create_dataset(split="test") as (dataset, _): with self.create_dataset(split="test") as (dataset, _):
...@@ -1929,11 +1927,11 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1929,11 +1927,11 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
assert flow is None assert flow is None
def test_bad_input(self): def test_bad_input(self):
with pytest.raises(ValueError, match="split must be either"): with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
with self.create_dataset(split="bad"): with self.create_dataset(split="bad"):
pass pass
with pytest.raises(ValueError, match="pass_name must be either"): with pytest.raises(ValueError, match="Unknown value 'bad' for argument pass_name"):
with self.create_dataset(pass_name="bad"): with self.create_dataset(pass_name="bad"):
pass pass
...@@ -1993,10 +1991,62 @@ class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1993,10 +1991,62 @@ class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase):
assert valid is None assert valid is None
def test_bad_input(self): def test_bad_input(self):
with pytest.raises(ValueError, match="split must be either"): with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
with self.create_dataset(split="bad"): with self.create_dataset(split="bad"):
pass pass
class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.FlyingChairs
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
FLOW_H, FLOW_W = 3, 4
def _make_split_file(self, root, num_examples):
# We create a fake split file here, but users are asked to download the real one from the authors website
split_ids = [1] * num_examples["train"] + [2] * num_examples["val"]
random.shuffle(split_ids)
with open(str(root / "FlyingChairs_train_val.txt"), "w+") as split_file:
for split_id in split_ids:
split_file.write(f"{split_id}\n")
def inject_fake_data(self, tmpdir, config):
root = pathlib.Path(tmpdir) / "FlyingChairs"
num_examples = {"train": 5, "val": 3}
num_examples_total = sum(num_examples.values())
datasets_utils.create_image_folder( # img1
root,
name="data",
file_name_fn=lambda image_idx: f"00{image_idx}_img1.ppm",
num_examples=num_examples_total,
)
datasets_utils.create_image_folder( # img2
root,
name="data",
file_name_fn=lambda image_idx: f"00{image_idx}_img2.ppm",
num_examples=num_examples_total,
)
for i in range(num_examples_total):
file_name = str(root / "data" / f"00{i}_flow.flo")
datasets_utils.make_fake_flo_file(h=self.FLOW_H, w=self.FLOW_W, file_name=file_name)
self._make_split_file(root, num_examples)
return num_examples[config["split"]]
@datasets_utils.test_all_configs
def test_flow(self, config):
# Make sure flow always exists, and make sure there are as many flow values as (pairs of) images
# Also make sure the flow is properly decoded
with self.create_dataset(config=config) as (dataset, _):
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
for _, _, flow in dataset:
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
from ._optical_flow import KittiFlow, Sintel from ._optical_flow import KittiFlow, Sintel, FlyingChairs
from .caltech import Caltech101, Caltech256 from .caltech import Caltech101, Caltech256
from .celeba import CelebA from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100 from .cifar import CIFAR10, CIFAR100
...@@ -74,4 +74,5 @@ __all__ = ( ...@@ -74,4 +74,5 @@ __all__ = (
"LFWPairs", "LFWPairs",
"KittiFlow", "KittiFlow",
"Sintel", "Sintel",
"FlyingChairs",
) )
...@@ -8,12 +8,14 @@ import torch ...@@ -8,12 +8,14 @@ import torch
from PIL import Image from PIL import Image
from ..io.image import _read_png_16 from ..io.image import _read_png_16
from .utils import verify_str_arg
from .vision import VisionDataset from .vision import VisionDataset
__all__ = ( __all__ = (
"KittiFlow", "KittiFlow",
"Sintel", "Sintel",
"FlyingChairs",
) )
...@@ -109,11 +111,8 @@ class Sintel(FlowDataset): ...@@ -109,11 +111,8 @@ class Sintel(FlowDataset):
def __init__(self, root, split="train", pass_name="clean", transforms=None): def __init__(self, root, split="train", pass_name="clean", transforms=None):
super().__init__(root=root, transforms=transforms) super().__init__(root=root, transforms=transforms)
if split not in ("train", "test"): verify_str_arg(split, "split", valid_values=("train", "test"))
raise ValueError("split must be either 'train' or 'test'") verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final"))
if pass_name not in ("clean", "final"):
raise ValueError("pass_name must be either 'clean' or 'final'")
root = Path(root) / "Sintel" root = Path(root) / "Sintel"
...@@ -171,8 +170,7 @@ class KittiFlow(FlowDataset): ...@@ -171,8 +170,7 @@ class KittiFlow(FlowDataset):
def __init__(self, root, split="train", transforms=None): def __init__(self, root, split="train", transforms=None):
super().__init__(root=root, transforms=transforms) super().__init__(root=root, transforms=transforms)
if split not in ("train", "test"): verify_str_arg(split, "split", valid_values=("train", "test"))
raise ValueError("split must be either 'train' or 'test'")
root = Path(root) / "Kitti" / (split + "ing") root = Path(root) / "Kitti" / (split + "ing")
images1 = sorted(glob(str(root / "image_2" / "*_10.png"))) images1 = sorted(glob(str(root / "image_2" / "*_10.png")))
...@@ -208,6 +206,71 @@ class KittiFlow(FlowDataset): ...@@ -208,6 +206,71 @@ class KittiFlow(FlowDataset):
return _read_16bits_png_with_flow_and_valid_mask(file_name) return _read_16bits_png_with_flow_and_valid_mask(file_name)
class FlyingChairs(FlowDataset):
"""`FlyingChairs <https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs>`_ Dataset for optical flow.
You will also need to download the FlyingChairs_train_val.txt file from the dataset page.
The dataset is expected to have the following structure: ::
root
FlyingChairs
data
00001_flow.flo
00001_img1.ppm
00001_img2.ppm
...
FlyingChairs_train_val.txt
Args:
root (string): Root directory of the FlyingChairs Dataset.
split (string, optional): The dataset split, either "train" (default) or "val"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
def __init__(self, root, split="train", transforms=None):
super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "val"))
root = Path(root) / "FlyingChairs"
images = sorted(glob(str(root / "data" / "*.ppm")))
flows = sorted(glob(str(root / "data" / "*.flo")))
split_file_name = "FlyingChairs_train_val.txt"
if not os.path.exists(root / split_file_name):
raise FileNotFoundError(
"The FlyingChairs_train_val.txt file was not found - please download it from the dataset page (see docstring)."
)
split_list = np.loadtxt(str(root / split_file_name), dtype=np.int32)
for i in range(len(flows)):
split_id = split_list[i]
if (split == "train" and split_id == 1) or (split == "val" and split_id == 2):
self._flow_list += [flows[i]]
self._image_list += [[images[2 * i], images[2 * i + 1]]]
def __getitem__(self, index):
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
"""
return super().__getitem__(index)
def _read_flow(self, file_name):
return _read_flo(file_name)
def _read_flo(file_name): def _read_flo(file_name):
"""Read .flo file in Middlebury format""" """Read .flo file in Middlebury format"""
# Code adapted from: # Code adapted from:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment