Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7f424379
Unverified
Commit
7f424379
authored
Nov 05, 2021
by
Nicolas Hug
Committed by
GitHub
Nov 05, 2021
Browse files
Add FlyingChairs dataset for optical flow (#4860)
parent
eb48a1d8
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
144 additions
and
20 deletions
+144
-20
docs/source/datasets.rst
docs/source/datasets.rst
+1
-0
test/datasets_utils.py
test/datasets_utils.py
+9
-0
test/test_datasets.py
test/test_datasets.py
+62
-12
torchvision/datasets/__init__.py
torchvision/datasets/__init__.py
+2
-1
torchvision/datasets/_optical_flow.py
torchvision/datasets/_optical_flow.py
+70
-7
No files found.
docs/source/datasets.rst
View file @
7f424379
...
@@ -43,6 +43,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
...
@@ -43,6 +43,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
FashionMNIST
FashionMNIST
Flickr8k
Flickr8k
Flickr30k
Flickr30k
FlyingChairs
HMDB51
HMDB51
ImageNet
ImageNet
INaturalist
INaturalist
...
...
test/datasets_utils.py
View file @
7f424379
...
@@ -8,6 +8,7 @@ import pathlib
...
@@ -8,6 +8,7 @@ import pathlib
import
random
import
random
import
shutil
import
shutil
import
string
import
string
import
struct
import
tarfile
import
tarfile
import
unittest
import
unittest
import
unittest.mock
import
unittest.mock
...
@@ -922,3 +923,11 @@ def create_random_string(length: int, *digits: str) -> str:
...
@@ -922,3 +923,11 @@ def create_random_string(length: int, *digits: str) -> str:
digits
=
""
.
join
(
itertools
.
chain
(
*
digits
))
digits
=
""
.
join
(
itertools
.
chain
(
*
digits
))
return
""
.
join
(
random
.
choice
(
digits
)
for
_
in
range
(
length
))
return
""
.
join
(
random
.
choice
(
digits
)
for
_
in
range
(
length
))
def
make_fake_flo_file
(
h
,
w
,
file_name
):
"""Creates a fake flow file in .flo format."""
values
=
list
(
range
(
2
*
h
*
w
))
content
=
b
"PIEH"
+
struct
.
pack
(
"i"
,
w
)
+
struct
.
pack
(
"i"
,
h
)
+
struct
.
pack
(
"f"
*
len
(
values
),
*
values
)
with
open
(
file_name
,
"wb"
)
as
f
:
f
.
write
(
content
)
test/test_datasets.py
View file @
7f424379
...
@@ -1874,11 +1874,9 @@ class LFWPairsTestCase(LFWPeopleTestCase):
...
@@ -1874,11 +1874,9 @@ class LFWPairsTestCase(LFWPeopleTestCase):
class
SintelTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
SintelTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
Sintel
DATASET_CLASS
=
datasets
.
Sintel
ADDITIONAL_CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
"train"
,
"test"
),
pass_name
=
(
"clean"
,
"final"
))
ADDITIONAL_CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
"train"
,
"test"
),
pass_name
=
(
"clean"
,
"final"
))
# We patch the flow reader, because this would otherwise force us to generate fake (but readable) .flo files,
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
(
np
.
ndarray
,
type
(
None
)))
# which is something we want to # avoid.
_FAKE_FLOW
=
"Fake Flow"
FLOW_H
,
FLOW_W
=
3
,
4
EXTRA_PATCHES
=
{
unittest
.
mock
.
patch
(
"torchvision.datasets.Sintel._read_flow"
,
return_value
=
_FAKE_FLOW
)}
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
(
type
(
_FAKE_FLOW
),
type
(
None
)))
def
inject_fake_data
(
self
,
tmpdir
,
config
):
def
inject_fake_data
(
self
,
tmpdir
,
config
):
root
=
pathlib
.
Path
(
tmpdir
)
/
"Sintel"
root
=
pathlib
.
Path
(
tmpdir
)
/
"Sintel"
...
@@ -1899,14 +1897,13 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1899,14 +1897,13 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
num_examples
=
num_images_per_scene
,
num_examples
=
num_images_per_scene
,
)
)
# For the ground truth flow value we just create empty files so that they're properly discovered,
# see comment above about EXTRA_PATCHES
flow_root
=
root
/
"training"
/
"flow"
flow_root
=
root
/
"training"
/
"flow"
for
scene_id
in
range
(
num_scenes
):
for
scene_id
in
range
(
num_scenes
):
scene_dir
=
flow_root
/
f
"scene_
{
scene_id
}
"
scene_dir
=
flow_root
/
f
"scene_
{
scene_id
}
"
os
.
makedirs
(
scene_dir
)
os
.
makedirs
(
scene_dir
)
for
i
in
range
(
num_images_per_scene
-
1
):
for
i
in
range
(
num_images_per_scene
-
1
):
open
(
str
(
scene_dir
/
f
"frame_000
{
i
}
.flo"
),
"a"
).
close
()
file_name
=
str
(
scene_dir
/
f
"frame_000
{
i
}
.flo"
)
datasets_utils
.
make_fake_flo_file
(
h
=
self
.
FLOW_H
,
w
=
self
.
FLOW_W
,
file_name
=
file_name
)
# with e.g. num_images_per_scene = 3, for a single scene with have 3 images
# with e.g. num_images_per_scene = 3, for a single scene with have 3 images
# which are frame_0000, frame_0001 and frame_0002
# which are frame_0000, frame_0001 and frame_0002
...
@@ -1920,7 +1917,8 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1920,7 +1917,8 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
with
self
.
create_dataset
(
split
=
"train"
)
as
(
dataset
,
_
):
with
self
.
create_dataset
(
split
=
"train"
)
as
(
dataset
,
_
):
assert
dataset
.
_flow_list
and
len
(
dataset
.
_flow_list
)
==
len
(
dataset
.
_image_list
)
assert
dataset
.
_flow_list
and
len
(
dataset
.
_flow_list
)
==
len
(
dataset
.
_image_list
)
for
_
,
_
,
flow
in
dataset
:
for
_
,
_
,
flow
in
dataset
:
assert
flow
==
self
.
_FAKE_FLOW
assert
flow
.
shape
==
(
2
,
self
.
FLOW_H
,
self
.
FLOW_W
)
np
.
testing
.
assert_allclose
(
flow
,
np
.
arange
(
flow
.
size
).
reshape
(
flow
.
shape
))
# Make sure flow is always None for test split
# Make sure flow is always None for test split
with
self
.
create_dataset
(
split
=
"test"
)
as
(
dataset
,
_
):
with
self
.
create_dataset
(
split
=
"test"
)
as
(
dataset
,
_
):
...
@@ -1929,11 +1927,11 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1929,11 +1927,11 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
assert
flow
is
None
assert
flow
is
None
def
test_bad_input
(
self
):
def
test_bad_input
(
self
):
with
pytest
.
raises
(
ValueError
,
match
=
"
split must be either
"
):
with
pytest
.
raises
(
ValueError
,
match
=
"
Unknown value 'bad' for argument split
"
):
with
self
.
create_dataset
(
split
=
"bad"
):
with
self
.
create_dataset
(
split
=
"bad"
):
pass
pass
with
pytest
.
raises
(
ValueError
,
match
=
"
pass_name must be either
"
):
with
pytest
.
raises
(
ValueError
,
match
=
"
Unknown value 'bad' for argument pass_name
"
):
with
self
.
create_dataset
(
pass_name
=
"bad"
):
with
self
.
create_dataset
(
pass_name
=
"bad"
):
pass
pass
...
@@ -1993,10 +1991,62 @@ class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1993,10 +1991,62 @@ class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase):
assert
valid
is
None
assert
valid
is
None
def
test_bad_input
(
self
):
def
test_bad_input
(
self
):
with
pytest
.
raises
(
ValueError
,
match
=
"
split must be either
"
):
with
pytest
.
raises
(
ValueError
,
match
=
"
Unknown value 'bad' for argument split
"
):
with
self
.
create_dataset
(
split
=
"bad"
):
with
self
.
create_dataset
(
split
=
"bad"
):
pass
pass
class
FlyingChairsTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
FlyingChairs
ADDITIONAL_CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
"train"
,
"val"
))
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
(
np
.
ndarray
,
type
(
None
)))
FLOW_H
,
FLOW_W
=
3
,
4
def
_make_split_file
(
self
,
root
,
num_examples
):
# We create a fake split file here, but users are asked to download the real one from the authors website
split_ids
=
[
1
]
*
num_examples
[
"train"
]
+
[
2
]
*
num_examples
[
"val"
]
random
.
shuffle
(
split_ids
)
with
open
(
str
(
root
/
"FlyingChairs_train_val.txt"
),
"w+"
)
as
split_file
:
for
split_id
in
split_ids
:
split_file
.
write
(
f
"
{
split_id
}
\n
"
)
def
inject_fake_data
(
self
,
tmpdir
,
config
):
root
=
pathlib
.
Path
(
tmpdir
)
/
"FlyingChairs"
num_examples
=
{
"train"
:
5
,
"val"
:
3
}
num_examples_total
=
sum
(
num_examples
.
values
())
datasets_utils
.
create_image_folder
(
# img1
root
,
name
=
"data"
,
file_name_fn
=
lambda
image_idx
:
f
"00
{
image_idx
}
_img1.ppm"
,
num_examples
=
num_examples_total
,
)
datasets_utils
.
create_image_folder
(
# img2
root
,
name
=
"data"
,
file_name_fn
=
lambda
image_idx
:
f
"00
{
image_idx
}
_img2.ppm"
,
num_examples
=
num_examples_total
,
)
for
i
in
range
(
num_examples_total
):
file_name
=
str
(
root
/
"data"
/
f
"00
{
i
}
_flow.flo"
)
datasets_utils
.
make_fake_flo_file
(
h
=
self
.
FLOW_H
,
w
=
self
.
FLOW_W
,
file_name
=
file_name
)
self
.
_make_split_file
(
root
,
num_examples
)
return
num_examples
[
config
[
"split"
]]
@
datasets_utils
.
test_all_configs
def
test_flow
(
self
,
config
):
# Make sure flow always exists, and make sure there are as many flow values as (pairs of) images
# Also make sure the flow is properly decoded
with
self
.
create_dataset
(
config
=
config
)
as
(
dataset
,
_
):
assert
dataset
.
_flow_list
and
len
(
dataset
.
_flow_list
)
==
len
(
dataset
.
_image_list
)
for
_
,
_
,
flow
in
dataset
:
assert
flow
.
shape
==
(
2
,
self
.
FLOW_H
,
self
.
FLOW_W
)
np
.
testing
.
assert_allclose
(
flow
,
np
.
arange
(
flow
.
size
).
reshape
(
flow
.
shape
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
torchvision/datasets/__init__.py
View file @
7f424379
from
._optical_flow
import
KittiFlow
,
Sintel
from
._optical_flow
import
KittiFlow
,
Sintel
,
FlyingChairs
from
.caltech
import
Caltech101
,
Caltech256
from
.caltech
import
Caltech101
,
Caltech256
from
.celeba
import
CelebA
from
.celeba
import
CelebA
from
.cifar
import
CIFAR10
,
CIFAR100
from
.cifar
import
CIFAR10
,
CIFAR100
...
@@ -74,4 +74,5 @@ __all__ = (
...
@@ -74,4 +74,5 @@ __all__ = (
"LFWPairs"
,
"LFWPairs"
,
"KittiFlow"
,
"KittiFlow"
,
"Sintel"
,
"Sintel"
,
"FlyingChairs"
,
)
)
torchvision/datasets/_optical_flow.py
View file @
7f424379
...
@@ -8,12 +8,14 @@ import torch
...
@@ -8,12 +8,14 @@ import torch
from
PIL
import
Image
from
PIL
import
Image
from
..io.image
import
_read_png_16
from
..io.image
import
_read_png_16
from
.utils
import
verify_str_arg
from
.vision
import
VisionDataset
from
.vision
import
VisionDataset
__all__
=
(
__all__
=
(
"KittiFlow"
,
"KittiFlow"
,
"Sintel"
,
"Sintel"
,
"FlyingChairs"
,
)
)
...
@@ -109,11 +111,8 @@ class Sintel(FlowDataset):
...
@@ -109,11 +111,8 @@ class Sintel(FlowDataset):
def
__init__
(
self
,
root
,
split
=
"train"
,
pass_name
=
"clean"
,
transforms
=
None
):
def
__init__
(
self
,
root
,
split
=
"train"
,
pass_name
=
"clean"
,
transforms
=
None
):
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
if
split
not
in
(
"train"
,
"test"
):
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
raise
ValueError
(
"split must be either 'train' or 'test'"
)
verify_str_arg
(
pass_name
,
"pass_name"
,
valid_values
=
(
"clean"
,
"final"
))
if
pass_name
not
in
(
"clean"
,
"final"
):
raise
ValueError
(
"pass_name must be either 'clean' or 'final'"
)
root
=
Path
(
root
)
/
"Sintel"
root
=
Path
(
root
)
/
"Sintel"
...
@@ -171,8 +170,7 @@ class KittiFlow(FlowDataset):
...
@@ -171,8 +170,7 @@ class KittiFlow(FlowDataset):
def
__init__
(
self
,
root
,
split
=
"train"
,
transforms
=
None
):
def
__init__
(
self
,
root
,
split
=
"train"
,
transforms
=
None
):
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
if
split
not
in
(
"train"
,
"test"
):
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
raise
ValueError
(
"split must be either 'train' or 'test'"
)
root
=
Path
(
root
)
/
"Kitti"
/
(
split
+
"ing"
)
root
=
Path
(
root
)
/
"Kitti"
/
(
split
+
"ing"
)
images1
=
sorted
(
glob
(
str
(
root
/
"image_2"
/
"*_10.png"
)))
images1
=
sorted
(
glob
(
str
(
root
/
"image_2"
/
"*_10.png"
)))
...
@@ -208,6 +206,71 @@ class KittiFlow(FlowDataset):
...
@@ -208,6 +206,71 @@ class KittiFlow(FlowDataset):
return
_read_16bits_png_with_flow_and_valid_mask
(
file_name
)
return
_read_16bits_png_with_flow_and_valid_mask
(
file_name
)
class
FlyingChairs
(
FlowDataset
):
"""`FlyingChairs <https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs>`_ Dataset for optical flow.
You will also need to download the FlyingChairs_train_val.txt file from the dataset page.
The dataset is expected to have the following structure: ::
root
FlyingChairs
data
00001_flow.flo
00001_img1.ppm
00001_img2.ppm
...
FlyingChairs_train_val.txt
Args:
root (string): Root directory of the FlyingChairs Dataset.
split (string, optional): The dataset split, either "train" (default) or "val"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
def
__init__
(
self
,
root
,
split
=
"train"
,
transforms
=
None
):
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"val"
))
root
=
Path
(
root
)
/
"FlyingChairs"
images
=
sorted
(
glob
(
str
(
root
/
"data"
/
"*.ppm"
)))
flows
=
sorted
(
glob
(
str
(
root
/
"data"
/
"*.flo"
)))
split_file_name
=
"FlyingChairs_train_val.txt"
if
not
os
.
path
.
exists
(
root
/
split_file_name
):
raise
FileNotFoundError
(
"The FlyingChairs_train_val.txt file was not found - please download it from the dataset page (see docstring)."
)
split_list
=
np
.
loadtxt
(
str
(
root
/
split_file_name
),
dtype
=
np
.
int32
)
for
i
in
range
(
len
(
flows
)):
split_id
=
split_list
[
i
]
if
(
split
==
"train"
and
split_id
==
1
)
or
(
split
==
"val"
and
split_id
==
2
):
self
.
_flow_list
+=
[
flows
[
i
]]
self
.
_image_list
+=
[[
images
[
2
*
i
],
images
[
2
*
i
+
1
]]]
def
__getitem__
(
self
,
index
):
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
"""
return
super
().
__getitem__
(
index
)
def
_read_flow
(
self
,
file_name
):
return
_read_flo
(
file_name
)
def
_read_flo
(
file_name
):
def
_read_flo
(
file_name
):
"""Read .flo file in Middlebury format"""
"""Read .flo file in Middlebury format"""
# Code adapted from:
# Code adapted from:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment