sparse_unet.py 11.5 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
wuyuefeng's avatar
wuyuefeng committed
2
import torch
VVsssssk's avatar
VVsssssk committed
3
4
5
6
7
8
9
10

from mmdet3d.ops.spconv import IS_SPCONV2_AVAILABLE

if IS_SPCONV2_AVAILABLE:
    from spconv.pytorch import SparseConvTensor, SparseSequential
else:
    from mmcv.ops import SparseConvTensor, SparseSequential

11
from mmcv.runner import BaseModule, auto_fp16
wuyuefeng's avatar
wuyuefeng committed
12

wuyuefeng's avatar
wuyuefeng committed
13
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
14
from ..builder import MIDDLE_ENCODERS
wuyuefeng's avatar
wuyuefeng committed
15
16


17
@MIDDLE_ENCODERS.register_module()
18
class SparseUNet(BaseModule):
zhangwenwei's avatar
zhangwenwei committed
19
    r"""SparseUNet for PartA^2.
wuyuefeng's avatar
wuyuefeng committed
20

21
    See the `paper <https://arxiv.org/abs/1907.03670>`_ for more details.
wuyuefeng's avatar
wuyuefeng committed
22
23

    Args:
wangtai's avatar
wangtai committed
24
25
26
27
28
        in_channels (int): The number of input channels.
        sparse_shape (list[int]): The sparse shape of input tensor.
        norm_cfg (dict): Config of normalization layer.
        base_channels (int): Out channels for conv_input layer.
        output_channels (int): Out channels for conv_out layer.
wuyuefeng's avatar
wuyuefeng committed
29
        encoder_channels (tuple[tuple[int]]):
wangtai's avatar
wangtai committed
30
31
            Convolutional channels of each encode block.
        encoder_paddings (tuple[tuple[int]]): Paddings of each encode block.
wuyuefeng's avatar
wuyuefeng committed
32
        decoder_channels (tuple[tuple[int]]):
wangtai's avatar
wangtai committed
33
34
            Convolutional channels of each decode block.
        decoder_paddings (tuple[tuple[int]]): Paddings of each decode block.
wuyuefeng's avatar
wuyuefeng committed
35
    """
wuyuefeng's avatar
wuyuefeng committed
36
37
38

    def __init__(self,
                 in_channels,
wuyuefeng's avatar
wuyuefeng committed
39
40
                 sparse_shape,
                 order=('conv', 'norm', 'act'),
wuyuefeng's avatar
wuyuefeng committed
41
42
                 norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
                 base_channels=16,
43
44
45
46
47
48
49
                 output_channels=128,
                 encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64,
                                                                        64)),
                 encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
                                                                 1)),
                 decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16),
                                   (16, 16, 16)),
50
51
52
                 decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1)),
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
wuyuefeng's avatar
wuyuefeng committed
53
        self.sparse_shape = sparse_shape
wuyuefeng's avatar
wuyuefeng committed
54
        self.in_channels = in_channels
wuyuefeng's avatar
wuyuefeng committed
55
        self.order = order
wuyuefeng's avatar
wuyuefeng committed
56
        self.base_channels = base_channels
57
58
59
60
61
62
        self.output_channels = output_channels
        self.encoder_channels = encoder_channels
        self.encoder_paddings = encoder_paddings
        self.decoder_channels = decoder_channels
        self.decoder_paddings = decoder_paddings
        self.stage_num = len(self.encoder_channels)
63
        self.fp16_enabled = False
wuyuefeng's avatar
wuyuefeng committed
64
65
        # Spconv init all weight on its own

wuyuefeng's avatar
wuyuefeng committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        assert isinstance(order, tuple) and len(order) == 3
        assert set(order) == {'conv', 'norm', 'act'}

        if self.order[0] != 'conv':  # pre activate
            self.conv_input = make_sparse_convmodule(
                in_channels,
                self.base_channels,
                3,
                norm_cfg=norm_cfg,
                padding=1,
                indice_key='subm1',
                conv_type='SubMConv3d',
                order=('conv', ))
        else:  # post activate
            self.conv_input = make_sparse_convmodule(
                in_channels,
                self.base_channels,
                3,
                norm_cfg=norm_cfg,
                padding=1,
                indice_key='subm1',
                conv_type='SubMConv3d')
wuyuefeng's avatar
wuyuefeng committed
88

89
        encoder_out_channels = self.make_encoder_layers(
wuyuefeng's avatar
wuyuefeng committed
90
91
92
93
94
95
96
97
98
99
100
101
102
            make_sparse_convmodule, norm_cfg, self.base_channels)
        self.make_decoder_layers(make_sparse_convmodule, norm_cfg,
                                 encoder_out_channels)

        self.conv_out = make_sparse_convmodule(
            encoder_out_channels,
            self.output_channels,
            kernel_size=(3, 1, 1),
            stride=(2, 1, 1),
            norm_cfg=norm_cfg,
            padding=0,
            indice_key='spconv_down2',
            conv_type='SparseConv3d')
wuyuefeng's avatar
wuyuefeng committed
103

104
    @auto_fp16(apply_to=('voxel_features', ))
wuyuefeng's avatar
wuyuefeng committed
105
    def forward(self, voxel_features, coors, batch_size):
zhangwenwei's avatar
zhangwenwei committed
106
        """Forward of SparseUNet.
wuyuefeng's avatar
wuyuefeng committed
107
108

        Args:
zhangwenwei's avatar
zhangwenwei committed
109
110
111
112
            voxel_features (torch.float32): Voxel features in shape [N, C].
            coors (torch.int32): Coordinates in shape [N, 4],
                the columns in the order of (batch_idx, z_idx, y_idx, x_idx).
            batch_size (int): Batch size.
wuyuefeng's avatar
wuyuefeng committed
113
114

        Returns:
zhangwenwei's avatar
zhangwenwei committed
115
            dict[str, torch.Tensor]: Backbone features.
wuyuefeng's avatar
wuyuefeng committed
116
117
        """
        coors = coors.int()
118
119
        input_sp_tensor = SparseConvTensor(voxel_features, coors,
                                           self.sparse_shape, batch_size)
wuyuefeng's avatar
wuyuefeng committed
120
121
        x = self.conv_input(input_sp_tensor)

wuyuefeng's avatar
wuyuefeng committed
122
        encode_features = []
wuyuefeng's avatar
wuyuefeng committed
123
124
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x)
wuyuefeng's avatar
wuyuefeng committed
125
            encode_features.append(x)
wuyuefeng's avatar
wuyuefeng committed
126
127
128

        # for detection head
        # [200, 176, 5] -> [200, 176, 2]
wuyuefeng's avatar
wuyuefeng committed
129
        out = self.conv_out(encode_features[-1])
wuyuefeng's avatar
wuyuefeng committed
130
131
132
133
134
        spatial_features = out.dense()

        N, C, D, H, W = spatial_features.shape
        spatial_features = spatial_features.view(N, C * D, H, W)

wuyuefeng's avatar
wuyuefeng committed
135
        # for segmentation head, with output shape:
wuyuefeng's avatar
wuyuefeng committed
136
137
138
139
        # [400, 352, 11] <- [200, 176, 5]
        # [800, 704, 21] <- [400, 352, 11]
        # [1600, 1408, 41] <- [800, 704, 21]
        # [1600, 1408, 41] <- [1600, 1408, 41]
wuyuefeng's avatar
wuyuefeng committed
140
141
142
        decode_features = []
        x = encode_features[-1]
        for i in range(self.stage_num, 0, -1):
wuyuefeng's avatar
wuyuefeng committed
143
144
145
146
            x = self.decoder_layer_forward(encode_features[i - 1], x,
                                           getattr(self, f'lateral_layer{i}'),
                                           getattr(self, f'merge_layer{i}'),
                                           getattr(self, f'upsample_layer{i}'))
wuyuefeng's avatar
wuyuefeng committed
147
            decode_features.append(x)
wuyuefeng's avatar
wuyuefeng committed
148

wuyuefeng's avatar
wuyuefeng committed
149
        seg_features = decode_features[-1].features
wuyuefeng's avatar
wuyuefeng committed
150

wuyuefeng's avatar
wuyuefeng committed
151
152
        ret = dict(
            spatial_features=spatial_features, seg_features=seg_features)
wuyuefeng's avatar
wuyuefeng committed
153
154
155

        return ret

wuyuefeng's avatar
wuyuefeng committed
156
157
    def decoder_layer_forward(self, x_lateral, x_bottom, lateral_layer,
                              merge_layer, upsample_layer):
wuyuefeng's avatar
wuyuefeng committed
158
159
160
        """Forward of upsample and residual block.

        Args:
zhangwenwei's avatar
zhangwenwei committed
161
162
163
164
165
            x_lateral (:obj:`SparseConvTensor`): Lateral tensor.
            x_bottom (:obj:`SparseConvTensor`): Feature from bottom layer.
            lateral_layer (SparseBasicBlock): Convolution for lateral tensor.
            merge_layer (SparseSequential): Convolution for merging features.
            upsample_layer (SparseSequential): Convolution for upsampling.
wuyuefeng's avatar
wuyuefeng committed
166
167

        Returns:
zhangwenwei's avatar
zhangwenwei committed
168
            :obj:`SparseConvTensor`: Upsampled feature.
wuyuefeng's avatar
wuyuefeng committed
169
        """
wuyuefeng's avatar
wuyuefeng committed
170
171
172
173
174
        x = lateral_layer(x_lateral)
        x.features = torch.cat((x_bottom.features, x.features), dim=1)
        x_merge = merge_layer(x)
        x = self.reduce_channel(x, x_merge.features.shape[1])
        x.features = x_merge.features + x.features
wuyuefeng's avatar
wuyuefeng committed
175
        x = upsample_layer(x)
wuyuefeng's avatar
wuyuefeng committed
176
177
178
        return x

    @staticmethod
wuyuefeng's avatar
wuyuefeng committed
179
180
    def reduce_channel(x, out_channels):
        """reduce channel for element-wise addition.
wuyuefeng's avatar
wuyuefeng committed
181
182

        Args:
zhangwenwei's avatar
zhangwenwei committed
183
184
185
            x (:obj:`SparseConvTensor`): Sparse tensor, ``x.features``
                are in shape (N, C1).
            out_channels (int): The number of channel after reduction.
wuyuefeng's avatar
wuyuefeng committed
186
187

        Returns:
zhangwenwei's avatar
zhangwenwei committed
188
            :obj:`SparseConvTensor`: Channel reduced feature.
wuyuefeng's avatar
wuyuefeng committed
189
190
191
        """
        features = x.features
        n, in_channels = features.shape
wuyuefeng's avatar
wuyuefeng committed
192
193
        assert (in_channels % out_channels
                == 0) and (in_channels >= out_channels)
wuyuefeng's avatar
wuyuefeng committed
194
195
196
197

        x.features = features.view(n, out_channels, -1).sum(dim=2)
        return x

198
    def make_encoder_layers(self, make_block, norm_cfg, in_channels):
zhangwenwei's avatar
zhangwenwei committed
199
        """make encoder layers using sparse convs.
wuyuefeng's avatar
wuyuefeng committed
200
201

        Args:
zhangwenwei's avatar
zhangwenwei committed
202
203
204
            make_block (method): A bounded function to build blocks.
            norm_cfg (dict[str]): Config of normalization layer.
            in_channels (int): The number of encoder input channels.
wuyuefeng's avatar
wuyuefeng committed
205
206

        Returns:
wangtai's avatar
wangtai committed
207
            int: The number of encoder output channels.
wuyuefeng's avatar
wuyuefeng committed
208
        """
209
        self.encoder_layers = SparseSequential()
wuyuefeng's avatar
wuyuefeng committed
210

211
        for i, blocks in enumerate(self.encoder_channels):
wuyuefeng's avatar
wuyuefeng committed
212
213
            blocks_list = []
            for j, out_channels in enumerate(tuple(blocks)):
214
                padding = tuple(self.encoder_paddings[i])[j]
wuyuefeng's avatar
wuyuefeng committed
215
216
217
218
219
220
221
222
223
224
225
                # each stage started with a spconv layer
                # except the first stage
                if i != 0 and j == 0:
                    blocks_list.append(
                        make_block(
                            in_channels,
                            out_channels,
                            3,
                            norm_cfg=norm_cfg,
                            stride=2,
                            padding=padding,
226
                            indice_key=f'spconv{i + 1}',
wuyuefeng's avatar
wuyuefeng committed
227
                            conv_type='SparseConv3d'))
wuyuefeng's avatar
wuyuefeng committed
228
229
230
231
232
233
234
235
                else:
                    blocks_list.append(
                        make_block(
                            in_channels,
                            out_channels,
                            3,
                            norm_cfg=norm_cfg,
                            padding=padding,
wuyuefeng's avatar
wuyuefeng committed
236
237
                            indice_key=f'subm{i + 1}',
                            conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
238
                in_channels = out_channels
239
            stage_name = f'encoder_layer{i + 1}'
240
            stage_layers = SparseSequential(*blocks_list)
wuyuefeng's avatar
wuyuefeng committed
241
            self.encoder_layers.add_module(stage_name, stage_layers)
wuyuefeng's avatar
wuyuefeng committed
242
243
        return out_channels

244
    def make_decoder_layers(self, make_block, norm_cfg, in_channels):
zhangwenwei's avatar
zhangwenwei committed
245
        """make decoder layers using sparse convs.
wuyuefeng's avatar
wuyuefeng committed
246
247

        Args:
zhangwenwei's avatar
zhangwenwei committed
248
249
250
            make_block (method): A bounded function to build blocks.
            norm_cfg (dict[str]): Config of normalization layer.
            in_channels (int): The number of encoder input channels.
wuyuefeng's avatar
wuyuefeng committed
251
252

        Returns:
zhangwenwei's avatar
zhangwenwei committed
253
            int: The number of encoder output channels.
wuyuefeng's avatar
wuyuefeng committed
254
        """
255
256
257
        block_num = len(self.decoder_channels)
        for i, block_channels in enumerate(self.decoder_channels):
            paddings = self.decoder_paddings[i]
wuyuefeng's avatar
wuyuefeng committed
258
            setattr(
259
                self, f'lateral_layer{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
260
261
262
263
                SparseBasicBlock(
                    in_channels,
                    block_channels[0],
                    conv_cfg=dict(
264
                        type='SubMConv3d', indice_key=f'subm{block_num - i}'),
wuyuefeng's avatar
wuyuefeng committed
265
266
                    norm_cfg=norm_cfg))
            setattr(
267
                self, f'merge_layer{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
268
269
270
271
272
273
                make_block(
                    in_channels * 2,
                    block_channels[1],
                    3,
                    norm_cfg=norm_cfg,
                    padding=paddings[0],
wuyuefeng's avatar
wuyuefeng committed
274
275
                    indice_key=f'subm{block_num - i}',
                    conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
276
277
278
279
280
281
282
283
284
            if block_num - i != 1:
                setattr(
                    self, f'upsample_layer{block_num - i}',
                    make_block(
                        in_channels,
                        block_channels[2],
                        3,
                        norm_cfg=norm_cfg,
                        indice_key=f'spconv{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
285
                        conv_type='SparseInverseConv3d'))
wuyuefeng's avatar
wuyuefeng committed
286
287
            else:
                # use submanifold conv instead of inverse conv
wuyuefeng's avatar
wuyuefeng committed
288
                # in the last block
wuyuefeng's avatar
wuyuefeng committed
289
290
291
292
293
294
295
296
297
                setattr(
                    self, f'upsample_layer{block_num - i}',
                    make_block(
                        in_channels,
                        block_channels[2],
                        3,
                        norm_cfg=norm_cfg,
                        padding=paddings[1],
                        indice_key='subm1',
wuyuefeng's avatar
wuyuefeng committed
298
                        conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
299
            in_channels = block_channels[2]