Unverified Commit d4d7af24 authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Enhance] Pipeline function process points and masks simultaneously (#444)

* modify PointShuffle and add unittest

* modify PointsRangeFilter and add unittest

* fix small bugs in IndoorPointSample

* fix small typos in docstring
parent b035bc8e
......@@ -64,19 +64,19 @@ class Box3DMode(IntEnum):
"""Convert boxes from `src` mode to `dst` mode.
Args:
box (tuple | list | np.dnarray |
box (tuple | list | np.ndarray |
torch.Tensor | BaseInstance3DBoxes):
Can be a k-tuple, k-list or an Nxk array/tensor, where k = 7.
src (:obj:`BoxMode`): The src Box mode.
dst (:obj:`BoxMode`): The target Box mode.
rt_mat (np.dnarray | torch.Tensor): The rotation and translation
rt_mat (np.ndarray | torch.Tensor): The rotation and translation
matrix between different coordinates. Defaults to None.
The conversion from `src` coordinates to `dst` coordinates
usually comes along the change of sensors, e.g., from camera
to LiDAR. This requires a transformation matrix.
Returns:
(tuple | list | np.dnarray | torch.Tensor | BaseInstance3DBoxes): \
(tuple | list | np.ndarray | torch.Tensor | BaseInstance3DBoxes): \
The converted box of the same type.
"""
if src == dst:
......
......@@ -296,7 +296,7 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
Args:
dst (:obj:`BoxMode`): The target Box mode.
rt_mat (np.dnarray | torch.Tensor): The rotation and translation
rt_mat (np.ndarray | torch.Tensor): The rotation and translation
matrix between different coordinates. Defaults to None.
The conversion from ``src`` coordinates to ``dst`` coordinates
usually comes along the change of sensors, e.g., from camera
......
......@@ -77,19 +77,19 @@ class Coord3DMode(IntEnum):
"""Convert boxes from `src` mode to `dst` mode.
Args:
box (tuple | list | np.dnarray |
box (tuple | list | np.ndarray |
torch.Tensor | BaseInstance3DBoxes):
Can be a k-tuple, k-list or an Nxk array/tensor, where k = 7.
src (:obj:`CoordMode`): The src Box mode.
dst (:obj:`CoordMode`): The target Box mode.
rt_mat (np.dnarray | torch.Tensor): The rotation and translation
rt_mat (np.ndarray | torch.Tensor): The rotation and translation
matrix between different coordinates. Defaults to None.
The conversion from `src` coordinates to `dst` coordinates
usually comes along the change of sensors, e.g., from camera
to LiDAR. This requires a transformation matrix.
Returns:
(tuple | list | np.dnarray | torch.Tensor | BaseInstance3DBoxes): \
(tuple | list | np.ndarray | torch.Tensor | BaseInstance3DBoxes): \
The converted box of the same type.
"""
if src == dst:
......@@ -182,19 +182,19 @@ class Coord3DMode(IntEnum):
"""Convert points from `src` mode to `dst` mode.
Args:
point (tuple | list | np.dnarray |
point (tuple | list | np.ndarray |
torch.Tensor | BasePoints):
Can be a k-tuple, k-list or an Nxk array/tensor.
src (:obj:`CoordMode`): The src Point mode.
dst (:obj:`CoordMode`): The target Point mode.
rt_mat (np.dnarray | torch.Tensor): The rotation and translation
rt_mat (np.ndarray | torch.Tensor): The rotation and translation
matrix between different coordinates. Defaults to None.
The conversion from `src` coordinates to `dst` coordinates
usually comes along the change of sensors, e.g., from camera
to LiDAR. This requires a transformation matrix.
Returns:
(tuple | list | np.dnarray | torch.Tensor | BasePoints): \
(tuple | list | np.ndarray | torch.Tensor | BasePoints): \
The converted point of the same type.
"""
if src == dst:
......
......@@ -127,9 +127,14 @@ class BasePoints(object):
return self.tensor.shape
def shuffle(self):
"""Shuffle the points."""
self.tensor = self.tensor[torch.randperm(
self.__len__(), device=self.tensor.device)]
"""Shuffle the points.
Returns:
torch.Tensor: The shuffled index.
"""
idx = torch.randperm(self.__len__(), device=self.tensor.device)
self.tensor = self.tensor[idx]
return idx
def rotate(self, rotation, axis=None):
"""Rotate points with the given rotation matrix or angle.
......
......@@ -120,7 +120,7 @@ class RandomFlip3D(RandomFlip):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(sync_2d={self.sync_2d},'
repr_str += f'flip_ratio_bev_vertical={self.flip_ratio_bev_vertical})'
repr_str += f' flip_ratio_bev_vertical={self.flip_ratio_bev_vertical})'
return repr_str
......@@ -453,10 +453,21 @@ class PointShuffle(object):
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'points' keys are updated \
in the result dict.
dict: Results after filtering, 'points', 'pts_instance_mask' \
and 'pts_semantic_mask' keys are updated in the result dict.
"""
input_dict['points'].shuffle()
idx = input_dict['points'].shuffle()
idx = idx.numpy()
pts_instance_mask = input_dict.get('pts_instance_mask', None)
pts_semantic_mask = input_dict.get('pts_semantic_mask', None)
if pts_instance_mask is not None:
input_dict['pts_instance_mask'] = pts_instance_mask[idx]
if pts_semantic_mask is not None:
input_dict['pts_semantic_mask'] = pts_semantic_mask[idx]
return input_dict
def __repr__(self):
......@@ -527,13 +538,24 @@ class PointsRangeFilter(object):
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'points' keys are updated \
in the result dict.
dict: Results after filtering, 'points', 'pts_instance_mask' \
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = input_dict['points']
points_mask = points.in_range_3d(self.pcd_range)
clean_points = points[points_mask]
input_dict['points'] = clean_points
points_mask = points_mask.numpy()
pts_instance_mask = input_dict.get('pts_instance_mask', None)
pts_semantic_mask = input_dict.get('pts_semantic_mask', None)
if pts_instance_mask is not None:
input_dict['pts_instance_mask'] = pts_instance_mask[points_mask]
if pts_semantic_mask is not None:
input_dict['pts_semantic_mask'] = pts_semantic_mask[points_mask]
return input_dict
def __repr__(self):
......@@ -638,15 +660,17 @@ class IndoorPointSample(object):
points = results['points']
points, choices = self.points_random_sampling(
points, self.num_points, return_choices=True)
results['points'] = points
pts_instance_mask = results.get('pts_instance_mask', None)
pts_semantic_mask = results.get('pts_semantic_mask', None)
results['points'] = points
if pts_instance_mask is not None and pts_semantic_mask is not None:
if pts_instance_mask is not None:
pts_instance_mask = pts_instance_mask[choices]
pts_semantic_mask = pts_semantic_mask[choices]
results['pts_instance_mask'] = pts_instance_mask
if pts_semantic_mask is not None:
pts_semantic_mask = pts_semantic_mask[choices]
results['pts_semantic_mask'] = pts_semantic_mask
return results
......@@ -885,8 +909,8 @@ class BackgroundPointsFilter(object):
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'points' keys are updated \
in the result dict.
dict: Results after filtering, 'points', 'pts_instance_mask' \
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = input_dict['points']
gt_bboxes_3d = input_dict['gt_bboxes_3d']
......
......@@ -4,10 +4,10 @@ import pytest
import torch
from mmdet3d.core import Box3DMode, CameraInstance3DBoxes, LiDARInstance3DBoxes
from mmdet3d.core.points import LiDARPoints
from mmdet3d.core.points import DepthPoints, LiDARPoints
from mmdet3d.datasets import (BackgroundPointsFilter, ObjectNoise,
ObjectSample, RandomFlip3D,
VoxelBasedPointSampler)
ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D, VoxelBasedPointSampler)
def test_remove_points_in_boxes():
......@@ -139,6 +139,88 @@ def test_object_noise():
assert torch.allclose(gt_bboxes_3d, expected_gt_bboxes_3d, 1e-3)
def test_point_shuffle():
np.random.seed(0)
torch.manual_seed(0)
point_shuffle = PointShuffle()
points = np.fromfile('tests/data/scannet/points/scene0000_00.bin',
np.float32).reshape(-1, 6)
ins_mask = np.fromfile('tests/data/scannet/instance_mask/scene0000_00.bin',
np.long)
sem_mask = np.fromfile('tests/data/scannet/semantic_mask/scene0000_00.bin',
np.long)
points = DepthPoints(
points.copy(), points_dim=6, attribute_dims=dict(color=[3, 4, 5]))
input_dict = dict(
points=points.clone(),
pts_instance_mask=ins_mask.copy(),
pts_semantic_mask=sem_mask.copy())
results = point_shuffle(input_dict)
shuffle_pts = results['points']
shuffle_ins_mask = results['pts_instance_mask']
shuffle_sem_mask = results['pts_semantic_mask']
shuffle_idx = np.array([
44, 19, 93, 90, 71, 69, 37, 95, 53, 91, 81, 42, 80, 85, 74, 56, 76, 63,
82, 40, 26, 92, 57, 10, 16, 66, 89, 41, 97, 8, 31, 24, 35, 30, 65, 7,
98, 23, 20, 29, 78, 61, 94, 15, 4, 52, 59, 5, 54, 46, 3, 28, 2, 70, 6,
60, 49, 68, 55, 72, 79, 77, 45, 1, 32, 34, 11, 0, 22, 12, 87, 50, 25,
47, 36, 96, 9, 83, 62, 84, 18, 17, 75, 67, 13, 48, 39, 21, 64, 88, 38,
27, 14, 73, 33, 58, 86, 43, 99, 51
])
expected_pts = points.tensor.numpy()[shuffle_idx]
expected_ins_mask = ins_mask[shuffle_idx]
expected_sem_mask = sem_mask[shuffle_idx]
assert np.allclose(shuffle_pts.tensor.numpy(), expected_pts)
assert np.all(shuffle_ins_mask == expected_ins_mask)
assert np.all(shuffle_sem_mask == expected_sem_mask)
repr_str = repr(point_shuffle)
expected_repr_str = 'PointShuffle'
assert repr_str == expected_repr_str
def test_points_range_filter():
pcd_range = [0.0, 0.0, 0.0, 3.0, 3.0, 3.0]
points_range_filter = PointsRangeFilter(pcd_range)
points = np.fromfile('tests/data/scannet/points/scene0000_00.bin',
np.float32).reshape(-1, 6)
ins_mask = np.fromfile('tests/data/scannet/instance_mask/scene0000_00.bin',
np.long)
sem_mask = np.fromfile('tests/data/scannet/semantic_mask/scene0000_00.bin',
np.long)
points = DepthPoints(
points.copy(), points_dim=6, attribute_dims=dict(color=[3, 4, 5]))
input_dict = dict(
points=points.clone(),
pts_instance_mask=ins_mask.copy(),
pts_semantic_mask=sem_mask.copy())
results = points_range_filter(input_dict)
shuffle_pts = results['points']
shuffle_ins_mask = results['pts_instance_mask']
shuffle_sem_mask = results['pts_semantic_mask']
select_idx = np.array(
[5, 11, 22, 26, 27, 33, 46, 47, 56, 63, 74, 78, 79, 91])
expected_pts = points.tensor.numpy()[select_idx]
expected_ins_mask = ins_mask[select_idx]
expected_sem_mask = sem_mask[select_idx]
assert np.allclose(shuffle_pts.tensor.numpy(), expected_pts)
assert np.all(shuffle_ins_mask == expected_ins_mask)
assert np.all(shuffle_sem_mask == expected_sem_mask)
repr_str = repr(points_range_filter)
expected_repr_str = f'PointsRangeFilter(point_cloud_range={pcd_range})'
assert repr_str == expected_repr_str
def test_random_flip_3d():
random_flip_3d = RandomFlip3D(
flip_ratio_bev_horizontal=1.0, flip_ratio_bev_vertical=1.0)
......@@ -190,7 +272,7 @@ def test_random_flip_3d():
[5.0903, -5.1004, -1.2694, 0.7100, 1.7000, 1.8300, 5.0552]])
repr_str = repr(random_flip_3d)
expected_repr_str = 'RandomFlip3D(sync_2d=True,' \
'flip_ratio_bev_vertical=1.0)'
' flip_ratio_bev_vertical=1.0)'
assert np.allclose(points, expected_points)
assert torch.allclose(gt_bboxes_3d, expected_gt_bboxes_3d)
assert repr_str == expected_repr_str
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment