votenet.py 3.68 KB
Newer Older
wuyuefeng's avatar
Votenet  
wuyuefeng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch

from mmdet3d.core import bbox3d2result
from mmdet.models import DETECTORS, SingleStageDetector


@DETECTORS.register_module()
class VoteNet(SingleStageDetector):
    """VoteNet model.

    https://arxiv.org/pdf/1904.09664.pdf
    """

    def __init__(self,
                 backbone,
                 bbox_head=None,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(VoteNet, self).__init__(
            backbone=backbone,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            pretrained=pretrained)

    def extract_feat(self, points):
        x = self.backbone(points)
        if self.with_neck:
            x = self.neck(x)
        return x

    def forward_train(self,
                      points,
                      img_meta,
                      gt_bboxes_3d,
                      gt_labels_3d,
                      pts_semantic_mask=None,
                      pts_instance_mask=None,
                      gt_bboxes_ignore=None):
        """Forward of training.

        Args:
            points (list[Tensor]): Points of each batch.
            img_meta (list): Image metas.
wuyuefeng's avatar
wuyuefeng committed
46
            gt_bboxes_3d (BaseInstance3DBoxes): gt bboxes of each batch.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
            gt_labels_3d (list[Tensor]): gt class labels of each batch.
            pts_semantic_mask (None | list[Tensor]): point-wise semantic
                label of each batch.
            pts_instance_mask (None | list[Tensor]): point-wise instance
                label of each batch.
            gt_bboxes_ignore (None | list[Tensor]): Specify which bounding.

        Returns:
            dict: Losses.
        """
        points_cat = torch.stack(points)  # tmp

        x = self.extract_feat(points_cat)
        bbox_preds = self.bbox_head(x, self.train_cfg.sample_mod)
        loss_inputs = (points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask,
                       pts_instance_mask, img_meta)
        losses = self.bbox_head.loss(
            bbox_preds, *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
        return losses

    def forward_test(self, **kwargs):
        return self.simple_test(**kwargs)

    def forward(self, return_loss=True, **kwargs):
        if return_loss:
            return self.forward_train(**kwargs)
        else:
            return self.forward_test(**kwargs)

    def simple_test(self,
                    points,
                    img_meta,
                    gt_bboxes_3d=None,
                    gt_labels_3d=None,
                    pts_semantic_mask=None,
                    pts_instance_mask=None,
                    rescale=False):
        """Forward of testing.

        Args:
            points (list[Tensor]): Points of each sample.
            img_meta (list): Image metas.
wuyuefeng's avatar
wuyuefeng committed
89
            gt_bboxes_3d (BaseInstance3DBoxes): gt bboxes of each sample.
wuyuefeng's avatar
Votenet  
wuyuefeng committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
            gt_labels_3d (list[Tensor]): gt class labels of each sample.
            pts_semantic_mask (None | list[Tensor]): point-wise semantic
                label of each sample.
            pts_instance_mask (None | list[Tensor]): point-wise instance
                label of each sample.
            rescale (bool): Whether to rescale results.

        Returns:
            list: Predicted 3d boxes.
        """
        points_cat = torch.stack(points)  # tmp

        x = self.extract_feat(points_cat)
        bbox_preds = self.bbox_head(x, self.test_cfg.sample_mod)
        bbox_list = self.bbox_head.get_bboxes(
            points_cat, bbox_preds, img_meta, rescale=rescale)
        bbox_results = [
            bbox3d2result(bboxes, scores, labels)
            for bboxes, scores, labels in bbox_list
        ]
        return bbox_results[0]