multiview_dfm.py 22.5 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import Union

4
5
import numpy as np
import torch
6
7
from mmengine.structures import InstanceData
from torch import Tensor
8
9
10
11
12
13

from mmdet3d.models.layers.fusion_layers.point_fusion import (point_sample,
                                                              voxel_sample)
from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures.bbox_3d.utils import get_lidar2img
from mmdet3d.structures.det3d_data_sample import SampleList
14
from mmdet3d.utils import ConfigType, OptConfigType, OptInstanceList
15
16
17
18
from .dfm import DfM


@MODELS.register_module()
19
class MultiViewDfM(DfM):
20
21
22
23
24
25
26
27
28
29
    r"""Waymo challenge solution of `MV-FCOS3D++
    <https://arxiv.org/abs/2207.12716>`_.

    Args:
        backbone (:obj:`ConfigDict` or dict): The backbone config.
        neck (:obj:`ConfigDict` or dict): The neck config.
        backbone_stereo (:obj:`ConfigDict` or dict): The stereo backbone
        config.
        backbone_3d (:obj:`ConfigDict` or dict): The 3d backbone config.
        neck_3d (:obj:`ConfigDict` or dict): The 3D neck config.
30
        bbox_head_3d (:obj:`ConfigDict` or dict): The bbox head config.
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        voxel_size (:obj:`ConfigDict` or dict): The voxel size.
        anchor_generator (:obj:`ConfigDict` or dict): The anchor generator
            config.
        neck_2d (:obj:`ConfigDict` or dict, optional): The 2D neck config
            for 2D object detection. Defaults to None.
        bbox_head_2d (:obj:`ConfigDict` or dict, optional): The 2D bbox
            head config for 2D object detection. Defaults to None.
        depth_head_2d (:obj:`ConfigDict` or dict, optional): The 2D depth
            head config for depth estimation in fov space. Defaults to None.
        depth_head (:obj:`ConfigDict` or dict, optional): The depth head
            config for depth estimation in 3D voxel projected to fov space .
        train_cfg (:obj:`ConfigDict` or dict, optional): Config dict of
            training hyper-parameters. Defaults to None.
        test_cfg (:obj:`ConfigDict` or dict, optional): Config dict of test
            hyper-parameters. Defaults to None.
        data_preprocessor (dict or ConfigDict, optional): The pre-process
            config of :class:`BaseDataPreprocessor`.  it usually includes,
                ``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
        valid_sample (bool): Whether to filter invalid points in view
            transformation. Defaults to True.
        temporal_aggregate (str): Key to determine the aggregation way in
            temporal fusion. Defaults to 'concat'.
        transform_depth (bool): Key to determine the transformation of depth.
            Defaults to True.
        init_cfg (:obj:`ConfigDict` or dict, optional): The initialization
            config. Defaults to None.
    """

    def __init__(self,
                 backbone: ConfigType,
                 neck: ConfigType,
                 backbone_stereo: ConfigType,
                 backbone_3d: ConfigType,
                 neck_3d: ConfigType,
65
                 bbox_head_3d: ConfigType,
66
67
68
69
70
71
72
73
74
75
                 voxel_size: ConfigType,
                 anchor_generator: ConfigType,
                 neck_2d: ConfigType = None,
                 bbox_head_2d: ConfigType = None,
                 depth_head_2d: ConfigType = None,
                 depth_head: ConfigType = None,
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 data_preprocessor: OptConfigType = None,
                 valid_sample: bool = True,
76
                 temporal_aggregate: str = 'mean',
77
78
                 transform_depth: bool = True,
                 init_cfg: OptConfigType = None):
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        super().__init__(
            data_preprocessor=data_preprocessor,
            backbone=backbone,
            neck=neck,
            backbone_stereo=backbone_stereo,
            backbone_3d=backbone_3d,
            neck_3d=neck_3d,
            bbox_head_3d=bbox_head_3d,
            neck_2d=neck_2d,
            bbox_head_2d=bbox_head_2d,
            depth_head_2d=depth_head_2d,
            depth_head=depth_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            init_cfg=init_cfg)
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
        self.voxel_size = voxel_size
        self.voxel_range = anchor_generator['ranges'][0]
        self.n_voxels = [
            round((self.voxel_range[3] - self.voxel_range[0]) /
                  self.voxel_size[0]),
            round((self.voxel_range[4] - self.voxel_range[1]) /
                  self.voxel_size[1]),
            round((self.voxel_range[5] - self.voxel_range[2]) /
                  self.voxel_size[2])
        ]
        self.anchor_generator = TASK_UTILS.build(anchor_generator)
        self.valid_sample = valid_sample
        self.temporal_aggregate = temporal_aggregate
        self.transform_depth = transform_depth

    def extract_feat(self, batch_inputs_dict: dict,
                     batch_data_samples: SampleList):
        """Extract 3d features from the backbone -> fpn -> 3d projection.

        Args:
            batch_inputs_dict (dict): The model input dict which include
                the 'imgs' key.

                    - imgs (torch.Tensor, optional): Image of each sample.
            batch_data_samples (list[:obj:`DetDataSample`]): The batch
                data samples. It usually includes information such
                as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.

        Returns:
            torch.Tensor: of shape (N, C_out, N_x, N_y, N_z)
        """
        # TODO: Nt means the number of frames temporally
        # num_views means the number of views of a frame
        img = batch_inputs_dict['imgs']
        batch_img_metas = [
            data_samples.metainfo for data_samples in batch_data_samples
        ]
        batch_size, _, C_in, H, W = img.shape
        num_views = batch_img_metas[0]['num_views']
        num_ref_frames = batch_img_metas[0]['num_ref_frames']
        if num_ref_frames > 0:
            num_frames = num_ref_frames + 1
        else:
            num_frames = 1
        input_shape = img.shape[-2:]
        # NOTE: input_shape is the largest pad_shape of the batch of images
        for img_meta in batch_img_metas:
            img_meta.update(input_shape=input_shape)
        if num_ref_frames > 0:
            cur_imgs = img[:, :num_views].reshape(-1, C_in, H, W)
            prev_imgs = img[:, num_views:].reshape(-1, C_in, H, W)
            cur_feats = self.backbone(cur_imgs)
            cur_feats = self.neck(cur_feats)[0]
            with torch.no_grad():
                prev_feats = self.backbone(prev_imgs)
                prev_feats = self.neck(prev_feats)[0]
            _, C_feat, H_feat, W_feat = cur_feats.shape
            cur_feats = cur_feats.view(batch_size, -1, C_feat, H_feat, W_feat)
            prev_feats = prev_feats.view(batch_size, -1, C_feat, H_feat,
                                         W_feat)
            batch_feats = torch.cat([cur_feats, prev_feats], dim=1)
        else:
            batch_imgs = img.view(-1, C_in, H, W)
            batch_feats = self.backbone(batch_imgs)
            # TODO: support SPP module neck
            batch_feats = self.neck(batch_feats)[0]
            _, C_feat, H_feat, W_feat = batch_feats.shape
            batch_feats = batch_feats.view(batch_size, -1, C_feat, H_feat,
                                           W_feat)
        # transform the feature to voxel & stereo space
        transform_feats = self.feature_transformation(batch_feats,
                                                      batch_img_metas,
                                                      num_views, num_frames)
        if self.with_depth_head_2d:
            transform_feats += (batch_feats[:, :num_views], )
        return transform_feats

    def feature_transformation(self, batch_feats, batch_img_metas, num_views,
                               num_frames):
        """Feature transformation from perspective view to BEV.

        Args:
            batch_feats (torch.Tensor): Perspective view features of shape
                (batch_size, num_views, C, H, W).
            batch_img_metas (list[dict]): Image meta information. Each element
                corresponds to a group of images. len(img_metas) == B.
            num_views (int): Number of views.
            num_frames (int): Number of consecutive frames.

        Returns:
            tuple[torch.Tensor]: Volume features and (optionally) stereo \
            features.
        """
        # TODO: support more complicated 2D feature sampling
        points = self.anchor_generator.grid_anchors(
            [self.n_voxels[::-1]], device=batch_feats.device)[0][:, :3]
        volumes = []
        img_scale_factors = []
        img_flips = []
        img_crop_offsets = []
        for feature, img_meta in zip(batch_feats, batch_img_metas):

            # TODO: remove feature sampling from back
            # TODO: support different scale_factors/flip/crop_offset for
            # different views
            frame_volume = []
            frame_valid_nums = []
            for frame_idx in range(num_frames):
                volume = []
                valid_flags = []
                if isinstance(img_meta['img_shape'], list):
                    img_shape = img_meta['img_shape'][frame_idx][:2]
                else:
                    img_shape = img_meta['img_shape'][:2]

                for view_idx in range(num_views):

                    sample_idx = frame_idx * num_views + view_idx

                    if 'scale_factor' in img_meta:
                        img_scale_factor = img_meta['scale_factor'][sample_idx]
                        if isinstance(img_scale_factor, np.ndarray) and \
                                len(img_meta['scale_factor']) >= 2:
                            img_scale_factor = (
                                points.new_tensor(img_scale_factor[:2]))
                        else:
                            img_scale_factor = (
                                points.new_tensor(img_scale_factor))
                    else:
                        img_scale_factor = (1)
                    img_flip = img_meta['flip'][sample_idx] \
                        if 'flip' in img_meta.keys() else False
                    img_crop_offset = (
                        points.new_tensor(
                            img_meta['img_crop_offset'][sample_idx])
                        if 'img_crop_offset' in img_meta.keys() else 0)
                    lidar2cam = points.new_tensor(
                        img_meta['lidar2cam'][sample_idx])
                    cam2img = points.new_tensor(
                        img_meta['ori_cam2img'][sample_idx])
                    # align the precision, the tensor is converted to float32
                    lidar2img = get_lidar2img(cam2img.double(),
                                              lidar2cam.double())
                    lidar2img = lidar2img.float()

                    sample_results = point_sample(
                        img_meta,
                        img_features=feature[sample_idx][None, ...],
                        points=points,
                        proj_mat=lidar2img,
                        coord_type='LIDAR',
                        img_scale_factor=img_scale_factor,
                        img_crop_offset=img_crop_offset,
                        img_flip=img_flip,
                        img_pad_shape=img_meta['input_shape'],
                        img_shape=img_shape,
                        aligned=False,
                        valid_flag=self.valid_sample)
                    if self.valid_sample:
                        volume.append(sample_results[0])
                        valid_flags.append(sample_results[1])
                    else:
                        volume.append(sample_results)
                    # TODO: save valid flags, more reasonable feat fusion
                if self.valid_sample:
                    valid_nums = torch.stack(
                        valid_flags, dim=0).sum(0)  # (N, )
                    volume = torch.stack(volume, dim=0).sum(0)
                    valid_mask = valid_nums > 0
                    volume[~valid_mask] = 0
                    frame_valid_nums.append(valid_nums)
                else:
                    volume = torch.stack(volume, dim=0).mean(0)
                frame_volume.append(volume)

            img_scale_factors.append(img_scale_factor)
            img_flips.append(img_flip)
            img_crop_offsets.append(img_crop_offset)

            if self.valid_sample:
                if self.temporal_aggregate == 'mean':
                    frame_volume = torch.stack(frame_volume, dim=0).sum(0)
                    frame_valid_nums = torch.stack(
                        frame_valid_nums, dim=0).sum(0)
                    frame_valid_mask = frame_valid_nums > 0
                    frame_volume[~frame_valid_mask] = 0
                    frame_volume = frame_volume / torch.clamp(
                        frame_valid_nums[:, None], min=1)
                elif self.temporal_aggregate == 'concat':
                    frame_valid_nums = torch.stack(frame_valid_nums, dim=1)
                    frame_volume = torch.stack(frame_volume, dim=1)
                    frame_valid_mask = frame_valid_nums > 0
                    frame_volume[~frame_valid_mask] = 0
                    frame_volume = (frame_volume / torch.clamp(
                        frame_valid_nums[:, :, None], min=1)).flatten(
                            start_dim=1, end_dim=2)
            else:
                frame_volume = torch.stack(frame_volume, dim=0).mean(0)
            volumes.append(
                frame_volume.reshape(self.n_voxels[::-1] + [-1]).permute(
                    3, 2, 1, 0))
        volume_feat = torch.stack(volumes)  # (B, C, N_x, N_y, N_z)
        if self.with_backbone_3d:
            outputs = self.backbone_3d(volume_feat)
            volume_feat = outputs[0]
            if self.backbone_3d.output_bev:
                # use outputs[0] if len(outputs) == 1
                # use outputs[1] if len(outputs) == 2
                # TODO: unify the output formats
                bev_feat = outputs[-1]
        # grid_sample stereo features from the volume feature
        # TODO: also support temporal modeling for depth head
        if self.with_depth_head:
            batch_stereo_feats = []
            for batch_idx in range(volume_feat.shape[0]):
                stereo_feat = []
                for view_idx in range(num_views):
                    img_scale_factor = img_scale_factors[batch_idx] \
                        if self.transform_depth else points.new_tensor(
                            [1., 1.])
                    img_crop_offset = img_crop_offsets[batch_idx] \
                        if self.transform_depth else points.new_tensor(
                            [0., 0.])
                    img_flip = img_flips[batch_idx] if self.transform_depth \
                        else False
                    img_pad_shape = img_meta['input_shape'] \
                        if self.transform_depth else img_meta['ori_shape'][:2]
                    lidar2cam = points.new_tensor(
                        batch_img_metas[batch_idx]['lidar2cam'][view_idx])
                    cam2img = points.new_tensor(
                        img_meta[batch_idx]['lidar2cam'][view_idx])
                    proj_mat = torch.matmul(cam2img, lidar2cam)
                    stereo_feat.append(
                        voxel_sample(
                            volume_feat[batch_idx][None],
                            voxel_range=self.voxel_range,
                            voxel_size=self.voxel_size,
                            depth_samples=volume_feat.new_tensor(
                                self.depth_samples),
                            proj_mat=proj_mat,
                            downsample_factor=self.depth_head.
                            downsample_factor,
                            img_scale_factor=img_scale_factor,
                            img_crop_offset=img_crop_offset,
                            img_flip=img_flip,
                            img_pad_shape=img_pad_shape,
                            img_shape=batch_img_metas[batch_idx]['img_shape']
                            [view_idx][:2],
                            aligned=True))  # TODO: study the aligned setting
                batch_stereo_feats.append(torch.cat(stereo_feat))
            # cat (N, C, D, H, W) -> (B*N, C, D, H, W)
            batch_stereo_feats = torch.cat(batch_stereo_feats)
        if self.with_neck_3d:
            if self.with_backbone_3d and self.backbone_3d.output_bev:
                spatial_features = self.neck_3d(bev_feat)
                # TODO: unify the outputs of neck_3d
                volume_feat = spatial_features[1]
            else:
                volume_feat = self.neck_3d(volume_feat)[0]
        # TODO: unify the output format of neck_3d
        transform_feats = (volume_feat, )
        if self.with_depth_head:
            transform_feats += (batch_stereo_feats, )
        return transform_feats

359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
    def loss(self, batch_inputs: Tensor,
             batch_data_samples: SampleList) -> Union[dict, tuple]:
        """Calculate losses from a batch of inputs dict and data samples.

        Args:
            batch_inputs_dict (dict): The model input dict which include
                'points', 'img' keys.

                    - points (list[torch.Tensor]): Point cloud of each sample.
                    - imgs (torch.Tensor, optional): Image of each sample.

            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.

        Returns:
            dict: A dictionary of loss components.
        """
        feats = self.extract_feat(batch_inputs, batch_data_samples)
        bev_feat = feats[0]
        losses = self.bbox_head_3d.loss([bev_feat], batch_data_samples)
        return losses

    def predict(self, batch_inputs: Tensor,
                batch_data_samples: SampleList) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing.

        Args:
            batch_inputs_dict (dict): The model input dict which include
                'points', 'imgs' keys.

                - points (list[torch.Tensor]): Point cloud of each sample.
                - imgs (torch.Tensor, optional): Image of each sample.

            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.

        Returns:
            list[:obj:`Det3DDataSample`]: Detection results of the
            input samples. Each Det3DDataSample usually contain
            'pred_instances_3d'. And the ``pred_instances_3d`` usually
            contains following keys.

            - scores_3d (Tensor): Classification scores, has a shape
                (num_instance, )
            - labels_3d (Tensor): Labels of bboxes, has a shape
                (num_instances, ).
            - bboxes_3d (Tensor): Contains a tensor with shape
                (num_instances, C) where C >=7.
        """
        feats = self.extract_feat(batch_inputs, batch_data_samples)
        bev_feat = feats[0]
        results_list = self.bbox_head_3d.predict([bev_feat],
                                                 batch_data_samples)
        predictions = self.add_pred_to_datasample(batch_data_samples,
                                                  results_list)
        return predictions

    def _forward(self,
                 batch_inputs: Tensor,
                 batch_data_samples: SampleList = None):
        """Network forward process.

        Usually includes backbone, neck and head forward without any post-
        processing.
        """
        feats = self.extract_feat(batch_inputs, batch_data_samples)
        bev_feat = feats[0]
        self.bbox_head.forward(bev_feat, batch_data_samples)

    def add_pred_to_datasample(
        self,
        data_samples: SampleList,
        data_instances_3d: OptInstanceList = None,
        data_instances_2d: OptInstanceList = None,
    ) -> SampleList:
        """Convert results list to `Det3DDataSample`.

        Subclasses could override it to be compatible for some multi-modality
        3D detectors.

        Args:
            data_samples (list[:obj:`Det3DDataSample`]): The input data.
            data_instances_3d (list[:obj:`InstanceData`], optional): 3D
                Detection results of each sample.
            data_instances_2d (list[:obj:`InstanceData`], optional): 2D
                Detection results of each sample.

        Returns:
            list[:obj:`Det3DDataSample`]: Detection results of the
            input. Each Det3DDataSample usually contains
            'pred_instances_3d'. And the ``pred_instances_3d`` normally
            contains following keys.

            - scores_3d (Tensor): Classification scores, has a shape
              (num_instance, )
            - labels_3d (Tensor): Labels of 3D bboxes, has a shape
              (num_instances, ).
            - bboxes_3d (Tensor): Contains a tensor with shape
              (num_instances, C) where C >=7.

            When there are image prediction in some models, it should
            contains  `pred_instances`, And the ``pred_instances`` normally
            contains following keys.

            - scores (Tensor): Classification scores of image, has a shape
              (num_instance, )
            - labels (Tensor): Predict Labels of 2D bboxes, has a shape
              (num_instances, ).
            - bboxes (Tensor): Contains a tensor with shape
              (num_instances, 4).
        """

        assert (data_instances_2d is not None) or \
               (data_instances_3d is not None),\
               'please pass at least one type of data_samples'

        if data_instances_2d is None:
            data_instances_2d = [
                InstanceData() for _ in range(len(data_instances_3d))
            ]
        if data_instances_3d is None:
            data_instances_3d = [
                InstanceData() for _ in range(len(data_instances_2d))
            ]

        for i, data_sample in enumerate(data_samples):
            data_sample.pred_instances_3d = data_instances_3d[i]
            data_sample.pred_instances = data_instances_2d[i]
        return data_samples

492
493
494
495
496
497
498
499
500
501
502
    def aug_test(self, imgs, img_metas, **kwargs):
        """Test with augmentations.

        Args:
            imgs (list[torch.Tensor]): Input images of shape (N, C_in, H, W).
            img_metas (list): Image metas.

        Returns:
            list[dict]: Predicted 3d boxes.
        """
        raise NotImplementedError