sassd.py 3.92 KB
Newer Older
Wenbo Yu's avatar
Wenbo Yu committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
from typing import Tuple, Union
Wenbo Yu's avatar
Wenbo Yu committed
3

4
5
6
7
8
from torch import Tensor

from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig
from ...structures.det3d_data_sample import SampleList
Wenbo Yu's avatar
Wenbo Yu committed
9
10
11
from .single_stage import SingleStage3DDetector


12
@MODELS.register_module()
Wenbo Yu's avatar
Wenbo Yu committed
13
14
15
16
class SASSD(SingleStage3DDetector):
    r"""`SASSD <https://github.com/skyhehe123/SA-SSD>` _ for 3D detection."""

    def __init__(self,
17
18
19
20
21
22
23
24
25
                 voxel_encoder: ConfigType,
                 middle_encoder: ConfigType,
                 backbone: ConfigType,
                 neck: OptConfigType = None,
                 bbox_head: OptConfigType = None,
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptMultiConfig = None):
Wenbo Yu's avatar
Wenbo Yu committed
26
27
28
29
30
31
        super(SASSD, self).__init__(
            backbone=backbone,
            neck=neck,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
32
33
            data_preprocessor=data_preprocessor,
            init_cfg=init_cfg)
Wenbo Yu's avatar
Wenbo Yu committed
34

zhangshilong's avatar
zhangshilong committed
35
36
        self.voxel_encoder = MODELS.build(voxel_encoder)
        self.middle_encoder = MODELS.build(middle_encoder)
Wenbo Yu's avatar
Wenbo Yu committed
37

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    def extract_feat(
        self,
        batch_inputs_dict: dict,
        test_mode: bool = True
    ) -> Union[Tuple[Tuple[Tensor], Tuple], Tuple[Tensor]]:
        """Extract features from points.

        Args:
            batch_inputs_dict (dict): The batch inputs.
            test_mode (bool, optional): Whether test mode. Defaults to True.

        Returns:
            Union[Tuple[Tuple[Tensor], Tuple], Tuple[Tensor]]: In test mode, it
            returns the features of points from multiple levels. In training
            mode, it returns the features of points from multiple levels and a
            tuple containing the mean features of points and the targets of
            clssification and regression.
        """
        voxel_dict = batch_inputs_dict['voxels']
        voxel_features = self.voxel_encoder(voxel_dict['voxels'],
                                            voxel_dict['num_points'],
                                            voxel_dict['coors'])
        batch_size = voxel_dict['coors'][-1, 0].item() + 1
        # `point_misc` is a tuple containing the mean features of points and
        # the targets of clssification and regression. It's only used for
        # calculating auxiliary loss in training mode.
        x, point_misc = self.middle_encoder(voxel_features,
                                            voxel_dict['coors'], batch_size,
Wenbo Yu's avatar
Wenbo Yu committed
66
67
68
69
70
                                            test_mode)
        x = self.backbone(x)
        if self.with_neck:
            x = self.neck(x)

71
        return (x, point_misc) if not test_mode else x
Wenbo Yu's avatar
Wenbo Yu committed
72

73
74
75
    def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
             **kwargs) -> dict:
        """Calculate losses from a batch of inputs dict and data samples.
Wenbo Yu's avatar
Wenbo Yu committed
76
77

        Args:
78
79
80
81
82
83
84
            batch_inputs_dict (dict): The model input dict which include
                'points' keys.
                    - points (list[torch.Tensor]): Point cloud of each sample.

            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.
Wenbo Yu's avatar
Wenbo Yu committed
85
86

        Returns:
87
            dict: A dictionary of loss components.
Wenbo Yu's avatar
Wenbo Yu committed
88
        """
89
90
91
92
93
94
95
96
        x, point_misc = self.extract_feat(batch_inputs_dict, test_mode=False)
        batch_gt_bboxes_3d = [
            data_sample.gt_instances_3d.bboxes_3d
            for data_sample in batch_data_samples
        ]
        aux_loss = self.middle_encoder.aux_loss(*point_misc,
                                                batch_gt_bboxes_3d)
        losses = self.bbox_head.loss(x, batch_data_samples)
Wenbo Yu's avatar
Wenbo Yu committed
97
98
        losses.update(aux_loss)
        return losses