imvotenet.py 22.4 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
zhangshilong's avatar
zhangshilong committed
2
3
import copy
from typing import Dict, List, Optional, Sequence, Tuple, Union
4

5
import torch
zhangshilong's avatar
zhangshilong committed
6
7
from mmengine import InstanceData
from torch import Tensor
8

9
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
10
11
from mmdet3d.structures import Det3DDataSample
from ..layers import MLP
12
13
14
from .base import Base3DDetector


zhangshilong's avatar
zhangshilong committed
15
def sample_valid_seeds(mask: Tensor, num_sampled_seed: int = 1024) -> Tensor:
16
17
    r"""Randomly sample seeds from all imvotes.

18
    Modified from `<https://github.com/facebookresearch/imvotenet/blob/a8856345146bacf29a57266a2f0b874406fd8823/models/imvotenet.py#L26>`_
19
20
21
22
23
24
25
26
27

    Args:
        mask (torch.Tensor): Bool tensor in shape (
            seed_num*max_imvote_per_pixel), indicates
            whether this imvote corresponds to a 2D bbox.
        num_sampled_seed (int): How many to sample from all imvotes.

    Returns:
        torch.Tensor: Indices with shape (num_sampled_seed).
28
    """  # noqa: E501
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    device = mask.device
    batch_size = mask.shape[0]
    sample_inds = mask.new_zeros((batch_size, num_sampled_seed),
                                 dtype=torch.int64)
    for bidx in range(batch_size):
        # return index of non zero elements
        valid_inds = torch.nonzero(mask[bidx, :]).squeeze(-1)
        if len(valid_inds) < num_sampled_seed:
            # compute set t1 - t2
            t1 = torch.arange(num_sampled_seed, device=device)
            t2 = valid_inds % num_sampled_seed
            combined = torch.cat((t1, t2))
            uniques, counts = combined.unique(return_counts=True)
            difference = uniques[counts == 1]

            rand_inds = torch.randperm(
                len(difference),
                device=device)[:num_sampled_seed - len(valid_inds)]
            cur_sample_inds = difference[rand_inds]
            cur_sample_inds = torch.cat((valid_inds, cur_sample_inds))
        else:
            rand_inds = torch.randperm(
                len(valid_inds), device=device)[:num_sampled_seed]
            cur_sample_inds = valid_inds[rand_inds]
        sample_inds[bidx, :] = cur_sample_inds
    return sample_inds


57
@MODELS.register_module()
58
class ImVoteNet(Base3DDetector):
zhangshilong's avatar
zhangshilong committed
59
60
61
62
63
64
65
66
67
    r"""`ImVoteNet <https://arxiv.org/abs/2001.10692>`_ for 3D detection.

    ImVoteNet is based on fusing 2D votes in images and 3D votes in point
    clouds, which explicitly extract both geometric and semantic features
    from the 2D images. It leverage camera parameters to lift these
    features to 3D. A multi-tower training scheme also improve the synergy
    of 2D-3D feature fusion.

    """
68
69

    def __init__(self,
zhangshilong's avatar
zhangshilong committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
                 pts_backbone: Optional[dict] = None,
                 pts_bbox_heads: Optional[dict] = None,
                 pts_neck: Optional[dict] = None,
                 img_backbone: Optional[dict] = None,
                 img_neck: Optional[dict] = None,
                 img_roi_head: Optional[dict] = None,
                 img_rpn_head: Optional[dict] = None,
                 img_mlp: Optional[dict] = None,
                 freeze_img_branch: bool = False,
                 fusion_layer: Optional[dict] = None,
                 num_sampled_seed: Optional[dict] = None,
                 train_cfg: Optional[dict] = None,
                 test_cfg: Optional[dict] = None,
                 init_cfg: Optional[dict] = None,
                 **kwargs) -> None:

        super(ImVoteNet, self).__init__(init_cfg=init_cfg, **kwargs)
87
88
89

        # point branch
        if pts_backbone is not None:
90
            self.pts_backbone = MODELS.build(pts_backbone)
91
        if pts_neck is not None:
92
            self.pts_neck = MODELS.build(pts_neck)
93
94
95
96
97
98
99
100
101
102
103
104
        if pts_bbox_heads is not None:
            pts_bbox_head_common = pts_bbox_heads.common
            pts_bbox_head_common.update(
                train_cfg=train_cfg.pts if train_cfg is not None else None)
            pts_bbox_head_common.update(test_cfg=test_cfg.pts)
            pts_bbox_head_joint = pts_bbox_head_common.copy()
            pts_bbox_head_joint.update(pts_bbox_heads.joint)
            pts_bbox_head_pts = pts_bbox_head_common.copy()
            pts_bbox_head_pts.update(pts_bbox_heads.pts)
            pts_bbox_head_img = pts_bbox_head_common.copy()
            pts_bbox_head_img.update(pts_bbox_heads.img)

105
106
107
            self.pts_bbox_head_joint = MODELS.build(pts_bbox_head_joint)
            self.pts_bbox_head_pts = MODELS.build(pts_bbox_head_pts)
            self.pts_bbox_head_img = MODELS.build(pts_bbox_head_img)
108
109
110
111
112
113
114
115
            self.pts_bbox_heads = [
                self.pts_bbox_head_joint, self.pts_bbox_head_pts,
                self.pts_bbox_head_img
            ]
            self.loss_weights = pts_bbox_heads.loss_weights

        # image branch
        if img_backbone:
116
            self.img_backbone = MODELS.build(img_backbone)
117
        if img_neck is not None:
118
            self.img_neck = MODELS.build(img_neck)
119
120
121
122
123
124
        if img_rpn_head is not None:
            rpn_train_cfg = train_cfg.img_rpn if train_cfg \
                is not None else None
            img_rpn_head_ = img_rpn_head.copy()
            img_rpn_head_.update(
                train_cfg=rpn_train_cfg, test_cfg=test_cfg.img_rpn)
125
            self.img_rpn_head = MODELS.build(img_rpn_head_)
126
127
128
129
130
        if img_roi_head is not None:
            rcnn_train_cfg = train_cfg.img_rcnn if train_cfg \
                is not None else None
            img_roi_head.update(
                train_cfg=rcnn_train_cfg, test_cfg=test_cfg.img_rcnn)
131
            self.img_roi_head = MODELS.build(img_roi_head)
132
133
134

        # fusion
        if fusion_layer is not None:
135
            self.fusion_layer = MODELS.build(fusion_layer)
136
137
138
139
140
141
142
143
144
145
146
147
148
149
            self.max_imvote_per_pixel = fusion_layer.max_imvote_per_pixel

        self.freeze_img_branch = freeze_img_branch
        if freeze_img_branch:
            self.freeze_img_branch_params()

        if img_mlp is not None:
            self.img_mlp = MLP(**img_mlp)

        self.num_sampled_seed = num_sampled_seed

        self.train_cfg = train_cfg
        self.test_cfg = test_cfg

zhangshilong's avatar
zhangshilong committed
150
151
    def _forward(self):
        raise NotImplementedError
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252

    def freeze_img_branch_params(self):
        """Freeze all image branch parameters."""
        if self.with_img_bbox_head:
            for param in self.img_bbox_head.parameters():
                param.requires_grad = False
        if self.with_img_backbone:
            for param in self.img_backbone.parameters():
                param.requires_grad = False
        if self.with_img_neck:
            for param in self.img_neck.parameters():
                param.requires_grad = False
        if self.with_img_rpn:
            for param in self.img_rpn_head.parameters():
                param.requires_grad = False
        if self.with_img_roi_head:
            for param in self.img_roi_head.parameters():
                param.requires_grad = False

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        """Overload in order to load img network ckpts into img branch."""
        module_names = ['backbone', 'neck', 'roi_head', 'rpn_head']
        for key in list(state_dict):
            for module_name in module_names:
                if key.startswith(module_name) and ('img_' +
                                                    key) not in state_dict:
                    state_dict['img_' + key] = state_dict.pop(key)

        super()._load_from_state_dict(state_dict, prefix, local_metadata,
                                      strict, missing_keys, unexpected_keys,
                                      error_msgs)

    def train(self, mode=True):
        """Overload in order to keep image branch modules in eval mode."""
        super(ImVoteNet, self).train(mode)
        if self.freeze_img_branch:
            if self.with_img_bbox_head:
                self.img_bbox_head.eval()
            if self.with_img_backbone:
                self.img_backbone.eval()
            if self.with_img_neck:
                self.img_neck.eval()
            if self.with_img_rpn:
                self.img_rpn_head.eval()
            if self.with_img_roi_head:
                self.img_roi_head.eval()

    @property
    def with_img_bbox(self):
        """bool: Whether the detector has a 2D image box head."""
        return ((hasattr(self, 'img_roi_head') and self.img_roi_head.with_bbox)
                or (hasattr(self, 'img_bbox_head')
                    and self.img_bbox_head is not None))

    @property
    def with_img_bbox_head(self):
        """bool: Whether the detector has a 2D image box head (not roi)."""
        return hasattr(self,
                       'img_bbox_head') and self.img_bbox_head is not None

    @property
    def with_img_backbone(self):
        """bool: Whether the detector has a 2D image backbone."""
        return hasattr(self, 'img_backbone') and self.img_backbone is not None

    @property
    def with_img_neck(self):
        """bool: Whether the detector has a neck in image branch."""
        return hasattr(self, 'img_neck') and self.img_neck is not None

    @property
    def with_img_rpn(self):
        """bool: Whether the detector has a 2D RPN in image detector branch."""
        return hasattr(self, 'img_rpn_head') and self.img_rpn_head is not None

    @property
    def with_img_roi_head(self):
        """bool: Whether the detector has a RoI Head in image branch."""
        return hasattr(self, 'img_roi_head') and self.img_roi_head is not None

    @property
    def with_pts_bbox(self):
        """bool: Whether the detector has a 3D box head."""
        return hasattr(self,
                       'pts_bbox_head') and self.pts_bbox_head is not None

    @property
    def with_pts_backbone(self):
        """bool: Whether the detector has a 3D backbone."""
        return hasattr(self, 'pts_backbone') and self.pts_backbone is not None

    @property
    def with_pts_neck(self):
        """bool: Whether the detector has a neck in 3D detector branch."""
        return hasattr(self, 'pts_neck') and self.pts_neck is not None

    def extract_feat(self, imgs):
        """Just to inherit from abstract method."""
        pass

zhangshilong's avatar
zhangshilong committed
253
    def extract_img_feat(self, img: Tensor) -> Sequence[Tensor]:
254
255
256
257
258
259
        """Directly extract features from the img backbone+neck."""
        x = self.img_backbone(img)
        if self.with_img_neck:
            x = self.img_neck(x)
        return x

zhangshilong's avatar
zhangshilong committed
260
    def extract_pts_feat(self, pts: Tensor) -> Tuple[Tensor]:
261
262
263
264
265
266
267
268
269
270
271
        """Extract features of points."""
        x = self.pts_backbone(pts)
        if self.with_pts_neck:
            x = self.pts_neck(x)

        seed_points = x['fp_xyz'][-1]
        seed_features = x['fp_features'][-1]
        seed_indices = x['fp_indices'][-1]

        return (seed_points, seed_features, seed_indices)

zhangshilong's avatar
zhangshilong committed
272
273
274
    def loss(self, batch_inputs_dict: Dict[str, Union[List, Tensor]],
             batch_data_samples: List[Det3DDataSample],
             **kwargs) -> List[Det3DDataSample]:
275
276
        """
        Args:
zhangshilong's avatar
zhangshilong committed
277
278
279
280
281
282
283
284
285
            batch_inputs_dict (dict): The model input dict which include
                'points', 'imgs` keys.

                - points (list[torch.Tensor]): Point cloud of each sample.
                - imgs (list[torch.Tensor]): Image tensor with shape
                  (N, C, H ,W).
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`.
286
287

        Returns:
zhangshilong's avatar
zhangshilong committed
288
            dict[str, Tensor]: A dictionary of loss components.
289
        """
zhangshilong's avatar
zhangshilong committed
290
291
        imgs = batch_inputs_dict.get('imgs', None)
        points = batch_inputs_dict.get('points', None)
292
        if points is None:
zhangshilong's avatar
zhangshilong committed
293
            x = self.extract_img_feat(imgs)
294
295
296
297
298
            losses = dict()
            # RPN forward and loss
            if self.with_img_rpn:
                proposal_cfg = self.train_cfg.get('img_rpn_proposal',
                                                  self.test_cfg.img_rpn)
zhangshilong's avatar
zhangshilong committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
                rpn_data_samples = copy.deepcopy(batch_data_samples)
                # set cat_id of gt_labels to 0 in RPN
                for data_sample in rpn_data_samples:
                    data_sample.gt_instances.labels = \
                        torch.zeros_like(data_sample.gt_instances.labels)

                rpn_losses, rpn_results_list = \
                    self.img_rpn_head.loss_and_predict(
                        x, rpn_data_samples,
                        proposal_cfg=proposal_cfg, **kwargs)
                # avoid get same name with roi_head loss
                keys = rpn_losses.keys()
                for key in keys:
                    if 'loss' in key and 'rpn' not in key:
                        rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key)
314
315
                losses.update(rpn_losses)
            else:
zhangshilong's avatar
zhangshilong committed
316
317
318
319
320
321
322
323
324
325
                assert batch_data_samples[0].get('proposals', None) is not None
                # use pre-defined proposals in InstanceData for
                # the second stage
                # to extract ROI features.
                rpn_results_list = [
                    data_sample.proposals for data_sample in batch_data_samples
                ]

            roi_losses = self.img_roi_head.loss(x, rpn_results_list,
                                                batch_data_samples, **kwargs)
326
327
328
            losses.update(roi_losses)
            return losses
        else:
zhangshilong's avatar
zhangshilong committed
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
            with torch.no_grad():
                results_2d = self.predict_img_only(
                    batch_inputs_dict['imgs'],
                    batch_data_samples,
                    rescale=False)
            # tensor with shape (n, 6), the 6 arrange
            # as [x1, x2, y1, y2, score, label]
            pred_bboxes_with_label_list = []
            for single_results in results_2d:
                cat_preds = torch.cat(
                    (single_results.bboxes, single_results.scores[:, None],
                     single_results.labels[:, None]),
                    dim=-1)
                cat_preds = cat_preds[torch.argsort(
                    cat_preds[:, 4], descending=True)]
                # drop half bboxes during training for better generalization
                if self.training:
                    rand_drop = torch.randperm(
                        len(cat_preds))[:(len(cat_preds) + 1) // 2]
                    rand_drop = torch.sort(rand_drop)[0]
                    cat_preds = cat_preds[rand_drop]
350

zhangshilong's avatar
zhangshilong committed
351
                pred_bboxes_with_label_list.append(cat_preds)
352

zhangshilong's avatar
zhangshilong committed
353
354
355
356
357
358
            stack_points = torch.stack(points)
            seeds_3d, seed_3d_features, seed_indices = \
                self.extract_pts_feat(stack_points)
            img_metas = [item.metainfo for item in batch_data_samples]
            img_features, masks = self.fusion_layer(
                imgs, pred_bboxes_with_label_list, seeds_3d, img_metas)
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391

            inds = sample_valid_seeds(masks, self.num_sampled_seed)
            batch_size, img_feat_size = img_features.shape[:2]
            pts_feat_size = seed_3d_features.shape[1]
            inds_img = inds.view(batch_size, 1,
                                 -1).expand(-1, img_feat_size, -1)
            img_features = img_features.gather(-1, inds_img)
            inds = inds % inds.shape[1]
            inds_seed_xyz = inds.view(batch_size, -1, 1).expand(-1, -1, 3)
            seeds_3d = seeds_3d.gather(1, inds_seed_xyz)
            inds_seed_feats = inds.view(batch_size, 1,
                                        -1).expand(-1, pts_feat_size, -1)
            seed_3d_features = seed_3d_features.gather(-1, inds_seed_feats)
            seed_indices = seed_indices.gather(1, inds)

            img_features = self.img_mlp(img_features)
            fused_features = torch.cat([seed_3d_features, img_features], dim=1)

            feat_dict_joint = dict(
                seed_points=seeds_3d,
                seed_features=fused_features,
                seed_indices=seed_indices)
            feat_dict_pts = dict(
                seed_points=seeds_3d,
                seed_features=seed_3d_features,
                seed_indices=seed_indices)
            feat_dict_img = dict(
                seed_points=seeds_3d,
                seed_features=img_features,
                seed_indices=seed_indices)

            losses_towers = []
            losses_joint = self.pts_bbox_head_joint.loss(
zhangshilong's avatar
zhangshilong committed
392
393
394
395
396
                points, feat_dict_joint, batch_data_samples)
            losses_pts = self.pts_bbox_head_pts.loss(points, feat_dict_pts,
                                                     batch_data_samples)
            losses_img = self.pts_bbox_head_img.loss(points, feat_dict_img,
                                                     batch_data_samples)
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
            losses_towers.append(losses_joint)
            losses_towers.append(losses_pts)
            losses_towers.append(losses_img)
            combined_losses = dict()
            for loss_term in losses_joint:
                if 'loss' in loss_term:
                    combined_losses[loss_term] = 0
                    for i in range(len(losses_towers)):
                        combined_losses[loss_term] += \
                            losses_towers[i][loss_term] * \
                            self.loss_weights[i]
                else:
                    # only save the metric of the joint head
                    # if it is not a loss
                    combined_losses[loss_term] = \
                        losses_towers[0][loss_term]

            return combined_losses

zhangshilong's avatar
zhangshilong committed
416
417
418
419
    def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
                batch_data_samples: List[Det3DDataSample],
                **kwargs) -> List[Det3DDataSample]:
        """Forward of testing.
420
421

        Args:
zhangshilong's avatar
zhangshilong committed
422
423
424
425
426
427
428
429
            batch_inputs_dict (dict): The model input dict which include
                'points' and 'imgs keys.

                - points (list[torch.Tensor]): Point cloud of each sample.
                - imgs (list[torch.Tensor]): Tensor of Images.
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`.
430
        """
zhangshilong's avatar
zhangshilong committed
431
432
        points = batch_inputs_dict.get('points', None)
        imgs = batch_inputs_dict.get('imgs', None)
433
        if points is None:
zhangshilong's avatar
zhangshilong committed
434
435
436
            assert imgs is not None
            results_2d = self.predict_img_only(imgs, batch_data_samples)
            return self.convert_to_datasample(results_list_2d=results_2d)
437
438

        else:
zhangshilong's avatar
zhangshilong committed
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
            results_2d = self.predict_img_only(
                batch_inputs_dict['imgs'], batch_data_samples, rescale=False)
            # tensor with shape (n, 6), the 6 arrange
            # as [x1, x2, y1, y2, score, label]
            pred_bboxes_with_label_list = []
            for single_results in results_2d:
                cat_preds = torch.cat(
                    (single_results.bboxes, single_results.scores[:, None],
                     single_results.labels[:, None]),
                    dim=-1)
                cat_preds = cat_preds[torch.argsort(
                    cat_preds[:, 4], descending=True)]
                pred_bboxes_with_label_list.append(cat_preds)

            stack_points = torch.stack(points)
            seeds_3d, seed_3d_features, seed_indices = \
                self.extract_pts_feat(stack_points)
456

zhangshilong's avatar
zhangshilong committed
457
458
459
            img_features, masks = self.fusion_layer(
                imgs, pred_bboxes_with_label_list, seeds_3d,
                [item.metainfo for item in batch_data_samples])
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482

            inds = sample_valid_seeds(masks, self.num_sampled_seed)
            batch_size, img_feat_size = img_features.shape[:2]
            pts_feat_size = seed_3d_features.shape[1]
            inds_img = inds.view(batch_size, 1,
                                 -1).expand(-1, img_feat_size, -1)
            img_features = img_features.gather(-1, inds_img)
            inds = inds % inds.shape[1]
            inds_seed_xyz = inds.view(batch_size, -1, 1).expand(-1, -1, 3)
            seeds_3d = seeds_3d.gather(1, inds_seed_xyz)
            inds_seed_feats = inds.view(batch_size, 1,
                                        -1).expand(-1, pts_feat_size, -1)
            seed_3d_features = seed_3d_features.gather(-1, inds_seed_feats)
            seed_indices = seed_indices.gather(1, inds)

            img_features = self.img_mlp(img_features)

            fused_features = torch.cat([seed_3d_features, img_features], dim=1)

            feat_dict = dict(
                seed_points=seeds_3d,
                seed_features=fused_features,
                seed_indices=seed_indices)
zhangshilong's avatar
zhangshilong committed
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530

            results_3d = self.pts_bbox_head_joint.predict(
                batch_inputs_dict['points'],
                feat_dict,
                batch_data_samples,
                rescale=True)

            return self.convert_to_datasample(results_3d)

    def predict_img_only(self,
                         imgs: Tensor,
                         batch_data_samples: List[Det3DDataSample],
                         rescale: bool = True) -> List[InstanceData]:
        """Predict results from a batch of imgs with post- processing.

        Args:
            imgs (Tensor): Inputs images with shape (N, C, H, W).
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
            rescale (bool): Whether to rescale the results.
                Defaults to True.

        Returns:
            list[:obj:`InstanceData`]: Return the list of detection
            results of the input images, usually contains following keys.

            - scores (Tensor): Classification scores, has a shape
                (num_instance, )
            - labels (Tensor): Labels of bboxes, has a shape
                (num_instances, ).
            - bboxes (Tensor): Has a shape (num_instances, 4),
                the last dimension 4 arrange as (x1, y1, x2, y2).
        """

        assert self.with_img_bbox, 'Img bbox head must be implemented.'
        assert self.with_img_backbone, 'Img backbone must be implemented.'
        assert self.with_img_rpn, 'Img rpn must be implemented.'
        assert self.with_img_roi_head, 'Img roi head must be implemented.'
        x = self.extract_img_feat(imgs)

        # If there are no pre-defined proposals, use RPN to get proposals
        if batch_data_samples[0].get('proposals', None) is None:
            rpn_results_list = self.img_rpn_head.predict(
                x, batch_data_samples, rescale=False)
        else:
            rpn_results_list = [
                data_sample.proposals for data_sample in batch_data_samples
531
532
            ]

zhangshilong's avatar
zhangshilong committed
533
534
        results_list = self.img_roi_head.predict(
            x, rpn_results_list, batch_data_samples, rescale=rescale)
535

zhangshilong's avatar
zhangshilong committed
536
        return results_list