fused_semantic_head.py 3.47 KB
Newer Older
1
2
3
4
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import kaiming_init

Cao Yuhang's avatar
Cao Yuhang committed
5
from mmdet.core import auto_fp16, force_fp32
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from ..registry import HEADS
from ..utils import ConvModule


@HEADS.register_module
class FusedSemanticHead(nn.Module):
    """Multi-level fused semantic segmentation head.

    in_1 -> 1x1 conv ---
                        |
    in_2 -> 1x1 conv -- |
                       ||
    in_3 -> 1x1 conv - ||
                      |||                  /-> 1x1 conv (mask prediction)
    in_4 -> 1x1 conv -----> 3x3 convs (*4)
                        |                  \-> 1x1 conv (feature)
    in_5 -> 1x1 conv ---
    """  # noqa: W605

    def __init__(self,
                 num_ins,
                 fusion_level,
                 num_convs=4,
                 in_channels=256,
                 conv_out_channels=256,
                 num_classes=183,
                 ignore_label=255,
                 loss_weight=0.2,
34
                 conv_cfg=None,
Kai Chen's avatar
Kai Chen committed
35
                 norm_cfg=None):
36
37
38
39
40
41
42
43
44
        super(FusedSemanticHead, self).__init__()
        self.num_ins = num_ins
        self.fusion_level = fusion_level
        self.num_convs = num_convs
        self.in_channels = in_channels
        self.conv_out_channels = conv_out_channels
        self.num_classes = num_classes
        self.ignore_label = ignore_label
        self.loss_weight = loss_weight
45
        self.conv_cfg = conv_cfg
Kai Chen's avatar
Kai Chen committed
46
        self.norm_cfg = norm_cfg
Cao Yuhang's avatar
Cao Yuhang committed
47
        self.fp16_enabled = False
48
49
50
51
52
53
54
55

        self.lateral_convs = nn.ModuleList()
        for i in range(self.num_ins):
            self.lateral_convs.append(
                ConvModule(
                    self.in_channels,
                    self.in_channels,
                    1,
56
                    conv_cfg=self.conv_cfg,
Kai Chen's avatar
Kai Chen committed
57
                    norm_cfg=self.norm_cfg,
58
59
60
61
62
63
64
65
66
67
68
                    inplace=False))

        self.convs = nn.ModuleList()
        for i in range(self.num_convs):
            in_channels = self.in_channels if i == 0 else conv_out_channels
            self.convs.append(
                ConvModule(
                    in_channels,
                    conv_out_channels,
                    3,
                    padding=1,
69
                    conv_cfg=self.conv_cfg,
Kai Chen's avatar
Kai Chen committed
70
                    norm_cfg=self.norm_cfg))
71
72
73
74
        self.conv_embedding = ConvModule(
            conv_out_channels,
            conv_out_channels,
            1,
75
            conv_cfg=self.conv_cfg,
Kai Chen's avatar
Kai Chen committed
76
            norm_cfg=self.norm_cfg)
77
78
79
80
81
82
83
        self.conv_logits = nn.Conv2d(conv_out_channels, self.num_classes, 1)

        self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label)

    def init_weights(self):
        kaiming_init(self.conv_logits)

Cao Yuhang's avatar
Cao Yuhang committed
84
    @auto_fp16()
85
86
87
88
89
90
    def forward(self, feats):
        x = self.lateral_convs[self.fusion_level](feats[self.fusion_level])
        fused_size = tuple(x.shape[-2:])
        for i, feat in enumerate(feats):
            if i != self.fusion_level:
                feat = F.interpolate(
91
                    feat, size=fused_size, mode='bilinear', align_corners=True)
92
93
94
95
96
97
98
99
100
                x += self.lateral_convs[i](feat)

        for i in range(self.num_convs):
            x = self.convs[i](x)

        mask_pred = self.conv_logits(x)
        x = self.conv_embedding(x)
        return mask_pred, x

101
    @force_fp32(apply_to=('mask_pred', ))
102
103
104
105
106
    def loss(self, mask_pred, labels):
        labels = labels.squeeze(1).long()
        loss_semantic_seg = self.criterion(mask_pred, labels)
        loss_semantic_seg *= self.loss_weight
        return loss_semantic_seg