Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
7408cb51
Unverified
Commit
7408cb51
authored
Nov 09, 2021
by
Nicolas Hug
Committed by
GitHub
Nov 09, 2021
Browse files
Add pass_name='both' for Sintel dataset (#4888)
parent
43524b61
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
13 deletions
+15
-13
test/test_datasets.py
test/test_datasets.py
+3
-2
torchvision/datasets/_optical_flow.py
torchvision/datasets/_optical_flow.py
+12
-11
No files found.
test/test_datasets.py
View file @
7408cb51
...
@@ -1873,7 +1873,7 @@ class LFWPairsTestCase(LFWPeopleTestCase):
...
@@ -1873,7 +1873,7 @@ class LFWPairsTestCase(LFWPeopleTestCase):
class
SintelTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
SintelTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
Sintel
DATASET_CLASS
=
datasets
.
Sintel
ADDITIONAL_CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
"train"
,
"test"
),
pass_name
=
(
"clean"
,
"final"
))
ADDITIONAL_CONFIGS
=
datasets_utils
.
combinations_grid
(
split
=
(
"train"
,
"test"
),
pass_name
=
(
"clean"
,
"final"
,
"both"
))
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
(
np
.
ndarray
,
type
(
None
)))
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
(
np
.
ndarray
,
type
(
None
)))
FLOW_H
,
FLOW_W
=
3
,
4
FLOW_H
,
FLOW_W
=
3
,
4
...
@@ -1909,7 +1909,8 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -1909,7 +1909,8 @@ class SintelTestCase(datasets_utils.ImageDatasetTestCase):
# which are frame_0000, frame_0001 and frame_0002
# which are frame_0000, frame_0001 and frame_0002
# They will be consecutively paired as (frame_0000, frame_0001), (frame_0001, frame_0002),
# They will be consecutively paired as (frame_0000, frame_0001), (frame_0001, frame_0002),
# that is 3 - 1 = 2 examples. Hence the formula below
# that is 3 - 1 = 2 examples. Hence the formula below
num_examples
=
(
num_images_per_scene
-
1
)
*
num_scenes
num_passes
=
2
if
config
[
"pass_name"
]
==
"both"
else
1
num_examples
=
(
num_images_per_scene
-
1
)
*
num_scenes
*
num_passes
return
num_examples
return
num_examples
def
test_flow
(
self
):
def
test_flow
(
self
):
...
...
torchvision/datasets/_optical_flow.py
View file @
7408cb51
...
@@ -103,7 +103,7 @@ class Sintel(FlowDataset):
...
@@ -103,7 +103,7 @@ class Sintel(FlowDataset):
Args:
Args:
root (string): Root directory of the Sintel Dataset.
root (string): Root directory of the Sintel Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default)
or
"final". See link above for
pass_name (string, optional): The pass to use, either "clean" (default)
,
"final"
, or "both"
. See link above for
details on the different passes.
details on the different passes.
transforms (callable, optional): A function/transform that takes in
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
``img1, img2, flow, valid`` and returns a transformed version.
...
@@ -115,21 +115,22 @@ class Sintel(FlowDataset):
...
@@ -115,21 +115,22 @@ class Sintel(FlowDataset):
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
super
().
__init__
(
root
=
root
,
transforms
=
transforms
)
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
verify_str_arg
(
split
,
"split"
,
valid_values
=
(
"train"
,
"test"
))
verify_str_arg
(
pass_name
,
"pass_name"
,
valid_values
=
(
"clean"
,
"final"
))
verify_str_arg
(
pass_name
,
"pass_name"
,
valid_values
=
(
"clean"
,
"final"
,
"both"
))
passes
=
[
"clean"
,
"final"
]
if
pass_name
==
"both"
else
[
pass_name
]
root
=
Path
(
root
)
/
"Sintel"
root
=
Path
(
root
)
/
"Sintel"
split_dir
=
"training"
if
split
==
"train"
else
split
image_root
=
root
/
split_dir
/
pass_name
flow_root
=
root
/
"training"
/
"flow"
flow_root
=
root
/
"training"
/
"flow"
for
scene
in
os
.
listdir
(
image_root
):
for
pass_name
in
passes
:
image_list
=
sorted
(
glob
(
str
(
image_root
/
scene
/
"*.png"
)))
split_dir
=
"training"
if
split
==
"train"
else
split
for
i
in
range
(
len
(
image_list
)
-
1
):
image_root
=
root
/
split_dir
/
pass_name
self
.
_image_list
+=
[[
image_list
[
i
],
image_list
[
i
+
1
]]]
for
scene
in
os
.
listdir
(
image_root
):
image_list
=
sorted
(
glob
(
str
(
image_root
/
scene
/
"*.png"
)))
for
i
in
range
(
len
(
image_list
)
-
1
):
self
.
_image_list
+=
[[
image_list
[
i
],
image_list
[
i
+
1
]]]
if
split
==
"train"
:
if
split
==
"train"
:
self
.
_flow_list
+=
sorted
(
glob
(
str
(
flow_root
/
scene
/
"*.flo"
)))
self
.
_flow_list
+=
sorted
(
glob
(
str
(
flow_root
/
scene
/
"*.flo"
)))
def
__getitem__
(
self
,
index
):
def
__getitem__
(
self
,
index
):
"""Return example at given index.
"""Return example at given index.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment