Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
d7067e44
Unverified
Commit
d7067e44
authored
Dec 03, 2022
by
Wenwei Zhang
Committed by
GitHub
Dec 03, 2022
Browse files
Bump version to v1.1.0rc2
Bump to v1.1.0rc2
parents
28fe73d2
fb0e57e5
Changes
360
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
971 additions
and
769 deletions
+971
-769
mmdet3d/datasets/det3d_dataset.py
mmdet3d/datasets/det3d_dataset.py
+118
-37
mmdet3d/datasets/kitti2d_dataset.py
mmdet3d/datasets/kitti2d_dataset.py
+5
-5
mmdet3d/datasets/kitti_dataset.py
mmdet3d/datasets/kitti_dataset.py
+33
-21
mmdet3d/datasets/lyft_dataset.py
mmdet3d/datasets/lyft_dataset.py
+13
-11
mmdet3d/datasets/nuscenes_dataset.py
mmdet3d/datasets/nuscenes_dataset.py
+33
-22
mmdet3d/datasets/s3dis_dataset.py
mmdet3d/datasets/s3dis_dataset.py
+150
-165
mmdet3d/datasets/scannet_dataset.py
mmdet3d/datasets/scannet_dataset.py
+56
-43
mmdet3d/datasets/seg3d_dataset.py
mmdet3d/datasets/seg3d_dataset.py
+53
-45
mmdet3d/datasets/semantickitti_dataset.py
mmdet3d/datasets/semantickitti_dataset.py
+33
-23
mmdet3d/datasets/sunrgbd_dataset.py
mmdet3d/datasets/sunrgbd_dataset.py
+12
-12
mmdet3d/datasets/transforms/__init__.py
mmdet3d/datasets/transforms/__init__.py
+8
-11
mmdet3d/datasets/transforms/compose.py
mmdet3d/datasets/transforms/compose.py
+0
-53
mmdet3d/datasets/transforms/dbsampler.py
mmdet3d/datasets/transforms/dbsampler.py
+10
-12
mmdet3d/datasets/transforms/formating.py
mmdet3d/datasets/transforms/formating.py
+1
-1
mmdet3d/datasets/transforms/loading.py
mmdet3d/datasets/transforms/loading.py
+63
-81
mmdet3d/datasets/transforms/test_time_aug.py
mmdet3d/datasets/transforms/test_time_aug.py
+9
-10
mmdet3d/datasets/transforms/transforms_3d.py
mmdet3d/datasets/transforms/transforms_3d.py
+137
-134
mmdet3d/datasets/waymo_dataset.py
mmdet3d/datasets/waymo_dataset.py
+54
-39
mmdet3d/evaluation/functional/waymo_utils/__init__.py
mmdet3d/evaluation/functional/waymo_utils/__init__.py
+2
-2
mmdet3d/evaluation/functional/waymo_utils/prediction_to_waymo.py
.../evaluation/functional/waymo_utils/prediction_to_waymo.py
+181
-42
No files found.
mmdet3d/datasets/det3d_dataset.py
View file @
d7067e44
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
copy
import
copy
import
os
from
os
import
path
as
osp
from
os
import
path
as
osp
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Set
,
Union
import
mmengine
import
mmengine
import
numpy
as
np
import
numpy
as
np
import
torch
from
mmengine.dataset
import
BaseDataset
from
mmengine.dataset
import
BaseDataset
from
mmengine.logging
import
print_log
from
terminaltables
import
AsciiTable
from
mmdet3d.datasets
import
DATASETS
from
mmdet3d.datasets
import
DATASETS
from
mmdet3d.structures
import
get_box_type
from
mmdet3d.structures
import
get_box_type
...
@@ -25,22 +29,22 @@ class Det3DDataset(BaseDataset):
...
@@ -25,22 +29,22 @@ class Det3DDataset(BaseDataset):
ann_file (str): Annotation file path. Defaults to ''.
ann_file (str): Annotation file path. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as class
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
information. Defaults to None.
data_prefix (dict
, optional
): Prefix for training data. Defaults to
data_prefix (dict): Prefix for training data. Defaults to
dict(pts='velodyne', img='').
dict(pts='velodyne', img='').
pipeline (
l
ist[dict]
, optional
): Pipeline used for data processing.
pipeline (
L
ist[dict]): Pipeline used for data processing.
Defaults to
None
.
Defaults to
[]
.
modality (dict
, optional
): Modality to specify the sensor data used
modality (dict): Modality to specify the sensor data used
as input,
as input,
it usually has following keys:
it usually has following keys:
- use_camera: bool
- use_camera: bool
- use_lidar: bool
- use_lidar: bool
Defaults to
`
dict(use_lidar=True, use_camera=False)
`
Defaults to dict(use_lidar=True, use_camera=False)
.
default_cam_key (str, optional): The default camera name adopted.
default_cam_key (str, optional): The default camera name adopted.
Defaults to None.
Defaults to None.
box_type_3d (str
, optional
): Type of 3D box of this dataset.
box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
to its original format then converted them to `box_type_3d`.
Defaults to 'LiDAR'. Available options includes:
Defaults to 'LiDAR'
in this dataset
. Available options includes:
- 'LiDAR': Box in LiDAR coordinates, usually for
- 'LiDAR': Box in LiDAR coordinates, usually for
outdoor point cloud 3d detection.
outdoor point cloud 3d detection.
...
@@ -48,16 +52,20 @@ class Det3DDataset(BaseDataset):
...
@@ -48,16 +52,20 @@ class Det3DDataset(BaseDataset):
indoor point cloud 3d detection.
indoor point cloud 3d detection.
- 'Camera': Box in camera coordinates, usually
- 'Camera': Box in camera coordinates, usually
for vision-based 3d detection.
for vision-based 3d detection.
filter_empty_gt (bool): Whether to filter the data with empty GT.
filter_empty_gt (bool, optional): Whether to filter the data with
If it's set to be True, the example with empty annotations after
empty GT. Defaults to True.
data pipeline will be dropped and a random example will be chosen
test_mode (bool, optional): Whether the dataset is in test mode.
in `__getitem__`. Defaults to True.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
Defaults to False.
load_eval_anns (bool
, optional
): Whether to load annotations
load_eval_anns (bool): Whether to load annotations
in test_mode,
in test_mode,
the annotation will be save in `eval_ann_infos`,
the annotation will be save in `eval_ann_infos`,
which can be
which can be
used in Evaluator. Defaults to True.
used in Evaluator. Defaults to True.
file_client_args (dict
, optional
): Configuration of file client.
file_client_args (dict): Configuration of file client.
Defaults to dict(backend='disk').
Defaults to dict(backend='disk').
show_ins_var (bool): For debug purpose. Whether to show variation
of the number of instances before and after through pipeline.
Defaults to False.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -71,8 +79,9 @@ class Det3DDataset(BaseDataset):
...
@@ -71,8 +79,9 @@ class Det3DDataset(BaseDataset):
box_type_3d
:
dict
=
'LiDAR'
,
box_type_3d
:
dict
=
'LiDAR'
,
filter_empty_gt
:
bool
=
True
,
filter_empty_gt
:
bool
=
True
,
test_mode
:
bool
=
False
,
test_mode
:
bool
=
False
,
load_eval_anns
=
True
,
load_eval_anns
:
bool
=
True
,
file_client_args
:
dict
=
dict
(
backend
=
'disk'
),
file_client_args
:
dict
=
dict
(
backend
=
'disk'
),
show_ins_var
:
bool
=
False
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
# init file client
# init file client
self
.
file_client
=
mmengine
.
FileClient
(
**
file_client_args
)
self
.
file_client
=
mmengine
.
FileClient
(
**
file_client_args
)
...
@@ -94,24 +103,31 @@ class Det3DDataset(BaseDataset):
...
@@ -94,24 +103,31 @@ class Det3DDataset(BaseDataset):
self
.
box_type_3d
,
self
.
box_mode_3d
=
get_box_type
(
box_type_3d
)
self
.
box_type_3d
,
self
.
box_mode_3d
=
get_box_type
(
box_type_3d
)
if
metainfo
is
not
None
and
'
CLASSES
'
in
metainfo
:
if
metainfo
is
not
None
and
'
classes
'
in
metainfo
:
# we allow to train on subset of self.METAINFO['
CLASSES
']
# we allow to train on subset of self.METAINFO['
classes
']
# map unselected labels to -1
# map unselected labels to -1
self
.
label_mapping
=
{
self
.
label_mapping
=
{
i
:
-
1
i
:
-
1
for
i
in
range
(
len
(
self
.
METAINFO
[
'
CLASSES
'
]))
for
i
in
range
(
len
(
self
.
METAINFO
[
'
classes
'
]))
}
}
self
.
label_mapping
[
-
1
]
=
-
1
self
.
label_mapping
[
-
1
]
=
-
1
for
label_idx
,
name
in
enumerate
(
metainfo
[
'
CLASSES
'
]):
for
label_idx
,
name
in
enumerate
(
metainfo
[
'
classes
'
]):
ori_label
=
self
.
METAINFO
[
'
CLASSES
'
].
index
(
name
)
ori_label
=
self
.
METAINFO
[
'
classes
'
].
index
(
name
)
self
.
label_mapping
[
ori_label
]
=
label_idx
self
.
label_mapping
[
ori_label
]
=
label_idx
self
.
num_ins_per_cat
=
{
name
:
0
for
name
in
metainfo
[
'classes'
]}
else
:
else
:
self
.
label_mapping
=
{
self
.
label_mapping
=
{
i
:
i
i
:
i
for
i
in
range
(
len
(
self
.
METAINFO
[
'
CLASSES
'
]))
for
i
in
range
(
len
(
self
.
METAINFO
[
'
classes
'
]))
}
}
self
.
label_mapping
[
-
1
]
=
-
1
self
.
label_mapping
[
-
1
]
=
-
1
self
.
num_ins_per_cat
=
{
name
:
0
for
name
in
self
.
METAINFO
[
'classes'
]
}
super
().
__init__
(
super
().
__init__
(
ann_file
=
ann_file
,
ann_file
=
ann_file
,
metainfo
=
metainfo
,
metainfo
=
metainfo
,
...
@@ -125,10 +141,25 @@ class Det3DDataset(BaseDataset):
...
@@ -125,10 +141,25 @@ class Det3DDataset(BaseDataset):
self
.
metainfo
[
'box_type_3d'
]
=
box_type_3d
self
.
metainfo
[
'box_type_3d'
]
=
box_type_3d
self
.
metainfo
[
'label_mapping'
]
=
self
.
label_mapping
self
.
metainfo
[
'label_mapping'
]
=
self
.
label_mapping
# used for showing variation of the number of instances before and
# after through the pipeline
self
.
show_ins_var
=
show_ins_var
# show statistics of this dataset
print_log
(
'-'
*
30
,
'current'
)
print_log
(
f
'The length of the dataset:
{
len
(
self
)
}
'
,
'current'
)
content_show
=
[[
'category'
,
'number'
]]
for
cat_name
,
num
in
self
.
num_ins_per_cat
.
items
():
content_show
.
append
([
cat_name
,
num
])
table
=
AsciiTable
(
content_show
)
print_log
(
f
'The number of instances per category in the dataset:
\n
{
table
.
table
}
'
,
# noqa: E501
'current'
)
def
_remove_dontcare
(
self
,
ann_info
:
dict
)
->
dict
:
def
_remove_dontcare
(
self
,
ann_info
:
dict
)
->
dict
:
"""Remove annotations that do not need to be cared.
"""Remove annotations that do not need to be cared.
-1 indicate dontcare in MMDet3d.
-1 indicate
s
dontcare in MMDet3d.
Args:
Args:
ann_info (dict): Dict of annotation infos. The
ann_info (dict): Dict of annotation infos. The
...
@@ -156,7 +187,7 @@ class Det3DDataset(BaseDataset):
...
@@ -156,7 +187,7 @@ class Det3DDataset(BaseDataset):
index (int): Index of the annotation data to get.
index (int): Index of the annotation data to get.
Returns:
Returns:
dict:
a
nnotation information.
dict:
A
nnotation information.
"""
"""
data_info
=
self
.
get_data_info
(
index
)
data_info
=
self
.
get_data_info
(
index
)
# test model
# test model
...
@@ -167,8 +198,8 @@ class Det3DDataset(BaseDataset):
...
@@ -167,8 +198,8 @@ class Det3DDataset(BaseDataset):
return
ann_info
return
ann_info
def
parse_ann_info
(
self
,
info
:
dict
)
->
Opt
ion
al
[
dict
]:
def
parse_ann_info
(
self
,
info
:
dict
)
->
Un
ion
[
dict
,
None
]:
"""Process the `instances` in data info to `ann_info`
"""Process the `instances` in data info to `ann_info`
.
In `Custom3DDataset`, we simply concatenate all the field
In `Custom3DDataset`, we simply concatenate all the field
in `instances` to `np.ndarray`, you can do the specific
in `instances` to `np.ndarray`, you can do the specific
...
@@ -179,7 +210,7 @@ class Det3DDataset(BaseDataset):
...
@@ -179,7 +210,7 @@ class Det3DDataset(BaseDataset):
info (dict): Info dict.
info (dict): Info dict.
Returns:
Returns:
dict
|
None: Processed `ann_info`
dict
or
None: Processed `ann_info`
.
"""
"""
# add s or gt prefix for most keys after concat
# add s or gt prefix for most keys after concat
# we only process 3d annotations here, the corresponding
# we only process 3d annotations here, the corresponding
...
@@ -223,14 +254,20 @@ class Det3DDataset(BaseDataset):
...
@@ -223,14 +254,20 @@ class Det3DDataset(BaseDataset):
ann_info
[
mapped_ann_name
]
=
temp_anns
ann_info
[
mapped_ann_name
]
=
temp_anns
ann_info
[
'instances'
]
=
info
[
'instances'
]
ann_info
[
'instances'
]
=
info
[
'instances'
]
for
label
in
ann_info
[
'gt_labels_3d'
]:
if
label
!=
-
1
:
cat_name
=
self
.
metainfo
[
'classes'
][
label
]
self
.
num_ins_per_cat
[
cat_name
]
+=
1
return
ann_info
return
ann_info
def
parse_data_info
(
self
,
info
:
dict
)
->
dict
:
def
parse_data_info
(
self
,
info
:
dict
)
->
dict
:
"""Process the raw data info.
"""Process the raw data info.
Convert all relative path of needed modality data file to
Convert all relative path of needed modality data file to
the absolute path. And process
the absolute path. And process
the `instances` field to
the `instances` field to
`ann_info` in training stage.
`ann_info` in training stage.
Args:
Args:
info (dict): Raw info dict.
info (dict): Raw info dict.
...
@@ -251,7 +288,7 @@ class Det3DDataset(BaseDataset):
...
@@ -251,7 +288,7 @@ class Det3DDataset(BaseDataset):
if
'lidar_sweeps'
in
info
:
if
'lidar_sweeps'
in
info
:
for
sweep
in
info
[
'lidar_sweeps'
]:
for
sweep
in
info
[
'lidar_sweeps'
]:
file_suffix
=
sweep
[
'lidar_points'
][
'lidar_path'
].
split
(
file_suffix
=
sweep
[
'lidar_points'
][
'lidar_path'
].
split
(
'/'
)[
-
1
]
os
.
sep
)[
-
1
]
if
'samples'
in
sweep
[
'lidar_points'
][
'lidar_path'
]:
if
'samples'
in
sweep
[
'lidar_points'
][
'lidar_path'
]:
sweep
[
'lidar_points'
][
'lidar_path'
]
=
osp
.
join
(
sweep
[
'lidar_points'
][
'lidar_path'
]
=
osp
.
join
(
self
.
data_prefix
[
'pts'
],
file_suffix
)
self
.
data_prefix
[
'pts'
],
file_suffix
)
...
@@ -291,7 +328,37 @@ class Det3DDataset(BaseDataset):
...
@@ -291,7 +328,37 @@ class Det3DDataset(BaseDataset):
return
info
return
info
def
prepare_data
(
self
,
index
:
int
)
->
Optional
[
dict
]:
def
_show_ins_var
(
self
,
old_labels
:
np
.
ndarray
,
new_labels
:
torch
.
Tensor
)
->
None
:
"""Show variation of the number of instances before and after through
the pipeline.
Args:
old_labels (np.ndarray): The labels before through the pipeline.
new_labels (torch.Tensor): The labels after through the pipeline.
"""
ori_num_per_cat
=
dict
()
for
label
in
old_labels
:
if
label
!=
-
1
:
cat_name
=
self
.
metainfo
[
'classes'
][
label
]
ori_num_per_cat
[
cat_name
]
=
ori_num_per_cat
.
get
(
cat_name
,
0
)
+
1
new_num_per_cat
=
dict
()
for
label
in
new_labels
:
if
label
!=
-
1
:
cat_name
=
self
.
metainfo
[
'classes'
][
label
]
new_num_per_cat
[
cat_name
]
=
new_num_per_cat
.
get
(
cat_name
,
0
)
+
1
content_show
=
[[
'category'
,
'new number'
,
'ori number'
]]
for
cat_name
,
num
in
ori_num_per_cat
.
items
():
new_num
=
new_num_per_cat
.
get
(
cat_name
,
0
)
content_show
.
append
([
cat_name
,
new_num
,
num
])
table
=
AsciiTable
(
content_show
)
print_log
(
'The number of instances per category after and before '
f
'through pipeline:
\n
{
table
.
table
}
'
,
'current'
)
def
prepare_data
(
self
,
index
:
int
)
->
Union
[
dict
,
None
]:
"""Data preparation for both training and testing stage.
"""Data preparation for both training and testing stage.
Called by `__getitem__` of dataset.
Called by `__getitem__` of dataset.
...
@@ -300,12 +367,12 @@ class Det3DDataset(BaseDataset):
...
@@ -300,12 +367,12 @@ class Det3DDataset(BaseDataset):
index (int): Index for accessing the target data.
index (int): Index for accessing the target data.
Returns:
Returns:
dict
|
None: Data dict of the corresponding index.
dict
or
None: Data dict of the corresponding index.
"""
"""
input_dict
=
self
.
get_data_info
(
index
)
ori_
input_dict
=
self
.
get_data_info
(
index
)
# deepcopy here to avoid inplace modification in pipeline.
# deepcopy here to avoid inplace modification in pipeline.
input_dict
=
copy
.
deepcopy
(
input_dict
)
input_dict
=
copy
.
deepcopy
(
ori_
input_dict
)
# box_type_3d (str): 3D box type.
# box_type_3d (str): 3D box type.
input_dict
[
'box_type_3d'
]
=
self
.
box_type_3d
input_dict
[
'box_type_3d'
]
=
self
.
box_type_3d
...
@@ -318,15 +385,29 @@ class Det3DDataset(BaseDataset):
...
@@ -318,15 +385,29 @@ class Det3DDataset(BaseDataset):
return
None
return
None
example
=
self
.
pipeline
(
input_dict
)
example
=
self
.
pipeline
(
input_dict
)
if
not
self
.
test_mode
and
self
.
filter_empty_gt
:
if
not
self
.
test_mode
and
self
.
filter_empty_gt
:
# after pipeline drop the example with empty annotations
# after pipeline drop the example with empty annotations
# return None to random another in `__getitem__`
# return None to random another in `__getitem__`
if
example
is
None
or
len
(
if
example
is
None
or
len
(
example
[
'data_samples'
].
gt_instances_3d
.
labels_3d
)
==
0
:
example
[
'data_samples'
].
gt_instances_3d
.
labels_3d
)
==
0
:
return
None
return
None
if
self
.
show_ins_var
:
if
'ann_info'
in
ori_input_dict
:
self
.
_show_ins_var
(
ori_input_dict
[
'ann_info'
][
'gt_labels_3d'
],
example
[
'data_samples'
].
gt_instances_3d
.
labels_3d
)
else
:
print_log
(
"'ann_info' is not in the input dict. It's probably that "
'the data is not in training mode'
,
'current'
,
level
=
30
)
return
example
return
example
def
get_cat_ids
(
self
,
idx
:
int
)
->
Lis
t
[
int
]:
def
get_cat_ids
(
self
,
idx
:
int
)
->
Se
t
[
int
]:
"""Get category ids by index. Dataset wrapped by ClassBalancedDataset
"""Get category ids by index. Dataset wrapped by ClassBalancedDataset
must implement this method.
must implement this method.
...
...
mmdet3d/datasets/kitti2d_dataset.py
View file @
d7067e44
...
@@ -36,7 +36,7 @@ class Kitti2DDataset(Det3DDataset):
...
@@ -36,7 +36,7 @@ class Kitti2DDataset(Det3DDataset):
Defaults to False.
Defaults to False.
"""
"""
CLASSES
=
(
'car'
,
'pedestrian'
,
'cyclist'
)
classes
=
(
'car'
,
'pedestrian'
,
'cyclist'
)
"""
"""
Annotation format:
Annotation format:
[
[
...
@@ -90,7 +90,7 @@ class Kitti2DDataset(Det3DDataset):
...
@@ -90,7 +90,7 @@ class Kitti2DDataset(Det3DDataset):
self
.
data_infos
=
mmengine
.
load
(
ann_file
)
self
.
data_infos
=
mmengine
.
load
(
ann_file
)
self
.
cat2label
=
{
self
.
cat2label
=
{
cat_name
:
i
cat_name
:
i
for
i
,
cat_name
in
enumerate
(
self
.
CLASSES
)
for
i
,
cat_name
in
enumerate
(
self
.
classes
)
}
}
return
self
.
data_infos
return
self
.
data_infos
...
@@ -122,7 +122,7 @@ class Kitti2DDataset(Det3DDataset):
...
@@ -122,7 +122,7 @@ class Kitti2DDataset(Det3DDataset):
difficulty
=
annos
[
'difficulty'
]
difficulty
=
annos
[
'difficulty'
]
# remove classes that is not needed
# remove classes that is not needed
selected
=
self
.
keep_arrays_by_name
(
gt_names
,
self
.
CLASSES
)
selected
=
self
.
keep_arrays_by_name
(
gt_names
,
self
.
classes
)
gt_bboxes
=
gt_bboxes
[
selected
]
gt_bboxes
=
gt_bboxes
[
selected
]
gt_names
=
gt_names
[
selected
]
gt_names
=
gt_names
[
selected
]
difficulty
=
difficulty
[
selected
]
difficulty
=
difficulty
[
selected
]
...
@@ -215,7 +215,7 @@ class Kitti2DDataset(Det3DDataset):
...
@@ -215,7 +215,7 @@ class Kitti2DDataset(Det3DDataset):
"""
"""
from
mmdet3d.structures.ops.transforms
import
bbox2result_kitti2d
from
mmdet3d.structures.ops.transforms
import
bbox2result_kitti2d
sample_idx
=
[
info
[
'image'
][
'image_idx'
]
for
info
in
self
.
data_infos
]
sample_idx
=
[
info
[
'image'
][
'image_idx'
]
for
info
in
self
.
data_infos
]
result_files
=
bbox2result_kitti2d
(
outputs
,
self
.
CLASSES
,
sample_idx
,
result_files
=
bbox2result_kitti2d
(
outputs
,
self
.
classes
,
sample_idx
,
out
)
out
)
return
result_files
return
result_files
...
@@ -237,5 +237,5 @@ class Kitti2DDataset(Det3DDataset):
...
@@ -237,5 +237,5 @@ class Kitti2DDataset(Det3DDataset):
]),
'KITTI data set only evaluate bbox'
]),
'KITTI data set only evaluate bbox'
gt_annos
=
[
info
[
'annos'
]
for
info
in
self
.
data_infos
]
gt_annos
=
[
info
[
'annos'
]
for
info
in
self
.
data_infos
]
ap_result_str
,
ap_dict
=
kitti_eval
(
ap_result_str
,
ap_dict
=
kitti_eval
(
gt_annos
,
result_files
,
self
.
CLASSES
,
eval_types
=
[
'bbox'
])
gt_annos
,
result_files
,
self
.
classes
,
eval_types
=
[
'bbox'
])
return
ap_result_str
,
ap_dict
return
ap_result_str
,
ap_dict
mmdet3d/datasets/kitti_dataset.py
View file @
d7067e44
...
@@ -18,13 +18,13 @@ class KittiDataset(Det3DDataset):
...
@@ -18,13 +18,13 @@ class KittiDataset(Det3DDataset):
Args:
Args:
data_root (str): Path of dataset root.
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
ann_file (str): Path of annotation file.
pipeline (
l
ist[dict]
, optional
): Pipeline used for data processing.
pipeline (
L
ist[dict]): Pipeline used for data processing.
Defaults to
None
.
Defaults to
[]
.
modality (dict
, optional
): Modality to specify the sensor data used
modality (dict): Modality to specify the sensor data used
as input.
as input.
Defaults to
`
dict(use_lidar=True)
`
.
Defaults to dict(use_lidar=True).
default_cam_key (str
, optional
): The default camera name adopted.
default_cam_key (str): The default camera name adopted.
Defaults to 'CAM2'.
Defaults to 'CAM2'.
box_type_3d (str
, optional
): Type of 3D box of this dataset.
box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
to its original format then converted them to `box_type_3d`.
Defaults to 'LiDAR' in this dataset. Available options includes:
Defaults to 'LiDAR' in this dataset. Available options includes:
...
@@ -32,17 +32,28 @@ class KittiDataset(Det3DDataset):
...
@@ -32,17 +32,28 @@ class KittiDataset(Det3DDataset):
- 'LiDAR': Box in LiDAR coordinates.
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool, optional): Whether to filter empty GT.
load_type (str): Type of loading mode. Defaults to 'frame_based'.
Defaults to True.
test_mode (bool, optional): Whether the dataset is in test mode.
- 'frame_based': Load all of the instances in the frame.
- 'mv_image_based': Load all of the instances in the frame and need
to convert to the FOV-based data type to support image-based
detector.
- 'fov_image_based': Only load the instances inside the default
cam, and need to convert to the FOV-based data type to support
image-based detector.
filter_empty_gt (bool): Whether to filter the data with empty GT.
If it's set to be True, the example with empty annotations after
data pipeline will be dropped and a random example will be chosen
in `__getitem__`. Defaults to True.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
Defaults to False.
pcd_limit_range (
l
ist[float]
, optional
): The range of point cloud
pcd_limit_range (
L
ist[float]): The range of point cloud
used to filter
used to filter
invalid predicted boxes.
invalid predicted boxes.
Defaults to [0, -40, -3, 70.4, 40, 0.0].
Defaults to [0, -40, -3, 70.4, 40, 0.0].
"""
"""
# TODO: use full classes of kitti
# TODO: use full classes of kitti
METAINFO
=
{
METAINFO
=
{
'
CLASSES
'
:
(
'Pedestrian'
,
'Cyclist'
,
'Car'
,
'Van'
,
'Truck'
,
'
classes
'
:
(
'Pedestrian'
,
'Cyclist'
,
'Car'
,
'Van'
,
'Truck'
,
'Person_sitting'
,
'Tram'
,
'Misc'
)
'Person_sitting'
,
'Tram'
,
'Misc'
)
}
}
...
@@ -52,7 +63,7 @@ class KittiDataset(Det3DDataset):
...
@@ -52,7 +63,7 @@ class KittiDataset(Det3DDataset):
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
modality
:
dict
=
dict
(
use_lidar
=
True
),
modality
:
dict
=
dict
(
use_lidar
=
True
),
default_cam_key
:
str
=
'CAM2'
,
default_cam_key
:
str
=
'CAM2'
,
task
:
str
=
'lidar_det
'
,
load_type
:
str
=
'frame_based
'
,
box_type_3d
:
str
=
'LiDAR'
,
box_type_3d
:
str
=
'LiDAR'
,
filter_empty_gt
:
bool
=
True
,
filter_empty_gt
:
bool
=
True
,
test_mode
:
bool
=
False
,
test_mode
:
bool
=
False
,
...
@@ -60,8 +71,9 @@ class KittiDataset(Det3DDataset):
...
@@ -60,8 +71,9 @@ class KittiDataset(Det3DDataset):
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
self
.
pcd_limit_range
=
pcd_limit_range
self
.
pcd_limit_range
=
pcd_limit_range
assert
task
in
(
'lidar_det'
,
'mono_det'
)
assert
load_type
in
(
'frame_based'
,
'mv_image_based'
,
self
.
task
=
task
'fov_image_based'
)
self
.
load_type
=
load_type
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
ann_file
=
ann_file
,
ann_file
=
ann_file
,
...
@@ -111,7 +123,7 @@ class KittiDataset(Det3DDataset):
...
@@ -111,7 +123,7 @@ class KittiDataset(Det3DDataset):
info
[
'plane'
]
=
plane_lidar
info
[
'plane'
]
=
plane_lidar
if
self
.
task
==
'mono_det'
:
if
self
.
load_type
==
'fov_image_based'
and
self
.
load_eval_anns
:
info
[
'instances'
]
=
info
[
'cam_instances'
][
self
.
default_cam_key
]
info
[
'instances'
]
=
info
[
'cam_instances'
][
self
.
default_cam_key
]
info
=
super
().
parse_data_info
(
info
)
info
=
super
().
parse_data_info
(
info
)
...
@@ -119,21 +131,21 @@ class KittiDataset(Det3DDataset):
...
@@ -119,21 +131,21 @@ class KittiDataset(Det3DDataset):
return
info
return
info
def
parse_ann_info
(
self
,
info
:
dict
)
->
dict
:
def
parse_ann_info
(
self
,
info
:
dict
)
->
dict
:
"""
Get annotation info according to the given index
.
"""
Process the `instances` in data info to `ann_info`
.
Args:
Args:
info (dict): Data information of single data sample.
info (dict): Data information of single data sample.
Returns:
Returns:
dict:
a
nnotation information consists of the following keys:
dict:
A
nnotation information consists of the following keys:
- gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
- gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
3D ground truth bboxes.
3D ground truth bboxes.
- bbox_labels_3d (np.ndarray): Labels of ground truths.
- bbox_labels_3d (np.ndarray): Labels of ground truths.
- gt_bboxes (np.ndarray): 2D ground truth bboxes.
- gt_bboxes (np.ndarray): 2D ground truth bboxes.
- gt_labels (np.ndarray): Labels of ground truths.
- gt_labels (np.ndarray): Labels of ground truths.
- difficulty (int): Difficulty defined by KITTI.
- difficulty (int): Difficulty defined by KITTI.
0, 1, 2 represent xxxxx respectively.
0, 1, 2 represent xxxxx respectively.
"""
"""
ann_info
=
super
().
parse_ann_info
(
info
)
ann_info
=
super
().
parse_ann_info
(
info
)
if
ann_info
is
None
:
if
ann_info
is
None
:
...
@@ -142,7 +154,7 @@ class KittiDataset(Det3DDataset):
...
@@ -142,7 +154,7 @@ class KittiDataset(Det3DDataset):
ann_info
[
'gt_bboxes_3d'
]
=
np
.
zeros
((
0
,
7
),
dtype
=
np
.
float32
)
ann_info
[
'gt_bboxes_3d'
]
=
np
.
zeros
((
0
,
7
),
dtype
=
np
.
float32
)
ann_info
[
'gt_labels_3d'
]
=
np
.
zeros
(
0
,
dtype
=
np
.
int64
)
ann_info
[
'gt_labels_3d'
]
=
np
.
zeros
(
0
,
dtype
=
np
.
int64
)
if
self
.
task
==
'mono_det'
:
if
self
.
load_type
in
[
'fov_image_based'
,
'mv_image_based'
]
:
ann_info
[
'gt_bboxes'
]
=
np
.
zeros
((
0
,
4
),
dtype
=
np
.
float32
)
ann_info
[
'gt_bboxes'
]
=
np
.
zeros
((
0
,
4
),
dtype
=
np
.
float32
)
ann_info
[
'gt_bboxes_labels'
]
=
np
.
array
(
0
,
dtype
=
np
.
int64
)
ann_info
[
'gt_bboxes_labels'
]
=
np
.
array
(
0
,
dtype
=
np
.
int64
)
ann_info
[
'centers_2d'
]
=
np
.
zeros
((
0
,
2
),
dtype
=
np
.
float32
)
ann_info
[
'centers_2d'
]
=
np
.
zeros
((
0
,
2
),
dtype
=
np
.
float32
)
...
...
mmdet3d/datasets/lyft_dataset.py
View file @
d7067e44
...
@@ -21,10 +21,10 @@ class LyftDataset(Det3DDataset):
...
@@ -21,10 +21,10 @@ class LyftDataset(Det3DDataset):
Args:
Args:
data_root (str): Path of dataset root.
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
ann_file (str): Path of annotation file.
pipeline (
l
ist[dict]
, optional
): Pipeline used for data processing.
pipeline (
L
ist[dict]): Pipeline used for data processing.
Defaults to
None
.
Defaults to
[]
.
modality (dict
, optional
): Modality to specify the sensor data used
modality (dict): Modality to specify the sensor data used
as input.
as input.
Defaults to dict(use_camera=False, use_lidar=True).
Defaults to dict(use_camera=False, use_lidar=True).
box_type_3d (str): Type of 3D box of this dataset.
box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
to its original format then converted them to `box_type_3d`.
...
@@ -33,14 +33,16 @@ class LyftDataset(Det3DDataset):
...
@@ -33,14 +33,16 @@ class LyftDataset(Det3DDataset):
- 'LiDAR': Box in LiDAR coordinates.
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool, optional): Whether to filter empty GT.
filter_empty_gt (bool): Whether to filter the data with empty GT.
Defaults to True.
If it's set to be True, the example with empty annotations after
test_mode (bool, optional): Whether the dataset is in test mode.
data pipeline will be dropped and a random example will be chosen
in `__getitem__`. Defaults to True.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
Defaults to False.
"""
"""
METAINFO
=
{
METAINFO
=
{
'
CLASSES
'
:
'
classes
'
:
(
'car'
,
'truck'
,
'bus'
,
'emergency_vehicle'
,
'other_vehicle'
,
(
'car'
,
'truck'
,
'bus'
,
'emergency_vehicle'
,
'other_vehicle'
,
'motorcycle'
,
'bicycle'
,
'pedestrian'
,
'animal'
)
'motorcycle'
,
'bicycle'
,
'pedestrian'
,
'animal'
)
}
}
...
@@ -66,16 +68,16 @@ class LyftDataset(Det3DDataset):
...
@@ -66,16 +68,16 @@ class LyftDataset(Det3DDataset):
**
kwargs
)
**
kwargs
)
def
parse_ann_info
(
self
,
info
:
dict
)
->
dict
:
def
parse_ann_info
(
self
,
info
:
dict
)
->
dict
:
"""
Get annotation info according to the given index
.
"""
Process the `instances` in data info to `ann_info`
.
Args:
Args:
info (dict): Data information of single data sample.
info (dict): Data information of single data sample.
Returns:
Returns:
dict:
a
nnotation information consists of the following keys:
dict:
A
nnotation information consists of the following keys:
- gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
- gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
3D ground truth bboxes.
3D ground truth bboxes.
- gt_labels_3d (np.ndarray): Labels of 3D ground truths.
- gt_labels_3d (np.ndarray): Labels of 3D ground truths.
"""
"""
ann_info
=
super
().
parse_ann_info
(
info
)
ann_info
=
super
().
parse_ann_info
(
info
)
...
...
mmdet3d/datasets/nuscenes_dataset.py
View file @
d7067e44
...
@@ -22,9 +22,8 @@ class NuScenesDataset(Det3DDataset):
...
@@ -22,9 +22,8 @@ class NuScenesDataset(Det3DDataset):
Args:
Args:
data_root (str): Path of dataset root.
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
ann_file (str): Path of annotation file.
task (str, optional): Detection task. Defaults to 'lidar_det'.
pipeline (list[dict]): Pipeline used for data processing.
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to [].
Defaults to None.
box_type_3d (str): Type of 3D box of this dataset.
box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
to its original format then converted them to `box_type_3d`.
...
@@ -33,20 +32,31 @@ class NuScenesDataset(Det3DDataset):
...
@@ -33,20 +32,31 @@ class NuScenesDataset(Det3DDataset):
- 'LiDAR': Box in LiDAR coordinates.
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
- 'Camera': Box in camera coordinates.
modality (dict, optional): Modality to specify the sensor data used
load_type (str): Type of loading mode. Defaults to 'frame_based'.
as input. Defaults to dict(use_camera=False, use_lidar=True).
filter_empty_gt (bool, optional): Whether to filter empty GT.
- 'frame_based': Load all of the instances in the frame.
Defaults to True.
- 'mv_image_based': Load all of the instances in the frame and need
test_mode (bool, optional): Whether the dataset is in test mode.
to convert to the FOV-based data type to support image-based
detector.
- 'fov_image_based': Only load the instances inside the default
cam, and need to convert to the FOV-based data type to support
image-based detector.
modality (dict): Modality to specify the sensor data used as input.
Defaults to dict(use_camera=False, use_lidar=True).
filter_empty_gt (bool): Whether to filter the data with empty GT.
If it's set to be True, the example with empty annotations after
data pipeline will be dropped and a random example will be chosen
in `__getitem__`. Defaults to True.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
Defaults to False.
with_velocity (bool
, optional
): Whether to include velocity prediction
with_velocity (bool): Whether to include velocity prediction
into the experiments. Defaults to True.
into the experiments. Defaults to True.
use_valid_flag (bool
, optional
): Whether to use `use_valid_flag` key
use_valid_flag (bool): Whether to use `use_valid_flag` key
in the info file as mask to filter gt_boxes and gt_names.
in the info file as mask to filter gt_boxes and gt_names.
Defaults to False.
Defaults to False.
"""
"""
METAINFO
=
{
METAINFO
=
{
'
CLASSES
'
:
'
classes
'
:
(
'car'
,
'truck'
,
'trailer'
,
'bus'
,
'construction_vehicle'
,
'bicycle'
,
(
'car'
,
'truck'
,
'trailer'
,
'bus'
,
'construction_vehicle'
,
'bicycle'
,
'motorcycle'
,
'pedestrian'
,
'traffic_cone'
,
'barrier'
),
'motorcycle'
,
'pedestrian'
,
'traffic_cone'
,
'barrier'
),
'version'
:
'version'
:
...
@@ -56,9 +66,9 @@ class NuScenesDataset(Det3DDataset):
...
@@ -56,9 +66,9 @@ class NuScenesDataset(Det3DDataset):
def
__init__
(
self
,
def
__init__
(
self
,
data_root
:
str
,
data_root
:
str
,
ann_file
:
str
,
ann_file
:
str
,
task
:
str
=
'lidar_det'
,
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
box_type_3d
:
str
=
'LiDAR'
,
box_type_3d
:
str
=
'LiDAR'
,
load_type
:
str
=
'frame_based'
,
modality
:
dict
=
dict
(
modality
:
dict
=
dict
(
use_camera
=
False
,
use_camera
=
False
,
use_lidar
=
True
,
use_lidar
=
True
,
...
@@ -72,8 +82,9 @@ class NuScenesDataset(Det3DDataset):
...
@@ -72,8 +82,9 @@ class NuScenesDataset(Det3DDataset):
self
.
with_velocity
=
with_velocity
self
.
with_velocity
=
with_velocity
# TODO: Redesign multi-view data process in the future
# TODO: Redesign multi-view data process in the future
assert
task
in
(
'lidar_det'
,
'mono_det'
,
'multi-view_det'
)
assert
load_type
in
(
'frame_based'
,
'mv_image_based'
,
self
.
task
=
task
'fov_image_based'
)
self
.
load_type
=
load_type
assert
box_type_3d
.
lower
()
in
(
'lidar'
,
'camera'
)
assert
box_type_3d
.
lower
()
in
(
'lidar'
,
'camera'
)
super
().
__init__
(
super
().
__init__
(
...
@@ -108,16 +119,16 @@ class NuScenesDataset(Det3DDataset):
...
@@ -108,16 +119,16 @@ class NuScenesDataset(Det3DDataset):
return
filtered_annotations
return
filtered_annotations
def
parse_ann_info
(
self
,
info
:
dict
)
->
dict
:
def
parse_ann_info
(
self
,
info
:
dict
)
->
dict
:
"""
Get annotation info according to the given index
.
"""
Process the `instances` in data info to `ann_info`
.
Args:
Args:
info (dict): Data information of single data sample.
info (dict): Data information of single data sample.
Returns:
Returns:
dict:
a
nnotation information consists of the following keys:
dict:
A
nnotation information consists of the following keys:
- gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
- gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
3D ground truth bboxes.
3D ground truth bboxes.
- gt_labels_3d (np.ndarray): Labels of ground truths.
- gt_labels_3d (np.ndarray): Labels of ground truths.
"""
"""
ann_info
=
super
().
parse_ann_info
(
info
)
ann_info
=
super
().
parse_ann_info
(
info
)
...
@@ -142,7 +153,7 @@ class NuScenesDataset(Det3DDataset):
...
@@ -142,7 +153,7 @@ class NuScenesDataset(Det3DDataset):
ann_info
[
'gt_bboxes_3d'
]
=
np
.
zeros
((
0
,
7
),
dtype
=
np
.
float32
)
ann_info
[
'gt_bboxes_3d'
]
=
np
.
zeros
((
0
,
7
),
dtype
=
np
.
float32
)
ann_info
[
'gt_labels_3d'
]
=
np
.
zeros
(
0
,
dtype
=
np
.
int64
)
ann_info
[
'gt_labels_3d'
]
=
np
.
zeros
(
0
,
dtype
=
np
.
int64
)
if
self
.
task
==
'mono3
d'
:
if
self
.
load_type
in
[
'fov_image_based'
,
'mv_image_base
d'
]
:
ann_info
[
'gt_bboxes'
]
=
np
.
zeros
((
0
,
4
),
dtype
=
np
.
float32
)
ann_info
[
'gt_bboxes'
]
=
np
.
zeros
((
0
,
4
),
dtype
=
np
.
float32
)
ann_info
[
'gt_bboxes_labels'
]
=
np
.
array
(
0
,
dtype
=
np
.
int64
)
ann_info
[
'gt_bboxes_labels'
]
=
np
.
array
(
0
,
dtype
=
np
.
int64
)
ann_info
[
'attr_labels'
]
=
np
.
array
(
0
,
dtype
=
np
.
int64
)
ann_info
[
'attr_labels'
]
=
np
.
array
(
0
,
dtype
=
np
.
int64
)
...
@@ -152,7 +163,7 @@ class NuScenesDataset(Det3DDataset):
...
@@ -152,7 +163,7 @@ class NuScenesDataset(Det3DDataset):
# the nuscenes box center is [0.5, 0.5, 0.5], we change it to be
# the nuscenes box center is [0.5, 0.5, 0.5], we change it to be
# the same as KITTI (0.5, 0.5, 0)
# the same as KITTI (0.5, 0.5, 0)
# TODO: Unify the coordinates
# TODO: Unify the coordinates
if
self
.
task
==
'mono_det'
:
if
self
.
load_type
in
[
'fov_image_based'
,
'mv_image_based'
]
:
gt_bboxes_3d
=
CameraInstance3DBoxes
(
gt_bboxes_3d
=
CameraInstance3DBoxes
(
ann_info
[
'gt_bboxes_3d'
],
ann_info
[
'gt_bboxes_3d'
],
box_dim
=
ann_info
[
'gt_bboxes_3d'
].
shape
[
-
1
],
box_dim
=
ann_info
[
'gt_bboxes_3d'
].
shape
[
-
1
],
...
@@ -167,7 +178,7 @@ class NuScenesDataset(Det3DDataset):
...
@@ -167,7 +178,7 @@ class NuScenesDataset(Det3DDataset):
return
ann_info
return
ann_info
def
parse_data_info
(
self
,
info
:
dict
)
->
dict
:
def
parse_data_info
(
self
,
info
:
dict
)
->
Union
[
List
[
dict
],
dict
]
:
"""Process the raw data info.
"""Process the raw data info.
The only difference with it in `Det3DDataset`
The only difference with it in `Det3DDataset`
...
@@ -177,10 +188,10 @@ class NuScenesDataset(Det3DDataset):
...
@@ -177,10 +188,10 @@ class NuScenesDataset(Det3DDataset):
info (dict): Raw info dict.
info (dict): Raw info dict.
Returns:
Returns:
dict: Has `ann_info` in training stage. And
List[dict] or
dict: Has `ann_info` in training stage. And
all path has been converted to absolute path.
all path has been converted to absolute path.
"""
"""
if
self
.
task
==
'mono_det
'
:
if
self
.
load_type
==
'mv_image_based
'
:
data_list
=
[]
data_list
=
[]
if
self
.
modality
[
'use_lidar'
]:
if
self
.
modality
[
'use_lidar'
]:
info
[
'lidar_points'
][
'lidar_path'
]
=
\
info
[
'lidar_points'
][
'lidar_path'
]
=
\
...
...
mmdet3d/datasets/s3dis_dataset.py
View file @
d7067e44
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
os
import
path
as
osp
from
os
import
path
as
osp
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -8,7 +8,6 @@ from mmdet3d.registry import DATASETS
...
@@ -8,7 +8,6 @@ from mmdet3d.registry import DATASETS
from
mmdet3d.structures
import
DepthInstance3DBoxes
from
mmdet3d.structures
import
DepthInstance3DBoxes
from
.det3d_dataset
import
Det3DDataset
from
.det3d_dataset
import
Det3DDataset
from
.seg3d_dataset
import
Seg3DDataset
from
.seg3d_dataset
import
Seg3DDataset
from
.transforms
import
Compose
@
DATASETS
.
register_module
()
@
DATASETS
.
register_module
()
...
@@ -19,138 +18,132 @@ class S3DISDataset(Det3DDataset):
...
@@ -19,138 +18,132 @@ class S3DISDataset(Det3DDataset):
often train on 5 of them and test on the remaining one. The one for
often train on 5 of them and test on the remaining one. The one for
test is Area_5 as suggested in `GSDN <https://arxiv.org/abs/2006.12356>`_.
test is Area_5 as suggested in `GSDN <https://arxiv.org/abs/2006.12356>`_.
To concatenate 5 areas during training
To concatenate 5 areas during training
`mm
det
.datasets.dataset_wrappers.ConcatDataset` should be used.
`mm
engine
.datasets.dataset_wrappers.ConcatDataset` should be used.
Args:
Args:
data_root (str): Path of dataset root.
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
ann_file (str): Path of annotation file.
pipeline (list[dict], optional): Pipeline used for data processing.
metainfo (dict, optional): Meta information for dataset, such as class
Defaults to None.
information. Defaults to None.
classes (tuple[str], optional): Classes used in the dataset.
data_prefix (dict): Prefix for data. Defaults to
Defaults to None.
dict(pts='points',
modality (dict, optional): Modality to specify the sensor data used
pts_instance_mask='instance_mask',
as input. Defaults to None.
pts_semantic_mask='semantic_mask').
box_type_3d (str, optional): Type of 3D box of this dataset.
pipeline (List[dict]): Pipeline used for data processing.
Defaults to [].
modality (dict): Modality to specify the sensor data used as input.
Defaults to dict(use_camera=False, use_lidar=True).
box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
to its original format then converted them to `box_type_3d`.
Defaults to 'Depth' in this dataset. Available options includes
Defaults to 'Depth' in this dataset. Available options includes
:
- 'LiDAR': Box in LiDAR coordinates.
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool, optional): Whether to filter empty GT.
filter_empty_gt (bool): Whether to filter the data with empty GT.
Defaults to True.
If it's set to be True, the example with empty annotations after
test_mode (bool, optional): Whether the dataset is in test mode.
data pipeline will be dropped and a random example will be chosen
in `__getitem__`. Defaults to True.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
Defaults to False.
"""
"""
CLASSES
=
(
'table'
,
'chair'
,
'sofa'
,
'bookcase'
,
'board'
)
METAINFO
=
{
'classes'
:
(
'table'
,
'chair'
,
'sofa'
,
'bookcase'
,
'board'
),
# the valid ids of segmentation annotations
'seg_valid_class_ids'
:
(
7
,
8
,
9
,
10
,
11
),
'seg_all_class_ids'
:
tuple
(
range
(
1
,
14
))
# possibly with 'stair' class
}
def
__init__
(
self
,
def
__init__
(
self
,
data_root
,
data_root
:
str
,
ann_file
,
ann_file
:
str
,
pipeline
=
None
,
metainfo
:
Optional
[
dict
]
=
None
,
classes
=
None
,
data_prefix
:
dict
=
dict
(
modality
=
None
,
pts
=
'points'
,
box_type_3d
=
'Depth'
,
pts_instance_mask
=
'instance_mask'
,
filter_empty_gt
=
True
,
pts_semantic_mask
=
'semantic_mask'
),
test_mode
=
False
,
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
*
kwargs
):
modality
:
dict
=
dict
(
use_camera
=
False
,
use_lidar
=
True
),
box_type_3d
:
str
=
'Depth'
,
filter_empty_gt
:
bool
=
True
,
test_mode
:
bool
=
False
,
**
kwargs
)
->
None
:
# construct seg_label_mapping for semantic mask
seg_max_cat_id
=
len
(
self
.
METAINFO
[
'seg_all_class_ids'
])
seg_valid_cat_ids
=
self
.
METAINFO
[
'seg_valid_class_ids'
]
neg_label
=
len
(
seg_valid_cat_ids
)
seg_label_mapping
=
np
.
ones
(
seg_max_cat_id
+
1
,
dtype
=
np
.
int
)
*
neg_label
for
cls_idx
,
cat_id
in
enumerate
(
seg_valid_cat_ids
):
seg_label_mapping
[
cat_id
]
=
cls_idx
self
.
seg_label_mapping
=
seg_label_mapping
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
ann_file
=
ann_file
,
ann_file
=
ann_file
,
metainfo
=
metainfo
,
data_prefix
=
data_prefix
,
pipeline
=
pipeline
,
pipeline
=
pipeline
,
classes
=
classes
,
modality
=
modality
,
modality
=
modality
,
box_type_3d
=
box_type_3d
,
box_type_3d
=
box_type_3d
,
filter_empty_gt
=
filter_empty_gt
,
filter_empty_gt
=
filter_empty_gt
,
test_mode
=
test_mode
,
test_mode
=
test_mode
,
*
kwargs
)
**
kwargs
)
self
.
metainfo
[
'seg_label_mapping'
]
=
self
.
seg_label_mapping
assert
'use_camera'
in
self
.
modality
and
\
'use_lidar'
in
self
.
modality
assert
self
.
modality
[
'use_camera'
]
or
self
.
modality
[
'use_lidar'
]
def
get_ann
_info
(
self
,
in
dex
)
:
def
parse_data
_info
(
self
,
in
fo
:
dict
)
->
dict
:
"""
Get annotation info according to the given index
.
"""
Process the raw data info
.
Args:
Args:
in
dex (int): Index of the annotation data to ge
t.
in
fo (dict): Raw info dic
t.
Returns:
Returns:
dict: annotation information consists of the following keys:
dict: Has `ann_info` in training stage. And
all path has been converted to absolute path.
- gt_bboxes_3d (:obj:`DepthInstance3DBoxes`):
3D ground truth bboxes
- gt_labels_3d (np.ndarray): Labels of ground truths.
- pts_instance_mask_path (str): Path of instance masks.
- pts_semantic_mask_path (str): Path of semantic masks.
"""
"""
# Use index to get the annos, thus the evalhook could also use this api
info
[
'pts_instance_mask_path'
]
=
osp
.
join
(
info
=
self
.
data_infos
[
index
]
self
.
data_prefix
.
get
(
'pts_instance_mask'
,
''
),
if
info
[
'annos'
][
'gt_num'
]
!=
0
:
info
[
'pts_instance_mask_path'
])
gt_bboxes_3d
=
info
[
'annos'
][
'gt_boxes_upright_depth'
].
astype
(
info
[
'pts_semantic_mask_path'
]
=
osp
.
join
(
np
.
float32
)
# k, 6
self
.
data_prefix
.
get
(
'pts_semantic_mask'
,
''
),
gt_labels_3d
=
info
[
'annos'
][
'class'
].
astype
(
np
.
int64
)
info
[
'pts_semantic_mask_path'
])
else
:
gt_bboxes_3d
=
np
.
zeros
((
0
,
6
),
dtype
=
np
.
float32
)
info
=
super
().
parse_data_info
(
info
)
gt_labels_3d
=
np
.
zeros
((
0
,
),
dtype
=
np
.
int64
)
# only be used in `PointSegClassMapping` in pipeline
# to map original semantic class to valid category ids.
# to target box structure
info
[
'seg_label_mapping'
]
=
self
.
seg_label_mapping
gt_bboxes_3d
=
DepthInstance3DBoxes
(
return
info
gt_bboxes_3d
,
box_dim
=
gt_bboxes_3d
.
shape
[
-
1
],
def
parse_ann_info
(
self
,
info
:
dict
)
->
dict
:
with_yaw
=
False
,
"""Process the `instances` in data info to `ann_info`.
origin
=
(
0.5
,
0.5
,
0.5
)).
convert_to
(
self
.
box_mode_3d
)
pts_instance_mask_path
=
osp
.
join
(
self
.
data_root
,
info
[
'pts_instance_mask_path'
])
pts_semantic_mask_path
=
osp
.
join
(
self
.
data_root
,
info
[
'pts_semantic_mask_path'
])
anns_results
=
dict
(
gt_bboxes_3d
=
gt_bboxes_3d
,
gt_labels_3d
=
gt_labels_3d
,
pts_instance_mask_path
=
pts_instance_mask_path
,
pts_semantic_mask_path
=
pts_semantic_mask_path
)
return
anns_results
def
get_data_info
(
self
,
index
):
"""Get data info according to the given index.
Args:
Args:
in
dex (int): Index of the sample data to ge
t.
in
fo (dict): Info dic
t.
Returns:
Returns:
dict: Data information that will be passed to the data
dict: Processed `ann_info`.
preprocessing transforms. It includes the following keys:
- pts_filename (str): Filename of point clouds.
- file_name (str): Filename of point clouds.
- ann_info (dict): Annotation info.
"""
"""
info
=
self
.
data_infos
[
index
]
ann_info
=
super
().
parse_ann_info
(
info
)
pts_filename
=
osp
.
join
(
self
.
data_root
,
info
[
'pts_path'
])
# empty gt
input_dict
=
dict
(
pts_filename
=
pts_filename
)
if
ann_info
is
None
:
ann_info
=
dict
()
ann_info
[
'gt_bboxes_3d'
]
=
np
.
zeros
((
0
,
6
),
dtype
=
np
.
float32
)
ann_info
[
'gt_labels_3d'
]
=
np
.
zeros
((
0
,
),
dtype
=
np
.
int64
)
# to target box structure
if
not
self
.
test_mode
:
ann_info
[
'gt_bboxes_3d'
]
=
DepthInstance3DBoxes
(
annos
=
self
.
get_ann_info
(
index
)
ann_info
[
'gt_bboxes_3d'
],
input_dict
[
'ann_info'
]
=
annos
box_dim
=
ann_info
[
'gt_bboxes_3d'
].
shape
[
-
1
],
if
self
.
filter_empty_gt
and
~
(
annos
[
'gt_labels_3d'
]
!=
-
1
).
any
():
with_yaw
=
False
,
return
None
origin
=
(
0.5
,
0.5
,
0.5
)).
convert_to
(
self
.
box_mode_3d
)
return
input_dict
return
ann_info
def
_build_default_pipeline
(
self
):
"""Build the default pipeline for this dataset."""
pipeline
=
[
dict
(
type
=
'LoadPointsFromFile'
,
coord_type
=
'DEPTH'
,
shift_height
=
False
,
load_dim
=
6
,
use_dim
=
[
0
,
1
,
2
,
3
,
4
,
5
]),
dict
(
type
=
'DefaultFormatBundle3D'
,
class_names
=
self
.
CLASSES
,
with_label
=
False
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
])
]
return
Compose
(
pipeline
)
class
_S3DISSegDataset
(
Seg3DDataset
):
class
_S3DISSegDataset
(
Seg3DDataset
):
...
@@ -166,30 +159,31 @@ class _S3DISSegDataset(Seg3DDataset):
...
@@ -166,30 +159,31 @@ class _S3DISSegDataset(Seg3DDataset):
wrapper to concat all the provided data in different areas.
wrapper to concat all the provided data in different areas.
Args:
Args:
data_root (str): Path of dataset root.
data_root (str, optional): Path of dataset root, Defaults to None.
ann_file (str): Path of annotation file.
ann_file (str): Path of annotation file. Defaults to ''.
pipeline (list[dict], optional): Pipeline used for data processing.
metainfo (dict, optional): Meta information for dataset, such as class
Defaults to None.
information. Defaults to None.
classes (tuple[str], optional): Classes used in the dataset.
data_prefix (dict): Prefix for training data. Defaults to
Defaults to None.
dict(pts='points', pts_instance_mask='', pts_semantic_mask='').
palette (list[list[int]], optional): The palette of segmentation map.
pipeline (List[dict]): Pipeline used for data processing.
Defaults to None.
Defaults to [].
modality (dict, optional): Modality to specify the sensor data used
modality (dict): Modality to specify the sensor data used as input.
as input. Defaults to None.
Defaults to dict(use_lidar=True, use_camera=False).
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
ignore_index (int, optional): The label index to be ignored, e.g.
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES).
unannotated points. If None is given, set to len(self.classes) to
be consistent with PointSegClassMapping function in pipeline.
Defaults to None.
Defaults to None.
scene_idxs (np.ndarray
|
str, optional): Precomputed index to load
scene_idxs (np.ndarray
or
str, optional): Precomputed index to load
data. For scenes with many points, we may sample it several times.
data. For scenes with many points, we may sample it several times.
Defaults to None.
Defaults to None.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
"""
"""
METAINFO
=
{
METAINFO
=
{
'
CLASSES
'
:
'
classes
'
:
(
'ceiling'
,
'floor'
,
'wall'
,
'beam'
,
'column'
,
'window'
,
'door'
,
(
'ceiling'
,
'floor'
,
'wall'
,
'beam'
,
'column'
,
'window'
,
'door'
,
'table'
,
'chair'
,
'sofa'
,
'bookcase'
,
'board'
,
'clutter'
),
'table'
,
'chair'
,
'sofa'
,
'bookcase'
,
'board'
,
'clutter'
),
'
PALETTE
'
:
[[
0
,
255
,
0
],
[
0
,
0
,
255
],
[
0
,
255
,
255
],
[
255
,
255
,
0
],
'
palette
'
:
[[
0
,
255
,
0
],
[
0
,
0
,
255
],
[
0
,
255
,
255
],
[
255
,
255
,
0
],
[
255
,
0
,
255
],
[
100
,
100
,
255
],
[
200
,
200
,
100
],
[
255
,
0
,
255
],
[
100
,
100
,
255
],
[
200
,
200
,
100
],
[
170
,
120
,
200
],
[
255
,
0
,
0
],
[
200
,
100
,
100
],
[
170
,
120
,
200
],
[
255
,
0
,
0
],
[
200
,
100
,
100
],
[
10
,
200
,
100
],
[
200
,
200
,
200
],
[
50
,
50
,
50
]],
[
10
,
200
,
100
],
[
200
,
200
,
200
],
[
50
,
50
,
50
]],
...
@@ -204,12 +198,12 @@ class _S3DISSegDataset(Seg3DDataset):
...
@@ -204,12 +198,12 @@ class _S3DISSegDataset(Seg3DDataset):
ann_file
:
str
=
''
,
ann_file
:
str
=
''
,
metainfo
:
Optional
[
dict
]
=
None
,
metainfo
:
Optional
[
dict
]
=
None
,
data_prefix
:
dict
=
dict
(
data_prefix
:
dict
=
dict
(
pts
=
'points'
,
img
=
''
,
instance_mask
=
''
,
semantic_mask
=
''
),
pts
=
'points'
,
pts_
instance_mask
=
''
,
pts_
semantic_mask
=
''
),
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
),
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
),
ignore_index
=
None
,
ignore_index
:
Optional
[
int
]
=
None
,
scene_idxs
=
None
,
scene_idxs
:
Optional
[
Union
[
np
.
ndarray
,
str
]]
=
None
,
test_mode
=
False
,
test_mode
:
bool
=
False
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
...
@@ -223,7 +217,8 @@ class _S3DISSegDataset(Seg3DDataset):
...
@@ -223,7 +217,8 @@ class _S3DISSegDataset(Seg3DDataset):
test_mode
=
test_mode
,
test_mode
=
test_mode
,
**
kwargs
)
**
kwargs
)
def
get_scene_idxs
(
self
,
scene_idxs
):
def
get_scene_idxs
(
self
,
scene_idxs
:
Union
[
np
.
ndarray
,
str
,
None
])
->
np
.
ndarray
:
"""Compute scene_idxs for data sampling.
"""Compute scene_idxs for data sampling.
We sample more times for scenes with more points.
We sample more times for scenes with more points.
...
@@ -250,37 +245,40 @@ class S3DISSegDataset(_S3DISSegDataset):
...
@@ -250,37 +245,40 @@ class S3DISSegDataset(_S3DISSegDataset):
data downloading.
data downloading.
Args:
Args:
data_root (str): Path of dataset root.
data_root (str, optional): Path of dataset root. Defaults to None.
ann_files (list[str]): Path of several annotation files.
ann_files (List[str]): Path of several annotation files.
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to ''.
Defaults to None.
metainfo (dict, optional): Meta information for dataset, such as class
classes (tuple[str], optional): Classes used in the dataset.
information. Defaults to None.
Defaults to None.
data_prefix (dict): Prefix for training data. Defaults to
palette (list[list[int]], optional): The palette of segmentation map.
dict(pts='points', pts_instance_mask='', pts_semantic_mask='').
Defaults to None.
pipeline (List[dict]): Pipeline used for data processing.
modality (dict, optional): Modality to specify the sensor data used
Defaults to [].
as input. Defaults to None.
modality (dict): Modality to specify the sensor data used as input.
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to dict(use_lidar=True, use_camera=False).
Defaults to False.
ignore_index (int, optional): The label index to be ignored, e.g.
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES).
unannotated points. If None is given, set to len(self.classes) to
be consistent with PointSegClassMapping function in pipeline.
Defaults to None.
Defaults to None.
scene_idxs (list[np.ndarray] | list[str], optional): Precomputed index
scene_idxs (List[np.ndarray] | List[str], optional): Precomputed index
to load data. For scenes with many points, we may sample it several
to load data. For scenes with many points, we may sample it
times. Defaults to None.
several times. Defaults to None.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
data_root
:
Optional
[
str
]
=
None
,
data_root
:
Optional
[
str
]
=
None
,
ann_files
:
str
=
''
,
ann_files
:
List
[
str
]
=
''
,
metainfo
:
Optional
[
dict
]
=
None
,
metainfo
:
Optional
[
dict
]
=
None
,
data_prefix
:
dict
=
dict
(
data_prefix
:
dict
=
dict
(
pts
=
'points'
,
img
=
''
,
instance_mask
=
''
,
semantic_mask
=
''
),
pts
=
'points'
,
pts_
instance_mask
=
''
,
pts_
semantic_mask
=
''
),
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
),
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
),
ignore_index
=
None
,
ignore_index
:
Optional
[
int
]
=
None
,
scene_idxs
=
None
,
scene_idxs
:
Optional
[
Union
[
List
[
np
.
ndarray
],
test_mode
=
False
,
List
[
str
]]]
=
None
,
test_mode
:
bool
=
False
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
# make sure that ann_files and scene_idxs have same length
# make sure that ann_files and scene_idxs have same length
...
@@ -298,7 +296,6 @@ class S3DISSegDataset(_S3DISSegDataset):
...
@@ -298,7 +296,6 @@ class S3DISSegDataset(_S3DISSegDataset):
ignore_index
=
ignore_index
,
ignore_index
=
ignore_index
,
scene_idxs
=
scene_idxs
[
0
],
scene_idxs
=
scene_idxs
[
0
],
test_mode
=
test_mode
,
test_mode
=
test_mode
,
serialize_data
=
False
,
**
kwargs
)
**
kwargs
)
datasets
=
[
datasets
=
[
...
@@ -312,56 +309,44 @@ class S3DISSegDataset(_S3DISSegDataset):
...
@@ -312,56 +309,44 @@ class S3DISSegDataset(_S3DISSegDataset):
ignore_index
=
ignore_index
,
ignore_index
=
ignore_index
,
scene_idxs
=
scene_idxs
[
i
],
scene_idxs
=
scene_idxs
[
i
],
test_mode
=
test_mode
,
test_mode
=
test_mode
,
serialize_data
=
False
,
**
kwargs
)
for
i
in
range
(
len
(
ann_files
))
**
kwargs
)
for
i
in
range
(
len
(
ann_files
))
]
]
# data_list and scene_idxs need to be concat
# data_list and scene_idxs need to be concat
self
.
concat_data_list
([
dst
.
data_list
for
dst
in
datasets
])
self
.
concat_data_list
([
dst
.
data_list
for
dst
in
datasets
])
self
.
concat_scene_idxs
([
dst
.
scene_idxs
for
dst
in
datasets
])
# set group flag for the sampler
# set group flag for the sampler
if
not
self
.
test_mode
:
if
not
self
.
test_mode
:
self
.
_set_group_flag
()
self
.
_set_group_flag
()
def
concat_data_list
(
self
,
data_lists
)
:
def
concat_data_list
(
self
,
data_lists
:
List
[
List
[
dict
]])
->
None
:
"""Concat data_list from several datasets to form self.data_list.
"""Concat data_list from several datasets to form self.data_list.
Args:
Args:
data_lists (list[list[dict]])
data_lists (List[List[dict]]): List of dict containing
annotation information.
"""
"""
self
.
data_list
=
[
self
.
data_list
=
[
data
for
data_list
in
data_lists
for
data
in
data_list
data
for
data_list
in
data_lists
for
data
in
data_list
]
]
def
concat_scene_idxs
(
self
,
scene_idxs
):
"""Concat scene_idxs from several datasets to form self.scene_idxs.
Needs to manually add offset to scene_idxs[1, 2, ...].
Args:
scene_idxs (list[np.ndarray])
"""
self
.
scene_idxs
=
np
.
array
([],
dtype
=
np
.
int32
)
offset
=
0
for
one_scene_idxs
in
scene_idxs
:
self
.
scene_idxs
=
np
.
concatenate
(
[
self
.
scene_idxs
,
one_scene_idxs
+
offset
]).
astype
(
np
.
int32
)
offset
=
np
.
unique
(
self
.
scene_idxs
).
max
()
+
1
@
staticmethod
@
staticmethod
def
_duplicate_to_list
(
x
,
num
)
:
def
_duplicate_to_list
(
x
:
Any
,
num
:
int
)
->
list
:
"""Repeat x `num` times to form a list."""
"""Repeat x `num` times to form a list."""
return
[
x
for
_
in
range
(
num
)]
return
[
x
for
_
in
range
(
num
)]
def
_check_ann_files
(
self
,
ann_file
):
def
_check_ann_files
(
self
,
ann_file
:
Union
[
List
[
str
],
Tuple
[
str
],
str
])
->
List
[
str
]:
"""Make ann_files as list/tuple."""
"""Make ann_files as list/tuple."""
# ann_file could be str
# ann_file could be str
if
not
isinstance
(
ann_file
,
(
list
,
tuple
)):
if
not
isinstance
(
ann_file
,
(
list
,
tuple
)):
ann_file
=
self
.
_duplicate_to_list
(
ann_file
,
1
)
ann_file
=
self
.
_duplicate_to_list
(
ann_file
,
1
)
return
ann_file
return
ann_file
def
_check_scene_idxs
(
self
,
scene_idx
,
num
):
def
_check_scene_idxs
(
self
,
scene_idx
:
Union
[
str
,
List
[
Union
[
list
,
tuple
,
np
.
ndarray
]],
List
[
str
],
None
],
num
:
int
)
->
List
[
np
.
ndarray
]:
"""Make scene_idxs as list/tuple."""
"""Make scene_idxs as list/tuple."""
if
scene_idx
is
None
:
if
scene_idx
is
None
:
return
self
.
_duplicate_to_list
(
scene_idx
,
num
)
return
self
.
_duplicate_to_list
(
scene_idx
,
num
)
...
...
mmdet3d/datasets/scannet_dataset.py
View file @
d7067e44
...
@@ -26,13 +26,13 @@ class ScanNetDataset(Det3DDataset):
...
@@ -26,13 +26,13 @@ class ScanNetDataset(Det3DDataset):
metainfo (dict, optional): Meta information for dataset, such as class
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
information. Defaults to None.
data_prefix (dict): Prefix for data. Defaults to
data_prefix (dict): Prefix for data. Defaults to
`
dict(pts='points',
dict(pts='points',
pts_i
s
ntance_mask='instance_mask',
pts_in
s
tance_mask='instance_mask',
pts_semantic_mask='semantic_mask')
`
.
pts_semantic_mask='semantic_mask').
pipeline (
l
ist[dict]): Pipeline used for data processing.
pipeline (
L
ist[dict]): Pipeline used for data processing.
Defaults to
None
.
Defaults to
[]
.
modality (dict): Modality to specify the sensor data used
modality (dict): Modality to specify the sensor data used
as input.
as input.
Defaults to
None
.
Defaults to
dict(use_camera=False, use_lidar=True)
.
box_type_3d (str): Type of 3D box of this dataset.
box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
to its original format then converted them to `box_type_3d`.
...
@@ -41,13 +41,15 @@ class ScanNetDataset(Det3DDataset):
...
@@ -41,13 +41,15 @@ class ScanNetDataset(Det3DDataset):
- 'LiDAR': Box in LiDAR coordinates.
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool): Whether to filter empty GT.
filter_empty_gt (bool): Whether to filter the data with empty GT.
Defaults to True.
If it's set to be True, the example with empty annotations after
data pipeline will be dropped and a random example will be chosen
in `__getitem__`. Defaults to True.
test_mode (bool): Whether the dataset is in test mode.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
Defaults to False.
"""
"""
METAINFO
=
{
METAINFO
=
{
'
CLASSES
'
:
'
classes
'
:
(
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
'door'
,
'window'
,
(
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
'door'
,
'window'
,
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'curtain'
,
'refrigerator'
,
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'curtain'
,
'refrigerator'
,
'showercurtrain'
,
'toilet'
,
'sink'
,
'bathtub'
,
'garbagebin'
),
'showercurtrain'
,
'toilet'
,
'sink'
,
'bathtub'
,
'garbagebin'
),
...
@@ -71,7 +73,7 @@ class ScanNetDataset(Det3DDataset):
...
@@ -71,7 +73,7 @@ class ScanNetDataset(Det3DDataset):
box_type_3d
:
str
=
'Depth'
,
box_type_3d
:
str
=
'Depth'
,
filter_empty_gt
:
bool
=
True
,
filter_empty_gt
:
bool
=
True
,
test_mode
:
bool
=
False
,
test_mode
:
bool
=
False
,
**
kwargs
):
**
kwargs
)
->
None
:
# construct seg_label_mapping for semantic mask
# construct seg_label_mapping for semantic mask
seg_max_cat_id
=
len
(
self
.
METAINFO
[
'seg_all_class_ids'
])
seg_max_cat_id
=
len
(
self
.
METAINFO
[
'seg_all_class_ids'
])
...
@@ -128,8 +130,8 @@ class ScanNetDataset(Det3DDataset):
...
@@ -128,8 +130,8 @@ class ScanNetDataset(Det3DDataset):
info (dict): Raw info dict.
info (dict): Raw info dict.
Returns:
Returns:
dict:
Data information that will be passed to the data
dict:
Has `ann_info` in training stage. And
preprocessing transforms. It includes the following keys:
all path has been converted to absolute path.
"""
"""
info
[
'axis_align_matrix'
]
=
self
.
_get_axis_align_matrix
(
info
)
info
[
'axis_align_matrix'
]
=
self
.
_get_axis_align_matrix
(
info
)
info
[
'pts_instance_mask_path'
]
=
osp
.
join
(
info
[
'pts_instance_mask_path'
]
=
osp
.
join
(
...
@@ -146,13 +148,13 @@ class ScanNetDataset(Det3DDataset):
...
@@ -146,13 +148,13 @@ class ScanNetDataset(Det3DDataset):
return
info
return
info
def
parse_ann_info
(
self
,
info
:
dict
)
->
dict
:
def
parse_ann_info
(
self
,
info
:
dict
)
->
dict
:
"""Process the `instances` in data info to `ann_info`
"""Process the `instances` in data info to `ann_info`
.
Args:
Args:
info (dict): Info dict.
info (dict): Info dict.
Returns:
Returns:
dict: Processed `ann_info`
dict: Processed `ann_info`
.
"""
"""
ann_info
=
super
().
parse_ann_info
(
info
)
ann_info
=
super
().
parse_ann_info
(
info
)
# empty gt
# empty gt
...
@@ -181,32 +183,36 @@ class ScanNetSegDataset(Seg3DDataset):
...
@@ -181,32 +183,36 @@ class ScanNetSegDataset(Seg3DDataset):
for data downloading.
for data downloading.
Args:
Args:
data_root (str): Path of dataset root.
data_root (str, optional): Path of dataset root. Defaults to None.
ann_file (str): Path of annotation file.
ann_file (str): Path of annotation file. Defaults to ''.
pipeline (list[dict], optional): Pipeline used for data processing.
pipeline (List[dict]): Pipeline used for data processing.
Defaults to None.
Defaults to [].
classes (tuple[str], optional): Classes used in the dataset.
metainfo (dict, optional): Meta information for dataset, such as class
Defaults to None.
information. Defaults to None.
palette (list[list[int]], optional): The palette of segmentation map.
data_prefix (dict): Prefix for training data. Defaults to
Defaults to None.
dict(pts='points',
modality (dict, optional): Modality to specify the sensor data used
img='',
as input. Defaults to None.
pts_instance_mask='',
test_mode (bool, optional): Whether the dataset is in test mode.
pts_semantic_mask='').
Defaults to False.
modality (dict): Modality to specify the sensor data used as input.
Defaults to dict(use_lidar=True, use_camera=False).
ignore_index (int, optional): The label index to be ignored, e.g.
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES).
unannotated points. If None is given, set to len(self.classes) to
be consistent with PointSegClassMapping function in pipeline.
Defaults to None.
Defaults to None.
scene_idxs (np.ndarray
|
str, optional): Precomputed index to load
scene_idxs (np.ndarray
or
str, optional): Precomputed index to load
data. For scenes with many points, we may sample it several times.
data. For scenes with many points, we may sample it several times.
Defaults to None.
Defaults to None.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
"""
"""
METAINFO
=
{
METAINFO
=
{
'
CLASSES
'
:
'
classes
'
:
(
'wall'
,
'floor'
,
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
'door'
,
(
'wall'
,
'floor'
,
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
'door'
,
'window'
,
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'curtain'
,
'window'
,
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'curtain'
,
'refrigerator'
,
'showercurtrain'
,
'toilet'
,
'sink'
,
'bathtub'
,
'refrigerator'
,
'showercurtrain'
,
'toilet'
,
'sink'
,
'bathtub'
,
'otherfurniture'
),
'otherfurniture'
),
'
PALETTE
'
:
[
'
palette
'
:
[
[
174
,
199
,
232
],
[
174
,
199
,
232
],
[
152
,
223
,
138
],
[
152
,
223
,
138
],
[
31
,
119
,
180
],
[
31
,
119
,
180
],
...
@@ -239,12 +245,15 @@ class ScanNetSegDataset(Seg3DDataset):
...
@@ -239,12 +245,15 @@ class ScanNetSegDataset(Seg3DDataset):
ann_file
:
str
=
''
,
ann_file
:
str
=
''
,
metainfo
:
Optional
[
dict
]
=
None
,
metainfo
:
Optional
[
dict
]
=
None
,
data_prefix
:
dict
=
dict
(
data_prefix
:
dict
=
dict
(
pts
=
'points'
,
img
=
''
,
instance_mask
=
''
,
semantic_mask
=
''
),
pts
=
'points'
,
img
=
''
,
pts_instance_mask
=
''
,
pts_semantic_mask
=
''
),
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
),
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
),
ignore_index
=
None
,
ignore_index
:
Optional
[
int
]
=
None
,
scene_idxs
=
None
,
scene_idxs
:
Optional
[
Union
[
np
.
ndarray
,
str
]]
=
None
,
test_mode
=
False
,
test_mode
:
bool
=
False
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
...
@@ -258,7 +267,8 @@ class ScanNetSegDataset(Seg3DDataset):
...
@@ -258,7 +267,8 @@ class ScanNetSegDataset(Seg3DDataset):
test_mode
=
test_mode
,
test_mode
=
test_mode
,
**
kwargs
)
**
kwargs
)
def
get_scene_idxs
(
self
,
scene_idxs
):
def
get_scene_idxs
(
self
,
scene_idxs
:
Union
[
np
.
ndarray
,
str
,
None
])
->
np
.
ndarray
:
"""Compute scene_idxs for data sampling.
"""Compute scene_idxs for data sampling.
We sample more times for scenes with more points.
We sample more times for scenes with more points.
...
@@ -275,11 +285,11 @@ class ScanNetSegDataset(Seg3DDataset):
...
@@ -275,11 +285,11 @@ class ScanNetSegDataset(Seg3DDataset):
class
ScanNetInstanceSegDataset
(
Seg3DDataset
):
class
ScanNetInstanceSegDataset
(
Seg3DDataset
):
METAINFO
=
{
METAINFO
=
{
'
CLASSES
'
:
'
classes
'
:
(
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
'door'
,
'window'
,
(
'cabinet'
,
'bed'
,
'chair'
,
'sofa'
,
'table'
,
'door'
,
'window'
,
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'curtain'
,
'refrigerator'
,
'bookshelf'
,
'picture'
,
'counter'
,
'desk'
,
'curtain'
,
'refrigerator'
,
'showercurtrain'
,
'toilet'
,
'sink'
,
'bathtub'
,
'garbagebin'
),
'showercurtrain'
,
'toilet'
,
'sink'
,
'bathtub'
,
'garbagebin'
),
'
PLATTE
'
:
[
'
palette
'
:
[
[
174
,
199
,
232
],
[
174
,
199
,
232
],
[
152
,
223
,
138
],
[
152
,
223
,
138
],
[
31
,
119
,
180
],
[
31
,
119
,
180
],
...
@@ -312,13 +322,16 @@ class ScanNetInstanceSegDataset(Seg3DDataset):
...
@@ -312,13 +322,16 @@ class ScanNetInstanceSegDataset(Seg3DDataset):
ann_file
:
str
=
''
,
ann_file
:
str
=
''
,
metainfo
:
Optional
[
dict
]
=
None
,
metainfo
:
Optional
[
dict
]
=
None
,
data_prefix
:
dict
=
dict
(
data_prefix
:
dict
=
dict
(
pts
=
'points'
,
img
=
''
,
instance_mask
=
''
,
semantic_mask
=
''
),
pts
=
'points'
,
img
=
''
,
pts_instance_mask
=
''
,
pts_semantic_mask
=
''
),
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
),
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
),
test_mode
=
False
,
test_mode
:
bool
=
False
,
ignore_index
=
None
,
ignore_index
:
Optional
[
int
]
=
None
,
scene_idxs
=
None
,
scene_idxs
:
Optional
[
Union
[
np
.
ndarray
,
str
]]
=
None
,
file_client_args
=
dict
(
backend
=
'disk'
),
file_client_args
:
dict
=
dict
(
backend
=
'disk'
),
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
...
...
mmdet3d/datasets/seg3d_dataset.py
View file @
d7067e44
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
os
import
path
as
osp
from
os
import
path
as
osp
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Sequence
,
Union
import
mmengine
import
mmengine
import
numpy
as
np
import
numpy
as
np
...
@@ -16,40 +16,45 @@ class Seg3DDataset(BaseDataset):
...
@@ -16,40 +16,45 @@ class Seg3DDataset(BaseDataset):
This is the base dataset of ScanNet, S3DIS and SemanticKITTI dataset.
This is the base dataset of ScanNet, S3DIS and SemanticKITTI dataset.
Args:
Args:
data_root (str): Path of dataset root.
data_root (str, optional): Path of dataset root. Defaults to None.
ann_file (str): Path of annotation file.
ann_file (str): Path of annotation file. Defaults to ''.
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
metainfo (dict, optional): Meta information for dataset, such as class
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
information. Defaults to None.
data_prefix (dict, optional): Prefix for training data. Defaults to
data_prefix (dict): Prefix for training data. Defaults to
dict(pts='velodyne', img='', instance_mask='', semantic_mask='').
dict(pts='points',
pipeline (list[dict], optional): Pipeline used for data processing.
img='',
Defaults to None.
pts_instance_mask='',
modality (dict, optional): Modality to specify the sensor data used
pts_semantic_mask='').
as input, it usually has following keys.
pipeline (List[dict]): Pipeline used for data processing.
Defaults to [].
modality (dict): Modality to specify the sensor data used
as input, it usually has following keys:
- use_camera: bool
- use_camera: bool
- use_lidar: bool
- use_lidar: bool
Defaults to `dict(use_lidar=True, use_camera=False)`
Defaults to dict(use_lidar=True, use_camera=False).
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
ignore_index (int, optional): The label index to be ignored, e.g.
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.
CLASSES
) to
unannotated points. If None is given, set to len(self.
classes
) to
be consistent with PointSegClassMapping function in pipeline.
be consistent with PointSegClassMapping function in pipeline.
Defaults to None.
Defaults to None.
scene_idxs (np.ndarray
|
str, optional): Precomputed index to load
scene_idxs (np.ndarray
or
str, optional): Precomputed index to load
data. For scenes with many points, we may sample it several times.
data. For scenes with many points, we may sample it several times.
Defaults to None.
Defaults to None.
load_eval_anns (bool): Whether to load annotations
test_mode (bool): Whether the dataset is in test mode.
in test_mode, the annotation will be save in
Defaults to False.
`eval_ann_infos`, which can be use in Evaluator.
serialize_data (bool): Whether to hold memory using serialized objects,
when enabled, data loader workers can use shared RAM from master
process instead of making a copy.
Defaults to False for 3D Segmentation datasets.
load_eval_anns (bool): Whether to load annotations in test_mode,
the annotation will be save in `eval_ann_infos`, which can be used
in Evaluator. Defaults to True.
file_client_args (dict): Configuration of file client.
file_client_args (dict): Configuration of file client.
Defaults to
`
dict(backend='disk')
`
.
Defaults to dict(backend='disk').
"""
"""
METAINFO
=
{
METAINFO
=
{
'
CLASSES
'
:
None
,
# names of all classes data used for the task
'
classes
'
:
None
,
# names of all classes data used for the task
'
PALETTE
'
:
None
,
# official color for visualization
'
palette
'
:
None
,
# official color for visualization
'seg_valid_class_ids'
:
None
,
# class_ids used for training
'seg_valid_class_ids'
:
None
,
# class_ids used for training
'seg_all_class_ids'
:
None
,
# all possible class_ids in loaded seg mask
'seg_all_class_ids'
:
None
,
# all possible class_ids in loaded seg mask
}
}
...
@@ -62,12 +67,13 @@ class Seg3DDataset(BaseDataset):
...
@@ -62,12 +67,13 @@ class Seg3DDataset(BaseDataset):
pts
=
'points'
,
pts
=
'points'
,
img
=
''
,
img
=
''
,
pts_instance_mask
=
''
,
pts_instance_mask
=
''
,
pts_emantic_mask
=
''
),
pts_
s
emantic_mask
=
''
),
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
),
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
),
ignore_index
:
Optional
[
int
]
=
None
,
ignore_index
:
Optional
[
int
]
=
None
,
scene_idxs
:
Optional
[
str
]
=
None
,
scene_idxs
:
Optional
[
Union
[
str
,
np
.
ndarray
]
]
=
None
,
test_mode
:
bool
=
False
,
test_mode
:
bool
=
False
,
serialize_data
:
bool
=
False
,
load_eval_anns
:
bool
=
True
,
load_eval_anns
:
bool
=
True
,
file_client_args
:
dict
=
dict
(
backend
=
'disk'
),
file_client_args
:
dict
=
dict
(
backend
=
'disk'
),
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
...
@@ -78,11 +84,11 @@ class Seg3DDataset(BaseDataset):
...
@@ -78,11 +84,11 @@ class Seg3DDataset(BaseDataset):
# TODO: We maintain the ignore_index attributes,
# TODO: We maintain the ignore_index attributes,
# but we may consider to remove it in the future.
# but we may consider to remove it in the future.
self
.
ignore_index
=
len
(
self
.
METAINFO
[
'
CLASSES
'
])
if
\
self
.
ignore_index
=
len
(
self
.
METAINFO
[
'
classes
'
])
if
\
ignore_index
is
None
else
ignore_index
ignore_index
is
None
else
ignore_index
# Get label mapping for custom classes
# Get label mapping for custom classes
new_classes
=
metainfo
.
get
(
'
CLASSES
'
,
None
)
new_classes
=
metainfo
.
get
(
'
classes
'
,
None
)
self
.
label_mapping
,
self
.
label2cat
,
seg_valid_class_ids
=
\
self
.
label_mapping
,
self
.
label2cat
,
seg_valid_class_ids
=
\
self
.
get_label_mapping
(
new_classes
)
self
.
get_label_mapping
(
new_classes
)
...
@@ -95,10 +101,10 @@ class Seg3DDataset(BaseDataset):
...
@@ -95,10 +101,10 @@ class Seg3DDataset(BaseDataset):
# generate palette if it is not defined based on
# generate palette if it is not defined based on
# label mapping, otherwise directly use palette
# label mapping, otherwise directly use palette
# defined in dataset config.
# defined in dataset config.
palette
=
metainfo
.
get
(
'
PALETTE
'
,
None
)
palette
=
metainfo
.
get
(
'
palette
'
,
None
)
updated_palette
=
self
.
_update_palette
(
new_classes
,
palette
)
updated_palette
=
self
.
_update_palette
(
new_classes
,
palette
)
metainfo
[
'
PALETTE
'
]
=
updated_palette
metainfo
[
'
palette
'
]
=
updated_palette
# construct seg_label_mapping for semantic mask
# construct seg_label_mapping for semantic mask
seg_max_cat_id
=
len
(
self
.
METAINFO
[
'seg_all_class_ids'
])
seg_max_cat_id
=
len
(
self
.
METAINFO
[
'seg_all_class_ids'
])
...
@@ -117,18 +123,19 @@ class Seg3DDataset(BaseDataset):
...
@@ -117,18 +123,19 @@ class Seg3DDataset(BaseDataset):
data_prefix
=
data_prefix
,
data_prefix
=
data_prefix
,
pipeline
=
pipeline
,
pipeline
=
pipeline
,
test_mode
=
test_mode
,
test_mode
=
test_mode
,
serialize_data
=
serialize_data
,
**
kwargs
)
**
kwargs
)
self
.
metainfo
[
'seg_label_mapping'
]
=
self
.
seg_label_mapping
self
.
metainfo
[
'seg_label_mapping'
]
=
self
.
seg_label_mapping
self
.
scene_idxs
=
self
.
get_scene_idxs
(
scene_idxs
)
self
.
scene_idxs
=
self
.
get_scene_idxs
(
scene_idxs
)
self
.
data_list
=
[
self
.
data_list
[
i
]
for
i
in
self
.
scene_idxs
]
# set group flag for the sampler
# set group flag for the sampler
if
not
self
.
test_mode
:
if
not
self
.
test_mode
:
self
.
_set_group_flag
()
self
.
_set_group_flag
()
def
get_label_mapping
(
self
,
def
get_label_mapping
(
self
,
new_classes
:
Optional
[
Sequence
]
=
None
new_classes
:
Optional
[
Sequence
]
=
None
)
->
tuple
:
)
->
Union
[
Dict
,
None
]:
"""Get label mapping.
"""Get label mapping.
The ``label_mapping`` is a dictionary, its keys are the old label ids
The ``label_mapping`` is a dictionary, its keys are the old label ids
...
@@ -138,21 +145,20 @@ class Seg3DDataset(BaseDataset):
...
@@ -138,21 +145,20 @@ class Seg3DDataset(BaseDataset):
None, `label_mapping` is not None.
None, `label_mapping` is not None.
Args:
Args:
new_classes (list, tuple, optional): The new classes name from
new_classes (list or tuple, optional): The new classes name from
metainfo. Default to None.
metainfo. Defaults to None.
Returns:
Returns:
tuple: The mapping from old classes in cls.METAINFO to
tuple: The mapping from old classes in cls.METAINFO to
new classes in metainfo
new classes in metainfo
"""
"""
old_classes
=
self
.
METAINFO
.
get
(
'
CLASSES
'
,
None
)
old_classes
=
self
.
METAINFO
.
get
(
'
classes
'
,
None
)
if
(
new_classes
is
not
None
and
old_classes
is
not
None
if
(
new_classes
is
not
None
and
old_classes
is
not
None
and
list
(
new_classes
)
!=
list
(
old_classes
)):
and
list
(
new_classes
)
!=
list
(
old_classes
)):
if
not
set
(
new_classes
).
issubset
(
old_classes
):
if
not
set
(
new_classes
).
issubset
(
old_classes
):
raise
ValueError
(
raise
ValueError
(
f
'new classes
{
new_classes
}
is not a '
f
'new classes
{
new_classes
}
is not a '
f
'subset of
CLASSES
{
old_classes
}
in METAINFO.'
)
f
'subset of
classes
{
old_classes
}
in METAINFO.'
)
# obtain true id from valid_class_ids
# obtain true id from valid_class_ids
valid_class_ids
=
[
valid_class_ids
=
[
...
@@ -180,13 +186,14 @@ class Seg3DDataset(BaseDataset):
...
@@ -180,13 +186,14 @@ class Seg3DDataset(BaseDataset):
# map label to category name
# map label to category name
label2cat
=
{
label2cat
=
{
i
:
cat_name
i
:
cat_name
for
i
,
cat_name
in
enumerate
(
self
.
METAINFO
[
'
CLASSES
'
])
for
i
,
cat_name
in
enumerate
(
self
.
METAINFO
[
'
classes
'
])
}
}
valid_class_ids
=
self
.
METAINFO
[
'seg_valid_class_ids'
]
valid_class_ids
=
self
.
METAINFO
[
'seg_valid_class_ids'
]
return
label_mapping
,
label2cat
,
valid_class_ids
return
label_mapping
,
label2cat
,
valid_class_ids
def
_update_palette
(
self
,
new_classes
,
palette
)
->
list
:
def
_update_palette
(
self
,
new_classes
:
list
,
palette
:
Union
[
None
,
list
])
->
list
:
"""Update palette according to metainfo.
"""Update palette according to metainfo.
If length of palette is equal to classes, just return the palette.
If length of palette is equal to classes, just return the palette.
...
@@ -199,10 +206,10 @@ class Seg3DDataset(BaseDataset):
...
@@ -199,10 +206,10 @@ class Seg3DDataset(BaseDataset):
"""
"""
if
palette
is
None
:
if
palette
is
None
:
# If palette is not defined, it generate a palette according
# If palette is not defined, it generate a palette according
# to the original
PALETTE
and classes.
# to the original
palette
and classes.
old_classes
=
self
.
METAINFO
.
get
(
'
CLASSES
'
,
None
)
old_classes
=
self
.
METAINFO
.
get
(
'
classes
'
,
None
)
palette
=
[
palette
=
[
self
.
METAINFO
[
'
PALETTE
'
][
old_classes
.
index
(
cls_name
)]
self
.
METAINFO
[
'
palette
'
][
old_classes
.
index
(
cls_name
)]
for
cls_name
in
new_classes
for
cls_name
in
new_classes
]
]
return
palette
return
palette
...
@@ -211,8 +218,8 @@ class Seg3DDataset(BaseDataset):
...
@@ -211,8 +218,8 @@ class Seg3DDataset(BaseDataset):
if
len
(
palette
)
==
len
(
new_classes
):
if
len
(
palette
)
==
len
(
new_classes
):
return
palette
return
palette
else
:
else
:
raise
ValueError
(
'Once
PLATTE
in set in metainfo, it should'
raise
ValueError
(
'Once
palette
in set in metainfo, it should'
'match
CLASSES
in metainfo'
)
'match
classes
in metainfo'
)
def
parse_data_info
(
self
,
info
:
dict
)
->
dict
:
def
parse_data_info
(
self
,
info
:
dict
)
->
dict
:
"""Process the raw data info.
"""Process the raw data info.
...
@@ -260,7 +267,8 @@ class Seg3DDataset(BaseDataset):
...
@@ -260,7 +267,8 @@ class Seg3DDataset(BaseDataset):
return
info
return
info
def
get_scene_idxs
(
self
,
scene_idxs
):
def
get_scene_idxs
(
self
,
scene_idxs
:
Union
[
None
,
str
,
np
.
ndarray
])
->
np
.
ndarray
:
"""Compute scene_idxs for data sampling.
"""Compute scene_idxs for data sampling.
We sample more times for scenes with more points.
We sample more times for scenes with more points.
...
@@ -282,7 +290,7 @@ class Seg3DDataset(BaseDataset):
...
@@ -282,7 +290,7 @@ class Seg3DDataset(BaseDataset):
return
scene_idxs
.
astype
(
np
.
int32
)
return
scene_idxs
.
astype
(
np
.
int32
)
def
_set_group_flag
(
self
):
def
_set_group_flag
(
self
)
->
None
:
"""Set flag according to image aspect ratio.
"""Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1,
Images with aspect ratio greater than 1 will be set as group 1,
...
...
mmdet3d/datasets/semantickitti_dataset.py
View file @
d7067e44
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Union
import
numpy
as
np
from
mmdet3d.registry
import
DATASETS
from
mmdet3d.registry
import
DATASETS
from
.seg3d_dataset
import
Seg3DDataset
from
.seg3d_dataset
import
Seg3DDataset
...
@@ -14,30 +16,35 @@ class SemanticKITTIDataset(Seg3DDataset):
...
@@ -14,30 +16,35 @@ class SemanticKITTIDataset(Seg3DDataset):
for data downloading
for data downloading
Args:
Args:
data_root (str): Path of dataset root.
data_root (str, optional): Path of dataset root. Defaults to None.
ann_file (str): Path of annotation file.
ann_file (str): Path of annotation file. Defaults to ''.
pipeline (list[dict], optional): Pipeline used for data processing.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_prefix (dict): Prefix for training data. Defaults to
dict(pts='points',
img='',
pts_instance_mask='',
pts_semantic_mask='').
pipeline (List[dict]): Pipeline used for data processing.
Defaults to [].
modality (dict): Modality to specify the sensor data used as input,
it usually has following keys:
- use_camera: bool
- use_lidar: bool
Defaults to dict(use_lidar=True, use_camera=False).
ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.classes) to
be consistent with PointSegClassMapping function in pipeline.
Defaults to None.
Defaults to None.
classes (tuple[str], optional): Classes used in the dataset.
scene_idxs (np.ndarray or str, optional): Precomputed index to load
data. For scenes with many points, we may sample it several times.
Defaults to None.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
test_mode (bool): Whether the dataset is in test mode.
as input. Defaults to None.
box_type_3d (str, optional): NO 3D box for this dataset.
You can choose any type
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
Defaults to 'LiDAR' in this dataset. Available options includes
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool, optional): Whether to filter empty GT.
Defaults to True.
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
Defaults to False.
"""
"""
METAINFO
=
{
METAINFO
=
{
'
CLASSES
'
:
(
'unlabeled'
,
'car'
,
'bicycle'
,
'motorcycle'
,
'truck'
,
'
classes
'
:
(
'unlabeled'
,
'car'
,
'bicycle'
,
'motorcycle'
,
'truck'
,
'bus'
,
'person'
,
'bicyclist'
,
'motorcyclist'
,
'road'
,
'bus'
,
'person'
,
'bicyclist'
,
'motorcyclist'
,
'road'
,
'parking'
,
'sidewalk'
,
'other-ground'
,
'building'
,
'fence'
,
'parking'
,
'sidewalk'
,
'other-ground'
,
'building'
,
'fence'
,
'vegetation'
,
'trunck'
,
'terrian'
,
'pole'
,
'traffic-sign'
),
'vegetation'
,
'trunck'
,
'terrian'
,
'pole'
,
'traffic-sign'
),
...
@@ -52,12 +59,15 @@ class SemanticKITTIDataset(Seg3DDataset):
...
@@ -52,12 +59,15 @@ class SemanticKITTIDataset(Seg3DDataset):
ann_file
:
str
=
''
,
ann_file
:
str
=
''
,
metainfo
:
Optional
[
dict
]
=
None
,
metainfo
:
Optional
[
dict
]
=
None
,
data_prefix
:
dict
=
dict
(
data_prefix
:
dict
=
dict
(
pts
=
'points'
,
img
=
''
,
instance_mask
=
''
,
semantic_mask
=
''
),
pts
=
'points'
,
img
=
''
,
pts_instance_mask
=
''
,
pts_semantic_mask
=
''
),
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
),
modality
:
dict
=
dict
(
use_lidar
=
True
,
use_camera
=
False
),
ignore_index
=
None
,
ignore_index
:
Optional
[
int
]
=
None
,
scene_idxs
=
None
,
scene_idxs
:
Optional
[
Union
[
str
,
np
.
ndarray
]]
=
None
,
test_mode
=
False
,
test_mode
:
bool
=
False
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
super
().
__init__
(
super
().
__init__
(
...
...
mmdet3d/datasets/sunrgbd_dataset.py
View file @
d7067e44
...
@@ -24,13 +24,13 @@ class SUNRGBDDataset(Det3DDataset):
...
@@ -24,13 +24,13 @@ class SUNRGBDDataset(Det3DDataset):
ann_file (str): Path of annotation file.
ann_file (str): Path of annotation file.
metainfo (dict, optional): Meta information for dataset, such as class
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
information. Defaults to None.
data_prefix (dict
, optiona;
): Prefix for data. Defaults to
data_prefix (dict): Prefix for data. Defaults to
dict(pts='points',img='sunrgbd_trainval').
dict(pts='points',img='sunrgbd_trainval').
pipeline (
l
ist[dict]
, optional
): Pipeline used for data processing.
pipeline (
L
ist[dict]): Pipeline used for data processing.
Defaults to
None
.
Defaults to
[]
.
modality (dict
, optional
): Modality to specify the sensor data used
modality (dict): Modality to specify the sensor data used
as input.
as input.
Defaults to dict(use_camera=True, use_lidar=True).
Defaults to dict(use_camera=True, use_lidar=True).
default_cam_key (str
, optional
): The default camera name adopted.
default_cam_key (str): The default camera name adopted.
Defaults to 'CAM0'.
Defaults to 'CAM0'.
box_type_3d (str): Type of 3D box of this dataset.
box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box
Based on the `box_type_3d`, the dataset will encapsulate the box
...
@@ -40,13 +40,13 @@ class SUNRGBDDataset(Det3DDataset):
...
@@ -40,13 +40,13 @@ class SUNRGBDDataset(Det3DDataset):
- 'LiDAR': Box in LiDAR coordinates.
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool
, optional
): Whether to filter empty GT.
filter_empty_gt (bool): Whether to filter empty GT.
Defaults to True.
Defaults to True.
test_mode (bool
, optional
): Whether the dataset is in test mode.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
Defaults to False.
"""
"""
METAINFO
=
{
METAINFO
=
{
'
CLASSES
'
:
(
'bed'
,
'table'
,
'sofa'
,
'chair'
,
'toilet'
,
'desk'
,
'
classes
'
:
(
'bed'
,
'table'
,
'sofa'
,
'chair'
,
'toilet'
,
'desk'
,
'dresser'
,
'night_stand'
,
'bookshelf'
,
'bathtub'
)
'dresser'
,
'night_stand'
,
'bookshelf'
,
'bathtub'
)
}
}
...
@@ -58,11 +58,11 @@ class SUNRGBDDataset(Det3DDataset):
...
@@ -58,11 +58,11 @@ class SUNRGBDDataset(Det3DDataset):
pts
=
'points'
,
img
=
'sunrgbd_trainval/image'
),
pts
=
'points'
,
img
=
'sunrgbd_trainval/image'
),
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
default_cam_key
:
str
=
'CAM0'
,
default_cam_key
:
str
=
'CAM0'
,
modality
=
dict
(
use_camera
=
True
,
use_lidar
=
True
),
modality
:
dict
=
dict
(
use_camera
=
True
,
use_lidar
=
True
),
box_type_3d
:
str
=
'Depth'
,
box_type_3d
:
str
=
'Depth'
,
filter_empty_gt
:
bool
=
True
,
filter_empty_gt
:
bool
=
True
,
test_mode
:
bool
=
False
,
test_mode
:
bool
=
False
,
**
kwargs
):
**
kwargs
)
->
None
:
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
ann_file
=
ann_file
,
ann_file
=
ann_file
,
...
@@ -121,7 +121,7 @@ class SUNRGBDDataset(Det3DDataset):
...
@@ -121,7 +121,7 @@ class SUNRGBDDataset(Det3DDataset):
return
info
return
info
def
parse_ann_info
(
self
,
info
:
dict
)
->
dict
:
def
parse_ann_info
(
self
,
info
:
dict
)
->
dict
:
"""Process the `instances` in data info to `ann_info`
"""Process the `instances` in data info to `ann_info`
.
Args:
Args:
info (dict): Info dict.
info (dict): Info dict.
...
...
mmdet3d/datasets/transforms/__init__.py
View file @
d7067e44
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
.compose
import
Compose
from
.dbsampler
import
DataBaseSampler
from
.dbsampler
import
DataBaseSampler
from
.formating
import
Pack3DDetInputs
from
.formating
import
Pack3DDetInputs
from
.loading
import
(
LoadAnnotations3D
,
LoadImageFromFileMono3D
,
from
.loading
import
(
LoadAnnotations3D
,
LoadImageFromFileMono3D
,
...
@@ -21,15 +20,13 @@ from .transforms_3d import (AffineResize, BackgroundPointsFilter,
...
@@ -21,15 +20,13 @@ from .transforms_3d import (AffineResize, BackgroundPointsFilter,
__all__
=
[
__all__
=
[
'ObjectSample'
,
'RandomFlip3D'
,
'ObjectNoise'
,
'GlobalRotScaleTrans'
,
'ObjectSample'
,
'RandomFlip3D'
,
'ObjectNoise'
,
'GlobalRotScaleTrans'
,
'PointShuffle'
,
'ObjectRangeFilter'
,
'PointsRangeFilter'
,
'PointShuffle'
,
'ObjectRangeFilter'
,
'PointsRangeFilter'
,
'Pack3DDetInputs'
,
'Pack3DDetInputs'
,
'LoadMultiViewImageFromFiles'
,
'LoadPointsFromFile'
,
'Compose'
,
'LoadMultiViewImageFromFiles'
,
'LoadPointsFromFile'
,
'DataBaseSampler'
,
'NormalizePointsColor'
,
'LoadAnnotations3D'
,
'DataBaseSampler'
,
'IndoorPointSample'
,
'PointSample'
,
'PointSegClassMapping'
,
'NormalizePointsColor'
,
'LoadAnnotations3D'
,
'IndoorPointSample'
,
'MultiScaleFlipAug3D'
,
'LoadPointsFromMultiSweeps'
,
'PointSample'
,
'PointSegClassMapping'
,
'MultiScaleFlipAug3D'
,
'BackgroundPointsFilter'
,
'VoxelBasedPointSampler'
,
'GlobalAlignment'
,
'LoadPointsFromMultiSweeps'
,
'BackgroundPointsFilter'
,
'IndoorPatchPointSample'
,
'LoadImageFromFileMono3D'
,
'ObjectNameFilter'
,
'VoxelBasedPointSampler'
,
'GlobalAlignment'
,
'IndoorPatchPointSample'
,
'RandomDropPointsColor'
,
'RandomJitterPoints'
,
'AffineResize'
,
'LoadImageFromFileMono3D'
,
'ObjectNameFilter'
,
'RandomDropPointsColor'
,
'RandomShiftScale'
,
'LoadPointsFromDict'
,
'Resize3D'
,
'RandomResize3D'
,
'RandomJitterPoints'
,
'AffineResize'
,
'RandomShiftScale'
,
'LoadPointsFromDict'
,
'Resize3D'
,
'RandomResize3D'
,
'MultiViewWrapper'
,
'PhotoMetricDistortion3D'
'MultiViewWrapper'
,
'PhotoMetricDistortion3D'
]
]
mmdet3d/datasets/transforms/compose.py
deleted
100644 → 0
View file @
28fe73d2
# Copyright (c) OpenMMLab. All rights reserved.
import
collections
from
mmdet3d.registry
import
TRANSFORMS
@
TRANSFORMS
.
register_module
()
class
Compose
:
"""Compose multiple transforms sequentially.
Args:
transforms (Sequence[dict | callable]): Sequence of transform object or
config dict to be composed.
"""
def
__init__
(
self
,
transforms
):
assert
isinstance
(
transforms
,
collections
.
abc
.
Sequence
)
self
.
transforms
=
[]
for
transform
in
transforms
:
if
isinstance
(
transform
,
dict
):
transform
=
TRANSFORMS
.
build
(
transform
)
self
.
transforms
.
append
(
transform
)
elif
callable
(
transform
):
self
.
transforms
.
append
(
transform
)
else
:
raise
TypeError
(
'transform must be callable or a dict'
)
def
__call__
(
self
,
data
):
"""Call function to apply transforms sequentially.
Args:
data (dict): A result dict contains the data to transform.
Returns:
dict: Transformed data.
"""
for
t
in
self
.
transforms
:
data
=
t
(
data
)
if
data
is
None
:
return
None
return
data
def
__repr__
(
self
):
format_string
=
self
.
__class__
.
__name__
+
'('
for
t
in
self
.
transforms
:
str_
=
t
.
__repr__
()
if
'Compose('
in
str_
:
str_
=
str_
.
replace
(
'
\n
'
,
'
\n
'
)
format_string
+=
'
\n
'
format_string
+=
f
'
{
str_
}
'
format_string
+=
'
\n
)'
return
format_string
mmdet3d/datasets/transforms/dbsampler.py
View file @
d7067e44
...
@@ -18,9 +18,8 @@ class BatchSampler:
...
@@ -18,9 +18,8 @@ class BatchSampler:
sample_list (list[dict]): List of samples.
sample_list (list[dict]): List of samples.
name (str, optional): The category of samples. Defaults to None.
name (str, optional): The category of samples. Defaults to None.
epoch (int, optional): Sampling epoch. Defaults to None.
epoch (int, optional): Sampling epoch. Defaults to None.
shuffle (bool, optional): Whether to shuffle indices.
shuffle (bool): Whether to shuffle indices. Defaults to False.
Defaults to False.
drop_reminder (bool): Drop reminder. Defaults to False.
drop_reminder (bool, optional): Drop reminder. Defaults to False.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -90,12 +89,11 @@ class DataBaseSampler(object):
...
@@ -90,12 +89,11 @@ class DataBaseSampler(object):
prepare (dict): Name of preparation functions and the input value.
prepare (dict): Name of preparation functions and the input value.
sample_groups (dict): Sampled classes and numbers.
sample_groups (dict): Sampled classes and numbers.
classes (list[str], optional): List of classes. Defaults to None.
classes (list[str], optional): List of classes. Defaults to None.
points_loader(dict
, optional
): Config of points loader. Defaults to
points_loader
(dict): Config of points loader. Defaults to
dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0, 1, 2, 3]).
dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0, 1, 2, 3]).
file_client_args (dict, optional): Config dict of file clients,
file_client_args (dict): Arguments to instantiate a FileClient.
refer to
See :class:`mmengine.fileio.FileClient` for details.
https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
Defaults to dict(backend='disk').
for more details. Defaults to dict(backend='disk').
"""
"""
def
__init__
(
def
__init__
(
...
@@ -133,12 +131,12 @@ class DataBaseSampler(object):
...
@@ -133,12 +131,12 @@ class DataBaseSampler(object):
from
mmengine.logging
import
MMLogger
from
mmengine.logging
import
MMLogger
logger
:
MMLogger
=
MMLogger
.
get_current_instance
()
logger
:
MMLogger
=
MMLogger
.
get_current_instance
()
for
k
,
v
in
db_infos
.
items
():
for
k
,
v
in
db_infos
.
items
():
logger
.
info
(
f
'load
{
len
(
v
)
}
{
k
}
database infos'
)
logger
.
info
(
f
'load
{
len
(
v
)
}
{
k
}
database infos
in DataBaseSampler
'
)
for
prep_func
,
val
in
prepare
.
items
():
for
prep_func
,
val
in
prepare
.
items
():
db_infos
=
getattr
(
self
,
prep_func
)(
db_infos
,
val
)
db_infos
=
getattr
(
self
,
prep_func
)(
db_infos
,
val
)
logger
.
info
(
'After filter database:'
)
logger
.
info
(
'After filter database:'
)
for
k
,
v
in
db_infos
.
items
():
for
k
,
v
in
db_infos
.
items
():
logger
.
info
(
f
'load
{
len
(
v
)
}
{
k
}
database infos'
)
logger
.
info
(
f
'load
{
len
(
v
)
}
{
k
}
database infos
in DataBaseSampler
'
)
self
.
db_infos
=
db_infos
self
.
db_infos
=
db_infos
...
@@ -219,9 +217,9 @@ class DataBaseSampler(object):
...
@@ -219,9 +217,9 @@ class DataBaseSampler(object):
dict: Dict of sampled 'pseudo ground truths'.
dict: Dict of sampled 'pseudo ground truths'.
- gt_labels_3d (np.ndarray): ground truths labels
- gt_labels_3d (np.ndarray): ground truths labels
of sampled objects.
of sampled objects.
- gt_bboxes_3d (:obj:`BaseInstance3DBoxes`):
- gt_bboxes_3d (:obj:`BaseInstance3DBoxes`):
sampled ground truth 3D bounding boxes
sampled ground truth 3D bounding boxes
- points (np.ndarray): sampled points
- points (np.ndarray): sampled points
- group_ids (np.ndarray): ids of sampled ground truths
- group_ids (np.ndarray): ids of sampled ground truths
"""
"""
...
...
mmdet3d/datasets/transforms/formating.py
View file @
d7067e44
...
@@ -102,7 +102,7 @@ class Pack3DDetInputs(BaseTransform):
...
@@ -102,7 +102,7 @@ class Pack3DDetInputs(BaseTransform):
- points
- points
- img
- img
- 'data_samples' (obj:`Det3DDataSample`): The annotation info of
- 'data_samples' (
:
obj:`Det3DDataSample`): The annotation info of
the sample.
the sample.
"""
"""
# augtest
# augtest
...
...
mmdet3d/datasets/transforms/loading.py
View file @
d7067e44
...
@@ -7,10 +7,10 @@ import mmengine
...
@@ -7,10 +7,10 @@ import mmengine
import
numpy
as
np
import
numpy
as
np
from
mmcv.transforms
import
LoadImageFromFile
from
mmcv.transforms
import
LoadImageFromFile
from
mmcv.transforms.base
import
BaseTransform
from
mmcv.transforms.base
import
BaseTransform
from
mmdet.datasets.transforms
import
LoadAnnotations
from
mmdet3d.registry
import
TRANSFORMS
from
mmdet3d.registry
import
TRANSFORMS
from
mmdet3d.structures.points
import
BasePoints
,
get_points_type
from
mmdet3d.structures.points
import
BasePoints
,
get_points_type
from
mmdet.datasets.transforms
import
LoadAnnotations
@
TRANSFORMS
.
register_module
()
@
TRANSFORMS
.
register_module
()
...
@@ -20,19 +20,17 @@ class LoadMultiViewImageFromFiles(BaseTransform):
...
@@ -20,19 +20,17 @@ class LoadMultiViewImageFromFiles(BaseTransform):
Expects results['img_filename'] to be a list of filenames.
Expects results['img_filename'] to be a list of filenames.
Args:
Args:
to_float32 (bool
, optional
): Whether to convert the img to float32.
to_float32 (bool): Whether to convert the img to float32.
Defaults to False.
Defaults to False.
color_type (str, optional): Color type of the file.
color_type (str): Color type of the file. Defaults to 'unchanged'.
Defaults to 'unchanged'.
file_client_args (dict): Arguments to instantiate a FileClient.
file_client_args (dict): Config dict of file clients,
See :class:`mmengine.fileio.FileClient` for details.
refer to
Defaults to dict(backend='disk').
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
num_views (int): Number of view in a frame. Defaults to 5.
for more details. Defaults to dict(backend='disk').
num_ref_frames (int): Number of frame in loading. Defaults to -1.
num_views (int): num of view in a frame. Default to 5.
test_mode (bool): Whether is test mode in loading. Defaults to False.
num_ref_frames (int): num of frame in loading. Default to -1.
set_default_scale (bool): Whether to set default scale.
test_mode (bool): Whether is test mode in loading. Default to False.
Defaults to True.
set_default_scale (bool): Whether to set default scale. Default to
True.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -63,7 +61,7 @@ class LoadMultiViewImageFromFiles(BaseTransform):
...
@@ -63,7 +61,7 @@ class LoadMultiViewImageFromFiles(BaseTransform):
Returns:
Returns:
dict: The result dict containing the multi-view image data.
dict: The result dict containing the multi-view image data.
Added keys and values are described below.
Added keys and values are described below.
- filename (str): Multi-view image filenames.
- filename (str): Multi-view image filenames.
- img (np.ndarray): Multi-view image arrays.
- img (np.ndarray): Multi-view image arrays.
...
@@ -210,7 +208,7 @@ class LoadMultiViewImageFromFiles(BaseTransform):
...
@@ -210,7 +208,7 @@ class LoadMultiViewImageFromFiles(BaseTransform):
results
[
'num_ref_frames'
]
=
self
.
num_ref_frames
results
[
'num_ref_frames'
]
=
self
.
num_ref_frames
return
results
return
results
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(to_float32=
{
self
.
to_float32
}
, '
repr_str
+=
f
'(to_float32=
{
self
.
to_float32
}
, '
...
@@ -276,22 +274,17 @@ class LoadPointsFromMultiSweeps(BaseTransform):
...
@@ -276,22 +274,17 @@ class LoadPointsFromMultiSweeps(BaseTransform):
This is usually used for nuScenes dataset to utilize previous sweeps.
This is usually used for nuScenes dataset to utilize previous sweeps.
Args:
Args:
sweeps_num (int, optional): Number of sweeps. Defaults to 10.
sweeps_num (int): Number of sweeps. Defaults to 10.
load_dim (int, optional): Dimension number of the loaded points.
load_dim (int): Dimension number of the loaded points. Defaults to 5.
Defaults to 5.
use_dim (list[int]): Which dimension to use. Defaults to [0, 1, 2, 4].
use_dim (list[int], optional): Which dimension to use.
file_client_args (dict): Arguments to instantiate a FileClient.
Defaults to [0, 1, 2, 4].
See :class:`mmengine.fileio.FileClient` for details.
file_client_args (dict, optional): Config dict of file clients,
Defaults to dict(backend='disk').
refer to
pad_empty_sweeps (bool): Whether to repeat keyframe when
https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
pad_empty_sweeps (bool, optional): Whether to repeat keyframe when
sweeps is empty. Defaults to False.
sweeps is empty. Defaults to False.
remove_close (bool, optional): Whether to remove close points.
remove_close (bool): Whether to remove close points. Defaults to False.
Defaults to False.
test_mode (bool): If `test_mode=True`, it will not randomly sample
test_mode (bool, optional): If `test_mode=True`, it will not
sweeps but select the nearest N frames. Defaults to False.
randomly sample sweeps but select the nearest N frames.
Defaults to False.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -336,11 +329,11 @@ class LoadPointsFromMultiSweeps(BaseTransform):
...
@@ -336,11 +329,11 @@ class LoadPointsFromMultiSweeps(BaseTransform):
def
_remove_close
(
self
,
def
_remove_close
(
self
,
points
:
Union
[
np
.
ndarray
,
BasePoints
],
points
:
Union
[
np
.
ndarray
,
BasePoints
],
radius
:
float
=
1.0
)
->
Union
[
np
.
ndarray
,
BasePoints
]:
radius
:
float
=
1.0
)
->
Union
[
np
.
ndarray
,
BasePoints
]:
"""Remove
s
point too close within a certain radius from origin.
"""Remove point too close within a certain radius from origin.
Args:
Args:
points (np.ndarray | :obj:`BasePoints`): Sweep points.
points (np.ndarray | :obj:`BasePoints`): Sweep points.
radius (float
, optional
): Radius below which points are removed.
radius (float): Radius below which points are removed.
Defaults to 1.0.
Defaults to 1.0.
Returns:
Returns:
...
@@ -366,10 +359,10 @@ class LoadPointsFromMultiSweeps(BaseTransform):
...
@@ -366,10 +359,10 @@ class LoadPointsFromMultiSweeps(BaseTransform):
Returns:
Returns:
dict: The result dict containing the multi-sweep points data.
dict: The result dict containing the multi-sweep points data.
Updated key and value are described below.
Updated key and value are described below.
- points (np.ndarray | :obj:`BasePoints`): Multi-sweep point
- points (np.ndarray | :obj:`BasePoints`): Multi-sweep point
cloud arrays.
cloud arrays.
"""
"""
points
=
results
[
'points'
]
points
=
results
[
'points'
]
points
.
tensor
[:,
4
]
=
0
points
.
tensor
[:,
4
]
=
0
...
@@ -414,7 +407,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
...
@@ -414,7 +407,7 @@ class LoadPointsFromMultiSweeps(BaseTransform):
results
[
'points'
]
=
points
results
[
'points'
]
=
points
return
results
return
results
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
return
f
'
{
self
.
__class__
.
__name__
}
(sweeps_num=
{
self
.
sweeps_num
}
)'
return
f
'
{
self
.
__class__
.
__name__
}
(sweeps_num=
{
self
.
sweeps_num
}
)'
...
@@ -444,7 +437,7 @@ class PointSegClassMapping(BaseTransform):
...
@@ -444,7 +437,7 @@ class PointSegClassMapping(BaseTransform):
Returns:
Returns:
dict: The result dict containing the mapped category ids.
dict: The result dict containing the mapped category ids.
Updated key and value are described below.
Updated key and value are described below.
- pts_semantic_mask (np.ndarray): Mapped semantic masks.
- pts_semantic_mask (np.ndarray): Mapped semantic masks.
"""
"""
...
@@ -465,7 +458,7 @@ class PointSegClassMapping(BaseTransform):
...
@@ -465,7 +458,7 @@ class PointSegClassMapping(BaseTransform):
return
results
return
results
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
return
repr_str
return
repr_str
...
@@ -490,7 +483,7 @@ class NormalizePointsColor(BaseTransform):
...
@@ -490,7 +483,7 @@ class NormalizePointsColor(BaseTransform):
Returns:
Returns:
dict: The result dict containing the normalized points.
dict: The result dict containing the normalized points.
Updated key and value are described below.
Updated key and value are described below.
- points (:obj:`BasePoints`): Points after color normalization.
- points (:obj:`BasePoints`): Points after color normalization.
"""
"""
...
@@ -505,7 +498,7 @@ class NormalizePointsColor(BaseTransform):
...
@@ -505,7 +498,7 @@ class NormalizePointsColor(BaseTransform):
input_dict
[
'points'
]
=
points
input_dict
[
'points'
]
=
points
return
input_dict
return
input_dict
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(color_mean=
{
self
.
color_mean
}
)'
repr_str
+=
f
'(color_mean=
{
self
.
color_mean
}
)'
...
@@ -533,19 +526,15 @@ class LoadPointsFromFile(BaseTransform):
...
@@ -533,19 +526,15 @@ class LoadPointsFromFile(BaseTransform):
- 'LIDAR': Points in LiDAR coordinates.
- 'LIDAR': Points in LiDAR coordinates.
- 'DEPTH': Points in depth coordinates, usually for indoor dataset.
- 'DEPTH': Points in depth coordinates, usually for indoor dataset.
- 'CAMERA': Points in camera coordinates.
- 'CAMERA': Points in camera coordinates.
load_dim (int, optional): The dimension of the loaded points.
load_dim (int): The dimension of the loaded points. Defaults to 6.
Defaults to 6.
use_dim (list[int] | int): Which dimensions of the points to use.
use_dim (list[int] | int, optional): Which dimensions of the points
Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
to use. Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
or use_dim=[0, 1, 2, 3] to use the intensity dimension.
or use_dim=[0, 1, 2, 3] to use the intensity dimension.
shift_height (bool, optional): Whether to use shifted height.
shift_height (bool): Whether to use shifted height. Defaults to False.
Defaults to False.
use_color (bool): Whether to use color features. Defaults to False.
use_color (bool, optional): Whether to use color features.
file_client_args (dict): Arguments to instantiate a FileClient.
Defaults to False.
See :class:`mmengine.fileio.FileClient` for details.
file_client_args (dict, optional): Config dict of file clients,
Defaults to dict(backend='disk').
refer to
https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
"""
"""
def
__init__
(
def
__init__
(
...
@@ -602,7 +591,7 @@ class LoadPointsFromFile(BaseTransform):
...
@@ -602,7 +591,7 @@ class LoadPointsFromFile(BaseTransform):
Returns:
Returns:
dict: The result dict containing the point clouds data.
dict: The result dict containing the point clouds data.
Added key and value are described below.
Added key and value are described below.
- points (:obj:`BasePoints`): Point clouds data.
- points (:obj:`BasePoints`): Point clouds data.
"""
"""
...
@@ -638,7 +627,7 @@ class LoadPointsFromFile(BaseTransform):
...
@@ -638,7 +627,7 @@ class LoadPointsFromFile(BaseTransform):
return
results
return
results
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
+
'('
repr_str
=
self
.
__class__
.
__name__
+
'('
repr_str
+=
f
'shift_height=
{
self
.
shift_height
}
, '
repr_str
+=
f
'shift_height=
{
self
.
shift_height
}
, '
...
@@ -688,7 +677,7 @@ class LoadAnnotations3D(LoadAnnotations):
...
@@ -688,7 +677,7 @@ class LoadAnnotations3D(LoadAnnotations):
- pts_instance_mask_path (str): Path of instance mask file.
- pts_instance_mask_path (str): Path of instance mask file.
Only when `with_mask_3d` is True.
Only when `with_mask_3d` is True.
- pts_semantic_mask_path (str): Path of semantic mask file.
- pts_semantic_mask_path (str): Path of semantic mask file.
Only when
Only when
`with_seg_3d` is True.
Added Keys:
Added Keys:
...
@@ -713,33 +702,25 @@ class LoadAnnotations3D(LoadAnnotations):
...
@@ -713,33 +702,25 @@ class LoadAnnotations3D(LoadAnnotations):
Only when `with_seg_3d` is True.
Only when `with_seg_3d` is True.
Args:
Args:
with_bbox_3d (bool, optional): Whether to load 3D boxes.
with_bbox_3d (bool): Whether to load 3D boxes. Defaults to True.
Defaults to True.
with_label_3d (bool): Whether to load 3D labels. Defaults to True.
with_label_3d (bool, optional): Whether to load 3D labels.
with_attr_label (bool): Whether to load attribute label.
Defaults to True.
with_attr_label (bool, optional): Whether to load attribute label.
Defaults to False.
with_mask_3d (bool, optional): Whether to load 3D instance masks.
for points. Defaults to False.
with_seg_3d (bool, optional): Whether to load 3D semantic masks.
for points. Defaults to False.
with_bbox (bool, optional): Whether to load 2D boxes.
Defaults to False.
with_label (bool, optional): Whether to load 2D labels.
Defaults to False.
Defaults to False.
with_mask (bool
, optional
): Whether to load
2
D instance masks.
with_mask
_3d
(bool): Whether to load
3
D instance masks
for points
.
Defaults to False.
Defaults to False.
with_seg (bool
, optional
): Whether to load
2
D semantic masks.
with_seg
_3d
(bool): Whether to load
3
D semantic masks
for points
.
Defaults to False.
Defaults to False.
with_bbox_depth (bool, optional): Whether to load 2.5D boxes.
with_bbox (bool): Whether to load 2D boxes. Defaults to False.
Defaults to False.
with_label (bool): Whether to load 2D labels. Defaults to False.
poly2mask (bool, optional): Whether to convert polygon annotations
with_mask (bool): Whether to load 2D instance masks. Defaults to False.
to bitmasks. Defaults to True.
with_seg (bool): Whether to load 2D semantic masks. Defaults to False.
seg_3d_dtype (dtype, optional): Dtype of 3D semantic masks.
with_bbox_depth (bool): Whether to load 2.5D boxes. Defaults to False.
Defaults to int64.
poly2mask (bool): Whether to convert polygon annotations to bitmasks.
file_client_args (dict): Config dict of file clients, refer to
Defaults to True.
https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
seg_3d_dtype (dtype): Dtype of 3D semantic masks. Defaults to int64.
for more details.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmengine.fileio.FileClient` for details.
Defaults to dict(backend='disk').
"""
"""
def
__init__
(
def
__init__
(
...
@@ -889,7 +870,8 @@ class LoadAnnotations3D(LoadAnnotations):
...
@@ -889,7 +870,8 @@ class LoadAnnotations3D(LoadAnnotations):
`ignore_flag`
`ignore_flag`
Args:
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
results (dict): Result dict from :obj:`mmcv.BaseDataset`.
Returns:
Returns:
dict: The dict contains loaded bounding box annotations.
dict: The dict contains loaded bounding box annotations.
"""
"""
...
@@ -900,7 +882,7 @@ class LoadAnnotations3D(LoadAnnotations):
...
@@ -900,7 +882,7 @@ class LoadAnnotations3D(LoadAnnotations):
"""Private function to load label annotations.
"""Private function to load label annotations.
Args:
Args:
results (dict): Result dict from :obj :obj:`
`
mmcv.BaseDataset`
`
.
results (dict): Result dict from :obj :obj:`mmcv.BaseDataset`.
Returns:
Returns:
dict: The dict contains loaded label annotations.
dict: The dict contains loaded label annotations.
...
@@ -933,7 +915,7 @@ class LoadAnnotations3D(LoadAnnotations):
...
@@ -933,7 +915,7 @@ class LoadAnnotations3D(LoadAnnotations):
return
results
return
results
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
indent_str
=
' '
indent_str
=
' '
repr_str
=
self
.
__class__
.
__name__
+
'(
\n
'
repr_str
=
self
.
__class__
.
__name__
+
'(
\n
'
...
...
mmdet3d/datasets/transforms/test_time_aug.py
View file @
d7067e44
...
@@ -19,18 +19,17 @@ class MultiScaleFlipAug3D(BaseTransform):
...
@@ -19,18 +19,17 @@ class MultiScaleFlipAug3D(BaseTransform):
img_scale (tuple | list[tuple]): Images scales for resizing.
img_scale (tuple | list[tuple]): Images scales for resizing.
pts_scale_ratio (float | list[float]): Points scale ratios for
pts_scale_ratio (float | list[float]): Points scale ratios for
resizing.
resizing.
flip (bool, optional): Whether apply flip augmentation.
flip (bool): Whether apply flip augmentation. Defaults to False.
Defaults to False.
flip_direction (str | list[str]): Flip augmentation directions
flip_direction (str | list[str], optional): Flip augmentation
for images, options are "horizontal" and "vertical".
directions for images, options are "horizontal" and "vertical".
If flip_direction is list, multiple flip augmentations will
If flip_direction is list, multiple flip augmentations will
be applied. It has no effect when ``flip == False``.
be applied. It has no effect when ``flip == False``.
Defaults to 'horizontal'.
Defaults to 'horizontal'.
pcd_horizontal_flip (bool
, optional
): Whether to apply horizontal
pcd_horizontal_flip (bool): Whether to apply horizontal
flip
flip
augmentation to point cloud. Defaults to
Tru
e.
augmentation to point cloud. Defaults to
Fals
e.
Note that it works only when 'flip' is turned on.
Note that it works only when 'flip' is turned on.
pcd_vertical_flip (bool
, optional
): Whether to apply vertical flip
pcd_vertical_flip (bool): Whether to apply vertical flip
augmentation to point cloud. Defaults to
Tru
e.
augmentation to point cloud. Defaults to
Fals
e.
Note that it works only when 'flip' is turned on.
Note that it works only when 'flip' is turned on.
"""
"""
...
@@ -75,7 +74,7 @@ class MultiScaleFlipAug3D(BaseTransform):
...
@@ -75,7 +74,7 @@ class MultiScaleFlipAug3D(BaseTransform):
Returns:
Returns:
List[dict]: The list contains the data that is augmented with
List[dict]: The list contains the data that is augmented with
different scales and flips.
different scales and flips.
"""
"""
aug_data_list
=
[]
aug_data_list
=
[]
...
@@ -112,7 +111,7 @@ class MultiScaleFlipAug3D(BaseTransform):
...
@@ -112,7 +111,7 @@ class MultiScaleFlipAug3D(BaseTransform):
return
aug_data_list
return
aug_data_list
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(transforms=
{
self
.
transforms
}
, '
repr_str
+=
f
'(transforms=
{
self
.
transforms
}
, '
...
...
mmdet3d/datasets/transforms/transforms_3d.py
View file @
d7067e44
...
@@ -6,7 +6,9 @@ from typing import List, Optional, Tuple, Union
...
@@ -6,7 +6,9 @@ from typing import List, Optional, Tuple, Union
import
cv2
import
cv2
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
from
mmcv.transforms
import
BaseTransform
,
RandomResize
,
Resize
from
mmcv.transforms
import
BaseTransform
,
Compose
,
RandomResize
,
Resize
from
mmdet.datasets.transforms
import
(
PhotoMetricDistortion
,
RandomCrop
,
RandomFlip
)
from
mmengine
import
is_tuple_of
from
mmengine
import
is_tuple_of
from
mmdet3d.models.task_modules
import
VoxelGenerator
from
mmdet3d.models.task_modules
import
VoxelGenerator
...
@@ -15,9 +17,6 @@ from mmdet3d.structures import (CameraInstance3DBoxes, DepthInstance3DBoxes,
...
@@ -15,9 +17,6 @@ from mmdet3d.structures import (CameraInstance3DBoxes, DepthInstance3DBoxes,
LiDARInstance3DBoxes
)
LiDARInstance3DBoxes
)
from
mmdet3d.structures.ops
import
box_np_ops
from
mmdet3d.structures.ops
import
box_np_ops
from
mmdet3d.structures.points
import
BasePoints
from
mmdet3d.structures.points
import
BasePoints
from
mmdet.datasets.transforms
import
(
PhotoMetricDistortion
,
RandomCrop
,
RandomFlip
)
from
.compose
import
Compose
from
.data_augment_utils
import
noise_per_object_v3_
from
.data_augment_utils
import
noise_per_object_v3_
...
@@ -30,7 +29,7 @@ class RandomDropPointsColor(BaseTransform):
...
@@ -30,7 +29,7 @@ class RandomDropPointsColor(BaseTransform):
util/transform.py#L223>`_ for more details.
util/transform.py#L223>`_ for more details.
Args:
Args:
drop_ratio (float
, optional
): The probability of dropping point colors.
drop_ratio (float): The probability of dropping point colors.
Defaults to 0.2.
Defaults to 0.2.
"""
"""
...
@@ -46,8 +45,8 @@ class RandomDropPointsColor(BaseTransform):
...
@@ -46,8 +45,8 @@ class RandomDropPointsColor(BaseTransform):
input_dict (dict): Result dict from loading pipeline.
input_dict (dict): Result dict from loading pipeline.
Returns:
Returns:
dict: Results after color dropping,
dict: Results after color dropping,
'points' key is updated
'points' key is updated
in the result dict.
in the result dict.
"""
"""
points
=
input_dict
[
'points'
]
points
=
input_dict
[
'points'
]
assert
points
.
attribute_dims
is
not
None
and
\
assert
points
.
attribute_dims
is
not
None
and
\
...
@@ -64,7 +63,7 @@ class RandomDropPointsColor(BaseTransform):
...
@@ -64,7 +63,7 @@ class RandomDropPointsColor(BaseTransform):
points
.
color
=
points
.
color
*
0.0
points
.
color
=
points
.
color
*
0.0
return
input_dict
return
input_dict
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(drop_ratio=
{
self
.
drop_ratio
}
)'
repr_str
+=
f
'(drop_ratio=
{
self
.
drop_ratio
}
)'
...
@@ -108,8 +107,8 @@ class RandomFlip3D(RandomFlip):
...
@@ -108,8 +107,8 @@ class RandomFlip3D(RandomFlip):
in vertical direction. Defaults to 0.0.
in vertical direction. Defaults to 0.0.
flip_box3d (bool): Whether to flip bounding box. In most of the case,
flip_box3d (bool): Whether to flip bounding box. In most of the case,
the box should be fliped. In cam-based bev detection, this is set
the box should be fliped. In cam-based bev detection, this is set
to
f
alse, since the flip of 2D images does not influence the 3D
to
F
alse, since the flip of 2D images does not influence the 3D
box. Default to True.
box. Default
s
to True.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -150,12 +149,11 @@ class RandomFlip3D(RandomFlip):
...
@@ -150,12 +149,11 @@ class RandomFlip3D(RandomFlip):
Args:
Args:
input_dict (dict): Result dict from loading pipeline.
input_dict (dict): Result dict from loading pipeline.
direction (str, optional): Flip direction.
direction (str): Flip direction. Defaults to 'horizontal'.
Default: 'horizontal'.
Returns:
Returns:
dict: Flipped results, 'points', 'bbox3d_fields' keys are
dict: Flipped results, 'points', 'bbox3d_fields' keys are
updated in the result dict.
updated in the result dict.
"""
"""
assert
direction
in
[
'horizontal'
,
'vertical'
]
assert
direction
in
[
'horizontal'
,
'vertical'
]
if
self
.
flip_box3d
:
if
self
.
flip_box3d
:
...
@@ -210,8 +208,8 @@ class RandomFlip3D(RandomFlip):
...
@@ -210,8 +208,8 @@ class RandomFlip3D(RandomFlip):
Returns:
Returns:
dict: Flipped results, 'flip', 'flip_direction',
dict: Flipped results, 'flip', 'flip_direction',
'pcd_horizontal_flip' and 'pcd_vertical_flip' keys are added
'pcd_horizontal_flip' and 'pcd_vertical_flip' keys are added
into result dict.
into result dict.
"""
"""
# flip 2D image and its annotations
# flip 2D image and its annotations
if
'img'
in
input_dict
:
if
'img'
in
input_dict
:
...
@@ -241,7 +239,7 @@ class RandomFlip3D(RandomFlip):
...
@@ -241,7 +239,7 @@ class RandomFlip3D(RandomFlip):
input_dict
[
'transformation_3d_flow'
].
extend
([
'VF'
])
input_dict
[
'transformation_3d_flow'
].
extend
([
'VF'
])
return
input_dict
return
input_dict
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(sync_2d=
{
self
.
sync_2d
}
,'
repr_str
+=
f
'(sync_2d=
{
self
.
sync_2d
}
,'
...
@@ -254,7 +252,7 @@ class RandomJitterPoints(BaseTransform):
...
@@ -254,7 +252,7 @@ class RandomJitterPoints(BaseTransform):
"""Randomly jitter point coordinates.
"""Randomly jitter point coordinates.
Different from the global translation in ``GlobalRotScaleTrans``, here we
Different from the global translation in ``GlobalRotScaleTrans``, here we
apply different noises to each point in a scene.
apply different noises to each point in a scene.
Args:
Args:
jitter_std (list[float]): The standard deviation of jittering noise.
jitter_std (list[float]): The standard deviation of jittering noise.
...
@@ -267,7 +265,7 @@ class RandomJitterPoints(BaseTransform):
...
@@ -267,7 +265,7 @@ class RandomJitterPoints(BaseTransform):
Note:
Note:
This transform should only be used in point cloud segmentation tasks
This transform should only be used in point cloud segmentation tasks
because we don't transform ground-truth bboxes accordingly.
because we don't transform ground-truth bboxes accordingly.
For similar transform in detection task, please refer to `ObjectNoise`.
For similar transform in detection task, please refer to `ObjectNoise`.
"""
"""
...
@@ -296,7 +294,7 @@ class RandomJitterPoints(BaseTransform):
...
@@ -296,7 +294,7 @@ class RandomJitterPoints(BaseTransform):
Returns:
Returns:
dict: Results after adding noise to each point,
dict: Results after adding noise to each point,
'points' key is updated in the result dict.
'points' key is updated in the result dict.
"""
"""
points
=
input_dict
[
'points'
]
points
=
input_dict
[
'points'
]
jitter_std
=
np
.
array
(
self
.
jitter_std
,
dtype
=
np
.
float32
)
jitter_std
=
np
.
array
(
self
.
jitter_std
,
dtype
=
np
.
float32
)
...
@@ -309,7 +307,7 @@ class RandomJitterPoints(BaseTransform):
...
@@ -309,7 +307,7 @@ class RandomJitterPoints(BaseTransform):
points
.
translate
(
jitter_noise
)
points
.
translate
(
jitter_noise
)
return
input_dict
return
input_dict
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(jitter_std=
{
self
.
jitter_std
}
,'
repr_str
+=
f
'(jitter_std=
{
self
.
jitter_std
}
,'
...
@@ -344,11 +342,11 @@ class ObjectSample(BaseTransform):
...
@@ -344,11 +342,11 @@ class ObjectSample(BaseTransform):
Args:
Args:
db_sampler (dict): Config dict of the database sampler.
db_sampler (dict): Config dict of the database sampler.
sample_2d (bool): Whether to also paste 2D image patch to the images
sample_2d (bool): Whether to also paste 2D image patch to the images
.
This should be true when applying multi-modality cut-and-paste.
This should be true when applying multi-modality cut-and-paste.
Defaults to False.
Defaults to False.
use_ground_plane (bool): Whether to use ground plane to adjust the
use_ground_plane (bool): Whether to use ground plane to adjust the
3D labels.
3D labels.
Defaults to False.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -386,8 +384,8 @@ class ObjectSample(BaseTransform):
...
@@ -386,8 +384,8 @@ class ObjectSample(BaseTransform):
Returns:
Returns:
dict: Results after object sampling augmentation,
dict: Results after object sampling augmentation,
'points', 'gt_bboxes_3d', 'gt_labels_3d' keys are updated
'points', 'gt_bboxes_3d', 'gt_labels_3d' keys are updated
in the result dict.
in the result dict.
"""
"""
gt_bboxes_3d
=
input_dict
[
'gt_bboxes_3d'
]
gt_bboxes_3d
=
input_dict
[
'gt_bboxes_3d'
]
gt_labels_3d
=
input_dict
[
'gt_labels_3d'
]
gt_labels_3d
=
input_dict
[
'gt_labels_3d'
]
...
@@ -445,12 +443,12 @@ class ObjectSample(BaseTransform):
...
@@ -445,12 +443,12 @@ class ObjectSample(BaseTransform):
return
input_dict
return
input_dict
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'db_sampler=
{
self
.
db_sampler
}
,'
repr_str
+=
f
'
(
db_sampler=
{
self
.
db_sampler
}
,'
repr_str
+=
f
' sample_2d=
{
self
.
sample_2d
}
,'
repr_str
+=
f
' sample_2d=
{
self
.
sample_2d
}
,'
repr_str
+=
f
' use_ground_plane=
{
self
.
use_ground_plane
}
'
repr_str
+=
f
' use_ground_plane=
{
self
.
use_ground_plane
}
)
'
return
repr_str
return
repr_str
...
@@ -469,15 +467,15 @@ class ObjectNoise(BaseTransform):
...
@@ -469,15 +467,15 @@ class ObjectNoise(BaseTransform):
- gt_bboxes_3d
- gt_bboxes_3d
Args:
Args:
translation_std (list[float]
, optional
): Standard deviation of the
translation_std (list[float]): Standard deviation of the
distribution where translation noise are sampled from.
distribution where translation noise are sampled from.
Defaults to [0.25, 0.25, 0.25].
Defaults to [0.25, 0.25, 0.25].
global_rot_range (list[float]
, optional
): Global rotation to the scene.
global_rot_range (list[float]): Global rotation to the scene.
Defaults to [0.0, 0.0].
Defaults to [0.0, 0.0].
rot_range (list[float]
, optional
): Object rotation range.
rot_range (list[float]): Object rotation range.
Defaults to [-0.15707963267, 0.15707963267].
Defaults to [-0.15707963267, 0.15707963267].
num_try (int
, optional
): Number of times to try if the noise applied is
num_try (int): Number of times to try if the noise applied is
invalid.
invalid.
Defaults to 100.
Defaults to 100.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -498,7 +496,7 @@ class ObjectNoise(BaseTransform):
...
@@ -498,7 +496,7 @@ class ObjectNoise(BaseTransform):
Returns:
Returns:
dict: Results after adding noise to each object,
dict: Results after adding noise to each object,
'points', 'gt_bboxes_3d' keys are updated in the result dict.
'points', 'gt_bboxes_3d' keys are updated in the result dict.
"""
"""
gt_bboxes_3d
=
input_dict
[
'gt_bboxes_3d'
]
gt_bboxes_3d
=
input_dict
[
'gt_bboxes_3d'
]
points
=
input_dict
[
'points'
]
points
=
input_dict
[
'points'
]
...
@@ -519,7 +517,7 @@ class ObjectNoise(BaseTransform):
...
@@ -519,7 +517,7 @@ class ObjectNoise(BaseTransform):
input_dict
[
'points'
]
=
points
.
new_point
(
numpy_points
)
input_dict
[
'points'
]
=
points
.
new_point
(
numpy_points
)
return
input_dict
return
input_dict
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(num_try=
{
self
.
num_try
}
,'
repr_str
+=
f
'(num_try=
{
self
.
num_try
}
,'
...
@@ -538,10 +536,10 @@ class GlobalAlignment(BaseTransform):
...
@@ -538,10 +536,10 @@ class GlobalAlignment(BaseTransform):
Note:
Note:
We do not record the applied rotation and translation as in
We do not record the applied rotation and translation as in
GlobalRotScaleTrans. Because usually, we do not need to reverse
GlobalRotScaleTrans. Because usually, we do not need to reverse
the alignment step.
the alignment step.
For example, ScanNet 3D detection task uses aligned ground-truth
For example, ScanNet 3D detection task uses aligned ground-truth
bounding boxes for evaluation.
bounding boxes for evaluation.
"""
"""
def
__init__
(
self
,
rotation_axis
:
int
)
->
None
:
def
__init__
(
self
,
rotation_axis
:
int
)
->
None
:
...
@@ -593,7 +591,7 @@ class GlobalAlignment(BaseTransform):
...
@@ -593,7 +591,7 @@ class GlobalAlignment(BaseTransform):
Returns:
Returns:
dict: Results after global alignment, 'points' and keys in
dict: Results after global alignment, 'points' and keys in
input_dict['bbox3d_fields'] are updated in the result dict.
input_dict['bbox3d_fields'] are updated in the result dict.
"""
"""
assert
'axis_align_matrix'
in
results
,
\
assert
'axis_align_matrix'
in
results
,
\
'axis_align_matrix is not provided in GlobalAlignment'
'axis_align_matrix is not provided in GlobalAlignment'
...
@@ -610,7 +608,7 @@ class GlobalAlignment(BaseTransform):
...
@@ -610,7 +608,7 @@ class GlobalAlignment(BaseTransform):
return
results
return
results
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(rotation_axis=
{
self
.
rotation_axis
}
)'
repr_str
+=
f
'(rotation_axis=
{
self
.
rotation_axis
}
)'
...
@@ -640,15 +638,15 @@ class GlobalRotScaleTrans(BaseTransform):
...
@@ -640,15 +638,15 @@ class GlobalRotScaleTrans(BaseTransform):
- pcd_scale_factor (np.float32)
- pcd_scale_factor (np.float32)
Args:
Args:
rot_range (list[float]
, optional
): Range of rotation angle.
rot_range (list[float]): Range of rotation angle.
Defaults to [-0.78539816, 0.78539816] (close to [-pi/4, pi/4]).
Defaults to [-0.78539816, 0.78539816] (close to [-pi/4, pi/4]).
scale_ratio_range (list[float]
, optional
): Range of scale ratio.
scale_ratio_range (list[float]): Range of scale ratio.
Defaults to [0.95, 1.05].
Defaults to [0.95, 1.05].
translation_std (list[float]
, optional
): The standard deviation of
translation_std (list[float]): The standard deviation of
translation noise applied to a scene, which
translation noise applied to a scene, which
is sampled from a gaussian distribution whose standard deviation
is sampled from a gaussian distribution whose standard deviation
is set by ``translation_std``. Defaults to [0, 0, 0]
is set by ``translation_std``. Defaults to [0, 0, 0]
.
shift_height (bool
, optional
): Whether to shift height.
shift_height (bool): Whether to shift height.
(the fourth dimension of indoor points) when scaling.
(the fourth dimension of indoor points) when scaling.
Defaults to False.
Defaults to False.
"""
"""
...
@@ -689,8 +687,7 @@ class GlobalRotScaleTrans(BaseTransform):
...
@@ -689,8 +687,7 @@ class GlobalRotScaleTrans(BaseTransform):
Returns:
Returns:
dict: Results after translation, 'points', 'pcd_trans'
dict: Results after translation, 'points', 'pcd_trans'
and `gt_bboxes_3d` is updated
and `gt_bboxes_3d` is updated in the result dict.
in the result dict.
"""
"""
translation_std
=
np
.
array
(
self
.
translation_std
,
dtype
=
np
.
float32
)
translation_std
=
np
.
array
(
self
.
translation_std
,
dtype
=
np
.
float32
)
trans_factor
=
np
.
random
.
normal
(
scale
=
translation_std
,
size
=
3
).
T
trans_factor
=
np
.
random
.
normal
(
scale
=
translation_std
,
size
=
3
).
T
...
@@ -708,8 +705,7 @@ class GlobalRotScaleTrans(BaseTransform):
...
@@ -708,8 +705,7 @@ class GlobalRotScaleTrans(BaseTransform):
Returns:
Returns:
dict: Results after rotation, 'points', 'pcd_rotation'
dict: Results after rotation, 'points', 'pcd_rotation'
and `gt_bboxes_3d` is updated
and `gt_bboxes_3d` is updated in the result dict.
in the result dict.
"""
"""
rotation
=
self
.
rot_range
rotation
=
self
.
rot_range
noise_rotation
=
np
.
random
.
uniform
(
rotation
[
0
],
rotation
[
1
])
noise_rotation
=
np
.
random
.
uniform
(
rotation
[
0
],
rotation
[
1
])
...
@@ -735,8 +731,7 @@ class GlobalRotScaleTrans(BaseTransform):
...
@@ -735,8 +731,7 @@ class GlobalRotScaleTrans(BaseTransform):
Returns:
Returns:
dict: Results after scaling, 'points' and
dict: Results after scaling, 'points' and
`gt_bboxes_3d` is updated
`gt_bboxes_3d` is updated in the result dict.
in the result dict.
"""
"""
scale
=
input_dict
[
'pcd_scale_factor'
]
scale
=
input_dict
[
'pcd_scale_factor'
]
points
=
input_dict
[
'points'
]
points
=
input_dict
[
'points'
]
...
@@ -774,7 +769,7 @@ class GlobalRotScaleTrans(BaseTransform):
...
@@ -774,7 +769,7 @@ class GlobalRotScaleTrans(BaseTransform):
Returns:
Returns:
dict: Results after scaling, 'points', 'pcd_rotation',
dict: Results after scaling, 'points', 'pcd_rotation',
'pcd_scale_factor', 'pcd_trans' and `gt_bboxes_3d`
is
updated
'pcd_scale_factor', 'pcd_trans' and `gt_bboxes_3d`
are
updated
in the result dict.
in the result dict.
"""
"""
if
'transformation_3d_flow'
not
in
input_dict
:
if
'transformation_3d_flow'
not
in
input_dict
:
...
@@ -791,7 +786,7 @@ class GlobalRotScaleTrans(BaseTransform):
...
@@ -791,7 +786,7 @@ class GlobalRotScaleTrans(BaseTransform):
input_dict
[
'transformation_3d_flow'
].
extend
([
'R'
,
'S'
,
'T'
])
input_dict
[
'transformation_3d_flow'
].
extend
([
'R'
,
'S'
,
'T'
])
return
input_dict
return
input_dict
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(rot_range=
{
self
.
rot_range
}
,'
repr_str
+=
f
'(rot_range=
{
self
.
rot_range
}
,'
...
@@ -813,7 +808,7 @@ class PointShuffle(BaseTransform):
...
@@ -813,7 +808,7 @@ class PointShuffle(BaseTransform):
Returns:
Returns:
dict: Results after filtering, 'points', 'pts_instance_mask'
dict: Results after filtering, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict.
and 'pts_semantic_mask' keys are updated in the result dict.
"""
"""
idx
=
input_dict
[
'points'
].
shuffle
()
idx
=
input_dict
[
'points'
].
shuffle
()
idx
=
idx
.
numpy
()
idx
=
idx
.
numpy
()
...
@@ -829,7 +824,7 @@ class PointShuffle(BaseTransform):
...
@@ -829,7 +824,7 @@ class PointShuffle(BaseTransform):
return
input_dict
return
input_dict
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
...
@@ -850,7 +845,7 @@ class ObjectRangeFilter(BaseTransform):
...
@@ -850,7 +845,7 @@ class ObjectRangeFilter(BaseTransform):
point_cloud_range (list[float]): Point cloud range.
point_cloud_range (list[float]): Point cloud range.
"""
"""
def
__init__
(
self
,
point_cloud_range
:
List
[
float
]):
def
__init__
(
self
,
point_cloud_range
:
List
[
float
])
->
None
:
self
.
pcd_range
=
np
.
array
(
point_cloud_range
,
dtype
=
np
.
float32
)
self
.
pcd_range
=
np
.
array
(
point_cloud_range
,
dtype
=
np
.
float32
)
def
transform
(
self
,
input_dict
:
dict
)
->
dict
:
def
transform
(
self
,
input_dict
:
dict
)
->
dict
:
...
@@ -861,7 +856,7 @@ class ObjectRangeFilter(BaseTransform):
...
@@ -861,7 +856,7 @@ class ObjectRangeFilter(BaseTransform):
Returns:
Returns:
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d'
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d'
keys are updated in the result dict.
keys are updated in the result dict.
"""
"""
# Check points instance type and initialise bev_range
# Check points instance type and initialise bev_range
if
isinstance
(
input_dict
[
'gt_bboxes_3d'
],
if
isinstance
(
input_dict
[
'gt_bboxes_3d'
],
...
@@ -887,7 +882,7 @@ class ObjectRangeFilter(BaseTransform):
...
@@ -887,7 +882,7 @@ class ObjectRangeFilter(BaseTransform):
return
input_dict
return
input_dict
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(point_cloud_range=
{
self
.
pcd_range
.
tolist
()
}
)'
repr_str
+=
f
'(point_cloud_range=
{
self
.
pcd_range
.
tolist
()
}
)'
...
@@ -923,7 +918,7 @@ class PointsRangeFilter(BaseTransform):
...
@@ -923,7 +918,7 @@ class PointsRangeFilter(BaseTransform):
Returns:
Returns:
dict: Results after filtering, 'points', 'pts_instance_mask'
dict: Results after filtering, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict.
and 'pts_semantic_mask' keys are updated in the result dict.
"""
"""
points
=
input_dict
[
'points'
]
points
=
input_dict
[
'points'
]
points_mask
=
points
.
in_range_3d
(
self
.
pcd_range
)
points_mask
=
points
.
in_range_3d
(
self
.
pcd_range
)
...
@@ -942,7 +937,7 @@ class PointsRangeFilter(BaseTransform):
...
@@ -942,7 +937,7 @@ class PointsRangeFilter(BaseTransform):
return
input_dict
return
input_dict
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(point_cloud_range=
{
self
.
pcd_range
.
tolist
()
}
)'
repr_str
+=
f
'(point_cloud_range=
{
self
.
pcd_range
.
tolist
()
}
)'
...
@@ -977,7 +972,7 @@ class ObjectNameFilter(BaseTransform):
...
@@ -977,7 +972,7 @@ class ObjectNameFilter(BaseTransform):
Returns:
Returns:
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d'
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d'
keys are updated in the result dict.
keys are updated in the result dict.
"""
"""
gt_labels_3d
=
input_dict
[
'gt_labels_3d'
]
gt_labels_3d
=
input_dict
[
'gt_labels_3d'
]
gt_bboxes_mask
=
np
.
array
([
n
in
self
.
labels
for
n
in
gt_labels_3d
],
gt_bboxes_mask
=
np
.
array
([
n
in
self
.
labels
for
n
in
gt_labels_3d
],
...
@@ -987,7 +982,7 @@ class ObjectNameFilter(BaseTransform):
...
@@ -987,7 +982,7 @@ class ObjectNameFilter(BaseTransform):
return
input_dict
return
input_dict
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(classes=
{
self
.
classes
}
)'
repr_str
+=
f
'(classes=
{
self
.
classes
}
)'
...
@@ -1017,8 +1012,8 @@ class PointSample(BaseTransform):
...
@@ -1017,8 +1012,8 @@ class PointSample(BaseTransform):
sample_range (float, optional): The range where to sample points.
sample_range (float, optional): The range where to sample points.
If not None, the points with depth larger than `sample_range` are
If not None, the points with depth larger than `sample_range` are
prior to be sampled. Defaults to None.
prior to be sampled. Defaults to None.
replace (bool
, optional
): Whether the sampling is with or without
replace (bool): Whether the sampling is with or without
replacement.
replacement.
Defaults to False.
Defaults to False.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -1046,10 +1041,9 @@ class PointSample(BaseTransform):
...
@@ -1046,10 +1041,9 @@ class PointSample(BaseTransform):
num_samples (int): Number of samples to be sampled.
num_samples (int): Number of samples to be sampled.
sample_range (float, optional): Indicating the range where the
sample_range (float, optional): Indicating the range where the
points will be sampled. Defaults to None.
points will be sampled. Defaults to None.
replace (bool, optional): Sampling with or without replacement.
replace (bool): Sampling with or without replacement.
Defaults to False.
return_choices (bool, optional): Whether return choice.
Defaults to False.
Defaults to False.
return_choices (bool): Whether return choice. Defaults to False.
Returns:
Returns:
tuple[:obj:`BasePoints`, np.ndarray] | :obj:`BasePoints`:
tuple[:obj:`BasePoints`, np.ndarray] | :obj:`BasePoints`:
...
@@ -1089,7 +1083,7 @@ class PointSample(BaseTransform):
...
@@ -1089,7 +1083,7 @@ class PointSample(BaseTransform):
Returns:
Returns:
dict: Results after sampling, 'points', 'pts_instance_mask'
dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict.
and 'pts_semantic_mask' keys are updated in the result dict.
"""
"""
points
=
input_dict
[
'points'
]
points
=
input_dict
[
'points'
]
points
,
choices
=
self
.
_points_random_sampling
(
points
,
choices
=
self
.
_points_random_sampling
(
...
@@ -1113,7 +1107,7 @@ class PointSample(BaseTransform):
...
@@ -1113,7 +1107,7 @@ class PointSample(BaseTransform):
return
input_dict
return
input_dict
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(num_points=
{
self
.
num_points
}
,'
repr_str
+=
f
'(num_points=
{
self
.
num_points
}
,'
...
@@ -1149,7 +1143,7 @@ class IndoorPatchPointSample(BaseTransform):
...
@@ -1149,7 +1143,7 @@ class IndoorPatchPointSample(BaseTransform):
Args:
Args:
num_points (int): Number of points to be sampled.
num_points (int): Number of points to be sampled.
block_size (float
, optional
): Size of a block to sample points from.
block_size (float): Size of a block to sample points from.
Defaults to 1.5.
Defaults to 1.5.
sample_rate (float, optional): Stride used in sliding patch generation.
sample_rate (float, optional): Stride used in sliding patch generation.
This parameter is unused in `IndoorPatchPointSample` and thus has
This parameter is unused in `IndoorPatchPointSample` and thus has
...
@@ -1159,24 +1153,24 @@ class IndoorPatchPointSample(BaseTransform):
...
@@ -1159,24 +1153,24 @@ class IndoorPatchPointSample(BaseTransform):
segmentation task. This is set in PointSegClassMapping as neg_cls.
segmentation task. This is set in PointSegClassMapping as neg_cls.
If not None, will be used as a patch selection criterion.
If not None, will be used as a patch selection criterion.
Defaults to None.
Defaults to None.
use_normalized_coord (bool
, optional
): Whether to use normalized xyz as
use_normalized_coord (bool): Whether to use normalized xyz as
additional features. Defaults to False.
additional features. Defaults to False.
num_try (int
, optional
): Number of times to try if the patch selected
num_try (int): Number of times to try if the patch selected
is invalid.
is invalid.
Defaults to 10.
Defaults to 10.
enlarge_size (float
, optional
): Enlarge the sampled patch to
enlarge_size (float): Enlarge the sampled patch to
[-block_size / 2 - enlarge_size, block_size / 2 + enlarge_size] as
[-block_size / 2 - enlarge_size, block_size / 2 + enlarge_size] as
an augmentation. If None, set it as 0. Defaults to 0.2.
an augmentation. If None, set it as 0. Defaults to 0.2.
min_unique_num (int, optional): Minimum number of unique points
min_unique_num (int, optional): Minimum number of unique points
the sampled patch should contain. If None, use PointNet++'s method
the sampled patch should contain. If None, use PointNet++'s method
to judge uniqueness. Defaults to None.
to judge uniqueness. Defaults to None.
eps (float
, optional
): A value added to patch boundary to guarantee
eps (float): A value added to patch boundary to guarantee
points coverage. Defaults to 1e-2.
points coverage. Defaults to 1e-2.
Note:
Note:
This transform should only be used in the training process of point
This transform should only be used in the training process of point
cloud segmentation tasks. For the sliding patch generation and
cloud segmentation tasks. For the sliding patch generation and
inference process in testing, please refer to the `slide_inference`
inference process in testing, please refer to the `slide_inference`
function of `EncoderDecoder3D` class.
function of `EncoderDecoder3D` class.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -1356,7 +1350,7 @@ class IndoorPatchPointSample(BaseTransform):
...
@@ -1356,7 +1350,7 @@ class IndoorPatchPointSample(BaseTransform):
Returns:
Returns:
dict: Results after sampling, 'points', 'pts_instance_mask'
dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict.
and 'pts_semantic_mask' keys are updated in the result dict.
"""
"""
points
=
input_dict
[
'points'
]
points
=
input_dict
[
'points'
]
...
@@ -1386,7 +1380,7 @@ class IndoorPatchPointSample(BaseTransform):
...
@@ -1386,7 +1380,7 @@ class IndoorPatchPointSample(BaseTransform):
return
input_dict
return
input_dict
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(num_points=
{
self
.
num_points
}
,'
repr_str
+=
f
'(num_points=
{
self
.
num_points
}
,'
...
@@ -1405,7 +1399,7 @@ class BackgroundPointsFilter(BaseTransform):
...
@@ -1405,7 +1399,7 @@ class BackgroundPointsFilter(BaseTransform):
"""Filter background points near the bounding box.
"""Filter background points near the bounding box.
Args:
Args:
bbox_enlarge_range (tuple[float]
,
float): Bbox enlarge range.
bbox_enlarge_range (tuple[float]
|
float): Bbox enlarge range.
"""
"""
def
__init__
(
self
,
bbox_enlarge_range
:
Union
[
Tuple
[
float
],
float
])
->
None
:
def
__init__
(
self
,
bbox_enlarge_range
:
Union
[
Tuple
[
float
],
float
])
->
None
:
...
@@ -1427,7 +1421,7 @@ class BackgroundPointsFilter(BaseTransform):
...
@@ -1427,7 +1421,7 @@ class BackgroundPointsFilter(BaseTransform):
Returns:
Returns:
dict: Results after filtering, 'points', 'pts_instance_mask'
dict: Results after filtering, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict.
and 'pts_semantic_mask' keys are updated in the result dict.
"""
"""
points
=
input_dict
[
'points'
]
points
=
input_dict
[
'points'
]
gt_bboxes_3d
=
input_dict
[
'gt_bboxes_3d'
]
gt_bboxes_3d
=
input_dict
[
'gt_bboxes_3d'
]
...
@@ -1458,7 +1452,7 @@ class BackgroundPointsFilter(BaseTransform):
...
@@ -1458,7 +1452,7 @@ class BackgroundPointsFilter(BaseTransform):
input_dict
[
'pts_semantic_mask'
]
=
pts_semantic_mask
[
valid_masks
]
input_dict
[
'pts_semantic_mask'
]
=
pts_semantic_mask
[
valid_masks
]
return
input_dict
return
input_dict
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(bbox_enlarge_range=
{
self
.
bbox_enlarge_range
.
tolist
()
}
)'
repr_str
+=
f
'(bbox_enlarge_range=
{
self
.
bbox_enlarge_range
.
tolist
()
}
)'
...
@@ -1473,9 +1467,10 @@ class VoxelBasedPointSampler(BaseTransform):
...
@@ -1473,9 +1467,10 @@ class VoxelBasedPointSampler(BaseTransform):
Args:
Args:
cur_sweep_cfg (dict): Config for sampling current points.
cur_sweep_cfg (dict): Config for sampling current points.
prev_sweep_cfg (dict): Config for sampling previous points.
prev_sweep_cfg (dict, optional): Config for sampling previous points.
Defaults to None.
time_dim (int): Index that indicate the time dimension
time_dim (int): Index that indicate the time dimension
for input points.
for input points.
Defaults to 3.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -1502,7 +1497,7 @@ class VoxelBasedPointSampler(BaseTransform):
...
@@ -1502,7 +1497,7 @@ class VoxelBasedPointSampler(BaseTransform):
points (np.ndarray): Points subset to be sampled.
points (np.ndarray): Points subset to be sampled.
sampler (VoxelGenerator): Voxel based sampler for
sampler (VoxelGenerator): Voxel based sampler for
each points subset.
each points subset.
point_dim (int): The dimension of each points
point_dim (int): The dimension of each points
.
Returns:
Returns:
np.ndarray: Sampled points.
np.ndarray: Sampled points.
...
@@ -1529,7 +1524,7 @@ class VoxelBasedPointSampler(BaseTransform):
...
@@ -1529,7 +1524,7 @@ class VoxelBasedPointSampler(BaseTransform):
Returns:
Returns:
dict: Results after sampling, 'points', 'pts_instance_mask'
dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict.
and 'pts_semantic_mask' keys are updated in the result dict.
"""
"""
points
=
results
[
'points'
]
points
=
results
[
'points'
]
original_dim
=
points
.
shape
[
1
]
original_dim
=
points
.
shape
[
1
]
...
@@ -1589,7 +1584,7 @@ class VoxelBasedPointSampler(BaseTransform):
...
@@ -1589,7 +1584,7 @@ class VoxelBasedPointSampler(BaseTransform):
return
results
return
results
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
def
_auto_indent
(
repr_str
,
indent
):
def
_auto_indent
(
repr_str
,
indent
):
...
@@ -1625,7 +1620,7 @@ class AffineResize(BaseTransform):
...
@@ -1625,7 +1620,7 @@ class AffineResize(BaseTransform):
img_scale (tuple): Images scales for resizing.
img_scale (tuple): Images scales for resizing.
down_ratio (int): The down ratio of feature map.
down_ratio (int): The down ratio of feature map.
Actually the arg should be >= 1.
Actually the arg should be >= 1.
bbox_clip_border (bool
, optional
): Whether clip the objects
bbox_clip_border (bool): Whether clip the objects
outside the border of the image. Defaults to True.
outside the border of the image. Defaults to True.
"""
"""
...
@@ -1646,7 +1641,7 @@ class AffineResize(BaseTransform):
...
@@ -1646,7 +1641,7 @@ class AffineResize(BaseTransform):
Returns:
Returns:
dict: Results after affine resize, 'affine_aug', 'trans_mat'
dict: Results after affine resize, 'affine_aug', 'trans_mat'
keys are added in the result dict.
keys are added in the result dict.
"""
"""
# The results have gone through RandomShiftScale before AffineResize
# The results have gone through RandomShiftScale before AffineResize
if
'center'
not
in
results
:
if
'center'
not
in
results
:
...
@@ -1803,7 +1798,7 @@ class AffineResize(BaseTransform):
...
@@ -1803,7 +1798,7 @@ class AffineResize(BaseTransform):
ref_point3
=
ref_point2
+
np
.
array
([
-
d
[
1
],
d
[
0
]])
ref_point3
=
ref_point2
+
np
.
array
([
-
d
[
1
],
d
[
0
]])
return
ref_point3
return
ref_point3
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(img_scale=
{
self
.
img_scale
}
, '
repr_str
+=
f
'(img_scale=
{
self
.
img_scale
}
, '
...
@@ -1838,7 +1833,7 @@ class RandomShiftScale(BaseTransform):
...
@@ -1838,7 +1833,7 @@ class RandomShiftScale(BaseTransform):
Returns:
Returns:
dict: Results after random shift and scale, 'center', 'size'
dict: Results after random shift and scale, 'center', 'size'
and 'affine_aug' keys are added in the result dict.
and 'affine_aug' keys are added in the result dict.
"""
"""
img
=
results
[
'img'
]
img
=
results
[
'img'
]
...
@@ -1863,7 +1858,7 @@ class RandomShiftScale(BaseTransform):
...
@@ -1863,7 +1858,7 @@ class RandomShiftScale(BaseTransform):
return
results
return
results
def
__repr__
(
self
):
def
__repr__
(
self
)
->
str
:
"""str: Return a string that describes the module."""
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(shift_scale=
{
self
.
shift_scale
}
, '
repr_str
+=
f
'(shift_scale=
{
self
.
shift_scale
}
, '
...
@@ -1874,7 +1869,7 @@ class RandomShiftScale(BaseTransform):
...
@@ -1874,7 +1869,7 @@ class RandomShiftScale(BaseTransform):
@
TRANSFORMS
.
register_module
()
@
TRANSFORMS
.
register_module
()
class
Resize3D
(
Resize
):
class
Resize3D
(
Resize
):
def
_resize_3d
(
self
,
results
)
:
def
_resize_3d
(
self
,
results
:
dict
)
->
None
:
"""Resize centers_2d and modify camera intrinisc with
"""Resize centers_2d and modify camera intrinisc with
``results['scale']``."""
``results['scale']``."""
if
'centers_2d'
in
results
:
if
'centers_2d'
in
results
:
...
@@ -1888,6 +1883,7 @@ class Resize3D(Resize):
...
@@ -1888,6 +1883,7 @@ class Resize3D(Resize):
Args:
Args:
results (dict): Result dict from loading pipeline.
results (dict): Result dict from loading pipeline.
Returns:
Returns:
dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map',
dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map',
'gt_keypoints', 'scale', 'scale_factor', 'img_shape',
'gt_keypoints', 'scale', 'scale_factor', 'img_shape',
...
@@ -1909,7 +1905,7 @@ class RandomResize3D(RandomResize):
...
@@ -1909,7 +1905,7 @@ class RandomResize3D(RandomResize):
and cam2img with ``results['scale']``.
and cam2img with ``results['scale']``.
"""
"""
def
_resize_3d
(
self
,
results
)
:
def
_resize_3d
(
self
,
results
:
dict
)
->
None
:
"""Resize centers_2d and modify camera intrinisc with
"""Resize centers_2d and modify camera intrinisc with
``results['scale']``."""
``results['scale']``."""
if
'centers_2d'
in
results
:
if
'centers_2d'
in
results
:
...
@@ -1917,7 +1913,7 @@ class RandomResize3D(RandomResize):
...
@@ -1917,7 +1913,7 @@ class RandomResize3D(RandomResize):
results
[
'cam2img'
][
0
]
*=
np
.
array
(
results
[
'scale_factor'
][
0
])
results
[
'cam2img'
][
0
]
*=
np
.
array
(
results
[
'scale_factor'
][
0
])
results
[
'cam2img'
][
1
]
*=
np
.
array
(
results
[
'scale_factor'
][
1
])
results
[
'cam2img'
][
1
]
*=
np
.
array
(
results
[
'scale_factor'
][
1
])
def
transform
(
self
,
results
)
:
def
transform
(
self
,
results
:
dict
)
->
dict
:
"""Transform function to resize images, bounding boxes, masks, semantic
"""Transform function to resize images, bounding boxes, masks, semantic
segmentation map. Compared to RandomResize, this function would further
segmentation map. Compared to RandomResize, this function would further
check if scale is already set in results.
check if scale is already set in results.
...
@@ -1926,8 +1922,8 @@ class RandomResize3D(RandomResize):
...
@@ -1926,8 +1922,8 @@ class RandomResize3D(RandomResize):
results (dict): Result dict from loading pipeline.
results (dict): Result dict from loading pipeline.
Returns:
Returns:
dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
\
dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
'keep_ratio' keys are added into result dict.
'keep_ratio' keys are added into result dict.
"""
"""
if
'scale'
not
in
results
:
if
'scale'
not
in
results
:
results
[
'scale'
]
=
self
.
_random_scale
()
results
[
'scale'
]
=
self
.
_random_scale
()
...
@@ -1989,14 +1985,14 @@ class RandomCrop3D(RandomCrop):
...
@@ -1989,14 +1985,14 @@ class RandomCrop3D(RandomCrop):
on cropped instance masks. Defaults to False.
on cropped instance masks. Defaults to False.
bbox_clip_border (bool): Whether clip the objects outside
bbox_clip_border (bool): Whether clip the objects outside
the border of the image. Defaults to True.
the border of the image. Defaults to True.
rel_offset_h (tuple): The cropping interval of image height. Default
rel_offset_h (tuple): The cropping interval of image height. Default
s
to (0., 1.).
to (0., 1.).
rel_offset_w (tuple): The cropping interval of image width. Default
rel_offset_w (tuple): The cropping interval of image width. Default
s
to (0., 1.).
to (0., 1.).
Note:
Note:
- If the image is smaller than the absolute crop size, return the
- If the image is smaller than the absolute crop size, return the
original image.
original image.
- The keys for bboxes, labels and masks must be aligned. That is,
- The keys for bboxes, labels and masks must be aligned. That is,
``gt_bboxes`` corresponds to ``gt_labels`` and ``gt_masks``, and
``gt_bboxes`` corresponds to ``gt_labels`` and ``gt_masks``, and
``gt_bboxes_ignore`` corresponds to ``gt_labels_ignore`` and
``gt_bboxes_ignore`` corresponds to ``gt_labels_ignore`` and
...
@@ -2005,14 +2001,16 @@ class RandomCrop3D(RandomCrop):
...
@@ -2005,14 +2001,16 @@ class RandomCrop3D(RandomCrop):
``allow_negative_crop`` is set to False, skip this image.
``allow_negative_crop`` is set to False, skip this image.
"""
"""
def
__init__
(
self
,
def
__init__
(
crop_size
,
self
,
crop_type
=
'absolute'
,
crop_size
:
tuple
,
allow_negative_crop
=
False
,
crop_type
:
str
=
'absolute'
,
recompute_bbox
=
False
,
allow_negative_crop
:
bool
=
False
,
bbox_clip_border
=
True
,
recompute_bbox
:
bool
=
False
,
rel_offset_h
=
(
0.
,
1.
),
bbox_clip_border
:
bool
=
True
,
rel_offset_w
=
(
0.
,
1.
)):
rel_offset_h
:
tuple
=
(
0.
,
1.
),
rel_offset_w
:
tuple
=
(
0.
,
1.
)
)
->
None
:
super
().
__init__
(
super
().
__init__
(
crop_size
=
crop_size
,
crop_size
=
crop_size
,
crop_type
=
crop_type
,
crop_type
=
crop_type
,
...
@@ -2024,7 +2022,10 @@ class RandomCrop3D(RandomCrop):
...
@@ -2024,7 +2022,10 @@ class RandomCrop3D(RandomCrop):
self
.
rel_offset_h
=
rel_offset_h
self
.
rel_offset_h
=
rel_offset_h
self
.
rel_offset_w
=
rel_offset_w
self
.
rel_offset_w
=
rel_offset_w
def
_crop_data
(
self
,
results
,
crop_size
,
allow_negative_crop
):
def
_crop_data
(
self
,
results
:
dict
,
crop_size
:
tuple
,
allow_negative_crop
:
bool
=
False
)
->
dict
:
"""Function to randomly crop images, bounding boxes, masks, semantic
"""Function to randomly crop images, bounding boxes, masks, semantic
segmentation maps.
segmentation maps.
...
@@ -2032,11 +2033,11 @@ class RandomCrop3D(RandomCrop):
...
@@ -2032,11 +2033,11 @@ class RandomCrop3D(RandomCrop):
results (dict): Result dict from loading pipeline.
results (dict): Result dict from loading pipeline.
crop_size (tuple): Expected absolute size after cropping, (h, w).
crop_size (tuple): Expected absolute size after cropping, (h, w).
allow_negative_crop (bool): Whether to allow a crop that does not
allow_negative_crop (bool): Whether to allow a crop that does not
contain any bbox area. Default to False.
contain any bbox area. Default
s
to False.
Returns:
Returns:
dict: Randomly cropped results, 'img_shape' key in result dict is
dict: Randomly cropped results, 'img_shape' key in result dict is
updated according to crop size.
updated according to crop size.
"""
"""
assert
crop_size
[
0
]
>
0
and
crop_size
[
1
]
>
0
assert
crop_size
[
0
]
>
0
and
crop_size
[
1
]
>
0
for
key
in
results
.
get
(
'img_fields'
,
[
'img'
]):
for
key
in
results
.
get
(
'img_fields'
,
[
'img'
]):
...
@@ -2119,7 +2120,7 @@ class RandomCrop3D(RandomCrop):
...
@@ -2119,7 +2120,7 @@ class RandomCrop3D(RandomCrop):
return
results
return
results
def
transform
(
self
,
results
)
:
def
transform
(
self
,
results
:
dict
)
->
dict
:
"""Transform function to randomly crop images, bounding boxes, masks,
"""Transform function to randomly crop images, bounding boxes, masks,
semantic segmentation maps.
semantic segmentation maps.
...
@@ -2128,7 +2129,7 @@ class RandomCrop3D(RandomCrop):
...
@@ -2128,7 +2129,7 @@ class RandomCrop3D(RandomCrop):
Returns:
Returns:
dict: Randomly cropped results, 'img_shape' key in result dict is
dict: Randomly cropped results, 'img_shape' key in result dict is
updated according to crop size.
updated according to crop size.
"""
"""
image_size
=
results
[
'img'
].
shape
[:
2
]
image_size
=
results
[
'img'
].
shape
[:
2
]
if
'crop_size'
not
in
results
:
if
'crop_size'
not
in
results
:
...
@@ -2139,7 +2140,8 @@ class RandomCrop3D(RandomCrop):
...
@@ -2139,7 +2140,8 @@ class RandomCrop3D(RandomCrop):
results
=
self
.
_crop_data
(
results
,
crop_size
,
self
.
allow_negative_crop
)
results
=
self
.
_crop_data
(
results
,
crop_size
,
self
.
allow_negative_crop
)
return
results
return
results
def
__repr__
(
self
):
def
__repr__
(
self
)
->
dict
:
"""str: Return a string that describes the module."""
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
f
'(crop_size=
{
self
.
crop_size
}
, '
repr_str
+=
f
'(crop_size=
{
self
.
crop_size
}
, '
repr_str
+=
f
'crop_type=
{
self
.
crop_type
}
, '
repr_str
+=
f
'crop_type=
{
self
.
crop_type
}
, '
...
@@ -2260,43 +2262,44 @@ class MultiViewWrapper(BaseTransform):
...
@@ -2260,43 +2262,44 @@ class MultiViewWrapper(BaseTransform):
transforms (list[dict]): A list of dict specifying the transformations
transforms (list[dict]): A list of dict specifying the transformations
for the monocular situation.
for the monocular situation.
override_aug_config (bool): flag of whether to use the same aug config
override_aug_config (bool): flag of whether to use the same aug config
for multiview image. Default to True.
for multiview image. Default
s
to True.
process_fields (list): Desired keys that the transformations should
process_fields (list): Desired keys that the transformations should
be conducted on. Default to ['img', 'cam2img', 'lidar2cam'],
be conducted on. Defaults to ['img', 'cam2img', 'lidar2cam'].
collected_keys (list): Collect information in transformation
collected_keys (list): Collect information in transformation
like rotate angles, crop roi, and flip state. Default to
like rotate angles, crop roi, and flip state. Default
s
to
['scale', 'scale_factor', 'crop',
['scale', 'scale_factor', 'crop',
'crop_offset', 'ori_shape',
'crop_offset', 'ori_shape',
'pad_shape', 'img_shape',
'pad_shape', 'img_shape',
'pad_fixed_size', 'pad_size_divisor',
'pad_fixed_size', 'pad_size_divisor',
'flip', 'flip_direction', 'rotate']
,
'flip', 'flip_direction', 'rotate']
.
randomness_keys (list): The keys that related to the randomness
randomness_keys (list): The keys that related to the randomness
in transformation Default to
in transformation
.
Default
s
to
['scale', 'scale_factor', 'crop_size', 'flip',
['scale', 'scale_factor', 'crop_size', 'flip',
'flip_direction', 'photometric_param']
'flip_direction', 'photometric_param']
"""
"""
def
__init__
(
self
,
def
__init__
(
transforms
:
dict
,
self
,
override_aug_config
:
bool
=
True
,
transforms
:
dict
,
process_fields
:
list
=
[
'img'
,
'cam2img'
,
'lidar2cam'
],
override_aug_config
:
bool
=
True
,
collected_keys
:
list
=
[
process_fields
:
list
=
[
'img'
,
'cam2img'
,
'lidar2cam'
],
'scale'
,
'scale_factor'
,
'crop'
,
'img_crop_offset'
,
collected_keys
:
list
=
[
'ori_shape'
,
'pad_shape'
,
'img_shape'
,
'pad_fixed_size'
,
'scale'
,
'scale_factor'
,
'crop'
,
'img_crop_offset'
,
'ori_shape'
,
'pad_size_divisor'
,
'flip'
,
'flip_direction'
,
'rotate'
'pad_shape'
,
'img_shape'
,
'pad_fixed_size'
,
'pad_size_divisor'
,
],
'flip'
,
'flip_direction'
,
'rotate'
randomness_keys
:
list
=
[
],
'scale'
,
'scale_factor'
,
'crop_size'
,
'img_crop_offset'
,
randomness_keys
:
list
=
[
'flip'
,
'flip_direction'
,
'photometric_param'
'scale'
,
'scale_factor'
,
'crop_size'
,
'img_crop_offset'
,
'flip'
,
]):
'flip_direction'
,
'photometric_param'
]
)
->
None
:
self
.
transforms
=
Compose
(
transforms
)
self
.
transforms
=
Compose
(
transforms
)
self
.
override_aug_config
=
override_aug_config
self
.
override_aug_config
=
override_aug_config
self
.
collected_keys
=
collected_keys
self
.
collected_keys
=
collected_keys
self
.
process_fields
=
process_fields
self
.
process_fields
=
process_fields
self
.
randomness_keys
=
randomness_keys
self
.
randomness_keys
=
randomness_keys
def
transform
(
self
,
input_dict
)
:
def
transform
(
self
,
input_dict
:
dict
)
->
dict
:
"""Transform function to do the transform for multiview image.
"""Transform function to do the transform for multiview image.
Args:
Args:
...
...
mmdet3d/datasets/waymo_dataset.py
View file @
d7067e44
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
os.path
as
osp
import
os.path
as
osp
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
Callable
,
List
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -24,20 +24,20 @@ class WaymoDataset(KittiDataset):
...
@@ -24,20 +24,20 @@ class WaymoDataset(KittiDataset):
data_root (str): Path of dataset root.
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
ann_file (str): Path of annotation file.
data_prefix (dict): data prefix for point cloud and
data_prefix (dict): data prefix for point cloud and
camera data dict. Default to dict(
camera data dict. Default
s
to dict(
pts='velodyne',
pts='velodyne',
CAM_FRONT='image_0',
CAM_FRONT='image_0',
CAM_FRONT_RIGHT='image_1',
CAM_FRONT_RIGHT='image_1',
CAM_FRONT_LEFT='image_2',
CAM_FRONT_LEFT='image_2',
CAM_SIDE_RIGHT='image_3',
CAM_SIDE_RIGHT='image_3',
CAM_SIDE_LEFT='image_4')
CAM_SIDE_LEFT='image_4')
pipeline (
l
ist[dict]
, optional
): Pipeline used for data processing.
pipeline (
L
ist[dict]): Pipeline used for data processing.
Defaults to
None
.
Defaults to
[]
.
modality (dict
, optional
): Modality to specify the sensor data used
modality (dict): Modality to specify the sensor data used
as input. Defaults to dict(use_lidar=True).
as input. Defaults to dict(use_lidar=True).
default_cam_key (str
, optional
): Default camera key for lidar2img
default_cam_key (str): Default camera key for lidar2img
association. Defaults to 'CAM_FRONT'.
association. Defaults to 'CAM_FRONT'.
box_type_3d (str
, optional
): Type of 3D box of this dataset.
box_type_3d (str): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
to its original format then converted them to `box_type_3d`.
Defaults to 'LiDAR' in this dataset. Available options includes:
Defaults to 'LiDAR' in this dataset. Available options includes:
...
@@ -45,24 +45,30 @@ class WaymoDataset(KittiDataset):
...
@@ -45,24 +45,30 @@ class WaymoDataset(KittiDataset):
- 'LiDAR': Box in LiDAR coordinates.
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool, optional): Whether to filter empty GT.
load_type (str): Type of loading mode. Defaults to 'frame_based'.
Defaults to True.
test_mode (bool, optional): Whether the dataset is in test mode.
- 'frame_based': Load all of the instances in the frame.
- 'mv_image_based': Load all of the instances in the frame and need
to convert to the FOV-based data type to support image-based
detector.
- 'fov_image_based': Only load the instances inside the default
cam, and need to convert to the FOV-based data type to support
image-based detector.
filter_empty_gt (bool): Whether to filter the data with empty GT.
If it's set to be True, the example with empty annotations after
data pipeline will be dropped and a random example will be chosen
in `__getitem__`. Defaults to True.
test_mode (bool): Whether the dataset is in test mode.
Defaults to False.
Defaults to False.
pcd_limit_range (
l
ist[float]
, optional
): The range of point cloud
pcd_limit_range (
L
ist[float]): The range of point cloud
used to filter invalid predicted boxes.
used to filter invalid predicted boxes.
Defaults to [-85, -85, -5, 85, 85, 5].
Defaults to [-85, -85, -5, 85, 85, 5].
cam_sync_instances (bool
, optional
): If use the camera sync label
cam_sync_instances (bool): If use the camera sync label
supported from waymo version 1.3.1. Defaults to False.
supported from waymo version 1.3.1. Defaults to False.
load_interval (int, optional): load frame interval.
load_interval (int): load frame interval. Defaults to 1.
Defaults to 1.
max_sweeps (int): max sweep for each frame. Defaults to 0.
task (str, optional): task for 3D detection (lidar, mono3d).
lidar: take all the ground trurh in the frame.
mono3d: take the groundtruth that can be seen in the cam.
Defaults to 'lidar'.
max_sweeps (int, optional): max sweep for each frame. Defaults to 0.
"""
"""
METAINFO
=
{
'
CLASSES
'
:
(
'Car'
,
'Pedestrian'
,
'Cyclist'
)}
METAINFO
=
{
'
classes
'
:
(
'Car'
,
'Pedestrian'
,
'Cyclist'
)}
def
__init__
(
self
,
def
__init__
(
self
,
data_root
:
str
,
data_root
:
str
,
...
@@ -75,28 +81,27 @@ class WaymoDataset(KittiDataset):
...
@@ -75,28 +81,27 @@ class WaymoDataset(KittiDataset):
CAM_SIDE_RIGHT
=
'image_3'
,
CAM_SIDE_RIGHT
=
'image_3'
,
CAM_SIDE_LEFT
=
'image_4'
),
CAM_SIDE_LEFT
=
'image_4'
),
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
pipeline
:
List
[
Union
[
dict
,
Callable
]]
=
[],
modality
:
Optional
[
dict
]
=
dict
(
use_lidar
=
True
),
modality
:
dict
=
dict
(
use_lidar
=
True
),
default_cam_key
:
str
=
'CAM_FRONT'
,
default_cam_key
:
str
=
'CAM_FRONT'
,
box_type_3d
:
str
=
'LiDAR'
,
box_type_3d
:
str
=
'LiDAR'
,
load_type
:
str
=
'frame_based'
,
filter_empty_gt
:
bool
=
True
,
filter_empty_gt
:
bool
=
True
,
test_mode
:
bool
=
False
,
test_mode
:
bool
=
False
,
pcd_limit_range
:
List
[
float
]
=
[
0
,
-
40
,
-
3
,
70.4
,
40
,
0.0
],
pcd_limit_range
:
List
[
float
]
=
[
0
,
-
40
,
-
3
,
70.4
,
40
,
0.0
],
cam_sync_instances
=
False
,
cam_sync_instances
:
bool
=
False
,
load_interval
=
1
,
load_interval
:
int
=
1
,
task
=
'lidar_det'
,
max_sweeps
:
int
=
0
,
max_sweeps
=
0
,
**
kwargs
)
->
None
:
**
kwargs
):
self
.
load_interval
=
load_interval
self
.
load_interval
=
load_interval
# set loading mode for different task settings
# set loading mode for different task settings
self
.
cam_sync_instances
=
cam_sync_instances
self
.
cam_sync_instances
=
cam_sync_instances
# construct self.cat_ids for vision-only anns parsing
# construct self.cat_ids for vision-only anns parsing
self
.
cat_ids
=
range
(
len
(
self
.
METAINFO
[
'
CLASSES
'
]))
self
.
cat_ids
=
range
(
len
(
self
.
METAINFO
[
'
classes
'
]))
self
.
cat2label
=
{
cat_id
:
i
for
i
,
cat_id
in
enumerate
(
self
.
cat_ids
)}
self
.
cat2label
=
{
cat_id
:
i
for
i
,
cat_id
in
enumerate
(
self
.
cat_ids
)}
self
.
max_sweeps
=
max_sweeps
self
.
max_sweeps
=
max_sweeps
self
.
task
=
task
# we do not provide file_client_args to custom_3d init
# we do not provide file_client_args to custom_3d init
# because we want disk loading for info
# because we want disk loading for info
# while ceph loading for
KITTI
2Waymo
# while ceph loading for
Prediction
2Waymo
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
ann_file
=
ann_file
,
ann_file
=
ann_file
,
...
@@ -108,24 +113,25 @@ class WaymoDataset(KittiDataset):
...
@@ -108,24 +113,25 @@ class WaymoDataset(KittiDataset):
default_cam_key
=
default_cam_key
,
default_cam_key
=
default_cam_key
,
data_prefix
=
data_prefix
,
data_prefix
=
data_prefix
,
test_mode
=
test_mode
,
test_mode
=
test_mode
,
load_type
=
load_type
,
**
kwargs
)
**
kwargs
)
def
parse_ann_info
(
self
,
info
:
dict
)
->
dict
:
def
parse_ann_info
(
self
,
info
:
dict
)
->
dict
:
"""
Get annotation info according to the given index
.
"""
Process the `instances` in data info to `ann_info`
.
Args:
Args:
info (dict): Data information of single data sample.
info (dict): Data information of single data sample.
Returns:
Returns:
dict:
a
nnotation information consists of the following keys:
dict:
A
nnotation information consists of the following keys:
- bboxes_3d (:obj:`LiDARInstance3DBoxes`):
- bboxes_3d (:obj:`LiDARInstance3DBoxes`):
3D ground truth bboxes.
3D ground truth bboxes.
- bbox_labels_3d (np.ndarray): Labels of ground truths.
- bbox_labels_3d (np.ndarray): Labels of ground truths.
- gt_bboxes (np.ndarray): 2D ground truth bboxes.
- gt_bboxes (np.ndarray): 2D ground truth bboxes.
- gt_labels (np.ndarray): Labels of ground truths.
- gt_labels (np.ndarray): Labels of ground truths.
- difficulty (int): Difficulty defined by KITTI.
- difficulty (int): Difficulty defined by KITTI.
0, 1, 2 represent xxxxx respectively.
0, 1, 2 represent xxxxx respectively.
"""
"""
ann_info
=
Det3DDataset
.
parse_ann_info
(
self
,
info
)
ann_info
=
Det3DDataset
.
parse_ann_info
(
self
,
info
)
if
ann_info
is
None
:
if
ann_info
is
None
:
...
@@ -150,7 +156,7 @@ class WaymoDataset(KittiDataset):
...
@@ -150,7 +156,7 @@ class WaymoDataset(KittiDataset):
centers_2d
=
np
.
zeros
((
0
,
2
),
dtype
=
np
.
float32
)
centers_2d
=
np
.
zeros
((
0
,
2
),
dtype
=
np
.
float32
)
depths
=
np
.
zeros
((
0
),
dtype
=
np
.
float32
)
depths
=
np
.
zeros
((
0
),
dtype
=
np
.
float32
)
if
self
.
task
==
'mono_det'
:
if
self
.
load_type
in
[
'fov_image_based'
,
'mv_image_based'
]
:
gt_bboxes_3d
=
CameraInstance3DBoxes
(
gt_bboxes_3d
=
CameraInstance3DBoxes
(
ann_info
[
'gt_bboxes_3d'
],
ann_info
[
'gt_bboxes_3d'
],
box_dim
=
ann_info
[
'gt_bboxes_3d'
].
shape
[
-
1
],
box_dim
=
ann_info
[
'gt_bboxes_3d'
].
shape
[
-
1
],
...
@@ -182,13 +188,22 @@ class WaymoDataset(KittiDataset):
...
@@ -182,13 +188,22 @@ class WaymoDataset(KittiDataset):
data_list
=
data_list
[::
self
.
load_interval
]
data_list
=
data_list
[::
self
.
load_interval
]
return
data_list
return
data_list
def
parse_data_info
(
self
,
info
:
dict
)
->
dict
:
def
parse_data_info
(
self
,
info
:
dict
)
->
Union
[
dict
,
List
[
dict
]]
:
"""if task is lidar or multiview det, use super() method elif task is
"""if task is lidar or multiview det, use super() method elif task is
mono3d, split the info from frame-wise to img-wise."""
mono3d, split the info from frame-wise to img-wise."""
if
self
.
task
!=
'mono_det'
:
if
self
.
cam_sync_instances
:
if
self
.
cam_sync_instances
:
# use the cam sync labels
info
[
'instances'
]
=
info
[
'cam_sync_instances'
]
info
[
'instances'
]
=
info
[
'cam_sync_instances'
]
if
self
.
load_type
==
'frame_based'
:
return
super
().
parse_data_info
(
info
)
elif
self
.
load_type
==
'fov_image_based'
:
# only loading the fov image and the fov instance
new_image_info
=
{}
new_image_info
[
self
.
default_cam_key
]
=
\
info
[
'images'
][
self
.
default_cam_key
]
info
[
'images'
]
=
new_image_info
info
[
'instances'
]
=
info
[
'cam_instances'
][
self
.
default_cam_key
]
return
super
().
parse_data_info
(
info
)
return
super
().
parse_data_info
(
info
)
else
:
else
:
# in the mono3d, the instances is from cam sync.
# in the mono3d, the instances is from cam sync.
...
...
mmdet3d/evaluation/functional/waymo_utils/__init__.py
View file @
d7067e44
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
.prediction_
kitti_
to_waymo
import
KITTI
2Waymo
from
.prediction_to_waymo
import
Prediction
2Waymo
__all__
=
[
'
KITTI
2Waymo'
]
__all__
=
[
'
Prediction
2Waymo'
]
mmdet3d/evaluation/functional/waymo_utils/prediction_
kitti_
to_waymo.py
→
mmdet3d/evaluation/functional/waymo_utils/prediction_to_waymo.py
View file @
d7067e44
...
@@ -5,29 +5,33 @@ r"""Adapted from `Waymo to KITTI converter
...
@@ -5,29 +5,33 @@ r"""Adapted from `Waymo to KITTI converter
try
:
try
:
from
waymo_open_dataset
import
dataset_pb2
as
open_dataset
from
waymo_open_dataset
import
dataset_pb2
as
open_dataset
from
waymo_open_dataset
import
label_pb2
from
waymo_open_dataset.protos
import
metrics_pb2
from
waymo_open_dataset.protos.metrics_pb2
import
Objects
except
ImportError
:
except
ImportError
:
Objects
=
None
raise
ImportError
(
raise
ImportError
(
'Please run "pip install waymo-open-dataset-tf-2-1-0==1.2.0" '
'Please run "pip install waymo-open-dataset-tf-2-1-0==1.2.0" '
'to install the official devkit first.'
)
'to install the official devkit first.'
)
from
glob
import
glob
from
glob
import
glob
from
os.path
import
join
from
os.path
import
join
from
typing
import
List
,
Optional
import
mmengine
import
mmengine
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
waymo_open_dataset
import
label_pb2
from
waymo_open_dataset.protos
import
metrics_pb2
class
KITTI2Waymo
(
object
):
class
Prediction2Waymo
(
object
):
"""KITTI predictions to Waymo converter.
"""Predictions to Waymo converter. The format of prediction results could
be original format or kitti-format.
This class serves as the converter to change predictions from KITTI to
This class serves as the converter to change predictions from KITTI to
Waymo format.
Waymo format.
Args:
Args:
kitti_
result
_file
s (list[dict]): Prediction
s in KITTI format
.
results (list[dict]): Prediction
results
.
waymo_tfrecords_dir (str): Directory to load waymo raw data.
waymo_tfrecords_dir (str): Directory to load waymo raw data.
waymo_results_save_dir (str): Directory to save converted predictions
waymo_results_save_dir (str): Directory to save converted predictions
in waymo format (.bin files).
in waymo format (.bin files).
...
@@ -35,33 +39,47 @@ class KITTI2Waymo(object):
...
@@ -35,33 +39,47 @@ class KITTI2Waymo(object):
predictions in waymo format (.bin file), like 'a/b/c.bin'.
predictions in waymo format (.bin file), like 'a/b/c.bin'.
prefix (str): Prefix of filename. In general, 0 for training, 1 for
prefix (str): Prefix of filename. In general, 0 for training, 1 for
validation and 2 for testing.
validation and 2 for testing.
workers (str): Number of parallel processes.
classes (dict): A list of class name.
workers (str): Number of parallel processes. Defaults to 2.
file_client_args (str): File client for reading gt in waymo format.
Defaults to ``dict(backend='disk')``.
from_kitti_format (bool, optional): Whether the reuslts are kitti
format. Defaults to False.
idx2metainfo (Optional[dict], optional): The mapping from sample_idx to
metainfo. The metainfo must contain the keys: 'idx2contextname' and
'idx2timestamp'. Defaults to None.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
kitti_result_files
,
results
:
List
[
dict
],
waymo_tfrecords_dir
,
waymo_tfrecords_dir
:
str
,
waymo_results_save_dir
,
waymo_results_save_dir
:
str
,
waymo_results_final_path
,
waymo_results_final_path
:
str
,
prefix
,
prefix
:
str
,
workers
=
64
,
classes
:
dict
,
file_client_args
=
dict
(
backend
=
'disk'
)):
workers
:
int
=
2
,
file_client_args
:
dict
=
dict
(
backend
=
'disk'
),
self
.
kitti_result_files
=
kitti_result_files
from_kitti_format
:
bool
=
False
,
idx2metainfo
:
Optional
[
dict
]
=
None
):
self
.
results
=
results
self
.
waymo_tfrecords_dir
=
waymo_tfrecords_dir
self
.
waymo_tfrecords_dir
=
waymo_tfrecords_dir
self
.
waymo_results_save_dir
=
waymo_results_save_dir
self
.
waymo_results_save_dir
=
waymo_results_save_dir
self
.
waymo_results_final_path
=
waymo_results_final_path
self
.
waymo_results_final_path
=
waymo_results_final_path
self
.
prefix
=
prefix
self
.
prefix
=
prefix
self
.
classes
=
classes
self
.
workers
=
int
(
workers
)
self
.
workers
=
int
(
workers
)
self
.
file_client_args
=
file_client_args
self
.
file_client_args
=
file_client_args
self
.
name2idx
=
{}
self
.
from_kitti_format
=
from_kitti_format
for
idx
,
result
in
enumerate
(
kitti_result_files
):
if
idx2metainfo
is
not
None
:
if
len
(
result
[
'sample_id'
])
>
0
:
self
.
idx2metainfo
=
idx2metainfo
self
.
name2idx
[
str
(
result
[
'sample_id'
][
0
])]
=
idx
# If ``fast_eval``, the metainfo does not need to be read from
# original data online. It's preprocessed offline.
self
.
fast_eval
=
True
else
:
self
.
fast_eval
=
False
# turn on eager execution for older tensorflow versions
self
.
name2idx
=
{}
if
int
(
tf
.
__version__
.
split
(
'.'
)[
0
])
<
2
:
tf
.
enable_eager_execution
()
self
.
k2w_cls_map
=
{
self
.
k2w_cls_map
=
{
'Car'
:
label_pb2
.
Label
.
TYPE_VEHICLE
,
'Car'
:
label_pb2
.
Label
.
TYPE_VEHICLE
,
...
@@ -70,12 +88,28 @@ class KITTI2Waymo(object):
...
@@ -70,12 +88,28 @@ class KITTI2Waymo(object):
'Cyclist'
:
label_pb2
.
Label
.
TYPE_CYCLIST
,
'Cyclist'
:
label_pb2
.
Label
.
TYPE_CYCLIST
,
}
}
self
.
T_ref_to_front_cam
=
np
.
array
([[
0.0
,
0.0
,
1.0
,
0.0
],
if
self
.
from_kitti_format
:
[
-
1.0
,
0.0
,
0.0
,
0.0
],
self
.
T_ref_to_front_cam
=
np
.
array
([[
0.0
,
0.0
,
1.0
,
0.0
],
[
0.0
,
-
1.0
,
0.0
,
0.0
],
[
-
1.0
,
0.0
,
0.0
,
0.0
],
[
0.0
,
0.0
,
0.0
,
1.0
]])
[
0.0
,
-
1.0
,
0.0
,
0.0
],
[
0.0
,
0.0
,
0.0
,
1.0
]])
# ``sample_idx`` of the sample in kitti-format is an array
for
idx
,
result
in
enumerate
(
results
):
if
len
(
result
[
'sample_idx'
])
>
0
:
self
.
name2idx
[
str
(
result
[
'sample_idx'
][
0
])]
=
idx
else
:
# ``sample_idx`` of the sample in the original prediction
# is an int value.
for
idx
,
result
in
enumerate
(
results
):
self
.
name2idx
[
str
(
result
[
'sample_idx'
])]
=
idx
if
not
self
.
fast_eval
:
# need to read original '.tfrecord' file
self
.
get_file_names
()
# turn on eager execution for older tensorflow versions
if
int
(
tf
.
__version__
.
split
(
'.'
)[
0
])
<
2
:
tf
.
enable_eager_execution
()
self
.
get_file_names
()
self
.
create_folder
()
self
.
create_folder
()
def
get_file_names
(
self
):
def
get_file_names
(
self
):
...
@@ -192,6 +226,13 @@ class KITTI2Waymo(object):
...
@@ -192,6 +226,13 @@ class KITTI2Waymo(object):
file_idx (int): Index of the file to be converted.
file_idx (int): Index of the file to be converted.
"""
"""
file_pathname
=
self
.
waymo_tfrecord_pathnames
[
file_idx
]
file_pathname
=
self
.
waymo_tfrecord_pathnames
[
file_idx
]
if
's3://'
in
file_pathname
and
tf
.
__version__
>=
'2.6.0'
:
try
:
import
tensorflow_io
as
tfio
# noqa: F401
except
ImportError
:
raise
ImportError
(
"Please run 'pip install tensorflow-io' to install tensorflow_io first."
# noqa: E501
)
file_data
=
tf
.
data
.
TFRecordDataset
(
file_pathname
,
compression_type
=
''
)
file_data
=
tf
.
data
.
TFRecordDataset
(
file_pathname
,
compression_type
=
''
)
for
frame_num
,
frame_data
in
enumerate
(
file_data
):
for
frame_num
,
frame_data
in
enumerate
(
file_data
):
...
@@ -200,22 +241,30 @@ class KITTI2Waymo(object):
...
@@ -200,22 +241,30 @@ class KITTI2Waymo(object):
filename
=
f
'
{
self
.
prefix
}{
file_idx
:
03
d
}{
frame_num
:
03
d
}
'
filename
=
f
'
{
self
.
prefix
}{
file_idx
:
03
d
}{
frame_num
:
03
d
}
'
for
camera
in
frame
.
context
.
camera_calibrations
:
# FRONT = 1, see dataset.proto for details
if
camera
.
name
==
1
:
T_front_cam_to_vehicle
=
np
.
array
(
camera
.
extrinsic
.
transform
).
reshape
(
4
,
4
)
T_k2w
=
T_front_cam_to_vehicle
@
self
.
T_ref_to_front_cam
context_name
=
frame
.
context
.
name
context_name
=
frame
.
context
.
name
frame_timestamp_micros
=
frame
.
timestamp_micros
frame_timestamp_micros
=
frame
.
timestamp_micros
if
filename
in
self
.
name2idx
:
if
filename
in
self
.
name2idx
:
kitti_result
=
\
if
self
.
from_kitti_format
:
self
.
kitti_result_files
[
self
.
name2idx
[
filename
]]
for
camera
in
frame
.
context
.
camera_calibrations
:
objects
=
self
.
parse_objects
(
kitti_result
,
T_k2w
,
context_name
,
# FRONT = 1, see dataset.proto for details
frame_timestamp_micros
)
if
camera
.
name
==
1
:
T_front_cam_to_vehicle
=
np
.
array
(
camera
.
extrinsic
.
transform
).
reshape
(
4
,
4
)
T_k2w
=
T_front_cam_to_vehicle
@
self
.
T_ref_to_front_cam
kitti_result
=
\
self
.
results
[
self
.
name2idx
[
filename
]]
objects
=
self
.
parse_objects
(
kitti_result
,
T_k2w
,
context_name
,
frame_timestamp_micros
)
else
:
index
=
self
.
name2idx
[
filename
]
objects
=
self
.
parse_objects_from_origin
(
self
.
results
[
index
],
context_name
,
frame_timestamp_micros
)
else
:
else
:
print
(
filename
,
'not found.'
)
print
(
filename
,
'not found.'
)
objects
=
metrics_pb2
.
Objects
()
objects
=
metrics_pb2
.
Objects
()
...
@@ -225,11 +274,100 @@ class KITTI2Waymo(object):
...
@@ -225,11 +274,100 @@ class KITTI2Waymo(object):
'wb'
)
as
f
:
'wb'
)
as
f
:
f
.
write
(
objects
.
SerializeToString
())
f
.
write
(
objects
.
SerializeToString
())
def
convert_one_fast
(
self
,
res_index
:
int
):
"""Convert action for single file. It read the metainfo from the
preprocessed file offline and will be faster.
Args:
res_index (int): The indices of the results.
"""
sample_idx
=
self
.
results
[
res_index
][
'sample_idx'
]
if
len
(
self
.
results
[
res_index
][
'pred_instances_3d'
])
>
0
:
objects
=
self
.
parse_objects_from_origin
(
self
.
results
[
res_index
],
self
.
idx2metainfo
[
str
(
sample_idx
)][
'contextname'
],
self
.
idx2metainfo
[
str
(
sample_idx
)][
'timestamp'
])
else
:
print
(
sample_idx
,
'not found.'
)
objects
=
metrics_pb2
.
Objects
()
with
open
(
join
(
self
.
waymo_results_save_dir
,
f
'
{
sample_idx
}
.bin'
),
'wb'
)
as
f
:
f
.
write
(
objects
.
SerializeToString
())
def
parse_objects_from_origin
(
self
,
result
:
dict
,
contextname
:
str
,
timestamp
:
str
)
->
Objects
:
"""Parse obejcts from the original prediction results.
Args:
result (dict): The original prediction results.
contextname (str): The ``contextname`` of sample in waymo.
timestamp (str): The ``timestamp`` of sample in waymo.
Returns:
metrics_pb2.Objects: The parsed object.
"""
lidar_boxes
=
result
[
'pred_instances_3d'
][
'bboxes_3d'
].
tensor
scores
=
result
[
'pred_instances_3d'
][
'scores_3d'
]
labels
=
result
[
'pred_instances_3d'
][
'labels_3d'
]
def
parse_one_object
(
index
):
class_name
=
self
.
classes
[
labels
[
index
].
item
()]
box
=
label_pb2
.
Label
.
Box
()
height
=
lidar_boxes
[
index
][
5
].
item
()
heading
=
lidar_boxes
[
index
][
6
].
item
()
while
heading
<
-
np
.
pi
:
heading
+=
2
*
np
.
pi
while
heading
>
np
.
pi
:
heading
-=
2
*
np
.
pi
box
.
center_x
=
lidar_boxes
[
index
][
0
].
item
()
box
.
center_y
=
lidar_boxes
[
index
][
1
].
item
()
box
.
center_z
=
lidar_boxes
[
index
][
2
].
item
()
+
height
/
2
box
.
length
=
lidar_boxes
[
index
][
3
].
item
()
box
.
width
=
lidar_boxes
[
index
][
4
].
item
()
box
.
height
=
height
box
.
heading
=
heading
o
=
metrics_pb2
.
Object
()
o
.
object
.
box
.
CopyFrom
(
box
)
o
.
object
.
type
=
self
.
k2w_cls_map
[
class_name
]
o
.
score
=
scores
[
index
].
item
()
o
.
context_name
=
contextname
o
.
frame_timestamp_micros
=
timestamp
return
o
objects
=
metrics_pb2
.
Objects
()
for
i
in
range
(
len
(
lidar_boxes
)):
objects
.
objects
.
append
(
parse_one_object
(
i
))
return
objects
def
convert
(
self
):
def
convert
(
self
):
"""Convert action."""
"""Convert action."""
print
(
'Start converting ...'
)
print
(
'Start converting ...'
)
mmengine
.
track_parallel_progress
(
self
.
convert_one
,
range
(
len
(
self
)),
convert_func
=
self
.
convert_one_fast
if
self
.
fast_eval
else
\
self
.
workers
)
self
.
convert_one
# from torch.multiprocessing import set_sharing_strategy
# # Force using "file_system" sharing strategy for stability
# set_sharing_strategy("file_system")
# mmengine.track_parallel_progress(convert_func, range(len(self)),
# self.workers)
# TODO: Support multiprocessing. Now, multiprocessing evaluation will
# cause shared memory error in torch-1.10 and torch-1.11. Details can
# be seen in https://github.com/pytorch/pytorch/issues/67864.
prog_bar
=
mmengine
.
ProgressBar
(
len
(
self
))
for
i
in
range
(
len
(
self
)):
convert_func
(
i
)
prog_bar
.
update
()
print
(
'
\n
Finished ...'
)
print
(
'
\n
Finished ...'
)
# combine all files into one .bin
# combine all files into one .bin
...
@@ -241,7 +379,8 @@ class KITTI2Waymo(object):
...
@@ -241,7 +379,8 @@ class KITTI2Waymo(object):
def
__len__
(
self
):
def
__len__
(
self
):
"""Length of the filename list."""
"""Length of the filename list."""
return
len
(
self
.
waymo_tfrecord_pathnames
)
return
len
(
self
.
results
)
if
self
.
fast_eval
else
len
(
self
.
waymo_tfrecord_pathnames
)
def
transform
(
self
,
T
,
x
,
y
,
z
):
def
transform
(
self
,
T
,
x
,
y
,
z
):
"""Transform the coordinates with matrix T.
"""Transform the coordinates with matrix T.
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment