Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
b9f4ed93
Unverified
Commit
b9f4ed93
authored
Nov 12, 2021
by
Nicolas Hug
Committed by
GitHub
Nov 12, 2021
Browse files
Add HD1K dataset for optical flow (#4890)
parent
22ff44fd
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
113 additions
and
1 deletion
+113
-1
docs/source/datasets.rst
docs/source/datasets.rst
+1
-0
test/test_datasets.py
test/test_datasets.py
+42
-0
torchvision/datasets/__init__.py
torchvision/datasets/__init__.py
+2
-1
torchvision/datasets/_optical_flow.py
torchvision/datasets/_optical_flow.py
+68
-0
No files found.
docs/source/datasets.rst
View file @
b9f4ed93
...
...
@@ -45,6 +45,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
Flickr30k
FlyingChairs
FlyingThings3D
HD1K
HMDB51
ImageNet
INaturalist
...
...
test/test_datasets.py
View file @
b9f4ed93
...
...
@@ -2126,5 +2126,47 @@ class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase):
pass
class
HD1KTestCase
(
KittiFlowTestCase
):
DATASET_CLASS
=
datasets
.
HD1K
def
inject_fake_data
(
self
,
tmpdir
,
config
):
root
=
pathlib
.
Path
(
tmpdir
)
/
"hd1k"
num_sequences
=
4
if
config
[
"split"
]
==
"train"
else
3
num_examples_per_train_sequence
=
3
for
seq_idx
in
range
(
num_sequences
):
# Training data
datasets_utils
.
create_image_folder
(
root
/
"hd1k_input"
,
name
=
"image_2"
,
file_name_fn
=
lambda
image_idx
:
f
"
{
seq_idx
:
06
d
}
_
{
image_idx
}
.png"
,
num_examples
=
num_examples_per_train_sequence
,
)
datasets_utils
.
create_image_folder
(
root
/
"hd1k_flow_gt"
,
name
=
"flow_occ"
,
file_name_fn
=
lambda
image_idx
:
f
"
{
seq_idx
:
06
d
}
_
{
image_idx
}
.png"
,
num_examples
=
num_examples_per_train_sequence
,
)
# Test data
datasets_utils
.
create_image_folder
(
root
/
"hd1k_challenge"
,
name
=
"image_2"
,
file_name_fn
=
lambda
_
:
f
"
{
seq_idx
:
06
d
}
_10.png"
,
num_examples
=
1
,
)
datasets_utils
.
create_image_folder
(
root
/
"hd1k_challenge"
,
name
=
"image_2"
,
file_name_fn
=
lambda
_
:
f
"
{
seq_idx
:
06
d
}
_11.png"
,
num_examples
=
1
,
)
num_examples_per_sequence
=
num_examples_per_train_sequence
if
config
[
"split"
]
==
"train"
else
2
return
num_sequences
*
(
num_examples_per_sequence
-
1
)
if
__name__
==
"__main__"
:
unittest
.
main
()
torchvision/datasets/__init__.py
View file @
b9f4ed93
from
._optical_flow
import
KittiFlow
,
Sintel
,
FlyingChairs
,
FlyingThings3D
from
._optical_flow
import
KittiFlow
,
Sintel
,
FlyingChairs
,
FlyingThings3D
,
HD1K
from
.caltech
import
Caltech101
,
Caltech256
from
.celeba
import
CelebA
from
.cifar
import
CIFAR10
,
CIFAR100
...
...
@@ -76,4 +76,5 @@ __all__ = (
"Sintel"
,
"FlyingChairs"
,
"FlyingThings3D"
,
"HD1K"
,
)
torchvision/datasets/_optical_flow.py
View file @
b9f4ed93
...
...
@@ -19,6 +19,7 @@ __all__ = (
"Sintel"
,
"FlyingThings3D"
,
"FlyingChairs"
,
"HD1K"
,
)
...
...
@@ -363,6 +364,73 @@ class FlyingThings3D(FlowDataset):
return
_read_pfm
(
file_name
)
class
HD1K
(
FlowDataset
):
"""`HD1K <http://hci-benchmark.iwr.uni-heidelberg.de/>`__ dataset for optical flow.
The dataset is expected to have the following structure: ::
root
hd1k
hd1k_challenge
image_2
hd1k_flow_gt
flow_occ
hd1k_input
image_2
Args:
root (string): Root directory of the HD1K Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
"""
_has_builtin_flow_mask
=
True
def
__init__
(
self
,
root
,
split
=
"train"
,
transforms
=
None
):
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
root
=
Path
(
root
)
/
"hd1k"
if
split
==
"train"
:
# There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop
for
seq_idx
in
range
(
36
):
flows
=
sorted
(
glob
(
str
(
root
/
"hd1k_flow_gt"
/
"flow_occ"
/
f
"
{
seq_idx
:
06
d
}
_*.png"
)))
images
=
sorted
(
glob
(
str
(
root
/
"hd1k_input"
/
"image_2"
/
f
"
{
seq_idx
:
06
d
}
_*.png"
)))
for
i
in
range
(
len
(
flows
)
-
1
):
self
.
_flow_list
+=
[
flows
[
i
]]
self
.
_image_list
+=
[[
images
[
i
],
images
[
i
+
1
]]]
else
:
images1
=
sorted
(
glob
(
str
(
root
/
"hd1k_challenge"
/
"image_2"
/
"*10.png"
)))
images2
=
sorted
(
glob
(
str
(
root
/
"hd1k_challenge"
/
"image_2"
/
"*11.png"
)))
for
image1
,
image2
in
zip
(
images1
,
images2
):
self
.
_image_list
+=
[[
image1
,
image2
]]
if
not
self
.
_image_list
:
raise
FileNotFoundError
(
"Could not find the HD1K images. Please make sure the directory structure is correct."
)
def
_read_flow
(
self
,
file_name
):
return
_read_16bits_png_with_flow_and_valid_mask
(
file_name
)
def
__getitem__
(
self
,
index
):
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
valid)`` where ``valid`` is a numpy boolean mask of shape (H, W)
indicating which flow values are valid. The flow is a numpy array of
shape (2, H, W) and the images are PIL images. If `split="test"`, a
4-tuple with ``(img1, img2, None, None)`` is returned.
"""
return
super
().
__getitem__
(
index
)
def
_read_flo
(
file_name
):
"""Read .flo file in Middlebury format"""
# Code adapted from:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment