from typing import List import numpy as np from mmcv import BaseTransform from mmdet3d.registry import TRANSFORMS @TRANSFORMS.register_module() class ObjectRangeFilter3D(BaseTransform): """Filter objects by the range. It differs from `ObjectRangeFilter` by using `in_range_3d` instead of `in_range_bev`. Required Keys: - gt_bboxes_3d Modified Keys: - gt_bboxes_3d Args: point_cloud_range (list[float]): Point cloud range. """ def __init__(self, point_cloud_range: List[float]) -> None: self.pcd_range = np.array(point_cloud_range, dtype=np.float32) def transform(self, input_dict: dict) -> dict: """Transform function to filter objects by the range. Args: input_dict (dict): Result dict from loading pipeline. Returns: dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' keys are updated in the result dict. """ gt_bboxes_3d = input_dict['gt_bboxes_3d'] gt_labels_3d = input_dict['gt_labels_3d'] mask = gt_bboxes_3d.in_range_3d(self.pcd_range) gt_bboxes_3d = gt_bboxes_3d[mask] # mask is a torch tensor but gt_labels_3d is still numpy array # using mask to index gt_labels_3d will cause bug when # len(gt_labels_3d) == 1, where mask=1 will be interpreted # as gt_labels_3d[1] and cause out of index error gt_labels_3d = gt_labels_3d[mask.numpy().astype(bool)] # limit rad to [-pi, pi] gt_bboxes_3d.limit_yaw(offset=0.5, period=2 * np.pi) input_dict['gt_bboxes_3d'] = gt_bboxes_3d input_dict['gt_labels_3d'] = gt_labels_3d return input_dict def __repr__(self) -> str: """str: Return a string that describes the module.""" repr_str = self.__class__.__name__ repr_str += f'(point_cloud_range={self.pcd_range.tolist()})' return repr_str @TRANSFORMS.register_module() class PointsRangeFilter3D(BaseTransform): """Filter points by the range. It differs from `PointRangeFilter` by using `in_range_bev` instead of `in_range_3d`. Required Keys: - points - pts_instance_mask (optional) Modified Keys: - points - pts_instance_mask (optional) Args: point_cloud_range (list[float]): Point cloud range. """ def __init__(self, point_cloud_range: List[float]) -> None: self.pcd_range = np.array(point_cloud_range, dtype=np.float32) def transform(self, input_dict: dict) -> dict: """Transform function to filter points by the range. Args: input_dict (dict): Result dict from loading pipeline. Returns: dict: Results after filtering, 'points', 'pts_instance_mask' and 'pts_semantic_mask' keys are updated in the result dict. """ points = input_dict['points'] points_mask = points.in_range_bev(self.pcd_range[[0, 1, 3, 4]]) clean_points = points[points_mask] input_dict['points'] = clean_points points_mask = points_mask.numpy() pts_instance_mask = input_dict.get('pts_instance_mask', None) pts_semantic_mask = input_dict.get('pts_semantic_mask', None) if pts_instance_mask is not None: input_dict['pts_instance_mask'] = pts_instance_mask[points_mask] if pts_semantic_mask is not None: input_dict['pts_semantic_mask'] = pts_semantic_mask[points_mask] return input_dict def __repr__(self) -> str: """str: Return a string that describes the module.""" repr_str = self.__class__.__name__ repr_str += f'(point_cloud_range={self.pcd_range.tolist()})' return repr_str