Commit 8282f10c authored by VVsssssk's avatar VVsssssk Committed by ChaimZhu
Browse files

[Refactor]Fix ObjectRangeFilter + PointsRangeFilter + ObjectNameFilter

parent 84e479ea
# Copyright (c) OpenMMLab. All rights reserved.
import random
import warnings
from typing import Dict
import cv2
import numpy as np
......@@ -12,6 +11,7 @@ from mmengine.registry import build_from_cfg
from mmdet3d.core import VoxelGenerator
from mmdet3d.core.bbox import (CameraInstance3DBoxes, DepthInstance3DBoxes,
LiDARInstance3DBoxes, box_np_ops)
from mmdet3d.core.points import BasePoints
from mmdet3d.registry import TRANSFORMS
from mmdet.datasets.pipelines import RandomFlip
from .data_augment_utils import noise_per_object_v3_
......@@ -274,6 +274,7 @@ class ObjectSample(BaseTransform):
- gt_bboxes (optional)
Modified Keys:
- points
- gt_bboxes_3d
- gt_labels_3d
......@@ -293,7 +294,10 @@ class ObjectSample(BaseTransform):
3D labels.
"""
def __init__(self, db_sampler, sample_2d=False, use_ground_plane=False):
def __init__(self,
db_sampler: dict,
sample_2d: bool = False,
use_ground_plane: bool = False):
self.sampler_cfg = db_sampler
self.sample_2d = sample_2d
if 'type' not in db_sampler.keys():
......@@ -302,7 +306,8 @@ class ObjectSample(BaseTransform):
self.use_ground_plane = use_ground_plane
@staticmethod
def remove_points_in_boxes(points, boxes):
def remove_points_in_boxes(points: BasePoints,
boxes: np.ndarray) -> np.ndarray:
"""Remove the points in the sampled bounding boxes.
Args:
......@@ -422,10 +427,10 @@ class ObjectNoise(BaseTransform):
"""
def __init__(self,
translation_std=[0.25, 0.25, 0.25],
global_rot_range=[0.0, 0.0],
rot_range=[-0.15707963267, 0.15707963267],
num_try=100):
translation_std: list = [0.25, 0.25, 0.25],
global_rot_range: list = [0.0, 0.0],
rot_range: list = [-0.15707963267, 0.15707963267],
num_try: int = 100):
self.translation_std = translation_std
self.global_rot_range = global_rot_range
self.rot_range = rot_range
......@@ -756,18 +761,26 @@ class PointShuffle(object):
@TRANSFORMS.register_module()
class ObjectRangeFilter(object):
class ObjectRangeFilter(BaseTransform):
"""Filter objects by the range.
Required Keys:
- gt_bboxes_3d
Modified Keys:
- gt_bboxes_3d
Args:
point_cloud_range (list[float]): Point cloud range.
"""
def __init__(self, point_cloud_range):
def __init__(self, point_cloud_range: list):
self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
def __call__(self, input_dict):
"""Call function to filter objects by the range.
def transform(self, input_dict: dict) -> dict:
"""Transform function to filter objects by the range.
Args:
input_dict (dict): Result dict from loading pipeline.
......@@ -808,18 +821,28 @@ class ObjectRangeFilter(object):
@TRANSFORMS.register_module()
class PointsRangeFilter(object):
class PointsRangeFilter(BaseTransform):
"""Filter points by the range.
Required Keys:
- points
- pts_instance_mask (optional)
Modified Keys:
- points
- pts_instance_mask (optional)
Args:
point_cloud_range (list[float]): Point cloud range.
"""
def __init__(self, point_cloud_range):
def __init__(self, point_cloud_range: list):
self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
def __call__(self, input_dict):
"""Call function to filter points by the range.
def transform(self, input_dict: dict) -> dict:
"""Transform function to filter points by the range.
Args:
input_dict (dict): Result dict from loading pipeline.
......@@ -853,19 +876,27 @@ class PointsRangeFilter(object):
@TRANSFORMS.register_module()
class ObjectNameFilter(object):
class ObjectNameFilter(BaseTransform):
"""Filter GT objects by their names.
Required Keys:
- gt_labels_3d
Modified Keys:
- gt_labels_3d
Args:
classes (list[str]): List of class names to be kept for training.
"""
def __init__(self, classes):
def __init__(self, classes: list):
self.classes = classes
self.labels = list(range(len(self.classes)))
def __call__(self, input_dict):
"""Call function to filter objects by their names.
def transform(self, input_dict: dict) -> dict:
"""Transform function to filter objects by their names.
Args:
input_dict (dict): Result dict from loading pipeline.
......@@ -896,11 +927,13 @@ class PointSample(BaseTransform):
Sampling data to a certain number.
Required Keys:
- points
- pts_instance_mask (optional)
- pts_semantic_mask (optional)
Modified Keys:
- points
- pts_instance_mask (optional)
- pts_semantic_mask (optional)
......@@ -914,7 +947,10 @@ class PointSample(BaseTransform):
replacement. Defaults to False.
"""
def __init__(self, num_points, sample_range=None, replace=False):
def __init__(self,
num_points: int,
sample_range: float = None,
replace: bool = False):
self.num_points = num_points
self.sample_range = sample_range
self.replace = replace
......@@ -967,7 +1003,7 @@ class PointSample(BaseTransform):
else:
return points[choices]
def transform(self, input_dict: Dict) -> Dict:
def transform(self, input_dict: dict) -> dict:
"""Transform function to sample points to in indoor scenes.
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment