sparse_unet.py 11.6 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
wuyuefeng's avatar
wuyuefeng committed
2
import torch
VVsssssk's avatar
VVsssssk committed
3

zhangshilong's avatar
zhangshilong committed
4
from mmdet3d.models.layers.spconv import IS_SPCONV2_AVAILABLE
VVsssssk's avatar
VVsssssk committed
5
6
7
8
9
10

if IS_SPCONV2_AVAILABLE:
    from spconv.pytorch import SparseConvTensor, SparseSequential
else:
    from mmcv.ops import SparseConvTensor, SparseSequential

11
from mmengine.model import BaseModule
wuyuefeng's avatar
wuyuefeng committed
12

zhangshilong's avatar
zhangshilong committed
13
14
from mmdet3d.models.layers import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.models.layers.sparse_block import replace_feature
VVsssssk's avatar
VVsssssk committed
15
from mmdet3d.registry import MODELS
wuyuefeng's avatar
wuyuefeng committed
16
17


18
@MODELS.register_module()
19
class SparseUNet(BaseModule):
zhangwenwei's avatar
zhangwenwei committed
20
    r"""SparseUNet for PartA^2.
wuyuefeng's avatar
wuyuefeng committed
21

22
    See the `paper <https://arxiv.org/abs/1907.03670>`_ for more details.
wuyuefeng's avatar
wuyuefeng committed
23
24

    Args:
wangtai's avatar
wangtai committed
25
26
27
28
29
        in_channels (int): The number of input channels.
        sparse_shape (list[int]): The sparse shape of input tensor.
        norm_cfg (dict): Config of normalization layer.
        base_channels (int): Out channels for conv_input layer.
        output_channels (int): Out channels for conv_out layer.
wuyuefeng's avatar
wuyuefeng committed
30
        encoder_channels (tuple[tuple[int]]):
wangtai's avatar
wangtai committed
31
32
            Convolutional channels of each encode block.
        encoder_paddings (tuple[tuple[int]]): Paddings of each encode block.
wuyuefeng's avatar
wuyuefeng committed
33
        decoder_channels (tuple[tuple[int]]):
wangtai's avatar
wangtai committed
34
35
            Convolutional channels of each decode block.
        decoder_paddings (tuple[tuple[int]]): Paddings of each decode block.
wuyuefeng's avatar
wuyuefeng committed
36
    """
wuyuefeng's avatar
wuyuefeng committed
37
38
39

    def __init__(self,
                 in_channels,
wuyuefeng's avatar
wuyuefeng committed
40
41
                 sparse_shape,
                 order=('conv', 'norm', 'act'),
wuyuefeng's avatar
wuyuefeng committed
42
43
                 norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
                 base_channels=16,
44
45
46
47
48
49
50
                 output_channels=128,
                 encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64,
                                                                        64)),
                 encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
                                                                 1)),
                 decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16),
                                   (16, 16, 16)),
51
52
53
                 decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1)),
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
wuyuefeng's avatar
wuyuefeng committed
54
        self.sparse_shape = sparse_shape
wuyuefeng's avatar
wuyuefeng committed
55
        self.in_channels = in_channels
wuyuefeng's avatar
wuyuefeng committed
56
        self.order = order
wuyuefeng's avatar
wuyuefeng committed
57
        self.base_channels = base_channels
58
59
60
61
62
63
        self.output_channels = output_channels
        self.encoder_channels = encoder_channels
        self.encoder_paddings = encoder_paddings
        self.decoder_channels = decoder_channels
        self.decoder_paddings = decoder_paddings
        self.stage_num = len(self.encoder_channels)
64
        self.fp16_enabled = False
wuyuefeng's avatar
wuyuefeng committed
65
66
        # Spconv init all weight on its own

wuyuefeng's avatar
wuyuefeng committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        assert isinstance(order, tuple) and len(order) == 3
        assert set(order) == {'conv', 'norm', 'act'}

        if self.order[0] != 'conv':  # pre activate
            self.conv_input = make_sparse_convmodule(
                in_channels,
                self.base_channels,
                3,
                norm_cfg=norm_cfg,
                padding=1,
                indice_key='subm1',
                conv_type='SubMConv3d',
                order=('conv', ))
        else:  # post activate
            self.conv_input = make_sparse_convmodule(
                in_channels,
                self.base_channels,
                3,
                norm_cfg=norm_cfg,
                padding=1,
                indice_key='subm1',
                conv_type='SubMConv3d')
wuyuefeng's avatar
wuyuefeng committed
89

90
        encoder_out_channels = self.make_encoder_layers(
wuyuefeng's avatar
wuyuefeng committed
91
92
93
94
95
96
97
98
99
100
101
102
103
            make_sparse_convmodule, norm_cfg, self.base_channels)
        self.make_decoder_layers(make_sparse_convmodule, norm_cfg,
                                 encoder_out_channels)

        self.conv_out = make_sparse_convmodule(
            encoder_out_channels,
            self.output_channels,
            kernel_size=(3, 1, 1),
            stride=(2, 1, 1),
            norm_cfg=norm_cfg,
            padding=0,
            indice_key='spconv_down2',
            conv_type='SparseConv3d')
wuyuefeng's avatar
wuyuefeng committed
104
105

    def forward(self, voxel_features, coors, batch_size):
zhangwenwei's avatar
zhangwenwei committed
106
        """Forward of SparseUNet.
wuyuefeng's avatar
wuyuefeng committed
107
108

        Args:
zhangwenwei's avatar
zhangwenwei committed
109
110
111
112
            voxel_features (torch.float32): Voxel features in shape [N, C].
            coors (torch.int32): Coordinates in shape [N, 4],
                the columns in the order of (batch_idx, z_idx, y_idx, x_idx).
            batch_size (int): Batch size.
wuyuefeng's avatar
wuyuefeng committed
113
114

        Returns:
zhangwenwei's avatar
zhangwenwei committed
115
            dict[str, torch.Tensor]: Backbone features.
wuyuefeng's avatar
wuyuefeng committed
116
117
        """
        coors = coors.int()
118
119
        input_sp_tensor = SparseConvTensor(voxel_features, coors,
                                           self.sparse_shape, batch_size)
wuyuefeng's avatar
wuyuefeng committed
120
121
        x = self.conv_input(input_sp_tensor)

wuyuefeng's avatar
wuyuefeng committed
122
        encode_features = []
wuyuefeng's avatar
wuyuefeng committed
123
124
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x)
wuyuefeng's avatar
wuyuefeng committed
125
            encode_features.append(x)
wuyuefeng's avatar
wuyuefeng committed
126
127
128

        # for detection head
        # [200, 176, 5] -> [200, 176, 2]
wuyuefeng's avatar
wuyuefeng committed
129
        out = self.conv_out(encode_features[-1])
wuyuefeng's avatar
wuyuefeng committed
130
131
132
133
134
        spatial_features = out.dense()

        N, C, D, H, W = spatial_features.shape
        spatial_features = spatial_features.view(N, C * D, H, W)

wuyuefeng's avatar
wuyuefeng committed
135
        # for segmentation head, with output shape:
wuyuefeng's avatar
wuyuefeng committed
136
137
138
139
        # [400, 352, 11] <- [200, 176, 5]
        # [800, 704, 21] <- [400, 352, 11]
        # [1600, 1408, 41] <- [800, 704, 21]
        # [1600, 1408, 41] <- [1600, 1408, 41]
wuyuefeng's avatar
wuyuefeng committed
140
141
142
        decode_features = []
        x = encode_features[-1]
        for i in range(self.stage_num, 0, -1):
wuyuefeng's avatar
wuyuefeng committed
143
144
145
146
            x = self.decoder_layer_forward(encode_features[i - 1], x,
                                           getattr(self, f'lateral_layer{i}'),
                                           getattr(self, f'merge_layer{i}'),
                                           getattr(self, f'upsample_layer{i}'))
wuyuefeng's avatar
wuyuefeng committed
147
            decode_features.append(x)
wuyuefeng's avatar
wuyuefeng committed
148

wuyuefeng's avatar
wuyuefeng committed
149
        seg_features = decode_features[-1].features
wuyuefeng's avatar
wuyuefeng committed
150

wuyuefeng's avatar
wuyuefeng committed
151
152
        ret = dict(
            spatial_features=spatial_features, seg_features=seg_features)
wuyuefeng's avatar
wuyuefeng committed
153
154
155

        return ret

wuyuefeng's avatar
wuyuefeng committed
156
157
    def decoder_layer_forward(self, x_lateral, x_bottom, lateral_layer,
                              merge_layer, upsample_layer):
wuyuefeng's avatar
wuyuefeng committed
158
159
160
        """Forward of upsample and residual block.

        Args:
zhangwenwei's avatar
zhangwenwei committed
161
162
163
164
165
            x_lateral (:obj:`SparseConvTensor`): Lateral tensor.
            x_bottom (:obj:`SparseConvTensor`): Feature from bottom layer.
            lateral_layer (SparseBasicBlock): Convolution for lateral tensor.
            merge_layer (SparseSequential): Convolution for merging features.
            upsample_layer (SparseSequential): Convolution for upsampling.
wuyuefeng's avatar
wuyuefeng committed
166
167

        Returns:
zhangwenwei's avatar
zhangwenwei committed
168
            :obj:`SparseConvTensor`: Upsampled feature.
wuyuefeng's avatar
wuyuefeng committed
169
        """
wuyuefeng's avatar
wuyuefeng committed
170
        x = lateral_layer(x_lateral)
VVsssssk's avatar
VVsssssk committed
171
172
        x = replace_feature(x, torch.cat((x_bottom.features, x.features),
                                         dim=1))
wuyuefeng's avatar
wuyuefeng committed
173
174
        x_merge = merge_layer(x)
        x = self.reduce_channel(x, x_merge.features.shape[1])
VVsssssk's avatar
VVsssssk committed
175
        x = replace_feature(x, x_merge.features + x.features)
wuyuefeng's avatar
wuyuefeng committed
176
        x = upsample_layer(x)
wuyuefeng's avatar
wuyuefeng committed
177
178
179
        return x

    @staticmethod
wuyuefeng's avatar
wuyuefeng committed
180
181
    def reduce_channel(x, out_channels):
        """reduce channel for element-wise addition.
wuyuefeng's avatar
wuyuefeng committed
182
183

        Args:
zhangwenwei's avatar
zhangwenwei committed
184
185
186
            x (:obj:`SparseConvTensor`): Sparse tensor, ``x.features``
                are in shape (N, C1).
            out_channels (int): The number of channel after reduction.
wuyuefeng's avatar
wuyuefeng committed
187
188

        Returns:
zhangwenwei's avatar
zhangwenwei committed
189
            :obj:`SparseConvTensor`: Channel reduced feature.
wuyuefeng's avatar
wuyuefeng committed
190
191
192
        """
        features = x.features
        n, in_channels = features.shape
wuyuefeng's avatar
wuyuefeng committed
193
194
        assert (in_channels % out_channels
                == 0) and (in_channels >= out_channels)
VVsssssk's avatar
VVsssssk committed
195
        x = replace_feature(x, features.view(n, out_channels, -1).sum(dim=2))
wuyuefeng's avatar
wuyuefeng committed
196
197
        return x

198
    def make_encoder_layers(self, make_block, norm_cfg, in_channels):
zhangwenwei's avatar
zhangwenwei committed
199
        """make encoder layers using sparse convs.
wuyuefeng's avatar
wuyuefeng committed
200
201

        Args:
zhangwenwei's avatar
zhangwenwei committed
202
203
204
            make_block (method): A bounded function to build blocks.
            norm_cfg (dict[str]): Config of normalization layer.
            in_channels (int): The number of encoder input channels.
wuyuefeng's avatar
wuyuefeng committed
205
206

        Returns:
wangtai's avatar
wangtai committed
207
            int: The number of encoder output channels.
wuyuefeng's avatar
wuyuefeng committed
208
        """
209
        self.encoder_layers = SparseSequential()
wuyuefeng's avatar
wuyuefeng committed
210

211
        for i, blocks in enumerate(self.encoder_channels):
wuyuefeng's avatar
wuyuefeng committed
212
213
            blocks_list = []
            for j, out_channels in enumerate(tuple(blocks)):
214
                padding = tuple(self.encoder_paddings[i])[j]
wuyuefeng's avatar
wuyuefeng committed
215
216
217
218
219
220
221
222
223
224
225
                # each stage started with a spconv layer
                # except the first stage
                if i != 0 and j == 0:
                    blocks_list.append(
                        make_block(
                            in_channels,
                            out_channels,
                            3,
                            norm_cfg=norm_cfg,
                            stride=2,
                            padding=padding,
226
                            indice_key=f'spconv{i + 1}',
wuyuefeng's avatar
wuyuefeng committed
227
                            conv_type='SparseConv3d'))
wuyuefeng's avatar
wuyuefeng committed
228
229
230
231
232
233
234
235
                else:
                    blocks_list.append(
                        make_block(
                            in_channels,
                            out_channels,
                            3,
                            norm_cfg=norm_cfg,
                            padding=padding,
wuyuefeng's avatar
wuyuefeng committed
236
237
                            indice_key=f'subm{i + 1}',
                            conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
238
                in_channels = out_channels
239
            stage_name = f'encoder_layer{i + 1}'
240
            stage_layers = SparseSequential(*blocks_list)
wuyuefeng's avatar
wuyuefeng committed
241
            self.encoder_layers.add_module(stage_name, stage_layers)
wuyuefeng's avatar
wuyuefeng committed
242
243
        return out_channels

244
    def make_decoder_layers(self, make_block, norm_cfg, in_channels):
zhangwenwei's avatar
zhangwenwei committed
245
        """make decoder layers using sparse convs.
wuyuefeng's avatar
wuyuefeng committed
246
247

        Args:
zhangwenwei's avatar
zhangwenwei committed
248
249
250
            make_block (method): A bounded function to build blocks.
            norm_cfg (dict[str]): Config of normalization layer.
            in_channels (int): The number of encoder input channels.
wuyuefeng's avatar
wuyuefeng committed
251
252

        Returns:
zhangwenwei's avatar
zhangwenwei committed
253
            int: The number of encoder output channels.
wuyuefeng's avatar
wuyuefeng committed
254
        """
255
256
257
        block_num = len(self.decoder_channels)
        for i, block_channels in enumerate(self.decoder_channels):
            paddings = self.decoder_paddings[i]
wuyuefeng's avatar
wuyuefeng committed
258
            setattr(
259
                self, f'lateral_layer{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
260
261
262
263
                SparseBasicBlock(
                    in_channels,
                    block_channels[0],
                    conv_cfg=dict(
264
                        type='SubMConv3d', indice_key=f'subm{block_num - i}'),
wuyuefeng's avatar
wuyuefeng committed
265
266
                    norm_cfg=norm_cfg))
            setattr(
267
                self, f'merge_layer{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
268
269
270
271
272
273
                make_block(
                    in_channels * 2,
                    block_channels[1],
                    3,
                    norm_cfg=norm_cfg,
                    padding=paddings[0],
wuyuefeng's avatar
wuyuefeng committed
274
275
                    indice_key=f'subm{block_num - i}',
                    conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
276
277
278
279
280
281
282
283
284
            if block_num - i != 1:
                setattr(
                    self, f'upsample_layer{block_num - i}',
                    make_block(
                        in_channels,
                        block_channels[2],
                        3,
                        norm_cfg=norm_cfg,
                        indice_key=f'spconv{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
285
                        conv_type='SparseInverseConv3d'))
wuyuefeng's avatar
wuyuefeng committed
286
287
            else:
                # use submanifold conv instead of inverse conv
wuyuefeng's avatar
wuyuefeng committed
288
                # in the last block
wuyuefeng's avatar
wuyuefeng committed
289
290
291
292
293
294
295
296
297
                setattr(
                    self, f'upsample_layer{block_num - i}',
                    make_block(
                        in_channels,
                        block_channels[2],
                        3,
                        norm_cfg=norm_cfg,
                        padding=paddings[1],
                        indice_key='subm1',
wuyuefeng's avatar
wuyuefeng committed
298
                        conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
299
            in_channels = block_channels[2]