Unverified Commit 7408cb51 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add pass_name='both' for Sintel dataset (#4888)

parent 43524b61
...@@ -1873,7 +1873,7 @@ class LFWPairsTestCase(LFWPeopleTestCase): ...@@ -1873,7 +1873,7 @@ class LFWPairsTestCase(LFWPeopleTestCase):
class SintelTestCase(datasets_utils.ImageDatasetTestCase): class SintelTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.Sintel DATASET_CLASS = datasets.Sintel
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"), pass_name=("clean", "final")) ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"), pass_name=("clean", "final", "both"))
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
FLOW_H, FLOW_W = 3, 4 FLOW_H, FLOW_W = 3, 4
...@@ -1909,7 +1909,8 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase): ...@@ -1909,7 +1909,8 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
# which are frame_0000, frame_0001 and frame_0002 # which are frame_0000, frame_0001 and frame_0002
# They will be consecutively paired as (frame_0000, frame_0001), (frame_0001, frame_0002), # They will be consecutively paired as (frame_0000, frame_0001), (frame_0001, frame_0002),
# that is 3 - 1 = 2 examples. Hence the formula below # that is 3 - 1 = 2 examples. Hence the formula below
num_examples = (num_images_per_scene - 1) * num_scenes num_passes = 2 if config["pass_name"] == "both" else 1
num_examples = (num_images_per_scene - 1) * num_scenes * num_passes
return num_examples return num_examples
def test_flow(self): def test_flow(self):
......
...@@ -103,7 +103,7 @@ class Sintel(FlowDataset): ...@@ -103,7 +103,7 @@ class Sintel(FlowDataset):
Args: Args:
root (string): Root directory of the Sintel Dataset. root (string): Root directory of the Sintel Dataset.
split (string, optional): The dataset split, either "train" (default) or "test" split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default) or "final". See link above for pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
details on the different passes. details on the different passes.
transforms (callable, optional): A function/transform that takes in transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version. ``img1, img2, flow, valid`` and returns a transformed version.
...@@ -115,21 +115,22 @@ class Sintel(FlowDataset): ...@@ -115,21 +115,22 @@ class Sintel(FlowDataset):
super().__init__(root=root, transforms=transforms) super().__init__(root=root, transforms=transforms)
verify_str_arg(split, "split", valid_values=("train", "test")) verify_str_arg(split, "split", valid_values=("train", "test"))
verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final")) verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
passes = ["clean", "final"] if pass_name == "both" else [pass_name]
root = Path(root) / "Sintel" root = Path(root) / "Sintel"
split_dir = "training" if split == "train" else split
image_root = root / split_dir / pass_name
flow_root = root / "training" / "flow" flow_root = root / "training" / "flow"
for scene in os.listdir(image_root): for pass_name in passes:
image_list = sorted(glob(str(image_root / scene / "*.png"))) split_dir = "training" if split == "train" else split
for i in range(len(image_list) - 1): image_root = root / split_dir / pass_name
self._image_list += [[image_list[i], image_list[i + 1]]] for scene in os.listdir(image_root):
image_list = sorted(glob(str(image_root / scene / "*.png")))
for i in range(len(image_list) - 1):
self._image_list += [[image_list[i], image_list[i + 1]]]
if split == "train": if split == "train":
self._flow_list += sorted(glob(str(flow_root / scene / "*.flo"))) self._flow_list += sorted(glob(str(flow_root / scene / "*.flo")))
def __getitem__(self, index): def __getitem__(self, index):
"""Return example at given index. """Return example at given index.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment