sparse_unet.py 11.7 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
wuyuefeng's avatar
wuyuefeng committed
2
import torch
VVsssssk's avatar
VVsssssk committed
3
4
5
6
7
8
9
10

from mmdet3d.ops.spconv import IS_SPCONV2_AVAILABLE

if IS_SPCONV2_AVAILABLE:
    from spconv.pytorch import SparseConvTensor, SparseSequential
else:
    from mmcv.ops import SparseConvTensor, SparseSequential

11
from mmcv.runner import BaseModule, auto_fp16
wuyuefeng's avatar
wuyuefeng committed
12

wuyuefeng's avatar
wuyuefeng committed
13
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
VVsssssk's avatar
VVsssssk committed
14
from mmdet3d.ops.sparse_block import replace_feature
15
from ..builder import MIDDLE_ENCODERS
wuyuefeng's avatar
wuyuefeng committed
16
17


18
@MIDDLE_ENCODERS.register_module()
19
class SparseUNet(BaseModule):
zhangwenwei's avatar
zhangwenwei committed
20
    r"""SparseUNet for PartA^2.
wuyuefeng's avatar
wuyuefeng committed
21

22
    See the `paper <https://arxiv.org/abs/1907.03670>`_ for more details.
wuyuefeng's avatar
wuyuefeng committed
23
24

    Args:
wangtai's avatar
wangtai committed
25
26
27
28
29
        in_channels (int): The number of input channels.
        sparse_shape (list[int]): The sparse shape of input tensor.
        norm_cfg (dict): Config of normalization layer.
        base_channels (int): Out channels for conv_input layer.
        output_channels (int): Out channels for conv_out layer.
wuyuefeng's avatar
wuyuefeng committed
30
        encoder_channels (tuple[tuple[int]]):
wangtai's avatar
wangtai committed
31
32
            Convolutional channels of each encode block.
        encoder_paddings (tuple[tuple[int]]): Paddings of each encode block.
wuyuefeng's avatar
wuyuefeng committed
33
        decoder_channels (tuple[tuple[int]]):
wangtai's avatar
wangtai committed
34
35
            Convolutional channels of each decode block.
        decoder_paddings (tuple[tuple[int]]): Paddings of each decode block.
wuyuefeng's avatar
wuyuefeng committed
36
    """
wuyuefeng's avatar
wuyuefeng committed
37
38
39

    def __init__(self,
                 in_channels,
wuyuefeng's avatar
wuyuefeng committed
40
41
                 sparse_shape,
                 order=('conv', 'norm', 'act'),
wuyuefeng's avatar
wuyuefeng committed
42
43
                 norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
                 base_channels=16,
44
45
46
47
48
49
50
                 output_channels=128,
                 encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64,
                                                                        64)),
                 encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
                                                                 1)),
                 decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16),
                                   (16, 16, 16)),
51
52
53
                 decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1)),
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
wuyuefeng's avatar
wuyuefeng committed
54
        self.sparse_shape = sparse_shape
wuyuefeng's avatar
wuyuefeng committed
55
        self.in_channels = in_channels
wuyuefeng's avatar
wuyuefeng committed
56
        self.order = order
wuyuefeng's avatar
wuyuefeng committed
57
        self.base_channels = base_channels
58
59
60
61
62
63
        self.output_channels = output_channels
        self.encoder_channels = encoder_channels
        self.encoder_paddings = encoder_paddings
        self.decoder_channels = decoder_channels
        self.decoder_paddings = decoder_paddings
        self.stage_num = len(self.encoder_channels)
64
        self.fp16_enabled = False
wuyuefeng's avatar
wuyuefeng committed
65
66
        # Spconv init all weight on its own

wuyuefeng's avatar
wuyuefeng committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        assert isinstance(order, tuple) and len(order) == 3
        assert set(order) == {'conv', 'norm', 'act'}

        if self.order[0] != 'conv':  # pre activate
            self.conv_input = make_sparse_convmodule(
                in_channels,
                self.base_channels,
                3,
                norm_cfg=norm_cfg,
                padding=1,
                indice_key='subm1',
                conv_type='SubMConv3d',
                order=('conv', ))
        else:  # post activate
            self.conv_input = make_sparse_convmodule(
                in_channels,
                self.base_channels,
                3,
                norm_cfg=norm_cfg,
                padding=1,
                indice_key='subm1',
                conv_type='SubMConv3d')
