sparse_unet.py 16 KB
Newer Older
wuyuefeng's avatar
wuyuefeng committed
1
2
3
4
import torch
import torch.nn as nn

import mmdet3d.ops.spconv as spconv
wuyuefeng's avatar
wuyuefeng committed
5
from mmdet3d.ops import SparseBasicBlock
wuyuefeng's avatar
wuyuefeng committed
6
7
8
9
10
from mmdet.ops import build_norm_layer
from ..registry import MIDDLE_ENCODERS


@MIDDLE_ENCODERS.register_module
wuyuefeng's avatar
wuyuefeng committed
11
class SparseUNet(nn.Module):
wuyuefeng's avatar
wuyuefeng committed
12
13
14
15

    def __init__(self,
                 in_channels,
                 output_shape,
wuyuefeng's avatar
wuyuefeng committed
16
                 pre_act=False,
wuyuefeng's avatar
wuyuefeng committed
17
18
                 norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
                 base_channels=16,
19
20
21
22
23
24
25
26
                 output_channels=128,
                 encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64,
                                                                        64)),
                 encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
                                                                 1)),
                 decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16),
                                   (16, 16, 16)),
                 decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1))):
wuyuefeng's avatar
wuyuefeng committed
27
        """SparseUNet for PartA^2
wuyuefeng's avatar
wuyuefeng committed
28

wuyuefeng's avatar
wuyuefeng committed
29
30
        See https://arxiv.org/abs/1907.03670 for more detials.

wuyuefeng's avatar
wuyuefeng committed
31
32
33
34
        Args:
            in_channels (int): the number of input channels
            output_shape (list[int]): the shape of output tensor
            pre_act (bool): use pre_act_block or post_act_block
35
            norm_cfg (dict): config of normalization layer
wuyuefeng's avatar
wuyuefeng committed
36
            base_channels (int): out channels for conv_input layer
37
38
39
40
41
            output_channels (int): out channels for conv_out layer
            encoder_channels (tuple[tuple[int]]):
                conv channels of each encode block
            encoder_paddings (tuple[tuple[int]]): paddings of each encode block
            decoder_channels (tuple[tuple[int]]):
wuyuefeng's avatar
wuyuefeng committed
42
                conv channels of each decode block
43
            decoder_paddings (tuple[tuple[int]]): paddings of each decode block
wuyuefeng's avatar
wuyuefeng committed
44
45
46
47
48
49
        """
        super().__init__()
        self.sparse_shape = output_shape
        self.output_shape = output_shape
        self.in_channels = in_channels
        self.pre_act = pre_act
wuyuefeng's avatar
wuyuefeng committed
50
        self.base_channels = base_channels
51
52
53
54
55
56
        self.output_channels = output_channels
        self.encoder_channels = encoder_channels
        self.encoder_paddings = encoder_paddings
        self.decoder_channels = decoder_channels
        self.decoder_paddings = decoder_paddings
        self.stage_num = len(self.encoder_channels)
wuyuefeng's avatar
wuyuefeng committed
57
58
59
        # Spconv init all weight on its own

        if pre_act:
wuyuefeng's avatar
wuyuefeng committed
60
            # TODO: use ConvModule to encapsulate
wuyuefeng's avatar
wuyuefeng committed
61
62
63
            self.conv_input = spconv.SparseSequential(
                spconv.SubMConv3d(
                    in_channels,
wuyuefeng's avatar
wuyuefeng committed
64
                    self.base_channels,
wuyuefeng's avatar
wuyuefeng committed
65
66
67
                    3,
                    padding=1,
                    bias=False,
wuyuefeng's avatar
wuyuefeng committed
68
                    indice_key='subm1'))
wuyuefeng's avatar
wuyuefeng committed
69
            make_block = self.pre_act_block
wuyuefeng's avatar
wuyuefeng committed
70
71
72
73
        else:
            self.conv_input = spconv.SparseSequential(
                spconv.SubMConv3d(
                    in_channels,
wuyuefeng's avatar
wuyuefeng committed
74
                    self.base_channels,
wuyuefeng's avatar
wuyuefeng committed
75
76
77
78
                    3,
                    padding=1,
                    bias=False,
                    indice_key='subm1'),
wuyuefeng's avatar
wuyuefeng committed
79
                build_norm_layer(norm_cfg, self.base_channels)[1], nn.ReLU())
wuyuefeng's avatar
wuyuefeng committed
80
            make_block = self.post_act_block
wuyuefeng's avatar
wuyuefeng committed
81

82
83
84
        encoder_out_channels = self.make_encoder_layers(
            make_block, norm_cfg, self.base_channels)
        self.make_decoder_layers(make_block, norm_cfg, encoder_out_channels)
wuyuefeng's avatar
wuyuefeng committed
85
86
87
88

        self.conv_out = spconv.SparseSequential(
            # [200, 176, 5] -> [200, 176, 2]
            spconv.SparseConv3d(
wuyuefeng's avatar
wuyuefeng committed
89
                encoder_out_channels,
90
                self.output_channels, (3, 1, 1),
wuyuefeng's avatar
wuyuefeng committed
91
92
93
94
                stride=(2, 1, 1),
                padding=0,
                bias=False,
                indice_key='spconv_down2'),
95
            build_norm_layer(norm_cfg, self.output_channels)[1],
wuyuefeng's avatar
wuyuefeng committed
96
            nn.ReLU())
wuyuefeng's avatar
wuyuefeng committed
97
98

    def forward(self, voxel_features, coors, batch_size):
wuyuefeng's avatar
wuyuefeng committed
99
        """Forward of SparseUNet
wuyuefeng's avatar
wuyuefeng committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

        Args:
            voxel_features (torch.float32): shape [N, C]
            coors (torch.int32): shape [N, 4](batch_idx, z_idx, y_idx, x_idx)
            batch_size (int): batch size

        Returns:
            dict: backbone features
        """
        coors = coors.int()
        input_sp_tensor = spconv.SparseConvTensor(voxel_features, coors,
                                                  self.sparse_shape,
                                                  batch_size)
        x = self.conv_input(input_sp_tensor)

wuyuefeng's avatar
wuyuefeng committed
115
        encode_features = []
wuyuefeng's avatar
wuyuefeng committed
116
117
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x)
wuyuefeng's avatar
wuyuefeng committed
118
            encode_features.append(x)
wuyuefeng's avatar
wuyuefeng committed
119
120
121

        # for detection head
        # [200, 176, 5] -> [200, 176, 2]
wuyuefeng's avatar
wuyuefeng committed
122
        out = self.conv_out(encode_features[-1])
wuyuefeng's avatar
wuyuefeng committed
123
124
125
126
127
        spatial_features = out.dense()

        N, C, D, H, W = spatial_features.shape
        spatial_features = spatial_features.view(N, C * D, H, W)

wuyuefeng's avatar
wuyuefeng committed
128
        # for segmentation head, with output shape:
wuyuefeng's avatar
wuyuefeng committed
129
130
131
132
        # [400, 352, 11] <- [200, 176, 5]
        # [800, 704, 21] <- [400, 352, 11]
        # [1600, 1408, 41] <- [800, 704, 21]
        # [1600, 1408, 41] <- [1600, 1408, 41]
wuyuefeng's avatar
wuyuefeng committed
133
134
135
        decode_features = []
        x = encode_features[-1]
        for i in range(self.stage_num, 0, -1):
wuyuefeng's avatar
wuyuefeng committed
136
137
138
139
            x = self.decoder_layer_forward(encode_features[i - 1], x,
                                           getattr(self, f'lateral_layer{i}'),
                                           getattr(self, f'merge_layer{i}'),
                                           getattr(self, f'upsample_layer{i}'))
wuyuefeng's avatar
wuyuefeng committed
140
            decode_features.append(x)
wuyuefeng's avatar
wuyuefeng committed
141

wuyuefeng's avatar
wuyuefeng committed
142
        seg_features = decode_features[-1].features
wuyuefeng's avatar
wuyuefeng committed
143

wuyuefeng's avatar
wuyuefeng committed
144
145
        ret = dict(
            spatial_features=spatial_features, seg_features=seg_features)
wuyuefeng's avatar
wuyuefeng committed
146
147
148

        return ret

wuyuefeng's avatar
wuyuefeng committed
149
150
    def decoder_layer_forward(self, x_lateral, x_bottom, lateral_layer,
                              merge_layer, upsample_layer):
wuyuefeng's avatar
wuyuefeng committed
151
152
153
154
        """Forward of upsample and residual block.

        Args:
            x_lateral (SparseConvTensor): lateral tensor
wuyuefeng's avatar
wuyuefeng committed
155
            x_bottom (SparseConvTensor): feature from bottom layer
wuyuefeng's avatar
wuyuefeng committed
156
157
158
            lateral_layer (SparseBasicBlock): convolution for lateral tensor
            merge_layer (SparseSequential): convolution for merging features
            upsample_layer (SparseSequential): convolution for upsampling
wuyuefeng's avatar
wuyuefeng committed
159
160
161
162

        Returns:
            SparseConvTensor: upsampled feature
        """
wuyuefeng's avatar
wuyuefeng committed
163
164
165
166
167
        x = lateral_layer(x_lateral)
        x.features = torch.cat((x_bottom.features, x.features), dim=1)
        x_merge = merge_layer(x)
        x = self.reduce_channel(x, x_merge.features.shape[1])
        x.features = x_merge.features + x.features
wuyuefeng's avatar
wuyuefeng committed
168
        x = upsample_layer(x)
wuyuefeng's avatar
wuyuefeng committed
169
170
171
        return x

    @staticmethod
wuyuefeng's avatar
wuyuefeng committed
172
173
    def reduce_channel(x, out_channels):
        """reduce channel for element-wise addition.
wuyuefeng's avatar
wuyuefeng committed
174
175
176
177
178
179
180
181
182
183

        Args:
            x (SparseConvTensor): x.features (N, C1)
            out_channels (int): the number of channel after reduction

        Returns:
            SparseConvTensor: channel reduced feature
        """
        features = x.features
        n, in_channels = features.shape
wuyuefeng's avatar
wuyuefeng committed
184
185
        assert (in_channels % out_channels
                == 0) and (in_channels >= out_channels)
wuyuefeng's avatar
wuyuefeng committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

        x.features = features.view(n, out_channels, -1).sum(dim=2)
        return x

    def pre_act_block(self,
                      in_channels,
                      out_channels,
                      kernel_size,
                      indice_key=None,
                      stride=1,
                      padding=0,
                      conv_type='subm',
                      norm_cfg=None):
        """Make pre activate sparse convolution block.

        Args:
            in_channels (int): the number of input channels
            out_channels (int): the number of out channels
            kernel_size (int): kernel size of convolution
            indice_key (str): the indice key used for sparse tensor
            stride (int): the stride of convolution
            padding (int or list[int]): the padding number of input
            conv_type (str): conv type in 'subm', 'spconv' or 'inverseconv'
209
            norm_cfg (dict): config of normalization layer
wuyuefeng's avatar
wuyuefeng committed
210
211
212
213

        Returns:
            spconv.SparseSequential: pre activate sparse convolution block.
        """
wuyuefeng's avatar
wuyuefeng committed
214
        # TODO: use ConvModule to encapsulate
wuyuefeng's avatar
wuyuefeng committed
215
216
217
218
        assert conv_type in ['subm', 'spconv', 'inverseconv']

        if conv_type == 'subm':
            m = spconv.SparseSequential(
219
                build_norm_layer(norm_cfg, in_channels)[1],
wuyuefeng's avatar
wuyuefeng committed
220
221
222
223
224
225
226
                nn.ReLU(inplace=True),
                spconv.SubMConv3d(
                    in_channels,
                    out_channels,
                    kernel_size,
                    padding=padding,
                    bias=False,
wuyuefeng's avatar
wuyuefeng committed
227
                    indice_key=indice_key))
wuyuefeng's avatar
wuyuefeng committed
228
229
        elif conv_type == 'spconv':
            m = spconv.SparseSequential(
230
                build_norm_layer(norm_cfg, in_channels)[1],
wuyuefeng's avatar
wuyuefeng committed
231
232
233
234
235
236
237
238
                nn.ReLU(inplace=True),
                spconv.SparseConv3d(
                    in_channels,
                    out_channels,
                    kernel_size,
                    stride=stride,
                    padding=padding,
                    bias=False,
wuyuefeng's avatar
wuyuefeng committed
239
                    indice_key=indice_key))
wuyuefeng's avatar
wuyuefeng committed
240
241
        elif conv_type == 'inverseconv':
            m = spconv.SparseSequential(
242
                build_norm_layer(norm_cfg, in_channels)[1],
wuyuefeng's avatar
wuyuefeng committed
243
244
245
246
247
248
                nn.ReLU(inplace=True),
                spconv.SparseInverseConv3d(
                    in_channels,
                    out_channels,
                    kernel_size,
                    bias=False,
wuyuefeng's avatar
wuyuefeng committed
249
                    indice_key=indice_key))
wuyuefeng's avatar
wuyuefeng committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        else:
            raise NotImplementedError
        return m

    def post_act_block(self,
                       in_channels,
                       out_channels,
                       kernel_size,
                       indice_key,
                       stride=1,
                       padding=0,
                       conv_type='subm',
                       norm_cfg=None):
        """Make post activate sparse convolution block.

        Args:
            in_channels (int): the number of input channels
            out_channels (int): the number of out channels
            kernel_size (int): kernel size of convolution
            indice_key (str): the indice key used for sparse tensor
            stride (int): the stride of convolution
            padding (int or list[int]): the padding number of input
            conv_type (str): conv type in 'subm', 'spconv' or 'inverseconv'
273
            norm_cfg (dict[str]): config of normalization layer
wuyuefeng's avatar
wuyuefeng committed
274
275
276
277

        Returns:
            spconv.SparseSequential: post activate sparse convolution block.
        """
wuyuefeng's avatar
wuyuefeng committed
278
        # TODO: use ConvModule to encapsulate
wuyuefeng's avatar
wuyuefeng committed
279
280
281
282
283
284
285
286
287
288
        assert conv_type in ['subm', 'spconv', 'inverseconv']

        if conv_type == 'subm':
            m = spconv.SparseSequential(
                spconv.SubMConv3d(
                    in_channels,
                    out_channels,
                    kernel_size,
                    bias=False,
                    indice_key=indice_key),
289
                build_norm_layer(norm_cfg, out_channels)[1],
wuyuefeng's avatar
wuyuefeng committed
290
                nn.ReLU(inplace=True))
wuyuefeng's avatar
wuyuefeng committed
291
292
293
294
295
296
297
298
299
300
        elif conv_type == 'spconv':
            m = spconv.SparseSequential(
                spconv.SparseConv3d(
                    in_channels,
                    out_channels,
                    kernel_size,
                    stride=stride,
                    padding=padding,
                    bias=False,
                    indice_key=indice_key),
301
                build_norm_layer(norm_cfg, out_channels)[1],
wuyuefeng's avatar
wuyuefeng committed
302
                nn.ReLU(inplace=True))
wuyuefeng's avatar
wuyuefeng committed
303
304
305
306
307
308
309
310
        elif conv_type == 'inverseconv':
            m = spconv.SparseSequential(
                spconv.SparseInverseConv3d(
                    in_channels,
                    out_channels,
                    kernel_size,
                    bias=False,
                    indice_key=indice_key),
311
                build_norm_layer(norm_cfg, out_channels)[1],
wuyuefeng's avatar
wuyuefeng committed
312
                nn.ReLU(inplace=True))
wuyuefeng's avatar
wuyuefeng committed
313
314
315
        else:
            raise NotImplementedError
        return m
wuyuefeng's avatar
wuyuefeng committed
316

317
318
    def make_encoder_layers(self, make_block, norm_cfg, in_channels):
        """make encoder layers using sparse convs
wuyuefeng's avatar
wuyuefeng committed
319
320
321

        Args:
            make_block (method): a bounded function to build blocks
322
            norm_cfg (dict[str]): config of normalization layer
wuyuefeng's avatar
wuyuefeng committed
323
324
325
326
327
            in_channels (int): the number of encoder input channels

        Returns:
            int: the number of encoder output channels
        """
wuyuefeng's avatar
wuyuefeng committed
328
        self.encoder_layers = spconv.SparseSequential()
329
        for i, blocks in enumerate(self.encoder_channels):
wuyuefeng's avatar
wuyuefeng committed
330
331
            blocks_list = []
            for j, out_channels in enumerate(tuple(blocks)):
332
                padding = tuple(self.encoder_paddings[i])[j]
wuyuefeng's avatar
wuyuefeng committed
333
334
335
336
337
338
339
340
341
342
343
                # each stage started with a spconv layer
                # except the first stage
                if i != 0 and j == 0:
                    blocks_list.append(
                        make_block(
                            in_channels,
                            out_channels,
                            3,
                            norm_cfg=norm_cfg,
                            stride=2,
                            padding=padding,
344
                            indice_key=f'spconv{i + 1}',
wuyuefeng's avatar
wuyuefeng committed
345
346
347
348
349
350
351
352
353
                            conv_type='spconv'))
                else:
                    blocks_list.append(
                        make_block(
                            in_channels,
                            out_channels,
                            3,
                            norm_cfg=norm_cfg,
                            padding=padding,
354
                            indice_key=f'subm{i + 1}'))
wuyuefeng's avatar
wuyuefeng committed
355
                in_channels = out_channels
356
            stage_name = f'encoder_layer{i + 1}'
wuyuefeng's avatar
wuyuefeng committed
357
            stage_layers = spconv.SparseSequential(*blocks_list)
wuyuefeng's avatar
wuyuefeng committed
358
            self.encoder_layers.add_module(stage_name, stage_layers)
wuyuefeng's avatar
wuyuefeng committed
359
360
        return out_channels

361
362
    def make_decoder_layers(self, make_block, norm_cfg, in_channels):
        """make decoder layers using sparse convs
wuyuefeng's avatar
wuyuefeng committed
363
364
365

        Args:
            make_block (method): a bounded function to build blocks
366
            norm_cfg (dict[str]): config of normalization layer
wuyuefeng's avatar
wuyuefeng committed
367
368
369
370
371
            in_channels (int): the number of encoder input channels

        Returns:
            int: the number of encoder output channels
        """
372
373
374
        block_num = len(self.decoder_channels)
        for i, block_channels in enumerate(self.decoder_channels):
            paddings = self.decoder_paddings[i]
wuyuefeng's avatar
wuyuefeng committed
375
            setattr(
376
                self, f'lateral_layer{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
377
378
379
380
                SparseBasicBlock(
                    in_channels,
                    block_channels[0],
                    conv_cfg=dict(
381
                        type='SubMConv3d', indice_key=f'subm{block_num - i}'),
wuyuefeng's avatar
wuyuefeng committed
382
383
                    norm_cfg=norm_cfg))
            setattr(
384
                self, f'merge_layer{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
385
386
387
388
389
390
                make_block(
                    in_channels * 2,
                    block_channels[1],
                    3,
                    norm_cfg=norm_cfg,
                    padding=paddings[0],
391
                    indice_key=f'subm{block_num - i}'))
wuyuefeng's avatar
wuyuefeng committed
392
393
394
395
396
397
398
399
400
401
402
403
404
            if block_num - i != 1:
                setattr(
                    self, f'upsample_layer{block_num - i}',
                    make_block(
                        in_channels,
                        block_channels[2],
                        3,
                        norm_cfg=norm_cfg,
                        padding=paddings[1],
                        indice_key=f'spconv{block_num - i}',
                        conv_type='inverseconv'))
            else:
                # use submanifold conv instead of inverse conv
wuyuefeng's avatar
wuyuefeng committed
405
                # in the last block
wuyuefeng's avatar
wuyuefeng committed
406
407
408
409
410
411
412
413
414
415
                setattr(
                    self, f'upsample_layer{block_num - i}',
                    make_block(
                        in_channels,
                        block_channels[2],
                        3,
                        norm_cfg=norm_cfg,
                        padding=paddings[1],
                        indice_key='subm1',
                        conv_type='subm'))
wuyuefeng's avatar
wuyuefeng committed
416
            in_channels = block_channels[2]