spvcnn_backone.py 11.4 KB
Newer Older
Sun Jiahao's avatar
Sun Jiahao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence

import torch
from mmengine.registry import MODELS
from torch import Tensor, nn

from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
from .minkunet_backbone import MinkUNetBackbone

if IS_TORCHSPARSE_AVAILABLE:
    import torchsparse
    import torchsparse.nn.functional as F
    from torchsparse.nn.utils import get_kernel_offsets
    from torchsparse.tensor import PointTensor, SparseTensor
else:
    PointTensor = SparseTensor = None


@MODELS.register_module()
class SPVCNNBackbone(MinkUNetBackbone):
    """SPVCNN backbone with torchsparse backend.

    More details can be found in `paper <https://arxiv.org/abs/2007.16100>`_ .

    Args:
        in_channels (int): Number of input voxel feature channels.
            Defaults to 4.
        base_channels (int): The input channels for first encoder layer.
            Defaults to 32.
31
32
        num_stages (int): Number of stages in encoder and decoder.
            Defaults to 4.
Sun Jiahao's avatar
Sun Jiahao committed
33
34
35
36
37
        encoder_channels (List[int]): Convolutional channels of each encode
            layer. Defaults to [32, 64, 128, 256].
        decoder_channels (List[int]): Convolutional channels of each decode
            layer. Defaults to [256, 128, 96, 96].
        drop_ratio (float): Dropout ratio of voxel features. Defaults to 0.3.
38
        sparseconv_backend (str): Sparse convolution backend.
Sun Jiahao's avatar
Sun Jiahao committed
39
40
41
42
43
44
45
        init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`]
            , optional): Initialization config dict. Defaults to None.
    """

    def __init__(self,
                 in_channels: int = 4,
                 base_channels: int = 32,
46
                 num_stages: int = 4,
Sun Jiahao's avatar
Sun Jiahao committed
47
48
49
                 encoder_channels: Sequence[int] = [32, 64, 128, 256],
                 decoder_channels: Sequence[int] = [256, 128, 96, 96],
                 drop_ratio: float = 0.3,
50
51
52
53
54
55
                 sparseconv_backend: str = 'torchsparse',
                 **kwargs) -> None:
        assert num_stages == 4, 'SPVCNN backbone only supports 4 stages.'
        assert sparseconv_backend == 'torchsparse', \
            f'SPVCNN backbone only supports torchsparse backend, but got ' \
            f'sparseconv backend: {sparseconv_backend}.'
Sun Jiahao's avatar
Sun Jiahao committed
56
57
58
        super().__init__(
            in_channels=in_channels,
            base_channels=base_channels,
59
            num_stages=num_stages,
Sun Jiahao's avatar
Sun Jiahao committed
60
61
            encoder_channels=encoder_channels,
            decoder_channels=decoder_channels,
62
63
            sparseconv_backend=sparseconv_backend,
            **kwargs)
Sun Jiahao's avatar
Sun Jiahao committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77

        self.point_transforms = nn.ModuleList([
            nn.Sequential(
                nn.Linear(base_channels, encoder_channels[-1]),
                nn.BatchNorm1d(encoder_channels[-1]), nn.ReLU(True)),
            nn.Sequential(
                nn.Linear(encoder_channels[-1], decoder_channels[2]),
                nn.BatchNorm1d(decoder_channels[2]), nn.ReLU(True)),
            nn.Sequential(
                nn.Linear(decoder_channels[2], decoder_channels[4]),
                nn.BatchNorm1d(decoder_channels[4]), nn.ReLU(True))
        ])
        self.dropout = nn.Dropout(drop_ratio, True)

78
    def forward(self, voxel_features: Tensor, coors: Tensor) -> Tensor:
Sun Jiahao's avatar
Sun Jiahao committed
79
80
81
82
83
84
85
86
87
88
89
90
        """Forward function.

        Args:
            voxel_features (Tensor): Voxel features in shape (N, C).
            coors (Tensor): Coordinates in shape (N, 4),
                the columns in the order of (x_idx, y_idx, z_idx, batch_idx).

        Returns:
            PointTensor: Backbone features.
        """
        voxels = SparseTensor(voxel_features, coors)
        points = PointTensor(voxels.F, voxels.C.float())
91
        voxels = initial_voxelize(points)
Sun Jiahao's avatar
Sun Jiahao committed
92
93

        voxels = self.conv_input(voxels)
94
95
        points = voxel_to_point(voxels, points)
        voxels = point_to_voxel(voxels, points)
Sun Jiahao's avatar
Sun Jiahao committed
96
97
98
99
100
101
        laterals = [voxels]
        for encoder in self.encoder:
            voxels = encoder(voxels)
            laterals.append(voxels)
        laterals = laterals[:-1][::-1]

102
103
        points = voxel_to_point(voxels, points, self.point_transforms[0])
        voxels = point_to_voxel(voxels, points)
Sun Jiahao's avatar
Sun Jiahao committed
104
105
106
107
108
109
110
111
112
        voxels.F = self.dropout(voxels.F)

        decoder_outs = []
        for i, decoder in enumerate(self.decoder):
            voxels = decoder[0](voxels)
            voxels = torchsparse.cat((voxels, laterals[i]))
            voxels = decoder[1](voxels)
            decoder_outs.append(voxels)
            if i == 1:
113
114
115
                points = voxel_to_point(voxels, points,
                                        self.point_transforms[1])
                voxels = point_to_voxel(voxels, points)
Sun Jiahao's avatar
Sun Jiahao committed
116
117
                voxels.F = self.dropout(voxels.F)

118
119
        points = voxel_to_point(voxels, points, self.point_transforms[2])
        return points.F
Sun Jiahao's avatar
Sun Jiahao committed
120
121


122
123
124
@MODELS.register_module()
class MinkUNetBackboneV2(MinkUNetBackbone):
    r"""MinkUNet backbone V2.
Sun Jiahao's avatar
Sun Jiahao committed
125

126
    refer to https://github.com/PJLab-ADG/PCSeg/blob/master/pcseg/model/segmentor/voxel/minkunet/minkunet.py
Sun Jiahao's avatar
Sun Jiahao committed
127

128
129
130
    Args:
        sparseconv_backend (str): Sparse convolution backend.
    """  # noqa: E501
Sun Jiahao's avatar
Sun Jiahao committed
131

132
133
134
135
136
137
138
139
140
141
    def __init__(self,
                 sparseconv_backend: str = 'torchsparse',
                 **kwargs) -> None:
        assert sparseconv_backend == 'torchsparse', \
            f'SPVCNN backbone only supports torchsparse backend, but got ' \
            f'sparseconv backend: {sparseconv_backend}.'
        super().__init__(sparseconv_backend=sparseconv_backend, **kwargs)

    def forward(self, voxel_features: Tensor, coors: Tensor) -> Tensor:
        """Forward function.
Sun Jiahao's avatar
Sun Jiahao committed
142
143

        Args:
144
145
146
            voxel_features (Tensor): Voxel features in shape (N, C).
            coors (Tensor): Coordinates in shape (N, 4),
                the columns in the order of (x_idx, y_idx, z_idx, batch_idx).
Sun Jiahao's avatar
Sun Jiahao committed
147
148

        Returns:
149
            SparseTensor: Backbone features.
Sun Jiahao's avatar
Sun Jiahao committed
150
        """
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
        voxels = SparseTensor(voxel_features, coors)
        points = PointTensor(voxels.F, voxels.C.float())

        voxels = initial_voxelize(points)
        voxels = self.conv_input(voxels)
        points = voxel_to_point(voxels, points)

        laterals = [voxels]
        for encoder_layer in self.encoder:
            voxels = encoder_layer(voxels)
            laterals.append(voxels)
        laterals = laterals[:-1][::-1]
        points = voxel_to_point(voxels, points)
        output_features = [points.F]

        for i, decoder_layer in enumerate(self.decoder):
            voxels = decoder_layer[0](voxels)
            voxels = torchsparse.cat((voxels, laterals[i]))
            voxels = decoder_layer[1](voxels)
            if i % 2 == 1:
                points = voxel_to_point(voxels, points)
                output_features.append(points.F)

        points.F = torch.cat(output_features, dim=1)
        return points.F


def initial_voxelize(points: PointTensor) -> SparseTensor:
    """Voxelization again based on input PointTensor.

    Args:
        points (PointTensor): Input points after voxelization.

    Returns:
        SparseTensor: New voxels.
    """
    pc_hash = F.sphash(torch.floor(points.C).int())
    sparse_hash = torch.unique(pc_hash)
    idx_query = F.sphashquery(pc_hash, sparse_hash)
    counts = F.spcount(idx_query.int(), len(sparse_hash))

    inserted_coords = F.spvoxelize(torch.floor(points.C), idx_query, counts)
    inserted_coords = torch.round(inserted_coords).int()
    inserted_feat = F.spvoxelize(points.F, idx_query, counts)

    new_tensor = SparseTensor(inserted_feat, inserted_coords, 1)
    new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords)
    points.additional_features['idx_query'][1] = idx_query
    points.additional_features['counts'][1] = counts
    return new_tensor


def voxel_to_point(voxels: SparseTensor,
                   points: PointTensor,
                   point_transform: Optional[nn.Module] = None,
                   nearest: bool = False) -> PointTensor:
    """Feed voxel features to points.

    Args:
        voxels (SparseTensor): Input voxels.
        points (PointTensor): Input points.
        point_transform (nn.Module, optional): Point transform module
            for input point features. Defaults to None.
        nearest (bool): Whether to use nearest neighbor interpolation.
            Defaults to False.

    Returns:
        PointTensor: Points with new features.
    """
    if points.idx_query is None or points.weights is None or \
            points.idx_query.get(voxels.s) is None or \
            points.weights.get(voxels.s) is None:
        offsets = get_kernel_offsets(2, voxels.s, 1, device=points.F.device)
        old_hash = F.sphash(
            torch.cat([
                torch.floor(points.C[:, :3] / voxels.s[0]).int() * voxels.s[0],
                points.C[:, -1].int().view(-1, 1)
            ], 1), offsets)
        pc_hash = F.sphash(voxels.C.to(points.F.device))
        idx_query = F.sphashquery(old_hash, pc_hash)
        weights = F.calc_ti_weights(
            points.C, idx_query, scale=voxels.s[0]).transpose(0,
                                                              1).contiguous()
        idx_query = idx_query.transpose(0, 1).contiguous()
        if nearest:
            weights[:, 1:] = 0.
            idx_query[:, 1:] = -1
        new_features = F.spdevoxelize(voxels.F, idx_query, weights)
        new_tensor = PointTensor(
            new_features,
            points.C,
            idx_query=points.idx_query,
            weights=points.weights)
        new_tensor.additional_features = points.additional_features
        new_tensor.idx_query[voxels.s] = idx_query
        new_tensor.weights[voxels.s] = weights
        points.idx_query[voxels.s] = idx_query
        points.weights[voxels.s] = weights
    else:
        new_features = F.spdevoxelize(voxels.F, points.idx_query.get(voxels.s),
                                      points.weights.get(voxels.s))
        new_tensor = PointTensor(
            new_features,
            points.C,
            idx_query=points.idx_query,
            weights=points.weights)
        new_tensor.additional_features = points.additional_features

    if point_transform is not None:
        new_tensor.F = new_tensor.F + point_transform(points.F)

    return new_tensor


def point_to_voxel(voxels: SparseTensor, points: PointTensor) -> SparseTensor:
    """Feed point features to voxels.

    Args:
        voxels (SparseTensor): Input voxels.
        points (PointTensor): Input points.

    Returns:
        SparseTensor: Voxels with new features.
    """
    if points.additional_features is None or \
            points.additional_features.get('idx_query') is None or \
            points.additional_features['idx_query'].get(voxels.s) is None:
        pc_hash = F.sphash(
            torch.cat([
                torch.floor(points.C[:, :3] / voxels.s[0]).int() * voxels.s[0],
                points.C[:, -1].int().view(-1, 1)
            ], 1))
        sparse_hash = F.sphash(voxels.C)
        idx_query = F.sphashquery(pc_hash, sparse_hash)
        counts = F.spcount(idx_query.int(), voxels.C.shape[0])
        points.additional_features['idx_query'][voxels.s] = idx_query
        points.additional_features['counts'][voxels.s] = counts
    else:
        idx_query = points.additional_features['idx_query'][voxels.s]
        counts = points.additional_features['counts'][voxels.s]

    inserted_features = F.spvoxelize(points.F, idx_query, counts)
    new_tensor = SparseTensor(inserted_features, voxels.C, voxels.s)
    new_tensor.cmaps = voxels.cmaps
    new_tensor.kmaps = voxels.kmaps

    return new_tensor