Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
8282f10c
Commit
8282f10c
authored
May 25, 2022
by
VVsssssk
Committed by
ChaimZhu
Jul 20, 2022
Browse files
[Refactor]Fix ObjectRangeFilter + PointsRangeFilter + ObjectNameFilter
parent
84e479ea
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
21 deletions
+57
-21
mmdet3d/datasets/pipelines/transforms_3d.py
mmdet3d/datasets/pipelines/transforms_3d.py
+57
-21
No files found.
mmdet3d/datasets/pipelines/transforms_3d.py
View file @
8282f10c
# Copyright (c) OpenMMLab. All rights reserved.
import
random
import
warnings
from
typing
import
Dict
import
cv2
import
numpy
as
np
...
...
@@ -12,6 +11,7 @@ from mmengine.registry import build_from_cfg
from
mmdet3d.core
import
VoxelGenerator
from
mmdet3d.core.bbox
import
(
CameraInstance3DBoxes
,
DepthInstance3DBoxes
,
LiDARInstance3DBoxes
,
box_np_ops
)
from
mmdet3d.core.points
import
BasePoints
from
mmdet3d.registry
import
TRANSFORMS
from
mmdet.datasets.pipelines
import
RandomFlip
from
.data_augment_utils
import
noise_per_object_v3_
...
...
@@ -274,6 +274,7 @@ class ObjectSample(BaseTransform):
- gt_bboxes (optional)
Modified Keys:
- points
- gt_bboxes_3d
- gt_labels_3d
...
...
@@ -293,7 +294,10 @@ class ObjectSample(BaseTransform):
3D labels.
"""
def
__init__
(
self
,
db_sampler
,
sample_2d
=
False
,
use_ground_plane
=
False
):
def
__init__
(
self
,
db_sampler
:
dict
,
sample_2d
:
bool
=
False
,
use_ground_plane
:
bool
=
False
):
self
.
sampler_cfg
=
db_sampler
self
.
sample_2d
=
sample_2d
if
'type'
not
in
db_sampler
.
keys
():
...
...
@@ -302,7 +306,8 @@ class ObjectSample(BaseTransform):
self
.
use_ground_plane
=
use_ground_plane
@
staticmethod
def
remove_points_in_boxes
(
points
,
boxes
):
def
remove_points_in_boxes
(
points
:
BasePoints
,
boxes
:
np
.
ndarray
)
->
np
.
ndarray
:
"""Remove the points in the sampled bounding boxes.
Args:
...
...
@@ -422,10 +427,10 @@ class ObjectNoise(BaseTransform):
"""
def
__init__
(
self
,
translation_std
=
[
0.25
,
0.25
,
0.25
],
global_rot_range
=
[
0.0
,
0.0
],
rot_range
=
[
-
0.15707963267
,
0.15707963267
],
num_try
=
100
):
translation_std
:
list
=
[
0.25
,
0.25
,
0.25
],
global_rot_range
:
list
=
[
0.0
,
0.0
],
rot_range
:
list
=
[
-
0.15707963267
,
0.15707963267
],
num_try
:
int
=
100
):
self
.
translation_std
=
translation_std
self
.
global_rot_range
=
global_rot_range
self
.
rot_range
=
rot_range
...
...
@@ -756,18 +761,26 @@ class PointShuffle(object):
@
TRANSFORMS
.
register_module
()
class
ObjectRangeFilter
(
object
):
class
ObjectRangeFilter
(
BaseTransform
):
"""Filter objects by the range.
Required Keys:
- gt_bboxes_3d
Modified Keys:
- gt_bboxes_3d
Args:
point_cloud_range (list[float]): Point cloud range.
"""
def
__init__
(
self
,
point_cloud_range
):
def
__init__
(
self
,
point_cloud_range
:
list
):
self
.
pcd_range
=
np
.
array
(
point_cloud_range
,
dtype
=
np
.
float32
)
def
__call__
(
self
,
input_dict
)
:
"""
Call
function to filter objects by the range.
def
transform
(
self
,
input_dict
:
dict
)
->
dict
:
"""
Transform
function to filter objects by the range.
Args:
input_dict (dict): Result dict from loading pipeline.
...
...
@@ -808,18 +821,28 @@ class ObjectRangeFilter(object):
@
TRANSFORMS
.
register_module
()
class
PointsRangeFilter
(
object
):
class
PointsRangeFilter
(
BaseTransform
):
"""Filter points by the range.
Required Keys:
- points
- pts_instance_mask (optional)
Modified Keys:
- points
- pts_instance_mask (optional)
Args:
point_cloud_range (list[float]): Point cloud range.
"""
def
__init__
(
self
,
point_cloud_range
):
def
__init__
(
self
,
point_cloud_range
:
list
):
self
.
pcd_range
=
np
.
array
(
point_cloud_range
,
dtype
=
np
.
float32
)
def
__call__
(
self
,
input_dict
)
:
"""
Call
function to filter points by the range.
def
transform
(
self
,
input_dict
:
dict
)
->
dict
:
"""
Transform
function to filter points by the range.
Args:
input_dict (dict): Result dict from loading pipeline.
...
...
@@ -853,19 +876,27 @@ class PointsRangeFilter(object):
@
TRANSFORMS
.
register_module
()
class
ObjectNameFilter
(
object
):
class
ObjectNameFilter
(
BaseTransform
):
"""Filter GT objects by their names.
Required Keys:
- gt_labels_3d
Modified Keys:
- gt_labels_3d
Args:
classes (list[str]): List of class names to be kept for training.
"""
def
__init__
(
self
,
classes
):
def
__init__
(
self
,
classes
:
list
):
self
.
classes
=
classes
self
.
labels
=
list
(
range
(
len
(
self
.
classes
)))
def
__call__
(
self
,
input_dict
)
:
"""
Call
function to filter objects by their names.
def
transform
(
self
,
input_dict
:
dict
)
->
dict
:
"""
Transform
function to filter objects by their names.
Args:
input_dict (dict): Result dict from loading pipeline.
...
...
@@ -896,11 +927,13 @@ class PointSample(BaseTransform):
Sampling data to a certain number.
Required Keys:
- points
- pts_instance_mask (optional)
- pts_semantic_mask (optional)
Modified Keys:
- points
- pts_instance_mask (optional)
- pts_semantic_mask (optional)
...
...
@@ -914,7 +947,10 @@ class PointSample(BaseTransform):
replacement. Defaults to False.
"""
def
__init__
(
self
,
num_points
,
sample_range
=
None
,
replace
=
False
):
def
__init__
(
self
,
num_points
:
int
,
sample_range
:
float
=
None
,
replace
:
bool
=
False
):
self
.
num_points
=
num_points
self
.
sample_range
=
sample_range
self
.
replace
=
replace
...
...
@@ -967,7 +1003,7 @@ class PointSample(BaseTransform):
else
:
return
points
[
choices
]
def
transform
(
self
,
input_dict
:
D
ict
)
->
D
ict
:
def
transform
(
self
,
input_dict
:
d
ict
)
->
d
ict
:
"""Transform function to sample points to in indoor scenes.
Args:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment