"megatron/legacy/model/utils.py" did not exist on "c99fa80ce9f99303a07ced0cf3f063eaa34ac407"
decode_head.py 4.32 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from abc import ABCMeta, abstractmethod
from mmcv.cnn import normal_init
4
from mmcv.runner import BaseModule, auto_fp16, force_fp32
5
6
7
8
9
from torch import nn as nn

from mmseg.models.builder import build_loss


10
class Base3DDecodeHead(BaseModule, metaclass=ABCMeta):
11
12
13
14
15
    """Base class for BaseDecodeHead.

    Args:
        channels (int): Channels after modules, before conv_seg.
        num_classes (int): Number of classes.
16
17
        dropout_ratio (float, optional): Ratio of dropout layer. Default: 0.5.
        conv_cfg (dict, optional): Config of conv layers.
18
            Default: dict(type='Conv1d').
19
        norm_cfg (dict, optional): Config of norm layers.
20
            Default: dict(type='BN1d').
21
        act_cfg (dict, optional): Config of activation layers.
22
            Default: dict(type='ReLU').
23
        loss_decode (dict, optional): Config of decode loss.
24
            Default: dict(type='CrossEntropyLoss').
25
26
27
        ignore_index (int, optional): The label index to be ignored.
            When using masked BCE loss, ignore_index should be set to None.
            Default: 255.
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    """

    def __init__(self,
                 channels,
                 num_classes,
                 dropout_ratio=0.5,
                 conv_cfg=dict(type='Conv1d'),
                 norm_cfg=dict(type='BN1d'),
                 act_cfg=dict(type='ReLU'),
                 loss_decode=dict(
                     type='CrossEntropyLoss',
                     use_sigmoid=False,
                     class_weight=None,
                     loss_weight=1.0),
42
43
44
                 ignore_index=255,
                 init_cfg=None):
        super(Base3DDecodeHead, self).__init__(init_cfg=init_cfg)
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        self.channels = channels
        self.num_classes = num_classes
        self.dropout_ratio = dropout_ratio
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.act_cfg = act_cfg
        self.loss_decode = build_loss(loss_decode)
        self.ignore_index = ignore_index

        self.conv_seg = nn.Conv1d(channels, num_classes, kernel_size=1)
        if dropout_ratio > 0:
            self.dropout = nn.Dropout(dropout_ratio)
        else:
            self.dropout = None
        self.fp16_enabled = False

    def init_weights(self):
        """Initialize weights of classification layer."""
63
        super().init_weights()
64
65
66
67
68
69
70
71
        normal_init(self.conv_seg, mean=0, std=0.01)

    @auto_fp16()
    @abstractmethod
    def forward(self, inputs):
        """Placeholder of forward function."""
        pass

72
    def forward_train(self, inputs, img_metas, pts_semantic_mask, train_cfg):
73
74
75
        """Forward function for training.

        Args:
76
            inputs (list[torch.Tensor]): List of multi-level point features.
77
            img_metas (list[dict]): Meta information of each sample.
78
            pts_semantic_mask (torch.Tensor): Semantic segmentation masks
79
80
81
82
83
84
85
                used if the architecture supports semantic segmentation task.
            train_cfg (dict): The training config.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
        seg_logits = self.forward(inputs)
86
        losses = self.losses(seg_logits, pts_semantic_mask)
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        return losses

    def forward_test(self, inputs, img_metas, test_cfg):
        """Forward function for testing.

        Args:
            inputs (list[Tensor]): List of multi-level point features.
            img_metas (list[dict]): Meta information of each sample.
            test_cfg (dict): The testing config.

        Returns:
            Tensor: Output segmentation map.
        """
        return self.forward(inputs)

    def cls_seg(self, feat):
        """Classify each points."""
        if self.dropout is not None:
            feat = self.dropout(feat)
        output = self.conv_seg(feat)
        return output

    @force_fp32(apply_to=('seg_logit', ))
    def losses(self, seg_logit, seg_label):
        """Compute semantic segmentation loss.

        Args:
114
            seg_logit (torch.Tensor): Predicted per-point segmentation logits
115
                of shape [B, num_classes, N].
116
            seg_label (torch.Tensor): Ground-truth segmentation label of
117
118
119
120
121
122
                shape [B, N].
        """
        loss = dict()
        loss['loss_sem_seg'] = self.loss_decode(
            seg_logit, seg_label, ignore_index=self.ignore_index)
        return loss