Unverified Commit 24abf9b0 authored by encore-zhou's avatar encore-zhou Committed by GitHub
Browse files

[Feature]:add sweep point sample (#125)

* add sweep points sampler

* add sweep points sampler

* rebase master

* modify unittest

* modify voxel based point sampler

* modify voxel based point sampler

* fix bugs for voxel based point sampler

* modify voxel based points sampler

* fix bugs

* modify repr string
parent 873db382
...@@ -58,6 +58,19 @@ class VoxelGenerator(object): ...@@ -58,6 +58,19 @@ class VoxelGenerator(object):
"""np.ndarray: The size of grids.""" """np.ndarray: The size of grids."""
return self._grid_size return self._grid_size
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
indent = ' ' * (len(repr_str) + 1)
repr_str += f'(voxel_size={self._voxel_size},\n'
repr_str += indent + 'point_cloud_range='
repr_str += f'{self._point_cloud_range.tolist()},\n'
repr_str += indent + f'max_num_points={self._max_num_points},\n'
repr_str += indent + f'max_voxels={self._max_voxels},\n'
repr_str += indent + f'grid_size={self._grid_size.tolist()}'
repr_str += ')'
return repr_str
def points_to_voxel(points, def points_to_voxel(points,
voxel_size, voxel_size,
......
...@@ -9,7 +9,7 @@ from .pipelines import (BackgroundPointsFilter, GlobalRotScaleTrans, ...@@ -9,7 +9,7 @@ from .pipelines import (BackgroundPointsFilter, GlobalRotScaleTrans,
LoadPointsFromFile, LoadPointsFromMultiSweeps, LoadPointsFromFile, LoadPointsFromMultiSweeps,
NormalizePointsColor, ObjectNoise, ObjectRangeFilter, NormalizePointsColor, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter, ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D) RandomFlip3D, VoxelBasedPointSampler)
from .scannet_dataset import ScanNetDataset from .scannet_dataset import ScanNetDataset
from .sunrgbd_dataset import SUNRGBDDataset from .sunrgbd_dataset import SUNRGBDDataset
from .waymo_dataset import WaymoDataset from .waymo_dataset import WaymoDataset
...@@ -22,5 +22,6 @@ __all__ = [ ...@@ -22,5 +22,6 @@ __all__ = [
'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D', 'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D',
'LoadPointsFromFile', 'NormalizePointsColor', 'IndoorPointSample', 'LoadPointsFromFile', 'NormalizePointsColor', 'IndoorPointSample',
'LoadAnnotations3D', 'SUNRGBDDataset', 'ScanNetDataset', 'Custom3DDataset', 'LoadAnnotations3D', 'SUNRGBDDataset', 'ScanNetDataset', 'Custom3DDataset',
'LoadPointsFromMultiSweeps', 'WaymoDataset', 'BackgroundPointsFilter' 'LoadPointsFromMultiSweeps', 'WaymoDataset', 'BackgroundPointsFilter',
'VoxelBasedPointSampler'
] ]
...@@ -8,7 +8,7 @@ from .test_time_aug import MultiScaleFlipAug3D ...@@ -8,7 +8,7 @@ from .test_time_aug import MultiScaleFlipAug3D
from .transforms_3d import (BackgroundPointsFilter, GlobalRotScaleTrans, from .transforms_3d import (BackgroundPointsFilter, GlobalRotScaleTrans,
IndoorPointSample, ObjectNoise, ObjectRangeFilter, IndoorPointSample, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter, ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D) RandomFlip3D, VoxelBasedPointSampler)
__all__ = [ __all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
...@@ -17,5 +17,5 @@ __all__ = [ ...@@ -17,5 +17,5 @@ __all__ = [
'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler', 'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler',
'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample', 'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample',
'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps', 'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps',
'BackgroundPointsFilter' 'BackgroundPointsFilter', 'VoxelBasedPointSampler'
] ]
...@@ -2,6 +2,7 @@ import numpy as np ...@@ -2,6 +2,7 @@ import numpy as np
from mmcv import is_tuple_of from mmcv import is_tuple_of
from mmcv.utils import build_from_cfg from mmcv.utils import build_from_cfg
from mmdet3d.core import VoxelGenerator
from mmdet3d.core.bbox import box_np_ops from mmdet3d.core.bbox import box_np_ops
from mmdet.datasets.builder import PIPELINES from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import RandomFlip from mmdet.datasets.pipelines import RandomFlip
...@@ -697,3 +698,143 @@ class BackgroundPointsFilter(object): ...@@ -697,3 +698,143 @@ class BackgroundPointsFilter(object):
repr_str += '(bbox_enlarge_range={})'.format( repr_str += '(bbox_enlarge_range={})'.format(
self.bbox_enlarge_range.tolist()) self.bbox_enlarge_range.tolist())
return repr_str return repr_str
@PIPELINES.register_module()
class VoxelBasedPointSampler(object):
"""Voxel based point sampler.
Apply voxel sampling to multiple sweep points.
Args:
cur_sweep_cfg (dict): Config for sampling current points.
prev_sweep_cfg (dict): Config for sampling previous points.
time_dim (int): Index that indicate the time dimention
for input points.
"""
def __init__(self, cur_sweep_cfg, prev_sweep_cfg=None, time_dim=3):
self.cur_voxel_generator = VoxelGenerator(**cur_sweep_cfg)
self.cur_voxel_num = self.cur_voxel_generator._max_voxels
self.time_dim = time_dim
if prev_sweep_cfg is not None:
assert prev_sweep_cfg['max_num_points'] == \
cur_sweep_cfg['max_num_points']
self.prev_voxel_generator = VoxelGenerator(**prev_sweep_cfg)
self.prev_voxel_num = self.prev_voxel_generator._max_voxels
else:
self.prev_voxel_generator = None
self.prev_voxel_num = 0
def _sample_points(self, points, sampler, point_dim):
"""Sample points for each points subset.
Args:
points (np.ndarray): Points subset to be sampled.
sampler (VoxelGenerator): Voxel based sampler for
each points subset.
point_dim (int): The dimention of each points
Returns:
np.ndarray: Sampled points.
"""
voxels, coors, num_points_per_voxel = sampler.generate(points)
if voxels.shape[0] < sampler._max_voxels:
padding_points = np.zeros([
sampler._max_voxels - voxels.shape[0], sampler._max_num_points,
point_dim
],
dtype=points.dtype)
padding_points[:] = voxels[0]
sample_points = np.concatenate([voxels, padding_points], axis=0)
else:
sample_points = voxels
return sample_points
def __call__(self, results):
"""Call function to sample points from multiple sweeps.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after sampling, 'points', 'pts_instance_mask' \
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = results['points']
original_dim = points.shape[1]
# TODO: process instance and semantic mask while _max_num_points
# is larger than 1
# Extend points with seg and mask fields
map_fields2dim = []
start_dim = original_dim
extra_channel = [points]
for idx, key in enumerate(results['pts_mask_fields']):
map_fields2dim.append((key, idx + start_dim))
extra_channel.append(results[key][..., None])
start_dim += len(results['pts_mask_fields'])
for idx, key in enumerate(results['pts_seg_fields']):
map_fields2dim.append((key, idx + start_dim))
extra_channel.append(results[key][..., None])
points = np.concatenate(extra_channel, axis=-1)
# Split points into two part, current sweep points and
# previous sweeps points.
# TODO: support different sampling methods for next sweeps points
# and previous sweeps points.
cur_points_flag = (points[:, self.time_dim] == 0)
cur_sweep_points = points[cur_points_flag]
prev_sweeps_points = points[~cur_points_flag]
if prev_sweeps_points.shape[0] == 0:
prev_sweeps_points = cur_sweep_points
# Shuffle points before sampling
np.random.shuffle(cur_sweep_points)
np.random.shuffle(prev_sweeps_points)
cur_sweep_points = self._sample_points(cur_sweep_points,
self.cur_voxel_generator,
points.shape[1])
if self.prev_voxel_generator is not None:
prev_sweeps_points = self._sample_points(prev_sweeps_points,
self.prev_voxel_generator,
points.shape[1])
points = np.concatenate([cur_sweep_points, prev_sweeps_points], 0)
else:
points = cur_sweep_points
if self.cur_voxel_generator._max_num_points == 1:
points = points.squeeze(1)
results['points'] = points[..., :original_dim]
# Restore the correspoinding seg and mask fields
for key, dim_index in map_fields2dim:
results[key] = points[..., dim_index]
return results
def __repr__(self):
"""str: Return a string that describes the module."""
def _auto_indent(repr_str, indent):
repr_str = repr_str.split('\n')
repr_str = [' ' * indent + t + '\n' for t in repr_str]
repr_str = ''.join(repr_str)[:-1]
return repr_str
repr_str = self.__class__.__name__
indent = 4
repr_str += '(\n'
repr_str += ' ' * indent + f'num_cur_sweep={self.cur_voxel_num},\n'
repr_str += ' ' * indent + f'num_prev_sweep={self.prev_voxel_num},\n'
repr_str += ' ' * indent + f'time_dim={self.time_dim},\n'
repr_str += ' ' * indent + 'cur_voxel_generator=\n'
repr_str += f'{_auto_indent(repr(self.cur_voxel_generator), 8)},\n'
repr_str += ' ' * indent + 'prev_voxel_generator=\n'
repr_str += f'{_auto_indent(repr(self.prev_voxel_generator), 8)})'
return repr_str
...@@ -5,7 +5,8 @@ import torch ...@@ -5,7 +5,8 @@ import torch
from mmdet3d.core import Box3DMode, CameraInstance3DBoxes, LiDARInstance3DBoxes from mmdet3d.core import Box3DMode, CameraInstance3DBoxes, LiDARInstance3DBoxes
from mmdet3d.datasets import (BackgroundPointsFilter, ObjectNoise, from mmdet3d.datasets import (BackgroundPointsFilter, ObjectNoise,
ObjectSample, RandomFlip3D) ObjectSample, RandomFlip3D,
VoxelBasedPointSampler)
def test_remove_points_in_boxes(): def test_remove_points_in_boxes():
...@@ -231,3 +232,75 @@ def test_background_points_filter(): ...@@ -231,3 +232,75 @@ def test_background_points_filter():
# The length of bbox_enlarge_range should be 3 # The length of bbox_enlarge_range should be 3
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
BackgroundPointsFilter((0.5, 2.0)) BackgroundPointsFilter((0.5, 2.0))
def test_voxel_based_point_filter():
np.random.seed(0)
cur_sweep_cfg = dict(
voxel_size=[0.1, 0.1, 0.1],
point_cloud_range=[-50, -50, -4, 50, 50, 2],
max_num_points=1,
max_voxels=1024)
prev_sweep_cfg = dict(
voxel_size=[0.1, 0.1, 0.1],
point_cloud_range=[-50, -50, -4, 50, 50, 2],
max_num_points=1,
max_voxels=1024)
voxel_based_points_filter = VoxelBasedPointSampler(
cur_sweep_cfg, prev_sweep_cfg, time_dim=3)
points = np.stack([
np.random.rand(4096) * 120 - 60,
np.random.rand(4096) * 120 - 60,
np.random.rand(4096) * 10 - 6
],
axis=-1)
input_time = np.concatenate([np.zeros([2048, 1]), np.ones([2048, 1])], 0)
input_points = np.concatenate([points, input_time], 1)
input_dict = dict(
points=input_points, pts_mask_fields=[], pts_seg_fields=[])
input_dict = voxel_based_points_filter(input_dict)
points = input_dict['points']
repr_str = repr(voxel_based_points_filter)
expected_repr_str = """VoxelBasedPointSampler(
num_cur_sweep=1024,
num_prev_sweep=1024,
time_dim=3,
cur_voxel_generator=
VoxelGenerator(voxel_size=[0.1 0.1 0.1],
point_cloud_range=[-50.0, -50.0, -4.0, 50.0, 50.0, 2.0],
max_num_points=1,
max_voxels=1024,
grid_size=[1000, 1000, 60]),
prev_voxel_generator=
VoxelGenerator(voxel_size=[0.1 0.1 0.1],
point_cloud_range=[-50.0, -50.0, -4.0, 50.0, 50.0, 2.0],
max_num_points=1,
max_voxels=1024,
grid_size=[1000, 1000, 60]))"""
assert repr_str == expected_repr_str
assert points.shape == (2048, 4)
assert (points[:, :3].min(0) <
cur_sweep_cfg['point_cloud_range'][0:3]).sum() == 0
assert (points[:, :3].max(0) >
cur_sweep_cfg['point_cloud_range'][3:6]).sum() == 0
# Test instance mask and semantic mask
input_dict = dict(points=input_points)
input_dict['pts_instance_mask'] = np.random.randint(0, 10, [4096])
input_dict['pts_semantic_mask'] = np.random.randint(0, 6, [4096])
input_dict['pts_mask_fields'] = ['pts_instance_mask']
input_dict['pts_seg_fields'] = ['pts_semantic_mask']
input_dict = voxel_based_points_filter(input_dict)
pts_instance_mask = input_dict['pts_instance_mask']
pts_semantic_mask = input_dict['pts_semantic_mask']
assert pts_instance_mask.shape == (2048, )
assert pts_semantic_mask.shape == (2048, )
assert pts_instance_mask.max() < 10
assert pts_instance_mask.min() >= 0
assert pts_semantic_mask.max() < 6
assert pts_semantic_mask.min() >= 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment