_optical_flow.py 19.2 KB
Newer Older
1
import itertools
2
3
4
5
import os
from abc import ABC, abstractmethod
from glob import glob
from pathlib import Path
6
from typing import Callable, List, Optional, Tuple, Union
7
8
9
10
11
12

import numpy as np
import torch
from PIL import Image

from ..io.image import _read_png_16
13
from .utils import _read_pfm, verify_str_arg
14
15
from .vision import VisionDataset

16
17
18
19
T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]


20
21
22
__all__ = (
    "KittiFlow",
    "Sintel",
23
    "FlyingThings3D",
24
    "FlyingChairs",
25
    "HD1K",
26
27
28
29
)


class FlowDataset(ABC, VisionDataset):
30
31
32
    # Some datasets like Kitti have a built-in valid_flow_mask, indicating which flow values are valid
    # For those we return (img1, img2, flow, valid_flow_mask), and for the rest we return (img1, img2, flow),
    # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
33
34
    _has_builtin_flow_mask = False

35
    def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
36
37
38
39

        super().__init__(root=root)
        self.transforms = transforms

40
41
        self._flow_list: List[str] = []
        self._image_list: List[List[str]] = []
42

43
    def _read_img(self, file_name: str) -> Image.Image:
44
45
46
47
        img = Image.open(file_name)
        if img.mode != "RGB":
            img = img.convert("RGB")
        return img
48
49

    @abstractmethod
50
    def _read_flow(self, file_name: str):
51
        # Return the flow or a tuple with the flow and the valid_flow_mask if _has_builtin_flow_mask is True
52
53
        pass

54
    def __getitem__(self, index: int) -> Union[T1, T2]:
55
56
57
58
59
60
61

        img1 = self._read_img(self._image_list[index][0])
        img2 = self._read_img(self._image_list[index][1])

        if self._flow_list:  # it will be empty for some dataset when split="test"
            flow = self._read_flow(self._flow_list[index])
            if self._has_builtin_flow_mask:
62
                flow, valid_flow_mask = flow
63
            else:
64
                valid_flow_mask = None
65
        else:
66
            flow = valid_flow_mask = None
67
68

        if self.transforms is not None:
69
            img1, img2, flow, valid_flow_mask = self.transforms(img1, img2, flow, valid_flow_mask)
70

71
72
73
        if self._has_builtin_flow_mask or valid_flow_mask is not None:
            # The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
            return img1, img2, flow, valid_flow_mask
74
75
76
        else:
            return img1, img2, flow

77
    def __len__(self) -> int:
78
79
        return len(self._image_list)

80
    def __rmul__(self, v: int) -> torch.utils.data.ConcatDataset:
81
82
        return torch.utils.data.ConcatDataset([self] * v)

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

class Sintel(FlowDataset):
    """`Sintel <http://sintel.is.tue.mpg.de/>`_ Dataset for optical flow.

    The dataset is expected to have the following structure: ::

        root
            Sintel
                testing
                    clean
                        scene_1
                        scene_2
                        ...
                    final
                        scene_1
                        scene_2
                        ...
                training
                    clean
                        scene_1
                        scene_2
                        ...
                    final
                        scene_1
                        scene_2
                        ...
                    flow
                        scene_1
                        scene_2
                        ...

    Args:
115
        root (str or ``pathlib.Path``): Root directory of the Sintel Dataset.
116
        split (string, optional): The dataset split, either "train" (default) or "test"
117
        pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
118
119
            details on the different passes.
        transforms (callable, optional): A function/transform that takes in
120
121
            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
            ``valid_flow_mask`` is expected for consistency with other datasets which
122
123
124
            return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
    """

125
126
    def __init__(
        self,
127
        root: Union[str, Path],
128
129
130
131
        split: str = "train",
        pass_name: str = "clean",
        transforms: Optional[Callable] = None,
    ) -> None:
132
133
        super().__init__(root=root, transforms=transforms)

134
        verify_str_arg(split, "split", valid_values=("train", "test"))
135
136
        verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
        passes = ["clean", "final"] if pass_name == "both" else [pass_name]
137
138
139
140

        root = Path(root) / "Sintel"
        flow_root = root / "training" / "flow"

141
142
143
144
145
146
147
        for pass_name in passes:
            split_dir = "training" if split == "train" else split
            image_root = root / split_dir / pass_name
            for scene in os.listdir(image_root):
                image_list = sorted(glob(str(image_root / scene / "*.png")))
                for i in range(len(image_list) - 1):
                    self._image_list += [[image_list[i], image_list[i + 1]]]
148

149
150
                if split == "train":
                    self._flow_list += sorted(glob(str(flow_root / scene / "*.flo")))
151

152
    def __getitem__(self, index: int) -> Union[T1, T2]:
153
154
155
156
157
158
        """Return example at given index.

        Args:
            index(int): The index of the example to retrieve

        Returns:
159
160
161
162
163
            tuple: A 3-tuple with ``(img1, img2, flow)``.
            The flow is a numpy array of shape (2, H, W) and the images are PIL images.
            ``flow`` is None if ``split="test"``.
            If a valid flow mask is generated within the ``transforms`` parameter,
            a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
164
165
166
        """
        return super().__getitem__(index)

167
    def _read_flow(self, file_name: str) -> np.ndarray:
168
169
170
171
172
173
174
175
176
        return _read_flo(file_name)


class KittiFlow(FlowDataset):
    """`KITTI <http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow>`__ dataset for optical flow (2015).

    The dataset is expected to have the following structure: ::

        root
177
            KittiFlow
178
179
180
181
182
183
184
                testing
                    image_2
                training
                    image_2
                    flow_occ

    Args:
185
        root (str or ``pathlib.Path``): Root directory of the KittiFlow Dataset.
186
187
        split (string, optional): The dataset split, either "train" (default) or "test"
        transforms (callable, optional): A function/transform that takes in
188
            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
189
190
191
192
    """

    _has_builtin_flow_mask = True

193
    def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
194
195
        super().__init__(root=root, transforms=transforms)

196
        verify_str_arg(split, "split", valid_values=("train", "test"))
197

198
        root = Path(root) / "KittiFlow" / (split + "ing")
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        images1 = sorted(glob(str(root / "image_2" / "*_10.png")))
        images2 = sorted(glob(str(root / "image_2" / "*_11.png")))

        if not images1 or not images2:
            raise FileNotFoundError(
                "Could not find the Kitti flow images. Please make sure the directory structure is correct."
            )

        for img1, img2 in zip(images1, images2):
            self._image_list += [[img1, img2]]

        if split == "train":
            self._flow_list = sorted(glob(str(root / "flow_occ" / "*_10.png")))

213
    def __getitem__(self, index: int) -> Union[T1, T2]:
214
215
216
217
218
219
        """Return example at given index.

        Args:
            index(int): The index of the example to retrieve

        Returns:
220
221
            tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)``
            where ``valid_flow_mask`` is a numpy boolean mask of shape (H, W)
222
            indicating which flow values are valid. The flow is a numpy array of
223
224
            shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
            ``split="test"``.
225
226
227
        """
        return super().__getitem__(index)

228
    def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
229
230
231
        return _read_16bits_png_with_flow_and_valid_mask(file_name)


232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
class FlyingChairs(FlowDataset):
    """`FlyingChairs <https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs>`_ Dataset for optical flow.

    You will also need to download the FlyingChairs_train_val.txt file from the dataset page.

    The dataset is expected to have the following structure: ::

        root
            FlyingChairs
                data
                    00001_flow.flo
                    00001_img1.ppm
                    00001_img2.ppm
                    ...
                FlyingChairs_train_val.txt


    Args:
250
        root (str or ``pathlib.Path``): Root directory of the FlyingChairs Dataset.
251
252
        split (string, optional): The dataset split, either "train" (default) or "val"
        transforms (callable, optional): A function/transform that takes in
253
254
            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
            ``valid_flow_mask`` is expected for consistency with other datasets which
255
256
257
            return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
    """

258
    def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
        super().__init__(root=root, transforms=transforms)

        verify_str_arg(split, "split", valid_values=("train", "val"))

        root = Path(root) / "FlyingChairs"
        images = sorted(glob(str(root / "data" / "*.ppm")))
        flows = sorted(glob(str(root / "data" / "*.flo")))

        split_file_name = "FlyingChairs_train_val.txt"

        if not os.path.exists(root / split_file_name):
            raise FileNotFoundError(
                "The FlyingChairs_train_val.txt file was not found - please download it from the dataset page (see docstring)."
            )

        split_list = np.loadtxt(str(root / split_file_name), dtype=np.int32)
        for i in range(len(flows)):
            split_id = split_list[i]
            if (split == "train" and split_id == 1) or (split == "val" and split_id == 2):
                self._flow_list += [flows[i]]
                self._image_list += [[images[2 * i], images[2 * i + 1]]]

281
    def __getitem__(self, index: int) -> Union[T1, T2]:
282
283
284
285
286
287
288
289
        """Return example at given index.

        Args:
            index(int): The index of the example to retrieve

        Returns:
            tuple: A 3-tuple with ``(img1, img2, flow)``.
            The flow is a numpy array of shape (2, H, W) and the images are PIL images.
290
291
292
            ``flow`` is None if ``split="val"``.
            If a valid flow mask is generated within the ``transforms`` parameter,
            a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
293
294
295
        """
        return super().__getitem__(index)

296
    def _read_flow(self, file_name: str) -> np.ndarray:
297
298
299
        return _read_flo(file_name)


300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
class FlyingThings3D(FlowDataset):
    """`FlyingThings3D <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ dataset for optical flow.

    The dataset is expected to have the following structure: ::

        root
            FlyingThings3D
                frames_cleanpass
                    TEST
                    TRAIN
                frames_finalpass
                    TEST
                    TRAIN
                optical_flow
                    TEST
                    TRAIN

    Args:
318
        root (str or ``pathlib.Path``): Root directory of the intel FlyingThings3D Dataset.
319
320
321
322
323
        split (string, optional): The dataset split, either "train" (default) or "test"
        pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
            details on the different passes.
        camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
        transforms (callable, optional): A function/transform that takes in
324
325
            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
            ``valid_flow_mask`` is expected for consistency with other datasets which
326
327
328
            return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
    """

329
330
    def __init__(
        self,
331
        root: Union[str, Path],
332
333
334
335
336
        split: str = "train",
        pass_name: str = "clean",
        camera: str = "left",
        transforms: Optional[Callable] = None,
    ) -> None:
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
        super().__init__(root=root, transforms=transforms)

        verify_str_arg(split, "split", valid_values=("train", "test"))
        split = split.upper()

        verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both"))
        passes = {
            "clean": ["frames_cleanpass"],
            "final": ["frames_finalpass"],
            "both": ["frames_cleanpass", "frames_finalpass"],
        }[pass_name]

        verify_str_arg(camera, "camera", valid_values=("left", "right", "both"))
        cameras = ["left", "right"] if camera == "both" else [camera]

        root = Path(root) / "FlyingThings3D"

        directions = ("into_future", "into_past")
        for pass_name, camera, direction in itertools.product(passes, cameras, directions):
            image_dirs = sorted(glob(str(root / pass_name / split / "*/*")))
357
            image_dirs = sorted(Path(image_dir) / camera for image_dir in image_dirs)
358
359

            flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*")))
360
            flow_dirs = sorted(Path(flow_dir) / direction / camera for flow_dir in flow_dirs)
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378

            if not image_dirs or not flow_dirs:
                raise FileNotFoundError(
                    "Could not find the FlyingThings3D flow images. "
                    "Please make sure the directory structure is correct."
                )

            for image_dir, flow_dir in zip(image_dirs, flow_dirs):
                images = sorted(glob(str(image_dir / "*.png")))
                flows = sorted(glob(str(flow_dir / "*.pfm")))
                for i in range(len(flows) - 1):
                    if direction == "into_future":
                        self._image_list += [[images[i], images[i + 1]]]
                        self._flow_list += [flows[i]]
                    elif direction == "into_past":
                        self._image_list += [[images[i + 1], images[i]]]
                        self._flow_list += [flows[i + 1]]

379
    def __getitem__(self, index: int) -> Union[T1, T2]:
380
381
382
383
384
385
386
387
        """Return example at given index.

        Args:
            index(int): The index of the example to retrieve

        Returns:
            tuple: A 3-tuple with ``(img1, img2, flow)``.
            The flow is a numpy array of shape (2, H, W) and the images are PIL images.
388
389
390
            ``flow`` is None if ``split="test"``.
            If a valid flow mask is generated within the ``transforms`` parameter,
            a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
391
392
393
        """
        return super().__getitem__(index)

394
    def _read_flow(self, file_name: str) -> np.ndarray:
395
396
397
        return _read_pfm(file_name)


398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
class HD1K(FlowDataset):
    """`HD1K <http://hci-benchmark.iwr.uni-heidelberg.de/>`__ dataset for optical flow.

    The dataset is expected to have the following structure: ::

        root
            hd1k
                hd1k_challenge
                    image_2
                hd1k_flow_gt
                    flow_occ
                hd1k_input
                    image_2

    Args:
413
        root (str or ``pathlib.Path``): Root directory of the HD1K Dataset.
414
415
        split (string, optional): The dataset split, either "train" (default) or "test"
        transforms (callable, optional): A function/transform that takes in
416
            ``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
417
418
419
420
    """

    _has_builtin_flow_mask = True

421
    def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        super().__init__(root=root, transforms=transforms)

        verify_str_arg(split, "split", valid_values=("train", "test"))

        root = Path(root) / "hd1k"
        if split == "train":
            # There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop
            for seq_idx in range(36):
                flows = sorted(glob(str(root / "hd1k_flow_gt" / "flow_occ" / f"{seq_idx:06d}_*.png")))
                images = sorted(glob(str(root / "hd1k_input" / "image_2" / f"{seq_idx:06d}_*.png")))
                for i in range(len(flows) - 1):
                    self._flow_list += [flows[i]]
                    self._image_list += [[images[i], images[i + 1]]]
        else:
            images1 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*10.png")))
            images2 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*11.png")))
            for image1, image2 in zip(images1, images2):
                self._image_list += [[image1, image2]]

        if not self._image_list:
            raise FileNotFoundError(
                "Could not find the HD1K images. Please make sure the directory structure is correct."
            )

446
    def _read_flow(self, file_name: str) -> Tuple[np.ndarray, np.ndarray]:
447
448
        return _read_16bits_png_with_flow_and_valid_mask(file_name)

449
    def __getitem__(self, index: int) -> Union[T1, T2]:
450
451
452
453
454
455
        """Return example at given index.

        Args:
            index(int): The index of the example to retrieve

        Returns:
456
457
            tuple: A 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` where ``valid_flow_mask``
            is a numpy boolean mask of shape (H, W)
458
            indicating which flow values are valid. The flow is a numpy array of
459
460
            shape (2, H, W) and the images are PIL images. ``flow`` and ``valid_flow_mask`` are None if
            ``split="test"``.
461
462
463
464
        """
        return super().__getitem__(index)


465
def _read_flo(file_name: str) -> np.ndarray:
466
467
468
    """Read .flo file in Middlebury format"""
    # Code adapted from:
    # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
469
470
    # Everything needs to be in little Endian according to
    # https://vision.middlebury.edu/flow/code/flow-code/README.txt
471
    with open(file_name, "rb") as f:
472
473
        magic = np.fromfile(f, "c", count=4).tobytes()
        if magic != b"PIEH":
474
475
            raise ValueError("Magic number incorrect. Invalid .flo file")

476
477
478
        w = int(np.fromfile(f, "<i4", count=1))
        h = int(np.fromfile(f, "<i4", count=1))
        data = np.fromfile(f, "<f4", count=2 * w * h)
479
        return data.reshape(h, w, 2).transpose(2, 0, 1)
480
481


482
def _read_16bits_png_with_flow_and_valid_mask(file_name: str) -> Tuple[np.ndarray, np.ndarray]:
483
484

    flow_and_valid = _read_png_16(file_name).to(torch.float32)
485
    flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
486
    flow = (flow - 2**15) / 64  # This conversion is explained somewhere on the kitti archive
487
    valid_flow_mask = valid_flow_mask.bool()
488
489

    # For consistency with other datasets, we convert to numpy
490
    return flow.numpy(), valid_flow_mask.numpy()