"test/algo/compression/v1/test_model_speedup.py" did not exist on "5c861676fc0ce336b89ba5db40de05a81d8e987b"
point_fusion.py 16.7 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangwenwei's avatar
zhangwenwei committed
2
import torch
3
from mmcv.cnn import ConvModule
4
from mmengine.model import BaseModule
zhangwenwei's avatar
zhangwenwei committed
5
6
from torch import nn as nn
from torch.nn import functional as F
zhangwenwei's avatar
zhangwenwei committed
7

8
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
9
from mmdet3d.structures.bbox_3d import (get_proj_mat_by_coord_type,
10
                                        points_cam2img, points_img2cam)
11
from . import apply_3d_transformation
zhangwenwei's avatar
zhangwenwei committed
12
13


14
15
16
17
18
19
20
21
22
23
24
25
def point_sample(img_meta,
                 img_features,
                 points,
                 proj_mat,
                 coord_type,
                 img_scale_factor,
                 img_crop_offset,
                 img_flip,
                 img_pad_shape,
                 img_shape,
                 aligned=True,
                 padding_mode='zeros',
26
27
                 align_corners=True,
                 valid_flag=False):
zhangwenwei's avatar
zhangwenwei committed
28
    """Obtain image features using points.
zhangwenwei's avatar
zhangwenwei committed
29

zhangwenwei's avatar
zhangwenwei committed
30
    Args:
31
        img_meta (dict): Meta info.
wangtai's avatar
wangtai committed
32
33
        img_features (torch.Tensor): 1 x C x H x W image features.
        points (torch.Tensor): Nx3 point cloud in LiDAR coordinates.
34
35
        proj_mat (torch.Tensor): 4x4 transformation matrix.
        coord_type (str): 'DEPTH' or 'CAMERA' or 'LIDAR'.
36
        img_scale_factor (torch.Tensor): Scale factor with shape of
wangtai's avatar
wangtai committed
37
            (w_scale, h_scale).
38
        img_crop_offset (torch.Tensor): Crop offset used to crop
wangtai's avatar
wangtai committed
39
            image during data augmentation with shape of (w_offset, h_offset).
zhangwenwei's avatar
zhangwenwei committed
40
41
        img_flip (bool): Whether the image is flipped.
        img_pad_shape (tuple[int]): int tuple indicates the h & w after
wangtai's avatar
wangtai committed
42
            padding, this is necessary to obtain features in feature map.
zhangwenwei's avatar
zhangwenwei committed
43
        img_shape (tuple[int]): int tuple indicates the h & w before padding
wangtai's avatar
wangtai committed
44
            after scaling, this is necessary for flipping coordinates.
45
        aligned (bool): Whether use bilinear interpolation when
zhangwenwei's avatar
zhangwenwei committed
46
            sampling image features for each point. Defaults to True.
47
        padding_mode (str): Padding mode when padding values for
zhangwenwei's avatar
zhangwenwei committed
48
            features of out-of-image points. Defaults to 'zeros'.
49
        align_corners (bool): Whether to align corners when
zhangwenwei's avatar
zhangwenwei committed
50
            sampling image features for each point. Defaults to True.
51
52
53
        valid_flag (bool): Whether to filter out the points that
            outside the image and with depth smaller than 0. Defaults to
            False.
zhangwenwei's avatar
zhangwenwei committed
54
55

    Returns:
wangtai's avatar
wangtai committed
56
        torch.Tensor: NxC image features sampled by point coordinates.
zhangwenwei's avatar
zhangwenwei committed
57
    """
58
59

    # apply transformation based on info in img_meta
60
61
    points = apply_3d_transformation(
        points, coord_type, img_meta, reverse=True)
zhangwenwei's avatar
zhangwenwei committed
62

63
    # project points to image coordinate
64
65
66
67
68
69
    if valid_flag:
        proj_pts = points_cam2img(points, proj_mat, with_depth=True)
        pts_2d = proj_pts[..., :2]
        depths = proj_pts[..., 2]
    else:
        pts_2d = points_cam2img(points, proj_mat)
zhangwenwei's avatar
zhangwenwei committed
70
71
72
73
74
75
76
77
78
79
80
81

    # img transformation: scale -> crop -> flip
    # the image is resized by img_scale_factor
    img_coors = pts_2d[:, 0:2] * img_scale_factor  # Nx2
    img_coors -= img_crop_offset

    # grid sample, the valid grid range should be in [-1,1]
    coor_x, coor_y = torch.split(img_coors, 1, dim=1)  # each is Nx1

    if img_flip:
        # by default we take it as horizontal flip
        # use img_shape before padding for flip
82
83
        ori_h, ori_w = img_shape
        coor_x = ori_w - coor_x
zhangwenwei's avatar
zhangwenwei committed
84
85

    h, w = img_pad_shape
86
87
88
    norm_coor_y = coor_y / h * 2 - 1
    norm_coor_x = coor_x / w * 2 - 1
    grid = torch.cat([norm_coor_x, norm_coor_y],
zhangwenwei's avatar
zhangwenwei committed
89
90
91
92
93
94
95
96
97
98
99
                     dim=1).unsqueeze(0).unsqueeze(0)  # Nx2 -> 1x1xNx2

    # align_corner=True provides higher performance
    mode = 'bilinear' if aligned else 'nearest'
    point_features = F.grid_sample(
        img_features,
        grid,
        mode=mode,
        padding_mode=padding_mode,
        align_corners=align_corners)  # 1xCx1xN feats

100
101
102
103
104
105
106
107
108
    if valid_flag:
        # (N, )
        valid = (coor_x.squeeze() < w) & (coor_x.squeeze() > 0) & (
            coor_y.squeeze() < h) & (coor_y.squeeze() > 0) & (
                depths > 0)
        valid_features = point_features.squeeze().t()
        valid_features[~valid] = 0
        return valid_features, valid  # (N, C), (N,)

zhangwenwei's avatar
zhangwenwei committed
109
110
111
    return point_features.squeeze().t()


112
@MODELS.register_module()
113
class PointFusion(BaseModule):
zhangwenwei's avatar
zhangwenwei committed
114
    """Fuse image features from multi-scale features.
zhangwenwei's avatar
zhangwenwei committed
115
116
117
118
119
120
121
122

    Args:
        img_channels (list[int] | int): Channels of image features.
            It could be a list if the input is multi-scale image features.
        pts_channels (int): Channels of point features
        mid_channels (int): Channels of middle layers
        out_channels (int): Channels of output fused features
        img_levels (int, optional): Number of image levels. Defaults to 3.
123
124
        coord_type (str): 'DEPTH' or 'CAMERA' or 'LIDAR'.
            Defaults to 'LIDAR'.
zhangwenwei's avatar
zhangwenwei committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        conv_cfg (dict, optional): Dict config of conv layers of middle
            layers. Defaults to None.
        norm_cfg (dict, optional): Dict config of norm layers of middle
            layers. Defaults to None.
        act_cfg (dict, optional): Dict config of activatation layers.
            Defaults to None.
        activate_out (bool, optional): Whether to apply relu activation
            to output features. Defaults to True.
        fuse_out (bool, optional): Whether apply conv layer to the fused
            features. Defaults to False.
        dropout_ratio (int, float, optional): Dropout ratio of image
            features to prevent overfitting. Defaults to 0.
        aligned (bool, optional): Whether apply aligned feature fusion.
            Defaults to True.
        align_corners (bool, optional): Whether to align corner when
            sampling features according to points. Defaults to True.
        padding_mode (str, optional): Mode used to pad the features of
            points that do not have corresponding image features.
            Defaults to 'zeros'.
        lateral_conv (bool, optional): Whether to apply lateral convs
            to image features. Defaults to True.
zhangwenwei's avatar
zhangwenwei committed
146
147
148
149
150
151
152
153
    """

    def __init__(self,
                 img_channels,
                 pts_channels,
                 mid_channels,
                 out_channels,
                 img_levels=3,
154
                 coord_type='LIDAR',
zhangwenwei's avatar
zhangwenwei committed
155
156
                 conv_cfg=None,
                 norm_cfg=None,
zhangwenwei's avatar
zhangwenwei committed
157
                 act_cfg=None,
158
                 init_cfg=None,
zhangwenwei's avatar
zhangwenwei committed
159
160
161
162
163
164
165
                 activate_out=True,
                 fuse_out=False,
                 dropout_ratio=0,
                 aligned=True,
                 align_corners=True,
                 padding_mode='zeros',
                 lateral_conv=True):
166
        super(PointFusion, self).__init__(init_cfg=init_cfg)
zhangwenwei's avatar
zhangwenwei committed
167
168
169
170
171
172
173
174
175
        if isinstance(img_levels, int):
            img_levels = [img_levels]
        if isinstance(img_channels, int):
            img_channels = [img_channels] * len(img_levels)
        assert isinstance(img_levels, list)
        assert isinstance(img_channels, list)
        assert len(img_channels) == len(img_levels)

        self.img_levels = img_levels
176
        self.coord_type = coord_type
zhangwenwei's avatar
zhangwenwei committed
177
        self.act_cfg = act_cfg
zhangwenwei's avatar
zhangwenwei committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
        self.activate_out = activate_out
        self.fuse_out = fuse_out
        self.dropout_ratio = dropout_ratio
        self.img_channels = img_channels
        self.aligned = aligned
        self.align_corners = align_corners
        self.padding_mode = padding_mode

        self.lateral_convs = None
        if lateral_conv:
            self.lateral_convs = nn.ModuleList()
            for i in range(len(img_channels)):
                l_conv = ConvModule(
                    img_channels[i],
                    mid_channels,
                    3,
                    padding=1,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
zhangwenwei's avatar
zhangwenwei committed
197
                    act_cfg=self.act_cfg,
zhangwenwei's avatar
zhangwenwei committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
                    inplace=False)
                self.lateral_convs.append(l_conv)
            self.img_transform = nn.Sequential(
                nn.Linear(mid_channels * len(img_channels), out_channels),
                nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
            )
        else:
            self.img_transform = nn.Sequential(
                nn.Linear(sum(img_channels), out_channels),
                nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
            )
        self.pts_transform = nn.Sequential(
            nn.Linear(pts_channels, out_channels),
            nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
        )

        if self.fuse_out:
            self.fuse_conv = nn.Sequential(
                nn.Linear(mid_channels, out_channels),
                # For pts the BN is initialized differently by default
                # TODO: check whether this is necessary
                nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
                nn.ReLU(inplace=False))

222
223
224
225
226
        if init_cfg is None:
            self.init_cfg = [
                dict(type='Xavier', layer='Conv2d', distribution='uniform'),
                dict(type='Xavier', layer='Linear', distribution='uniform')
            ]
zhangwenwei's avatar
zhangwenwei committed
227

zhangwenwei's avatar
zhangwenwei committed
228
    def forward(self, img_feats, pts, pts_feats, img_metas):
zhangwenwei's avatar
zhangwenwei committed
229
        """Forward function.
zhangwenwei's avatar
zhangwenwei committed
230
231

        Args:
wangtai's avatar
wangtai committed
232
233
234
235
236
            img_feats (list[torch.Tensor]): Image features.
            pts: [list[torch.Tensor]]: A batch of points with shape N x 3.
            pts_feats (torch.Tensor): A tensor consist of point features of the
                total batch.
            img_metas (list[dict]): Meta information of images.
zhangwenwei's avatar
zhangwenwei committed
237

zhangwenwei's avatar
zhangwenwei committed
238
        Returns:
wangtai's avatar
wangtai committed
239
            torch.Tensor: Fused features of each point.
zhangwenwei's avatar
zhangwenwei committed
240
        """
zhangwenwei's avatar
zhangwenwei committed
241
        img_pts = self.obtain_mlvl_feats(img_feats, pts, img_metas)
zhangwenwei's avatar
zhangwenwei committed
242
243
244
245
246
247
248
249
250
251
252
253
254
        img_pre_fuse = self.img_transform(img_pts)
        if self.training and self.dropout_ratio > 0:
            img_pre_fuse = F.dropout(img_pre_fuse, self.dropout_ratio)
        pts_pre_fuse = self.pts_transform(pts_feats)

        fuse_out = img_pre_fuse + pts_pre_fuse
        if self.activate_out:
            fuse_out = F.relu(fuse_out)
        if self.fuse_out:
            fuse_out = self.fuse_conv(fuse_out)

        return fuse_out

zhangwenwei's avatar
zhangwenwei committed
255
    def obtain_mlvl_feats(self, img_feats, pts, img_metas):
256
257
258
259
260
261
262
263
264
265
266
        """Obtain multi-level features for each point.

        Args:
            img_feats (list(torch.Tensor)): Multi-scale image features produced
                by image backbone in shape (N, C, H, W).
            pts (list[torch.Tensor]): Points of each sample.
            img_metas (list[dict]): Meta information for each sample.

        Returns:
            torch.Tensor: Corresponding image features of each point.
        """
zhangwenwei's avatar
zhangwenwei committed
267
268
269
270
271
272
273
274
275
        if self.lateral_convs is not None:
            img_ins = [
                lateral_conv(img_feats[i])
                for i, lateral_conv in zip(self.img_levels, self.lateral_convs)
            ]
        else:
            img_ins = img_feats
        img_feats_per_point = []
        # Sample multi-level features
zhangwenwei's avatar
zhangwenwei committed
276
        for i in range(len(img_metas)):
zhangwenwei's avatar
zhangwenwei committed
277
278
279
280
            mlvl_img_feats = []
            for level in range(len(self.img_levels)):
                mlvl_img_feats.append(
                    self.sample_single(img_ins[level][i:i + 1], pts[i][:, :3],
zhangwenwei's avatar
zhangwenwei committed
281
                                       img_metas[i]))
zhangwenwei's avatar
zhangwenwei committed
282
283
284
285
286
287
288
            mlvl_img_feats = torch.cat(mlvl_img_feats, dim=-1)
            img_feats_per_point.append(mlvl_img_feats)

        img_pts = torch.cat(img_feats_per_point, dim=0)
        return img_pts

    def sample_single(self, img_feats, pts, img_meta):
289
290
291
292
        """Sample features from single level image feature map.

        Args:
            img_feats (torch.Tensor): Image feature map in shape
293
                (1, C, H, W).
294
295
296
297
298
299
            pts (torch.Tensor): Points of a single sample.
            img_meta (dict): Meta information of the single sample.

        Returns:
            torch.Tensor: Single level image features of each point.
        """
300
        # TODO: image transformation also extracted
zhangwenwei's avatar
zhangwenwei committed
301
        img_scale_factor = (
zhangwenwei's avatar
zhangwenwei committed
302
            pts.new_tensor(img_meta['scale_factor'][:2])
zhangwenwei's avatar
zhangwenwei committed
303
304
305
306
307
            if 'scale_factor' in img_meta.keys() else 1)
        img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False
        img_crop_offset = (
            pts.new_tensor(img_meta['img_crop_offset'])
            if 'img_crop_offset' in img_meta.keys() else 0)
308
        proj_mat = get_proj_mat_by_coord_type(img_meta, self.coord_type)
zhangwenwei's avatar
zhangwenwei committed
309
        img_pts = point_sample(
310
311
312
313
314
315
316
            img_meta=img_meta,
            img_features=img_feats,
            points=pts,
            proj_mat=pts.new_tensor(proj_mat),
            coord_type=self.coord_type,
            img_scale_factor=img_scale_factor,
            img_crop_offset=img_crop_offset,
zhangwenwei's avatar
zhangwenwei committed
317
            img_flip=img_flip,
318
            img_pad_shape=img_meta['input_shape'][:2],
zhangwenwei's avatar
zhangwenwei committed
319
320
321
322
323
324
            img_shape=img_meta['img_shape'][:2],
            aligned=self.aligned,
            padding_mode=self.padding_mode,
            align_corners=self.align_corners,
        )
        return img_pts
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415


def voxel_sample(voxel_features,
                 voxel_range,
                 voxel_size,
                 depth_samples,
                 proj_mat,
                 downsample_factor,
                 img_scale_factor,
                 img_crop_offset,
                 img_flip,
                 img_pad_shape,
                 img_shape,
                 aligned=True,
                 padding_mode='zeros',
                 align_corners=True):
    """Obtain image features using points.

    Args:
        voxel_features (torch.Tensor): 1 x C x Nx x Ny x Nz voxel features.
        voxel_range (list): The range of voxel features.
        voxel_size (:obj:`ConfigDict` or dict): The voxel size of voxel
            features.
        depth_samples (torch.Tensor): N depth samples in LiDAR coordinates.
        proj_mat (torch.Tensor): ORIGINAL LiDAR2img projection matrix
            for N views.
        downsample_factor (int): The downsample factor in rescaling.
        img_scale_factor (tuple[torch.Tensor]): Scale factor with shape of
            (w_scale, h_scale).
        img_crop_offset (tuple[torch.Tensor]): Crop offset used to crop
            image during data augmentation with shape of (w_offset, h_offset).
        img_flip (bool): Whether the image is flipped.
        img_pad_shape (tuple[int]): int tuple indicates the h & w after
            padding, this is necessary to obtain features in feature map.
        img_shape (tuple[int]): int tuple indicates the h & w before padding
            after scaling, this is necessary for flipping coordinates.
        aligned (bool, optional): Whether use bilinear interpolation when
            sampling image features for each point. Defaults to True.
        padding_mode (str, optional): Padding mode when padding values for
            features of out-of-image points. Defaults to 'zeros'.
        align_corners (bool, optional): Whether to align corners when
            sampling image features for each point. Defaults to True.

    Returns:
        torch.Tensor: 1xCxDxHxW frustum features sampled from voxel features.
    """
    # construct frustum grid
    device = voxel_features.device
    h, w = img_pad_shape
    h_out = round(h / downsample_factor)
    w_out = round(w / downsample_factor)
    ws = (torch.linspace(0, w_out - 1, w_out) * downsample_factor).to(device)
    hs = (torch.linspace(0, h_out - 1, h_out) * downsample_factor).to(device)
    depths = depth_samples[::downsample_factor]
    num_depths = len(depths)
    ds_3d, ys_3d, xs_3d = torch.meshgrid(depths, hs, ws)
    # grid: (D, H_out, W_out, 3) -> (D*H_out*W_out, 3)
    grid = torch.stack([xs_3d, ys_3d, ds_3d], dim=-1).view(-1, 3)
    # recover the coordinates in the canonical space
    # reverse order of augmentations: flip -> crop -> scale
    if img_flip:
        # by default we take it as horizontal flip
        # use img_shape before padding for flip
        ori_h, ori_w = img_shape
        grid[:, 0] = ori_w - grid[:, 0]
    grid[:, :2] += img_crop_offset
    grid[:, :2] /= img_scale_factor
    # grid3d: (D*H_out*W_out, 3) in LiDAR coordinate system
    grid3d = points_img2cam(grid, proj_mat)
    # convert the 3D point coordinates to voxel coordinates
    voxel_range = torch.tensor(voxel_range).to(device).view(1, 6)
    voxel_size = torch.tensor(voxel_size).to(device).view(1, 3)
    # suppose the voxel grid is generated with AlignedAnchorGenerator
    # -0.5 given each grid is located at the center of the grid
    # TODO: study whether here needs -0.5
    grid3d = (grid3d - voxel_range[:, :3]) / voxel_size - 0.5
    grid_size = (voxel_range[:, 3:] - voxel_range[:, :3]) / voxel_size
    # normalize grid3d to (-1, 1)
    grid3d = grid3d / grid_size * 2 - 1
    # (x, y, z) -> (z, y, x) for grid_sampling
    grid3d = grid3d.view(1, num_depths, h_out, w_out, 3)[..., [2, 1, 0]]
    # align_corner=True provides higher performance
    mode = 'bilinear' if aligned else 'nearest'
    frustum_features = F.grid_sample(
        voxel_features,
        grid3d,
        mode=mode,
        padding_mode=padding_mode,
        align_corners=align_corners)  # 1xCxDxHxW feats

    return frustum_features