Commit ff17dc45 authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'indoor_sample' into 'master'

pipeline:Indoor sample

See merge request open-mmlab/mmdet.3d!10
parents 624575d6 c55b395b
import numpy as np
from mmdet.datasets.registry import PIPELINES
@PIPELINES.register_module
class PointSample(object):
"""Point Sample.
Sampling data to a certain number.
Args:
name (str): Name of the dataset.
num_points (int): Number of points to be sampled.
"""
def __init__(self, num_points):
self.num_points = num_points
def points_random_sampling(self,
points,
num_samples,
replace=None,
return_choices=False):
"""Points Random Sampling.
Sample points to a certain number.
Args:
points (ndarray): 3D Points.
num_samples (int): Number of samples to be sampled.
replace (bool): Whether the sample is with or without replacement.
return_choices (bool): Whether return choice.
Returns:
points (ndarray): 3D Points.
choices (ndarray): The generated random samples
"""
if replace is None:
replace = (points.shape[0] < num_samples)
choices = np.random.choice(
points.shape[0], num_samples, replace=replace)
if return_choices:
return points[choices], choices
else:
return points[choices]
def __call__(self, results):
points = results.get('points', None)
points, choices = self.points_random_sampling(
points, self.num_points, return_choices=True)
pts_instance_mask = results.get('pts_instance_mask', None)
pts_semantic_mask = results.get('pts_semantic_mask', None)
results['points'] = points
if pts_instance_mask is not None and pts_semantic_mask is not None:
pts_instance_mask = pts_instance_mask[choices]
pts_semantic_mask = pts_semantic_mask[choices]
results['pts_instance_mask'] = pts_instance_mask
results['pts_semantic_mask'] = pts_semantic_mask
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += '(num_points={})'.format(self.num_points)
return repr_str
import numpy as np
from mmdet3d.datasets.pipelines.indoor_sample import PointSample
def test_indoor_sample():
np.random.seed(0)
scannet_sample_points = PointSample(5)
scannet_results = dict()
scannet_points = np.array([[1.0719866, -0.7870435, 0.8408122, 0.9196809],
[1.103661, 0.81065744, 2.6616862, 2.7405548],
[1.0276475, 1.5061463, 2.6174362, 2.6963048],
[-0.9709588, 0.6750515, 0.93901765, 1.0178864],
[1.0578915, 1.1693821, 0.87503505, 0.95390373],
[0.05560996, -1.5688863, 1.2440368, 1.3229055],
[-0.15731563, -1.7735453, 2.7535574, 2.832426],
[1.1188195, -0.99211365, 2.5551798, 2.6340485],
[-0.9186557, -1.7041215, 2.0562649, 2.1351335],
[-1.0128691, -1.3394243, 0.040936, 0.1198047]])
scannet_results['points'] = scannet_points
scannet_pts_instance_mask = np.array(
[15, 12, 11, 38, 0, 18, 17, 12, 17, 0])
scannet_results['pts_instance_mask'] = scannet_pts_instance_mask
scannet_pts_semantic_mask = np.array([38, 1, 1, 40, 0, 40, 1, 1, 1, 0])
scannet_results['pts_semantic_mask'] = scannet_pts_semantic_mask
scannet_results = scannet_sample_points(scannet_results)
scannet_points_result = scannet_results.get('points', None)
scannet_instance_labels_result = scannet_results.get(
'pts_instance_mask', None)
scannet_semantic_labels_result = scannet_results.get(
'pts_semantic_mask', None)
scannet_choices = np.array([2, 8, 4, 9, 1])
assert np.allclose(scannet_points[scannet_choices], scannet_points_result)
assert np.all(scannet_pts_instance_mask[scannet_choices] ==
scannet_instance_labels_result)
assert np.all(scannet_pts_semantic_mask[scannet_choices] ==
scannet_semantic_labels_result)
np.random.seed(0)
sunrgbd_sample_points = PointSample(5)
sunrgbd_results = dict()
sunrgbd_point_cloud = np.array(
[[-1.8135729e-01, 1.4695230e+00, -1.2780589e+00, 7.8938007e-03],
[1.2581362e-03, 2.0561588e+00, -1.0341064e+00, 2.5184631e-01],
[6.8236995e-01, 3.3611867e+00, -9.2599887e-01, 3.5995382e-01],
[-2.9432583e-01, 1.8714852e+00, -9.0929651e-01, 3.7665617e-01],
[-0.5024875, 1.8032674, -1.1403012, 0.14565146],
[-0.520559, 1.6324949, -0.9896099, 0.2963428],
[0.95929825, 2.9402404, -0.8746674, 0.41128528],
[-0.74624217, 1.5244724, -0.8678476, 0.41810507],
[0.56485355, 1.5747732, -0.804522, 0.4814307],
[-0.0913099, 1.3673826, -1.2800645, 0.00588822]])
sunrgbd_results['points'] = sunrgbd_point_cloud
sunrgbd_results = sunrgbd_sample_points(sunrgbd_results)
sunrgbd_choices = np.array([2, 8, 4, 9, 1])
sunrgbd_points_result = sunrgbd_results.get('points', None)
assert np.allclose(sunrgbd_point_cloud[sunrgbd_choices],
sunrgbd_points_result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment