Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
15a9a93b
Unverified
Commit
15a9a93b
authored
Sep 08, 2022
by
Ponku
Committed by
GitHub
Sep 08, 2022
Browse files
Added CREStereo dataset (#6351)
Co-authored-by:
Joao Gomes
<
jdsgomes@fb.com
>
parent
1d6a259c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
122 additions
and
0 deletions
+122
-0
docs/source/datasets.rst
docs/source/datasets.rst
+1
-0
test/test_datasets.py
test/test_datasets.py
+31
-0
torchvision/datasets/__init__.py
torchvision/datasets/__init__.py
+2
-0
torchvision/datasets/_stereo_matching.py
torchvision/datasets/_stereo_matching.py
+88
-0
No files found.
docs/source/datasets.rst
View file @
15a9a93b
...
@@ -111,6 +111,7 @@ Stereo Matching
...
@@ -111,6 +111,7 @@ Stereo Matching
CarlaStereo
CarlaStereo
Kitti2012Stereo
Kitti2012Stereo
Kitti2015Stereo
Kitti2015Stereo
CREStereo
FallingThingsStereo
FallingThingsStereo
SceneFlowStereo
SceneFlowStereo
SintelStereo
SintelStereo
...
...
test/test_datasets.py
View file @
15a9a93b
...
@@ -2841,6 +2841,37 @@ class CarlaStereoTestCase(datasets_utils.ImageDatasetTestCase):
...
@@ -2841,6 +2841,37 @@ class CarlaStereoTestCase(datasets_utils.ImageDatasetTestCase):
datasets_utils
.
shape_test_for_stereo
(
left
,
right
,
disparity
)
datasets_utils
.
shape_test_for_stereo
(
left
,
right
,
disparity
)
class
CREStereoTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
CREStereo
FEATURE_TYPES
=
(
PIL
.
Image
.
Image
,
PIL
.
Image
.
Image
,
np
.
ndarray
,
type
(
None
))
def
inject_fake_data
(
self
,
tmpdir
,
config
):
crestereo_dir
=
pathlib
.
Path
(
tmpdir
)
/
"CREStereo"
os
.
makedirs
(
crestereo_dir
,
exist_ok
=
True
)
examples
=
{
"tree"
:
2
,
"shapenet"
:
3
,
"reflective"
:
6
,
"hole"
:
5
}
for
category_name
in
[
"shapenet"
,
"reflective"
,
"tree"
,
"hole"
]:
split_dir
=
crestereo_dir
/
category_name
os
.
makedirs
(
split_dir
,
exist_ok
=
True
)
num_examples
=
examples
[
category_name
]
for
idx
in
range
(
num_examples
):
datasets_utils
.
create_image_file
(
root
=
split_dir
,
name
=
f
"
{
idx
}
_left.jpg"
,
size
=
(
100
,
100
))
datasets_utils
.
create_image_file
(
root
=
split_dir
,
name
=
f
"
{
idx
}
_right.jpg"
,
size
=
(
100
,
100
))
# these are going to end up being gray scale images
datasets_utils
.
create_image_file
(
root
=
split_dir
,
name
=
f
"
{
idx
}
_left.disp.png"
,
size
=
(
1
,
100
,
100
))
datasets_utils
.
create_image_file
(
root
=
split_dir
,
name
=
f
"
{
idx
}
_right.disp.png"
,
size
=
(
1
,
100
,
100
))
return
sum
(
examples
.
values
())
def
test_splits
(
self
):
with
self
.
create_dataset
()
as
(
dataset
,
_
):
for
left
,
right
,
disparity
,
mask
in
dataset
:
assert
mask
is
None
datasets_utils
.
shape_test_for_stereo
(
left
,
right
,
disparity
)
class
FallingThingsStereoTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
class
FallingThingsStereoTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
DATASET_CLASS
=
datasets
.
FallingThingsStereo
DATASET_CLASS
=
datasets
.
FallingThingsStereo
ADDITIONAL_CONFIGS
=
datasets_utils
.
combinations_grid
(
variant
=
(
"single"
,
"mixed"
,
"both"
))
ADDITIONAL_CONFIGS
=
datasets_utils
.
combinations_grid
(
variant
=
(
"single"
,
"mixed"
,
"both"
))
...
...
torchvision/datasets/__init__.py
View file @
15a9a93b
from
._optical_flow
import
FlyingChairs
,
FlyingThings3D
,
HD1K
,
KittiFlow
,
Sintel
from
._optical_flow
import
FlyingChairs
,
FlyingThings3D
,
HD1K
,
KittiFlow
,
Sintel
from
._stereo_matching
import
(
from
._stereo_matching
import
(
CarlaStereo
,
CarlaStereo
,
CREStereo
,
ETH3DStereo
,
ETH3DStereo
,
FallingThingsStereo
,
FallingThingsStereo
,
InStereo2k
,
InStereo2k
,
...
@@ -118,6 +119,7 @@ __all__ = (
...
@@ -118,6 +119,7 @@ __all__ = (
"Kitti2012Stereo"
,
"Kitti2012Stereo"
,
"Kitti2015Stereo"
,
"Kitti2015Stereo"
,
"CarlaStereo"
,
"CarlaStereo"
,
"CREStereo"
,
"FallingThingsStereo"
,
"FallingThingsStereo"
,
"SceneFlowStereo"
,
"SceneFlowStereo"
,
"SintelStereo"
,
"SintelStereo"
,
...
...
torchvision/datasets/_stereo_matching.py
View file @
15a9a93b
...
@@ -363,6 +363,94 @@ class Kitti2015Stereo(StereoMatchingDataset):
...
@@ -363,6 +363,94 @@ class Kitti2015Stereo(StereoMatchingDataset):
return
super
().
__getitem__
(
index
)
return
super
().
__getitem__
(
index
)
class
CREStereo
(
StereoMatchingDataset
):
"""Synthetic dataset used in training the `CREStereo <https://arxiv.org/pdf/2203.11483.pdf>`_ architecture.
Dataset details on the official paper `repo <https://github.com/megvii-research/CREStereo>`_.
The dataset is expected to have the following structure: ::
root
CREStereo
tree
img1_left.jpg
img1_right.jpg
img1_left.disp.jpg
img1_right.disp.jpg
img2_left.jpg
img2_right.jpg
img2_left.disp.jpg
img2_right.disp.jpg
...
shapenet
img1_left.jpg
img1_right.jpg
img1_left.disp.jpg
img1_right.disp.jpg
...
reflective
img1_left.jpg
img1_right.jpg
img1_left.disp.jpg
img1_right.disp.jpg
...
hole
img1_left.jpg
img1_right.jpg
img1_left.disp.jpg
img1_right.disp.jpg
...
Args:
root (str): Root directory of the dataset.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""
_has_built_in_disparity_mask
=
True
def
__init__
(
self
,
root
:
str
,
transforms
:
Optional
[
Callable
]
=
None
,
):
super
().
__init__
(
root
,
transforms
)
root
=
Path
(
root
)
/
"CREStereo"
dirs
=
[
"shapenet"
,
"reflective"
,
"tree"
,
"hole"
]
for
s
in
dirs
:
left_image_pattern
=
str
(
root
/
s
/
"*_left.jpg"
)
right_image_pattern
=
str
(
root
/
s
/
"*_right.jpg"
)
imgs
=
self
.
_scan_pairs
(
left_image_pattern
,
right_image_pattern
)
self
.
_images
+=
imgs
left_disparity_pattern
=
str
(
root
/
s
/
"*_left.disp.png"
)
right_disparity_pattern
=
str
(
root
/
s
/
"*_right.disp.png"
)
disparities
=
self
.
_scan_pairs
(
left_disparity_pattern
,
right_disparity_pattern
)
self
.
_disparities
+=
disparities
def
_read_disparity
(
self
,
file_path
:
str
)
->
Tuple
:
disparity_map
=
np
.
asarray
(
Image
.
open
(
file_path
),
dtype
=
np
.
float32
)
# unsqueeze the disparity map into (C, H, W) format
disparity_map
=
disparity_map
[
None
,
:,
:]
/
256.0
valid_mask
=
None
return
disparity_map
,
valid_mask
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
:
"""Return example at given index.
Args:
index(int): The index of the example to retrieve
Returns:
tuple: A 4-tuple with ``(img_left, img_right, disparity, valid_mask)``.
The disparity is a numpy array of shape (1, H, W) and the images are PIL images.
``valid_mask`` is implicitly ``None`` if the ``transforms`` parameter does not
generate a valid mask.
"""
return
super
().
__getitem__
(
index
)
class
FallingThingsStereo
(
StereoMatchingDataset
):
class
FallingThingsStereo
(
StereoMatchingDataset
):
"""`FallingThings <https://research.nvidia.com/publication/2018-06_falling-things-synthetic-dataset-3d-object-detection-and-pose-estimation>`_ dataset.
"""`FallingThings <https://research.nvidia.com/publication/2018-06_falling-things-synthetic-dataset-3d-object-detection-and-pose-estimation>`_ dataset.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment