indoor_sample.py 2.24 KB
Newer Older
liyinhao's avatar
liyinhao committed
1
2
3
4
5
6
7
8
9
import numpy as np

from mmdet.datasets.registry import PIPELINES


def points_random_sampling(points,
                           num_samples,
                           replace=None,
                           return_choices=False):
liyinhao's avatar
liyinhao committed
10
    """Points Random Sampling.
liyinhao's avatar
liyinhao committed
11
12
13
14
15
16

    Sample points to a certain number.

    Args:
        points (ndarray): 3D Points.
        num_samples (int): Number of samples to be sampled.
liyinhao's avatar
liyinhao committed
17
18
19
20
21
22
        replace (bool): Whether the sample is with or without replacement.
        return_choices (bool): Whether return choice.

    Returns:
        points (ndarray): 3D Points.
        choices (ndarray): The generated random samples
liyinhao's avatar
liyinhao committed
23
24
25
26
27
28
29
30
31
32
33
    """
    if replace is None:
        replace = (points.shape[0] < num_samples)
    choices = np.random.choice(points.shape[0], num_samples, replace=replace)
    if return_choices:
        return points[choices], choices
    else:
        return points[choices]


@PIPELINES.register_module
liyinhao's avatar
liyinhao committed
34
35
class PointSample(object):
    """Point Sample.
liyinhao's avatar
liyinhao committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

    Sampling data to a certain number.

    Args:
        name (str): Name of the dataset.
        num_points (int): Number of points to be sampled.
    """

    def __init__(self, name, num_points):
        assert name in ['scannet', 'sunrgbd']
        self.name = name
        self.num_points = num_points

    def __call__(self, results):
        points = results.get('points', None)
        pcl_color = results.get('pcl_color', None)
        points, choices = points_random_sampling(
            points, self.num_points, return_choices=True)
        results['points'] = points

        if self.name == 'scannet':
            pcl_color = pcl_color[choices]
            instance_labels = results.get('instance_labels', None)
            semantic_labels = results.get('semantic_labels', None)
            instance_labels = instance_labels[choices]
            semantic_labels = semantic_labels[choices]
            results['instance_labels'] = instance_labels
            results['semantic_labels'] = semantic_labels
            results['pcl_color'] = pcl_color

        return results

    def __repr__(self):
        repr_str = self.__class__.__name__
liyinhao's avatar
liyinhao committed
70
        repr_str += '(dataset_name={})'.format(self.name)
liyinhao's avatar
liyinhao committed
71
72
        repr_str += '(num_points={})'.format(self.num_points)
        return repr_str