sparse_unet.py 11.4 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
wuyuefeng's avatar
wuyuefeng committed
2
import torch
3
from mmcv.runner import BaseModule, auto_fp16
wuyuefeng's avatar
wuyuefeng committed
4

wuyuefeng's avatar
wuyuefeng committed
5
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
zhangwenwei's avatar
zhangwenwei committed
6
from mmdet3d.ops import spconv as spconv
7
from ..builder import MIDDLE_ENCODERS
wuyuefeng's avatar
wuyuefeng committed
8
9


10
@MIDDLE_ENCODERS.register_module()
11
class SparseUNet(BaseModule):
zhangwenwei's avatar
zhangwenwei committed
12
    r"""SparseUNet for PartA^2.
wuyuefeng's avatar
wuyuefeng committed
13

14
    See the `paper <https://arxiv.org/abs/1907.03670>`_ for more details.
wuyuefeng's avatar
wuyuefeng committed
15
16

    Args:
wangtai's avatar
wangtai committed
17
18
19
20
21
        in_channels (int): The number of input channels.
        sparse_shape (list[int]): The sparse shape of input tensor.
        norm_cfg (dict): Config of normalization layer.
        base_channels (int): Out channels for conv_input layer.
        output_channels (int): Out channels for conv_out layer.
wuyuefeng's avatar
wuyuefeng committed
22
        encoder_channels (tuple[tuple[int]]):
wangtai's avatar
wangtai committed
23
24
            Convolutional channels of each encode block.
        encoder_paddings (tuple[tuple[int]]): Paddings of each encode block.
wuyuefeng's avatar
wuyuefeng committed
25
        decoder_channels (tuple[tuple[int]]):
wangtai's avatar
wangtai committed
26
27
            Convolutional channels of each decode block.
        decoder_paddings (tuple[tuple[int]]): Paddings of each decode block.
wuyuefeng's avatar
wuyuefeng committed
28
    """
wuyuefeng's avatar
wuyuefeng committed
29
30
31

    def __init__(self,
                 in_channels,
wuyuefeng's avatar
wuyuefeng committed
32
33
                 sparse_shape,
                 order=('conv', 'norm', 'act'),
wuyuefeng's avatar
wuyuefeng committed
34
35
                 norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
                 base_channels=16,
36
37
38
39
40
41
42
                 output_channels=128,
                 encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64,
                                                                        64)),
                 encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
                                                                 1)),
                 decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16),
                                   (16, 16, 16)),
43
44
45
                 decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1)),
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
wuyuefeng's avatar
wuyuefeng committed
46
        self.sparse_shape = sparse_shape
wuyuefeng's avatar
wuyuefeng committed
47
        self.in_channels = in_channels
wuyuefeng's avatar
wuyuefeng committed
48
        self.order = order
wuyuefeng's avatar
wuyuefeng committed
49
        self.base_channels = base_channels
50
51
52
53
54
55
        self.output_channels = output_channels
        self.encoder_channels = encoder_channels
        self.encoder_paddings = encoder_paddings
        self.decoder_channels = decoder_channels
        self.decoder_paddings = decoder_paddings
        self.stage_num = len(self.encoder_channels)
56
        self.fp16_enabled = False
wuyuefeng's avatar
wuyuefeng committed
57
58
        # Spconv init all weight on its own

wuyuefeng's avatar
wuyuefeng committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        assert isinstance(order, tuple) and len(order) == 3
        assert set(order) == {'conv', 'norm', 'act'}

        if self.order[0] != 'conv':  # pre activate
            self.conv_input = make_sparse_convmodule(
                in_channels,
                self.base_channels,
                3,
                norm_cfg=norm_cfg,
                padding=1,
                indice_key='subm1',
                conv_type='SubMConv3d',
                order=('conv', ))
        else:  # post activate
            self.conv_input = make_sparse_convmodule(
                in_channels,
                self.base_channels,
                3,
                norm_cfg=norm_cfg,
                padding=1,
                indice_key='subm1',
                conv_type='SubMConv3d')
wuyuefeng's avatar
wuyuefeng committed
81

82
        encoder_out_channels = self.make_encoder_layers(
wuyuefeng's avatar
wuyuefeng committed
83
84
85
86
87
88
89
90
91
92
93
94
95
            make_sparse_convmodule, norm_cfg, self.base_channels)
        self.make_decoder_layers(make_sparse_convmodule, norm_cfg,
                                 encoder_out_channels)

        self.conv_out = make_sparse_convmodule(
            encoder_out_channels,
            self.output_channels,
            kernel_size=(3, 1, 1),
            stride=(2, 1, 1),
            norm_cfg=norm_cfg,
            padding=0,
            indice_key='spconv_down2',
            conv_type='SparseConv3d')
wuyuefeng's avatar
wuyuefeng committed
96

97
    @auto_fp16(apply_to=('voxel_features', ))
wuyuefeng's avatar
wuyuefeng committed
98
    def forward(self, voxel_features, coors, batch_size):
zhangwenwei's avatar
zhangwenwei committed
99
        """Forward of SparseUNet.
wuyuefeng's avatar
wuyuefeng committed
100
101

        Args:
zhangwenwei's avatar
zhangwenwei committed
102
103
104
105
            voxel_features (torch.float32): Voxel features in shape [N, C].
            coors (torch.int32): Coordinates in shape [N, 4],
                the columns in the order of (batch_idx, z_idx, y_idx, x_idx).
            batch_size (int): Batch size.
wuyuefeng's avatar
wuyuefeng committed
106
107

        Returns:
zhangwenwei's avatar
zhangwenwei committed
108
            dict[str, torch.Tensor]: Backbone features.
wuyuefeng's avatar
wuyuefeng committed
109
110
111
112
113
114
115
        """
        coors = coors.int()
        input_sp_tensor = spconv.SparseConvTensor(voxel_features, coors,
                                                  self.sparse_shape,
                                                  batch_size)
        x = self.conv_input(input_sp_tensor)

wuyuefeng's avatar
wuyuefeng committed
116
        encode_features = []
wuyuefeng's avatar
wuyuefeng committed
117
118
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x)
wuyuefeng's avatar
wuyuefeng committed
119
            encode_features.append(x)
wuyuefeng's avatar
wuyuefeng committed
120
121
122

        # for detection head
        # [200, 176, 5] -> [200, 176, 2]
wuyuefeng's avatar
wuyuefeng committed
123
        out = self.conv_out(encode_features[-1])
wuyuefeng's avatar
wuyuefeng committed
124
125
126
127
128
        spatial_features = out.dense()

        N, C, D, H, W = spatial_features.shape
        spatial_features = spatial_features.view(N, C * D, H, W)

wuyuefeng's avatar
wuyuefeng committed
129
        # for segmentation head, with output shape:
wuyuefeng's avatar
wuyuefeng committed
130
131
132
133
        # [400, 352, 11] <- [200, 176, 5]
        # [800, 704, 21] <- [400, 352, 11]
        # [1600, 1408, 41] <- [800, 704, 21]
        # [1600, 1408, 41] <- [1600, 1408, 41]
wuyuefeng's avatar
wuyuefeng committed
134
135
136
        decode_features = []
        x = encode_features[-1]
        for i in range(self.stage_num, 0, -1):
wuyuefeng's avatar
wuyuefeng committed
137
138
139
140
            x = self.decoder_layer_forward(encode_features[i - 1], x,
                                           getattr(self, f'lateral_layer{i}'),
                                           getattr(self, f'merge_layer{i}'),
                                           getattr(self, f'upsample_layer{i}'))
wuyuefeng's avatar
wuyuefeng committed
141
            decode_features.append(x)
wuyuefeng's avatar
wuyuefeng committed
142

wuyuefeng's avatar
wuyuefeng committed
143
        seg_features = decode_features[-1].features
wuyuefeng's avatar
wuyuefeng committed
144

wuyuefeng's avatar
wuyuefeng committed
145
146
        ret = dict(
            spatial_features=spatial_features, seg_features=seg_features)
wuyuefeng's avatar
wuyuefeng committed
147
148
149

        return ret

wuyuefeng's avatar
wuyuefeng committed
150
151
    def decoder_layer_forward(self, x_lateral, x_bottom, lateral_layer,
                              merge_layer, upsample_layer):
wuyuefeng's avatar
wuyuefeng committed
152
153
154
        """Forward of upsample and residual block.

        Args:
zhangwenwei's avatar
zhangwenwei committed
155
156
157
158
159
            x_lateral (:obj:`SparseConvTensor`): Lateral tensor.
            x_bottom (:obj:`SparseConvTensor`): Feature from bottom layer.
            lateral_layer (SparseBasicBlock): Convolution for lateral tensor.
            merge_layer (SparseSequential): Convolution for merging features.
            upsample_layer (SparseSequential): Convolution for upsampling.
wuyuefeng's avatar
wuyuefeng committed
160
161

        Returns:
zhangwenwei's avatar
zhangwenwei committed
162
            :obj:`SparseConvTensor`: Upsampled feature.
wuyuefeng's avatar
wuyuefeng committed
163
        """
wuyuefeng's avatar
wuyuefeng committed
164
165
166
167
168
        x = lateral_layer(x_lateral)
        x.features = torch.cat((x_bottom.features, x.features), dim=1)
        x_merge = merge_layer(x)
        x = self.reduce_channel(x, x_merge.features.shape[1])
        x.features = x_merge.features + x.features
wuyuefeng's avatar
wuyuefeng committed
169
        x = upsample_layer(x)
wuyuefeng's avatar
wuyuefeng committed
170
171
172
        return x

    @staticmethod
wuyuefeng's avatar
wuyuefeng committed
173
174
    def reduce_channel(x, out_channels):
        """reduce channel for element-wise addition.
wuyuefeng's avatar
wuyuefeng committed
175
176

        Args:
zhangwenwei's avatar
zhangwenwei committed
177
178
179
            x (:obj:`SparseConvTensor`): Sparse tensor, ``x.features``
                are in shape (N, C1).
            out_channels (int): The number of channel after reduction.
wuyuefeng's avatar
wuyuefeng committed
180
181

        Returns:
zhangwenwei's avatar
zhangwenwei committed
182
            :obj:`SparseConvTensor`: Channel reduced feature.
wuyuefeng's avatar
wuyuefeng committed
183
184
185
        """
        features = x.features
        n, in_channels = features.shape
wuyuefeng's avatar
wuyuefeng committed
186
187
        assert (in_channels % out_channels
                == 0) and (in_channels >= out_channels)
wuyuefeng's avatar
wuyuefeng committed
188
189
190
191

        x.features = features.view(n, out_channels, -1).sum(dim=2)
        return x

192
    def make_encoder_layers(self, make_block, norm_cfg, in_channels):
zhangwenwei's avatar
zhangwenwei committed
193
        """make encoder layers using sparse convs.
wuyuefeng's avatar
wuyuefeng committed
194
195

        Args:
zhangwenwei's avatar
zhangwenwei committed
196
197
198
            make_block (method): A bounded function to build blocks.
            norm_cfg (dict[str]): Config of normalization layer.
            in_channels (int): The number of encoder input channels.
wuyuefeng's avatar
wuyuefeng committed
199
200

        Returns:
wangtai's avatar
wangtai committed
201
            int: The number of encoder output channels.
wuyuefeng's avatar
wuyuefeng committed
202
        """
wuyuefeng's avatar
wuyuefeng committed
203
        self.encoder_layers = spconv.SparseSequential()
wuyuefeng's avatar
wuyuefeng committed
204

205
        for i, blocks in enumerate(self.encoder_channels):
wuyuefeng's avatar
wuyuefeng committed
206
207
            blocks_list = []
            for j, out_channels in enumerate(tuple(blocks)):
208
                padding = tuple(self.encoder_paddings[i])[j]
wuyuefeng's avatar
wuyuefeng committed
209
210
211
212
213
214
215
216
217
218
219
                # each stage started with a spconv layer
                # except the first stage
                if i != 0 and j == 0:
                    blocks_list.append(
                        make_block(
                            in_channels,
                            out_channels,
                            3,
                            norm_cfg=norm_cfg,
                            stride=2,
                            padding=padding,
220
                            indice_key=f'spconv{i + 1}',
wuyuefeng's avatar
wuyuefeng committed
221
                            conv_type='SparseConv3d'))
wuyuefeng's avatar
wuyuefeng committed
222
223
224
225
226
227
228
229
                else:
                    blocks_list.append(
                        make_block(
                            in_channels,
                            out_channels,
                            3,
                            norm_cfg=norm_cfg,
                            padding=padding,
wuyuefeng's avatar
wuyuefeng committed
230
231
                            indice_key=f'subm{i + 1}',
                            conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
232
                in_channels = out_channels
233
            stage_name = f'encoder_layer{i + 1}'
wuyuefeng's avatar
wuyuefeng committed
234
            stage_layers = spconv.SparseSequential(*blocks_list)
wuyuefeng's avatar
wuyuefeng committed
235
            self.encoder_layers.add_module(stage_name, stage_layers)
wuyuefeng's avatar
wuyuefeng committed
236
237
        return out_channels

238
    def make_decoder_layers(self, make_block, norm_cfg, in_channels):
zhangwenwei's avatar
zhangwenwei committed
239
        """make decoder layers using sparse convs.
wuyuefeng's avatar
wuyuefeng committed
240
241

        Args:
zhangwenwei's avatar
zhangwenwei committed
242
243
244
            make_block (method): A bounded function to build blocks.
            norm_cfg (dict[str]): Config of normalization layer.
            in_channels (int): The number of encoder input channels.
wuyuefeng's avatar
wuyuefeng committed
245
246

        Returns:
zhangwenwei's avatar
zhangwenwei committed
247
            int: The number of encoder output channels.
wuyuefeng's avatar
wuyuefeng committed
248
        """
249
250
251
        block_num = len(self.decoder_channels)
        for i, block_channels in enumerate(self.decoder_channels):
            paddings = self.decoder_paddings[i]
wuyuefeng's avatar
wuyuefeng committed
252
            setattr(
253
                self, f'lateral_layer{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
254
255
256
257
                SparseBasicBlock(
                    in_channels,
                    block_channels[0],
                    conv_cfg=dict(
258
                        type='SubMConv3d', indice_key=f'subm{block_num - i}'),
wuyuefeng's avatar
wuyuefeng committed
259
260
                    norm_cfg=norm_cfg))
            setattr(
261
                self, f'merge_layer{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
262
263
264
265
266
267
                make_block(
                    in_channels * 2,
                    block_channels[1],
                    3,
                    norm_cfg=norm_cfg,
                    padding=paddings[0],
wuyuefeng's avatar
wuyuefeng committed
268
269
                    indice_key=f'subm{block_num - i}',
                    conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
270
271
272
273
274
275
276
277
278
            if block_num - i != 1:
                setattr(
                    self, f'upsample_layer{block_num - i}',
                    make_block(
                        in_channels,
                        block_channels[2],
                        3,
                        norm_cfg=norm_cfg,
                        indice_key=f'spconv{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
279
                        conv_type='SparseInverseConv3d'))
wuyuefeng's avatar
wuyuefeng committed
280
281
            else:
                # use submanifold conv instead of inverse conv
wuyuefeng's avatar
wuyuefeng committed
282
                # in the last block
wuyuefeng's avatar
wuyuefeng committed
283
284
285
286
287
288
289
290
291
                setattr(
                    self, f'upsample_layer{block_num - i}',
                    make_block(
                        in_channels,
                        block_channels[2],
                        3,
                        norm_cfg=norm_cfg,
                        padding=paddings[1],
                        indice_key='subm1',
wuyuefeng's avatar
wuyuefeng committed
292
                        conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
293
            in_channels = block_channels[2]