Unverified Commit b9f4ed93 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add HD1K dataset for optical flow (#4890)

parent 22ff44fd
......@@ -45,6 +45,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
Flickr30k
FlyingChairs
FlyingThings3D
HD1K
HMDB51
ImageNet
INaturalist
......
......@@ -2126,5 +2126,47 @@ class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase):
pass
class HD1KTestCase(KittiFlowTestCase):
DATASET_CLASS = datasets.HD1K
def inject_fake_data(self, tmpdir, config):
root = pathlib.Path(tmpdir) / "hd1k"
num_sequences = 4 if config["split"] == "train" else 3
num_examples_per_train_sequence = 3
for seq_idx in range(num_sequences):
# Training data
datasets_utils.create_image_folder(
root / "hd1k_input",
name="image_2",
file_name_fn=lambda image_idx: f"{seq_idx:06d}_{image_idx}.png",
num_examples=num_examples_per_train_sequence,
)
datasets_utils.create_image_folder(
root / "hd1k_flow_gt",
name="flow_occ",
file_name_fn=lambda image_idx: f"{seq_idx:06d}_{image_idx}.png",
num_examples=num_examples_per_train_sequence,
)
# Test data
datasets_utils.create_image_folder(
root / "hd1k_challenge",
name="image_2",
file_name_fn=lambda _: f"{seq_idx:06d}_10.png",
num_examples=1,
)
datasets_utils.create_image_folder(
root / "hd1k_challenge",
name="image_2",
file_name_fn=lambda _: f"{seq_idx:06d}_11.png",
num_examples=1,
)
num_examples_per_sequence = num_examples_per_train_sequence if config["split"] == "train" else 2
return num_sequences * (num_examples_per_sequence - 1)
if __name__ == "__main__":
unittest.main()
from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D
from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D, HD1K
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100
......@@ -76,4 +76,5 @@ __all__ = (
"Sintel",
"FlyingChairs",
"FlyingThings3D",
"HD1K",
)
......@@ -19,6 +19,7 @@ __all__ = (
"Sintel",
"FlyingThings3D",
"FlyingChairs",
"HD1K",
)
......@@ -363,6 +364,73 @@ class FlyingThings3D(FlowDataset):
return _read_pfm(file_name)
class HD1K(FlowDataset):
"""`HD1K <http://hci-benchmark.iwr.uni-heidelberg.de/>`__ dataset for optical flow.
The dataset is expected to have the following structure: ::
root
hd1k
hd1k_challenge
image_2
hd1k_flow_gt
flow_occ
hd1k_input
image_2
Args:
root (string): Root directory of the HD1K Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
"""
_has_builtin_flow_mask = True
def __init__(self, root, split="train", transforms=None):
super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test"))
root = Path(root) / "hd1k"
if split == "train":
# There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop
for seq_idx in range(36):
flows = sorted(glob(str(root / "hd1k_flow_gt" / "flow_occ" / f"{seq_idx:06d}_*.png")))
images = sorted(glob(str(root / "hd1k_input" / "image_2" / f"{seq_idx:06d}_*.png")))
for i in range(len(flows) - 1):
self._flow_list += [flows[i]]
self._image_list += [[images[i], images[i + 1]]]
else:
images1 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*10.png")))
images2 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*11.png")))
for image1, image2 in zip(images1, images2):
self._image_list += [[image1, image2]]
if not self._image_list:
raise FileNotFoundError(
"Could not find the HD1K images. Please make sure the directory structure is correct."
)
def _read_flow(self, file_name):
return _read_16bits_png_with_flow_and_valid_mask(file_name)
def __getitem__(self, index):
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
valid)`` where ``valid`` is a numpy boolean mask of shape (H, W)
indicating which flow values are valid. The flow is a numpy array of
shape (2, H, W) and the images are PIL images. If `split="test"`, a
4-tuple with ``(img1, img2, None, None)`` is returned.
"""
return super().__getitem__(index)
def _read_flo(file_name):
"""Read .flo file in Middlebury format"""
# Code adapted from:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment