# -*- coding: utf-8 -*- # @Time : 2019/9/13 10:29 # @Author : zhoujun import torch import torch.nn.functional as F from torch import nn from models.basic import ConvBnRelu class FPEM_FFM(nn.Module): def __init__(self, in_channels, inner_channels=128, fpem_repeat=2, **kwargs): """ PANnet :param in_channels: 基础网络输出的维度 """ super().__init__() self.conv_out = inner_channels inplace = True # reduce layers self.reduce_conv_c2 = ConvBnRelu(in_channels[0], inner_channels, kernel_size=1, inplace=inplace) self.reduce_conv_c3 = ConvBnRelu(in_channels[1], inner_channels, kernel_size=1, inplace=inplace) self.reduce_conv_c4 = ConvBnRelu(in_channels[2], inner_channels, kernel_size=1, inplace=inplace) self.reduce_conv_c5 = ConvBnRelu(in_channels[3], inner_channels, kernel_size=1, inplace=inplace) self.fpems = nn.ModuleList() for i in range(fpem_repeat): self.fpems.append(FPEM(self.conv_out)) self.out_channels = self.conv_out * 4 def forward(self, x): c2, c3, c4, c5 = x # reduce channel c2 = self.reduce_conv_c2(c2) c3 = self.reduce_conv_c3(c3) c4 = self.reduce_conv_c4(c4) c5 = self.reduce_conv_c5(c5) # FPEM for i, fpem in enumerate(self.fpems): c2, c3, c4, c5 = fpem(c2, c3, c4, c5) if i == 0: c2_ffm = c2 c3_ffm = c3 c4_ffm = c4 c5_ffm = c5 else: c2_ffm += c2 c3_ffm += c3 c4_ffm += c4 c5_ffm += c5 # FFM c5 = F.interpolate(c5_ffm, c2_ffm.size()[-2:]) c4 = F.interpolate(c4_ffm, c2_ffm.size()[-2:]) c3 = F.interpolate(c3_ffm, c2_ffm.size()[-2:]) Fy = torch.cat([c2_ffm, c3, c4, c5], dim=1) return Fy class FPEM(nn.Module): def __init__(self, in_channels=128): super().__init__() self.up_add1 = SeparableConv2d(in_channels, in_channels, 1) self.up_add2 = SeparableConv2d(in_channels, in_channels, 1) self.up_add3 = SeparableConv2d(in_channels, in_channels, 1) self.down_add1 = SeparableConv2d(in_channels, in_channels, 2) self.down_add2 = SeparableConv2d(in_channels, in_channels, 2) self.down_add3 = SeparableConv2d(in_channels, in_channels, 2) def forward(self, c2, c3, c4, c5): # up阶段 c4 = self.up_add1(self._upsample_add(c5, c4)) c3 = self.up_add2(self._upsample_add(c4, c3)) c2 = self.up_add3(self._upsample_add(c3, c2)) # down 阶段 c3 = self.down_add1(self._upsample_add(c3, c2)) c4 = self.down_add2(self._upsample_add(c4, c3)) c5 = self.down_add3(self._upsample_add(c5, c4)) return c2, c3, c4, c5 def _upsample_add(self, x, y): return F.interpolate(x, size=y.size()[2:]) + y class SeparableConv2d(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super(SeparableConv2d, self).__init__() self.depthwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1, stride=stride, groups=in_channels) self.pointwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() def forward(self, x): x = self.depthwise_conv(x) x = self.pointwise_conv(x) x = self.bn(x) x = self.relu(x) return x