point_fusion.py 12.1 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
import torch
2
from mmcv.cnn import ConvModule, xavier_init
zhangwenwei's avatar
zhangwenwei committed
3
4
from torch import nn as nn
from torch.nn import functional as F
zhangwenwei's avatar
zhangwenwei committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

from ..registry import FUSION_LAYERS


def point_sample(
    img_features,
    points,
    lidar2img_rt,
    pcd_rotate_mat,
    img_scale_factor,
    img_crop_offset,
    pcd_trans_factor,
    pcd_scale_factor,
    pcd_flip,
    img_flip,
    img_pad_shape,
    img_shape,
    aligned=True,
    padding_mode='zeros',
    align_corners=True,
):
zhangwenwei's avatar
zhangwenwei committed
26
    """Obtain image features using points.
zhangwenwei's avatar
zhangwenwei committed
27

zhangwenwei's avatar
zhangwenwei committed
28
    Args:
zhangwenwei's avatar
zhangwenwei committed
29
        img_features (Tensor): 1xCxHxW image features
zhangwenwei's avatar
zhangwenwei committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
        points (Tensor): Nx3 point cloud in LiDAR coordinates
        lidar2img_rt (Tensor): 4x4 transformation matrix
        pcd_rotate_mat (Tensor): 3x3 rotation matrix of points
            during augmentation
        img_scale_factor (Tensor): (w_scale, h_scale)
        img_crop_offset (Tensor): (w_offset, h_offset) offset used to crop
            image during data augmentation
        pcd_trans_factor ([type]): Translation of points in augmentation
        pcd_scale_factor (float): Scale factor of points during
            data augmentation
        pcd_flip (bool): Whether the points are flipped.
        img_flip (bool): Whether the image is flipped.
        img_pad_shape (tuple[int]): int tuple indicates the h & w after
            padding, this is necessary to obtain features in feature map
        img_shape (tuple[int]): int tuple indicates the h & w before padding
zhangwenwei's avatar
zhangwenwei committed
45
            after scaling, this is necessary for flipping coordinates
zhangwenwei's avatar
zhangwenwei committed
46
47
48
49
50
51
52
53
        aligned (bool, optional): Whether use bilinear interpolation when
            sampling image features for each point. Defaults to True.
        padding_mode (str, optional): Padding mode when padding values for
            features of out-of-image points. Defaults to 'zeros'.
        align_corners (bool, optional): Whether to align corners when
            sampling image features for each point. Defaults to True.

    Returns:
zhangwenwei's avatar
zhangwenwei committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        (Tensor): NxC image features sampled by point coordinates
    """
    # aug order: flip -> trans -> scale -> rot
    # The transformation follows the augmentation order in data pipeline
    if pcd_flip:
        # if the points are flipped, flip them back first
        points[:, 1] = -points[:, 1]

    points -= pcd_trans_factor
    # the points should be scaled to the original scale in velo coordinate
    points /= pcd_scale_factor
    # the points should be rotated back
    # pcd_rotate_mat @ pcd_rotate_mat.inverse() is not exactly an identity
    # matrix, use angle to create the inverse rot matrix neither.
    points = points @ pcd_rotate_mat.inverse()

    # project points from velo coordinate to camera coordinate
    num_points = points.shape[0]
    pts_4d = torch.cat([points, points.new_ones(size=(num_points, 1))], dim=-1)
    pts_2d = pts_4d @ lidar2img_rt.t()

    # cam_points is Tensor of Nx4 whose last column is 1
    # transform camera coordinate to image coordinate

    pts_2d[:, 2] = torch.clamp(pts_2d[:, 2], min=1e-5)
    pts_2d[:, 0] /= pts_2d[:, 2]
    pts_2d[:, 1] /= pts_2d[:, 2]

    # img transformation: scale -> crop -> flip
    # the image is resized by img_scale_factor
    img_coors = pts_2d[:, 0:2] * img_scale_factor  # Nx2
    img_coors -= img_crop_offset

    # grid sample, the valid grid range should be in [-1,1]
    coor_x, coor_y = torch.split(img_coors, 1, dim=1)  # each is Nx1

    if img_flip:
        # by default we take it as horizontal flip
        # use img_shape before padding for flip
        orig_h, orig_w = img_shape
        coor_x = orig_w - coor_x

    h, w = img_pad_shape
    coor_y = coor_y / h * 2 - 1
    coor_x = coor_x / w * 2 - 1
    grid = torch.cat([coor_x, coor_y],
                     dim=1).unsqueeze(0).unsqueeze(0)  # Nx2 -> 1x1xNx2

    # align_corner=True provides higher performance
    mode = 'bilinear' if aligned else 'nearest'
    point_features = F.grid_sample(
        img_features,
        grid,
        mode=mode,
        padding_mode=padding_mode,
        align_corners=align_corners)  # 1xCx1xN feats

    return point_features.squeeze().t()


114
@FUSION_LAYERS.register_module()
zhangwenwei's avatar
zhangwenwei committed
115
class PointFusion(nn.Module):
zhangwenwei's avatar
zhangwenwei committed
116
    """Fuse image features from multi-scale features.
zhangwenwei's avatar
zhangwenwei committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

    Args:
        img_channels (list[int] | int): Channels of image features.
            It could be a list if the input is multi-scale image features.
        pts_channels (int): Channels of point features
        mid_channels (int): Channels of middle layers
        out_channels (int): Channels of output fused features
        img_levels (int, optional): Number of image levels. Defaults to 3.
        conv_cfg (dict, optional): Dict config of conv layers of middle
            layers. Defaults to None.
        norm_cfg (dict, optional): Dict config of norm layers of middle
            layers. Defaults to None.
        act_cfg (dict, optional): Dict config of activatation layers.
            Defaults to None.
        activate_out (bool, optional): Whether to apply relu activation
            to output features. Defaults to True.
        fuse_out (bool, optional): Whether apply conv layer to the fused
            features. Defaults to False.
        dropout_ratio (int, float, optional): Dropout ratio of image
            features to prevent overfitting. Defaults to 0.
        aligned (bool, optional): Whether apply aligned feature fusion.
            Defaults to True.
        align_corners (bool, optional): Whether to align corner when
            sampling features according to points. Defaults to True.
        padding_mode (str, optional): Mode used to pad the features of
            points that do not have corresponding image features.
            Defaults to 'zeros'.
        lateral_conv (bool, optional): Whether to apply lateral convs
            to image features. Defaults to True.
zhangwenwei's avatar
zhangwenwei committed
146
147
148
149
150
151
152
153
154
155
    """

    def __init__(self,
                 img_channels,
                 pts_channels,
                 mid_channels,
                 out_channels,
                 img_levels=3,
                 conv_cfg=None,
                 norm_cfg=None,
zhangwenwei's avatar
zhangwenwei committed
156
                 act_cfg=None,
zhangwenwei's avatar
zhangwenwei committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
                 activate_out=True,
                 fuse_out=False,
                 dropout_ratio=0,
                 aligned=True,
                 align_corners=True,
                 padding_mode='zeros',
                 lateral_conv=True):
        super(PointFusion, self).__init__()
        if isinstance(img_levels, int):
            img_levels = [img_levels]
        if isinstance(img_channels, int):
            img_channels = [img_channels] * len(img_levels)
        assert isinstance(img_levels, list)
        assert isinstance(img_channels, list)
        assert len(img_channels) == len(img_levels)

        self.img_levels = img_levels
zhangwenwei's avatar
zhangwenwei committed
174
        self.act_cfg = act_cfg
zhangwenwei's avatar
zhangwenwei committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        self.activate_out = activate_out
        self.fuse_out = fuse_out
        self.dropout_ratio = dropout_ratio
        self.img_channels = img_channels
        self.aligned = aligned
        self.align_corners = align_corners
        self.padding_mode = padding_mode

        self.lateral_convs = None
        if lateral_conv:
            self.lateral_convs = nn.ModuleList()
            for i in range(len(img_channels)):
                l_conv = ConvModule(
                    img_channels[i],
                    mid_channels,
                    3,
                    padding=1,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
zhangwenwei's avatar
zhangwenwei committed
194
                    act_cfg=self.act_cfg,
zhangwenwei's avatar
zhangwenwei committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
                    inplace=False)
                self.lateral_convs.append(l_conv)
            self.img_transform = nn.Sequential(
                nn.Linear(mid_channels * len(img_channels), out_channels),
                nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
            )
        else:
            self.img_transform = nn.Sequential(
                nn.Linear(sum(img_channels), out_channels),
                nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
            )
        self.pts_transform = nn.Sequential(
            nn.Linear(pts_channels, out_channels),
            nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
        )

        if self.fuse_out:
            self.fuse_conv = nn.Sequential(
                nn.Linear(mid_channels, out_channels),
                # For pts the BN is initialized differently by default
                # TODO: check whether this is necessary
                nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
                nn.ReLU(inplace=False))

        self.init_weights()

    # default init_weights for conv(msra) and norm in ConvModule
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                xavier_init(m, distribution='uniform')

zhangwenwei's avatar
zhangwenwei committed
227
    def forward(self, img_feats, pts, pts_feats, img_metas):
zhangwenwei's avatar
zhangwenwei committed
228
        """Forward function.
zhangwenwei's avatar
zhangwenwei committed
229
230
231
232
233
234
235

        Args:
            img_feats (list[Tensor]): img features
            pts: [list[Tensor]]: a batch of points with shape Nx3
            pts_feats (Tensor): a tensor consist of point features of the
                total batch
            img_metas (list[dict]): meta information of images
zhangwenwei's avatar
zhangwenwei committed
236

zhangwenwei's avatar
zhangwenwei committed
237
238
        Returns:
            torch.Tensor: fused features of each point.
zhangwenwei's avatar
zhangwenwei committed
239
        """
zhangwenwei's avatar
zhangwenwei committed
240
        img_pts = self.obtain_mlvl_feats(img_feats, pts, img_metas)
zhangwenwei's avatar
zhangwenwei committed
241
242
243
244
245
246
247
248
249
250
251
252
253
        img_pre_fuse = self.img_transform(img_pts)
        if self.training and self.dropout_ratio > 0:
            img_pre_fuse = F.dropout(img_pre_fuse, self.dropout_ratio)
        pts_pre_fuse = self.pts_transform(pts_feats)

        fuse_out = img_pre_fuse + pts_pre_fuse
        if self.activate_out:
            fuse_out = F.relu(fuse_out)
        if self.fuse_out:
            fuse_out = self.fuse_conv(fuse_out)

        return fuse_out

zhangwenwei's avatar
zhangwenwei committed
254
    def obtain_mlvl_feats(self, img_feats, pts, img_metas):
zhangwenwei's avatar
zhangwenwei committed
255
256
257
258
259
260
261
262
263
        if self.lateral_convs is not None:
            img_ins = [
                lateral_conv(img_feats[i])
                for i, lateral_conv in zip(self.img_levels, self.lateral_convs)
            ]
        else:
            img_ins = img_feats
        img_feats_per_point = []
        # Sample multi-level features
zhangwenwei's avatar
zhangwenwei committed
264
        for i in range(len(img_metas)):
zhangwenwei's avatar
zhangwenwei committed
265
266
267
268
269
270
271
            mlvl_img_feats = []
            for level in range(len(self.img_levels)):
                if torch.isnan(img_ins[level][i:i + 1]).any():
                    import pdb
                    pdb.set_trace()
                mlvl_img_feats.append(
                    self.sample_single(img_ins[level][i:i + 1], pts[i][:, :3],
zhangwenwei's avatar
zhangwenwei committed
272
                                       img_metas[i]))
zhangwenwei's avatar
zhangwenwei committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
            mlvl_img_feats = torch.cat(mlvl_img_feats, dim=-1)
            img_feats_per_point.append(mlvl_img_feats)

        img_pts = torch.cat(img_feats_per_point, dim=0)
        return img_pts

    def sample_single(self, img_feats, pts, img_meta):
        pcd_scale_factor = (
            img_meta['pcd_scale_factor']
            if 'pcd_scale_factor' in img_meta.keys() else 1)
        pcd_trans_factor = (
            pts.new_tensor(img_meta['pcd_trans'])
            if 'pcd_trans' in img_meta.keys() else 0)
        pcd_rotate_mat = (
zhangwenwei's avatar
zhangwenwei committed
287
288
            pts.new_tensor(img_meta['pcd_rotation']) if 'pcd_rotation'
            in img_meta.keys() else torch.eye(3).type_as(pts).to(pts.device))
zhangwenwei's avatar
zhangwenwei committed
289
        img_scale_factor = (
zhangwenwei's avatar
zhangwenwei committed
290
            pts.new_tensor(img_meta['scale_factor'][:2])
zhangwenwei's avatar
zhangwenwei committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
            if 'scale_factor' in img_meta.keys() else 1)
        pcd_flip = img_meta['pcd_flip'] if 'pcd_flip' in img_meta.keys(
        ) else False
        img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False
        img_crop_offset = (
            pts.new_tensor(img_meta['img_crop_offset'])
            if 'img_crop_offset' in img_meta.keys() else 0)
        img_pts = point_sample(
            img_feats,
            pts,
            pts.new_tensor(img_meta['lidar2img']),
            pcd_rotate_mat,
            img_scale_factor,
            img_crop_offset,
            pcd_trans_factor,
            pcd_scale_factor,
            pcd_flip=pcd_flip,
            img_flip=img_flip,
            img_pad_shape=img_meta['pad_shape'][:2],
            img_shape=img_meta['img_shape'][:2],
            aligned=self.aligned,
            padding_mode=self.padding_mode,
            align_corners=self.align_corners,
        )
        return img_pts