decode_head.py 6.1 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
from abc import ABCMeta, abstractmethod
3
from typing import List
4

5
import torch
6
from mmcv.cnn import normal_init
7
8
from mmcv.runner import BaseModule, auto_fp16
from torch import Tensor
9
10
from torch import nn as nn

11
from mmdet3d.registry import MODELS
zhangshilong's avatar
zhangshilong committed
12
13
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils.typing import ConfigType
14
15


16
class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
17
18
    """Base class for BaseDecodeHead.

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    1. The ``init_weights`` method is used to initialize decode_head's
    model parameters. After segmentor initialization, ``init_weights``
    is triggered when ``segmentor.init_weights()`` is called externally.

    2. The ``loss`` method is used to calculate the loss of decode_head,
    which includes two steps: (1) the decode_head model performs forward
    propagation to obtain the feature maps (2) The ``loss_by_feat`` method
    is called based on the feature maps to calculate the loss.

    .. code:: text

    loss(): forward() -> loss_by_feat()

    3. The ``predict`` method is used to predict segmentation results,
    which includes two steps: (1) the decode_head model performs forward
    propagation to obtain the feature maps (2) The ``predict_by_feat`` method
    is called based on the feature maps to predict segmentation results
    including post-processing.

    .. code:: text

    predict(): forward() -> predict_by_feat()

42
43
44
    Args:
        channels (int): Channels after modules, before conv_seg.
        num_classes (int): Number of classes.
45
46
        dropout_ratio (float, optional): Ratio of dropout layer. Default: 0.5.
        conv_cfg (dict, optional): Config of conv layers.
47
            Default: dict(type='Conv1d').
48
        norm_cfg (dict, optional): Config of norm layers.
49
            Default: dict(type='BN1d').
50
        act_cfg (dict, optional): Config of activation layers.
51
            Default: dict(type='ReLU').
52
        loss_decode (dict, optional): Config of decode loss.
53
            Default: dict(type='CrossEntropyLoss').
54
55
56
        ignore_index (int, optional): The label index to be ignored.
            When using masked BCE loss, ignore_index should be set to None.
            Default: 255.
57
        init_cfg (dict or list[dict], optional): Initialization config dict.
58
59
60
61
62
63
64
65
66
67
    """

    def __init__(self,
                 channels,
                 num_classes,
                 dropout_ratio=0.5,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 act_cfg=dict(type='ReLU'),
                 loss_decode=dict(
68
                     type='mmdet.CrossEntropyLoss',
69
70
71
                     use_sigmoid=False,
                     class_weight=None,
                     loss_weight=1.0),
72
                 ignore_index=255,
73
                 init_cfg=None) -> None:
74
        super(Base3DDecodeHead, self).__init__(init_cfg=init_cfg)
75
76
77
78
79
80
        self.channels = channels
        self.num_classes = num_classes
        self.dropout_ratio = dropout_ratio
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.act_cfg = act_cfg
81
        self.loss_decode = MODELS.build(loss_decode)
82
83
84
85
86
87
88
        self.ignore_index = ignore_index

        self.conv_seg = nn.Conv1d(channels, num_classes, kernel_size=1)
        if dropout_ratio > 0:
            self.dropout = nn.Dropout(dropout_ratio)
        else:
            self.dropout = None
89

90
91
92
93
        self.fp16_enabled = False

    def init_weights(self):
        """Initialize weights of classification layer."""
94
        super().init_weights()
95
96
97
98
        normal_init(self.conv_seg, mean=0, std=0.01)

    @auto_fp16()
    @abstractmethod
99
    def forward(self, feats_dict: dict):
100
101
102
        """Placeholder of forward function."""
        pass

103
104
105
106
107
108
109
110
111
    def cls_seg(self, feat: Tensor) -> Tensor:
        """Classify each points."""
        if self.dropout is not None:
            feat = self.dropout(feat)
        output = self.conv_seg(feat)
        return output

    def loss(self, inputs: List[Tensor], batch_data_samples: SampleList,
             train_cfg: ConfigType) -> dict:
112
113
114
        """Forward function for training.

        Args:
115
            inputs (list[torch.Tensor]): List of multi-level point features.
116
            img_metas (list[dict]): Meta information of each sample.
117
            pts_semantic_mask (torch.Tensor): Semantic segmentation masks
118
119
120
121
122
123
124
                used if the architecture supports semantic segmentation task.
            train_cfg (dict): The training config.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
        seg_logits = self.forward(inputs)
125
        losses = self.loss_by_feat(seg_logits, batch_data_samples)
126
127
        return losses

128
129
    def predict(self, inputs: List[Tensor], batch_input_metas: List[dict],
                test_cfg: ConfigType) -> List[Tensor]:
130
131
132
133
        """Forward function for testing.

        Args:
            inputs (list[Tensor]): List of multi-level point features.
134
            batch_img_metas (list[dict]): Meta information of each sample.
135
136
137
138
139
            test_cfg (dict): The testing config.

        Returns:
            Tensor: Output segmentation map.
        """
140
        seg_logits = self.forward(inputs)
141

142
143
144
145
146
147
148
149
        return seg_logits

    def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
        gt_semantic_segs = [
            data_sample.gt_pts_seg.pts_semantic_mask
            for data_sample in batch_data_samples
        ]
        return torch.stack(gt_semantic_segs, dim=0)
150

151
152
    def loss_by_feat(self, seg_logit: Tensor,
                     batch_data_samples: SampleList) -> dict:
153
154
155
        """Compute semantic segmentation loss.

        Args:
156
            seg_logit (torch.Tensor): Predicted per-point segmentation logits
157
                of shape [B, num_classes, N].
158
159
160
            batch_data_samples (List[:obj:`Det3DDataSample`]): The seg
                data samples. It usually includes information such
                as `metainfo` and `gt_pts_seg`.
161
        """
162
        seg_label = self._stack_batch_gt(batch_data_samples)
163
164
165
166
        loss = dict()
        loss['loss_sem_seg'] = self.loss_decode(
            seg_logit, seg_label, ignore_index=self.ignore_index)
        return loss