sparse_unet.py 12.5 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import Dict, List, Optional, Tuple

wuyuefeng's avatar
wuyuefeng committed
4
import torch
5
from torch import Tensor, nn
VVsssssk's avatar
VVsssssk committed
6

zhangshilong's avatar
zhangshilong committed
7
from mmdet3d.models.layers.spconv import IS_SPCONV2_AVAILABLE
VVsssssk's avatar
VVsssssk committed
8
9
10
11
12
13

if IS_SPCONV2_AVAILABLE:
    from spconv.pytorch import SparseConvTensor, SparseSequential
else:
    from mmcv.ops import SparseConvTensor, SparseSequential

14
from mmengine.model import BaseModule
wuyuefeng's avatar
wuyuefeng committed
15

zhangshilong's avatar
zhangshilong committed
16
17
from mmdet3d.models.layers import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.models.layers.sparse_block import replace_feature
VVsssssk's avatar
VVsssssk committed
18
from mmdet3d.registry import MODELS
wuyuefeng's avatar
wuyuefeng committed
19

20
21
TwoTupleIntType = Tuple[Tuple[int]]

wuyuefeng's avatar
wuyuefeng committed
22

23
@MODELS.register_module()
24
class SparseUNet(BaseModule):
zhangwenwei's avatar
zhangwenwei committed
25
    r"""SparseUNet for PartA^2.
wuyuefeng's avatar
wuyuefeng committed
26

27
    See the `paper <https://arxiv.org/abs/1907.03670>`_ for more details.
wuyuefeng's avatar
wuyuefeng committed
28
29

    Args:
wangtai's avatar
wangtai committed
30
31
32
33
34
        in_channels (int): The number of input channels.
        sparse_shape (list[int]): The sparse shape of input tensor.
        norm_cfg (dict): Config of normalization layer.
        base_channels (int): Out channels for conv_input layer.
        output_channels (int): Out channels for conv_out layer.
wuyuefeng's avatar
wuyuefeng committed
35
        encoder_channels (tuple[tuple[int]]):
wangtai's avatar
wangtai committed
36
37
            Convolutional channels of each encode block.
        encoder_paddings (tuple[tuple[int]]): Paddings of each encode block.
wuyuefeng's avatar
wuyuefeng committed
38
        decoder_channels (tuple[tuple[int]]):
wangtai's avatar
wangtai committed
39
40
            Convolutional channels of each decode block.
        decoder_paddings (tuple[tuple[int]]): Paddings of each decode block.
wuyuefeng's avatar
wuyuefeng committed
41
    """
wuyuefeng's avatar
wuyuefeng committed
42

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    def __init__(
            self,
            in_channels: int,
            sparse_shape: List[int],
            order: Tuple[str] = ('conv', 'norm', 'act'),
            norm_cfg: dict = dict(type='BN1d', eps=1e-3, momentum=0.01),
            base_channels: int = 16,
            output_channels: int = 128,
            encoder_channels: Optional[TwoTupleIntType] = ((16, ), (32, 32,
                                                                    32),
                                                           (64, 64,
                                                            64), (64, 64, 64)),
            encoder_paddings: Optional[TwoTupleIntType] = ((1, ), (1, 1, 1),
                                                           (1, 1, 1),
                                                           ((0, 1, 1), 1, 1)),
            decoder_channels: Optional[TwoTupleIntType] = ((64, 64,
                                                            64), (64, 64, 32),
                                                           (32, 32,
                                                            16), (16, 16, 16)),
            decoder_paddings: Optional[TwoTupleIntType] = ((1, 0), (1, 0),
                                                           (0, 0), (0, 1)),
            init_cfg: bool = None):
65
        super().__init__(init_cfg=init_cfg)
wuyuefeng's avatar
wuyuefeng committed
66
        self.sparse_shape = sparse_shape
wuyuefeng's avatar
wuyuefeng committed
67
        self.in_channels = in_channels
wuyuefeng's avatar
wuyuefeng committed
68
        self.order = order
wuyuefeng's avatar
wuyuefeng committed
69
        self.base_channels = base_channels
70
71
72
73
74
75
        self.output_channels = output_channels
        self.encoder_channels = encoder_channels
        self.encoder_paddings = encoder_paddings
        self.decoder_channels = decoder_channels
        self.decoder_paddings = decoder_paddings
        self.stage_num = len(self.encoder_channels)
wuyuefeng's avatar
wuyuefeng committed
76
77
        # Spconv init all weight on its own

wuyuefeng's avatar
wuyuefeng committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        assert isinstance(order, tuple) and len(order) == 3
        assert set(order) == {'conv', 'norm', 'act'}

        if self.order[0] != 'conv':  # pre activate
            self.conv_input = make_sparse_convmodule(
                in_channels,
                self.base_channels,
                3,
                norm_cfg=norm_cfg,
                padding=1,
                indice_key='subm1',
                conv_type='SubMConv3d',
                order=('conv', ))
        else:  # post activate
            self.conv_input = make_sparse_convmodule(
                in_channels,
                self.base_channels,
                3,
                norm_cfg=norm_cfg,
                padding=1,
                indice_key='subm1',
                conv_type='SubMConv3d')
wuyuefeng's avatar
wuyuefeng committed
100

101
        encoder_out_channels = self.make_encoder_layers(
wuyuefeng's avatar
wuyuefeng committed
102
103
104
105
106
107
108
109
110
111
112
113
114
            make_sparse_convmodule, norm_cfg, self.base_channels)
        self.make_decoder_layers(make_sparse_convmodule, norm_cfg,
                                 encoder_out_channels)

        self.conv_out = make_sparse_convmodule(
            encoder_out_channels,
            self.output_channels,
            kernel_size=(3, 1, 1),
            stride=(2, 1, 1),
            norm_cfg=norm_cfg,
            padding=0,
            indice_key='spconv_down2',
            conv_type='SparseConv3d')
wuyuefeng's avatar
wuyuefeng committed
115

116
117
    def forward(self, voxel_features: Tensor, coors: Tensor,
                batch_size: int) -> Dict[str, Tensor]:
zhangwenwei's avatar
zhangwenwei committed
118
        """Forward of SparseUNet.
wuyuefeng's avatar
wuyuefeng committed
119
120

        Args:
zhangwenwei's avatar
zhangwenwei committed
121
122
123
124
            voxel_features (torch.float32): Voxel features in shape [N, C].
            coors (torch.int32): Coordinates in shape [N, 4],
                the columns in the order of (batch_idx, z_idx, y_idx, x_idx).
            batch_size (int): Batch size.
wuyuefeng's avatar
wuyuefeng committed
125
126

        Returns:
zhangwenwei's avatar
zhangwenwei committed
127
            dict[str, torch.Tensor]: Backbone features.
wuyuefeng's avatar
wuyuefeng committed
128
129
        """
        coors = coors.int()
130
131
        input_sp_tensor = SparseConvTensor(voxel_features, coors,
                                           self.sparse_shape, batch_size)
wuyuefeng's avatar
wuyuefeng committed
132
133
        x = self.conv_input(input_sp_tensor)

wuyuefeng's avatar
wuyuefeng committed
134
        encode_features = []
wuyuefeng's avatar
wuyuefeng committed
135
136
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x)
wuyuefeng's avatar
wuyuefeng committed
137
            encode_features.append(x)
wuyuefeng's avatar
wuyuefeng committed
138
139
140

        # for detection head
        # [200, 176, 5] -> [200, 176, 2]
wuyuefeng's avatar
wuyuefeng committed
141
        out = self.conv_out(encode_features[-1])
wuyuefeng's avatar
wuyuefeng committed
142
143
144
145
146
        spatial_features = out.dense()

        N, C, D, H, W = spatial_features.shape
        spatial_features = spatial_features.view(N, C * D, H, W)

wuyuefeng's avatar
wuyuefeng committed
147
        # for segmentation head, with output shape:
wuyuefeng's avatar
wuyuefeng committed
148
149
150
151
        # [400, 352, 11] <- [200, 176, 5]
        # [800, 704, 21] <- [400, 352, 11]
        # [1600, 1408, 41] <- [800, 704, 21]
        # [1600, 1408, 41] <- [1600, 1408, 41]
wuyuefeng's avatar
wuyuefeng committed
152
153
154
        decode_features = []
        x = encode_features[-1]
        for i in range(self.stage_num, 0, -1):
wuyuefeng's avatar
wuyuefeng committed
155
156
157
158
            x = self.decoder_layer_forward(encode_features[i - 1], x,
                                           getattr(self, f'lateral_layer{i}'),
                                           getattr(self, f'merge_layer{i}'),
                                           getattr(self, f'upsample_layer{i}'))
wuyuefeng's avatar
wuyuefeng committed
159
            decode_features.append(x)
wuyuefeng's avatar
wuyuefeng committed
160

wuyuefeng's avatar
wuyuefeng committed
161
        seg_features = decode_features[-1].features
wuyuefeng's avatar
wuyuefeng committed
162

wuyuefeng's avatar
wuyuefeng committed
163
164
        ret = dict(
            spatial_features=spatial_features, seg_features=seg_features)
wuyuefeng's avatar
wuyuefeng committed
165
166
167

        return ret

168
169
170
171
    def decoder_layer_forward(
            self, x_lateral: SparseConvTensor, x_bottom: SparseConvTensor,
            lateral_layer: SparseBasicBlock, merge_layer: SparseSequential,
            upsample_layer: SparseSequential) -> SparseConvTensor:
wuyuefeng's avatar
wuyuefeng committed
172
173
174
        """Forward of upsample and residual block.

        Args:
zhangwenwei's avatar
zhangwenwei committed
175
176
177
178
179
            x_lateral (:obj:`SparseConvTensor`): Lateral tensor.
            x_bottom (:obj:`SparseConvTensor`): Feature from bottom layer.
            lateral_layer (SparseBasicBlock): Convolution for lateral tensor.
            merge_layer (SparseSequential): Convolution for merging features.
            upsample_layer (SparseSequential): Convolution for upsampling.
wuyuefeng's avatar
wuyuefeng committed
180
181

        Returns:
zhangwenwei's avatar
zhangwenwei committed
182
            :obj:`SparseConvTensor`: Upsampled feature.
wuyuefeng's avatar
wuyuefeng committed
183
        """
wuyuefeng's avatar
wuyuefeng committed
184
        x = lateral_layer(x_lateral)
VVsssssk's avatar
VVsssssk committed
185
186
        x = replace_feature(x, torch.cat((x_bottom.features, x.features),
                                         dim=1))
wuyuefeng's avatar
wuyuefeng committed
187
188
        x_merge = merge_layer(x)
        x = self.reduce_channel(x, x_merge.features.shape[1])
VVsssssk's avatar
VVsssssk committed
189
        x = replace_feature(x, x_merge.features + x.features)
wuyuefeng's avatar
wuyuefeng committed
190
        x = upsample_layer(x)
wuyuefeng's avatar
wuyuefeng committed
191
192
193
        return x

    @staticmethod
194
195
    def reduce_channel(x: SparseConvTensor,
                       out_channels: int) -> SparseConvTensor:
wuyuefeng's avatar
wuyuefeng committed
196
        """reduce channel for element-wise addition.
wuyuefeng's avatar
wuyuefeng committed
197
198

        Args:
zhangwenwei's avatar
zhangwenwei committed
199
200
201
            x (:obj:`SparseConvTensor`): Sparse tensor, ``x.features``
                are in shape (N, C1).
            out_channels (int): The number of channel after reduction.
wuyuefeng's avatar
wuyuefeng committed
202
203

        Returns:
zhangwenwei's avatar
zhangwenwei committed
204
            :obj:`SparseConvTensor`: Channel reduced feature.
wuyuefeng's avatar
wuyuefeng committed
205
206
207
        """
        features = x.features
        n, in_channels = features.shape
wuyuefeng's avatar
wuyuefeng committed
208
209
        assert (in_channels % out_channels
                == 0) and (in_channels >= out_channels)
VVsssssk's avatar
VVsssssk committed
210
        x = replace_feature(x, features.view(n, out_channels, -1).sum(dim=2))
wuyuefeng's avatar
wuyuefeng committed
211
212
        return x

213
214
    def make_encoder_layers(self, make_block: nn.Module, norm_cfg: dict,
                            in_channels: int) -> int:
zhangwenwei's avatar
zhangwenwei committed
215
        """make encoder layers using sparse convs.
wuyuefeng's avatar
wuyuefeng committed
216
217

        Args:
zhangwenwei's avatar
zhangwenwei committed
218
219
220
            make_block (method): A bounded function to build blocks.
            norm_cfg (dict[str]): Config of normalization layer.
            in_channels (int): The number of encoder input channels.
wuyuefeng's avatar
wuyuefeng committed
221
222

        Returns:
wangtai's avatar
wangtai committed
223
            int: The number of encoder output channels.
wuyuefeng's avatar
wuyuefeng committed
224
        """
225
        self.encoder_layers = SparseSequential()
wuyuefeng's avatar
wuyuefeng committed
226

227
        for i, blocks in enumerate(self.encoder_channels):
wuyuefeng's avatar
wuyuefeng committed
228
229
            blocks_list = []
            for j, out_channels in enumerate(tuple(blocks)):
230
                padding = tuple(self.encoder_paddings[i])[j]
wuyuefeng's avatar
wuyuefeng committed
231
232
233
234
235
236
237
238
239
240
241
                # each stage started with a spconv layer
                # except the first stage
                if i != 0 and j == 0:
                    blocks_list.append(
                        make_block(
                            in_channels,
                            out_channels,
                            3,
                            norm_cfg=norm_cfg,
                            stride=2,
                            padding=padding,
242
                            indice_key=f'spconv{i + 1}',
wuyuefeng's avatar
wuyuefeng committed
243
                            conv_type='SparseConv3d'))
wuyuefeng's avatar
wuyuefeng committed
244
245
246
247
248
249
250
251
                else:
                    blocks_list.append(
                        make_block(
                            in_channels,
                            out_channels,
                            3,
                            norm_cfg=norm_cfg,
                            padding=padding,
wuyuefeng's avatar
wuyuefeng committed
252
253
                            indice_key=f'subm{i + 1}',
                            conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
254
                in_channels = out_channels
255
            stage_name = f'encoder_layer{i + 1}'
256
            stage_layers = SparseSequential(*blocks_list)
wuyuefeng's avatar
wuyuefeng committed
257
            self.encoder_layers.add_module(stage_name, stage_layers)
wuyuefeng's avatar
wuyuefeng committed
258
259
        return out_channels

260
261
    def make_decoder_layers(self, make_block: nn.Module, norm_cfg: dict,
                            in_channels: int) -> int:
zhangwenwei's avatar
zhangwenwei committed
262
        """make decoder layers using sparse convs.
wuyuefeng's avatar
wuyuefeng committed
263
264

        Args:
zhangwenwei's avatar
zhangwenwei committed
265
266
267
            make_block (method): A bounded function to build blocks.
            norm_cfg (dict[str]): Config of normalization layer.
            in_channels (int): The number of encoder input channels.
wuyuefeng's avatar
wuyuefeng committed
268
269

        Returns:
zhangwenwei's avatar
zhangwenwei committed
270
            int: The number of encoder output channels.
wuyuefeng's avatar
wuyuefeng committed
271
        """
272
273
274
        block_num = len(self.decoder_channels)
        for i, block_channels in enumerate(self.decoder_channels):
            paddings = self.decoder_paddings[i]
wuyuefeng's avatar
wuyuefeng committed
275
            setattr(
276
                self, f'lateral_layer{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
277
278
279
280
                SparseBasicBlock(
                    in_channels,
                    block_channels[0],
                    conv_cfg=dict(
281
                        type='SubMConv3d', indice_key=f'subm{block_num - i}'),
wuyuefeng's avatar
wuyuefeng committed
282
283
                    norm_cfg=norm_cfg))
            setattr(
284
                self, f'merge_layer{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
285
286
287
288
289
290
                make_block(
                    in_channels * 2,
                    block_channels[1],
                    3,
                    norm_cfg=norm_cfg,
                    padding=paddings[0],
wuyuefeng's avatar
wuyuefeng committed
291
292
                    indice_key=f'subm{block_num - i}',
                    conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
293
294
295
296
297
298
299
300
301
            if block_num - i != 1:
                setattr(
                    self, f'upsample_layer{block_num - i}',
                    make_block(
                        in_channels,
                        block_channels[2],
                        3,
                        norm_cfg=norm_cfg,
                        indice_key=f'spconv{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
302
                        conv_type='SparseInverseConv3d'))
wuyuefeng's avatar
wuyuefeng committed
303
304
            else:
                # use submanifold conv instead of inverse conv
wuyuefeng's avatar
wuyuefeng committed
305
                # in the last block
wuyuefeng's avatar
wuyuefeng committed
306
307
308
309
310
311
312
313
314
                setattr(
                    self, f'upsample_layer{block_num - i}',
                    make_block(
                        in_channels,
                        block_channels[2],
                        3,
                        norm_cfg=norm_cfg,
                        padding=paddings[1],
                        indice_key='subm1',
wuyuefeng's avatar
wuyuefeng committed
315
                        conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
316
            in_channels = block_channels[2]