from collections import OrderedDict from torch import nn, Tensor from torch.nn import functional as F from typing import Dict __all__ = ["LRASPP"] class LRASPP(nn.Module): """ Implements a Lite R-ASPP Network for semantic segmentation from `"Searching for MobileNetV3" `_. Args: backbone (nn.Module): the network used to compute the features for the model. The backbone should return an OrderedDict[Tensor], with the key being "high" for the high level feature map and "low" for the low level feature map. low_channels (int): the number of channels of the low level features. high_channels (int): the number of channels of the high level features. num_classes (int): number of output classes of the model (including the background). inter_channels (int, optional): the number of channels for intermediate computations. """ def __init__(self, backbone, low_channels, high_channels, num_classes, inter_channels=128): super().__init__() self.backbone = backbone self.classifier = LRASPPHead(low_channels, high_channels, num_classes, inter_channels) def forward(self, input): features = self.backbone(input) out = self.classifier(features) out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False) result = OrderedDict() result["out"] = out return result class LRASPPHead(nn.Module): def __init__(self, low_channels, high_channels, num_classes, inter_channels): super().__init__() self.cbr = nn.Sequential( nn.Conv2d(high_channels, inter_channels, 1, bias=False), nn.BatchNorm2d(inter_channels), nn.ReLU(inplace=True) ) self.scale = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(high_channels, inter_channels, 1, bias=False), nn.Sigmoid(), ) self.low_classifier = nn.Conv2d(low_channels, num_classes, 1) self.high_classifier = nn.Conv2d(inter_channels, num_classes, 1) def forward(self, input: Dict[str, Tensor]) -> Tensor: low = input["low"] high = input["high"] x = self.cbr(high) s = self.scale(high) x = x * s x = F.interpolate(x, size=low.shape[-2:], mode='bilinear', align_corners=False) return self.low_classifier(low) + self.high_classifier(x)