import torch from torch import nn from torch.nn import functional as F from ._utils import _SimpleSegmentationModel __all__ = ["DeepLabV3"] class DeepLabV3(_SimpleSegmentationModel): """ Implements DeepLabV3 model from `"Rethinking Atrous Convolution for Semantic Image Segmentation" `_. Args: backbone (nn.Module): the network used to compute the features for the model. The backbone should return an OrderedDict[Tensor], with the key being "out" for the last feature map used, and "aux" if an auxiliary classifier is used. classifier (nn.Module): module that takes the "out" element returned from the backbone and returns a dense prediction. aux_classifier (nn.Module, optional): auxiliary classifier used during training """ pass class DeepLabHead(nn.Sequential): def __init__(self, in_channels, num_classes): super(DeepLabHead, self).__init__( ASPP(in_channels, [12, 24, 36]), nn.Conv2d(256, 256, 3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(), nn.Conv2d(256, num_classes, 1) ) class ASPPConv(nn.Sequential): def __init__(self, in_channels, out_channels, dilation): modules = [ nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU() ] super(ASPPConv, self).__init__(*modules) class ASPPPooling(nn.Sequential): def __init__(self, in_channels, out_channels): super(ASPPPooling, self).__init__( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU()) def forward(self, x): size = x.shape[-2:] for mod in self: x = mod(x) return F.interpolate(x, size=size, mode='bilinear', align_corners=False) class ASPP(nn.Module): def __init__(self, in_channels, atrous_rates, out_channels=256): super(ASPP, self).__init__() modules = [] modules.append(nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU())) rates = tuple(atrous_rates) for rate in rates: modules.append(ASPPConv(in_channels, out_channels, rate)) modules.append(ASPPPooling(in_channels, out_channels)) self.convs = nn.ModuleList(modules) self.project = nn.Sequential( nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(), nn.Dropout(0.5)) def forward(self, x): res = [] for conv in self.convs: res.append(conv(x)) res = torch.cat(res, dim=1) return self.project(res)