wuyuefeng's avatar
wuyuefeng committed
89

90
        encoder_out_channels = self.make_encoder_layers(
wuyuefeng's avatar
wuyuefeng committed
91
92
93
94
95
96
97
98
99
100
101
102
103
            make_sparse_convmodule, norm_cfg, self.base_channels)
        self.make_decoder_layers(make_sparse_convmodule, norm_cfg,
                                 encoder_out_channels)

        self.conv_out = make_sparse_convmodule(
            encoder_out_channels,
            self.output_channels,
            kernel_size=(3, 1, 1),
            stride=(2, 1, 1),
            norm_cfg=norm_cfg,
            padding=0,
            indice_key='spconv_down2',
            conv_type='SparseConv3d')
wuyuefeng's avatar
wuyuefeng committed
104

105
    @auto_fp16(apply_to=('voxel_features', ))
wuyuefeng's avatar
wuyuefeng committed
106
    def forward(self, voxel_features, coors, batch_size):
zhangwenwei's avatar
zhangwenwei committed
107
        """Forward of SparseUNet.
wuyuefeng's avatar
wuyuefeng committed
108
109

        Args:
zhangwenwei's avatar
zhangwenwei committed
110
111
112
113
            voxel_features (torch.float32): Voxel features in shape [N, C].
            coors (torch.int32): Coordinates in shape [N, 4],
                the columns in the order of (batch_idx, z_idx, y_idx, x_idx).
            batch_size (int): Batch size.
wuyuefeng's avatar
wuyuefeng committed
114
115

        Returns:
zhangwenwei's avatar
zhangwenwei committed
116
            dict[str, torch.Tensor]: Backbone features.
wuyuefeng's avatar
wuyuefeng committed
117
118
        """
        coors = coors.int()
119
120
        input_sp_tensor = SparseConvTensor(voxel_features, coors,
                                           self.sparse_shape, batch_size)
wuyuefeng's avatar
wuyuefeng committed
121
122
        x = self.conv_input(input_sp_tensor)

wuyuefeng's avatar
wuyuefeng committed
123
        encode_features = []
wuyuefeng's avatar
wuyuefeng committed
124
125
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x)
wuyuefeng's avatar
wuyuefeng committed
126
            encode_features.append(x)
wuyuefeng's avatar
wuyuefeng committed
127
128
129

        # for detection head
        # [200, 176, 5] -> [200, 176, 2]
wuyuefeng's avatar
wuyuefeng committed
130
        out = self.conv_out(encode_features[-1])
wuyuefeng's avatar
wuyuefeng committed
131
132
133
134
135
        spatial_features = out.dense()

        N, C, D, H, W = spatial_features.shape
        spatial_features = spatial_features.view(N, C * D, H, W)

wuyuefeng's avatar
wuyuefeng committed
136
        # for segmentation head, with output shape:
wuyuefeng's avatar
wuyuefeng committed
137
138
139
140
        # [400, 352, 11] <- [200, 176, 5]
        # [800, 704, 21] <- [400, 352, 11]
        # [1600, 1408, 41] <- [800, 704, 21]
        # [1600, 1408, 41] <- [1600, 1408, 41]
wuyuefeng's avatar
wuyuefeng committed
141
142
143
        decode_features = []
        x = encode_features[-1]
        for i in range(self.stage_num, 0, -1):
wuyuefeng's avatar
wuyuefeng committed
144
145
146
147
            x = self.decoder_layer_forward(encode_features[i - 1], x,
                                           getattr(self, f'lateral_layer{i}'),
                                           getattr(self, f'merge_layer{i}'),
                                           getattr(self, f'upsample_layer{i}'))
wuyuefeng's avatar
wuyuefeng committed
148
            decode_features.append(x)
wuyuefeng's avatar
wuyuefeng committed
149

wuyuefeng's avatar
wuyuefeng committed
150
        seg_features = decode_features[-1].features
wuyuefeng's avatar
wuyuefeng committed
151

wuyuefeng's avatar
wuyuefeng committed
152
153
        ret = dict(
            spatial_features=spatial_features, seg_features=seg_features)
wuyuefeng's avatar
wuyuefeng committed
154
155
156

        return ret

wuyuefeng's avatar
wuyuefeng committed
157
158
    def decoder_layer_forward(self, x_lateral, x_bottom, lateral_layer,
                              merge_layer, upsample_layer):
wuyuefeng's avatar
wuyuefeng committed
159
160
161
        """Forward of upsample and residual block.

        Args:
zhangwenwei's avatar
zhangwenwei committed
162
163
164
165
166
            x_lateral (:obj:`SparseConvTensor`): Lateral tensor.
            x_bottom (:obj:`SparseConvTensor`): Feature from bottom layer.
            lateral_layer (SparseBasicBlock): Convolution for lateral tensor.
            merge_layer (SparseSequential): Convolution for merging features.
            upsample_layer (SparseSequential): Convolution for upsampling.
wuyuefeng's avatar
wuyuefeng committed
167
168

        Returns:
zhangwenwei's avatar
zhangwenwei committed
169
            :obj:`SparseConvTensor`: Upsampled feature.
wuyuefeng's avatar
wuyuefeng committed
170
        """
wuyuefeng's avatar
wuyuefeng committed
171
        x = lateral_layer(x_lateral)
VVsssssk's avatar
VVsssssk committed
172
173
        x = replace_feature(x, torch.cat((x_bottom.features, x.features),
                                         dim=1))
wuyuefeng's avatar
wuyuefeng committed
174
175
        x_merge = merge_layer(x)
        x = self.reduce_channel(x, x_merge.features.shape[1])
VVsssssk's avatar
VVsssssk committed
176
        x = replace_feature(x, x_merge.features + x.features)
wuyuefeng's avatar
wuyuefeng committed
177
        x = upsample_layer(x)
wuyuefeng's avatar
wuyuefeng committed
178
179
180
        return x

    @staticmethod
wuyuefeng's avatar
wuyuefeng committed
181
182
    def reduce_channel(x, out_channels):
        """reduce channel for element-wise addition.
wuyuefeng's avatar
wuyuefeng committed
183
184

        Args:
zhangwenwei's avatar
zhangwenwei committed
185
186
187
            x (:obj:`SparseConvTensor`): Sparse tensor, ``x.features``
                are in shape (N, C1).
            out_channels (int): The number of channel after reduction.
wuyuefeng's avatar
wuyuefeng committed
188
189

        Returns:
zhangwenwei's avatar
zhangwenwei committed
190
            :obj:`SparseConvTensor`: Channel reduced feature.
wuyuefeng's avatar
wuyuefeng committed
191
192
193
        """
        features = x.features
        n, in_channels = features.shape
wuyuefeng's avatar
wuyuefeng committed
194
195
        assert (in_channels % out_channels
                == 0) and (in_channels >= out_channels)
VVsssssk's avatar
VVsssssk committed
196
        x = replace_feature(x, features.view(n, out_channels, -1).sum(dim=2))
wuyuefeng's avatar
wuyuefeng committed
197
198
        return x

199
    def make_encoder_layers(self, make_block, norm_cfg, in_channels):
zhangwenwei's avatar
zhangwenwei committed
200
        """make encoder layers using sparse convs.
wuyuefeng's avatar
wuyuefeng committed
201
202

        Args:
zhangwenwei's avatar
zhangwenwei committed
203
204
205
            make_block (method): A bounded function to build blocks.
            norm_cfg (dict[str]): Config of normalization layer.
            in_channels (int): The number of encoder input channels.
wuyuefeng's avatar
wuyuefeng committed
206
207

        Returns:
wangtai's avatar
wangtai committed
208
            int: The number of encoder output channels.
wuyuefeng's avatar
wuyuefeng committed
209
        """
210
        self.encoder_layers = SparseSequential()
wuyuefeng's avatar
wuyuefeng committed
211

212
        for i, blocks in enumerate(self.encoder_channels):
wuyuefeng's avatar
wuyuefeng committed
213
214
            blocks_list = []
            for j, out_channels in enumerate(tuple(blocks)):
215
                padding = tuple(self.encoder_paddings[i])[j]
wuyuefeng's avatar
wuyuefeng committed
216
217
218
219
220
221
222
223
224
225
226
                # each stage started with a spconv layer
                # except the first stage
                if i != 0 and j == 0:
                    blocks_list.append(
                        make_block(
                            in_channels,
                            out_channels,
                            3,
                            norm_cfg=norm_cfg,
                            stride=2,
                            padding=padding,
227
                            indice_key=f'spconv{i + 1}',
wuyuefeng's avatar
wuyuefeng committed
228
                            conv_type='SparseConv3d'))
wuyuefeng's avatar
wuyuefeng committed
229
230
231
232
233
234
235
236
                else:
                    blocks_list.append(
                        make_block(
                            in_channels,
                            out_channels,
                            3,
                            norm_cfg=norm_cfg,
                            padding=padding,
wuyuefeng's avatar
wuyuefeng committed
237
238
                            indice_key=f'subm{i + 1}',
                            conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
239
                in_channels = out_channels
240
            stage_name = f'encoder_layer{i + 1}'
241
            stage_layers = SparseSequential(*blocks_list)
wuyuefeng's avatar
wuyuefeng committed
242
            self.encoder_layers.add_module(stage_name, stage_layers)
wuyuefeng's avatar
wuyuefeng committed
243
244
        return out_channels

245
    def make_decoder_layers(self, make_block, norm_cfg, in_channels):
zhangwenwei's avatar
zhangwenwei committed
246
        """make decoder layers using sparse convs.
wuyuefeng's avatar
wuyuefeng committed
247
248

        Args:
zhangwenwei's avatar
zhangwenwei committed
249
250
251
            make_block (method): A bounded function to build blocks.
            norm_cfg (dict[str]): Config of normalization layer.
            in_channels (int): The number of encoder input channels.
wuyuefeng's avatar
wuyuefeng committed
252
253

        Returns:
zhangwenwei's avatar
zhangwenwei committed
254
            int: The number of encoder output channels.
wuyuefeng's avatar
wuyuefeng committed
255
        """
256
257
258
        block_num = len(self.decoder_channels)
        for i, block_channels in enumerate(self.decoder_channels):
            paddings = self.decoder_paddings[i]
wuyuefeng's avatar
wuyuefeng committed
259
            setattr(
260
                self, f'lateral_layer{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
261
262
263
264
                SparseBasicBlock(
                    in_channels,
                    block_channels[0],
                    conv_cfg=dict(
265
                        type='SubMConv3d', indice_key=f'subm{block_num - i}'),
wuyuefeng's avatar
wuyuefeng committed
266
267
                    norm_cfg=norm_cfg))
            setattr(
268
                self, f'merge_layer{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
269
270
271
272
273
274
                make_block(
                    in_channels * 2,
                    block_channels[1],
                    3,
                    norm_cfg=norm_cfg,
                    padding=paddings[0],
wuyuefeng's avatar
wuyuefeng committed
275
276
                    indice_key=f'subm{block_num - i}',
                    conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
277
278
279
280
281
282
283
284
285
            if block_num - i != 1:
                setattr(
                    self, f'upsample_layer{block_num - i}',
                    make_block(
                        in_channels,
                        block_channels[2],
                        3,
                        norm_cfg=norm_cfg,
                        indice_key=f'spconv{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
286
                        conv_type='SparseInverseConv3d'))
wuyuefeng's avatar
wuyuefeng committed
287
288
            else:
                # use submanifold conv instead of inverse conv
wuyuefeng's avatar
wuyuefeng committed
289
                # in the last block
wuyuefeng's avatar
wuyuefeng committed
290
291
292
293
294
295
296
297
298
                setattr(
                    self, f'upsample_layer{block_num - i}',
                    make_block(
                        in_channels,
                        block_channels[2],
                        3,
                        norm_cfg=norm_cfg,
                        padding=paddings[1],
                        indice_key='subm1',
wuyuefeng's avatar
wuyuefeng committed
299
                        conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
300
            in_channels = block_channels[2]