Unverified Commit 4040a074 authored by encore-zhou's avatar encore-zhou Committed by GitHub
Browse files

[feature] Add data_aug BackgroundPointsFilter (#84)

* add data_aug BackgroundPointsFilter

* modify BackgroundPointsFilter

* modify BackgroundPointsFilter
parent 16e8d143
from mmdet.datasets.builder import build_dataloader
from .builder import DATASETS, build_dataset
from mmdet.datasets.builder import DATASETS, build_dataloader, build_dataset
from .custom_3d import Custom3DDataset
from .kitti_dataset import KittiDataset
from .lyft_dataset import LyftDataset
from .nuscenes_dataset import NuScenesDataset
from .pipelines import (GlobalRotScaleTrans, IndoorPointSample,
LoadAnnotations3D, LoadPointsFromFile,
LoadPointsFromMultiSweeps, NormalizePointsColor,
ObjectNoise, ObjectRangeFilter, ObjectSample,
PointShuffle, PointsRangeFilter, RandomFlip3D)
from .pipelines import (BackgroundPointsFilter, GlobalRotScaleTrans,
IndoorPointSample, LoadAnnotations3D,
LoadPointsFromFile, LoadPointsFromMultiSweeps,
NormalizePointsColor, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D)
from .scannet_dataset import ScanNetDataset
from .sunrgbd_dataset import SUNRGBDDataset
......@@ -20,5 +20,5 @@ __all__ = [
'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D',
'LoadPointsFromFile', 'NormalizePointsColor', 'IndoorPointSample',
'LoadAnnotations3D', 'SUNRGBDDataset', 'ScanNetDataset', 'Custom3DDataset',
'LoadPointsFromMultiSweeps'
'LoadPointsFromMultiSweeps', 'BackgroundPointsFilter'
]
......@@ -5,9 +5,10 @@ from .loading import (LoadAnnotations3D, LoadMultiViewImageFromFiles,
LoadPointsFromFile, LoadPointsFromMultiSweeps,
NormalizePointsColor, PointSegClassMapping)
from .test_time_aug import MultiScaleFlipAug3D
from .transforms_3d import (GlobalRotScaleTrans, IndoorPointSample,
ObjectNoise, ObjectRangeFilter, ObjectSample,
PointShuffle, PointsRangeFilter, RandomFlip3D)
from .transforms_3d import (BackgroundPointsFilter, GlobalRotScaleTrans,
IndoorPointSample, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D)
__all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
......@@ -15,5 +16,6 @@ __all__ = [
'Compose', 'LoadMultiViewImageFromFiles', 'LoadPointsFromFile',
'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler',
'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample',
'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps'
'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps',
'BackgroundPointsFilter'
]
import numpy as np
from mmcv import is_tuple_of
from mmcv.utils import build_from_cfg
from mmdet3d.core.bbox import box_np_ops
......@@ -634,3 +635,65 @@ class IndoorPointSample(object):
repr_str = self.__class__.__name__
repr_str += '(num_points={})'.format(self.num_points)
return repr_str
@PIPELINES.register_module()
class BackgroundPointsFilter(object):
"""Filter background points near the bounding box.
Args:
bbox_enlarge_range (tuple[float], float): Bbox enlarge range.
"""
def __init__(self, bbox_enlarge_range):
assert (is_tuple_of(bbox_enlarge_range, float)
and len(bbox_enlarge_range) == 3) \
or isinstance(bbox_enlarge_range, float), \
f'Invalid arguments bbox_enlarge_range {bbox_enlarge_range}'
if isinstance(bbox_enlarge_range, float):
bbox_enlarge_range = [bbox_enlarge_range] * 3
self.bbox_enlarge_range = np.array(
bbox_enlarge_range, dtype=np.float32)[np.newaxis, :]
def __call__(self, input_dict):
"""Call function to filter points by the range.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'points' keys are updated \
in the result dict.
"""
points = input_dict['points']
gt_bboxes_3d = input_dict['gt_bboxes_3d']
gt_bboxes_3d_np = gt_bboxes_3d.tensor.numpy()
gt_bboxes_3d_np[:, :3] = gt_bboxes_3d.gravity_center.numpy()
enlarged_gt_bboxes_3d = gt_bboxes_3d_np.copy()
enlarged_gt_bboxes_3d[:, 3:6] += self.bbox_enlarge_range
foreground_masks = box_np_ops.points_in_rbbox(points, gt_bboxes_3d_np)
enlarge_foreground_masks = box_np_ops.points_in_rbbox(
points, enlarged_gt_bboxes_3d)
foreground_masks = foreground_masks.max(1)
enlarge_foreground_masks = enlarge_foreground_masks.max(1)
valid_masks = ~np.logical_and(~foreground_masks,
enlarge_foreground_masks)
input_dict['points'] = points[valid_masks]
pts_instance_mask = input_dict.get('pts_instance_mask', None)
if pts_instance_mask is not None:
input_dict['pts_instance_mask'] = pts_instance_mask[valid_masks]
pts_semantic_mask = input_dict.get('pts_semantic_mask', None)
if pts_semantic_mask is not None:
input_dict['pts_semantic_mask'] = pts_semantic_mask[valid_masks]
return input_dict
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += '(bbox_enlarge_range={})'.format(
self.bbox_enlarge_range.tolist())
return repr_str
import mmcv
import numpy as np
import pytest
import torch
from mmdet3d.core import Box3DMode, CameraInstance3DBoxes, LiDARInstance3DBoxes
from mmdet3d.datasets import ObjectNoise, ObjectSample, RandomFlip3D
from mmdet3d.datasets import (BackgroundPointsFilter, ObjectNoise,
ObjectSample, RandomFlip3D)
def test_remove_points_in_boxes():
......@@ -186,3 +188,45 @@ def test_random_flip_3d():
assert np.allclose(points, expected_points)
assert torch.allclose(gt_bboxes_3d, expected_gt_bboxes_3d)
assert repr_str == expected_repr_str
def test_background_points_filter():
np.random.seed(0)
background_points_filter = BackgroundPointsFilter((0.5, 2.0, 0.5))
points = np.fromfile(
'./tests/data/kitti/training/velodyne_reduced/000000.bin',
np.float32).reshape(-1, 4)
orig_points = points.copy()
annos = mmcv.load('./tests/data/kitti/kitti_infos_train.pkl')
info = annos[0]
rect = info['calib']['R0_rect'].astype(np.float32)
Trv2c = info['calib']['Tr_velo_to_cam'].astype(np.float32)
annos = info['annos']
loc = annos['location']
dims = annos['dimensions']
rots = annos['rotation_y']
gt_bboxes_3d = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1).astype(np.float32)
gt_bboxes_3d = CameraInstance3DBoxes(gt_bboxes_3d).convert_to(
Box3DMode.LIDAR, np.linalg.inv(rect @ Trv2c))
extra_points = gt_bboxes_3d.corners.reshape(8, 3)[[1, 2, 5, 6], :]
extra_points[:, 2] += 0.1
extra_points = torch.cat([extra_points, extra_points.new_zeros(4, 1)], 1)
points = np.concatenate([points, extra_points.numpy()], 0)
input_dict = dict(points=points, gt_bboxes_3d=gt_bboxes_3d)
input_dict = background_points_filter(input_dict)
points = input_dict['points']
repr_str = repr(background_points_filter)
expected_repr_str = 'BackgroundPointsFilter(bbox_enlarge_range=' \
'[[0.5, 2.0, 0.5]])'
assert repr_str == expected_repr_str
assert points.shape == (800, 4)
assert np.allclose(orig_points, points)
# test single float config
BackgroundPointsFilter(0.5)
# The length of bbox_enlarge_range should be 3
with pytest.raises(AssertionError):
BackgroundPointsFilter((0.5, 2.0))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment