sassd.py 3.92 KB
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union

from torch import Tensor

from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig
from ...structures.det3d_data_sample import SampleList
from .single_stage import SingleStage3DDetector


@MODELS.register_module()
class SASSD(SingleStage3DDetector):
    r"""`SASSD <https://github.com/skyhehe123/SA-SSD>` _ for 3D detection."""

    def __init__(self,
                 voxel_encoder: ConfigType,
                 middle_encoder: ConfigType,
                 backbone: ConfigType,
                 neck: OptConfigType = None,
                 bbox_head: OptConfigType = None,
                 train_cfg: OptConfigType = None,
                 test_cfg: OptConfigType = None,
                 data_preprocessor: OptConfigType = None,
                 init_cfg: OptMultiConfig = None):
        super(SASSD, self).__init__(
            backbone=backbone,
            neck=neck,
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            data_preprocessor=data_preprocessor,
            init_cfg=init_cfg)

        self.voxel_encoder = MODELS.build(voxel_encoder)
        self.middle_encoder = MODELS.build(middle_encoder)

    def extract_feat(
        self,
        batch_inputs_dict: dict,
        test_mode: bool = True
    ) -> Union[Tuple[Tuple[Tensor], Tuple], Tuple[Tensor]]:
        """Extract features from points.

        Args:
            batch_inputs_dict (dict): The batch inputs.
            test_mode (bool, optional): Whether test mode. Defaults to True.

        Returns:
            Union[Tuple[Tuple[Tensor], Tuple], Tuple[Tensor]]: In test mode, it
            returns the features of points from multiple levels. In training
            mode, it returns the features of points from multiple levels and a
            tuple containing the mean features of points and the targets of
            clssification and regression.
        """
        voxel_dict = batch_inputs_dict['voxels']
        voxel_features = self.voxel_encoder(voxel_dict['voxels'],
                                            voxel_dict['num_points'],
                                            voxel_dict['coors'])
        batch_size = voxel_dict['coors'][-1, 0].item() + 1
        # `point_misc` is a tuple containing the mean features of points and
        # the targets of clssification and regression. It's only used for
        # calculating auxiliary loss in training mode.
        x, point_misc = self.middle_encoder(voxel_features,
                                            voxel_dict['coors'], batch_size,
                                            test_mode)
        x = self.backbone(x)
        if self.with_neck:
            x = self.neck(x)

        return (x, point_misc) if not test_mode else x

    def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
             **kwargs) -> dict:
        """Calculate losses from a batch of inputs dict and data samples.

        Args:
            batch_inputs_dict (dict): The model input dict which include
                'points' keys.
                    - points (list[torch.Tensor]): Point cloud of each sample.

            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`.

        Returns:
            dict: A dictionary of loss components.
        """
        x, point_misc = self.extract_feat(batch_inputs_dict, test_mode=False)
        batch_gt_bboxes_3d = [
            data_sample.gt_instances_3d.bboxes_3d
            for data_sample in batch_data_samples
        ]
        aux_loss = self.middle_encoder.aux_loss(*point_misc,
                                                batch_gt_bboxes_3d)
        losses = self.bbox_head.loss(x, batch_data_samples)
        losses.update(aux_loss)
        return losses