Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
d5bddd25
Unverified
Commit
d5bddd25
authored
May 31, 2021
by
Ziyi Wu
Committed by
GitHub
May 31, 2021
Browse files
[Enhance] Reuse some functions in `Datasets` loading data (#583)
* move _get_data to utils * add comment
parent
25a736f7
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
75 deletions
+38
-75
mmdet3d/datasets/custom_3d.py
mmdet3d/datasets/custom_3d.py
+3
-25
mmdet3d/datasets/custom_3d_seg.py
mmdet3d/datasets/custom_3d_seg.py
+3
-25
mmdet3d/datasets/nuscenes_mono_dataset.py
mmdet3d/datasets/nuscenes_mono_dataset.py
+3
-25
mmdet3d/datasets/utils.py
mmdet3d/datasets/utils.py
+29
-0
No files found.
mmdet3d/datasets/custom_3d.py
View file @
d5bddd25
...
...
@@ -8,7 +8,7 @@ from torch.utils.data import Dataset
from
mmdet.datasets
import
DATASETS
from
..core.bbox
import
get_box_type
from
.pipelines
import
Compose
from
.utils
import
get_loading_pipeline
from
.utils
import
extract_result_dict
,
get_loading_pipeline
@
DATASETS
.
register_module
()
...
...
@@ -293,28 +293,6 @@ class Custom3DDataset(Dataset):
return
Compose
(
loading_pipeline
)
return
Compose
(
pipeline
)
@
staticmethod
def
_get_data
(
results
,
key
):
"""Extract and return the data corresponding to key in result dict.
Args:
results (dict): Data loaded using pipeline.
key (str): Key of the desired data.
Returns:
np.ndarray | torch.Tensor | None: Data term.
"""
if
key
not
in
results
.
keys
():
return
None
# results[key] may be data or list[data]
# data may be wrapped inside DataContainer
data
=
results
[
key
]
if
isinstance
(
data
,
list
)
or
isinstance
(
data
,
tuple
):
data
=
data
[
0
]
if
isinstance
(
data
,
mmcv
.
parallel
.
DataContainer
):
data
=
data
.
_data
return
data
def
_extract_data
(
self
,
index
,
pipeline
,
key
,
load_annos
=
False
):
"""Load data using input pipeline and extract data according to key.
...
...
@@ -341,9 +319,9 @@ class Custom3DDataset(Dataset):
# extract data items according to keys
if
isinstance
(
key
,
str
):
data
=
self
.
_get_data
(
example
,
key
)
data
=
extract_result_dict
(
example
,
key
)
else
:
data
=
[
self
.
_get_data
(
example
,
k
)
for
k
in
key
]
data
=
[
extract_result_dict
(
example
,
k
)
for
k
in
key
]
if
load_annos
:
self
.
test_mode
=
original_test_mode
...
...
mmdet3d/datasets/custom_3d_seg.py
View file @
d5bddd25
...
...
@@ -8,7 +8,7 @@ from torch.utils.data import Dataset
from
mmdet.datasets
import
DATASETS
from
mmseg.datasets
import
DATASETS
as
SEG_DATASETS
from
.pipelines
import
Compose
from
.utils
import
get_loading_pipeline
from
.utils
import
extract_result_dict
,
get_loading_pipeline
@
DATASETS
.
register_module
()
...
...
@@ -399,28 +399,6 @@ class Custom3DSegDataset(Dataset):
return
Compose
(
loading_pipeline
)
return
Compose
(
pipeline
)
@
staticmethod
def
_get_data
(
results
,
key
):
"""Extract and return the data corresponding to key in result dict.
Args:
results (dict): Data loaded using pipeline.
key (str): Key of the desired data.
Returns:
np.ndarray | torch.Tensor | None: Data term.
"""
if
key
not
in
results
.
keys
():
return
None
# results[key] may be data or list[data]
# data may be wrapped inside DataContainer
data
=
results
[
key
]
if
isinstance
(
data
,
list
)
or
isinstance
(
data
,
tuple
):
data
=
data
[
0
]
if
isinstance
(
data
,
mmcv
.
parallel
.
DataContainer
):
data
=
data
.
_data
return
data
def
_extract_data
(
self
,
index
,
pipeline
,
key
,
load_annos
=
False
):
"""Load data using input pipeline and extract data according to key.
...
...
@@ -447,9 +425,9 @@ class Custom3DSegDataset(Dataset):
# extract data items according to keys
if
isinstance
(
key
,
str
):
data
=
self
.
_get_data
(
example
,
key
)
data
=
extract_result_dict
(
example
,
key
)
else
:
data
=
[
self
.
_get_data
(
example
,
k
)
for
k
in
key
]
data
=
[
extract_result_dict
(
example
,
k
)
for
k
in
key
]
if
load_annos
:
self
.
test_mode
=
original_test_mode
...
...
mmdet3d/datasets/nuscenes_mono_dataset.py
View file @
d5bddd25
...
...
@@ -13,7 +13,7 @@ from mmdet.datasets import DATASETS, CocoDataset
from
..core
import
show_multi_modality_result
from
..core.bbox
import
CameraInstance3DBoxes
,
get_box_type
,
mono_cam_box2vis
from
.pipelines
import
Compose
from
.utils
import
get_loading_pipeline
from
.utils
import
extract_result_dict
,
get_loading_pipeline
@
DATASETS
.
register_module
()
...
...
@@ -541,28 +541,6 @@ class NuScenesMonoDataset(CocoDataset):
self
.
show
(
results
,
out_dir
,
pipeline
=
pipeline
)
return
results_dict
@
staticmethod
def
_get_data
(
results
,
key
):
"""Extract and return the data corresponding to key in result dict.
Args:
results (dict): Data loaded using pipeline.
key (str): Key of the desired data.
Returns:
np.ndarray | torch.Tensor | None: Data term.
"""
if
key
not
in
results
.
keys
():
return
None
# results[key] may be data or list[data]
# data may be wrapped inside DataContainer
data
=
results
[
key
]
if
isinstance
(
data
,
list
)
or
isinstance
(
data
,
tuple
):
data
=
data
[
0
]
if
isinstance
(
data
,
mmcv
.
parallel
.
DataContainer
):
data
=
data
.
_data
return
data
def
_extract_data
(
self
,
index
,
pipeline
,
key
,
load_annos
=
False
):
"""Load data using input pipeline and extract data according to key.
...
...
@@ -590,9 +568,9 @@ class NuScenesMonoDataset(CocoDataset):
# extract data items according to keys
if
isinstance
(
key
,
str
):
data
=
self
.
_get_data
(
example
,
key
)
data
=
extract_result_dict
(
example
,
key
)
else
:
data
=
[
self
.
_get_data
(
example
,
k
)
for
k
in
key
]
data
=
[
extract_result_dict
(
example
,
k
)
for
k
in
key
]
return
data
...
...
mmdet3d/datasets/utils.py
View file @
d5bddd25
import
mmcv
# yapf: disable
from
mmdet3d.datasets.pipelines
import
(
Collect3D
,
DefaultFormatBundle3D
,
LoadAnnotations3D
,
...
...
@@ -108,3 +110,30 @@ def get_loading_pipeline(pipeline):
'The data pipeline in your config file must include '
\
'loading step.'
return
loading_pipeline
def
extract_result_dict
(
results
,
key
):
"""Extract and return the data corresponding to key in result dict.
``results`` is a dict output from `pipeline(input_dict)`, which is the
loaded data from ``Dataset`` class.
The data terms inside may be wrapped in list, tuple and DataContainer, so
this function essentially extracts data from these wrappers.
Args:
results (dict): Data loaded using pipeline.
key (str): Key of the desired data.
Returns:
np.ndarray | torch.Tensor | None: Data term.
"""
if
key
not
in
results
.
keys
():
return
None
# results[key] may be data or list[data] or tuple[data]
# data may be wrapped inside DataContainer
data
=
results
[
key
]
if
isinstance
(
data
,
(
list
,
tuple
)):
data
=
data
[
0
]
if
isinstance
(
data
,
mmcv
.
parallel
.
DataContainer
):
data
=
data
.
_data
return
data
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment