votenet.py 4.38 KB
Newer Older
wuyuefeng's avatar
Votenet  
wuyuefeng committed
1
2
import torch

zhangwenwei's avatar
zhangwenwei committed
3
4
5
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d
from mmdet.models import DETECTORS
from .single_stage import SingleStage3DDetector
wuyuefeng's avatar
Votenet  
wuyuefeng committed
6
7
8


@DETECTORS.register_module()
zhangwenwei's avatar
zhangwenwei committed
9
class VoteNet(SingleStage3DDetector):
10
    r"""`VoteNet <https://arxiv.org/pdf/1904.09664.pdf>`_ for 3D detection."""
wuyuefeng's avatar
Votenet  
wuyuefeng committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24

    def __init__(self,
                 backbone,
                 bbox_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(VoteNet, self).__init__(
            backbone=backbone,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            pretrained=pretrained)

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    def extract_feat(self, points, img_metas=None):
        """Directly extract features from the backbone+neck.

        Args:
            points (torch.Tensor): Input points.
        """
        x = self.backbone(points)
        if self.with_neck:
            x = self.neck(x)

        seed_points = x['fp_xyz'][-1]
        seed_features = x['fp_features'][-1]
        seed_indices = x['fp_indices'][-1]

        feat_dict = {
            'seed_points': seed_points,
            'seed_features': seed_features,
            'seed_indices': seed_indices
        }

        return feat_dict

wuyuefeng's avatar
Votenet  
wuyuefeng committed
47
48
    def forward_train(self,
                      points,
zhangwenwei's avatar
zhangwenwei committed
49
                      img_metas,
wuyuefeng's avatar
Votenet  
wuyuefeng committed
50
51
52
53
54
55
56
57
                      gt_bboxes_3d,
                      gt_labels_3d,
                      pts_semantic_mask=None,
                      pts_instance_mask=None,
                      gt_bboxes_ignore=None):
        """Forward of training.

        Args:
liyinhao's avatar
liyinhao committed
58
            points (list[torch.Tensor]): Points of each batch.
zhangwenwei's avatar
zhangwenwei committed
59
            img_metas (list): Image metas.
zhangwenwei's avatar
zhangwenwei committed
60
            gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
liyinhao's avatar
liyinhao committed
61
62
            gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
            pts_semantic_mask (None | list[torch.Tensor]): point-wise semantic
wuyuefeng's avatar
Votenet  
wuyuefeng committed
63
                label of each batch.
liyinhao's avatar
liyinhao committed
64
            pts_instance_mask (None | list[torch.Tensor]): point-wise instance
wuyuefeng's avatar
Votenet  
wuyuefeng committed
65
                label of each batch.
liyinhao's avatar
liyinhao committed
66
67
            gt_bboxes_ignore (None | list[torch.Tensor]): Specify
                which bounding.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
68
69
70
71

        Returns:
            dict: Losses.
        """
zhangwenwei's avatar
zhangwenwei committed
72
        points_cat = torch.stack(points)
wuyuefeng's avatar
Votenet  
wuyuefeng committed
73
74
75
76

        x = self.extract_feat(points_cat)
        bbox_preds = self.bbox_head(x, self.train_cfg.sample_mod)
        loss_inputs = (points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask,
zhangwenwei's avatar
zhangwenwei committed
77
                       pts_instance_mask, img_metas)
wuyuefeng's avatar
Votenet  
wuyuefeng committed
78
79
80
81
        losses = self.bbox_head.loss(
            bbox_preds, *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
        return losses

zhangwenwei's avatar
zhangwenwei committed
82
    def simple_test(self, points, img_metas, imgs=None, rescale=False):
wuyuefeng's avatar
Votenet  
wuyuefeng committed
83
84
85
        """Forward of testing.

        Args:
liyinhao's avatar
liyinhao committed
86
            points (list[torch.Tensor]): Points of each sample.
zhangwenwei's avatar
zhangwenwei committed
87
            img_metas (list): Image metas.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
88
89
90
91
92
            rescale (bool): Whether to rescale results.

        Returns:
            list: Predicted 3d boxes.
        """
zhangwenwei's avatar
zhangwenwei committed
93
        points_cat = torch.stack(points)
wuyuefeng's avatar
Votenet  
wuyuefeng committed
94
95
96
97

        x = self.extract_feat(points_cat)
        bbox_preds = self.bbox_head(x, self.test_cfg.sample_mod)
        bbox_list = self.bbox_head.get_bboxes(
zhangwenwei's avatar
zhangwenwei committed
98
            points_cat, bbox_preds, img_metas, rescale=rescale)
wuyuefeng's avatar
Votenet  
wuyuefeng committed
99
100
101
102
        bbox_results = [
            bbox3d2result(bboxes, scores, labels)
            for bboxes, scores, labels in bbox_list
        ]
103
        return bbox_results
zhangwenwei's avatar
zhangwenwei committed
104
105

    def aug_test(self, points, img_metas, imgs=None, rescale=False):
106
        """Test with augmentation."""
zhangwenwei's avatar
zhangwenwei committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        points_cat = [torch.stack(pts) for pts in points]
        feats = self.extract_feats(points_cat, img_metas)

        # only support aug_test for one sample
        aug_bboxes = []
        for x, pts_cat, img_meta in zip(feats, points_cat, img_metas):
            bbox_preds = self.bbox_head(x, self.test_cfg.sample_mod)
            bbox_list = self.bbox_head.get_bboxes(
                pts_cat, bbox_preds, img_meta, rescale=rescale)
            bbox_list = [
                dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
                for bboxes, scores, labels in bbox_list
            ]
            aug_bboxes.append(bbox_list[0])

        # after merging, bboxes will be rescaled to the original image size
        merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas,
                                            self.bbox_head.test_cfg)

126
        return [merged_bboxes]