sparse_unetv2.py 14.1 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
3
4
import torch
import torch.nn as nn

import mmdet3d.ops.spconv as spconv
wuyuefeng's avatar
wuyuefeng committed
5
from mmdet3d.ops import SparseBasicBlock
wuyuefeng's avatar
wuyuefeng committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from mmdet.ops import build_norm_layer
from ..registry import MIDDLE_ENCODERS


@MIDDLE_ENCODERS.register_module
class SparseUnetV2(nn.Module):

    def __init__(self,
                 in_channels,
                 output_shape,
                 pre_act,
                 norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01)):
        """SparseUnet for PartA^2

        Args:
            in_channels (int): the number of input channels
            output_shape (list[int]): the shape of output tensor
            pre_act (bool): use pre_act_block or post_act_block
            norm_cfg (dict): normalize layer config
        """
        super().__init__()
        self.sparse_shape = output_shape
        self.output_shape = output_shape
        self.in_channels = in_channels
        self.pre_act = pre_act
        # Spconv init all weight on its own
        # TODO: make the network could be modified

        if pre_act:
            self.conv_input = spconv.SparseSequential(
                spconv.SubMConv3d(
                    in_channels,
                    16,
                    3,
                    padding=1,
                    bias=False,
                    indice_key='subm1'), )
            block = self.pre_act_block
        else:
            norm_name, norm_layer = build_norm_layer(norm_cfg, 16)
            self.conv_input = spconv.SparseSequential(
                spconv.SubMConv3d(
                    in_channels,
                    16,
                    3,
                    padding=1,
                    bias=False,
                    indice_key='subm1'),
                norm_layer,
                nn.ReLU(),
            )
            block = self.post_act_block

        self.conv1 = spconv.SparseSequential(
            block(16, 16, 3, norm_cfg=norm_cfg, padding=1,
                  indice_key='subm1'), )

        self.conv2 = spconv.SparseSequential(
            # [1600, 1408, 41] -> [800, 704, 21]
            block(
                16,
                32,
                3,
                norm_cfg=norm_cfg,
                stride=2,
                padding=1,
                indice_key='spconv2',
                conv_type='spconv'),
            block(32, 32, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm2'),
            block(32, 32, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm2'),
        )

        self.conv3 = spconv.SparseSequential(
            # [800, 704, 21] -> [400, 352, 11]
            block(
                32,
                64,
                3,
                norm_cfg=norm_cfg,
                stride=2,
                padding=1,
                indice_key='spconv3',
                conv_type='spconv'),
            block(64, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm3'),
            block(64, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm3'),
        )

        self.conv4 = spconv.SparseSequential(
            # [400, 352, 11] -> [200, 176, 5]
            block(
                64,
                64,
                3,
                norm_cfg=norm_cfg,
                stride=2,
                padding=(0, 1, 1),
                indice_key='spconv4',
                conv_type='spconv'),
            block(64, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm4'),
            block(64, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm4'),
        )

        norm_name, norm_layer = build_norm_layer(norm_cfg, 128)
        self.conv_out = spconv.SparseSequential(
            # [200, 176, 5] -> [200, 176, 2]
            spconv.SparseConv3d(
                64,
                128, (3, 1, 1),
                stride=(2, 1, 1),
                padding=0,
                bias=False,
                indice_key='spconv_down2'),
            norm_layer,
            nn.ReLU(),
        )

        # decoder
        # [400, 352, 11] <- [200, 176, 5]
        self.conv_up_t4 = SparseBasicBlock(
wuyuefeng's avatar
wuyuefeng committed
125
126
127
128
            64,
            64,
            conv_cfg=dict(type='SubMConv3d', indice_key='subm4'),
            norm_cfg=norm_cfg)
wuyuefeng's avatar
wuyuefeng committed
129
130
131
132
133
134
135
136
137
138
139
140
        self.conv_up_m4 = block(
            128, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm4')
        self.inv_conv4 = block(
            64,
            64,
            3,
            norm_cfg=norm_cfg,
            indice_key='spconv4',
            conv_type='inverseconv')

        # [800, 704, 21] <- [400, 352, 11]
        self.conv_up_t3 = SparseBasicBlock(
wuyuefeng's avatar
wuyuefeng committed
141
142
143
144
            64,
            64,
            conv_cfg=dict(type='SubMConv3d', indice_key='subm3'),
            norm_cfg=norm_cfg)
wuyuefeng's avatar
wuyuefeng committed
145
146
147
148
149
150
151
152
153
154
155
156
        self.conv_up_m3 = block(
            128, 64, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm3')
        self.inv_conv3 = block(
            64,
            32,
            3,
            norm_cfg=norm_cfg,
            indice_key='spconv3',
            conv_type='inverseconv')

        # [1600, 1408, 41] <- [800, 704, 21]
        self.conv_up_t2 = SparseBasicBlock(
wuyuefeng's avatar
wuyuefeng committed
157
158
159
160
            32,
            32,
            conv_cfg=dict(type='SubMConv3d', indice_key='subm2'),
            norm_cfg=norm_cfg)
wuyuefeng's avatar
wuyuefeng committed
161
162
163
164
165
166
167
168
169
170
171
172
        self.conv_up_m2 = block(
            64, 32, 3, norm_cfg=norm_cfg, indice_key='subm2')
        self.inv_conv2 = block(
            32,
            16,
            3,
            norm_cfg=norm_cfg,
            indice_key='spconv2',
            conv_type='inverseconv')

        # [1600, 1408, 41] <- [1600, 1408, 41]
        self.conv_up_t1 = SparseBasicBlock(
wuyuefeng's avatar
wuyuefeng committed
173
174
175
176
            16,
            16,
            conv_cfg=dict(type='SubMConv3d', indice_key='subm1'),
            norm_cfg=norm_cfg)
wuyuefeng's avatar
wuyuefeng committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
        self.conv_up_m1 = block(
            32, 16, 3, norm_cfg=norm_cfg, indice_key='subm1')

        self.conv5 = spconv.SparseSequential(
            block(16, 16, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm1'))

        self.seg_cls_layer = nn.Linear(16, 1, bias=True)
        self.seg_reg_layer = nn.Linear(16, 3, bias=True)

    def forward(self, voxel_features, coors, batch_size):
        """Forward of SparseUnetV2

        Args:
            voxel_features (torch.float32): shape [N, C]
            coors (torch.int32): shape [N, 4](batch_idx, z_idx, y_idx, x_idx)
            batch_size (int): batch size

        Returns:
            dict: backbone features
        """
        coors = coors.int()
        input_sp_tensor = spconv.SparseConvTensor(voxel_features, coors,
                                                  self.sparse_shape,
                                                  batch_size)
        x = self.conv_input(input_sp_tensor)

        x_conv1 = self.conv1(x)
        x_conv2 = self.conv2(x_conv1)
        x_conv3 = self.conv3(x_conv2)
        x_conv4 = self.conv4(x_conv3)

        # for detection head
        # [200, 176, 5] -> [200, 176, 2]
        out = self.conv_out(x_conv4)
        spatial_features = out.dense()

        N, C, D, H, W = spatial_features.shape
        spatial_features = spatial_features.view(N, C * D, H, W)

        ret = {'spatial_features': spatial_features}

        # for segmentation head
        # [400, 352, 11] <- [200, 176, 5]
        x_up4 = self.UR_block_forward(x_conv4, x_conv4, self.conv_up_t4,
                                      self.conv_up_m4, self.inv_conv4)
        # [800, 704, 21] <- [400, 352, 11]
        x_up3 = self.UR_block_forward(x_conv3, x_up4, self.conv_up_t3,
                                      self.conv_up_m3, self.inv_conv3)
        # [1600, 1408, 41] <- [800, 704, 21]
        x_up2 = self.UR_block_forward(x_conv2, x_up3, self.conv_up_t2,
                                      self.conv_up_m2, self.inv_conv2)
        # [1600, 1408, 41] <- [1600, 1408, 41]
        x_up1 = self.UR_block_forward(x_conv1, x_up2, self.conv_up_t1,
                                      self.conv_up_m1, self.conv5)

        seg_features = x_up1.features

        seg_cls_preds = self.seg_cls_layer(seg_features)  # (N, 1)
        seg_reg_preds = self.seg_reg_layer(seg_features)  # (N, 3)

        ret.update({
            'u_seg_preds': seg_cls_preds,
            'u_reg_preds': seg_reg_preds,
            'seg_features': seg_features
        })

        return ret

    def UR_block_forward(self, x_lateral, x_bottom, conv_t, conv_m, conv_inv):
        """Forward of upsample and residual block.

        Args:
            x_lateral (SparseConvTensor): lateral tensor
            x_bottom (SparseConvTensor): tensor from bottom layer
            conv_t (SparseBasicBlock): convolution for lateral tensor
            conv_m (SparseSequential): convolution for merging features
            conv_inv (SparseSequential): convolution for upsampling

        Returns:
            SparseConvTensor: upsampled feature
        """
        x_trans = conv_t(x_lateral)
        x = x_trans
        x.features = torch.cat((x_bottom.features, x_trans.features), dim=1)
        x_m = conv_m(x)
        x = self.channel_reduction(x, x_m.features.shape[1])
        x.features = x_m.features + x.features
        x = conv_inv(x)
        return x

    @staticmethod
    def channel_reduction(x, out_channels):
        """Channel reduction for element-wise add.

        Args:
            x (SparseConvTensor): x.features (N, C1)
            out_channels (int): the number of channel after reduction

        Returns:
            SparseConvTensor: channel reduced feature
        """
        features = x.features
        n, in_channels = features.shape
        assert (in_channels %
                out_channels == 0) and (in_channels >= out_channels)

        x.features = features.view(n, out_channels, -1).sum(dim=2)
        return x

    def pre_act_block(self,
                      in_channels,
                      out_channels,
                      kernel_size,
                      indice_key=None,
                      stride=1,
                      padding=0,
                      conv_type='subm',
                      norm_cfg=None):
        """Make pre activate sparse convolution block.

        Args:
            in_channels (int): the number of input channels
            out_channels (int): the number of out channels
            kernel_size (int): kernel size of convolution
            indice_key (str): the indice key used for sparse tensor
            stride (int): the stride of convolution
            padding (int or list[int]): the padding number of input
            conv_type (str): conv type in 'subm', 'spconv' or 'inverseconv'
            norm_cfg (dict): normal layer configs

        Returns:
            spconv.SparseSequential: pre activate sparse convolution block.
        """
        assert conv_type in ['subm', 'spconv', 'inverseconv']

        norm_name, norm_layer = build_norm_layer(norm_cfg, in_channels)
        if conv_type == 'subm':
            m = spconv.SparseSequential(
                norm_layer,
                nn.ReLU(inplace=True),
                spconv.SubMConv3d(
                    in_channels,
                    out_channels,
                    kernel_size,
                    padding=padding,
                    bias=False,
                    indice_key=indice_key),
            )
        elif conv_type == 'spconv':
            m = spconv.SparseSequential(
                norm_layer,
                nn.ReLU(inplace=True),
                spconv.SparseConv3d(
                    in_channels,
                    out_channels,
                    kernel_size,
                    stride=stride,
                    padding=padding,
                    bias=False,
                    indice_key=indice_key),
            )
        elif conv_type == 'inverseconv':
            m = spconv.SparseSequential(
                norm_layer,
                nn.ReLU(inplace=True),
                spconv.SparseInverseConv3d(
                    in_channels,
                    out_channels,
                    kernel_size,
                    bias=False,
                    indice_key=indice_key),
            )
        else:
            raise NotImplementedError
        return m

    def post_act_block(self,
                       in_channels,
                       out_channels,
                       kernel_size,
                       indice_key,
                       stride=1,
                       padding=0,
                       conv_type='subm',
                       norm_cfg=None):
        """Make post activate sparse convolution block.

        Args:
            in_channels (int): the number of input channels
            out_channels (int): the number of out channels
            kernel_size (int): kernel size of convolution
            indice_key (str): the indice key used for sparse tensor
            stride (int): the stride of convolution
            padding (int or list[int]): the padding number of input
            conv_type (str): conv type in 'subm', 'spconv' or 'inverseconv'
            norm_cfg (dict[str]): normal layer configs

        Returns:
            spconv.SparseSequential: post activate sparse convolution block.
        """
        assert conv_type in ['subm', 'spconv', 'inverseconv']

        norm_name, norm_layer = build_norm_layer(norm_cfg, out_channels)
        if conv_type == 'subm':
            m = spconv.SparseSequential(
                spconv.SubMConv3d(
                    in_channels,
                    out_channels,
                    kernel_size,
                    bias=False,
                    indice_key=indice_key),
                norm_layer,
                nn.ReLU(inplace=True),
            )
        elif conv_type == 'spconv':
            m = spconv.SparseSequential(
                spconv.SparseConv3d(
                    in_channels,
                    out_channels,
                    kernel_size,
                    stride=stride,
                    padding=padding,
                    bias=False,
                    indice_key=indice_key),
                norm_layer,
                nn.ReLU(inplace=True),
            )
        elif conv_type == 'inverseconv':
            m = spconv.SparseSequential(
                spconv.SparseInverseConv3d(
                    in_channels,
                    out_channels,
                    kernel_size,
                    bias=False,
                    indice_key=indice_key),
                norm_layer,
                nn.ReLU(inplace=True),
            )
        else:
            raise NotImplementedError
        return m