Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
1bd131c7
Unverified
Commit
1bd131c7
authored
Nov 05, 2021
by
Nicolas Hug
Committed by
GitHub
Nov 05, 2021
Browse files
Add FlyingThings3D dataset for optical flow (#4858)
parent
7f424379
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
198 additions
and
8 deletions
+198
-8
docs/source/datasets.rst
docs/source/datasets.rst
+1
-0
test/datasets_utils.py
test/datasets_utils.py
+9
-7
test/test_datasets.py
test/test_datasets.py
+67
-0
torchvision/datasets/__init__.py
torchvision/datasets/__init__.py
+2
-1
torchvision/datasets/_optical_flow.py
torchvision/datasets/_optical_flow.py
+119
-0
No files found.
docs/source/datasets.rst
View file @
1bd131c7
...
...
@@ -44,6 +44,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
Flickr8k
Flickr30k
FlyingChairs
FlyingThings3D
HMDB51
ImageNet
INaturalist
...
...
test/datasets_utils.py
View file @
1bd131c7
...
...
@@ -204,7 +204,6 @@ class DatasetTestCase(unittest.TestCase):
``transforms``, or ``download``.
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
available, the tests are skipped.
- EXTRA_PATCHES(set): Additional patches to add for each test, to e.g. mock a specific function
Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on.
The fake data should resemble the original data as close as necessary, while containing only few examples. During
...
...
@@ -256,8 +255,6 @@ class DatasetTestCase(unittest.TestCase):
ADDITIONAL_CONFIGS
=
None
REQUIRED_PACKAGES
=
None
EXTRA_PATCHES
=
None
# These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
_TRANSFORM_KWARGS
=
{
"transform"
,
...
...
@@ -383,9 +380,6 @@ class DatasetTestCase(unittest.TestCase):
if
patch_checks
:
patchers
.
update
(
self
.
_patch_checks
())
if
self
.
EXTRA_PATCHES
is
not
None
:
patchers
.
update
(
self
.
EXTRA_PATCHES
)
with
get_tmp_dir
()
as
tmpdir
:
args
=
self
.
dataset_args
(
tmpdir
,
complete_config
)
info
=
self
.
_inject_fake_data
(
tmpdir
,
complete_config
)
if
inject_fake_data
else
None
...
...
@@ -393,7 +387,7 @@ class DatasetTestCase(unittest.TestCase):
with
self
.
_maybe_apply_patches
(
patchers
),
disable_console_output
():
dataset
=
self
.
DATASET_CLASS
(
*
args
,
**
complete_config
,
**
special_kwargs
)
yield
dataset
,
info
yield
dataset
,
info
@
classmethod
def
setUpClass
(
cls
):
...
...
@@ -925,6 +919,14 @@ def create_random_string(length: int, *digits: str) -> str:
return
""
.
join
(
random
.
choice
(
digits
)
for
_
in
range
(
length
))
def
make_fake_pfm_file
(
h
,
w
,
file_name
):
values
=
list
(
range
(
3
*
h
*
w
))
# Note: we pack everything in little endian: -1.0, and "<"
content
=
f
"PF
\n
{
w
}
{
h
}
\n
-1.0
\n
"
.
encode
()
+
struct
.
pack
(
"<"
+
"f"
*
len
(
values
),
*
values
)
with
open
(
file_name
,
"wb"
)
as
f
:
f
.
write
(
content
)
def
make_fake_flo_file
(
h
,
w
,
file_name
):
"""Creates a fake flow file in .flo format."""
values
=
list
(
range
(
2
*
h
*
w
))
...
...
test/test_datasets.py
View file @
1bd131c7
...
...
@@ -2048,5 +2048,72 @@ class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase):
np
.
testing
.
assert_allclose
(
flow
,
np
.
arange
(
flow
.
size
).
reshape
(
flow
.
shape
))
class
FlyingThings3DTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
FlyingThings3D
ADDITIONAL_CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
"train"
,
"test"
),
pass_name
=
(
"clean"
,
"final"
,
"both"
),
camera
=
(
"left"
,
"right"
,
"both"
)
)
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
(
np
.
ndarray
,
type
(
None
)))
FLOW_H
,
FLOW_W
=
3
,
4
def
inject_fake_data
(
self
,
tmpdir
,
config
):
root
=
pathlib
.
Path
(
tmpdir
)
/
"FlyingThings3D"
num_images_per_camera
=
3
if
config
[
"split"
]
==
"train"
else
4
passes
=
(
"frames_cleanpass"
,
"frames_finalpass"
)
splits
=
(
"TRAIN"
,
"TEST"
)
letters
=
(
"A"
,
"B"
,
"C"
)
subfolders
=
(
"0000"
,
"0001"
)
cameras
=
(
"left"
,
"right"
)
for
pass_name
,
split
,
letter
,
subfolder
,
camera
in
itertools
.
product
(
passes
,
splits
,
letters
,
subfolders
,
cameras
):
current_folder
=
root
/
pass_name
/
split
/
letter
/
subfolder
datasets_utils
.
create_image_folder
(
current_folder
,
name
=
camera
,
file_name_fn
=
lambda
image_idx
:
f
"00
{
image_idx
}
.png"
,
num_examples
=
num_images_per_camera
,
)
directions
=
(
"into_future"
,
"into_past"
)
for
split
,
letter
,
subfolder
,
direction
,
camera
in
itertools
.
product
(
splits
,
letters
,
subfolders
,
directions
,
cameras
):
current_folder
=
root
/
"optical_flow"
/
split
/
letter
/
subfolder
/
direction
/
camera
os
.
makedirs
(
str
(
current_folder
),
exist_ok
=
True
)
for
i
in
range
(
num_images_per_camera
):
datasets_utils
.
make_fake_pfm_file
(
self
.
FLOW_H
,
self
.
FLOW_W
,
file_name
=
str
(
current_folder
/
f
"
{
i
}
.pfm"
))
num_cameras
=
2
if
config
[
"camera"
]
==
"both"
else
1
num_passes
=
2
if
config
[
"pass_name"
]
==
"both"
else
1
num_examples
=
(
(
num_images_per_camera
-
1
)
*
num_cameras
*
len
(
subfolders
)
*
len
(
letters
)
*
len
(
splits
)
*
num_passes
)
return
num_examples
@
datasets_utils
.
test_all_configs
def
test_flow
(
self
,
config
):
with
self
.
create_dataset
(
config
=
config
)
as
(
dataset
,
_
):
assert
dataset
.
_flow_list
and
len
(
dataset
.
_flow_list
)
==
len
(
dataset
.
_image_list
)
for
_
,
_
,
flow
in
dataset
:
assert
flow
.
shape
==
(
2
,
self
.
FLOW_H
,
self
.
FLOW_W
)
# We don't check the values because the reshaping and flipping makes it hard to figure out
def
test_bad_input
(
self
):
with
pytest
.
raises
(
ValueError
,
match
=
"Unknown value 'bad' for argument split"
):
with
self
.
create_dataset
(
split
=
"bad"
):
pass
with
pytest
.
raises
(
ValueError
,
match
=
"Unknown value 'bad' for argument pass_name"
):
with
self
.
create_dataset
(
pass_name
=
"bad"
):
pass
with
pytest
.
raises
(
ValueError
,
match
=
"Unknown value 'bad' for argument camera"
):
with
self
.
create_dataset
(
camera
=
"bad"
):
pass
if
__name__
==
"__main__"
:
unittest
.
main
()
torchvision/datasets/__init__.py
View file @
1bd131c7
from
._optical_flow
import
KittiFlow
,
Sintel
,
FlyingChairs
from
._optical_flow
import
KittiFlow
,
Sintel
,
FlyingChairs
,
FlyingThings3D
from
.caltech
import
Caltech101
,
Caltech256
from
.celeba
import
CelebA
from
.cifar
import
CIFAR10
,
CIFAR100
...
...
@@ -75,4 +75,5 @@ __all__ = (
"KittiFlow"
,
"Sintel"
,
"FlyingChairs"
,
"FlyingThings3D"
,
)
torchvision/datasets/_optical_flow.py
View file @
1bd131c7
import
itertools
import
os
import
re
from
abc
import
ABC
,
abstractmethod
from
glob
import
glob
from
pathlib
import
Path
...
...
@@ -15,6 +17,7 @@ from .vision import VisionDataset
__all__
=
(
"KittiFlow"
,
"Sintel"
,
"FlyingThings3D"
,
"FlyingChairs"
,
)
...
...
@@ -271,6 +274,94 @@ class FlyingChairs(FlowDataset):
return
_read_flo
(
file_name
)
class
FlyingThings3D
(
FlowDataset
):
"""`FlyingThings3D <https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html>`_ dataset for optical flow.
The dataset is expected to have the following structure: ::
root
FlyingThings3D
frames_cleanpass
TEST
TRAIN
frames_finalpass
TEST
TRAIN
optical_flow
TEST
TRAIN
Args:
root (string): Root directory of the intel FlyingThings3D Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
details on the different passes.
camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both".
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
def
__init__
(
self
,
root
,
split
=
"train"
,
pass_name
=
"clean"
,
camera
=
"left"
,
transforms
=
None
):
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
split
=
split
.
upper
()
verify_str_arg
(
pass_name
,
"pass_name"
,
valid_values
=
(
"clean"
,
"final"
,
"both"
))
passes
=
{
"clean"
:
[
"frames_cleanpass"
],
"final"
:
[
"frames_finalpass"
],
"both"
:
[
"frames_cleanpass"
,
"frames_finalpass"
],
}[
pass_name
]
verify_str_arg
(
camera
,
"camera"
,
valid_values
=
(
"left"
,
"right"
,
"both"
))
cameras
=
[
"left"
,
"right"
]
if
camera
==
"both"
else
[
camera
]
root
=
Path
(
root
)
/
"FlyingThings3D"
directions
=
(
"into_future"
,
"into_past"
)
for
pass_name
,
camera
,
direction
in
itertools
.
product
(
passes
,
cameras
,
directions
):
image_dirs
=
sorted
(
glob
(
str
(
root
/
pass_name
/
split
/
"*/*"
)))
image_dirs
=
sorted
([
Path
(
image_dir
)
/
camera
for
image_dir
in
image_dirs
])
flow_dirs
=
sorted
(
glob
(
str
(
root
/
"optical_flow"
/
split
/
"*/*"
)))
flow_dirs
=
sorted
([
Path
(
flow_dir
)
/
direction
/
camera
for
flow_dir
in
flow_dirs
])
if
not
image_dirs
or
not
flow_dirs
:
raise
FileNotFoundError
(
"Could not find the FlyingThings3D flow images. "
"Please make sure the directory structure is correct."
)
for
image_dir
,
flow_dir
in
zip
(
image_dirs
,
flow_dirs
):
images
=
sorted
(
glob
(
str
(
image_dir
/
"*.png"
)))
flows
=
sorted
(
glob
(
str
(
flow_dir
/
"*.pfm"
)))
for
i
in
range
(
len
(
flows
)
-
1
):
if
direction
==
"into_future"
:
self
.
_image_list
+=
[[
images
[
i
],
images
[
i
+
1
]]]
self
.
_flow_list
+=
[
flows
[
i
]]
elif
direction
==
"into_past"
:
self
.
_image_list
+=
[[
images
[
i
+
1
],
images
[
i
]]]
self
.
_flow_list
+=
[
flows
[
i
+
1
]]
def
__getitem__
(
self
,
index
):
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
"""
return
super
().
__getitem__
(
index
)
def
_read_flow
(
self
,
file_name
):
return
_read_pfm
(
file_name
)
def
_read_flo
(
file_name
):
"""Read .flo file in Middlebury format"""
# Code adapted from:
...
...
@@ -295,3 +386,31 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name):
# For consistency with other datasets, we convert to numpy
return
flow
.
numpy
(),
valid
.
numpy
()
def
_read_pfm
(
file_name
):
"""Read flow in .pfm format"""
with
open
(
file_name
,
"rb"
)
as
f
:
header
=
f
.
readline
().
rstrip
()
if
header
!=
b
"PF"
:
raise
ValueError
(
"Invalid PFM file"
)
dim_match
=
re
.
match
(
rb
"^(\d+)\s(\d+)\s$"
,
f
.
readline
())
if
not
dim_match
:
raise
Exception
(
"Malformed PFM header."
)
w
,
h
=
(
int
(
dim
)
for
dim
in
dim_match
.
groups
())
scale
=
float
(
f
.
readline
().
rstrip
())
if
scale
<
0
:
# little-endian
endian
=
"<"
scale
=
-
scale
else
:
endian
=
">"
# big-endian
data
=
np
.
fromfile
(
f
,
dtype
=
endian
+
"f"
)
data
=
data
.
reshape
(
h
,
w
,
3
).
transpose
(
2
,
0
,
1
)
data
=
np
.
flip
(
data
,
axis
=
1
)
# flip on h dimension
data
=
data
[:
2
,
:,
:]
return
data
.
astype
(
np
.
float32
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment