Commit 8282f10c authored by VVsssssk's avatar VVsssssk Committed by ChaimZhu
Browse files

[Refactor]Fix ObjectRangeFilter + PointsRangeFilter + ObjectNameFilter

parent 84e479ea
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import random import random
import warnings import warnings
from typing import Dict
import cv2 import cv2
import numpy as np import numpy as np
...@@ -12,6 +11,7 @@ from mmengine.registry import build_from_cfg ...@@ -12,6 +11,7 @@ from mmengine.registry import build_from_cfg
from mmdet3d.core import VoxelGenerator from mmdet3d.core import VoxelGenerator
from mmdet3d.core.bbox import (CameraInstance3DBoxes, DepthInstance3DBoxes, from mmdet3d.core.bbox import (CameraInstance3DBoxes, DepthInstance3DBoxes,
LiDARInstance3DBoxes, box_np_ops) LiDARInstance3DBoxes, box_np_ops)
from mmdet3d.core.points import BasePoints
from mmdet3d.registry import TRANSFORMS from mmdet3d.registry import TRANSFORMS
from mmdet.datasets.pipelines import RandomFlip from mmdet.datasets.pipelines import RandomFlip
from .data_augment_utils import noise_per_object_v3_ from .data_augment_utils import noise_per_object_v3_
...@@ -274,6 +274,7 @@ class ObjectSample(BaseTransform): ...@@ -274,6 +274,7 @@ class ObjectSample(BaseTransform):
- gt_bboxes (optional) - gt_bboxes (optional)
Modified Keys: Modified Keys:
- points - points
- gt_bboxes_3d - gt_bboxes_3d
- gt_labels_3d - gt_labels_3d
...@@ -293,7 +294,10 @@ class ObjectSample(BaseTransform): ...@@ -293,7 +294,10 @@ class ObjectSample(BaseTransform):
3D labels. 3D labels.
""" """
def __init__(self, db_sampler, sample_2d=False, use_ground_plane=False): def __init__(self,
db_sampler: dict,
sample_2d: bool = False,
use_ground_plane: bool = False):
self.sampler_cfg = db_sampler self.sampler_cfg = db_sampler
self.sample_2d = sample_2d self.sample_2d = sample_2d
if 'type' not in db_sampler.keys(): if 'type' not in db_sampler.keys():
...@@ -302,7 +306,8 @@ class ObjectSample(BaseTransform): ...@@ -302,7 +306,8 @@ class ObjectSample(BaseTransform):
self.use_ground_plane = use_ground_plane self.use_ground_plane = use_ground_plane
@staticmethod @staticmethod
def remove_points_in_boxes(points, boxes): def remove_points_in_boxes(points: BasePoints,
boxes: np.ndarray) -> np.ndarray:
"""Remove the points in the sampled bounding boxes. """Remove the points in the sampled bounding boxes.
Args: Args:
...@@ -422,10 +427,10 @@ class ObjectNoise(BaseTransform): ...@@ -422,10 +427,10 @@ class ObjectNoise(BaseTransform):
""" """
def __init__(self, def __init__(self,
translation_std=[0.25, 0.25, 0.25], translation_std: list = [0.25, 0.25, 0.25],
global_rot_range=[0.0, 0.0], global_rot_range: list = [0.0, 0.0],
rot_range=[-0.15707963267, 0.15707963267], rot_range: list = [-0.15707963267, 0.15707963267],
num_try=100): num_try: int = 100):
self.translation_std = translation_std self.translation_std = translation_std
self.global_rot_range = global_rot_range self.global_rot_range = global_rot_range
self.rot_range = rot_range self.rot_range = rot_range
...@@ -756,18 +761,26 @@ class PointShuffle(object): ...@@ -756,18 +761,26 @@ class PointShuffle(object):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class ObjectRangeFilter(object): class ObjectRangeFilter(BaseTransform):
"""Filter objects by the range. """Filter objects by the range.
Required Keys:
- gt_bboxes_3d
Modified Keys:
- gt_bboxes_3d
Args: Args:
point_cloud_range (list[float]): Point cloud range. point_cloud_range (list[float]): Point cloud range.
""" """
def __init__(self, point_cloud_range): def __init__(self, point_cloud_range: list):
self.pcd_range = np.array(point_cloud_range, dtype=np.float32) self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
def __call__(self, input_dict): def transform(self, input_dict: dict) -> dict:
"""Call function to filter objects by the range. """Transform function to filter objects by the range.
Args: Args:
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
...@@ -808,18 +821,28 @@ class ObjectRangeFilter(object): ...@@ -808,18 +821,28 @@ class ObjectRangeFilter(object):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class PointsRangeFilter(object): class PointsRangeFilter(BaseTransform):
"""Filter points by the range. """Filter points by the range.
Required Keys:
- points
- pts_instance_mask (optional)
Modified Keys:
- points
- pts_instance_mask (optional)
Args: Args:
point_cloud_range (list[float]): Point cloud range. point_cloud_range (list[float]): Point cloud range.
""" """
def __init__(self, point_cloud_range): def __init__(self, point_cloud_range: list):
self.pcd_range = np.array(point_cloud_range, dtype=np.float32) self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
def __call__(self, input_dict): def transform(self, input_dict: dict) -> dict:
"""Call function to filter points by the range. """Transform function to filter points by the range.
Args: Args:
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
...@@ -853,19 +876,27 @@ class PointsRangeFilter(object): ...@@ -853,19 +876,27 @@ class PointsRangeFilter(object):
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class ObjectNameFilter(object): class ObjectNameFilter(BaseTransform):
"""Filter GT objects by their names. """Filter GT objects by their names.
Required Keys:
- gt_labels_3d
Modified Keys:
- gt_labels_3d
Args: Args:
classes (list[str]): List of class names to be kept for training. classes (list[str]): List of class names to be kept for training.
""" """
def __init__(self, classes): def __init__(self, classes: list):
self.classes = classes self.classes = classes
self.labels = list(range(len(self.classes))) self.labels = list(range(len(self.classes)))
def __call__(self, input_dict): def transform(self, input_dict: dict) -> dict:
"""Call function to filter objects by their names. """Transform function to filter objects by their names.
Args: Args:
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
...@@ -896,11 +927,13 @@ class PointSample(BaseTransform): ...@@ -896,11 +927,13 @@ class PointSample(BaseTransform):
Sampling data to a certain number. Sampling data to a certain number.
Required Keys: Required Keys:
- points - points
- pts_instance_mask (optional) - pts_instance_mask (optional)
- pts_semantic_mask (optional) - pts_semantic_mask (optional)
Modified Keys: Modified Keys:
- points - points
- pts_instance_mask (optional) - pts_instance_mask (optional)
- pts_semantic_mask (optional) - pts_semantic_mask (optional)
...@@ -914,7 +947,10 @@ class PointSample(BaseTransform): ...@@ -914,7 +947,10 @@ class PointSample(BaseTransform):
replacement. Defaults to False. replacement. Defaults to False.
""" """
def __init__(self, num_points, sample_range=None, replace=False): def __init__(self,
num_points: int,
sample_range: float = None,
replace: bool = False):
self.num_points = num_points self.num_points = num_points
self.sample_range = sample_range self.sample_range = sample_range
self.replace = replace self.replace = replace
...@@ -967,7 +1003,7 @@ class PointSample(BaseTransform): ...@@ -967,7 +1003,7 @@ class PointSample(BaseTransform):
else: else:
return points[choices] return points[choices]
def transform(self, input_dict: Dict) -> Dict: def transform(self, input_dict: dict) -> dict:
"""Transform function to sample points to in indoor scenes. """Transform function to sample points to in indoor scenes.
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment