Unverified Commit b9f4ed93 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add HD1K dataset for optical flow (#4890)

parent 22ff44fd
...@@ -45,6 +45,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas ...@@ -45,6 +45,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
Flickr30k Flickr30k
FlyingChairs FlyingChairs
FlyingThings3D FlyingThings3D
HD1K
HMDB51 HMDB51
ImageNet ImageNet
INaturalist INaturalist
......
...@@ -2126,5 +2126,47 @@ class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2126,5 +2126,47 @@ class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase):
pass pass
class HD1KTestCase(KittiFlowTestCase):
DATASET_CLASS = datasets.HD1K
def inject_fake_data(self, tmpdir, config):
root = pathlib.Path(tmpdir) / "hd1k"
num_sequences = 4 if config["split"] == "train" else 3
num_examples_per_train_sequence = 3
for seq_idx in range(num_sequences):
# Training data
datasets_utils.create_image_folder(
root / "hd1k_input",
name="image_2",
file_name_fn=lambda image_idx: f"{seq_idx:06d}_{image_idx}.png",
num_examples=num_examples_per_train_sequence,
)
datasets_utils.create_image_folder(
root / "hd1k_flow_gt",
name="flow_occ",
file_name_fn=lambda image_idx: f"{seq_idx:06d}_{image_idx}.png",
num_examples=num_examples_per_train_sequence,
)
# Test data
datasets_utils.create_image_folder(
root / "hd1k_challenge",
name="image_2",
file_name_fn=lambda _: f"{seq_idx:06d}_10.png",
num_examples=1,
)
datasets_utils.create_image_folder(
root / "hd1k_challenge",
name="image_2",
file_name_fn=lambda _: f"{seq_idx:06d}_11.png",
num_examples=1,
)
num_examples_per_sequence = num_examples_per_train_sequence if config["split"] == "train" else 2
return num_sequences * (num_examples_per_sequence - 1)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D, HD1K
from .caltech import Caltech101, Caltech256 from .caltech import Caltech101, Caltech256
from .celeba import CelebA from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100 from .cifar import CIFAR10, CIFAR100
...@@ -76,4 +76,5 @@ __all__ = ( ...@@ -76,4 +76,5 @@ __all__ = (
"Sintel", "Sintel",
"FlyingChairs", "FlyingChairs",
"FlyingThings3D", "FlyingThings3D",
"HD1K",
) )
...@@ -19,6 +19,7 @@ __all__ = ( ...@@ -19,6 +19,7 @@ __all__ = (
"Sintel", "Sintel",
"FlyingThings3D", "FlyingThings3D",
"FlyingChairs", "FlyingChairs",
"HD1K",
) )
...@@ -363,6 +364,73 @@ class FlyingThings3D(FlowDataset): ...@@ -363,6 +364,73 @@ class FlyingThings3D(FlowDataset):
return _read_pfm(file_name) return _read_pfm(file_name)
class HD1K(FlowDataset):
"""`HD1K <http://hci-benchmark.iwr.uni-heidelberg.de/>`__ dataset for optical flow.
The dataset is expected to have the following structure: ::
root
hd1k
hd1k_challenge
image_2
hd1k_flow_gt
flow_occ
hd1k_input
image_2
Args:
root (string): Root directory of the HD1K Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
"""
_has_builtin_flow_mask = True
def __init__(self, root, split="train", transforms=None):
super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test"))
root = Path(root) / "hd1k"
if split == "train":
# There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop
for seq_idx in range(36):
flows = sorted(glob(str(root / "hd1k_flow_gt" / "flow_occ" / f"{seq_idx:06d}_*.png")))
images = sorted(glob(str(root / "hd1k_input" / "image_2" / f"{seq_idx:06d}_*.png")))
for i in range(len(flows) - 1):
self._flow_list += [flows[i]]
self._image_list += [[images[i], images[i + 1]]]
else:
images1 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*10.png")))
images2 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*11.png")))
for image1, image2 in zip(images1, images2):
self._image_list += [[image1, image2]]
if not self._image_list:
raise FileNotFoundError(
"Could not find the HD1K images. Please make sure the directory structure is correct."
)
def _read_flow(self, file_name):
return _read_16bits_png_with_flow_and_valid_mask(file_name)
def __getitem__(self, index):
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
valid)`` where ``valid`` is a numpy boolean mask of shape (H, W)
indicating which flow values are valid. The flow is a numpy array of
shape (2, H, W) and the images are PIL images. If `split="test"`, a
4-tuple with ``(img1, img2, None, None)`` is returned.
"""
return super().__getitem__(index)
def _read_flo(file_name): def _read_flo(file_name):
"""Read .flo file in Middlebury format""" """Read .flo file in Middlebury format"""
# Code adapted from: # Code adapted from:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment