Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
1bd131c7
"src/vscode:/vscode.git/clone" did not exist on "ac5a1e28fc9cc233863bcfb2abb9eef6807f156f"
Unverified
Commit
1bd131c7
authored
Nov 05, 2021
by
Nicolas Hug
Committed by
GitHub
Nov 05, 2021
Browse files
Add FlyingThings3D dataset for optical flow (#4858)
parent
7f424379
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
198 additions
and
8 deletions
+198
-8
docs/source/datasets.rst
docs/source/datasets.rst
+1
-0
test/datasets_utils.py
test/datasets_utils.py
+9
-7
test/test_datasets.py
test/test_datasets.py
+67
-0
torchvision/datasets/__init__.py
torchvision/datasets/__init__.py
+2
-1
torchvision/datasets/_optical_flow.py
torchvision/datasets/_optical_flow.py
+119
-0
No files found.
docs/source/datasets.rst
View file @
1bd131c7
...
...
@@ -44,6 +44,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
Flickr8k
Flickr30k
FlyingChairs
FlyingThings3D
HMDB51
ImageNet
INaturalist
...
...
test/datasets_utils.py
View file @
1bd131c7
...
...
@@ -204,7 +204,6 @@ class DatasetTestCase(unittest.TestCase):
``transforms``, or ``download``.
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
available, the tests are skipped.
- EXTRA_PATCHES(set): Additional patches to add for each test, to e.g. mock a specific function
Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on.
The fake data should resemble the original data as close as necessary, while containing only few examples. During
...
...
@@ -256,8 +255,6 @@ class DatasetTestCase(unittest.TestCase):
ADDITIONAL_CONFIGS
=
None
REQUIRED_PACKAGES
=
None
EXTRA_PATCHES
=
None
# These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
_TRANSFORM_KWARGS
=
{
"transform"
,
...
...
@@ -383,9 +380,6 @@ class DatasetTestCase(unittest.TestCase):
if
patch_checks
:
patchers
.
update
(
self
.
_patch_checks
())
if
self
.
EXTRA_PATCHES
is
not
None
:
patchers
.
update
(
self
.
EXTRA_PATCHES
)
with
get_tmp_dir
()
as
tmpdir
:
args
=
self
.
dataset_args
(
tmpdir
,
complete_config
)
info
=
self
.
_inject_fake_data
(
tmpdir
,
complete_config
)
if
inject_fake_data
else
None
...
...
@@ -925,6 +919,14 @@ def create_random_string(length: int, *digits: str) -> str:
return
""
.
join
(
random
.
choice
(
digits
)
for
_
in
range
(
length
))
def
make_fake_pfm_file
(
h
,
w
,
file_name
):
values
=
list
(
range
(
3
*
h
*
w
))
# Note: we pack everything in little endian: -1.0, and "<"
content
=
f
"PF
\n
{
w
}
{
h
}
\n
-1.0
\n
"
.
encode
()
+
struct
.
pack
(
"<"
+
"f"
*
len
(
values
),
*
values
)
with
open
(
file_name
,
"wb"
)
as
f
:
f
.
write
(
content
)
def
make_fake_flo_file
(
h
,
w
,
file_name
):
"""Creates a fake flow file in .flo format."""
values
=
list
(
range
(
2
*
h
*
w
))
...
...
test/test_datasets.py
View file @
1bd131c7
...
...
@@ -2048,5 +2048,72 @@ class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase):
np
.
testing
.
assert_allclose
(
flow
,
np
.
arange
(
flow
.
size
).
reshape
(
flow
.
shape
))
class
FlyingThings3DTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
FlyingThings3D
ADDITIONAL_CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
"train"
,
"test"
),
pass_name
=
(
"clean"
,
"final"
,
"both"
),
camera
=
(
"left"
,
"right"
,
"both"
)
)
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
(
np
.
ndarray
,
type
(
None
)))
FLOW_H
,
FLOW_W
=
3
,
4
def
inject_fake_data
(
self
,
tmpdir
,
config
):
root
=
pathlib
.
Path
(
tmpdir
)
/
"FlyingThings3D"
num_images_per_camera
=
3
if
config
[
"split"
]
==
"train"
else
4
passes
=
(
"frames_cleanpass"
,
"frames_finalpass"
)
splits
=
(
"TRAIN"
,
"TEST"
)
letters
=
(
"A"
,
"B"
,
"C"
)
subfolders
=
(
"0000"
,
"0001"
)
cameras
=
(
"left"
,
"right"
)
for
pass_name
,
split
,
letter
,
subfolder
,
camera
in
itertools
.
product
(
passes
,
splits
,
letters
,
subfolders
,
cameras
):
current_folder
=
root
/
pass_name
/
split
/
letter
/
subfolder
datasets_utils
.
create_image_folder
(
current_folder
,
name
=
camera
,
file_name_fn
=
lambda
image_idx
:
f
"00
{
image_idx
}
.png"
,
num_examples
=
num_images_per_camera
,
)
directions
=
(
"into_future"
,
"into_past"
)
for
split
,
letter
,
subfolder
,
direction
,
camera
in
itertools
.
product
(
splits
,
letters
,
subfolders
,
directions
,
cameras
):
current_folder
=
root
/
"optical_flow"
/
split
/
letter
/
subfolder
/
direction
/
camera
os
.
makedirs
(
str
(
current_folder
),
exist_ok
=
True
)
for
i
in
range
(
num_images_per_camera
):
datasets_utils
.
make_fake_pfm_file
(
self
.
FLOW_H
,
self
.
FLOW_W
,
file_name
=
str
(
current_folder
/
f
"
{
i
}
.pfm"
))
num_cameras
=
2
if
config
[
"camera"
]
==
"both"
else
1
num_passes
=
2
if
config
[
"pass_name"
]
==
"both"
else
1
num_examples
=
(
(
num_images_per_camera
-
1
)
*
num_cameras
*
len
(
subfolders
)
*
len
(
letters
)
*
len
(
splits
)
*
num_passes
)
return
num_examples
@
datasets_utils
.
test_all_configs
def
test_flow
(
self
,
config
):
with
self
.
create_dataset
(
config
=
config
)
as
(
dataset
,
_
):
assert
dataset
.
_flow_list
and
len
(
dataset
.
_flow_list
)
==
len
(
dataset
.
_image_list
)
for
_
,
_
,
flow
in
dataset
:
assert
flow
.
shape
==
(
2
,
self
.
FLOW_H
,
self
.
FLOW_W
)
# We don't check the values because the reshaping and flipping makes it hard to figure out
def
test_bad_input
(
self
):
with
pytest
.
raises
(
ValueError
,
match
=
"Unknown value 'bad' for argument split"
):
with
self
.
create_dataset
(
split
=
"bad"
):
pass
with
pytest
.
raises
(
ValueError
,
match
=
"Unknown value 'bad' for argument pass_name"
):
with
self
.
create_dataset
(
pass_name
=
"bad"
):
pass
with
pytest
.
raises
(
ValueError
,
match
=
"Unknown value 'bad' for argument camera"
):
with
self
.
create_dataset
(
camera
=
"bad"
):
pass
if
__name__
==
"__main__"
:
unittest
.
main
()
torchvision/datasets/__init__.py
View file @
1bd131c7
from
._optical_flow
import
KittiFlow
,
Sintel
,
FlyingChairs
from
._optical_flow
import
KittiFlow
,
Sintel
,
FlyingChairs
,
FlyingThings3D
from
.caltech
import
Caltech101
,
Caltech256
from
.celeba
import
CelebA
from
.cifar
import
CIFAR10
,
CIFAR100
...
...
@@ -75,4 +75,5 @@ __all__ = (
"KittiFlow"
,
"Sintel"
,
"FlyingChairs"
,
"FlyingThings3D"
,
)
torchvision/datasets/_optical_flow.py
View file @
1bd131c7
import
itertools
import
os
import
re
from
abc
import
ABC
,
abstractmethod
from
glob
import
glob
from
pathlib
import
Path
...
...
@@ -15,6 +17,7 @@ from .vision import VisionDataset
__all__
=
(
"KittiFlow"
,
"Sintel"
,
"FlyingThings3D"
,
"FlyingChairs"
,
)
...
...
@@ -271,6 +274,94 @@ class FlyingChairs(FlowDataset):
return
_read_flo
(
file_name
)
class
FlyingThings3D
(
FlowDataset
):
"""`FlyingThings3D <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ dataset for optical flow.
The dataset is expected to have the following structure: ::
root
FlyingThings3D
frames_cleanpass
TEST
TRAIN
frames_finalpass
TEST
TRAIN
optical_flow
TEST
TRAIN
Args:
root (string): Root directory of the intel FlyingThings3D Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
details on the different passes.
camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
def
__init__
(
self
,
root
,
split
=
"train"
,
pass_name
=
"clean"
,
camera
=
"left"
,
transforms
=
None
):
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
split
=
split
.
upper
()
verify_str_arg
(
pass_name
,
"pass_name"
,
valid_values
=
(
"clean"
,
"final"
,
"both"
))
passes
=
{
"clean"
:
[
"frames_cleanpass"
],
"final"
:
[
"frames_finalpass"
],
"both"
:
[
"frames_cleanpass"
,
"frames_finalpass"
],
}[
pass_name
]
verify_str_arg
(
camera
,
"camera"
,
valid_values
=
(
"left"
,
"right"
,
"both"
))
cameras
=
[
"left"
,
"right"
]
if
camera
==
"both"
else
[
camera
]
root
=
Path
(
root
)
/
"FlyingThings3D"
directions
=
(
"into_future"
,
"into_past"
)
for
pass_name
,
camera
,
direction
in
itertools
.
product
(
passes
,
cameras
,
directions
):
image_dirs
=
sorted
(
glob
(
str
(
root
/
pass_name
/
split
/
"*/*"
)))
image_dirs
=
sorted
([
Path
(
image_dir
)
/
camera
for
image_dir
in
image_dirs
])
flow_dirs
=
sorted
(
glob
(
str
(
root
/
"optical_flow"
/
split
/
"*/*"
)))
flow_dirs
=
sorted
([
Path
(
flow_dir
)
/
direction
/
camera
for
flow_dir
in
flow_dirs
])
if
not
image_dirs
or
not
flow_dirs
:
raise
FileNotFoundError
(
"Could not find the FlyingThings3D flow images. "
"Please make sure the directory structure is correct."
)
for
image_dir
,
flow_dir
in
zip
(
image_dirs
,
flow_dirs
):
images
=
sorted
(
glob
(
str
(
image_dir
/
"*.png"
)))
flows
=
sorted
(
glob
(
str
(
flow_dir
/
"*.pfm"
)))
for
i
in
range
(
len
(
flows
)
-
1
):
if
direction
==
"into_future"
:
self
.
_image_list
+=
[[
images
[
i
],
images
[
i
+
1
]]]
self
.
_flow_list
+=
[
flows
[
i
]]
elif
direction
==
"into_past"
:
self
.
_image_list
+=
[[
images
[
i
+
1
],
images
[
i
]]]
self
.
_flow_list
+=
[
flows
[
i
+
1
]]
def
__getitem__
(
self
,
index
):
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
"""
return
super
().
__getitem__
(
index
)
def
_read_flow
(
self
,
file_name
):
return
_read_pfm
(
file_name
)
def
_read_flo
(
file_name
):
"""Read .flo file in Middlebury format"""
# Code adapted from:
...
...
@@ -295,3 +386,31 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name):
# For consistency with other datasets, we convert to numpy
return
flow
.
numpy
(),
valid
.
numpy
()
def
_read_pfm
(
file_name
):
"""Read flow in .pfm format"""
with
open
(
file_name
,
"rb"
)
as
f
:
header
=
f
.
readline
().
rstrip
()
if
header
!=
b
"PF"
:
raise
ValueError
(
"Invalid PFM file"
)
dim_match
=
re
.
match
(
rb
"^(\d+)\s(\d+)\s$"
,
f
.
readline
())
if
not
dim_match
:
raise
Exception
(
"Malformed PFM header."
)
w
,
h
=
(
int
(
dim
)
for
dim
in
dim_match
.
groups
())
scale
=
float
(
f
.
readline
().
rstrip
())
if
scale
<
0
:
# little-endian
endian
=
"<"
scale
=
-
scale
else
:
endian
=
">"
# big-endian
data
=
np
.
fromfile
(
f
,
dtype
=
endian
+
"f"
)
data
=
data
.
reshape
(
h
,
w
,
3
).
transpose
(
2
,
0
,
1
)
data
=
np
.
flip
(
data
,
axis
=
1
)
# flip on h dimension
data
=
data
[:
2
,
:,
:]
return
data
.
astype
(
np
.
float32
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment