Commit c64beaf1 authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

Fix dpool (#1390)

* fix dpool

* add _pair in dpool func
parent 69e93f6f
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from . import deform_pool_cuda from . import deform_pool_cuda
...@@ -21,6 +22,12 @@ class DeformRoIPoolingFunction(Function): ...@@ -21,6 +22,12 @@ class DeformRoIPoolingFunction(Function):
part_size=None, part_size=None,
sample_per_part=4, sample_per_part=4,
trans_std=.0): trans_std=.0):
# TODO: support unsquare RoIs
out_h, out_w = _pair(out_size)
assert isinstance(out_h, int) and isinstance(out_w, int)
assert out_h == out_w
out_size = out_h # out_h and out_w must be equal
ctx.spatial_scale = spatial_scale ctx.spatial_scale = spatial_scale
ctx.out_size = out_size ctx.out_size = out_size
ctx.out_channels = out_channels ctx.out_channels = out_channels
...@@ -85,7 +92,7 @@ class DeformRoIPooling(nn.Module): ...@@ -85,7 +92,7 @@ class DeformRoIPooling(nn.Module):
trans_std=.0): trans_std=.0):
super(DeformRoIPooling, self).__init__() super(DeformRoIPooling, self).__init__()
self.spatial_scale = spatial_scale self.spatial_scale = spatial_scale
self.out_size = out_size self.out_size = _pair(out_size)
self.out_channels = out_channels self.out_channels = out_channels
self.no_trans = no_trans self.no_trans = no_trans
self.group_size = group_size self.group_size = group_size
...@@ -125,12 +132,12 @@ class DeformRoIPoolingPack(DeformRoIPooling): ...@@ -125,12 +132,12 @@ class DeformRoIPoolingPack(DeformRoIPooling):
if not no_trans: if not no_trans:
seq = [] seq = []
ic = self.out_size * self.out_size * self.out_channels ic = self.out_size[0] * self.out_size[1] * self.out_channels
for i in range(self.num_offset_fcs): for i in range(self.num_offset_fcs):
if i < self.num_offset_fcs - 1: if i < self.num_offset_fcs - 1:
oc = self.deform_fc_channels oc = self.deform_fc_channels
else: else:
oc = self.out_size * self.out_size * 2 oc = self.out_size[0] * self.out_size[1] * 2
seq.append(nn.Linear(ic, oc)) seq.append(nn.Linear(ic, oc))
ic = oc ic = oc
if i < self.num_offset_fcs - 1: if i < self.num_offset_fcs - 1:
...@@ -156,7 +163,7 @@ class DeformRoIPoolingPack(DeformRoIPooling): ...@@ -156,7 +163,7 @@ class DeformRoIPoolingPack(DeformRoIPooling):
self.group_size, self.part_size, self.group_size, self.part_size,
self.sample_per_part, self.trans_std) self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1)) offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size, self.out_size) offset = offset.view(n, 2, self.out_size[0], self.out_size[1])
return deform_roi_pooling(data, rois, offset, self.spatial_scale, return deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels, self.out_size, self.out_channels,
self.no_trans, self.group_size, self.no_trans, self.group_size,
...@@ -188,12 +195,12 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): ...@@ -188,12 +195,12 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
if not no_trans: if not no_trans:
offset_fc_seq = [] offset_fc_seq = []
ic = self.out_size * self.out_size * self.out_channels ic = self.out_size[0] * self.out_size[1] * self.out_channels
for i in range(self.num_offset_fcs): for i in range(self.num_offset_fcs):
if i < self.num_offset_fcs - 1: if i < self.num_offset_fcs - 1:
oc = self.deform_fc_channels oc = self.deform_fc_channels
else: else:
oc = self.out_size * self.out_size * 2 oc = self.out_size[0] * self.out_size[1] * 2
offset_fc_seq.append(nn.Linear(ic, oc)) offset_fc_seq.append(nn.Linear(ic, oc))
ic = oc ic = oc
if i < self.num_offset_fcs - 1: if i < self.num_offset_fcs - 1:
...@@ -203,12 +210,12 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): ...@@ -203,12 +210,12 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
self.offset_fc[-1].bias.data.zero_() self.offset_fc[-1].bias.data.zero_()
mask_fc_seq = [] mask_fc_seq = []
ic = self.out_size * self.out_size * self.out_channels ic = self.out_size[0] * self.out_size[1] * self.out_channels
for i in range(self.num_mask_fcs): for i in range(self.num_mask_fcs):
if i < self.num_mask_fcs - 1: if i < self.num_mask_fcs - 1:
oc = self.deform_fc_channels oc = self.deform_fc_channels
else: else:
oc = self.out_size * self.out_size oc = self.out_size[0] * self.out_size[1]
mask_fc_seq.append(nn.Linear(ic, oc)) mask_fc_seq.append(nn.Linear(ic, oc))
ic = oc ic = oc
if i < self.num_mask_fcs - 1: if i < self.num_mask_fcs - 1:
...@@ -236,9 +243,9 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): ...@@ -236,9 +243,9 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
self.group_size, self.part_size, self.group_size, self.part_size,
self.sample_per_part, self.trans_std) self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1)) offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size, self.out_size) offset = offset.view(n, 2, self.out_size[0], self.out_size[1])
mask = self.mask_fc(x.view(n, -1)) mask = self.mask_fc(x.view(n, -1))
mask = mask.view(n, 1, self.out_size, self.out_size) mask = mask.view(n, 1, self.out_size[0], self.out_size[1])
return deform_roi_pooling( return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size, data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size, self.out_channels, self.no_trans, self.group_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment