decode_head.py 6.09 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
from abc import ABCMeta, abstractmethod
3
from typing import List
4

5
import torch
6
from mmengine.model import BaseModule, normal_init
7
from torch import Tensor
8
9
from torch import nn as nn

10
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
11
from mmdet3d.structures.det3d_data_sample import SampleList
12
from mmdet3d.utils.typing_utils import ConfigType
13
14


15
class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
16
17
    """Base class for BaseDecodeHead.

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    1. The ``init_weights`` method is used to initialize decode_head's
    model parameters. After segmentor initialization, ``init_weights``
    is triggered when ``segmentor.init_weights()`` is called externally.

    2. The ``loss`` method is used to calculate the loss of decode_head,
    which includes two steps: (1) the decode_head model performs forward
    propagation to obtain the feature maps (2) The ``loss_by_feat`` method
    is called based on the feature maps to calculate the loss.

    .. code:: text

    loss(): forward() -> loss_by_feat()

    3. The ``predict`` method is used to predict segmentation results,
    which includes two steps: (1) the decode_head model performs forward
    propagation to obtain the feature maps (2) The ``predict_by_feat`` method
    is called based on the feature maps to predict segmentation results
    including post-processing.

    .. code:: text

    predict(): forward() -> predict_by_feat()

41
42
43
    Args:
        channels (int): Channels after modules, before conv_seg.
        num_classes (int): Number of classes.
44
45
46
47
48
49
50
51
52
53
        dropout_ratio (float): Ratio of dropout layer. Defaults to 0.5.
        conv_cfg (dict): Config of conv layers.
            Defaults to dict(type='Conv1d').
        norm_cfg (dict): Config of norm layers.
            Defaults to dict(type='BN1d').
        act_cfg (dict): Config of activation layers.
            Defaults to dict(type='ReLU').
        loss_decode (dict): Config of decode loss.
            Defaults to dict(type='CrossEntropyLoss').
        ignore_index (int): The label index to be ignored.
54
            When using masked BCE loss, ignore_index should be set to None.
55
            Defaults to 255.
56
        init_cfg (dict or list[dict], optional): Initialization config dict.
57
            Defaults to None.
58
59
60
61
62
63
64
65
66
67
    """

    def __init__(self,
                 channels,
                 num_classes,
                 dropout_ratio=0.5,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 act_cfg=dict(type='ReLU'),
                 loss_decode=dict(
68
                     type='mmdet.CrossEntropyLoss',
69
70
71
                     use_sigmoid=False,
                     class_weight=None,
                     loss_weight=1.0),
72
                 ignore_index=255,
73
                 init_cfg=None) -> None:
74
        super(Base3DDecodeHead, self).__init__(init_cfg=init_cfg)
75
76
77
78
79
80
        self.channels = channels
        self.num_classes = num_classes
        self.dropout_ratio = dropout_ratio
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.act_cfg = act_cfg
81
        self.loss_decode = MODELS.build(loss_decode)
82
83
84
85
86
87
88
        self.ignore_index = ignore_index

        self.conv_seg = nn.Conv1d(channels, num_classes, kernel_size=1)
        if dropout_ratio > 0:
            self.dropout = nn.Dropout(dropout_ratio)
        else:
            self.dropout = None
89

90
91
    def init_weights(self):
        """Initialize weights of classification layer."""
92
        super().init_weights()
93
94
95
        normal_init(self.conv_seg, mean=0, std=0.01)

    @abstractmethod
96
    def forward(self, feats_dict: dict):
97
98
99
        """Placeholder of forward function."""
        pass

100
101
102
103
104
105
106
    def cls_seg(self, feat: Tensor) -> Tensor:
        """Classify each points."""
        if self.dropout is not None:
            feat = self.dropout(feat)
        output = self.conv_seg(feat)
        return output

107
108
    def loss(self, inputs: List[Tensor], batch_data_samples: SampleList,
             train_cfg: ConfigType) -> dict:
109
110
111
        """Forward function for training.

        Args:
112
            inputs (list[torch.Tensor]): List of multi-level point features.
113
114
115
            batch_data_samples (List[:obj:`Det3DDataSample`]): The seg
                data samples. It usually includes information such
                as `metainfo` and `gt_pts_seg`.
116
117
118
119
120
121
            train_cfg (dict): The training config.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
        seg_logits = self.forward(inputs)
122
        losses = self.loss_by_feat(seg_logits, batch_data_samples)
123
124
        return losses

125
126
    def predict(self, inputs: List[Tensor], batch_input_metas: List[dict],
                test_cfg: ConfigType) -> List[Tensor]:
127
128
129
130
        """Forward function for testing.

        Args:
            inputs (list[Tensor]): List of multi-level point features.
131
132
133
            batch_data_samples (List[:obj:`Det3DDataSample`]): The seg
                data samples. It usually includes information such
                as `metainfo` and `gt_pts_seg`.
134
135
136
137
138
            test_cfg (dict): The testing config.

        Returns:
            Tensor: Output segmentation map.
        """
139
        seg_logits = self.forward(inputs)
140

141
142
143
144
145
146
147
148
        return seg_logits

    def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
        gt_semantic_segs = [
            data_sample.gt_pts_seg.pts_semantic_mask
            for data_sample in batch_data_samples
        ]
        return torch.stack(gt_semantic_segs, dim=0)
149

150
151
    def loss_by_feat(self, seg_logit: Tensor,
                     batch_data_samples: SampleList) -> dict:
152
153
154
        """Compute semantic segmentation loss.

        Args:
155
            seg_logit (torch.Tensor): Predicted per-point segmentation logits
156
                of shape [B, num_classes, N].
157
158
159
            batch_data_samples (List[:obj:`Det3DDataSample`]): The seg
                data samples. It usually includes information such
                as `metainfo` and `gt_pts_seg`.
160
        """
161
        seg_label = self._stack_batch_gt(batch_data_samples)
162
163
164
165
        loss = dict()
        loss['loss_sem_seg'] = self.loss_decode(
            seg_logit, seg_label, ignore_index=self.ignore_index)
        return loss