Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
a57e45c8
Unverified
Commit
a57e45c8
authored
Dec 08, 2021
by
Nicolas Hug
Committed by
GitHub
Dec 08, 2021
Browse files
Some updates to optical flow datasets (#5004)
parent
1be7afd6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
32 deletions
+48
-32
torchvision/datasets/_optical_flow.py
torchvision/datasets/_optical_flow.py
+48
-32
No files found.
torchvision/datasets/_optical_flow.py
View file @
a57e45c8
...
@@ -24,9 +24,9 @@ __all__ = (
...
@@ -24,9 +24,9 @@ __all__ = (
class
FlowDataset
(
ABC
,
VisionDataset
):
class
FlowDataset
(
ABC
,
VisionDataset
):
# Some datasets like Kitti have a built-in valid
mask, indicating which flow values are valid
# Some datasets like Kitti have a built-in valid
_flow_
mask, indicating which flow values are valid
# For those we return (img1, img2, flow, valid), and for the rest we return (img1, img2, flow),
# For those we return (img1, img2, flow, valid
_flow_mask
), and for the rest we return (img1, img2, flow),
# and it's up to whatever consumes the dataset to decide what
`
valid
`
should be.
# and it's up to whatever consumes the dataset to decide what valid
_flow_mask
should be.
_has_builtin_flow_mask
=
False
_has_builtin_flow_mask
=
False
def
__init__
(
self
,
root
,
transforms
=
None
):
def
__init__
(
self
,
root
,
transforms
=
None
):
...
@@ -38,11 +38,14 @@ class FlowDataset(ABC, VisionDataset):
...
@@ -38,11 +38,14 @@ class FlowDataset(ABC, VisionDataset):
self
.
_image_list
=
[]
self
.
_image_list
=
[]
def
_read_img
(
self
,
file_name
):
def
_read_img
(
self
,
file_name
):
return
Image
.
open
(
file_name
)
img
=
Image
.
open
(
file_name
)
if
img
.
mode
!=
"RGB"
:
img
=
img
.
convert
(
"RGB"
)
return
img
@
abstractmethod
@
abstractmethod
def
_read_flow
(
self
,
file_name
):
def
_read_flow
(
self
,
file_name
):
# Return the flow or a tuple with the flow and the valid
mask if _has_builtin_flow_mask is True
# Return the flow or a tuple with the flow and the valid
_flow_
mask if _has_builtin_flow_mask is True
pass
pass
def
__getitem__
(
self
,
index
):
def
__getitem__
(
self
,
index
):
...
@@ -53,23 +56,27 @@ class FlowDataset(ABC, VisionDataset):
...
@@ -53,23 +56,27 @@ class FlowDataset(ABC, VisionDataset):
if
self
.
_flow_list
:
# it will be empty for some dataset when split="test"
if
self
.
_flow_list
:
# it will be empty for some dataset when split="test"
flow
=
self
.
_read_flow
(
self
.
_flow_list
[
index
])
flow
=
self
.
_read_flow
(
self
.
_flow_list
[
index
])
if
self
.
_has_builtin_flow_mask
:
if
self
.
_has_builtin_flow_mask
:
flow
,
valid
=
flow
flow
,
valid
_flow_mask
=
flow
else
:
else
:
valid
=
None
valid
_flow_mask
=
None
else
:
else
:
flow
=
valid
=
None
flow
=
valid
_flow_mask
=
None
if
self
.
transforms
is
not
None
:
if
self
.
transforms
is
not
None
:
img1
,
img2
,
flow
,
valid
=
self
.
transforms
(
img1
,
img2
,
flow
,
valid
)
img1
,
img2
,
flow
,
valid
_flow_mask
=
self
.
transforms
(
img1
,
img2
,
flow
,
valid
_flow_mask
)
if
self
.
_has_builtin_flow_mask
:
if
self
.
_has_builtin_flow_mask
or
valid_flow_mask
is
not
None
:
return
img1
,
img2
,
flow
,
valid
# The `or valid_flow_mask is not None` part is here because the mask can be generated within a transform
return
img1
,
img2
,
flow
,
valid_flow_mask
else
:
else
:
return
img1
,
img2
,
flow
return
img1
,
img2
,
flow
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
_image_list
)
return
len
(
self
.
_image_list
)
def
__rmul__
(
self
,
v
):
return
torch
.
utils
.
data
.
ConcatDataset
([
self
]
*
v
)
class
Sintel
(
FlowDataset
):
class
Sintel
(
FlowDataset
):
"""`Sintel <http://sintel.is.tue.mpg.de/>`_ Dataset for optical flow.
"""`Sintel <http://sintel.is.tue.mpg.de/>`_ Dataset for optical flow.
...
@@ -107,8 +114,8 @@ class Sintel(FlowDataset):
...
@@ -107,8 +114,8 @@ class Sintel(FlowDataset):
pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
details on the different passes.
details on the different passes.
transforms (callable, optional): A function/transform that takes in
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``img1, img2, flow, valid
_flow_mask
`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
``valid
_flow_mask
`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
"""
...
@@ -140,9 +147,11 @@ class Sintel(FlowDataset):
...
@@ -140,9 +147,11 @@ class Sintel(FlowDataset):
index(int): The index of the example to retrieve
index(int): The index of the example to retrieve
Returns:
Returns:
tuple: If ``split="train"`` a 3-tuple with ``(img1, img2, flow)``.
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images. If `split="test"`, a
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
3-tuple with ``(img1, img2, None)`` is returned.
``flow`` is None if ``split="test"``.
If a valid flow mask is generated within the ``transforms`` parameter,
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
"""
"""
return
super
().
__getitem__
(
index
)
return
super
().
__getitem__
(
index
)
...
@@ -167,7 +176,7 @@ class KittiFlow(FlowDataset):
...
@@ -167,7 +176,7 @@ class KittiFlow(FlowDataset):
root (string): Root directory of the KittiFlow Dataset.
root (string): Root directory of the KittiFlow Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``img1, img2, flow, valid
_flow_mask
`` and returns a transformed version.
"""
"""
_has_builtin_flow_mask
=
True
_has_builtin_flow_mask
=
True
...
@@ -199,11 +208,11 @@ class KittiFlow(FlowDataset):
...
@@ -199,11 +208,11 @@ class KittiFlow(FlowDataset):
index(int): The index of the example to retrieve
index(int): The index of the example to retrieve
Returns:
Returns:
tuple:
If ``split="train"`` a
4-tuple with ``(img1, img2, flow,
tuple:
A
4-tuple with ``(img1, img2, flow,
valid_flow_mask)``
valid)``
where ``valid`` is a numpy boolean mask of shape (H, W)
where ``valid
_flow_mask
`` is a numpy boolean mask of shape (H, W)
indicating which flow values are valid. The flow is a numpy array of
indicating which flow values are valid. The flow is a numpy array of
shape (2, H, W) and the images are PIL images.
If `split="test"`, a
shape (2, H, W) and the images are PIL images.
``flow`` and ``valid_flow_mask`` are None if
4-tuple with ``(img1, img2, None, None)`` is returned
.
``split="test"``
.
"""
"""
return
super
().
__getitem__
(
index
)
return
super
().
__getitem__
(
index
)
...
@@ -232,8 +241,8 @@ class FlyingChairs(FlowDataset):
...
@@ -232,8 +241,8 @@ class FlyingChairs(FlowDataset):
root (string): Root directory of the FlyingChairs Dataset.
root (string): Root directory of the FlyingChairs Dataset.
split (string, optional): The dataset split, either "train" (default) or "val"
split (string, optional): The dataset split, either "train" (default) or "val"
transforms (callable, optional): A function/transform that takes in
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``img1, img2, flow, valid
_flow_mask
`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
``valid
_flow_mask
`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
"""
...
@@ -269,6 +278,9 @@ class FlyingChairs(FlowDataset):
...
@@ -269,6 +278,9 @@ class FlyingChairs(FlowDataset):
Returns:
Returns:
tuple: A 3-tuple with ``(img1, img2, flow)``.
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
``flow`` is None if ``split="val"``.
If a valid flow mask is generated within the ``transforms`` parameter,
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
"""
"""
return
super
().
__getitem__
(
index
)
return
super
().
__getitem__
(
index
)
...
@@ -300,8 +312,8 @@ class FlyingThings3D(FlowDataset):
...
@@ -300,8 +312,8 @@ class FlyingThings3D(FlowDataset):
details on the different passes.
details on the different passes.
camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
transforms (callable, optional): A function/transform that takes in
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``img1, img2, flow, valid
_flow_mask
`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
``valid
_flow_mask
`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
"""
...
@@ -357,6 +369,9 @@ class FlyingThings3D(FlowDataset):
...
@@ -357,6 +369,9 @@ class FlyingThings3D(FlowDataset):
Returns:
Returns:
tuple: A 3-tuple with ``(img1, img2, flow)``.
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
``flow`` is None if ``split="test"``.
If a valid flow mask is generated within the ``transforms`` parameter,
a 4-tuple with ``(img1, img2, flow, valid_flow_mask)`` is returned.
"""
"""
return
super
().
__getitem__
(
index
)
return
super
().
__getitem__
(
index
)
...
@@ -382,7 +397,7 @@ class HD1K(FlowDataset):
...
@@ -382,7 +397,7 @@ class HD1K(FlowDataset):
root (string): Root directory of the HD1K Dataset.
root (string): Root directory of the HD1K Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``img1, img2, flow, valid
_flow_mask
`` and returns a transformed version.
"""
"""
_has_builtin_flow_mask
=
True
_has_builtin_flow_mask
=
True
...
@@ -422,11 +437,11 @@ class HD1K(FlowDataset):
...
@@ -422,11 +437,11 @@ class HD1K(FlowDataset):
index(int): The index of the example to retrieve
index(int): The index of the example to retrieve
Returns:
Returns:
tuple:
If ``split="train"`` a
4-tuple with ``(img1, img2, flow,
tuple:
A
4-tuple with ``(img1, img2, flow,
valid_flow_mask)`` where ``valid_flow_mask``
valid)`` where ``valid``
is a numpy boolean mask of shape (H, W)
is a numpy boolean mask of shape (H, W)
indicating which flow values are valid. The flow is a numpy array of
indicating which flow values are valid. The flow is a numpy array of
shape (2, H, W) and the images are PIL images.
If `split="test"`, a
shape (2, H, W) and the images are PIL images.
``flow`` and ``valid_flow_mask`` are None if
4-tuple with ``(img1, img2, None, None)`` is returned
.
``split="test"``
.
"""
"""
return
super
().
__getitem__
(
index
)
return
super
().
__getitem__
(
index
)
...
@@ -451,11 +466,12 @@ def _read_flo(file_name):
...
@@ -451,11 +466,12 @@ def _read_flo(file_name):
def
_read_16bits_png_with_flow_and_valid_mask
(
file_name
):
def
_read_16bits_png_with_flow_and_valid_mask
(
file_name
):
flow_and_valid
=
_read_png_16
(
file_name
).
to
(
torch
.
float32
)
flow_and_valid
=
_read_png_16
(
file_name
).
to
(
torch
.
float32
)
flow
,
valid
=
flow_and_valid
[:
2
,
:,
:],
flow_and_valid
[
2
,
:,
:]
flow
,
valid
_flow_mask
=
flow_and_valid
[:
2
,
:,
:],
flow_and_valid
[
2
,
:,
:]
flow
=
(
flow
-
2
**
15
)
/
64
# This conversion is explained somewhere on the kitti archive
flow
=
(
flow
-
2
**
15
)
/
64
# This conversion is explained somewhere on the kitti archive
valid_flow_mask
=
valid_flow_mask
.
bool
()
# For consistency with other datasets, we convert to numpy
# For consistency with other datasets, we convert to numpy
return
flow
.
numpy
(),
valid
.
numpy
()
return
flow
.
numpy
(),
valid
_flow_mask
.
numpy
()
def
_read_pfm
(
file_name
):
def
_read_pfm
(
file_name
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment