Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7f424379
"...text-generation-inference.git" did not exist on "68e9d6ab333715008c542467c8d5202cf4692253"
Unverified
Commit
7f424379
authored
Nov 05, 2021
by
Nicolas Hug
Committed by
GitHub
Nov 05, 2021
Browse files
Add FlyingChairs dataset for optical flow (#4860)
parent
eb48a1d8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
144 additions
and
20 deletions
+144
-20
docs/source/datasets.rst
docs/source/datasets.rst
+1
-0
test/datasets_utils.py
test/datasets_utils.py
+9
-0
test/test_datasets.py
test/test_datasets.py
+62
-12
torchvision/datasets/__init__.py
torchvision/datasets/__init__.py
+2
-1
torchvision/datasets/_optical_flow.py
torchvision/datasets/_optical_flow.py
+70
-7
No files found.
docs/source/datasets.rst
View file @
7f424379
...
@@ -43,6 +43,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
...
@@ -43,6 +43,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
FashionMNIST
FashionMNIST
Flickr8k
Flickr8k
Flickr30k
Flickr30k
FlyingChairs
HMDB51
HMDB51
ImageNet
ImageNet
INaturalist
INaturalist
...
...
test/datasets_utils.py
View file @
7f424379
...
@@ -8,6 +8,7 @@ import pathlib
...
@@ -8,6 +8,7 @@ import pathlib
import
random
import
random
import
shutil
import
shutil
import
string
import
string
import
struct
import
tarfile
import
tarfile
import
unittest
import
unittest
import
unittest.mock
import
unittest.mock
...
@@ -922,3 +923,11 @@ def create_random_string(length: int, *digits: str) -> str:
...
@@ -922,3 +923,11 @@ def create_random_string(length: int, *digits: str) -> str:
digits
=
""
.
join
(
itertools
.
chain
(
*
digits
))
digits
=
""
.
join
(
itertools
.
chain
(
*
digits
))
return
""
.
join
(
random
.
choice
(
digits
)
for
_
in
range
(
length
))
return
""
.
join
(
random
.
choice
(
digits
)
for
_
in
range
(
length
))
def
make_fake_flo_file
(
h
,
w
,
file_name
):
"""Creates a fake flow file in .flo format."""
values
=
list
(
range
(
2
*
h
*
w
))
content
=
b
"PIEH"
+
struct
.
pack
(
"i"
,
w
)
+
struct
.
pack
(
"i"
,
h
)
+
struct
.
pack
(
"f"
*
len
(
values
),
*
values
)
with
open
(
file_name
,
"wb"
)
as
f
:
f
.
write
(
content
)
test/test_datasets.py
View file @
7f424379
...
@@ -1874,11 +1874,9 @@ class LFWPairsTestCase(LFWPeopleTestCase):
...
@@ -1874,11 +1874,9 @@ class LFWPairsTestCase(LFWPeopleTestCase):
class
SintelTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
SintelTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
Sintel
DATASET_CLASS
=
datasets
.
Sintel
ADDITIONAL_CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
"train"
,
"test"
),
pass_name
=
(
"clean"
,
"final"
))
ADDITIONAL_CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
"train"
,
"test"
),
pass_name
=
(
"clean"
,
"final"
))
# We patch the flow reader, because this would otherwise force us to generate fake (but readable) .flo files,
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
(
np
.
ndarray
,
type
(
None
)))
# which is something we want to # avoid.
_FAKE_FLOW
=
"Fake Flow"
FLOW_H
,
FLOW_W
=
3
,
4
EXTRA_PATCHES
=
{
unittest
.
mock
.
patch
(
"torchvision.datasets.Sintel._read_flow"
,
return_value
=
_FAKE_FLOW
)}
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
(
type
(
_FAKE_FLOW
),
type
(
None
)))
def
inject_fake_data
(
self
,
tmpdir
,
config
):
def
inject_fake_data
(
self
,
tmpdir
,
config
):
root
=
pathlib
.
Path
(
tmpdir
)
/
"Sintel"
root
=
pathlib
.
Path
(
tmpdir
)
/
"Sintel"
...
@@ -1899,14 +1897,13 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1899,14 +1897,13 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
num_examples
=
num_images_per_scene
,
num_examples
=
num_images_per_scene
,
)
)
# For the ground truth flow value we just create empty files so that they're properly discovered,
# see comment above about EXTRA_PATCHES
flow_root
=
root
/
"training"
/
"flow"
flow_root
=
root
/
"training"
/
"flow"
for
scene_id
in
range
(
num_scenes
):
for
scene_id
in
range
(
num_scenes
):
scene_dir
=
flow_root
/
f
"scene_
{
scene_id
}
"
scene_dir
=
flow_root
/
f
"scene_
{
scene_id
}
"
os
.
makedirs
(
scene_dir
)
os
.
makedirs
(
scene_dir
)
for
i
in
range
(
num_images_per_scene
-
1
):
for
i
in
range
(
num_images_per_scene
-
1
):
open
(
str
(
scene_dir
/
f
"frame_000
{
i
}
.flo"
),
"a"
).
close
()
file_name
=
str
(
scene_dir
/
f
"frame_000
{
i
}
.flo"
)
datasets_utils
.
make_fake_flo_file
(
h
=
self
.
FLOW_H
,
w
=
self
.
FLOW_W
,
file_name
=
file_name
)
# with e.g. num_images_per_scene = 3, for a single scene with have 3 images
# with e.g. num_images_per_scene = 3, for a single scene with have 3 images
# which are frame_0000, frame_0001 and frame_0002
# which are frame_0000, frame_0001 and frame_0002
...
@@ -1920,7 +1917,8 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1920,7 +1917,8 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
with
self
.
create_dataset
(
split
=
"train"
)
as
(
dataset
,
_
):
with
self
.
create_dataset
(
split
=
"train"
)
as
(
dataset
,
_
):
assert
dataset
.
_flow_list
and
len
(
dataset
.
_flow_list
)
==
len
(
dataset
.
_image_list
)
assert
dataset
.
_flow_list
and
len
(
dataset
.
_flow_list
)
==
len
(
dataset
.
_image_list
)
for
_
,
_
,
flow
in
dataset
:
for
_
,
_
,
flow
in
dataset
:
assert
flow
==
self
.
_FAKE_FLOW
assert
flow
.
shape
==
(
2
,
self
.
FLOW_H
,
self
.
FLOW_W
)
np
.
testing
.
assert_allclose
(
flow
,
np
.
arange
(
flow
.
size
).
reshape
(
flow
.
shape
))
# Make sure flow is always None for test split
# Make sure flow is always None for test split
with
self
.
create_dataset
(
split
=
"test"
)
as
(
dataset
,
_
):
with
self
.
create_dataset
(
split
=
"test"
)
as
(
dataset
,
_
):
...
@@ -1929,11 +1927,11 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1929,11 +1927,11 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
assert
flow
is
None
assert
flow
is
None
def
test_bad_input
(
self
):
def
test_bad_input
(
self
):
with
pytest
.
raises
(
ValueError
,
match
=
"
split must be either
"
):
with
pytest
.
raises
(
ValueError
,
match
=
"
Unknown value 'bad' for argument split
"
):
with
self
.
create_dataset
(
split
=
"bad"
):
with
self
.
create_dataset
(
split
=
"bad"
):
pass
pass
with
pytest
.
raises
(
ValueError
,
match
=
"
pass_name must be either
"
):
with
pytest
.
raises
(
ValueError
,
match
=
"
Unknown value 'bad' for argument pass_name
"
):
with
self
.
create_dataset
(
pass_name
=
"bad"
):
with
self
.
create_dataset
(
pass_name
=
"bad"
):
pass
pass
...
@@ -1993,10 +1991,62 @@ class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1993,10 +1991,62 @@ class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase):
assert
valid
is
None
assert
valid
is
None
def
test_bad_input
(
self
):
def
test_bad_input
(
self
):
with
pytest
.
raises
(
ValueError
,
match
=
"
split must be either
"
):
with
pytest
.
raises
(
ValueError
,
match
=
"
Unknown value 'bad' for argument split
"
):
with
self
.
create_dataset
(
split
=
"bad"
):
with
self
.
create_dataset
(
split
=
"bad"
):
pass
pass
class
FlyingChairsTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
FlyingChairs
ADDITIONAL_CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
"train"
,
"val"
))
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
(
np
.
ndarray
,
type
(
None
)))
FLOW_H
,
FLOW_W
=
3
,
4
def
_make_split_file
(
self
,
root
,
num_examples
):
# We create a fake split file here, but users are asked to download the real one from the authors website
split_ids
=
[
1
]
*
num_examples
[
"train"
]
+
[
2
]
*
num_examples
[
"val"
]
random
.
shuffle
(
split_ids
)
with
open
(
str
(
root
/
"FlyingChairs_train_val.txt"
),
"w+"
)
as
split_file
:
for
split_id
in
split_ids
:
split_file
.
write
(
f
"
{
split_id
}
\n
"
)
def
inject_fake_data
(
self
,
tmpdir
,
config
):
root
=
pathlib
.
Path
(
tmpdir
)
/
"FlyingChairs"
num_examples
=
{
"train"
:
5
,
"val"
:
3
}
num_examples_total
=
sum
(
num_examples
.
values
())
datasets_utils
.
create_image_folder
(
# img1
root
,
name
=
"data"
,
file_name_fn
=
lambda
image_idx
:
f
"00
{
image_idx
}
_img1.ppm"
,
num_examples
=
num_examples_total
,
)
datasets_utils
.
create_image_folder
(
# img2
root
,
name
=
"data"
,
file_name_fn
=
lambda
image_idx
:
f
"00
{
image_idx
}
_img2.ppm"
,
num_examples
=
num_examples_total
,
)
for
i
in
range
(
num_examples_total
):
file_name
=
str
(
root
/
"data"
/
f
"00
{
i
}
_flow.flo"
)
datasets_utils
.
make_fake_flo_file
(
h
=
self
.
FLOW_H
,
w
=
self
.
FLOW_W
,
file_name
=
file_name
)
self
.
_make_split_file
(
root
,
num_examples
)
return
num_examples
[
config
[
"split"
]]
@
datasets_utils
.
test_all_configs
def
test_flow
(
self
,
config
):
# Make sure flow always exists, and make sure there are as many flow values as (pairs of) images
# Also make sure the flow is properly decoded
with
self
.
create_dataset
(
config
=
config
)
as
(
dataset
,
_
):
assert
dataset
.
_flow_list
and
len
(
dataset
.
_flow_list
)
==
len
(
dataset
.
_image_list
)
for
_
,
_
,
flow
in
dataset
:
assert
flow
.
shape
==
(
2
,
self
.
FLOW_H
,
self
.
FLOW_W
)
np
.
testing
.
assert_allclose
(
flow
,
np
.
arange
(
flow
.
size
).
reshape
(
flow
.
shape
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
torchvision/datasets/__init__.py
View file @
7f424379
from
._optical_flow
import
KittiFlow
,
Sintel
from
._optical_flow
import
KittiFlow
,
Sintel
,
FlyingChairs
from
.caltech
import
Caltech101
,
Caltech256
from
.caltech
import
Caltech101
,
Caltech256
from
.celeba
import
CelebA
from
.celeba
import
CelebA
from
.cifar
import
CIFAR10
,
CIFAR100
from
.cifar
import
CIFAR10
,
CIFAR100
...
@@ -74,4 +74,5 @@ __all__ = (
...
@@ -74,4 +74,5 @@ __all__ = (
"LFWPairs"
,
"LFWPairs"
,
"KittiFlow"
,
"KittiFlow"
,
"Sintel"
,
"Sintel"
,
"FlyingChairs"
,
)
)
torchvision/datasets/_optical_flow.py
View file @
7f424379
...
@@ -8,12 +8,14 @@ import torch
...
@@ -8,12 +8,14 @@ import torch
from
PIL
import
Image
from
PIL
import
Image
from
..io.image
import
_read_png_16
from
..io.image
import
_read_png_16
from
.utils
import
verify_str_arg
from
.vision
import
VisionDataset
from
.vision
import
VisionDataset
__all__
=
(
__all__
=
(
"KittiFlow"
,
"KittiFlow"
,
"Sintel"
,
"Sintel"
,
"FlyingChairs"
,
)
)
...
@@ -109,11 +111,8 @@ class Sintel(FlowDataset):
...
@@ -109,11 +111,8 @@ class Sintel(FlowDataset):
def
__init__
(
self
,
root
,
split
=
"train"
,
pass_name
=
"clean"
,
transforms
=
None
):
def
__init__
(
self
,
root
,
split
=
"train"
,
pass_name
=
"clean"
,
transforms
=
None
):
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
if
split
not
in
(
"train"
,
"test"
):
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
raise
ValueError
(
"split must be either 'train' or 'test'"
)
verify_str_arg
(
pass_name
,
"pass_name"
,
valid_values
=
(
"clean"
,
"final"
))
if
pass_name
not
in
(
"clean"
,
"final"
):
raise
ValueError
(
"pass_name must be either 'clean' or 'final'"
)
root
=
Path
(
root
)
/
"Sintel"
root
=
Path
(
root
)
/
"Sintel"
...
@@ -171,8 +170,7 @@ class KittiFlow(FlowDataset):
...
@@ -171,8 +170,7 @@ class KittiFlow(FlowDataset):
def
__init__
(
self
,
root
,
split
=
"train"
,
transforms
=
None
):
def
__init__
(
self
,
root
,
split
=
"train"
,
transforms
=
None
):
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
if
split
not
in
(
"train"
,
"test"
):
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
raise
ValueError
(
"split must be either 'train' or 'test'"
)
root
=
Path
(
root
)
/
"Kitti"
/
(
split
+
"ing"
)
root
=
Path
(
root
)
/
"Kitti"
/
(
split
+
"ing"
)
images1
=
sorted
(
glob
(
str
(
root
/
"image_2"
/
"*_10.png"
)))
images1
=
sorted
(
glob
(
str
(
root
/
"image_2"
/
"*_10.png"
)))
...
@@ -208,6 +206,71 @@ class KittiFlow(FlowDataset):
...
@@ -208,6 +206,71 @@ class KittiFlow(FlowDataset):
return
_read_16bits_png_with_flow_and_valid_mask
(
file_name
)
return
_read_16bits_png_with_flow_and_valid_mask
(
file_name
)
class
FlyingChairs
(
FlowDataset
):
"""`FlyingChairs <https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs>`_ Dataset for optical flow.
You will also need to download the FlyingChairs_train_val.txt file from the dataset page.
The dataset is expected to have the following structure: ::
root
FlyingChairs
data
00001_flow.flo
00001_img1.ppm
00001_img2.ppm
...
FlyingChairs_train_val.txt
Args:
root (string): Root directory of the FlyingChairs Dataset.
split (string, optional): The dataset split, either "train" (default) or "val"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``valid`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""
def
__init__
(
self
,
root
,
split
=
"train"
,
transforms
=
None
):
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"val"
))
root
=
Path
(
root
)
/
"FlyingChairs"
images
=
sorted
(
glob
(
str
(
root
/
"data"
/
"*.ppm"
)))
flows
=
sorted
(
glob
(
str
(
root
/
"data"
/
"*.flo"
)))
split_file_name
=
"FlyingChairs_train_val.txt"
if
not
os
.
path
.
exists
(
root
/
split_file_name
):
raise
FileNotFoundError
(
"The FlyingChairs_train_val.txt file was not found - please download it from the dataset page (see docstring)."
)
split_list
=
np
.
loadtxt
(
str
(
root
/
split_file_name
),
dtype
=
np
.
int32
)
for
i
in
range
(
len
(
flows
)):
split_id
=
split_list
[
i
]
if
(
split
==
"train"
and
split_id
==
1
)
or
(
split
==
"val"
and
split_id
==
2
):
self
.
_flow_list
+=
[
flows
[
i
]]
self
.
_image_list
+=
[[
images
[
2
*
i
],
images
[
2
*
i
+
1
]]]
def
__getitem__
(
self
,
index
):
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: A 3-tuple with ``(img1, img2, flow)``.
The flow is a numpy array of shape (2, H, W) and the images are PIL images.
"""
return
super
().
__getitem__
(
index
)
def
_read_flow
(
self
,
file_name
):
return
_read_flo
(
file_name
)
def
_read_flo
(
file_name
):
def
_read_flo
(
file_name
):
"""Read .flo file in Middlebury format"""
"""Read .flo file in Middlebury format"""
# Code adapted from:
# Code adapted from:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment