Unverified Commit fc9e0d9d authored by Wenhao Wu's avatar Wenhao Wu Committed by GitHub
Browse files

[Enhance] Sampling points based on distance metric (#667)

* [Enhance] Sampling points based on distance metric

* fix typo

* refine unittest

* refine unittest

* refine details & add unittest & refine configs

* remove __repr__ & rename arg

* fix unittest

* add unitest

* refine unittest

* refine code

* refine code

* refine depth calculation

* refine code
parent d3213cd3
...@@ -52,7 +52,7 @@ train_pipeline = [ ...@@ -52,7 +52,7 @@ train_pipeline = [
scale_ratio_range=[0.9, 1.1]), scale_ratio_range=[0.9, 1.1]),
# 3DSSD can get a higher performance without this transform # 3DSSD can get a higher performance without this transform
# dict(type='BackgroundPointsFilter', bbox_enlarge_range=(0.5, 2.0, 0.5)), # dict(type='BackgroundPointsFilter', bbox_enlarge_range=(0.5, 2.0, 0.5)),
dict(type='IndoorPointSample', num_points=16384), dict(type='PointSample', num_points=16384),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
...@@ -78,7 +78,7 @@ test_pipeline = [ ...@@ -78,7 +78,7 @@ test_pipeline = [
dict(type='RandomFlip3D'), dict(type='RandomFlip3D'),
dict( dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range), type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='IndoorPointSample', num_points=16384), dict(type='PointSample', num_points=16384),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
class_names=class_names, class_names=class_names,
......
...@@ -20,7 +20,7 @@ train_pipeline = [ ...@@ -20,7 +20,7 @@ train_pipeline = [
rot_range=[-0.523599, 0.523599], rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15], scale_ratio_range=[0.85, 1.15],
shift_height=True), shift_height=True),
dict(type='IndoorPointSample', num_points=20000), dict(type='PointSample', num_points=20000),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
...@@ -47,7 +47,7 @@ test_pipeline = [ ...@@ -47,7 +47,7 @@ test_pipeline = [
sync_2d=False, sync_2d=False,
flip_ratio_bev_horizontal=0.5, flip_ratio_bev_horizontal=0.5,
), ),
dict(type='IndoorPointSample', num_points=20000), dict(type='PointSample', num_points=20000),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
class_names=class_names, class_names=class_names,
......
...@@ -187,7 +187,7 @@ train_pipeline = [ ...@@ -187,7 +187,7 @@ train_pipeline = [
rot_range=[-0.523599, 0.523599], rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15], scale_ratio_range=[0.85, 1.15],
shift_height=True), shift_height=True),
dict(type='IndoorPointSample', num_points=20000), dict(type='PointSample', num_points=20000),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='DefaultFormatBundle3D', class_names=class_names),
dict( dict(
type='Collect3D', type='Collect3D',
...@@ -225,7 +225,7 @@ test_pipeline = [ ...@@ -225,7 +225,7 @@ test_pipeline = [
sync_2d=False, sync_2d=False,
flip_ratio_bev_horizontal=0.5, flip_ratio_bev_horizontal=0.5,
), ),
dict(type='IndoorPointSample', num_points=20000), dict(type='PointSample', num_points=20000),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
class_names=class_names, class_names=class_names,
......
...@@ -203,7 +203,7 @@ train_pipeline = [ # Training pipeline, refer to mmdet3d.datasets.pipelines for ...@@ -203,7 +203,7 @@ train_pipeline = [ # Training pipeline, refer to mmdet3d.datasets.pipelines for
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39), # all valid categories ids 36, 39), # all valid categories ids
max_cat_id=40), # max possible category id in input segmentation mask max_cat_id=40), # max possible category id in input segmentation mask
dict(type='IndoorPointSample', # Sample indoor points, refer to mmdet3d.datasets.pipelines.indoor_sample for more details dict(type='PointSample', # Sample points, refer to mmdet3d.datasets.pipelines.transforms_3d for more details
num_points=40000), # Number of points to be sampled num_points=40000), # Number of points to be sampled
dict(type='IndoorFlipData', # Augmentation pipeline that flip points and 3d boxes dict(type='IndoorFlipData', # Augmentation pipeline that flip points and 3d boxes
flip_ratio_yz=0.5, # Probability of being flipped along yz plane flip_ratio_yz=0.5, # Probability of being flipped along yz plane
...@@ -232,7 +232,7 @@ test_pipeline = [ # Testing pipeline, refer to mmdet3d.datasets.pipelines for m ...@@ -232,7 +232,7 @@ test_pipeline = [ # Testing pipeline, refer to mmdet3d.datasets.pipelines for m
shift_height=True, # Whether to use shifted height shift_height=True, # Whether to use shifted height
load_dim=6, # The dimension of the loaded points load_dim=6, # The dimension of the loaded points
use_dim=[0, 1, 2]), # Which dimensions of the points to be used use_dim=[0, 1, 2]), # Which dimensions of the points to be used
dict(type='IndoorPointSample', # Sample indoor points, refer to mmdet3d.datasets.pipelines.indoor_sample for more details dict(type='PointSample', # Sample points, refer to mmdet3d.datasets.pipelines.transforms_3d for more details
num_points=40000), # Number of points to be sampled num_points=40000), # Number of points to be sampled
dict( dict(
type='DefaultFormatBundle3D', # Default format bundle to gather data in the pipeline, refer to mmdet3d.datasets.pipelines.formating for more details type='DefaultFormatBundle3D', # Default format bundle to gather data in the pipeline, refer to mmdet3d.datasets.pipelines.formating for more details
...@@ -286,7 +286,7 @@ data = dict( ...@@ -286,7 +286,7 @@ data = dict(
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
28, 33, 34, 36, 39), 28, 33, 34, 36, 39),
max_cat_id=40), max_cat_id=40),
dict(type='IndoorPointSample', num_points=40000), dict(type='PointSample', num_points=40000),
dict( dict(
type='IndoorFlipData', type='IndoorFlipData',
flip_ratio_yz=0.5, flip_ratio_yz=0.5,
...@@ -325,7 +325,7 @@ data = dict( ...@@ -325,7 +325,7 @@ data = dict(
shift_height=True, shift_height=True,
load_dim=6, load_dim=6,
use_dim=[0, 1, 2]), use_dim=[0, 1, 2]),
dict(type='IndoorPointSample', num_points=40000), dict(type='PointSample', num_points=40000),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
class_names=('cabinet', 'bed', 'chair', 'sofa', 'table', class_names=('cabinet', 'bed', 'chair', 'sofa', 'table',
...@@ -350,7 +350,7 @@ data = dict( ...@@ -350,7 +350,7 @@ data = dict(
shift_height=True, shift_height=True,
load_dim=6, load_dim=6,
use_dim=[0, 1, 2]), use_dim=[0, 1, 2]),
dict(type='IndoorPointSample', num_points=40000), dict(type='PointSample', num_points=40000),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
class_names=('cabinet', 'bed', 'chair', 'sofa', 'table', class_names=('cabinet', 'bed', 'chair', 'sofa', 'table',
......
...@@ -281,6 +281,8 @@ class BasePoints(object): ...@@ -281,6 +281,8 @@ class BasePoints(object):
Nonzero elements in the vector will be selected. Nonzero elements in the vector will be selected.
4. `new_points = points[3:11, vector]`: 4. `new_points = points[3:11, vector]`:
return a slice of points and attribute dims. return a slice of points and attribute dims.
5. `new_points = points[4:12, 2]`:
return a slice of points with single attribute.
Note that the returned Points might share storage with this Points, Note that the returned Points might share storage with this Points,
subject to Pytorch's indexing semantics. subject to Pytorch's indexing semantics.
...@@ -303,6 +305,10 @@ class BasePoints(object): ...@@ -303,6 +305,10 @@ class BasePoints(object):
item = list(item) item = list(item)
item[1] = list(range(start, stop, step)) item[1] = list(range(start, stop, step))
item = tuple(item) item = tuple(item)
elif isinstance(item[1], int):
item = list(item)
item[1] = [item[1]]
item = tuple(item)
p = self.tensor[item[0], item[1]] p = self.tensor[item[0], item[1]]
keep_dims = list( keep_dims = list(
......
...@@ -7,14 +7,17 @@ from .kitti_mono_dataset import KittiMonoDataset ...@@ -7,14 +7,17 @@ from .kitti_mono_dataset import KittiMonoDataset
from .lyft_dataset import LyftDataset from .lyft_dataset import LyftDataset
from .nuscenes_dataset import NuScenesDataset from .nuscenes_dataset import NuScenesDataset
from .nuscenes_mono_dataset import NuScenesMonoDataset from .nuscenes_mono_dataset import NuScenesMonoDataset
# yapf: disable
from .pipelines import (BackgroundPointsFilter, GlobalAlignment, from .pipelines import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, IndoorPatchPointSample, GlobalRotScaleTrans, IndoorPatchPointSample,
IndoorPointSample, LoadAnnotations3D, IndoorPointSample, LoadAnnotations3D,
LoadPointsFromFile, LoadPointsFromMultiSweeps, LoadPointsFromFile, LoadPointsFromMultiSweeps,
NormalizePointsColor, ObjectNameFilter, ObjectNoise, NormalizePointsColor, ObjectNameFilter, ObjectNoise,
ObjectRangeFilter, ObjectSample, PointShuffle, ObjectRangeFilter, ObjectSample, PointSample,
PointsRangeFilter, RandomDropPointsColor, RandomFlip3D, PointShuffle, PointsRangeFilter, RandomDropPointsColor,
RandomJitterPoints, VoxelBasedPointSampler) RandomFlip3D, RandomJitterPoints,
VoxelBasedPointSampler)
# yapf: enable
from .s3dis_dataset import S3DISSegDataset from .s3dis_dataset import S3DISSegDataset
from .scannet_dataset import ScanNetDataset, ScanNetSegDataset from .scannet_dataset import ScanNetDataset, ScanNetSegDataset
from .semantickitti_dataset import SemanticKITTIDataset from .semantickitti_dataset import SemanticKITTIDataset
...@@ -30,9 +33,10 @@ __all__ = [ ...@@ -30,9 +33,10 @@ __all__ = [
'ObjectNoise', 'GlobalRotScaleTrans', 'PointShuffle', 'ObjectRangeFilter', 'ObjectNoise', 'GlobalRotScaleTrans', 'PointShuffle', 'ObjectRangeFilter',
'PointsRangeFilter', 'Collect3D', 'LoadPointsFromFile', 'S3DISSegDataset', 'PointsRangeFilter', 'Collect3D', 'LoadPointsFromFile', 'S3DISSegDataset',
'NormalizePointsColor', 'IndoorPatchPointSample', 'IndoorPointSample', 'NormalizePointsColor', 'IndoorPatchPointSample', 'IndoorPointSample',
'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset', 'ScanNetDataset', 'PointSample', 'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset',
'ScanNetSegDataset', 'SemanticKITTIDataset', 'Custom3DDataset', 'ScanNetDataset', 'ScanNetSegDataset', 'SemanticKITTIDataset',
'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'WaymoDataset', 'Custom3DDataset', 'Custom3DSegDataset', 'LoadPointsFromMultiSweeps',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'get_loading_pipeline', 'WaymoDataset', 'BackgroundPointsFilter', 'VoxelBasedPointSampler',
'RandomDropPointsColor', 'RandomJitterPoints', 'ObjectNameFilter' 'get_loading_pipeline', 'RandomDropPointsColor', 'RandomJitterPoints',
'ObjectNameFilter'
] ]
...@@ -9,10 +9,10 @@ from .test_time_aug import MultiScaleFlipAug3D ...@@ -9,10 +9,10 @@ from .test_time_aug import MultiScaleFlipAug3D
from .transforms_3d import (BackgroundPointsFilter, GlobalAlignment, from .transforms_3d import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, IndoorPatchPointSample, GlobalRotScaleTrans, IndoorPatchPointSample,
IndoorPointSample, ObjectNameFilter, ObjectNoise, IndoorPointSample, ObjectNameFilter, ObjectNoise,
ObjectRangeFilter, ObjectSample, PointShuffle, ObjectRangeFilter, ObjectSample, PointSample,
PointsRangeFilter, RandomDropPointsColor, PointShuffle, PointsRangeFilter,
RandomFlip3D, RandomJitterPoints, RandomDropPointsColor, RandomFlip3D,
VoxelBasedPointSampler) RandomJitterPoints, VoxelBasedPointSampler)
__all__ = [ __all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
...@@ -20,8 +20,9 @@ __all__ = [ ...@@ -20,8 +20,9 @@ __all__ = [
'Compose', 'LoadMultiViewImageFromFiles', 'LoadPointsFromFile', 'Compose', 'LoadMultiViewImageFromFiles', 'LoadPointsFromFile',
'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler', 'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler',
'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample', 'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample',
'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps', 'PointSample', 'PointSegClassMapping', 'MultiScaleFlipAug3D',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'GlobalAlignment', 'LoadPointsFromMultiSweeps', 'BackgroundPointsFilter',
'IndoorPatchPointSample', 'LoadImageFromFileMono3D', 'ObjectNameFilter', 'VoxelBasedPointSampler', 'GlobalAlignment', 'IndoorPatchPointSample',
'RandomDropPointsColor', 'RandomJitterPoints' 'LoadImageFromFileMono3D', 'ObjectNameFilter', 'RandomDropPointsColor',
'RandomJitterPoints'
] ]
...@@ -838,24 +838,27 @@ class ObjectNameFilter(object): ...@@ -838,24 +838,27 @@ class ObjectNameFilter(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class IndoorPointSample(object): class PointSample(object):
"""Indoor point sample. """Point sample.
Sampling data to a certain number. Sampling data to a certain number.
Args: Args:
name (str): Name of the dataset.
num_points (int): Number of points to be sampled. num_points (int): Number of points to be sampled.
sample_range (float, optional): The range where to sample points.
""" """
def __init__(self, num_points): def __init__(self, num_points, sample_range=None, replace=False):
self.num_points = num_points self.num_points = num_points
self.sample_range = sample_range
def points_random_sampling(self, self.replace = replace
points,
num_samples, def _points_random_sampling(self,
replace=None, points,
return_choices=False): num_samples,
sample_range=None,
replace=False,
return_choices=False):
"""Points random sampling. """Points random sampling.
Sample points to a certain number. Sample points to a certain number.
...@@ -863,20 +866,33 @@ class IndoorPointSample(object): ...@@ -863,20 +866,33 @@ class IndoorPointSample(object):
Args: Args:
points (np.ndarray | :obj:`BasePoints`): 3D Points. points (np.ndarray | :obj:`BasePoints`): 3D Points.
num_samples (int): Number of samples to be sampled. num_samples (int): Number of samples to be sampled.
replace (bool): Whether the sample is with or without replacement. sample_range (float, optional): Indicating the range where the
Defaults to None. points will be sampled.
return_choices (bool): Whether return choice. Defaults to False. Defaults to None.
replace (bool, optional): Sampling with or without replacement.
Defaults to None.
return_choices (bool, optional): Whether return choice.
Defaults to False.
Returns: Returns:
tuple[np.ndarray] | np.ndarray: tuple[np.ndarray] | np.ndarray:
- points (np.ndarray | :obj:`BasePoints`): 3D Points. - points (np.ndarray | :obj:`BasePoints`): 3D Points.
- choices (np.ndarray, optional): The generated random samples. - choices (np.ndarray, optional): The generated random samples.
""" """
if replace is None: if not replace:
replace = (points.shape[0] < num_samples) replace = (points.shape[0] < num_samples)
choices = np.random.choice( point_range = range(len(points))
points.shape[0], num_samples, replace=replace) if sample_range is not None and not replace:
# Only sampling the near points when len(points) >= num_samples
depth = np.linalg.norm(points.tensor, axis=1)
far_inds = np.where(depth > sample_range)[0]
near_inds = np.where(depth <= sample_range)[0]
point_range = near_inds
num_samples -= len(far_inds)
choices = np.random.choice(point_range, num_samples, replace=replace)
if sample_range is not None and not replace:
choices = np.concatenate((far_inds, choices))
# Shuffle points after sampling
np.random.shuffle(choices)
if return_choices: if return_choices:
return points[choices], choices return points[choices], choices
else: else:
...@@ -887,14 +903,23 @@ class IndoorPointSample(object): ...@@ -887,14 +903,23 @@ class IndoorPointSample(object):
Args: Args:
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after sampling, 'points', 'pts_instance_mask' \ dict: Results after sampling, 'points', 'pts_instance_mask' \
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
""" """
from mmdet3d.core.points import CameraPoints
points = results['points'] points = results['points']
points, choices = self.points_random_sampling( # Points in Camera coord can provide the depth information.
points, self.num_points, return_choices=True) # TODO: Need to suport distance-based sampling for other coord system.
if self.sample_range is not None:
assert isinstance(points, CameraPoints), \
'Sampling based on distance is only appliable for CAMERA coord'
points, choices = self._points_random_sampling(
points,
self.num_points,
self.sample_range,
self.replace,
return_choices=True)
results['points'] = points results['points'] = points
pts_instance_mask = results.get('pts_instance_mask', None) pts_instance_mask = results.get('pts_instance_mask', None)
...@@ -913,10 +938,29 @@ class IndoorPointSample(object): ...@@ -913,10 +938,29 @@ class IndoorPointSample(object):
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(num_points={self.num_points})' repr_str += f'(num_points={self.num_points},'
repr_str += f' sample_range={self.sample_range})'
return repr_str return repr_str
@PIPELINES.register_module()
class IndoorPointSample(PointSample):
"""Indoor point sample.
Sampling data to a certain number.
NOTE: IndoorPointSample is deprecated in favor of PointSample
Args:
num_points (int): Number of points to be sampled.
"""
def __init__(self, *args, **kwargs):
warnings.warn(
'IndoorPointSample is deprecated in favor of PointSample')
super(IndoorPointSample, self).__init__(*args, **kwargs)
@PIPELINES.register_module() @PIPELINES.register_module()
class IndoorPatchPointSample(object): class IndoorPatchPointSample(object):
r"""Indoor point sample within a patch. Modified from `PointNet++ <https:// r"""Indoor point sample within a patch. Modified from `PointNet++ <https://
......
...@@ -32,7 +32,7 @@ def test_getitem(): ...@@ -32,7 +32,7 @@ def test_getitem():
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33,
34, 36, 39)), 34, 36, 39)),
dict(type='IndoorPointSample', num_points=5), dict(type='PointSample', num_points=5),
dict( dict(
type='RandomFlip3D', type='RandomFlip3D',
sync_2d=False, sync_2d=False,
......
...@@ -28,7 +28,7 @@ def _generate_sunrgbd_dataset_config(): ...@@ -28,7 +28,7 @@ def _generate_sunrgbd_dataset_config():
rot_range=[-0.523599, 0.523599], rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15], scale_ratio_range=[0.85, 1.15],
shift_height=True), shift_height=True),
dict(type='IndoorPointSample', num_points=5), dict(type='PointSample', num_points=5),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='DefaultFormatBundle3D', class_names=class_names),
dict( dict(
type='Collect3D', type='Collect3D',
...@@ -73,7 +73,7 @@ def _generate_sunrgbd_multi_modality_dataset_config(): ...@@ -73,7 +73,7 @@ def _generate_sunrgbd_multi_modality_dataset_config():
rot_range=[-0.523599, 0.523599], rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15], scale_ratio_range=[0.85, 1.15],
shift_height=True), shift_height=True),
dict(type='IndoorPointSample', num_points=5), dict(type='PointSample', num_points=5),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='DefaultFormatBundle3D', class_names=class_names),
dict( dict(
type='Collect3D', type='Collect3D',
......
...@@ -17,7 +17,7 @@ def test_multi_scale_flip_aug_3D(): ...@@ -17,7 +17,7 @@ def test_multi_scale_flip_aug_3D():
'sync_2d': False, 'sync_2d': False,
'flip_ratio_bev_horizontal': 0.5 'flip_ratio_bev_horizontal': 0.5
}, { }, {
'type': 'IndoorPointSample', 'type': 'PointSample',
'num_points': 5 'num_points': 5
}, { }, {
'type': 'type':
......
...@@ -5,11 +5,12 @@ import torch ...@@ -5,11 +5,12 @@ import torch
from mmdet3d.core import (Box3DMode, CameraInstance3DBoxes, from mmdet3d.core import (Box3DMode, CameraInstance3DBoxes,
DepthInstance3DBoxes, LiDARInstance3DBoxes) DepthInstance3DBoxes, LiDARInstance3DBoxes)
from mmdet3d.core.bbox import Coord3DMode
from mmdet3d.core.points import DepthPoints, LiDARPoints from mmdet3d.core.points import DepthPoints, LiDARPoints
from mmdet3d.datasets import (BackgroundPointsFilter, GlobalAlignment, from mmdet3d.datasets import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, ObjectNameFilter, GlobalRotScaleTrans, ObjectNameFilter,
ObjectNoise, ObjectRangeFilter, ObjectSample, ObjectNoise, ObjectRangeFilter, ObjectSample,
PointShuffle, PointsRangeFilter, PointSample, PointShuffle, PointsRangeFilter,
RandomDropPointsColor, RandomFlip3D, RandomDropPointsColor, RandomFlip3D,
RandomJitterPoints, VoxelBasedPointSampler) RandomJitterPoints, VoxelBasedPointSampler)
...@@ -705,3 +706,36 @@ def test_voxel_based_point_filter(): ...@@ -705,3 +706,36 @@ def test_voxel_based_point_filter():
assert pts_instance_mask.min() >= 0 assert pts_instance_mask.min() >= 0
assert pts_semantic_mask.max() < 6 assert pts_semantic_mask.max() < 6
assert pts_semantic_mask.min() >= 0 assert pts_semantic_mask.min() >= 0
def test_points_sample():
np.random.seed(0)
points = np.fromfile(
'./tests/data/kitti/training/velodyne_reduced/000000.bin',
np.float32).reshape(-1, 4)
annos = mmcv.load('./tests/data/kitti/kitti_infos_train.pkl')
info = annos[0]
rect = torch.tensor(info['calib']['R0_rect'].astype(np.float32))
Trv2c = torch.tensor(info['calib']['Tr_velo_to_cam'].astype(np.float32))
points = LiDARPoints(
points.copy(), points_dim=4).convert_to(Coord3DMode.CAM, rect @ Trv2c)
num_points = 20
sample_range = 40
input_dict = dict(points=points)
point_sample = PointSample(
num_points=num_points, sample_range=sample_range)
sampled_pts = point_sample(input_dict)['points']
select_idx = np.array([
622, 146, 231, 444, 504, 533, 80, 401, 379, 2, 707, 562, 176, 491, 496,
464, 15, 590, 194, 449
])
expected_pts = points.tensor.numpy()[select_idx]
assert np.allclose(sampled_pts.tensor.numpy(), expected_pts)
repr_str = repr(point_sample)
expected_repr_str = f'PointSample(num_points={num_points},'\
+ f' sample_range={sample_range})'
assert repr_str == expected_repr_str
...@@ -32,7 +32,7 @@ def test_scannet_pipeline(): ...@@ -32,7 +32,7 @@ def test_scannet_pipeline():
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33,
34, 36, 39)), 34, 36, 39)),
dict(type='IndoorPointSample', num_points=5), dict(type='PointSample', num_points=5),
dict( dict(
type='RandomFlip3D', type='RandomFlip3D',
sync_2d=False, sync_2d=False,
...@@ -278,7 +278,7 @@ def test_sunrgbd_pipeline(): ...@@ -278,7 +278,7 @@ def test_sunrgbd_pipeline():
rot_range=[-0.523599, 0.523599], rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15], scale_ratio_range=[0.85, 1.15],
shift_height=True), shift_height=True),
dict(type='IndoorPointSample', num_points=5), dict(type='PointSample', num_points=5),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='DefaultFormatBundle3D', class_names=class_names),
dict( dict(
type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']), type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']),
......
import numpy as np import numpy as np
from mmdet3d.core.points import DepthPoints from mmdet3d.core.points import DepthPoints
from mmdet3d.datasets.pipelines import (IndoorPatchPointSample, from mmdet3d.datasets.pipelines import (IndoorPatchPointSample, PointSample,
IndoorPointSample,
PointSegClassMapping) PointSegClassMapping)
def test_indoor_sample(): def test_indoor_sample():
np.random.seed(0) np.random.seed(0)
scannet_sample_points = IndoorPointSample(5) scannet_sample_points = PointSample(5)
scannet_results = dict() scannet_results = dict()
scannet_points = np.array([[1.0719866, -0.7870435, 0.8408122, 0.9196809], scannet_points = np.array([[1.0719866, -0.7870435, 0.8408122, 0.9196809],
[1.103661, 0.81065744, 2.6616862, 2.7405548], [1.103661, 0.81065744, 2.6616862, 2.7405548],
...@@ -39,7 +38,7 @@ def test_indoor_sample(): ...@@ -39,7 +38,7 @@ def test_indoor_sample():
scannet_semantic_labels_result) scannet_semantic_labels_result)
np.random.seed(0) np.random.seed(0)
sunrgbd_sample_points = IndoorPointSample(5) sunrgbd_sample_points = PointSample(5)
sunrgbd_results = dict() sunrgbd_results = dict()
sunrgbd_point_cloud = np.array( sunrgbd_point_cloud = np.array(
[[-1.8135729e-01, 1.4695230e+00, -1.2780589e+00, 7.8938007e-03], [[-1.8135729e-01, 1.4695230e+00, -1.2780589e+00, 7.8938007e-03],
...@@ -58,7 +57,7 @@ def test_indoor_sample(): ...@@ -58,7 +57,7 @@ def test_indoor_sample():
sunrgbd_choices = np.array([2, 8, 4, 9, 1]) sunrgbd_choices = np.array([2, 8, 4, 9, 1])
sunrgbd_points_result = sunrgbd_results['points'].tensor.numpy() sunrgbd_points_result = sunrgbd_results['points'].tensor.numpy()
repr_str = repr(sunrgbd_sample_points) repr_str = repr(sunrgbd_sample_points)
expected_repr_str = 'IndoorPointSample(num_points=5)' expected_repr_str = 'PointSample(num_points=5, sample_range=None)'
assert repr_str == expected_repr_str assert repr_str == expected_repr_str
assert np.allclose(sunrgbd_point_cloud[sunrgbd_choices], assert np.allclose(sunrgbd_point_cloud[sunrgbd_choices],
sunrgbd_points_result) sunrgbd_points_result)
......
...@@ -193,6 +193,8 @@ def test_base_points(): ...@@ -193,6 +193,8 @@ def test_base_points():
[[9.0722, 47.3678, -2.5382, 0.6666, 0.1956, 0.4974, 0.9409], [[9.0722, 47.3678, -2.5382, 0.6666, 0.1956, 0.4974, 0.9409],
[6.8547, 42.2509, -2.5955, 0.6565, 0.6248, 0.6954, 0.2538]]) [6.8547, 42.2509, -2.5955, 0.6565, 0.6248, 0.6954, 0.2538]])
assert torch.allclose(expected_tensor, base_points[mask].tensor, 1e-4) assert torch.allclose(expected_tensor, base_points[mask].tensor, 1e-4)
expected_tensor = torch.tensor([[0.6666], [0.1502], [0.6565], [0.2803]])
assert torch.allclose(expected_tensor, base_points[:, 3].tensor, 1e-4)
# test length # test length
assert len(base_points) == 4 assert len(base_points) == 4
...@@ -451,6 +453,8 @@ def test_cam_points(): ...@@ -451,6 +453,8 @@ def test_cam_points():
[[9.0722, 47.3678, -2.5382, 0.6666, 0.1956, 0.4974, 0.9409], [[9.0722, 47.3678, -2.5382, 0.6666, 0.1956, 0.4974, 0.9409],
[6.8547, 42.2509, -2.5955, 0.6565, 0.6248, 0.6954, 0.2538]]) [6.8547, 42.2509, -2.5955, 0.6565, 0.6248, 0.6954, 0.2538]])
assert torch.allclose(expected_tensor, cam_points[mask].tensor, 1e-4) assert torch.allclose(expected_tensor, cam_points[mask].tensor, 1e-4)
expected_tensor = torch.tensor([[0.6666], [0.1502], [0.6565], [0.2803]])
assert torch.allclose(expected_tensor, cam_points[:, 3].tensor, 1e-4)
# test length # test length
assert len(cam_points) == 4 assert len(cam_points) == 4
...@@ -725,6 +729,8 @@ def test_lidar_points(): ...@@ -725,6 +729,8 @@ def test_lidar_points():
[[9.0722, 47.3678, -2.5382, 0.6666, 0.1956, 0.4974, 0.9409], [[9.0722, 47.3678, -2.5382, 0.6666, 0.1956, 0.4974, 0.9409],
[6.8547, 42.2509, -2.5955, 0.6565, 0.6248, 0.6954, 0.2538]]) [6.8547, 42.2509, -2.5955, 0.6565, 0.6248, 0.6954, 0.2538]])
assert torch.allclose(expected_tensor, lidar_points[mask].tensor, 1e-4) assert torch.allclose(expected_tensor, lidar_points[mask].tensor, 1e-4)
expected_tensor = torch.tensor([[0.6666], [0.1502], [0.6565], [0.2803]])
assert torch.allclose(expected_tensor, lidar_points[:, 3].tensor, 1e-4)
# test length # test length
assert len(lidar_points) == 4 assert len(lidar_points) == 4
...@@ -999,6 +1005,8 @@ def test_depth_points(): ...@@ -999,6 +1005,8 @@ def test_depth_points():
[[9.0722, 47.3678, -2.5382, 0.6666, 0.1956, 0.4974, 0.9409], [[9.0722, 47.3678, -2.5382, 0.6666, 0.1956, 0.4974, 0.9409],
[6.8547, 42.2509, -2.5955, 0.6565, 0.6248, 0.6954, 0.2538]]) [6.8547, 42.2509, -2.5955, 0.6565, 0.6248, 0.6954, 0.2538]])
assert torch.allclose(expected_tensor, depth_points[mask].tensor, 1e-4) assert torch.allclose(expected_tensor, depth_points[mask].tensor, 1e-4)
expected_tensor = torch.tensor([[0.6666], [0.1502], [0.6565], [0.2803]])
assert torch.allclose(expected_tensor, depth_points[:, 3].tensor, 1e-4)
# test length # test length
assert len(depth_points) == 4 assert len(depth_points) == 4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment