point_sa_module.py 14.2 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import List, Optional, Tuple, Union

wuyuefeng's avatar
wuyuefeng committed
4
5
import torch
from mmcv.cnn import ConvModule
6
7
8
from mmcv.ops import GroupAll
from mmcv.ops import PointsSampler as Points_Sampler
from mmcv.ops import QueryAndGroup, gather_points
9
from torch import Tensor
zhangwenwei's avatar
zhangwenwei committed
10
11
from torch import nn as nn
from torch.nn import functional as F
wuyuefeng's avatar
wuyuefeng committed
12

zhangshilong's avatar
zhangshilong committed
13
from mmdet3d.models.layers import PAConv
14
from mmdet3d.utils import ConfigType
15
from .builder import SA_MODULES
wuyuefeng's avatar
wuyuefeng committed
16
17


18
19
class BasePointSAModule(nn.Module):
    """Base module for point set abstraction module used in PointNets.
wuyuefeng's avatar
wuyuefeng committed
20
21
22

    Args:
        num_point (int): Number of points.
23
24
25
        radii (List[float]): List of radius in each ball query.
        sample_nums (List[int]): Number of samples in each ball query.
        mlp_channels (List[List[int]]): Specify of the pointnet before
wuyuefeng's avatar
wuyuefeng committed
26
            the global pooling for each scale.
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
        fps_mod (List[str]): Type of FPS method, valid mod
            ['F-FPS', 'D-FPS', 'FS']. Defaults to ['D-FPS'].

            - F-FPS: using feature distances for FPS.
            - D-FPS: using Euclidean distances of points for FPS.
            - FS: using F-FPS and D-FPS simultaneously.
        fps_sample_range_list (List[int]): Range of points to apply FPS.
            Defaults to [-1].
        dilated_group (bool): Whether to use dilated ball query.
            Defaults to False.
        use_xyz (bool): Whether to use xyz. Defaults to True.
        pool_mod (str): Type of pooling method. Defaults to 'max'.
        normalize_xyz (bool): Whether to normalize local XYZ with radius.
            Defaults to False.
        grouper_return_grouped_xyz (bool): Whether to return grouped xyz
            in `QueryAndGroup`. Defaults to False.
        grouper_return_grouped_idx (bool): Whether to return grouped idx
            in `QueryAndGroup`. Defaults to False.
wuyuefeng's avatar
wuyuefeng committed
45
46
47
    """

    def __init__(self,
48
49
50
51
52
53
54
55
56
57
58
59
                 num_point: int,
                 radii: List[float],
                 sample_nums: List[int],
                 mlp_channels: List[List[int]],
                 fps_mod: List[str] = ['D-FPS'],
                 fps_sample_range_list: List[int] = [-1],
                 dilated_group: bool = False,
                 use_xyz: bool = True,
                 pool_mod: str = 'max',
                 normalize_xyz: bool = False,
                 grouper_return_grouped_xyz: bool = False,
                 grouper_return_grouped_idx: bool = False) -> None:
60
        super(BasePointSAModule, self).__init__()
wuyuefeng's avatar
wuyuefeng committed
61
62
63

        assert len(radii) == len(sample_nums) == len(mlp_channels)
        assert pool_mod in ['max', 'avg']
64
65
66
67
68
69
70
        assert isinstance(fps_mod, list) or isinstance(fps_mod, tuple)
        assert isinstance(fps_sample_range_list, list) or isinstance(
            fps_sample_range_list, tuple)
        assert len(fps_mod) == len(fps_sample_range_list)

        if isinstance(mlp_channels, tuple):
            mlp_channels = list(map(list, mlp_channels))
71
        self.mlp_channels = mlp_channels
72
73
74
75
76

        if isinstance(num_point, int):
            self.num_point = [num_point]
        elif isinstance(num_point, list) or isinstance(num_point, tuple):
            self.num_point = num_point
77
78
        elif num_point is None:
            self.num_point = None
79
80
        else:
            raise NotImplementedError('Error type of num_point!')
wuyuefeng's avatar
wuyuefeng committed
81
82
83
84

        self.pool_mod = pool_mod
        self.groupers = nn.ModuleList()
        self.mlps = nn.ModuleList()
85
86
87
        self.fps_mod_list = fps_mod
        self.fps_sample_range_list = fps_sample_range_list

88
89
90
91
92
93
        if self.num_point is not None:
            self.points_sampler = Points_Sampler(self.num_point,
                                                 self.fps_mod_list,
                                                 self.fps_sample_range_list)
        else:
            self.points_sampler = None
wuyuefeng's avatar
wuyuefeng committed
94
95
96
97
98

        for i in range(len(radii)):
            radius = radii[i]
            sample_num = sample_nums[i]
            if num_point is not None:
99
100
101
102
                if dilated_group and i != 0:
                    min_radius = radii[i - 1]
                else:
                    min_radius = 0
wuyuefeng's avatar
wuyuefeng committed
103
104
105
                grouper = QueryAndGroup(
                    radius,
                    sample_num,
106
                    min_radius=min_radius,
wuyuefeng's avatar
wuyuefeng committed
107
                    use_xyz=use_xyz,
108
109
110
                    normalize_xyz=normalize_xyz,
                    return_grouped_xyz=grouper_return_grouped_xyz,
                    return_grouped_idx=grouper_return_grouped_idx)
wuyuefeng's avatar
wuyuefeng committed
111
112
113
114
            else:
                grouper = GroupAll(use_xyz)
            self.groupers.append(grouper)

115
116
    def _sample_points(self, points_xyz: Tensor, features: Tensor,
                       indices: Tensor, target_xyz: Tensor) -> Tuple[Tensor]:
117
        """Perform point sampling based on inputs.
wuyuefeng's avatar
wuyuefeng committed
118

119
120
121
        If `indices` is specified, directly sample corresponding points.
        Else if `target_xyz` is specified, use is as sampled points.
        Otherwise sample points using `self.points_sampler`.
wuyuefeng's avatar
wuyuefeng committed
122
123
124

        Args:
            points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
125
            features (Tensor): (B, C, N) Features of each point.
wuyuefeng's avatar
wuyuefeng committed
126
            indices (Tensor): (B, num_point) Index of the features.
127
            target_xyz (Tensor): (B, M, 3) new_xyz coordinates of the outputs.
wuyuefeng's avatar
wuyuefeng committed
128
129

        Returns:
130
131
132
133
            Tuple[Tensor]:

            - new_xyz: (B, num_point, 3) Sampled xyz coordinates of points.
            - indices: (B, num_point) Sampled points' index.
wuyuefeng's avatar
wuyuefeng committed
134
135
        """
        xyz_flipped = points_xyz.transpose(1, 2).contiguous()
136
137
138
139
140
141
142
        if indices is not None:
            assert (indices.shape[1] == self.num_point[0])
            new_xyz = gather_points(xyz_flipped, indices).transpose(
                1, 2).contiguous() if self.num_point is not None else None
        elif target_xyz is not None:
            new_xyz = target_xyz.contiguous()
        else:
143
144
145
146
147
148
            if self.num_point is not None:
                indices = self.points_sampler(points_xyz, features)
                new_xyz = gather_points(xyz_flipped,
                                        indices).transpose(1, 2).contiguous()
            else:
                new_xyz = None
wuyuefeng's avatar
wuyuefeng committed
149

150
151
        return new_xyz, indices

152
    def _pool_features(self, features: Tensor) -> Tensor:
153
154
155
        """Perform feature aggregation using pooling operation.

        Args:
156
157
            features (Tensor): (B, C, N, K) Features of locally grouped
                points before pooling.
158
159

        Returns:
160
            Tensor: (B, C, N) Pooled features aggregating local information.
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        """
        if self.pool_mod == 'max':
            # (B, C, N, 1)
            new_features = F.max_pool2d(
                features, kernel_size=[1, features.size(3)])
        elif self.pool_mod == 'avg':
            # (B, C, N, 1)
            new_features = F.avg_pool2d(
                features, kernel_size=[1, features.size(3)])
        else:
            raise NotImplementedError

        return new_features.squeeze(-1).contiguous()

    def forward(
        self,
177
178
179
180
181
182
        points_xyz: Tensor,
        features: Optional[Tensor] = None,
        indices: Optional[Tensor] = None,
        target_xyz: Optional[Tensor] = None,
    ) -> Tuple[Tensor]:
        """Forward.
183
184
185

        Args:
            points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
186
187
            features (Tensor, optional): (B, C, N) Features of each point.
                Defaults to None.
188
            indices (Tensor, optional): (B, num_point) Index of the features.
189
190
191
                Defaults to None.
            target_xyz (Tensor, optional): (B, M, 3) New coords of the outputs.
                Defaults to None.
192
193

        Returns:
194
195
196
197
198
199
200
201
            Tuple[Tensor]:

                - new_xyz: (B, M, 3) Where M is the number of points.
                  New features xyz.
                - new_features: (B, M, sum_k(mlps[k][-1])) Where M is the
                  number of points. New feature descriptors.
                - indices: (B, M) Where M is the number of points.
                  Index of the features.
202
203
204
205
206
207
208
        """
        new_features_list = []

        # sample points, (B, num_point, 3), (B, num_point)
        new_xyz, indices = self._sample_points(points_xyz, features, indices,
                                               target_xyz)

wuyuefeng's avatar
wuyuefeng committed
209
        for i in range(len(self.groupers)):
210
211
212
213
214
            # grouped_results may contain:
            # - grouped_features: (B, C, num_point, nsample)
            # - grouped_xyz: (B, 3, num_point, nsample)
            # - grouped_idx: (B, num_point, nsample)
            grouped_results = self.groupers[i](points_xyz, new_xyz, features)
wuyuefeng's avatar
wuyuefeng committed
215
216

            # (B, mlp[-1], num_point, nsample)
217
            new_features = self.mlps[i](grouped_results)
wuyuefeng's avatar
wuyuefeng committed
218

219
220
221
222
223
224
225
226
            # this is a bit hack because PAConv outputs two values
            # we take the first one as feature
            if isinstance(self.mlps[i][0], PAConv):
                assert isinstance(new_features, tuple)
                new_features = new_features[0]

            # (B, mlp[-1], num_point)
            new_features = self._pool_features(new_features)
wuyuefeng's avatar
wuyuefeng committed
227
228
229
230
231
            new_features_list.append(new_features)

        return new_xyz, torch.cat(new_features_list, dim=1), indices


232
233
234
235
236
237
238
@SA_MODULES.register_module()
class PointSAModuleMSG(BasePointSAModule):
    """Point set abstraction module with multi-scale grouping (MSG) used in
    PointNets.

    Args:
        num_point (int): Number of points.
239
240
241
        radii (List[float]): List of radius in each ball query.
        sample_nums (List[int]): Number of samples in each ball query.
        mlp_channels (List[List[int]]): Specify of the pointnet before
242
            the global pooling for each scale.
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        fps_mod (List[str]): Type of FPS method, valid mod
            ['F-FPS', 'D-FPS', 'FS']. Defaults to ['D-FPS'].

            - F-FPS: using feature distances for FPS.
            - D-FPS: using Euclidean distances of points for FPS.
            - FS: using F-FPS and D-FPS simultaneously.
        fps_sample_range_list (List[int]): Range of points to apply FPS.
            Defaults to [-1].
        dilated_group (bool): Whether to use dilated ball query.
            Defaults to False.
        norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
            layer. Defaults to dict(type='BN2d').
        use_xyz (bool): Whether to use xyz. Defaults to True.
        pool_mod (str): Type of pooling method. Defaults to 'max'.
        normalize_xyz (bool): Whether to normalize local XYZ with radius.
            Defaults to False.
        bias (bool or str): If specified as `auto`, it will be decided by
            `norm_cfg`. `bias` will be set as True if `norm_cfg` is None,
            otherwise False. Defaults to 'auto'.
262
263
264
    """

    def __init__(self,
265
266
267
268
269
270
271
272
273
274
275
276
                 num_point: int,
                 radii: List[float],
                 sample_nums: List[int],
                 mlp_channels: List[List[int]],
                 fps_mod: List[str] = ['D-FPS'],
                 fps_sample_range_list: List[int] = [-1],
                 dilated_group: bool = False,
                 norm_cfg: ConfigType = dict(type='BN2d'),
                 use_xyz: bool = True,
                 pool_mod: str = 'max',
                 normalize_xyz: bool = False,
                 bias: Union[bool, str] = 'auto') -> None:
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        super(PointSAModuleMSG, self).__init__(
            num_point=num_point,
            radii=radii,
            sample_nums=sample_nums,
            mlp_channels=mlp_channels,
            fps_mod=fps_mod,
            fps_sample_range_list=fps_sample_range_list,
            dilated_group=dilated_group,
            use_xyz=use_xyz,
            pool_mod=pool_mod,
            normalize_xyz=normalize_xyz)

        for i in range(len(self.mlp_channels)):
            mlp_channel = self.mlp_channels[i]
            if use_xyz:
                mlp_channel[0] += 3

            mlp = nn.Sequential()
            for i in range(len(mlp_channel) - 1):
                mlp.add_module(
                    f'layer{i}',
                    ConvModule(
                        mlp_channel[i],
                        mlp_channel[i + 1],
                        kernel_size=(1, 1),
                        stride=(1, 1),
                        conv_cfg=dict(type='Conv2d'),
                        norm_cfg=norm_cfg,
                        bias=bias))
            self.mlps.append(mlp)


309
@SA_MODULES.register_module()
wuyuefeng's avatar
wuyuefeng committed
310
class PointSAModule(PointSAModuleMSG):
311
312
    """Point set abstraction module with single-scale grouping (SSG) used in
    PointNets.
wuyuefeng's avatar
wuyuefeng committed
313
314

    Args:
315
        mlp_channels (List[int]): Specify of the pointnet before
wuyuefeng's avatar
wuyuefeng committed
316
            the global pooling for each scale.
317
318
        num_point (int, optional): Number of points. Defaults to None.
        radius (float, optional): Radius to group with. Defaults to None.
319
        num_sample (int, optional): Number of samples in each ball query.
320
321
322
323
324
325
326
327
328
329
330
            Defaults to None.
        norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
            layer. Default to dict(type='BN2d').
        use_xyz (bool): Whether to use xyz. Defaults to True.
        pool_mod (str): Type of pooling method. Defaults to 'max'.
        fps_mod (List[str]): Type of FPS method, valid mod
            ['F-FPS', 'D-FPS', 'FS']. Defaults to ['D-FPS'].
        fps_sample_range_list (List[int]): Range of points to apply FPS.
            Defaults to [-1].
        normalize_xyz (bool): Whether to normalize local XYZ with radius.
            Defaults to False.
wuyuefeng's avatar
wuyuefeng committed
331
332
333
    """

    def __init__(self,
334
335
336
337
338
339
340
341
342
343
                 mlp_channels: List[int],
                 num_point: Optional[int] = None,
                 radius: Optional[float] = None,
                 num_sample: Optional[int] = None,
                 norm_cfg: ConfigType = dict(type='BN2d'),
                 use_xyz: bool = True,
                 pool_mod: str = 'max',
                 fps_mod: List[str] = ['D-FPS'],
                 fps_sample_range_list: List[int] = [-1],
                 normalize_xyz: bool = False) -> None:
344
        super(PointSAModule, self).__init__(
wuyuefeng's avatar
wuyuefeng committed
345
346
347
348
349
350
351
            mlp_channels=[mlp_channels],
            num_point=num_point,
            radii=[radius],
            sample_nums=[num_sample],
            norm_cfg=norm_cfg,
            use_xyz=use_xyz,
            pool_mod=pool_mod,
352
353
            fps_mod=fps_mod,
            fps_sample_range_list=fps_sample_range_list,
wuyuefeng's avatar
wuyuefeng committed
354
            normalize_xyz=normalize_xyz)