points_sampler.py 5.31 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import torch
3
from mmcv.runner import force_fp32
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from torch import nn as nn
from typing import List

from .furthest_point_sample import (furthest_point_sample,
                                    furthest_point_sample_with_dist)
from .utils import calc_square_dist


def get_sampler_type(sampler_type):
    """Get the type and mode of points sampler.

    Args:
        sampler_type (str): The type of points sampler.
            The valid value are "D-FPS", "F-FPS", or "FS".

    Returns:
        class: Points sampler type.
    """
    if sampler_type == 'D-FPS':
        sampler = DFPS_Sampler
    elif sampler_type == 'F-FPS':
        sampler = FFPS_Sampler
    elif sampler_type == 'FS':
        sampler = FS_Sampler
    else:
        raise ValueError('Only "sampler_type" of "D-FPS", "F-FPS", or "FS"'
                         f' are supported, got {sampler_type}')

    return sampler


class Points_Sampler(nn.Module):
    """Points sampling.

    Args:
        num_point (list[int]): Number of sample points.
40
        fps_mod_list (list[str], optional): Type of FPS method, valid mod
41
42
43
44
            ['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS'].
            F-FPS: using feature distances for FPS.
            D-FPS: using Euclidean distances of points for FPS.
            FS: using F-FPS and D-FPS simultaneously.
45
46
        fps_sample_range_list (list[int], optional):
            Range of points to apply FPS. Default: [-1].
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    """

    def __init__(self,
                 num_point: List[int],
                 fps_mod_list: List[str] = ['D-FPS'],
                 fps_sample_range_list: List[int] = [-1]):
        super(Points_Sampler, self).__init__()
        # FPS would be applied to different fps_mod in the list,
        # so the length of the num_point should be equal to
        # fps_mod_list and fps_sample_range_list.
        assert len(num_point) == len(fps_mod_list) == len(
            fps_sample_range_list)
        self.num_point = num_point
        self.fps_sample_range_list = fps_sample_range_list
        self.samplers = nn.ModuleList()
        for fps_mod in fps_mod_list:
            self.samplers.append(get_sampler_type(fps_mod)())
64
        self.fp16_enabled = False
65

66
    @force_fp32()
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    def forward(self, points_xyz, features):
        """forward.

        Args:
            points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
            features (Tensor): (B, C, N) Descriptors of the features.

        Return:
            Tensor: (B, npoint, sample_num) Indices of sampled points.
        """
        indices = []
        last_fps_end_index = 0

        for fps_sample_range, sampler, npoint in zip(
                self.fps_sample_range_list, self.samplers, self.num_point):
            assert fps_sample_range < points_xyz.shape[1]

            if fps_sample_range == -1:
                sample_points_xyz = points_xyz[:, last_fps_end_index:]
86
87
                sample_features = features[:, :, last_fps_end_index:] if \
                    features is not None else None
88
89
90
91
            else:
                sample_points_xyz = \
                    points_xyz[:, last_fps_end_index:fps_sample_range]
                sample_features = \
92
93
                    features[:, :, last_fps_end_index:fps_sample_range] if \
                    features is not None else None
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

            fps_idx = sampler(sample_points_xyz.contiguous(), sample_features,
                              npoint)

            indices.append(fps_idx + last_fps_end_index)
            last_fps_end_index += fps_sample_range
        indices = torch.cat(indices, dim=1)

        return indices


class DFPS_Sampler(nn.Module):
    """DFPS_Sampling.

    Using Euclidean distances of points for FPS.
    """

    def __init__(self):
        super(DFPS_Sampler, self).__init__()

    def forward(self, points, features, npoint):
        """Sampling points with D-FPS."""
        fps_idx = furthest_point_sample(points.contiguous(), npoint)
        return fps_idx


class FFPS_Sampler(nn.Module):
    """FFPS_Sampler.

    Using feature distances for FPS.
    """

    def __init__(self):
        super(FFPS_Sampler, self).__init__()

    def forward(self, points, features, npoint):
        """Sampling points with F-FPS."""
131
132
        assert features is not None, \
            'feature input to FFPS_Sampler should not be None'
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2)
        features_dist = calc_square_dist(
            features_for_fps, features_for_fps, norm=False)
        fps_idx = furthest_point_sample_with_dist(features_dist, npoint)
        return fps_idx


class FS_Sampler(nn.Module):
    """FS_Sampling.

    Using F-FPS and D-FPS simultaneously.
    """

    def __init__(self):
        super(FS_Sampler, self).__init__()

    def forward(self, points, features, npoint):
        """Sampling points with FS_Sampling."""
151
152
        assert features is not None, \
            'feature input to FS_Sampler should not be None'
153
154
155
156
157
158
159
        features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2)
        features_dist = calc_square_dist(
            features_for_fps, features_for_fps, norm=False)
        fps_idx_ffps = furthest_point_sample_with_dist(features_dist, npoint)
        fps_idx_dfps = furthest_point_sample(points, npoint)
        fps_idx = torch.cat([fps_idx_ffps, fps_idx_dfps], dim=1)
        return fps_idx