"configs/datasets/winogrande/winogrande_gen.py" did not exist on "c94cc943485e275897ad95cfa5192ff8e066378a"
sparse_unet.py 11.6 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
wuyuefeng's avatar
wuyuefeng committed
2
import torch
VVsssssk's avatar
VVsssssk committed
3

zhangshilong's avatar
zhangshilong committed
4
from mmdet3d.models.layers.spconv import IS_SPCONV2_AVAILABLE
VVsssssk's avatar
VVsssssk committed
5
6
7
8
9
10

if IS_SPCONV2_AVAILABLE:
    from spconv.pytorch import SparseConvTensor, SparseSequential
else:
    from mmcv.ops import SparseConvTensor, SparseSequential

11
from mmengine.model import BaseModule
wuyuefeng's avatar
wuyuefeng committed
12

zhangshilong's avatar
zhangshilong committed
13
14
from mmdet3d.models.layers import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.models.layers.sparse_block import replace_feature
VVsssssk's avatar
VVsssssk committed
15
from mmdet3d.registry import MODELS
wuyuefeng's avatar
wuyuefeng committed
16
17


18
@MODELS.register_module()
19
class SparseUNet(BaseModule):
zhangwenwei's avatar
zhangwenwei committed
20
    r"""SparseUNet for PartA^2.
wuyuefeng's avatar
wuyuefeng committed
21

22
    See the `paper <https://arxiv.org/abs/1907.03670>`_ for more details.
wuyuefeng's avatar
wuyuefeng committed
23
24

    Args:
wangtai's avatar
wangtai committed
25
26
27
28
29
        in_channels (int): The number of input channels.
        sparse_shape (list[int]): The sparse shape of input tensor.
        norm_cfg (dict): Config of normalization layer.
        base_channels (int): Out channels for conv_input layer.
        output_channels (int): Out channels for conv_out layer.
wuyuefeng's avatar
wuyuefeng committed
30
        encoder_channels (tuple[tuple[int]]):
wangtai's avatar
wangtai committed
31
32
            Convolutional channels of each encode block.
        encoder_paddings (tuple[tuple[int]]): Paddings of each encode block.
wuyuefeng's avatar
wuyuefeng committed
33
        decoder_channels (tuple[tuple[int]]):
wangtai's avatar
wangtai committed
34
35
            Convolutional channels of each decode block.
        decoder_paddings (tuple[tuple[int]]): Paddings of each decode block.
wuyuefeng's avatar
wuyuefeng committed
36
    """
wuyuefeng's avatar
wuyuefeng committed
37
38
39

    def __init__(self,
                 in_channels,
wuyuefeng's avatar
wuyuefeng committed
40
41
                 sparse_shape,
                 order=('conv', 'norm', 'act'),
wuyuefeng's avatar
wuyuefeng committed
42
43
                 norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
                 base_channels=16,
44
45
46
47
48
49
50
                 output_channels=128,
                 encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64,
                                                                        64)),
                 encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
                                                                 1)),
                 decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16),
                                   (16, 16, 16)),
51
52
53
                 decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1)),
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
wuyuefeng's avatar
wuyuefeng committed
54
        self.sparse_shape = sparse_shape
wuyuefeng's avatar
wuyuefeng committed
55
        self.in_channels = in_channels
wuyuefeng's avatar
wuyuefeng committed
56
        self.order = order
wuyuefeng's avatar
wuyuefeng committed
57
        self.base_channels = base_channels
58
59
60
61
62
63
        self.output_channels = output_channels
        self.encoder_channels = encoder_channels
        self.encoder_paddings = encoder_paddings
        self.decoder_channels = decoder_channels
        self.decoder_paddings = decoder_paddings
        self.stage_num = len(self.encoder_channels)
wuyuefeng's avatar
wuyuefeng committed
64
65
        # Spconv init all weight on its own

wuyuefeng's avatar
wuyuefeng committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        assert isinstance(order, tuple) and len(order) == 3
        assert set(order) == {'conv', 'norm', 'act'}

        if self.order[0] != 'conv':  # pre activate
            self.conv_input = make_sparse_convmodule(
                in_channels,
                self.base_channels,
                3,
                norm_cfg=norm_cfg,
                padding=1,
                indice_key='subm1',
                conv_type='SubMConv3d',
                order=('conv', ))
        else:  # post activate
            self.conv_input = make_sparse_convmodule(
                in_channels,
                self.base_channels,
                3,
                norm_cfg=norm_cfg,
                padding=1,
                indice_key='subm1',
                conv_type='SubMConv3d')
wuyuefeng's avatar
wuyuefeng committed
88

89
        encoder_out_channels = self.make_encoder_layers(
wuyuefeng's avatar
wuyuefeng committed
90
91
92
93
94
95
96
97
98
99
100
101
102
            make_sparse_convmodule, norm_cfg, self.base_channels)
        self.make_decoder_layers(make_sparse_convmodule, norm_cfg,
                                 encoder_out_channels)

        self.conv_out = make_sparse_convmodule(
            encoder_out_channels,
            self.output_channels,
            kernel_size=(3, 1, 1),
            stride=(2, 1, 1),
            norm_cfg=norm_cfg,
            padding=0,
            indice_key='spconv_down2',
            conv_type='SparseConv3d')
wuyuefeng's avatar
wuyuefeng committed
103
104

    def forward(self, voxel_features, coors, batch_size):
zhangwenwei's avatar
zhangwenwei committed
105
        """Forward of SparseUNet.
wuyuefeng's avatar
wuyuefeng committed
106
107

        Args:
zhangwenwei's avatar
zhangwenwei committed
108
109
110
111
            voxel_features (torch.float32): Voxel features in shape [N, C].
            coors (torch.int32): Coordinates in shape [N, 4],
                the columns in the order of (batch_idx, z_idx, y_idx, x_idx).
            batch_size (int): Batch size.
wuyuefeng's avatar
wuyuefeng committed
112
113

        Returns:
zhangwenwei's avatar
zhangwenwei committed
114
            dict[str, torch.Tensor]: Backbone features.
wuyuefeng's avatar
wuyuefeng committed
115
116
        """
        coors = coors.int()
117
118
        input_sp_tensor = SparseConvTensor(voxel_features, coors,
                                           self.sparse_shape, batch_size)
wuyuefeng's avatar
wuyuefeng committed
119
120
        x = self.conv_input(input_sp_tensor)

wuyuefeng's avatar
wuyuefeng committed
121
        encode_features = []
wuyuefeng's avatar
wuyuefeng committed
122
123
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x)
wuyuefeng's avatar
wuyuefeng committed
124
            encode_features.append(x)
wuyuefeng's avatar
wuyuefeng committed
125
126
127

        # for detection head
        # [200, 176, 5] -> [200, 176, 2]
wuyuefeng's avatar
wuyuefeng committed
128
        out = self.conv_out(encode_features[-1])
wuyuefeng's avatar
wuyuefeng committed
129
130
131
132
133
        spatial_features = out.dense()

        N, C, D, H, W = spatial_features.shape
        spatial_features = spatial_features.view(N, C * D, H, W)

wuyuefeng's avatar
wuyuefeng committed
134
        # for segmentation head, with output shape:
wuyuefeng's avatar
wuyuefeng committed
135
136
137
138
        # [400, 352, 11] <- [200, 176, 5]
        # [800, 704, 21] <- [400, 352, 11]
        # [1600, 1408, 41] <- [800, 704, 21]
        # [1600, 1408, 41] <- [1600, 1408, 41]
wuyuefeng's avatar
wuyuefeng committed
139
140
141
        decode_features = []
        x = encode_features[-1]
        for i in range(self.stage_num, 0, -1):
wuyuefeng's avatar
wuyuefeng committed
142
143
144
145
            x = self.decoder_layer_forward(encode_features[i - 1], x,
                                           getattr(self, f'lateral_layer{i}'),
                                           getattr(self, f'merge_layer{i}'),
                                           getattr(self, f'upsample_layer{i}'))
wuyuefeng's avatar
wuyuefeng committed
146
            decode_features.append(x)
wuyuefeng's avatar
wuyuefeng committed
147

wuyuefeng's avatar
wuyuefeng committed
148
        seg_features = decode_features[-1].features
wuyuefeng's avatar
wuyuefeng committed
149

wuyuefeng's avatar
wuyuefeng committed
150
151
        ret = dict(
            spatial_features=spatial_features, seg_features=seg_features)
wuyuefeng's avatar
wuyuefeng committed
152
153
154

        return ret

wuyuefeng's avatar
wuyuefeng committed
155
156
    def decoder_layer_forward(self, x_lateral, x_bottom, lateral_layer,
                              merge_layer, upsample_layer):
wuyuefeng's avatar
wuyuefeng committed
157
158
159
        """Forward of upsample and residual block.

        Args:
zhangwenwei's avatar
zhangwenwei committed
160
161
162
163
164
            x_lateral (:obj:`SparseConvTensor`): Lateral tensor.
            x_bottom (:obj:`SparseConvTensor`): Feature from bottom layer.
            lateral_layer (SparseBasicBlock): Convolution for lateral tensor.
            merge_layer (SparseSequential): Convolution for merging features.
            upsample_layer (SparseSequential): Convolution for upsampling.
wuyuefeng's avatar
wuyuefeng committed
165
166

        Returns:
zhangwenwei's avatar
zhangwenwei committed
167
            :obj:`SparseConvTensor`: Upsampled feature.
wuyuefeng's avatar
wuyuefeng committed
168
        """
wuyuefeng's avatar
wuyuefeng committed
169
        x = lateral_layer(x_lateral)
VVsssssk's avatar
VVsssssk committed
170
171
        x = replace_feature(x, torch.cat((x_bottom.features, x.features),
                                         dim=1))
wuyuefeng's avatar
wuyuefeng committed
172
173
        x_merge = merge_layer(x)
        x = self.reduce_channel(x, x_merge.features.shape[1])
VVsssssk's avatar
VVsssssk committed
174
        x = replace_feature(x, x_merge.features + x.features)
wuyuefeng's avatar
wuyuefeng committed
175
        x = upsample_layer(x)
wuyuefeng's avatar
wuyuefeng committed
176
177
178
        return x

    @staticmethod
wuyuefeng's avatar
wuyuefeng committed
179
180
    def reduce_channel(x, out_channels):
        """reduce channel for element-wise addition.
wuyuefeng's avatar
wuyuefeng committed
181
182

        Args:
zhangwenwei's avatar
zhangwenwei committed
183
184
185
            x (:obj:`SparseConvTensor`): Sparse tensor, ``x.features``
                are in shape (N, C1).
            out_channels (int): The number of channel after reduction.
wuyuefeng's avatar
wuyuefeng committed
186
187

        Returns:
zhangwenwei's avatar
zhangwenwei committed
188
            :obj:`SparseConvTensor`: Channel reduced feature.
wuyuefeng's avatar
wuyuefeng committed
189
190
191
        """
        features = x.features
        n, in_channels = features.shape
wuyuefeng's avatar
wuyuefeng committed
192
193
        assert (in_channels % out_channels
                == 0) and (in_channels >= out_channels)
VVsssssk's avatar
VVsssssk committed
194
        x = replace_feature(x, features.view(n, out_channels, -1).sum(dim=2))
wuyuefeng's avatar
wuyuefeng committed
195
196
        return x

197
    def make_encoder_layers(self, make_block, norm_cfg, in_channels):
zhangwenwei's avatar
zhangwenwei committed
198
        """make encoder layers using sparse convs.
wuyuefeng's avatar
wuyuefeng committed
199
200

        Args:
zhangwenwei's avatar
zhangwenwei committed
201
202
203
            make_block (method): A bounded function to build blocks.
            norm_cfg (dict[str]): Config of normalization layer.
            in_channels (int): The number of encoder input channels.
wuyuefeng's avatar
wuyuefeng committed
204
205

        Returns:
wangtai's avatar
wangtai committed
206
            int: The number of encoder output channels.
wuyuefeng's avatar
wuyuefeng committed
207
        """
208
        self.encoder_layers = SparseSequential()
wuyuefeng's avatar
wuyuefeng committed
209

210
        for i, blocks in enumerate(self.encoder_channels):
wuyuefeng's avatar
wuyuefeng committed
211
212
            blocks_list = []
            for j, out_channels in enumerate(tuple(blocks)):
213
                padding = tuple(self.encoder_paddings[i])[j]
wuyuefeng's avatar
wuyuefeng committed
214
215
216
217
218
219
220
221
222
223
224
                # each stage started with a spconv layer
                # except the first stage
                if i != 0 and j == 0:
                    blocks_list.append(
                        make_block(
                            in_channels,
                            out_channels,
                            3,
                            norm_cfg=norm_cfg,
                            stride=2,
                            padding=padding,
225
                            indice_key=f'spconv{i + 1}',
wuyuefeng's avatar
wuyuefeng committed
226
                            conv_type='SparseConv3d'))
wuyuefeng's avatar
wuyuefeng committed
227
228
229
230
231
232
233
234
                else:
                    blocks_list.append(
                        make_block(
                            in_channels,
                            out_channels,
                            3,
                            norm_cfg=norm_cfg,
                            padding=padding,
wuyuefeng's avatar
wuyuefeng committed
235
236
                            indice_key=f'subm{i + 1}',
                            conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
237
                in_channels = out_channels
238
            stage_name = f'encoder_layer{i + 1}'
239
            stage_layers = SparseSequential(*blocks_list)
wuyuefeng's avatar
wuyuefeng committed
240
            self.encoder_layers.add_module(stage_name, stage_layers)
wuyuefeng's avatar
wuyuefeng committed
241
242
        return out_channels

243
    def make_decoder_layers(self, make_block, norm_cfg, in_channels):
zhangwenwei's avatar
zhangwenwei committed
244
        """make decoder layers using sparse convs.
wuyuefeng's avatar
wuyuefeng committed
245
246

        Args:
zhangwenwei's avatar
zhangwenwei committed
247
248
249
            make_block (method): A bounded function to build blocks.
            norm_cfg (dict[str]): Config of normalization layer.
            in_channels (int): The number of encoder input channels.
wuyuefeng's avatar
wuyuefeng committed
250
251

        Returns:
zhangwenwei's avatar
zhangwenwei committed
252
            int: The number of encoder output channels.
wuyuefeng's avatar
wuyuefeng committed
253
        """
254
255
256
        block_num = len(self.decoder_channels)
        for i, block_channels in enumerate(self.decoder_channels):
            paddings = self.decoder_paddings[i]
wuyuefeng's avatar
wuyuefeng committed
257
            setattr(
258
                self, f'lateral_layer{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
259
260
261
262
                SparseBasicBlock(
                    in_channels,
                    block_channels[0],
                    conv_cfg=dict(
263
                        type='SubMConv3d', indice_key=f'subm{block_num - i}'),
wuyuefeng's avatar
wuyuefeng committed
264
265
                    norm_cfg=norm_cfg))
            setattr(
266
                self, f'merge_layer{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
267
268
269
270
271
272
                make_block(
                    in_channels * 2,
                    block_channels[1],
                    3,
                    norm_cfg=norm_cfg,
                    padding=paddings[0],
wuyuefeng's avatar
wuyuefeng committed
273
274
                    indice_key=f'subm{block_num - i}',
                    conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
275
276
277
278
279
280
281
282
283
            if block_num - i != 1:
                setattr(
                    self, f'upsample_layer{block_num - i}',
                    make_block(
                        in_channels,
                        block_channels[2],
                        3,
                        norm_cfg=norm_cfg,
                        indice_key=f'spconv{block_num - i}',
wuyuefeng's avatar
wuyuefeng committed
284
                        conv_type='SparseInverseConv3d'))
wuyuefeng's avatar
wuyuefeng committed
285
286
            else:
                # use submanifold conv instead of inverse conv
wuyuefeng's avatar
wuyuefeng committed
287
                # in the last block
wuyuefeng's avatar
wuyuefeng committed
288
289
290
291
292
293
294
295
296
                setattr(
                    self, f'upsample_layer{block_num - i}',
                    make_block(
                        in_channels,
                        block_channels[2],
                        3,
                        norm_cfg=norm_cfg,
                        padding=paddings[1],
                        indice_key='subm1',
wuyuefeng's avatar
wuyuefeng committed
297
                        conv_type='SubMConv3d'))
wuyuefeng's avatar
wuyuefeng committed
298
            in_channels = block_channels[2]