groupfree3dnet.py 3.28 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
hjin2902's avatar
hjin2902 committed
2

3
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
4
from ...structures.det3d_data_sample import SampleList
hjin2902's avatar
hjin2902 committed
5
6
7
from .single_stage import SingleStage3DDetector


8
@MODELS.register_module()
hjin2902's avatar
hjin2902 committed
9
10
11
12
13
14
15
16
class GroupFree3DNet(SingleStage3DDetector):
    """`Group-Free 3D <https://arxiv.org/abs/2104.00678>`_."""

    def __init__(self,
                 backbone,
                 bbox_head=None,
                 train_cfg=None,
                 test_cfg=None,
jshilong's avatar
jshilong committed
17
18
                 init_cfg=None,
                 **kwargs):
hjin2902's avatar
hjin2902 committed
19
20
21
22
23
        super(GroupFree3DNet, self).__init__(
            backbone=backbone,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
jshilong's avatar
jshilong committed
24
25
            init_cfg=init_cfg,
            **kwargs)
hjin2902's avatar
hjin2902 committed
26

jshilong's avatar
jshilong committed
27
28
29
    def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
             **kwargs) -> dict:
        """Calculate losses from a batch of inputs dict and data samples.
hjin2902's avatar
hjin2902 committed
30
31

        Args:
jshilong's avatar
jshilong committed
32
33
34
35
36
37
38
39
40
            batch_inputs_dict (dict): The model input dict which include
                'points', 'imgs' keys.

                    - points (list[torch.Tensor]): Point cloud of each sample.
                    - imgs (torch.Tensor, optional): Image of each sample.

            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`, `gt_pts_seg`.
hjin2902's avatar
hjin2902 committed
41
42

        Returns:
jshilong's avatar
jshilong committed
43
            dict: A dictionary of loss components.
hjin2902's avatar
hjin2902 committed
44
        """
jshilong's avatar
jshilong committed
45
46
47
        x = self.extract_feat(batch_inputs_dict)
        points = batch_inputs_dict['points']
        losses = self.bbox_head.loss(points, x, batch_data_samples, **kwargs)
hjin2902's avatar
hjin2902 committed
48
49
        return losses

jshilong's avatar
jshilong committed
50
51
52
53
    def predict(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
                **kwargs) -> SampleList:
        """Predict results from a batch of inputs and data samples with post-
        processing.
hjin2902's avatar
hjin2902 committed
54
55

        Args:
jshilong's avatar
jshilong committed
56
57
            batch_inputs_dict (dict): The model input dict which include
                'points', 'imgs' keys.
hjin2902's avatar
hjin2902 committed
58

jshilong's avatar
jshilong committed
59
60
                    - points (list[torch.Tensor]): Point cloud of each sample.
                    - imgs (torch.Tensor, optional): Image of each sample.
hjin2902's avatar
hjin2902 committed
61

jshilong's avatar
jshilong committed
62
63
64
65
66
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`, `gt_pts_seg`.
            rescale (bool): Whether to rescale the results.
                Defaults to True.
hjin2902's avatar
hjin2902 committed
67

jshilong's avatar
jshilong committed
68
69
70
71
72
        Returns:
            list[:obj:`Det3DDataSample`]: Detection results of the
            input images. Each Det3DDataSample usually contain
            'pred_instances_3d'. And the ``pred_instances_3d`` usually
            contains following keys.
hjin2902's avatar
hjin2902 committed
73

jshilong's avatar
jshilong committed
74
75
76
77
            - scores_3d (Tensor): Classification scores, has a shape
                (num_instance, )
            - labels_3d (Tensor): Labels of bboxes, has a shape
                (num_instances, ).
zhangshilong's avatar
zhangshilong committed
78
            - bbox_3d (Tensor): Contains a tensor with shape
jshilong's avatar
jshilong committed
79
80
81
82
83
84
85
86
                (num_instances, C) where C >=7.
        """
        x = self.extract_feat(batch_inputs_dict)
        points = batch_inputs_dict['points']
        results_list = self.bbox_head.predict(points, x, batch_data_samples,
                                              **kwargs)
        predictions = self.convert_to_datasample(results_list)
        return predictions