Unverified Commit fef04cc4 authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Enhance] Migrate to `PointSample` and enhance `PointSample` function (#840)

* replace IndoorPointSample to PointSample

* add comments for PointSample transform

* fix bug when far points > num_sample

* refine format

* add unit test

* minor fix

* minor fix
parent e37c8777
...@@ -24,7 +24,7 @@ train_pipeline = [ ...@@ -24,7 +24,7 @@ train_pipeline = [
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39), 36, 39),
max_cat_id=40), max_cat_id=40),
dict(type='IndoorPointSample', num_points=40000), dict(type='PointSample', num_points=40000),
dict( dict(
type='RandomFlip3D', type='RandomFlip3D',
sync_2d=False, sync_2d=False,
...@@ -67,7 +67,7 @@ test_pipeline = [ ...@@ -67,7 +67,7 @@ test_pipeline = [
sync_2d=False, sync_2d=False,
flip_ratio_bev_horizontal=0.5, flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5), flip_ratio_bev_vertical=0.5),
dict(type='IndoorPointSample', num_points=40000), dict(type='PointSample', num_points=40000),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
class_names=class_names, class_names=class_names,
......
...@@ -92,7 +92,7 @@ train_pipeline = [ ...@@ -92,7 +92,7 @@ train_pipeline = [
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39)), 36, 39)),
dict(type='IndoorPointSample', num_points=50000), dict(type='PointSample', num_points=50000),
dict( dict(
type='RandomFlip3D', type='RandomFlip3D',
sync_2d=False, sync_2d=False,
...@@ -133,7 +133,7 @@ test_pipeline = [ ...@@ -133,7 +133,7 @@ test_pipeline = [
sync_2d=False, sync_2d=False,
flip_ratio_bev_horizontal=0.5, flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5), flip_ratio_bev_vertical=0.5),
dict(type='IndoorPointSample', num_points=50000), dict(type='PointSample', num_points=50000),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
class_names=class_names, class_names=class_names,
......
...@@ -91,7 +91,7 @@ train_pipeline = [ ...@@ -91,7 +91,7 @@ train_pipeline = [
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39)), 36, 39)),
dict(type='IndoorPointSample', num_points=50000), dict(type='PointSample', num_points=50000),
dict( dict(
type='RandomFlip3D', type='RandomFlip3D',
sync_2d=False, sync_2d=False,
...@@ -132,7 +132,7 @@ test_pipeline = [ ...@@ -132,7 +132,7 @@ test_pipeline = [
sync_2d=False, sync_2d=False,
flip_ratio_bev_horizontal=0.5, flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5), flip_ratio_bev_vertical=0.5),
dict(type='IndoorPointSample', num_points=50000), dict(type='PointSample', num_points=50000),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
class_names=class_names, class_names=class_names,
......
...@@ -107,7 +107,7 @@ train_pipeline = [ ...@@ -107,7 +107,7 @@ train_pipeline = [
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39)), 36, 39)),
dict(type='IndoorPointSample', num_points=50000), dict(type='PointSample', num_points=50000),
dict( dict(
type='RandomFlip3D', type='RandomFlip3D',
sync_2d=False, sync_2d=False,
...@@ -148,7 +148,7 @@ test_pipeline = [ ...@@ -148,7 +148,7 @@ test_pipeline = [
sync_2d=False, sync_2d=False,
flip_ratio_bev_horizontal=0.5, flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5), flip_ratio_bev_vertical=0.5),
dict(type='IndoorPointSample', num_points=50000), dict(type='PointSample', num_points=50000),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
class_names=class_names, class_names=class_names,
......
...@@ -108,7 +108,7 @@ train_pipeline = [ ...@@ -108,7 +108,7 @@ train_pipeline = [
type='PointSegClassMapping', type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39)), 36, 39)),
dict(type='IndoorPointSample', num_points=50000), dict(type='PointSample', num_points=50000),
dict( dict(
type='RandomFlip3D', type='RandomFlip3D',
sync_2d=False, sync_2d=False,
...@@ -149,7 +149,7 @@ test_pipeline = [ ...@@ -149,7 +149,7 @@ test_pipeline = [
sync_2d=False, sync_2d=False,
flip_ratio_bev_horizontal=0.5, flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5), flip_ratio_bev_vertical=0.5),
dict(type='IndoorPointSample', num_points=50000), dict(type='PointSample', num_points=50000),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
class_names=class_names, class_names=class_names,
......
...@@ -61,7 +61,7 @@ We adopt new pre-processing and conversion steps of ScanNet dataset. In previous ...@@ -61,7 +61,7 @@ We adopt new pre-processing and conversion steps of ScanNet dataset. In previous
- Since the aligned boxes share the same key as in old data infos, we do not need to modify the code related to it. But do remember that they are not in the same coordinate system as the saved points. - Since the aligned boxes share the same key as in old data infos, we do not need to modify the code related to it. But do remember that they are not in the same coordinate system as the saved points.
- There is an `IndoorPointSample` pipeline in the data pipelines for ScanNet detection task which down-samples points. So removing down-sampling in data generation will not affect the code. - There is an `PointSample` pipeline in the data pipelines for ScanNet detection task which down-samples points. So removing down-sampling in data generation will not affect the code.
We have trained a [VoteNet](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/votenet/votenet_8x8_scannet-3d-18class.py) model on the newly processed ScanNet dataset and get similar benchmark results. In order to prepare ScanNet data for both detection and segmentation tasks, please re-run the new pre-processing scripts following the ScanNet [README.md](https://github.com/open-mmlab/mmdetection3d/blob/master/data/scannet/README.md/). We have trained a [VoteNet](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/votenet/votenet_8x8_scannet-3d-18class.py) model on the newly processed ScanNet dataset and get similar benchmark results. In order to prepare ScanNet data for both detection and segmentation tasks, please re-run the new pre-processing scripts following the ScanNet [README.md](https://github.com/open-mmlab/mmdetection3d/blob/master/data/scannet/README.md/).
......
...@@ -267,7 +267,7 @@ train_pipeline = [ ...@@ -267,7 +267,7 @@ train_pipeline = [
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39), 36, 39),
max_cat_id=40), max_cat_id=40),
dict(type='IndoorPointSample', num_points=40000), dict(type='PointSample', num_points=40000),
dict( dict(
type='RandomFlip3D', type='RandomFlip3D',
sync_2d=False, sync_2d=False,
...@@ -291,7 +291,7 @@ train_pipeline = [ ...@@ -291,7 +291,7 @@ train_pipeline = [
- `GlobalAlignment`: The previous point cloud would be axis-aligned using the axis-aligned matrix. - `GlobalAlignment`: The previous point cloud would be axis-aligned using the axis-aligned matrix.
- `PointSegClassMapping`: Only the valid category ids will be mapped to class label ids like [0, 18) during training. - `PointSegClassMapping`: Only the valid category ids will be mapped to class label ids like [0, 18) during training.
- Data augmentation: - Data augmentation:
- `IndoorPointSample`: downsample the input point cloud. - `PointSample`: downsample the input point cloud.
- `RandomFlip3D`: randomly flip the input point cloud horizontally or vertically. - `RandomFlip3D`: randomly flip the input point cloud horizontally or vertically.
- `GlobalRotScaleTrans`: rotate the input point cloud, usually in the range of [-5, 5] (degrees) for ScanNet; then scale the input point cloud, usually by 1.0 for ScanNet; finally translate the input point cloud, usually by 0 for ScanNet. - `GlobalRotScaleTrans`: rotate the input point cloud, usually in the range of [-5, 5] (degrees) for ScanNet; then scale the input point cloud, usually by 1.0 for ScanNet; finally translate the input point cloud, usually by 0 for ScanNet.
......
...@@ -282,7 +282,7 @@ train_pipeline = [ ...@@ -282,7 +282,7 @@ train_pipeline = [
rot_range=[-0.523599, 0.523599], rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15], scale_ratio_range=[0.85, 1.15],
shift_height=True), shift_height=True),
dict(type='IndoorPointSample', num_points=20000), dict(type='PointSample', num_points=20000),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
...@@ -291,7 +291,7 @@ train_pipeline = [ ...@@ -291,7 +291,7 @@ train_pipeline = [
Data augmentation for point clouds: Data augmentation for point clouds:
- `RandomFlip3D`: randomly flip the input point cloud horizontally or vertically. - `RandomFlip3D`: randomly flip the input point cloud horizontally or vertically.
- `GlobalRotScaleTrans`: rotate the input point cloud, usually in the range of [-30, 30] (degrees) for SUN RGB-D; then scale the input point cloud, usually in the range of [0.85, 1.15] for SUN RGB-D; finally translate the input point cloud, usually by 0 for SUN RGB-D. - `GlobalRotScaleTrans`: rotate the input point cloud, usually in the range of [-30, 30] (degrees) for SUN RGB-D; then scale the input point cloud, usually in the range of [0.85, 1.15] for SUN RGB-D; finally translate the input point cloud, usually by 0 for SUN RGB-D.
- `IndoorPointSample`: downsample the input point cloud. - `PointSample`: downsample the input point cloud.
A typical train pipeline of SUN RGB-D for multi-modality (point cloud and image) 3D detection is as follows. A typical train pipeline of SUN RGB-D for multi-modality (point cloud and image) 3D detection is as follows.
...@@ -320,7 +320,7 @@ train_pipeline = [ ...@@ -320,7 +320,7 @@ train_pipeline = [
rot_range=[-0.523599, 0.523599], rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15], scale_ratio_range=[0.85, 1.15],
shift_height=True), shift_height=True),
dict(type='IndoorPointSample', num_points=20000), dict(type='PointSample', num_points=20000),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='DefaultFormatBundle3D', class_names=class_names),
dict( dict(
type='Collect3D', type='Collect3D',
......
...@@ -281,7 +281,7 @@ train_pipeline = [ ...@@ -281,7 +281,7 @@ train_pipeline = [
rot_range=[-0.523599, 0.523599], rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15], scale_ratio_range=[0.85, 1.15],
shift_height=True), shift_height=True),
dict(type='IndoorPointSample', num_points=20000), dict(type='PointSample', num_points=20000),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
...@@ -290,7 +290,7 @@ train_pipeline = [ ...@@ -290,7 +290,7 @@ train_pipeline = [
点云上的数据增强 点云上的数据增强
- `RandomFlip3D`:随机左右或前后翻转输入点云。 - `RandomFlip3D`:随机左右或前后翻转输入点云。
- `GlobalRotScaleTrans`:旋转输入点云,对于 SUN RGB-D 角度通常落入 [-30, 30] (度)的范围;并放缩输入点云,对于 SUN RGB-D 比例通常落入 [0.85, 1.15] 的范围;最后平移输入点云,对于 SUN RGB-D 通常位移量为 0。 - `GlobalRotScaleTrans`:旋转输入点云,对于 SUN RGB-D 角度通常落入 [-30, 30] (度)的范围;并放缩输入点云,对于 SUN RGB-D 比例通常落入 [0.85, 1.15] 的范围;最后平移输入点云,对于 SUN RGB-D 通常位移量为 0。
- `IndoorPointSample`:降采样输入点云。 - `PointSample`:降采样输入点云。
SUN RGB-D 上多模态(点云和图像)3D 物体检测的经典流程如下: SUN RGB-D 上多模态(点云和图像)3D 物体检测的经典流程如下:
...@@ -319,7 +319,7 @@ train_pipeline = [ ...@@ -319,7 +319,7 @@ train_pipeline = [
rot_range=[-0.523599, 0.523599], rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15], scale_ratio_range=[0.85, 1.15],
shift_height=True), shift_height=True),
dict(type='IndoorPointSample', num_points=20000), dict(type='PointSample', num_points=20000),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='DefaultFormatBundle3D', class_names=class_names),
dict( dict(
type='Collect3D', type='Collect3D',
......
...@@ -204,7 +204,7 @@ train_pipeline = [ # 训练流水线,更多细节请参考 mmdet3d.datasets.p ...@@ -204,7 +204,7 @@ train_pipeline = [ # 训练流水线,更多细节请参考 mmdet3d.datasets.p
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39), # 所有有效类别的编号 36, 39), # 所有有效类别的编号
max_cat_id=40), # 输入语义分割掩码中可能存在的最大类别编号 max_cat_id=40), # 输入语义分割掩码中可能存在的最大类别编号
dict(type='IndoorPointSample', # 室内点采样,更多细节请参考 mmdet3d.datasets.pipelines.indoor_sample dict(type='PointSample', # 室内点采样,更多细节请参考 mmdet3d.datasets.pipelines.indoor_sample
num_points=40000), # 采样的点的数量 num_points=40000), # 采样的点的数量
dict(type='IndoorFlipData', # 数据增广流程,随机翻转点和 3D 框 dict(type='IndoorFlipData', # 数据增广流程,随机翻转点和 3D 框
flip_ratio_yz=0.5, # 沿着 yz 平面被翻转的概率 flip_ratio_yz=0.5, # 沿着 yz 平面被翻转的概率
...@@ -233,7 +233,7 @@ test_pipeline = [ # 测试流水线,更多细节请参考 mmdet3d.datasets.pi ...@@ -233,7 +233,7 @@ test_pipeline = [ # 测试流水线,更多细节请参考 mmdet3d.datasets.pi
shift_height=True, # 是否使用变换高度 shift_height=True, # 是否使用变换高度
load_dim=6, # 读取的点的维度 load_dim=6, # 读取的点的维度
use_dim=[0, 1, 2]), # 使用所读取点的哪些维度 use_dim=[0, 1, 2]), # 使用所读取点的哪些维度
dict(type='IndoorPointSample', # 室内点采样,更多细节请参考 mmdet3d.datasets.pipelines.indoor_sample dict(type='PointSample', # 室内点采样,更多细节请参考 mmdet3d.datasets.pipelines.indoor_sample
num_points=40000), # 采样的点的数量 num_points=40000), # 采样的点的数量
dict( dict(
type='DefaultFormatBundle3D', # 默认格式打包以收集读取的所有数据,更多细节请参考 mmdet3d.datasets.pipelines.formating type='DefaultFormatBundle3D', # 默认格式打包以收集读取的所有数据,更多细节请参考 mmdet3d.datasets.pipelines.formating
...@@ -287,7 +287,7 @@ data = dict( ...@@ -287,7 +287,7 @@ data = dict(
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
28, 33, 34, 36, 39), 28, 33, 34, 36, 39),
max_cat_id=40), max_cat_id=40),
dict(type='IndoorPointSample', num_points=40000), dict(type='PointSample', num_points=40000),
dict( dict(
type='IndoorFlipData', type='IndoorFlipData',
flip_ratio_yz=0.5, flip_ratio_yz=0.5,
...@@ -326,7 +326,7 @@ data = dict( ...@@ -326,7 +326,7 @@ data = dict(
shift_height=True, shift_height=True,
load_dim=6, load_dim=6,
use_dim=[0, 1, 2]), use_dim=[0, 1, 2]),
dict(type='IndoorPointSample', num_points=40000), dict(type='PointSample', num_points=40000),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
class_names=('cabinet', 'bed', 'chair', 'sofa', 'table', class_names=('cabinet', 'bed', 'chair', 'sofa', 'table',
...@@ -351,7 +351,7 @@ data = dict( ...@@ -351,7 +351,7 @@ data = dict(
shift_height=True, shift_height=True,
load_dim=6, load_dim=6,
use_dim=[0, 1, 2]), use_dim=[0, 1, 2]),
dict(type='IndoorPointSample', num_points=40000), dict(type='PointSample', num_points=40000),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
class_names=('cabinet', 'bed', 'chair', 'sofa', 'table', class_names=('cabinet', 'bed', 'chair', 'sofa', 'table',
......
...@@ -846,6 +846,10 @@ class PointSample(object): ...@@ -846,6 +846,10 @@ class PointSample(object):
Args: Args:
num_points (int): Number of points to be sampled. num_points (int): Number of points to be sampled.
sample_range (float, optional): The range where to sample points. sample_range (float, optional): The range where to sample points.
If not None, the points with depth larger than `sample_range` are
prior to be sampled. Defaults to None.
replace (bool, optional): Whether the sampling is with or without
replacement. Defaults to False.
""" """
def __init__(self, num_points, sample_range=None, replace=False): def __init__(self, num_points, sample_range=None, replace=False):
...@@ -867,8 +871,7 @@ class PointSample(object): ...@@ -867,8 +871,7 @@ class PointSample(object):
points (np.ndarray | :obj:`BasePoints`): 3D Points. points (np.ndarray | :obj:`BasePoints`): 3D Points.
num_samples (int): Number of samples to be sampled. num_samples (int): Number of samples to be sampled.
sample_range (float, optional): Indicating the range where the sample_range (float, optional): Indicating the range where the
points will be sampled. points will be sampled. Defaults to None.
Defaults to None.
replace (bool, optional): Sampling with or without replacement. replace (bool, optional): Sampling with or without replacement.
Defaults to None. Defaults to None.
return_choices (bool, optional): Whether return choice. return_choices (bool, optional): Whether return choice.
...@@ -886,6 +889,10 @@ class PointSample(object): ...@@ -886,6 +889,10 @@ class PointSample(object):
depth = np.linalg.norm(points.tensor, axis=1) depth = np.linalg.norm(points.tensor, axis=1)
far_inds = np.where(depth > sample_range)[0] far_inds = np.where(depth > sample_range)[0]
near_inds = np.where(depth <= sample_range)[0] near_inds = np.where(depth <= sample_range)[0]
# in case there are too many far points
if len(far_inds) > num_samples:
far_inds = np.random.choice(
far_inds, num_samples, replace=False)
point_range = near_inds point_range = near_inds
num_samples -= len(far_inds) num_samples -= len(far_inds)
choices = np.random.choice(point_range, num_samples, replace=replace) choices = np.random.choice(point_range, num_samples, replace=replace)
...@@ -907,11 +914,11 @@ class PointSample(object): ...@@ -907,11 +914,11 @@ class PointSample(object):
dict: Results after sampling, 'points', 'pts_instance_mask' \ dict: Results after sampling, 'points', 'pts_instance_mask' \
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
""" """
from mmdet3d.core.points import CameraPoints
points = results['points'] points = results['points']
# Points in Camera coord can provide the depth information. # Points in Camera coord can provide the depth information.
# TODO: Need to suport distance-based sampling for other coord system. # TODO: Need to suport distance-based sampling for other coord system.
if self.sample_range is not None: if self.sample_range is not None:
from mmdet3d.core.points import CameraPoints
assert isinstance(points, CameraPoints), \ assert isinstance(points, CameraPoints), \
'Sampling based on distance is only appliable for CAMERA coord' 'Sampling based on distance is only appliable for CAMERA coord'
points, choices = self._points_random_sampling( points, choices = self._points_random_sampling(
...@@ -939,7 +946,8 @@ class PointSample(object): ...@@ -939,7 +946,8 @@ class PointSample(object):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f'(num_points={self.num_points},' repr_str += f'(num_points={self.num_points},'
repr_str += f' sample_range={self.sample_range})' repr_str += f' sample_range={self.sample_range},'
repr_str += f' replace={self.replace})'
return repr_str return repr_str
......
...@@ -722,7 +722,7 @@ def test_points_sample(): ...@@ -722,7 +722,7 @@ def test_points_sample():
points.copy(), points_dim=4).convert_to(Coord3DMode.CAM, rect @ Trv2c) points.copy(), points_dim=4).convert_to(Coord3DMode.CAM, rect @ Trv2c)
num_points = 20 num_points = 20
sample_range = 40 sample_range = 40
input_dict = dict(points=points) input_dict = dict(points=points.clone())
point_sample = PointSample( point_sample = PointSample(
num_points=num_points, sample_range=sample_range) num_points=num_points, sample_range=sample_range)
...@@ -736,6 +736,17 @@ def test_points_sample(): ...@@ -736,6 +736,17 @@ def test_points_sample():
assert np.allclose(sampled_pts.tensor.numpy(), expected_pts) assert np.allclose(sampled_pts.tensor.numpy(), expected_pts)
repr_str = repr(point_sample) repr_str = repr(point_sample)
expected_repr_str = f'PointSample(num_points={num_points},'\ expected_repr_str = f'PointSample(num_points={num_points}, ' \
+ f' sample_range={sample_range})' f'sample_range={sample_range}, ' \
'replace=False)'
assert repr_str == expected_repr_str assert repr_str == expected_repr_str
# test when number of far points are larger than number of sampled points
np.random.seed(0)
point_sample = PointSample(num_points=2, sample_range=sample_range)
input_dict = dict(points=points.clone())
sampled_pts = point_sample(input_dict)['points']
select_idx = np.array([449, 444])
expected_pts = points.tensor.numpy()[select_idx]
assert np.allclose(sampled_pts.tensor.numpy(), expected_pts)
...@@ -57,7 +57,9 @@ def test_indoor_sample(): ...@@ -57,7 +57,9 @@ def test_indoor_sample():
sunrgbd_choices = np.array([2, 8, 4, 9, 1]) sunrgbd_choices = np.array([2, 8, 4, 9, 1])
sunrgbd_points_result = sunrgbd_results['points'].tensor.numpy() sunrgbd_points_result = sunrgbd_results['points'].tensor.numpy()
repr_str = repr(sunrgbd_sample_points) repr_str = repr(sunrgbd_sample_points)
expected_repr_str = 'PointSample(num_points=5, sample_range=None)' expected_repr_str = 'PointSample(num_points=5, ' \
'sample_range=None, ' \
'replace=False)'
assert repr_str == expected_repr_str assert repr_str == expected_repr_str
assert np.allclose(sunrgbd_point_cloud[sunrgbd_choices], assert np.allclose(sunrgbd_point_cloud[sunrgbd_choices],
sunrgbd_points_result) sunrgbd_points_result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment