Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
a57e45c8
Unverified
Commit
a57e45c8
authored
Dec 08, 2021
by
Nicolas Hug
Committed by
GitHub
Dec 08, 2021
Browse files
Some updates to optical flow datasets (#5004)
parent
1be7afd6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
32 deletions
+48
-32
torchvision/datasets/_optical_flow.py
torchvision/datasets/_optical_flow.py
+48
-32
No files found.
torchvision/datasets/_optical_flow.py
View file @
a57e45c8
...
...
@@ -24,9 +24,9 @@ __all__ = (
class
FlowDataset
(
ABC
,
VisionDataset
):
# Some datasets like Kitti have a built-in valid
mask, indicating which flow values are valid
# For those we return (img1, img2, flow, valid), and for the rest we return (img1, img2, flow),
# and it's up to whatever consumes the dataset to decide what
`
valid
`
should be.
# Some datasets like Kitti have a built-in valid
_flow_
mask, indicating which flow values are valid
# For those we return (img1, img2, flow, valid
_flow_mask
), and for the rest we return (img1, img2, flow),
# and it's up to whatever consumes the dataset to decide what valid
_flow_mask
should be.
_has_builtin_flow_mask
=
False
def
__init__
(
self
,
root
,
transforms
=
None
):
...
...
@@ -38,11 +38,14 @@ class FlowDataset(ABC, VisionDataset):
self
.
_image_list
=
[]
def
_read_img
(
self
,
file_name
):
return
Image
.
open
(
file_name
)
img
=
Image
.
open
(
file_name
)
if
img
.
mode
!=
"RGB"
:
img
=
img
.
convert
(
"RGB"
)
return
img
@
abstractmethod
def
_read_flow
(
self
,
file_name
):
# Return the flow or a tuple with the flow and the valid
mask if _has_builtin_flow_mask is True
# Return the flow or a tuple with the flow and the valid
_flow_
mask if _has_builtin_flow_mask is True
pass
def
__getitem__
(
self
,
index
):
...
...
@@ -53,23 +56,27 @@ class FlowDataset(ABC, VisionDataset):
if
self
.
_flow_list
:
# it will be empty for some dataset when split="test"
flow
=
self
.
_read_flow
(
self
.
_flow_list
[
index
])
if
self
.
_has_builtin_flow_mask
:
flow
,
valid
=
flow
flow
,
valid
_flow_mask
=
flow
else
:
valid
=
None
valid
_flow_mask
=
None
else
:
flow
=
valid
=
None
flow
=
valid
_flow_mask
=
None
if
self
.
transforms
is
not
None
:
img1
,
img2
,
flow
,
valid
=
self
.
transforms
(
img1
,
img2
,
flow
,
valid
)
img1
,
img2
,
flow
,
valid
_flow_mask
=
self
.
transforms
(
img1
,
img2
,
flow
,
valid
_flow_mask
)
if
self
.
_has_builtin_flow_mask
:
return
img1
,
img2
,
flow
,
valid
if
self
.
_has_builtin_flow_mask
or
valid_flow_mask
is
not
None
:
# The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
return
img1
,
img2
,
flow
,
valid_flow_mask
else
:
return
img1
,
img2
,
flow
def
__len__
(
self
):
return
len
(
self
.
_image_list
)
def
__rmul__
(
self
,
v
):
return
torch
.
utils
.
data
.
ConcatDataset
([
self
]
*
v
)
class
Sintel
(
FlowDataset
):
"""`Sintel <http://sintel.is.tue.mpg.de/>`_ Dataset for optical flow.
...
...
@@ -107,8 +114,8 @@ class Sintel(FlowDataset):
pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
details on the different passes.
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
``img1, img2, flow, valid
_flow_mask
`` and returns a transformed version.
``valid
_flow_mask
`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
...
...
@@ -140,9 +147,11 @@ class Sintel(FlowDataset):
index(int): The index of the example to retrieve
Returns:
tuple: If ``split="train"`` a 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images. If `split="test"`, a
3-tuple with ``(img1, img2, None)`` is returned.
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
``flow`` is None if ``split="test"``.
If a valid flow mask is generated within the ``transforms`` parameter,
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
"""
return
super
().
__getitem__
(
index
)
...
...
@@ -167,7 +176,7 @@ class KittiFlow(FlowDataset):
root (string): Root directory of the KittiFlow Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``img1, img2, flow, valid
_flow_mask
`` and returns a transformed version.
"""
_has_builtin_flow_mask
=
True
...
...
@@ -199,11 +208,11 @@ class KittiFlow(FlowDataset):
index(int): The index of the example to retrieve
Returns:
tuple:
If ``split="train"`` a
4-tuple with ``(img1, img2, flow,
valid)``
where ``valid`` is a numpy boolean mask of shape (H, W)
tuple:
A
4-tuple with ``(img1, img2, flow,
valid_flow_mask)``
where ``valid
_flow_mask
`` is a numpy boolean mask of shape (H, W)
indicating which flow values are valid. The flow is a numpy array of
shape (2, H, W) and the images are PIL images.
If `split="test"`, a
4-tuple with ``(img1, img2, None, None)`` is returned
.
shape (2, H, W) and the images are PIL images.
``flow`` and ``valid_flow_mask`` are None if
``split="test"``
.
"""
return
super
().
__getitem__
(
index
)
...
...
@@ -232,8 +241,8 @@ class FlyingChairs(FlowDataset):
root (string): Root directory of the FlyingChairs Dataset.
split (string, optional): The dataset split, either "train" (default) or "val"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
``img1, img2, flow, valid
_flow_mask
`` and returns a transformed version.
``valid
_flow_mask
`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
...
...
@@ -269,6 +278,9 @@ class FlyingChairs(FlowDataset):
Returns:
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
``flow`` is None if ``split="val"``.
If a valid flow mask is generated within the ``transforms`` parameter,
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
"""
return
super
().
__getitem__
(
index
)
...
...
@@ -300,8 +312,8 @@ class FlyingThings3D(FlowDataset):
details on the different passes.
camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
``img1, img2, flow, valid
_flow_mask
`` and returns a transformed version.
``valid
_flow_mask
`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
...
...
@@ -357,6 +369,9 @@ class FlyingThings3D(FlowDataset):
Returns:
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
``flow`` is None if ``split="test"``.
If a valid flow mask is generated within the ``transforms`` parameter,
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
"""
return
super
().
__getitem__
(
index
)
...
...
@@ -382,7 +397,7 @@ class HD1K(FlowDataset):
root (string): Root directory of the HD1K Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``img1, img2, flow, valid
_flow_mask
`` and returns a transformed version.
"""
_has_builtin_flow_mask
=
True
...
...
@@ -422,11 +437,11 @@ class HD1K(FlowDataset):
index(int): The index of the example to retrieve
Returns:
tuple:
If ``split="train"`` a
4-tuple with ``(img1, img2, flow,
valid)`` where ``valid``
is a numpy boolean mask of shape (H, W)
tuple:
A
4-tuple with ``(img1, img2, flow,
valid_flow_mask)`` where ``valid_flow_mask``
is a numpy boolean mask of shape (H, W)
indicating which flow values are valid. The flow is a numpy array of
shape (2, H, W) and the images are PIL images.
If `split="test"`, a
4-tuple with ``(img1, img2, None, None)`` is returned
.
shape (2, H, W) and the images are PIL images.
``flow`` and ``valid_flow_mask`` are None if
``split="test"``
.
"""
return
super
().
__getitem__
(
index
)
...
...
@@ -451,11 +466,12 @@ def _read_flo(file_name):
def
_read_16bits_png_with_flow_and_valid_mask
(
file_name
):
flow_and_valid
=
_read_png_16
(
file_name
).
to
(
torch
.
float32
)
flow
,
valid
=
flow_and_valid
[:
2
,
:,
:],
flow_and_valid
[
2
,
:,
:]
flow
,
valid
_flow_mask
=
flow_and_valid
[:
2
,
:,
:],
flow_and_valid
[
2
,
:,
:]
flow
=
(
flow
-
2
**
15
)
/
64
# This conversion is explained somewhere on the kitti archive
valid_flow_mask
=
valid_flow_mask
.
bool
()
# For consistency with other datasets, we convert to numpy
return
flow
.
numpy
(),
valid
.
numpy
()
return
flow
.
numpy
(),
valid
_flow_mask
.
numpy
()
def
_read_pfm
(
file_name
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment