h3d_roi_head.py 4.33 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
jshilong's avatar
jshilong committed
2
3
4
5
6
from typing import Dict, List

from mmengine import InstanceData
from torch import Tensor

7
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
8
from mmdet3d.structures import Det3DDataSample
encore-zhou's avatar
encore-zhou committed
9
10
11
from .base_3droi_head import Base3DRoIHead


12
@MODELS.register_module()
encore-zhou's avatar
encore-zhou committed
13
14
15
16
17
18
19
20
21
22
23
class H3DRoIHead(Base3DRoIHead):
    """H3D roi head for H3DNet.

    Args:
        primitive_list (List): Configs of primitive heads.
        bbox_head (ConfigDict): Config of bbox_head.
        train_cfg (ConfigDict): Training config.
        test_cfg (ConfigDict): Testing config.
    """

    def __init__(self,
jshilong's avatar
jshilong committed
24
25
26
27
28
                 primitive_list: List[dict],
                 bbox_head: dict = None,
                 train_cfg: dict = None,
                 test_cfg: dict = None,
                 init_cfg: dict = None):
encore-zhou's avatar
encore-zhou committed
29
        super(H3DRoIHead, self).__init__(
30
31
32
33
            bbox_head=bbox_head,
            train_cfg=train_cfg,
            test_cfg=test_cfg,
            init_cfg=init_cfg)
encore-zhou's avatar
encore-zhou committed
34
35
        # Primitive module
        assert len(primitive_list) == 3
36
37
38
        self.primitive_z = MODELS.build(primitive_list[0])
        self.primitive_xy = MODELS.build(primitive_list[1])
        self.primitive_line = MODELS.build(primitive_list[2])
encore-zhou's avatar
encore-zhou committed
39
40
41
42
43
44

    def init_mask_head(self):
        """Initialize mask head, skip since ``H3DROIHead`` does not have
        one."""
        pass

jshilong's avatar
jshilong committed
45
46
47
48
49
50
51
52
    def init_bbox_head(self, dummy_args, bbox_head):
        """Initialize box head.

        Args:
            dummy_args (optional): Just to compatible with
                the interface in base class
            bbox_head (dict): Config for bbox head.
        """
encore-zhou's avatar
encore-zhou committed
53
54
        bbox_head['train_cfg'] = self.train_cfg
        bbox_head['test_cfg'] = self.test_cfg
55
        self.bbox_head = MODELS.build(bbox_head)
encore-zhou's avatar
encore-zhou committed
56
57
58
59
60

    def init_assigner_sampler(self):
        """Initialize assigner and sampler."""
        pass

jshilong's avatar
jshilong committed
61
62
    def loss(self, points: List[Tensor], feats_dict: dict,
             batch_data_samples: List[Det3DDataSample], **kwargs):
encore-zhou's avatar
encore-zhou committed
63
64
65
        """Training forward function of PartAggregationROIHead.

        Args:
jshilong's avatar
jshilong committed
66
67
68
69
70
            points (list[torch.Tensor]): Point cloud of each sample.
            feats_dict (dict): Dict of feature.
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes information such as
                `gt_instance_3d`.
encore-zhou's avatar
encore-zhou committed
71
72
73
74
75
76

        Returns:
            dict: losses from each head.
        """
        losses = dict()

jshilong's avatar
jshilong committed
77
78
        primitive_loss_inputs = (points, feats_dict, batch_data_samples)
        # note the feats_dict would be added new key and value in each head.
encore-zhou's avatar
encore-zhou committed
79
80
81
        loss_z = self.primitive_z.loss(*primitive_loss_inputs)
        loss_xy = self.primitive_xy.loss(*primitive_loss_inputs)
        loss_line = self.primitive_line.loss(*primitive_loss_inputs)
jshilong's avatar
jshilong committed
82
83
84

        losses.update(loss_z)
        losses.update(loss_xy)
encore-zhou's avatar
encore-zhou committed
85
86
87
88
        losses.update(loss_line)

        targets = feats_dict.pop('targets')

jshilong's avatar
jshilong committed
89
90
91
92
93
        bbox_loss = self.bbox_head.loss(
            points,
            feats_dict,
            rpn_targets=targets,
            batch_data_samples=batch_data_samples)
encore-zhou's avatar
encore-zhou committed
94
95
96
        losses.update(bbox_loss)
        return losses

jshilong's avatar
jshilong committed
97
98
99
100
101
102
103
    def predict(self,
                points: List[Tensor],
                feats_dict: Dict[str, Tensor],
                batch_data_samples: List[Det3DDataSample],
                suffix='_optimized',
                **kwargs) -> List[InstanceData]:
        """
encore-zhou's avatar
encore-zhou committed
104
        Args:
jshilong's avatar
jshilong committed
105
106
107
108
            points (list[tensor]): Point clouds of multiple samples.
            feats_dict (dict): Features from FPN or backbone..
            batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
                Samples. It usually includes meta information of data.
encore-zhou's avatar
encore-zhou committed
109
110

        Returns:
jshilong's avatar
jshilong committed
111
112
113
            list[:obj:`InstanceData`]: List of processed predictions. Each
            InstanceData contains 3d Bounding boxes and corresponding
            scores and labels.
encore-zhou's avatar
encore-zhou committed
114
115
        """

jshilong's avatar
jshilong committed
116
        result_z = self.primitive_z(feats_dict)
encore-zhou's avatar
encore-zhou committed
117
118
        feats_dict.update(result_z)

jshilong's avatar
jshilong committed
119
        result_xy = self.primitive_xy(feats_dict)
encore-zhou's avatar
encore-zhou committed
120
121
        feats_dict.update(result_xy)

jshilong's avatar
jshilong committed
122
        result_line = self.primitive_line(feats_dict)
encore-zhou's avatar
encore-zhou committed
123
124
        feats_dict.update(result_line)

jshilong's avatar
jshilong committed
125
        bbox_preds = self.bbox_head(feats_dict)
encore-zhou's avatar
encore-zhou committed
126
        feats_dict.update(bbox_preds)
jshilong's avatar
jshilong committed
127
128
129
130
        results_list = self.bbox_head.predict(
            points, feats_dict, batch_data_samples, suffix=suffix)

        return results_list