"examples/kandinsky2_2/text_to_image/requirements.txt" did not exist on "e5eed5235bfcc7867110d2ad61560b710c07e344"
point_fusion.py 17.4 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import List, Tuple, Union

zhangwenwei's avatar
zhangwenwei committed
4
import torch
5
from mmcv.cnn import ConvModule
6
from mmengine.model import BaseModule
7
from torch import Tensor
zhangwenwei's avatar
zhangwenwei committed
8
9
from torch import nn as nn
from torch.nn import functional as F
zhangwenwei's avatar
zhangwenwei committed
10

11
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
12
from mmdet3d.structures.bbox_3d import (get_proj_mat_by_coord_type,
13
                                        points_cam2img, points_img2cam)
14
from mmdet3d.utils import OptConfigType, OptMultiConfig
15
from . import apply_3d_transformation
zhangwenwei's avatar
zhangwenwei committed
16
17


18
19
20
21
22
23
24
25
26
27
28
29
30
31
def point_sample(img_meta: dict,
                 img_features: Tensor,
                 points: Tensor,
                 proj_mat: Tensor,
                 coord_type: str,
                 img_scale_factor: Tensor,
                 img_crop_offset: Tensor,
                 img_flip: bool,
                 img_pad_shape: Tuple[int],
                 img_shape: Tuple[int],
                 aligned: bool = True,
                 padding_mode: str = 'zeros',
                 align_corners: bool = True,
                 valid_flag: bool = False) -> Tensor:
zhangwenwei's avatar
zhangwenwei committed
32
    """Obtain image features using points.
zhangwenwei's avatar
zhangwenwei committed
33

zhangwenwei's avatar
zhangwenwei committed
34
    Args:
35
        img_meta (dict): Meta info.
36
37
38
        img_features (Tensor): 1 x C x H x W image features.
        points (Tensor): Nx3 point cloud in LiDAR coordinates.
        proj_mat (Tensor): 4x4 transformation matrix.
39
        coord_type (str): 'DEPTH' or 'CAMERA' or 'LIDAR'.
40
        img_scale_factor (Tensor): Scale factor with shape of
wangtai's avatar
wangtai committed
41
            (w_scale, h_scale).
42
43
        img_crop_offset (Tensor): Crop offset used to crop image during
            data augmentation with shape of (w_offset, h_offset).
zhangwenwei's avatar
zhangwenwei committed
44
        img_flip (bool): Whether the image is flipped.
45
46
47
48
49
        img_pad_shape (Tuple[int]): Int tuple indicates the h & w after
            padding. This is necessary to obtain features in feature map.
        img_shape (Tuple[int]): Int tuple indicates the h & w before padding
            after scaling. This is necessary for flipping coordinates.
        aligned (bool): Whether to use bilinear interpolation when
zhangwenwei's avatar
zhangwenwei committed
50
            sampling image features for each point. Defaults to True.
51
        padding_mode (str): Padding mode when padding values for
zhangwenwei's avatar
zhangwenwei committed
52
            features of out-of-image points. Defaults to 'zeros'.
53
        align_corners (bool): Whether to align corners when
zhangwenwei's avatar
zhangwenwei committed
54
            sampling image features for each point. Defaults to True.
55
56
        valid_flag (bool): Whether to filter out the points that outside
            the image and with depth smaller than 0. Defaults to False.
zhangwenwei's avatar
zhangwenwei committed
57
58

    Returns:
59
        Tensor: NxC image features sampled by point coordinates.
zhangwenwei's avatar
zhangwenwei committed
60
    """
61
62

    # apply transformation based on info in img_meta
63
64
    points = apply_3d_transformation(
        points, coord_type, img_meta, reverse=True)
zhangwenwei's avatar
zhangwenwei committed
65

66
    # project points to image coordinate
67
68
69
70
71
72
    if valid_flag:
        proj_pts = points_cam2img(points, proj_mat, with_depth=True)
        pts_2d = proj_pts[..., :2]
        depths = proj_pts[..., 2]
    else:
        pts_2d = points_cam2img(points, proj_mat)
zhangwenwei's avatar
zhangwenwei committed
73
74
75
76
77
78
79
80
81
82
83
84

    # img transformation: scale -> crop -> flip
    # the image is resized by img_scale_factor
    img_coors = pts_2d[:, 0:2] * img_scale_factor  # Nx2
    img_coors -= img_crop_offset

    # grid sample, the valid grid range should be in [-1,1]
    coor_x, coor_y = torch.split(img_coors, 1, dim=1)  # each is Nx1

    if img_flip:
        # by default we take it as horizontal flip
        # use img_shape before padding for flip
85
86
        ori_h, ori_w = img_shape
        coor_x = ori_w - coor_x
zhangwenwei's avatar
zhangwenwei committed
87
88

    h, w = img_pad_shape
89
90
91
    norm_coor_y = coor_y / h * 2 - 1
    norm_coor_x = coor_x / w * 2 - 1
    grid = torch.cat([norm_coor_x, norm_coor_y],
zhangwenwei's avatar
zhangwenwei committed
92
93
94
95
96
97
98
99
100
101
102
                     dim=1).unsqueeze(0).unsqueeze(0)  # Nx2 -> 1x1xNx2

    # align_corner=True provides higher performance
    mode = 'bilinear' if aligned else 'nearest'
    point_features = F.grid_sample(
        img_features,
        grid,
        mode=mode,
        padding_mode=padding_mode,
        align_corners=align_corners)  # 1xCx1xN feats

103
104
105
106
107
108
109
110
111
    if valid_flag:
        # (N, )
        valid = (coor_x.squeeze() < w) & (coor_x.squeeze() > 0) & (
            coor_y.squeeze() < h) & (coor_y.squeeze() > 0) & (
                depths > 0)
        valid_features = point_features.squeeze().t()
        valid_features[~valid] = 0
        return valid_features, valid  # (N, C), (N,)

zhangwenwei's avatar
zhangwenwei committed
112
113
114
    return point_features.squeeze().t()


115
@MODELS.register_module()
116
class PointFusion(BaseModule):
zhangwenwei's avatar
zhangwenwei committed
117
    """Fuse image features from multi-scale features.
zhangwenwei's avatar
zhangwenwei committed
118
119

    Args:
120
        img_channels (List[int] or int): Channels of image features.
zhangwenwei's avatar
zhangwenwei committed
121
122
123
124
            It could be a list if the input is multi-scale image features.
        pts_channels (int): Channels of point features
        mid_channels (int): Channels of middle layers
        out_channels (int): Channels of output fused features
125
126
127
128
129
130
131
        img_levels (List[int] or int): Number of image levels. Defaults to 3.
        coord_type (str): 'DEPTH' or 'CAMERA' or 'LIDAR'. Defaults to 'LIDAR'.
        conv_cfg (:obj:`ConfigDict` or dict): Config dict for convolution
            layers of middle layers. Defaults to None.
        norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
            layers of middle layers. Defaults to None.
        act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
zhangwenwei's avatar
zhangwenwei committed
132
            Defaults to None.
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
        init_cfg (:obj:`ConfigDict` or dict or List[:obj:`Contigdict` or dict],
            optional): Initialization config dict. Defaults to None.
        activate_out (bool): Whether to apply relu activation to output
            features. Defaults to True.
        fuse_out (bool): Whether to apply conv layer to the fused features.
            Defaults to False.
        dropout_ratio (int or float): Dropout ratio of image features to
            prevent overfitting. Defaults to 0.
        aligned (bool): Whether to apply aligned feature fusion.
            Defaults to True.
        align_corners (bool): Whether to align corner when sampling features
            according to points. Defaults to True.
        padding_mode (str): Mode used to pad the features of points that do not
            have corresponding image features. Defaults to 'zeros'.
        lateral_conv (bool): Whether to apply lateral convs to image features.
zhangwenwei's avatar
zhangwenwei committed
148
            Defaults to True.
zhangwenwei's avatar
zhangwenwei committed
149
150
151
    """

    def __init__(self,
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
                 img_channels: Union[List[int], int],
                 pts_channels: int,
                 mid_channels: int,
                 out_channels: int,
                 img_levels: Union[List[int], int] = 3,
                 coord_type: str = 'LIDAR',
                 conv_cfg: OptConfigType = None,
                 norm_cfg: OptConfigType = None,
                 act_cfg: OptConfigType = None,
                 init_cfg: OptMultiConfig = None,
                 activate_out: bool = True,
                 fuse_out: bool = False,
                 dropout_ratio: Union[int, float] = 0,
                 aligned: bool = True,
                 align_corners: bool = True,
                 padding_mode: str = 'zeros',
                 lateral_conv: bool = True) -> None:
169
        super(PointFusion, self).__init__(init_cfg=init_cfg)
zhangwenwei's avatar
zhangwenwei committed
170
171
172
173
174
175
176
177
178
        if isinstance(img_levels, int):
            img_levels = [img_levels]
        if isinstance(img_channels, int):
            img_channels = [img_channels] * len(img_levels)
        assert isinstance(img_levels, list)
        assert isinstance(img_channels, list)
        assert len(img_channels) == len(img_levels)

        self.img_levels = img_levels
179
        self.coord_type = coord_type
zhangwenwei's avatar
zhangwenwei committed
180
        self.act_cfg = act_cfg
zhangwenwei's avatar
zhangwenwei committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        self.activate_out = activate_out
        self.fuse_out = fuse_out
        self.dropout_ratio = dropout_ratio
        self.img_channels = img_channels
        self.aligned = aligned
        self.align_corners = align_corners
        self.padding_mode = padding_mode

        self.lateral_convs = None
        if lateral_conv:
            self.lateral_convs = nn.ModuleList()
            for i in range(len(img_channels)):
                l_conv = ConvModule(
                    img_channels[i],
                    mid_channels,
                    3,
                    padding=1,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
zhangwenwei's avatar
zhangwenwei committed
200
                    act_cfg=self.act_cfg,
zhangwenwei's avatar
zhangwenwei committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
                    inplace=False)
                self.lateral_convs.append(l_conv)
            self.img_transform = nn.Sequential(
                nn.Linear(mid_channels * len(img_channels), out_channels),
                nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
            )
        else:
            self.img_transform = nn.Sequential(
                nn.Linear(sum(img_channels), out_channels),
                nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
            )
        self.pts_transform = nn.Sequential(
            nn.Linear(pts_channels, out_channels),
            nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
        )

        if self.fuse_out:
            self.fuse_conv = nn.Sequential(
                nn.Linear(mid_channels, out_channels),
                # For pts the BN is initialized differently by default
                # TODO: check whether this is necessary
                nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
                nn.ReLU(inplace=False))

225
226
227
228
229
        if init_cfg is None:
            self.init_cfg = [
                dict(type='Xavier', layer='Conv2d', distribution='uniform'),
                dict(type='Xavier', layer='Linear', distribution='uniform')
            ]
zhangwenwei's avatar
zhangwenwei committed
230

231
232
    def forward(self, img_feats: List[Tensor], pts: List[Tensor],
                pts_feats: Tensor, img_metas: List[dict]) -> Tensor:
zhangwenwei's avatar
zhangwenwei committed
233
        """Forward function.
zhangwenwei's avatar
zhangwenwei committed
234
235

        Args:
236
237
238
            img_feats (List[Tensor]): Image features.
            pts: (List[Tensor]): A batch of points with shape N x 3.
            pts_feats (Tensor): A tensor consist of point features of the
wangtai's avatar
wangtai committed
239
                total batch.
240
            img_metas (List[dict]): Meta information of images.
zhangwenwei's avatar
zhangwenwei committed
241

zhangwenwei's avatar
zhangwenwei committed
242
        Returns:
243
            Tensor: Fused features of each point.
zhangwenwei's avatar
zhangwenwei committed
244
        """
zhangwenwei's avatar
zhangwenwei committed
245
        img_pts = self.obtain_mlvl_feats(img_feats, pts, img_metas)
zhangwenwei's avatar
zhangwenwei committed
246
247
248
249
250
251
252
253
254
255
256
257
258
        img_pre_fuse = self.img_transform(img_pts)
        if self.training and self.dropout_ratio > 0:
            img_pre_fuse = F.dropout(img_pre_fuse, self.dropout_ratio)
        pts_pre_fuse = self.pts_transform(pts_feats)

        fuse_out = img_pre_fuse + pts_pre_fuse
        if self.activate_out:
            fuse_out = F.relu(fuse_out)
        if self.fuse_out:
            fuse_out = self.fuse_conv(fuse_out)

        return fuse_out

259
260
    def obtain_mlvl_feats(self, img_feats: List[Tensor], pts: List[Tensor],
                          img_metas: List[dict]) -> Tensor:
261
262
263
        """Obtain multi-level features for each point.

        Args:
264
            img_feats (List[Tensor]): Multi-scale image features produced
265
                by image backbone in shape (N, C, H, W).
266
267
            pts (List[Tensor]): Points of each sample.
            img_metas (List[dict]): Meta information for each sample.
268
269

        Returns:
270
            Tensor: Corresponding image features of each point.
271
        """
zhangwenwei's avatar
zhangwenwei committed
272
273
274
275
276
277
278
279
280
        if self.lateral_convs is not None:
            img_ins = [
                lateral_conv(img_feats[i])
                for i, lateral_conv in zip(self.img_levels, self.lateral_convs)
            ]
        else:
            img_ins = img_feats
        img_feats_per_point = []
        # Sample multi-level features
zhangwenwei's avatar
zhangwenwei committed
281
        for i in range(len(img_metas)):
zhangwenwei's avatar
zhangwenwei committed
282
283
284
285
            mlvl_img_feats = []
            for level in range(len(self.img_levels)):
                mlvl_img_feats.append(
                    self.sample_single(img_ins[level][i:i + 1], pts[i][:, :3],
zhangwenwei's avatar
zhangwenwei committed
286
                                       img_metas[i]))
zhangwenwei's avatar
zhangwenwei committed
287
288
289
290
291
292
            mlvl_img_feats = torch.cat(mlvl_img_feats, dim=-1)
            img_feats_per_point.append(mlvl_img_feats)

        img_pts = torch.cat(img_feats_per_point, dim=0)
        return img_pts

293
294
    def sample_single(self, img_feats: Tensor, pts: Tensor,
                      img_meta: dict) -> Tensor:
295
296
297
        """Sample features from single level image feature map.

        Args:
298
299
            img_feats (Tensor): Image feature map in shape (1, C, H, W).
            pts (Tensor): Points of a single sample.
300
301
302
            img_meta (dict): Meta information of the single sample.

        Returns:
303
            Tensor: Single level image features of each point.
304
        """
305
        # TODO: image transformation also extracted
zhangwenwei's avatar
zhangwenwei committed
306
        img_scale_factor = (
zhangwenwei's avatar
zhangwenwei committed
307
            pts.new_tensor(img_meta['scale_factor'][:2])
zhangwenwei's avatar
zhangwenwei committed
308
309
310
311
312
            if 'scale_factor' in img_meta.keys() else 1)
        img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False
        img_crop_offset = (
            pts.new_tensor(img_meta['img_crop_offset'])
            if 'img_crop_offset' in img_meta.keys() else 0)
313
        proj_mat = get_proj_mat_by_coord_type(img_meta, self.coord_type)
zhangwenwei's avatar
zhangwenwei committed
314
        img_pts = point_sample(
315
316
317
318
319
320
321
            img_meta=img_meta,
            img_features=img_feats,
            points=pts,
            proj_mat=pts.new_tensor(proj_mat),
            coord_type=self.coord_type,
            img_scale_factor=img_scale_factor,
            img_crop_offset=img_crop_offset,
zhangwenwei's avatar
zhangwenwei committed
322
            img_flip=img_flip,
323
            img_pad_shape=img_meta['input_shape'][:2],
zhangwenwei's avatar
zhangwenwei committed
324
325
326
327
328
329
            img_shape=img_meta['img_shape'][:2],
            aligned=self.aligned,
            padding_mode=self.padding_mode,
            align_corners=self.align_corners,
        )
        return img_pts
330
331


332
333
334
335
336
337
338
339
340
341
342
343
344
345
def voxel_sample(voxel_features: Tensor,
                 voxel_range: List[float],
                 voxel_size: List[float],
                 depth_samples: Tensor,
                 proj_mat: Tensor,
                 downsample_factor: int,
                 img_scale_factor: Tensor,
                 img_crop_offset: Tensor,
                 img_flip: bool,
                 img_pad_shape: Tuple[int],
                 img_shape: Tuple[int],
                 aligned: bool = True,
                 padding_mode: str = 'zeros',
                 align_corners: bool = True) -> Tensor:
346
347
348
    """Obtain image features using points.

    Args:
349
350
351
352
353
        voxel_features (Tensor): 1 x C x Nx x Ny x Nz voxel features.
        voxel_range (List[float]): The range of voxel features.
        voxel_size (List[float]): The voxel size of voxel features.
        depth_samples (Tensor): N depth samples in LiDAR coordinates.
        proj_mat (Tensor): ORIGINAL LiDAR2img projection matrix for N views.
354
        downsample_factor (int): The downsample factor in rescaling.
355
        img_scale_factor (Tensor): Scale factor with shape of
356
            (w_scale, h_scale).
357
358
        img_crop_offset (Tensor): Crop offset used to crop image during
            data augmentation with shape of (w_offset, h_offset).
359
        img_flip (bool): Whether the image is flipped.
360
361
362
363
364
        img_pad_shape (Tuple[int]): Int tuple indicates the h & w after
            padding. This is necessary to obtain features in feature map.
        img_shape (Tuple[int]): Int tuple indicates the h & w before padding
            after scaling. This is necessary for flipping coordinates.
        aligned (bool): Whether to use bilinear interpolation when
365
            sampling image features for each point. Defaults to True.
366
        padding_mode (str): Padding mode when padding values for
367
            features of out-of-image points. Defaults to 'zeros'.
368
        align_corners (bool): Whether to align corners when
369
370
371
            sampling image features for each point. Defaults to True.

    Returns:
372
        Tensor: 1xCxDxHxW frustum features sampled from voxel features.
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    """
    # construct frustum grid
    device = voxel_features.device
    h, w = img_pad_shape
    h_out = round(h / downsample_factor)
    w_out = round(w / downsample_factor)
    ws = (torch.linspace(0, w_out - 1, w_out) * downsample_factor).to(device)
    hs = (torch.linspace(0, h_out - 1, h_out) * downsample_factor).to(device)
    depths = depth_samples[::downsample_factor]
    num_depths = len(depths)
    ds_3d, ys_3d, xs_3d = torch.meshgrid(depths, hs, ws)
    # grid: (D, H_out, W_out, 3) -> (D*H_out*W_out, 3)
    grid = torch.stack([xs_3d, ys_3d, ds_3d], dim=-1).view(-1, 3)
    # recover the coordinates in the canonical space
    # reverse order of augmentations: flip -> crop -> scale
    if img_flip:
        # by default we take it as horizontal flip
        # use img_shape before padding for flip
        ori_h, ori_w = img_shape
        grid[:, 0] = ori_w - grid[:, 0]
    grid[:, :2] += img_crop_offset
    grid[:, :2] /= img_scale_factor
    # grid3d: (D*H_out*W_out, 3) in LiDAR coordinate system
    grid3d = points_img2cam(grid, proj_mat)
    # convert the 3D point coordinates to voxel coordinates
    voxel_range = torch.tensor(voxel_range).to(device).view(1, 6)
    voxel_size = torch.tensor(voxel_size).to(device).view(1, 3)
    # suppose the voxel grid is generated with AlignedAnchorGenerator
    # -0.5 given each grid is located at the center of the grid
    # TODO: study whether here needs -0.5
    grid3d = (grid3d - voxel_range[:, :3]) / voxel_size - 0.5
    grid_size = (voxel_range[:, 3:] - voxel_range[:, :3]) / voxel_size
    # normalize grid3d to (-1, 1)
    grid3d = grid3d / grid_size * 2 - 1
    # (x, y, z) -> (z, y, x) for grid_sampling
    grid3d = grid3d.view(1, num_depths, h_out, w_out, 3)[..., [2, 1, 0]]
    # align_corner=True provides higher performance
    mode = 'bilinear' if aligned else 'nearest'
    frustum_features = F.grid_sample(
        voxel_features,
        grid3d,
        mode=mode,
        padding_mode=padding_mode,
        align_corners=align_corners)  # 1xCxDxHxW feats

    return frustum_features