paconv.py 15.8 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import copy
3
from typing import List, Tuple, Union
4

5
import torch
6
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
7
from mmcv.ops import assign_score_withk as assign_score_cuda
8
from mmengine.model import constant_init
9
from torch import Tensor
10
11
12
from torch import nn as nn
from torch.nn import functional as F

13
from mmdet3d.utils import ConfigType
14
15
16
17
from .utils import assign_kernel_withoutk, assign_score, calc_euclidian_dist


class ScoreNet(nn.Module):
18
    r"""ScoreNet that outputs coefficient scores to assemble kernel weights in
19
20
21
22
    the weight bank according to the relative position of point pairs.

    Args:
        mlp_channels (List[int]): Hidden unit sizes of SharedMLP layers.
23
        last_bn (bool): Whether to use BN on the last output of mlps.
24
            Defaults to False.
25
        score_norm (str): Normalization function of output scores.
26
            Can be 'softmax', 'sigmoid' or 'identity'. Defaults to 'softmax'.
27
        temp_factor (float): Temperature factor to scale the output
28
            scores before softmax. Defaults to 1.0.
29
30
31
32
        norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
            layer. Defaults to dict(type='BN2d').
        bias (bool or str): If specified as `auto`, it will be decided by
            `norm_cfg`. `bias` will be set as True if `norm_cfg` is None,
33
            otherwise False. Defaults to 'auto'.
34
35
36

    Note:
        The official code applies xavier_init to all Conv layers in ScoreNet,
37
38
39
40
        see `PAConv <https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg
        /model/pointnet2/paconv.py#L105>`_. However in our experiments, we
        did not find much difference in applying such xavier initialization
        or not. So we neglect this initialization in our implementation.
41
42
43
    """

    def __init__(self,
44
45
46
47
48
49
                 mlp_channels: List[int],
                 last_bn: bool = False,
                 score_norm: str = 'softmax',
                 temp_factor: float = 1.0,
                 norm_cfg: ConfigType = dict(type='BN2d'),
                 bias: Union[bool, str] = 'auto') -> None:
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        super(ScoreNet, self).__init__()

        assert score_norm in ['softmax', 'sigmoid', 'identity'], \
            f'unsupported score_norm function {score_norm}'

        self.score_norm = score_norm
        self.temp_factor = temp_factor

        self.mlps = nn.Sequential()
        for i in range(len(mlp_channels) - 2):
            self.mlps.add_module(
                f'layer{i}',
                ConvModule(
                    mlp_channels[i],
                    mlp_channels[i + 1],
                    kernel_size=(1, 1),
                    stride=(1, 1),
                    conv_cfg=dict(type='Conv2d'),
                    norm_cfg=norm_cfg,
                    bias=bias))

        # for the last mlp that outputs scores, no relu and possibly no bn
        i = len(mlp_channels) - 2
        self.mlps.add_module(
            f'layer{i}',
            ConvModule(
                mlp_channels[i],
                mlp_channels[i + 1],
                kernel_size=(1, 1),
                stride=(1, 1),
                conv_cfg=dict(type='Conv2d'),
                norm_cfg=norm_cfg if last_bn else None,
                act_cfg=None,
                bias=bias))

85
    def forward(self, xyz_features: Tensor) -> Tensor:
86
87
88
        """Forward.

        Args:
89
90
91
            xyz_features (Tensor): (B, C, N, K) Features constructed from xyz
                coordinates of point pairs. May contain relative positions,
                Euclidean distance, etc.
92
93

        Returns:
94
            Tensor: (B, N, K, M) Predicted scores for `M` kernels.
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        """
        scores = self.mlps(xyz_features)  # (B, M, N, K)

        # perform score normalization
        if self.score_norm == 'softmax':
            scores = F.softmax(scores / self.temp_factor, dim=1)
        elif self.score_norm == 'sigmoid':
            scores = torch.sigmoid(scores / self.temp_factor)
        else:  # 'identity'
            scores = scores

        scores = scores.permute(0, 2, 3, 1)  # (B, N, K, M)

        return scores


class PAConv(nn.Module):
    """Non-CUDA version of PAConv.

114
    PAConv stores a trainable weight bank containing several kernel weights.
115
116
117
118
119
120
    Given input points and features, it computes coefficient scores to assemble
    those kernels to form conv kernels, and then runs convolution on the input.

    Args:
        in_channels (int): Input channels of point features.
        out_channels (int): Output channels of point features.
121
        num_kernels (int): Number of kernel weights in the weight bank.
122
123
124
        norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
            layer. Defaults to dict(type='BN2d', momentum=0.1).
        act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
125
            Defaults to dict(type='ReLU', inplace=True).
126
        scorenet_input (str): Type of input to ScoreNet.
127
128
            Can be 'identity', 'w_neighbor' or 'w_neighbor_dist'.
            Defaults to 'w_neighbor_dist'.
129
        weight_bank_init (str): Init method of weight bank kernels.
130
            Can be 'kaiming' or 'xavier'. Defaults to 'kaiming'.
131
132
        kernel_input (str): Input features to be multiplied with kernel
            weights. Can be 'identity' or 'w_neighbor'.
133
            Defaults to 'w_neighbor'.
134
135
        scorenet_cfg (dict): Config of the ScoreNet module, which may contain
            the following keys and values:
136
137
138

            - mlp_channels (List[int]): Hidden units of MLPs.
            - score_norm (str): Normalization function of output scores.
139
              Can be 'softmax', 'sigmoid' or 'identity'.
140
            - temp_factor (float): Temperature factor to scale the output
141
              scores before softmax.
142
            - last_bn (bool): Whether to use BN on the last output of mlps.
143
144
145
146
            Defaults to dict(mlp_channels=[16, 16, 16],
                             score_norm='softmax',
                             temp_factor=1.0,
                             last_bn=False).
147
148
    """

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        num_kernels: int,
        norm_cfg: ConfigType = dict(type='BN2d', momentum=0.1),
        act_cfg: ConfigType = dict(type='ReLU', inplace=True),
        scorenet_input: str = 'w_neighbor_dist',
        weight_bank_init: str = 'kaiming',
        kernel_input: str = 'w_neighbor',
        scorenet_cfg: dict = dict(
            mlp_channels=[16, 16, 16],
            score_norm='softmax',
            temp_factor=1.0,
            last_bn=False)
    ) -> None:
165
166
167
168
169
        super(PAConv, self).__init__()

        # determine weight kernel size according to used features
        if kernel_input == 'identity':
            # only use grouped_features
170
            kernel_mul = 1
171
172
        elif kernel_input == 'w_neighbor':
            # concat of (grouped_features - center_features, grouped_features)
173
            kernel_mul = 2
174
175
176
177
        else:
            raise NotImplementedError(
                f'unsupported kernel_input {kernel_input}')
        self.kernel_input = kernel_input
178
        in_channels = kernel_mul * in_channels
179
180
181
182
183
184
185
186
187

        # determine mlp channels in ScoreNet according to used xyz features
        if scorenet_input == 'identity':
            # only use relative position (grouped_xyz - center_xyz)
            self.scorenet_in_channels = 3
        elif scorenet_input == 'w_neighbor':
            # (grouped_xyz - center_xyz, grouped_xyz)
            self.scorenet_in_channels = 6
        elif scorenet_input == 'w_neighbor_dist':
188
            # (center_xyz, grouped_xyz - center_xyz, Euclidean distance)
189
190
191
192
193
194
            self.scorenet_in_channels = 7
        else:
            raise NotImplementedError(
                f'unsupported scorenet_input {scorenet_input}')
        self.scorenet_input = scorenet_input

195
        # construct kernel weights in weight bank
196
197
198
199
200
201
202
203
204
205
        # self.weight_bank is of shape [C, num_kernels * out_c]
        # where C can be in_c or (2 * in_c)
        if weight_bank_init == 'kaiming':
            weight_init = nn.init.kaiming_normal_
        elif weight_bank_init == 'xavier':
            weight_init = nn.init.xavier_normal_
        else:
            raise NotImplementedError(
                f'unsupported weight bank init method {weight_bank_init}')

206
        self.num_kernels = num_kernels  # the parameter `m` in the paper
207
        weight_bank = weight_init(
208
            torch.empty(self.num_kernels, in_channels, out_channels))
209
        weight_bank = weight_bank.permute(1, 0, 2).reshape(
210
            in_channels, self.num_kernels * out_channels).contiguous()
211
212
213
214
215
        self.weight_bank = nn.Parameter(weight_bank, requires_grad=True)

        # construct ScoreNet
        scorenet_cfg_ = copy.deepcopy(scorenet_cfg)
        scorenet_cfg_['mlp_channels'].insert(0, self.scorenet_in_channels)
216
        scorenet_cfg_['mlp_channels'].append(self.num_kernels)
217
218
219
220
221
222
223
        self.scorenet = ScoreNet(**scorenet_cfg_)

        self.bn = build_norm_layer(norm_cfg, out_channels)[1] if \
            norm_cfg is not None else None
        self.activate = build_activation_layer(act_cfg) if \
            act_cfg is not None else None

224
225
226
227
        # set some basic attributes of Conv layers
        self.in_channels = in_channels
        self.out_channels = out_channels

228
229
        self.init_weights()

230
    def init_weights(self) -> None:
231
        """Initialize weights of shared MLP layers and BN layers."""
232
        if self.bn is not None:
233
            constant_init(self.bn, val=1, bias=0)
234

235
    def _prepare_scorenet_input(self, points_xyz: Tensor) -> Tensor:
236
237
238
        """Prepare input point pairs features for self.ScoreNet.

        Args:
239
240
            points_xyz (Tensor): (B, 3, npoint, K) Coordinates of the
                grouped points.
241
242

        Returns:
243
            Tensor: (B, C, npoint, K) The generated features per point pair.
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        """
        B, _, npoint, K = points_xyz.size()
        center_xyz = points_xyz[..., :1].repeat(1, 1, 1, K)
        xyz_diff = points_xyz - center_xyz  # [B, 3, npoint, K]
        if self.scorenet_input == 'identity':
            xyz_features = xyz_diff
        elif self.scorenet_input == 'w_neighbor':
            xyz_features = torch.cat((xyz_diff, points_xyz), dim=1)
        else:  # w_neighbor_dist
            euclidian_dist = calc_euclidian_dist(
                center_xyz.permute(0, 2, 3, 1).reshape(B * npoint * K, 3),
                points_xyz.permute(0, 2, 3, 1).reshape(B * npoint * K, 3)).\
                    reshape(B, 1, npoint, K)
            xyz_features = torch.cat((center_xyz, xyz_diff, euclidian_dist),
                                     dim=1)
        return xyz_features

261
    def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]:
262
263
264
        """Forward.

        Args:
265
            inputs (Tuple[Tensor]):
266

267
268
269
270
                - features (Tensor): (B, in_c, npoint, K)
                  Features of the queried points.
                - points_xyz (Tensor): (B, 3, npoint, K)
                  Coordinates of the grouped points.
271
272

        Returns:
273
            Tuple[Tensor]:
274

275
276
                - new_features: (B, out_c, npoint, K) Features after PAConv.
                - points_xyz: Same as input.
277
        """
278
        features, points_xyz = inputs
279
280
281
282
283
284
285
286
287
288
289
        B, _, npoint, K = features.size()

        if self.kernel_input == 'w_neighbor':
            center_features = features[..., :1].repeat(1, 1, 1, K)
            features_diff = features - center_features
            # to (B, 2 * in_c, npoint, K)
            features = torch.cat((features_diff, features), dim=1)

        # prepare features for between each point and its grouping center
        xyz_features = self._prepare_scorenet_input(points_xyz)

290
        # scores to assemble kernel weights
291
292
293
294
295
        scores = self.scorenet(xyz_features)  # [B, npoint, K, m]

        # first compute out features over all kernels
        # features is [B, C, npoint, K], weight_bank is [C, m * out_c]
        new_features = torch.matmul(
296
297
298
            features.permute(0, 2, 3, 1),
            self.weight_bank).view(B, npoint, K, self.num_kernels,
                                   -1)  # [B, npoint, K, m, out_c]
299
300
301
302
303
304
305
306
307
308
309

        # then aggregate using scores
        new_features = assign_score(scores, new_features)
        # to [B, out_c, npoint, K]
        new_features = new_features.permute(0, 3, 1, 2).contiguous()

        if self.bn is not None:
            new_features = self.bn(new_features)
        if self.activate is not None:
            new_features = self.activate(new_features)

310
311
312
        # in order to keep input output consistency
        # so that we can wrap PAConv in Sequential
        return (new_features, points_xyz)
313
314
315
316
317
318
319
320
321
322
323
324
325


class PAConvCUDA(PAConv):
    """CUDA version of PAConv that implements a cuda op to efficiently perform
    kernel assembling.

    Different from vanilla PAConv, the input features of this function is not
    grouped by centers. Instead, they will be queried on-the-fly by the
    additional input `points_idx`. This avoids the large intermediate matrix.
    See the `paper <https://arxiv.org/pdf/2103.14635.pdf>`_ appendix Sec. D for
    more detailed descriptions.
    """

326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        num_kernels: int,
        norm_cfg: ConfigType = dict(type='BN2d', momentum=0.1),
        act_cfg: ConfigType = dict(type='ReLU', inplace=True),
        scorenet_input: str = 'w_neighbor_dist',
        weight_bank_init: str = 'kaiming',
        kernel_input: str = 'w_neighbor',
        scorenet_cfg: dict = dict(
            mlp_channels=[8, 16, 16],
            score_norm='softmax',
            temp_factor=1.0,
            last_bn=False)
    ) -> None:
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        super(PAConvCUDA, self).__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            num_kernels=num_kernels,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg,
            scorenet_input=scorenet_input,
            weight_bank_init=weight_bank_init,
            kernel_input=kernel_input,
            scorenet_cfg=scorenet_cfg)

        assert self.kernel_input == 'w_neighbor', \
            'CUDA implemented PAConv only supports w_neighbor kernel_input'

356
    def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]:
357
358
359
        """Forward.

        Args:
360
            inputs (Tuple[Tensor]):
361

362
363
364
365
366
367
368
369
                - features (Tensor): (B, in_c, N)
                  Features of all points in the current point cloud.
                  Different from non-CUDA version PAConv, here the features
                  are not grouped by each center to form a K dim.
                - points_xyz (Tensor): (B, 3, npoint, K)
                  Coordinates of the grouped points.
                - points_idx (Tensor): (B, npoint, K)
                  Index of the grouped points.
370
371

        Returns:
372
            Tuple[Tensor]:
373

374
375
376
                - new_features: (B, out_c, npoint, K) Features after PAConv.
                - points_xyz: Same as input.
                - points_idx: Same as input.
377
        """
378
379
        features, points_xyz, points_idx = inputs

380
381
382
        # prepare features for between each point and its grouping center
        xyz_features = self._prepare_scorenet_input(points_xyz)

383
        # scores to assemble kernel weights
384
385
386
387
388
        scores = self.scorenet(xyz_features)  # [B, npoint, K, m]

        # pre-compute features for points and centers separately
        # features is [B, in_c, N], weight_bank is [C, m * out_dim]
        point_feat, center_feat = assign_kernel_withoutk(
389
            features, self.weight_bank, self.num_kernels)
390
391
392
393
394
395
396
397
398
399
400

        # aggregate features using custom cuda op
        new_features = assign_score_cuda(
            scores, point_feat, center_feat, points_idx,
            'sum').contiguous()  # [B, out_c, npoint, K]

        if self.bn is not None:
            new_features = self.bn(new_features)
        if self.activate is not None:
            new_features = self.activate(new_features)

401
402
        # in order to keep input output consistency
        return (new_features, points_xyz, points_idx)