Unverified Commit 24abf9b0 authored by encore-zhou's avatar encore-zhou Committed by GitHub
Browse files

[Feature]:add sweep point sample (#125)

* add sweep points sampler

* add sweep points sampler

* rebase master

* modify unittest

* modify voxel based point sampler

* modify voxel based point sampler

* fix bugs for voxel based point sampler

* modify voxel based points sampler

* fix bugs

* modify repr string
parent 873db382
......@@ -58,6 +58,19 @@ class VoxelGenerator(object):
"""np.ndarray: The size of grids."""
return self._grid_size
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
indent = ' ' * (len(repr_str) + 1)
repr_str += f'(voxel_size={self._voxel_size},\n'
repr_str += indent + 'point_cloud_range='
repr_str += f'{self._point_cloud_range.tolist()},\n'
repr_str += indent + f'max_num_points={self._max_num_points},\n'
repr_str += indent + f'max_voxels={self._max_voxels},\n'
repr_str += indent + f'grid_size={self._grid_size.tolist()}'
repr_str += ')'
return repr_str
def points_to_voxel(points,
voxel_size,
......
......@@ -9,7 +9,7 @@ from .pipelines import (BackgroundPointsFilter, GlobalRotScaleTrans,
LoadPointsFromFile, LoadPointsFromMultiSweeps,
NormalizePointsColor, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D)
RandomFlip3D, VoxelBasedPointSampler)
from .scannet_dataset import ScanNetDataset
from .sunrgbd_dataset import SUNRGBDDataset
from .waymo_dataset import WaymoDataset
......@@ -22,5 +22,6 @@ __all__ = [
'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D',
'LoadPointsFromFile', 'NormalizePointsColor', 'IndoorPointSample',
'LoadAnnotations3D', 'SUNRGBDDataset', 'ScanNetDataset', 'Custom3DDataset',
'LoadPointsFromMultiSweeps', 'WaymoDataset', 'BackgroundPointsFilter'
'LoadPointsFromMultiSweeps', 'WaymoDataset', 'BackgroundPointsFilter',
'VoxelBasedPointSampler'
]
......@@ -8,7 +8,7 @@ from .test_time_aug import MultiScaleFlipAug3D
from .transforms_3d import (BackgroundPointsFilter, GlobalRotScaleTrans,
IndoorPointSample, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D)
RandomFlip3D, VoxelBasedPointSampler)
__all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
......@@ -17,5 +17,5 @@ __all__ = [
'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler',
'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample',
'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps',
'BackgroundPointsFilter'
'BackgroundPointsFilter', 'VoxelBasedPointSampler'
]
......@@ -2,6 +2,7 @@ import numpy as np
from mmcv import is_tuple_of
from mmcv.utils import build_from_cfg
from mmdet3d.core import VoxelGenerator
from mmdet3d.core.bbox import box_np_ops
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import RandomFlip
......@@ -697,3 +698,143 @@ class BackgroundPointsFilter(object):
repr_str += '(bbox_enlarge_range={})'.format(
self.bbox_enlarge_range.tolist())
return repr_str
@PIPELINES.register_module()
class VoxelBasedPointSampler(object):
"""Voxel based point sampler.
Apply voxel sampling to multiple sweep points.
Args:
cur_sweep_cfg (dict): Config for sampling current points.
prev_sweep_cfg (dict): Config for sampling previous points.
time_dim (int): Index that indicate the time dimention
for input points.
"""
def __init__(self, cur_sweep_cfg, prev_sweep_cfg=None, time_dim=3):
self.cur_voxel_generator = VoxelGenerator(**cur_sweep_cfg)
self.cur_voxel_num = self.cur_voxel_generator._max_voxels
self.time_dim = time_dim
if prev_sweep_cfg is not None:
assert prev_sweep_cfg['max_num_points'] == \
cur_sweep_cfg['max_num_points']
self.prev_voxel_generator = VoxelGenerator(**prev_sweep_cfg)
self.prev_voxel_num = self.prev_voxel_generator._max_voxels
else:
self.prev_voxel_generator = None
self.prev_voxel_num = 0
def _sample_points(self, points, sampler, point_dim):
"""Sample points for each points subset.
Args:
points (np.ndarray): Points subset to be sampled.
sampler (VoxelGenerator): Voxel based sampler for
each points subset.
point_dim (int): The dimention of each points
Returns:
np.ndarray: Sampled points.
"""
voxels, coors, num_points_per_voxel = sampler.generate(points)
if voxels.shape[0] < sampler._max_voxels:
padding_points = np.zeros([
sampler._max_voxels - voxels.shape[0], sampler._max_num_points,
point_dim
],
dtype=points.dtype)
padding_points[:] = voxels[0]
sample_points = np.concatenate([voxels, padding_points], axis=0)
else:
sample_points = voxels
return sample_points
def __call__(self, results):
"""Call function to sample points from multiple sweeps.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after sampling, 'points', 'pts_instance_mask' \
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = results['points']
original_dim = points.shape[1]
# TODO: process instance and semantic mask while _max_num_points
# is larger than 1
# Extend points with seg and mask fields
map_fields2dim = []
start_dim = original_dim
extra_channel = [points]
for idx, key in enumerate(results['pts_mask_fields']):
map_fields2dim.append((key, idx + start_dim))
extra_channel.append(results[key][..., None])
start_dim += len(results['pts_mask_fields'])
for idx, key in enumerate(results['pts_seg_fields']):
map_fields2dim.append((key, idx + start_dim))
extra_channel.append(results[key][..., None])
points = np.concatenate(extra_channel, axis=-1)
# Split points into two part, current sweep points and
# previous sweeps points.
# TODO: support different sampling methods for next sweeps points
# and previous sweeps points.
cur_points_flag = (points[:, self.time_dim] == 0)
cur_sweep_points = points[cur_points_flag]
prev_sweeps_points = points[~cur_points_flag]
if prev_sweeps_points.shape[0] == 0:
prev_sweeps_points = cur_sweep_points
# Shuffle points before sampling
np.random.shuffle(cur_sweep_points)
np.random.shuffle(prev_sweeps_points)
cur_sweep_points = self._sample_points(cur_sweep_points,
self.cur_voxel_generator,
points.shape[1])
if self.prev_voxel_generator is not None:
prev_sweeps_points = self._sample_points(prev_sweeps_points,
self.prev_voxel_generator,
points.shape[1])
points = np.concatenate([cur_sweep_points, prev_sweeps_points], 0)
else:
points = cur_sweep_points
if self.cur_voxel_generator._max_num_points == 1:
points = points.squeeze(1)
results['points'] = points[..., :original_dim]
# Restore the correspoinding seg and mask fields
for key, dim_index in map_fields2dim:
results[key] = points[..., dim_index]
return results
def __repr__(self):
"""str: Return a string that describes the module."""
def _auto_indent(repr_str, indent):
repr_str = repr_str.split('\n')
repr_str = [' ' * indent + t + '\n' for t in repr_str]
repr_str = ''.join(repr_str)[:-1]
return repr_str
repr_str = self.__class__.__name__
indent = 4
repr_str += '(\n'
repr_str += ' ' * indent + f'num_cur_sweep={self.cur_voxel_num},\n'
repr_str += ' ' * indent + f'num_prev_sweep={self.prev_voxel_num},\n'
repr_str += ' ' * indent + f'time_dim={self.time_dim},\n'
repr_str += ' ' * indent + 'cur_voxel_generator=\n'
repr_str += f'{_auto_indent(repr(self.cur_voxel_generator), 8)},\n'
repr_str += ' ' * indent + 'prev_voxel_generator=\n'
repr_str += f'{_auto_indent(repr(self.prev_voxel_generator), 8)})'
return repr_str
......@@ -5,7 +5,8 @@ import torch
from mmdet3d.core import Box3DMode, CameraInstance3DBoxes, LiDARInstance3DBoxes
from mmdet3d.datasets import (BackgroundPointsFilter, ObjectNoise,
ObjectSample, RandomFlip3D)
ObjectSample, RandomFlip3D,
VoxelBasedPointSampler)
def test_remove_points_in_boxes():
......@@ -231,3 +232,75 @@ def test_background_points_filter():
# The length of bbox_enlarge_range should be 3
with pytest.raises(AssertionError):
BackgroundPointsFilter((0.5, 2.0))
def test_voxel_based_point_filter():
np.random.seed(0)
cur_sweep_cfg = dict(
voxel_size=[0.1, 0.1, 0.1],
point_cloud_range=[-50, -50, -4, 50, 50, 2],
max_num_points=1,
max_voxels=1024)
prev_sweep_cfg = dict(
voxel_size=[0.1, 0.1, 0.1],
point_cloud_range=[-50, -50, -4, 50, 50, 2],
max_num_points=1,
max_voxels=1024)
voxel_based_points_filter = VoxelBasedPointSampler(
cur_sweep_cfg, prev_sweep_cfg, time_dim=3)
points = np.stack([
np.random.rand(4096) * 120 - 60,
np.random.rand(4096) * 120 - 60,
np.random.rand(4096) * 10 - 6
],
axis=-1)
input_time = np.concatenate([np.zeros([2048, 1]), np.ones([2048, 1])], 0)
input_points = np.concatenate([points, input_time], 1)
input_dict = dict(
points=input_points, pts_mask_fields=[], pts_seg_fields=[])
input_dict = voxel_based_points_filter(input_dict)
points = input_dict['points']
repr_str = repr(voxel_based_points_filter)
expected_repr_str = """VoxelBasedPointSampler(
num_cur_sweep=1024,
num_prev_sweep=1024,
time_dim=3,
cur_voxel_generator=
VoxelGenerator(voxel_size=[0.1 0.1 0.1],
point_cloud_range=[-50.0, -50.0, -4.0, 50.0, 50.0, 2.0],
max_num_points=1,
max_voxels=1024,
grid_size=[1000, 1000, 60]),
prev_voxel_generator=
VoxelGenerator(voxel_size=[0.1 0.1 0.1],
point_cloud_range=[-50.0, -50.0, -4.0, 50.0, 50.0, 2.0],
max_num_points=1,
max_voxels=1024,
grid_size=[1000, 1000, 60]))"""
assert repr_str == expected_repr_str
assert points.shape == (2048, 4)
assert (points[:, :3].min(0) <
cur_sweep_cfg['point_cloud_range'][0:3]).sum() == 0
assert (points[:, :3].max(0) >
cur_sweep_cfg['point_cloud_range'][3:6]).sum() == 0
# Test instance mask and semantic mask
input_dict = dict(points=input_points)
input_dict['pts_instance_mask'] = np.random.randint(0, 10, [4096])
input_dict['pts_semantic_mask'] = np.random.randint(0, 6, [4096])
input_dict['pts_mask_fields'] = ['pts_instance_mask']
input_dict['pts_seg_fields'] = ['pts_semantic_mask']
input_dict = voxel_based_points_filter(input_dict)
pts_instance_mask = input_dict['pts_instance_mask']
pts_semantic_mask = input_dict['pts_semantic_mask']
assert pts_instance_mask.shape == (2048, )
assert pts_semantic_mask.shape == (2048, )
assert pts_instance_mask.max() < 10
assert pts_instance_mask.min() >= 0
assert pts_semantic_mask.max() < 6
assert pts_semantic_mask.min() >= 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment