paconv_sa_module.py 12.9 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
import torch
from torch import nn as nn

zhangshilong's avatar
zhangshilong committed
5
from mmdet3d.models.layers.paconv import PAConv, PAConvCUDA
6
7
8
9
10
11
12
13
14
15
16
17
18
from .builder import SA_MODULES
from .point_sa_module import BasePointSAModule


@SA_MODULES.register_module()
class PAConvSAModuleMSG(BasePointSAModule):
    r"""Point set abstraction module with multi-scale grouping (MSG) used in
    PAConv networks.

    Replace the MLPs in `PointSAModuleMSG` with PAConv layers.
    See the `paper <https://arxiv.org/abs/2103.14635>`_ for more details.

    Args:
19
        paconv_num_kernels (list[list[int]]): Number of kernel weights in the
20
21
            weight banks of each layer's PAConv.
        paconv_kernel_input (str, optional): Input features to be multiplied
22
            with kernel weights. Can be 'identity' or 'w_neighbor'.
23
24
25
26
27
28
29
30
            Defaults to 'w_neighbor'.
        scorenet_input (str, optional): Type of the input to ScoreNet.
            Defaults to 'w_neighbor_dist'. Can be the following values:

            - 'identity': Use xyz coordinates as input.
            - 'w_neighbor': Use xyz coordinates and the difference with center
                points as input.
            - 'w_neighbor_dist': Use xyz coordinates, the difference with
31
                center points and the Euclidean distance as input.
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

        scorenet_cfg (dict, optional): Config of the ScoreNet module, which
            may contain the following keys and values:

            - mlp_channels (List[int]): Hidden units of MLPs.
            - score_norm (str): Normalization function of output scores.
                Can be 'softmax', 'sigmoid' or 'identity'.
            - temp_factor (float): Temperature factor to scale the output
                scores before softmax.
            - last_bn (bool): Whether to use BN on the last output of mlps.
    """

    def __init__(self,
                 num_point,
                 radii,
                 sample_nums,
                 mlp_channels,
                 paconv_num_kernels,
                 fps_mod=['D-FPS'],
                 fps_sample_range_list=[-1],
                 dilated_group=False,
                 norm_cfg=dict(type='BN2d', momentum=0.1),
                 use_xyz=True,
                 pool_mod='max',
                 normalize_xyz=False,
                 bias='auto',
                 paconv_kernel_input='w_neighbor',
                 scorenet_input='w_neighbor_dist',
                 scorenet_cfg=dict(
                     mlp_channels=[16, 16, 16],
                     score_norm='softmax',
                     temp_factor=1.0,
                     last_bn=False)):
        super(PAConvSAModuleMSG, self).__init__(
            num_point=num_point,
            radii=radii,
            sample_nums=sample_nums,
            mlp_channels=mlp_channels,
            fps_mod=fps_mod,
            fps_sample_range_list=fps_sample_range_list,
            dilated_group=dilated_group,
            use_xyz=use_xyz,
            pool_mod=pool_mod,
            normalize_xyz=normalize_xyz,
            grouper_return_grouped_xyz=True)

        assert len(paconv_num_kernels) == len(mlp_channels)
        for i in range(len(mlp_channels)):
            assert len(paconv_num_kernels[i]) == len(mlp_channels[i]) - 1, \
81
                'PAConv number of kernel weights wrong'
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200

        # in PAConv, bias only exists in ScoreNet
        scorenet_cfg['bias'] = bias

        for i in range(len(self.mlp_channels)):
            mlp_channel = self.mlp_channels[i]
            if use_xyz:
                mlp_channel[0] += 3

            num_kernels = paconv_num_kernels[i]

            mlp = nn.Sequential()
            for i in range(len(mlp_channel) - 1):
                mlp.add_module(
                    f'layer{i}',
                    PAConv(
                        mlp_channel[i],
                        mlp_channel[i + 1],
                        num_kernels[i],
                        norm_cfg=norm_cfg,
                        kernel_input=paconv_kernel_input,
                        scorenet_input=scorenet_input,
                        scorenet_cfg=scorenet_cfg))
            self.mlps.append(mlp)


@SA_MODULES.register_module()
class PAConvSAModule(PAConvSAModuleMSG):
    r"""Point set abstraction module with single-scale grouping (SSG) used in
    PAConv networks.

    Replace the MLPs in `PointSAModule` with PAConv layers. See the `paper
    <https://arxiv.org/abs/2103.14635>`_ for more details.
    """

    def __init__(self,
                 mlp_channels,
                 paconv_num_kernels,
                 num_point=None,
                 radius=None,
                 num_sample=None,
                 norm_cfg=dict(type='BN2d', momentum=0.1),
                 use_xyz=True,
                 pool_mod='max',
                 fps_mod=['D-FPS'],
                 fps_sample_range_list=[-1],
                 normalize_xyz=False,
                 paconv_kernel_input='w_neighbor',
                 scorenet_input='w_neighbor_dist',
                 scorenet_cfg=dict(
                     mlp_channels=[16, 16, 16],
                     score_norm='softmax',
                     temp_factor=1.0,
                     last_bn=False)):
        super(PAConvSAModule, self).__init__(
            mlp_channels=[mlp_channels],
            paconv_num_kernels=[paconv_num_kernels],
            num_point=num_point,
            radii=[radius],
            sample_nums=[num_sample],
            norm_cfg=norm_cfg,
            use_xyz=use_xyz,
            pool_mod=pool_mod,
            fps_mod=fps_mod,
            fps_sample_range_list=fps_sample_range_list,
            normalize_xyz=normalize_xyz,
            paconv_kernel_input=paconv_kernel_input,
            scorenet_input=scorenet_input,
            scorenet_cfg=scorenet_cfg)


@SA_MODULES.register_module()
class PAConvCUDASAModuleMSG(BasePointSAModule):
    r"""Point set abstraction module with multi-scale grouping (MSG) used in
    PAConv networks.

    Replace the non CUDA version PAConv with CUDA implemented PAConv for
    efficient computation. See the `paper <https://arxiv.org/abs/2103.14635>`_
    for more details.
    """

    def __init__(self,
                 num_point,
                 radii,
                 sample_nums,
                 mlp_channels,
                 paconv_num_kernels,
                 fps_mod=['D-FPS'],
                 fps_sample_range_list=[-1],
                 dilated_group=False,
                 norm_cfg=dict(type='BN2d', momentum=0.1),
                 use_xyz=True,
                 pool_mod='max',
                 normalize_xyz=False,
                 bias='auto',
                 paconv_kernel_input='w_neighbor',
                 scorenet_input='w_neighbor_dist',
                 scorenet_cfg=dict(
                     mlp_channels=[8, 16, 16],
                     score_norm='softmax',
                     temp_factor=1.0,
                     last_bn=False)):
        super(PAConvCUDASAModuleMSG, self).__init__(
            num_point=num_point,
            radii=radii,
            sample_nums=sample_nums,
            mlp_channels=mlp_channels,
            fps_mod=fps_mod,
            fps_sample_range_list=fps_sample_range_list,
            dilated_group=dilated_group,
            use_xyz=use_xyz,
            pool_mod=pool_mod,
            normalize_xyz=normalize_xyz,
            grouper_return_grouped_xyz=True,
            grouper_return_grouped_idx=True)

        assert len(paconv_num_kernels) == len(mlp_channels)
        for i in range(len(mlp_channels)):
            assert len(paconv_num_kernels[i]) == len(mlp_channels[i]) - 1, \
201
                'PAConv number of kernel weights wrong'
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241

        # in PAConv, bias only exists in ScoreNet
        scorenet_cfg['bias'] = bias

        # we need to manually concat xyz for CUDA implemented PAConv
        self.use_xyz = use_xyz

        for i in range(len(self.mlp_channels)):
            mlp_channel = self.mlp_channels[i]
            if use_xyz:
                mlp_channel[0] += 3

            num_kernels = paconv_num_kernels[i]

            # can't use `nn.Sequential` for PAConvCUDA because its input and
            # output have different shapes
            mlp = nn.ModuleList()
            for i in range(len(mlp_channel) - 1):
                mlp.append(
                    PAConvCUDA(
                        mlp_channel[i],
                        mlp_channel[i + 1],
                        num_kernels[i],
                        norm_cfg=norm_cfg,
                        kernel_input=paconv_kernel_input,
                        scorenet_input=scorenet_input,
                        scorenet_cfg=scorenet_cfg))
            self.mlps.append(mlp)

    def forward(
        self,
        points_xyz,
        features=None,
        indices=None,
        target_xyz=None,
    ):
        """forward.

        Args:
            points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
242
            features (Tensor, optional): (B, C, N) features of each point.
243
                Default: None.
244
245
246
            indices (Tensor, optional): (B, num_point) Index of the features.
                Default: None.
            target_xyz (Tensor, optional): (B, M, 3) new coords of the outputs.
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
                Default: None.

        Returns:
            Tensor: (B, M, 3) where M is the number of points.
                New features xyz.
            Tensor: (B, M, sum_k(mlps[k][-1])) where M is the number
                of points. New feature descriptors.
            Tensor: (B, M) where M is the number of points.
                Index of the features.
        """
        new_features_list = []

        # sample points, (B, num_point, 3), (B, num_point)
        new_xyz, indices = self._sample_points(points_xyz, features, indices,
                                               target_xyz)

        for i in range(len(self.groupers)):
            xyz = points_xyz
            new_features = features
            for j in range(len(self.mlps[i])):
                # we don't use grouped_features here to avoid large GPU memory
                # _, (B, 3, num_point, nsample), (B, num_point, nsample)
                _, grouped_xyz, grouped_idx = self.groupers[i](xyz, new_xyz,
                                                               new_features)

                # concat xyz as additional features
                if self.use_xyz and j == 0:
                    # (B, C+3, N)
                    new_features = torch.cat(
                        (points_xyz.permute(0, 2, 1), new_features), dim=1)

                # (B, out_c, num_point, nsample)
                grouped_new_features = self.mlps[i][j](
                    (new_features, grouped_xyz, grouped_idx.long()))[0]

                # different from PointNet++ and non CUDA version of PAConv
                # CUDA version of PAConv needs to aggregate local features
                # every time after it passes through a Conv layer
                # in order to transform to valid input shape
                # (B, out_c, num_point)
                new_features = self._pool_features(grouped_new_features)

                # constrain the points to be grouped for next PAConv layer
                # because new_features only contains sampled centers now
                # (B, num_point, 3)
                xyz = new_xyz

            new_features_list.append(new_features)

        return new_xyz, torch.cat(new_features_list, dim=1), indices


@SA_MODULES.register_module()
class PAConvCUDASAModule(PAConvCUDASAModuleMSG):
    r"""Point set abstraction module with single-scale grouping (SSG) used in
    PAConv networks.

    Replace the non CUDA version PAConv with CUDA implemented PAConv for
    efficient computation. See the `paper <https://arxiv.org/abs/2103.14635>`_
    for more details.
    """

    def __init__(self,
                 mlp_channels,
                 paconv_num_kernels,
                 num_point=None,
                 radius=None,
                 num_sample=None,
                 norm_cfg=dict(type='BN2d', momentum=0.1),
                 use_xyz=True,
                 pool_mod='max',
                 fps_mod=['D-FPS'],
                 fps_sample_range_list=[-1],
                 normalize_xyz=False,
                 paconv_kernel_input='w_neighbor',
                 scorenet_input='w_neighbor_dist',
                 scorenet_cfg=dict(
                     mlp_channels=[8, 16, 16],
                     score_norm='softmax',
                     temp_factor=1.0,
                     last_bn=False)):
        super(PAConvCUDASAModule, self).__init__(
            mlp_channels=[mlp_channels],
            paconv_num_kernels=[paconv_num_kernels],
            num_point=num_point,
            radii=[radius],
            sample_nums=[num_sample],
            norm_cfg=norm_cfg,
            use_xyz=use_xyz,
            pool_mod=pool_mod,
            fps_mod=fps_mod,
            fps_sample_range_list=fps_sample_range_list,
            normalize_xyz=normalize_xyz,
            paconv_kernel_input=paconv_kernel_input,
            scorenet_input=scorenet_input,
            scorenet_cfg=scorenet_cfg)