from torch import nn from ..functions.deform_pool import deform_roi_pooling class DeformRoIPooling(nn.Module): def __init__(self, spatial_scale, out_size, out_channels, no_trans, group_size=1, part_size=None, sample_per_part=4, trans_std=.0): super(DeformRoIPooling, self).__init__() self.spatial_scale = spatial_scale self.out_size = out_size self.out_channels = out_channels self.no_trans = no_trans self.group_size = group_size self.part_size = out_size if part_size is None else part_size self.sample_per_part = sample_per_part self.trans_std = trans_std def forward(self, data, rois, offset): if self.no_trans: offset = data.new_empty(0) return deform_roi_pooling( data, rois, offset, self.spatial_scale, self.out_size, self.out_channels, self.no_trans, self.group_size, self.part_size, self.sample_per_part, self.trans_std) class DeformRoIPoolingPack(DeformRoIPooling): def __init__(self, spatial_scale, out_size, out_channels, no_trans, group_size=1, part_size=None, sample_per_part=4, trans_std=.0, num_offset_fcs=3, deform_fc_channels=1024): super(DeformRoIPoolingPack, self).__init__(spatial_scale, out_size, out_channels, no_trans, group_size, part_size, sample_per_part, trans_std) self.num_offset_fcs = num_offset_fcs self.deform_fc_channels = deform_fc_channels if not no_trans: seq = [] ic = self.out_size * self.out_size * self.out_channels for i in range(self.num_offset_fcs): if i < self.num_offset_fcs - 1: oc = self.deform_fc_channels else: oc = self.out_size * self.out_size * 2 seq.append(nn.Linear(ic, oc)) ic = oc if i < self.num_offset_fcs - 1: seq.append(nn.ReLU(inplace=True)) self.offset_fc = nn.Sequential(*seq) self.offset_fc[-1].weight.data.zero_() self.offset_fc[-1].bias.data.zero_() def forward(self, data, rois): assert data.size(1) == self.out_channels if self.no_trans: offset = data.new_empty(0) return deform_roi_pooling( data, rois, offset, self.spatial_scale, self.out_size, self.out_channels, self.no_trans, self.group_size, self.part_size, self.sample_per_part, self.trans_std) else: n = rois.shape[0] offset = data.new_empty(0) x = deform_roi_pooling(data, rois, offset, self.spatial_scale, self.out_size, self.out_channels, True, self.group_size, self.part_size, self.sample_per_part, self.trans_std) offset = self.offset_fc(x.view(n, -1)) offset = offset.view(n, 2, self.out_size, self.out_size) return deform_roi_pooling( data, rois, offset, self.spatial_scale, self.out_size, self.out_channels, self.no_trans, self.group_size, self.part_size, self.sample_per_part, self.trans_std) class ModulatedDeformRoIPoolingPack(DeformRoIPooling): def __init__(self, spatial_scale, out_size, out_channels, no_trans, group_size=1, part_size=None, sample_per_part=4, trans_std=.0, num_offset_fcs=3, num_mask_fcs=2, deform_fc_channels=1024): super(ModulatedDeformRoIPoolingPack, self).__init__( spatial_scale, out_size, out_channels, no_trans, group_size, part_size, sample_per_part, trans_std) self.num_offset_fcs = num_offset_fcs self.num_mask_fcs = num_mask_fcs self.deform_fc_channels = deform_fc_channels if not no_trans: offset_fc_seq = [] ic = self.out_size * self.out_size * self.out_channels for i in range(self.num_offset_fcs): if i < self.num_offset_fcs - 1: oc = self.deform_fc_channels else: oc = self.out_size * self.out_size * 2 offset_fc_seq.append(nn.Linear(ic, oc)) ic = oc if i < self.num_offset_fcs - 1: offset_fc_seq.append(nn.ReLU(inplace=True)) self.offset_fc = nn.Sequential(*offset_fc_seq) self.offset_fc[-1].weight.data.zero_() self.offset_fc[-1].bias.data.zero_() mask_fc_seq = [] ic = self.out_size * self.out_size * self.out_channels for i in range(self.num_mask_fcs): if i < self.num_mask_fcs - 1: oc = self.deform_fc_channels else: oc = self.out_size * self.out_size mask_fc_seq.append(nn.Linear(ic, oc)) ic = oc if i < self.num_mask_fcs - 1: mask_fc_seq.append(nn.ReLU(inplace=True)) else: mask_fc_seq.append(nn.Sigmoid()) self.mask_fc = nn.Sequential(*mask_fc_seq) self.mask_fc[-2].weight.data.zero_() self.mask_fc[-2].bias.data.zero_() def forward(self, data, rois): assert data.size(1) == self.out_channels if self.no_trans: offset = data.new_empty(0) return deform_roi_pooling( data, rois, offset, self.spatial_scale, self.out_size, self.out_channels, self.no_trans, self.group_size, self.part_size, self.sample_per_part, self.trans_std) else: n = rois.shape[0] offset = data.new_empty(0) x = deform_roi_pooling(data, rois, offset, self.spatial_scale, self.out_size, self.out_channels, True, self.group_size, self.part_size, self.sample_per_part, self.trans_std) offset = self.offset_fc(x.view(n, -1)) offset = offset.view(n, 2, self.out_size, self.out_size) mask = self.mask_fc(x.view(n, -1)) mask = mask.view(n, 1, self.out_size, self.out_size) return deform_roi_pooling( data, rois, offset, self.spatial_scale, self.out_size, self.out_channels, self.no_trans, self.group_size, self.part_size, self.sample_per_part, self.trans_std) * mask