lraspp.py 2.54 KB
Newer Older
1
from collections import OrderedDict
2
from typing import Dict
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

from torch import nn, Tensor
from torch.nn import functional as F


__all__ = ["LRASPP"]


class LRASPP(nn.Module):
    """
    Implements a Lite R-ASPP Network for semantic segmentation from
    `"Searching for MobileNetV3"
    <https://arxiv.org/abs/1905.02244>`_.

    Args:
        backbone (nn.Module): the network used to compute the features for the model.
            The backbone should return an OrderedDict[Tensor], with the key being
            "high" for the high level feature map and "low" for the low level feature map.
        low_channels (int): the number of channels of the low level features.
        high_channels (int): the number of channels of the high level features.
        num_classes (int): number of output classes of the model (including the background).
        inter_channels (int, optional): the number of channels for intermediate computations.
    """

27
    def __init__(
28
        self, backbone: nn.Module, low_channels: int, high_channels: int, num_classes: int, inter_channels: int = 128
29
    ) -> None:
30
31
32
33
        super().__init__()
        self.backbone = backbone
        self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels)

34
    def forward(self, input: Tensor) -> Dict[str, Tensor]:
35
36
        features = self.backbone(input)
        out = self.classifier(features)
37
        out = F.interpolate(out, size=input.shape[-2:], mode="bilinear", align_corners=False)
38
39
40
41
42
43
44
45

        result = OrderedDict()
        result["out"] = out

        return result


class LRASPPHead(nn.Module):
46
    def __init__(self, low_channels: int, high_channels: int, num_classes: int, inter_channels: int) -> None:
47
48
49
50
        super().__init__()
        self.cbr = nn.Sequential(
            nn.Conv2d(high_channels, inter_channels, 1, bias=False),
            nn.BatchNorm2d(inter_channels),
51
            nn.ReLU(inplace=True),
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        )
        self.scale = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(high_channels, inter_channels, 1, bias=False),
            nn.Sigmoid(),
        )
        self.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
        self.high_classifier = nn.Conv2d(inter_channels, num_classes, 1)

    def forward(self, input: Dict[str, Tensor]) -> Tensor:
        low = input["low"]
        high = input["high"]

        x = self.cbr(high)
        s = self.scale(high)
        x = x * s
68
        x = F.interpolate(x, size=low.shape[-2:], mode="bilinear", align_corners=False)
69
70

        return self.low_classifier(low) + self.high_classifier(x)