Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
f4f8ae22
Commit
f4f8ae22
authored
Jun 09, 2022
by
jshilong
Committed by
ChaimZhu
Jul 20, 2022
Browse files
Refactor GlobalAlignment and PointSegclassMappin
parent
7c6810e3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
74 additions
and
22 deletions
+74
-22
mmdet3d/datasets/pipelines/loading.py
mmdet3d/datasets/pipelines/loading.py
+21
-7
mmdet3d/datasets/pipelines/transforms_3d.py
mmdet3d/datasets/pipelines/transforms_3d.py
+14
-14
tests/test_data/test_transforms/test_augs.py
tests/test_data/test_transforms/test_augs.py
+23
-1
tests/test_data/test_transforms/test_loading.py
tests/test_data/test_transforms/test_loading.py
+16
-0
No files found.
mmdet3d/datasets/pipelines/loading.py
View file @
f4f8ae22
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Sequence
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
from
mmcv
import
BaseTransform
from
mmcv.transforms
import
LoadImageFromFile
from
mmcv.transforms
import
LoadImageFromFile
from
mmcv.transforms.base
import
BaseTransform
from
mmdet3d.core.points
import
BasePoints
,
get_points_type
from
mmdet3d.core.points
import
BasePoints
,
get_points_type
from
mmdet3d.registry
import
TRANSFORMS
from
mmdet3d.registry
import
TRANSFORMS
...
@@ -241,9 +243,19 @@ class LoadPointsFromMultiSweeps(object):
...
@@ -241,9 +243,19 @@ class LoadPointsFromMultiSweeps(object):
@
TRANSFORMS
.
register_module
()
@
TRANSFORMS
.
register_module
()
class
PointSegClassMapping
(
object
):
class
PointSegClassMapping
(
BaseTransform
):
"""Map original semantic class to valid category ids.
"""Map original semantic class to valid category ids.
Required Keys:
- lidar_points (dict)
- lidar_path (str)
Added Keys:
- points (np.float32)
Map valid classes as 0~len(valid_cat_ids)-1 and
Map valid classes as 0~len(valid_cat_ids)-1 and
others as len(valid_cat_ids).
others as len(valid_cat_ids).
...
@@ -253,7 +265,9 @@ class PointSegClassMapping(object):
...
@@ -253,7 +265,9 @@ class PointSegClassMapping(object):
segmentation mask. Defaults to 40.
segmentation mask. Defaults to 40.
"""
"""
def
__init__
(
self
,
valid_cat_ids
,
max_cat_id
=
40
):
def
__init__
(
self
,
valid_cat_ids
:
Sequence
[
int
],
max_cat_id
:
int
=
40
)
->
None
:
assert
max_cat_id
>=
np
.
max
(
valid_cat_ids
),
\
assert
max_cat_id
>=
np
.
max
(
valid_cat_ids
),
\
'max_cat_id should be greater than maximum id in valid_cat_ids'
'max_cat_id should be greater than maximum id in valid_cat_ids'
...
@@ -267,7 +281,7 @@ class PointSegClassMapping(object):
...
@@ -267,7 +281,7 @@ class PointSegClassMapping(object):
for
cls_idx
,
cat_id
in
enumerate
(
valid_cat_ids
):
for
cls_idx
,
cat_id
in
enumerate
(
valid_cat_ids
):
self
.
cat_id2class
[
cat_id
]
=
cls_idx
self
.
cat_id2class
[
cat_id
]
=
cls_idx
def
__call__
(
self
,
results
)
:
def
transform
(
self
,
results
:
dict
)
->
None
:
"""Call function to map original semantic class to valid category ids.
"""Call function to map original semantic class to valid category ids.
Args:
Args:
...
@@ -320,11 +334,11 @@ class NormalizePointsColor(object):
...
@@ -320,11 +334,11 @@ class NormalizePointsColor(object):
"""
"""
points
=
results
[
'points'
]
points
=
results
[
'points'
]
assert
points
.
attribute_dims
is
not
None
and
\
assert
points
.
attribute_dims
is
not
None
and
\
'color'
in
points
.
attribute_dims
.
keys
(),
\
'color'
in
points
.
attribute_dims
.
keys
(),
\
'Expect points have color attribute'
'Expect points have color attribute'
if
self
.
color_mean
is
not
None
:
if
self
.
color_mean
is
not
None
:
points
.
color
=
points
.
color
-
\
points
.
color
=
points
.
color
-
\
points
.
color
.
new_tensor
(
self
.
color_mean
)
points
.
color
.
new_tensor
(
self
.
color_mean
)
points
.
color
=
points
.
color
/
255.0
points
.
color
=
points
.
color
/
255.0
results
[
'points'
]
=
points
results
[
'points'
]
=
points
return
results
return
results
...
...
mmdet3d/datasets/pipelines/transforms_3d.py
View file @
f4f8ae22
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
random
import
random
import
warnings
import
warnings
from
typing
import
List
from
typing
import
Dict
,
List
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
...
@@ -507,7 +507,7 @@ class ObjectNoise(BaseTransform):
...
@@ -507,7 +507,7 @@ class ObjectNoise(BaseTransform):
@
TRANSFORMS
.
register_module
()
@
TRANSFORMS
.
register_module
()
class
GlobalAlignment
(
object
):
class
GlobalAlignment
(
BaseTransform
):
"""Apply global alignment to 3D scene points by rotation and translation.
"""Apply global alignment to 3D scene points by rotation and translation.
Args:
Args:
...
@@ -521,10 +521,10 @@ class GlobalAlignment(object):
...
@@ -521,10 +521,10 @@ class GlobalAlignment(object):
bounding boxes for evaluation.
bounding boxes for evaluation.
"""
"""
def
__init__
(
self
,
rotation_axis
)
:
def
__init__
(
self
,
rotation_axis
:
int
)
->
None
:
self
.
rotation_axis
=
rotation_axis
self
.
rotation_axis
=
rotation_axis
def
_trans_points
(
self
,
input_d
ict
,
trans_factor
)
:
def
_trans_points
(
self
,
results
:
D
ict
,
trans_factor
:
np
.
ndarray
)
->
None
:
"""Private function to translate points.
"""Private function to translate points.
Args:
Args:
...
@@ -534,9 +534,9 @@ class GlobalAlignment(object):
...
@@ -534,9 +534,9 @@ class GlobalAlignment(object):
Returns:
Returns:
dict: Results after translation, 'points' is updated in the dict.
dict: Results after translation, 'points' is updated in the dict.
"""
"""
input_dict
[
'points'
].
translate
(
trans_factor
)
results
[
'points'
].
translate
(
trans_factor
)
def
_rot_points
(
self
,
input_d
ict
,
rot_mat
)
:
def
_rot_points
(
self
,
results
:
D
ict
,
rot_mat
:
np
.
ndarray
)
->
None
:
"""Private function to rotate bounding boxes and points.
"""Private function to rotate bounding boxes and points.
Args:
Args:
...
@@ -547,9 +547,9 @@ class GlobalAlignment(object):
...
@@ -547,9 +547,9 @@ class GlobalAlignment(object):
dict: Results after rotation, 'points' is updated in the dict.
dict: Results after rotation, 'points' is updated in the dict.
"""
"""
# input should be rot_mat_T so I transpose it here
# input should be rot_mat_T so I transpose it here
input_dict
[
'points'
].
rotate
(
rot_mat
.
T
)
results
[
'points'
].
rotate
(
rot_mat
.
T
)
def
_check_rot_mat
(
self
,
rot_mat
)
:
def
_check_rot_mat
(
self
,
rot_mat
:
np
.
ndarray
)
->
None
:
"""Check if rotation matrix is valid for self.rotation_axis.
"""Check if rotation matrix is valid for self.rotation_axis.
Args:
Args:
...
@@ -562,7 +562,7 @@ class GlobalAlignment(object):
...
@@ -562,7 +562,7 @@ class GlobalAlignment(object):
is_valid
&=
(
rot_mat
[:,
self
.
rotation_axis
]
==
valid_array
).
all
()
is_valid
&=
(
rot_mat
[:,
self
.
rotation_axis
]
==
valid_array
).
all
()
assert
is_valid
,
f
'invalid rotation matrix
{
rot_mat
}
'
assert
is_valid
,
f
'invalid rotation matrix
{
rot_mat
}
'
def
__call__
(
self
,
input_dict
)
:
def
transform
(
self
,
results
:
Dict
)
->
Dict
:
"""Call function to shuffle points.
"""Call function to shuffle points.
Args:
Args:
...
@@ -572,20 +572,20 @@ class GlobalAlignment(object):
...
@@ -572,20 +572,20 @@ class GlobalAlignment(object):
dict: Results after global alignment, 'points' and keys in
dict: Results after global alignment, 'points' and keys in
input_dict['bbox3d_fields'] are updated in the result dict.
input_dict['bbox3d_fields'] are updated in the result dict.
"""
"""
assert
'axis_align_matrix'
in
input_dict
[
'ann_info'
].
keys
()
,
\
assert
'axis_align_matrix'
in
results
,
\
'axis_align_matrix is not provided in GlobalAlignment'
'axis_align_matrix is not provided in GlobalAlignment'
axis_align_matrix
=
input_dict
[
'ann_info'
]
[
'axis_align_matrix'
]
axis_align_matrix
=
results
[
'axis_align_matrix'
]
assert
axis_align_matrix
.
shape
==
(
4
,
4
),
\
assert
axis_align_matrix
.
shape
==
(
4
,
4
),
\
f
'invalid shape
{
axis_align_matrix
.
shape
}
for axis_align_matrix'
f
'invalid shape
{
axis_align_matrix
.
shape
}
for axis_align_matrix'
rot_mat
=
axis_align_matrix
[:
3
,
:
3
]
rot_mat
=
axis_align_matrix
[:
3
,
:
3
]
trans_vec
=
axis_align_matrix
[:
3
,
-
1
]
trans_vec
=
axis_align_matrix
[:
3
,
-
1
]
self
.
_check_rot_mat
(
rot_mat
)
self
.
_check_rot_mat
(
rot_mat
)
self
.
_rot_points
(
input_dict
,
rot_mat
)
self
.
_rot_points
(
results
,
rot_mat
)
self
.
_trans_points
(
input_dict
,
trans_vec
)
self
.
_trans_points
(
results
,
trans_vec
)
return
input_dict
return
results
def
__repr__
(
self
):
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
=
self
.
__class__
.
__name__
...
...
tests/test_data/test_transforms/test_augs.py
View file @
f4f8ae22
...
@@ -2,11 +2,12 @@
...
@@ -2,11 +2,12 @@
import
copy
import
copy
import
unittest
import
unittest
import
numpy
as
np
import
torch
import
torch
from
mmengine.testing
import
assert_allclose
from
mmengine.testing
import
assert_allclose
from
utils
import
create_data_info_after_loading
from
utils
import
create_data_info_after_loading
from
mmdet3d.datasets
import
RandomFlip3D
from
mmdet3d.datasets
import
GlobalAlignment
,
RandomFlip3D
from
mmdet3d.datasets.pipelines
import
GlobalRotScaleTrans
from
mmdet3d.datasets.pipelines
import
GlobalRotScaleTrans
...
@@ -77,3 +78,24 @@ class TestRandomFlip3D(unittest.TestCase):
...
@@ -77,3 +78,24 @@ class TestRandomFlip3D(unittest.TestCase):
-
ori_data_info
[
'gt_bboxes_3d'
].
tensor
[:,
1
])
-
ori_data_info
[
'gt_bboxes_3d'
].
tensor
[:,
1
])
assert_allclose
(
data_info
[
'gt_bboxes_3d'
].
tensor
[:,
2
],
assert_allclose
(
data_info
[
'gt_bboxes_3d'
].
tensor
[:,
2
],
ori_data_info
[
'gt_bboxes_3d'
].
tensor
[:,
2
])
ori_data_info
[
'gt_bboxes_3d'
].
tensor
[:,
2
])
class
TestGlobalAlignment
(
unittest
.
TestCase
):
def
test_global_alignment
(
self
):
data_info
=
create_data_info_after_loading
()
global_align_transform
=
GlobalAlignment
(
rotation_axis
=
2
)
data_info
[
'axis_align_matrix'
]
=
np
.
array
(
[[
0.945519
,
0.325568
,
0.
,
-
5.38439
],
[
-
0.325568
,
0.945519
,
0.
,
-
2.87178
],
[
0.
,
0.
,
1.
,
-
0.06435
],
[
0.
,
0.
,
0.
,
1.
]],
dtype
=
np
.
float32
)
global_align_transform
(
data_info
)
data_info
[
'axis_align_matrix'
]
=
np
.
array
(
[[
0.945519
,
0.325568
,
0.
,
-
5.38439
],
[
0
,
2
,
0.
,
-
2.87178
],
[
0.
,
0.
,
1.
,
-
0.06435
],
[
0.
,
0.
,
0.
,
1.
]],
dtype
=
np
.
float32
)
# assert the rot metric
with
self
.
assertRaises
(
AssertionError
):
global_align_transform
(
data_info
)
tests/test_data/test_transforms/test_loading.py
View file @
f4f8ae22
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
unittest
import
unittest
import
numpy
as
np
import
torch
import
torch
from
mmengine.testing
import
assert_allclose
from
mmengine.testing
import
assert_allclose
from
utils
import
create_dummy_data_info
from
utils
import
create_dummy_data_info
from
mmdet3d.core
import
DepthPoints
,
LiDARPoints
from
mmdet3d.core
import
DepthPoints
,
LiDARPoints
from
mmdet3d.datasets.pipelines
import
PointSegClassMapping
from
mmdet3d.datasets.pipelines.loading
import
(
LoadAnnotations3D
,
from
mmdet3d.datasets.pipelines.loading
import
(
LoadAnnotations3D
,
LoadPointsFromFile
)
LoadPointsFromFile
)
...
@@ -71,3 +73,17 @@ class TestLoadAnnotations3D(unittest.TestCase):
...
@@ -71,3 +73,17 @@ class TestLoadAnnotations3D(unittest.TestCase):
self
.
assertIn
(
'with_bbox_3d=True'
,
repr_str
)
self
.
assertIn
(
'with_bbox_3d=True'
,
repr_str
)
self
.
assertIn
(
'with_label_3d=True'
,
repr_str
)
self
.
assertIn
(
'with_label_3d=True'
,
repr_str
)
self
.
assertIn
(
'with_bbox_depth=False'
,
repr_str
)
self
.
assertIn
(
'with_bbox_depth=False'
,
repr_str
)
class
TestPointSegClassMapping
(
unittest
.
TestCase
):
def
test_point_seg_class_mapping
(
self
):
results
=
dict
()
results
[
'pts_semantic_mask'
]
=
np
.
array
([
1
,
2
,
3
,
4
,
5
])
point_seg_mapping_transform
=
PointSegClassMapping
(
valid_cat_ids
=
[
1
,
2
,
3
],
max_cat_id
=
results
[
'pts_semantic_mask'
].
max
())
results
=
point_seg_mapping_transform
(
results
)
assert_allclose
(
results
[
'pts_semantic_mask'
],
np
.
array
([
0
,
1
,
2
,
3
,
3
]))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment