Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
8282f10c
Commit
8282f10c
authored
May 25, 2022
by
VVsssssk
Committed by
ChaimZhu
Jul 20, 2022
Browse files
[Refactor]Fix ObjectRangeFilter + PointsRangeFilter + ObjectNameFilter
parent
84e479ea
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
21 deletions
+57
-21
mmdet3d/datasets/pipelines/transforms_3d.py
mmdet3d/datasets/pipelines/transforms_3d.py
+57
-21
No files found.
mmdet3d/datasets/pipelines/transforms_3d.py
View file @
8282f10c
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
random
import
random
import
warnings
import
warnings
from
typing
import
Dict
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
...
@@ -12,6 +11,7 @@ from mmengine.registry import build_from_cfg
...
@@ -12,6 +11,7 @@ from mmengine.registry import build_from_cfg
from
mmdet3d.core
import
VoxelGenerator
from
mmdet3d.core
import
VoxelGenerator
from
mmdet3d.core.bbox
import
(
CameraInstance3DBoxes
,
DepthInstance3DBoxes
,
from
mmdet3d.core.bbox
import
(
CameraInstance3DBoxes
,
DepthInstance3DBoxes
,
LiDARInstance3DBoxes
,
box_np_ops
)
LiDARInstance3DBoxes
,
box_np_ops
)
from
mmdet3d.core.points
import
BasePoints
from
mmdet3d.registry
import
TRANSFORMS
from
mmdet3d.registry
import
TRANSFORMS
from
mmdet.datasets.pipelines
import
RandomFlip
from
mmdet.datasets.pipelines
import
RandomFlip
from
.data_augment_utils
import
noise_per_object_v3_
from
.data_augment_utils
import
noise_per_object_v3_
...
@@ -274,6 +274,7 @@ class ObjectSample(BaseTransform):
...
@@ -274,6 +274,7 @@ class ObjectSample(BaseTransform):
- gt_bboxes (optional)
- gt_bboxes (optional)
Modified Keys:
Modified Keys:
- points
- points
- gt_bboxes_3d
- gt_bboxes_3d
- gt_labels_3d
- gt_labels_3d
...
@@ -293,7 +294,10 @@ class ObjectSample(BaseTransform):
...
@@ -293,7 +294,10 @@ class ObjectSample(BaseTransform):
3D labels.
3D labels.
"""
"""
def
__init__
(
self
,
db_sampler
,
sample_2d
=
False
,
use_ground_plane
=
False
):
def
__init__
(
self
,
db_sampler
:
dict
,
sample_2d
:
bool
=
False
,
use_ground_plane
:
bool
=
False
):
self
.
sampler_cfg
=
db_sampler
self
.
sampler_cfg
=
db_sampler
self
.
sample_2d
=
sample_2d
self
.
sample_2d
=
sample_2d
if
'type'
not
in
db_sampler
.
keys
():
if
'type'
not
in
db_sampler
.
keys
():
...
@@ -302,7 +306,8 @@ class ObjectSample(BaseTransform):
...
@@ -302,7 +306,8 @@ class ObjectSample(BaseTransform):
self
.
use_ground_plane
=
use_ground_plane
self
.
use_ground_plane
=
use_ground_plane
@
staticmethod
@
staticmethod
def
remove_points_in_boxes
(
points
,
boxes
):
def
remove_points_in_boxes
(
points
:
BasePoints
,
boxes
:
np
.
ndarray
)
->
np
.
ndarray
:
"""Remove the points in the sampled bounding boxes.
"""Remove the points in the sampled bounding boxes.
Args:
Args:
...
@@ -422,10 +427,10 @@ class ObjectNoise(BaseTransform):
...
@@ -422,10 +427,10 @@ class ObjectNoise(BaseTransform):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
translation_std
=
[
0.25
,
0.25
,
0.25
],
translation_std
:
list
=
[
0.25
,
0.25
,
0.25
],
global_rot_range
=
[
0.0
,
0.0
],
global_rot_range
:
list
=
[
0.0
,
0.0
],
rot_range
=
[
-
0.15707963267
,
0.15707963267
],
rot_range
:
list
=
[
-
0.15707963267
,
0.15707963267
],
num_try
=
100
):
num_try
:
int
=
100
):
self
.
translation_std
=
translation_std
self
.
translation_std
=
translation_std
self
.
global_rot_range
=
global_rot_range
self
.
global_rot_range
=
global_rot_range
self
.
rot_range
=
rot_range
self
.
rot_range
=
rot_range
...
@@ -756,18 +761,26 @@ class PointShuffle(object):
...
@@ -756,18 +761,26 @@ class PointShuffle(object):
@
TRANSFORMS
.
register_module
()
@
TRANSFORMS
.
register_module
()
class
ObjectRangeFilter
(
object
):
class
ObjectRangeFilter
(
BaseTransform
):
"""Filter objects by the range.
"""Filter objects by the range.
Required Keys:
- gt_bboxes_3d
Modified Keys:
- gt_bboxes_3d
Args:
Args:
point_cloud_range (list[float]): Point cloud range.
point_cloud_range (list[float]): Point cloud range.
"""
"""
def
__init__
(
self
,
point_cloud_range
):
def
__init__
(
self
,
point_cloud_range
:
list
):
self
.
pcd_range
=
np
.
array
(
point_cloud_range
,
dtype
=
np
.
float32
)
self
.
pcd_range
=
np
.
array
(
point_cloud_range
,
dtype
=
np
.
float32
)
def
__call__
(
self
,
input_dict
)
:
def
transform
(
self
,
input_dict
:
dict
)
->
dict
:
"""
Call
function to filter objects by the range.
"""
Transform
function to filter objects by the range.
Args:
Args:
input_dict (dict): Result dict from loading pipeline.
input_dict (dict): Result dict from loading pipeline.
...
@@ -808,18 +821,28 @@ class ObjectRangeFilter(object):
...
@@ -808,18 +821,28 @@ class ObjectRangeFilter(object):
@
TRANSFORMS
.
register_module
()
@
TRANSFORMS
.
register_module
()
class
PointsRangeFilter
(
object
):
class
PointsRangeFilter
(
BaseTransform
):
"""Filter points by the range.
"""Filter points by the range.
Required Keys:
- points
- pts_instance_mask (optional)
Modified Keys:
- points
- pts_instance_mask (optional)
Args:
Args:
point_cloud_range (list[float]): Point cloud range.
point_cloud_range (list[float]): Point cloud range.
"""
"""
def
__init__
(
self
,
point_cloud_range
):
def
__init__
(
self
,
point_cloud_range
:
list
):
self
.
pcd_range
=
np
.
array
(
point_cloud_range
,
dtype
=
np
.
float32
)
self
.
pcd_range
=
np
.
array
(
point_cloud_range
,
dtype
=
np
.
float32
)
def
__call__
(
self
,
input_dict
)
:
def
transform
(
self
,
input_dict
:
dict
)
->
dict
:
"""
Call
function to filter points by the range.
"""
Transform
function to filter points by the range.
Args:
Args:
input_dict (dict): Result dict from loading pipeline.
input_dict (dict): Result dict from loading pipeline.
...
@@ -853,19 +876,27 @@ class PointsRangeFilter(object):
...
@@ -853,19 +876,27 @@ class PointsRangeFilter(object):
@
TRANSFORMS
.
register_module
()
@
TRANSFORMS
.
register_module
()
class
ObjectNameFilter
(
object
):
class
ObjectNameFilter
(
BaseTransform
):
"""Filter GT objects by their names.
"""Filter GT objects by their names.
Required Keys:
- gt_labels_3d
Modified Keys:
- gt_labels_3d
Args:
Args:
classes (list[str]): List of class names to be kept for training.
classes (list[str]): List of class names to be kept for training.
"""
"""
def
__init__
(
self
,
classes
):
def
__init__
(
self
,
classes
:
list
):
self
.
classes
=
classes
self
.
classes
=
classes
self
.
labels
=
list
(
range
(
len
(
self
.
classes
)))
self
.
labels
=
list
(
range
(
len
(
self
.
classes
)))
def
__call__
(
self
,
input_dict
)
:
def
transform
(
self
,
input_dict
:
dict
)
->
dict
:
"""
Call
function to filter objects by their names.
"""
Transform
function to filter objects by their names.
Args:
Args:
input_dict (dict): Result dict from loading pipeline.
input_dict (dict): Result dict from loading pipeline.
...
@@ -896,11 +927,13 @@ class PointSample(BaseTransform):
...
@@ -896,11 +927,13 @@ class PointSample(BaseTransform):
Sampling data to a certain number.
Sampling data to a certain number.
Required Keys:
Required Keys:
- points
- points
- pts_instance_mask (optional)
- pts_instance_mask (optional)
- pts_semantic_mask (optional)
- pts_semantic_mask (optional)
Modified Keys:
Modified Keys:
- points
- points
- pts_instance_mask (optional)
- pts_instance_mask (optional)
- pts_semantic_mask (optional)
- pts_semantic_mask (optional)
...
@@ -914,7 +947,10 @@ class PointSample(BaseTransform):
...
@@ -914,7 +947,10 @@ class PointSample(BaseTransform):
replacement. Defaults to False.
replacement. Defaults to False.
"""
"""
def
__init__
(
self
,
num_points
,
sample_range
=
None
,
replace
=
False
):
def
__init__
(
self
,
num_points
:
int
,
sample_range
:
float
=
None
,
replace
:
bool
=
False
):
self
.
num_points
=
num_points
self
.
num_points
=
num_points
self
.
sample_range
=
sample_range
self
.
sample_range
=
sample_range
self
.
replace
=
replace
self
.
replace
=
replace
...
@@ -967,7 +1003,7 @@ class PointSample(BaseTransform):
...
@@ -967,7 +1003,7 @@ class PointSample(BaseTransform):
else
:
else
:
return
points
[
choices
]
return
points
[
choices
]
def
transform
(
self
,
input_dict
:
D
ict
)
->
D
ict
:
def
transform
(
self
,
input_dict
:
d
ict
)
->
d
ict
:
"""Transform function to sample points to in indoor scenes.
"""Transform function to sample points to in indoor scenes.
Args:
Args:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment