Commit d46124c6 authored by Kai Chen's avatar Kai Chen
Browse files

some renames

parent d8c72d8b
...@@ -12,19 +12,19 @@ class DeformRoIPoolingFunction(Function): ...@@ -12,19 +12,19 @@ class DeformRoIPoolingFunction(Function):
rois, rois,
offset, offset,
spatial_scale, spatial_scale,
pooled_size, out_size,
output_dim, out_channels,
no_trans, no_trans,
group_size=1, group_size=1,
part_size=None, part_size=None,
sample_per_part=4, sample_per_part=4,
trans_std=.0): trans_std=.0):
ctx.spatial_scale = spatial_scale ctx.spatial_scale = spatial_scale
ctx.pooled_size = pooled_size ctx.out_size = out_size
ctx.output_dim = output_dim ctx.out_channels = out_channels
ctx.no_trans = no_trans ctx.no_trans = no_trans
ctx.group_size = group_size ctx.group_size = group_size
ctx.part_size = pooled_size if part_size is None else part_size ctx.part_size = out_size if part_size is None else part_size
ctx.sample_per_part = sample_per_part ctx.sample_per_part = sample_per_part
ctx.trans_std = trans_std ctx.trans_std = trans_std
...@@ -32,13 +32,12 @@ class DeformRoIPoolingFunction(Function): ...@@ -32,13 +32,12 @@ class DeformRoIPoolingFunction(Function):
if not data.is_cuda: if not data.is_cuda:
raise NotImplementedError raise NotImplementedError
output = data.new_empty( n = rois.shape[0]
DeformRoIPoolingFunction._infer_shape(ctx, data, rois)) output = data.new_empty(n, out_channels, out_size, out_size)
output_count = data.new_empty( output_count = data.new_empty(n, out_channels, out_size, out_size)
DeformRoIPoolingFunction._infer_shape(ctx, data, rois))
deform_pool_cuda.deform_psroi_pooling_cuda_forward( deform_pool_cuda.deform_psroi_pooling_cuda_forward(
data, rois, offset, output, output_count, ctx.no_trans, data, rois, offset, output, output_count, ctx.no_trans,
ctx.spatial_scale, ctx.output_dim, ctx.group_size, ctx.pooled_size, ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size,
ctx.part_size, ctx.sample_per_part, ctx.trans_std) ctx.part_size, ctx.sample_per_part, ctx.trans_std)
if data.requires_grad or rois.requires_grad or offset.requires_grad: if data.requires_grad or rois.requires_grad or offset.requires_grad:
...@@ -55,20 +54,16 @@ class DeformRoIPoolingFunction(Function): ...@@ -55,20 +54,16 @@ class DeformRoIPoolingFunction(Function):
data, rois, offset = ctx.saved_tensors data, rois, offset = ctx.saved_tensors
output_count = ctx.output_count output_count = ctx.output_count
grad_input = torch.zeros_like(data) grad_input = torch.zeros_like(data)
grad_rois = None
grad_offset = torch.zeros_like(offset) grad_offset = torch.zeros_like(offset)
deform_pool_cuda.deform_psroi_pooling_cuda_backward( deform_pool_cuda.deform_psroi_pooling_cuda_backward(
grad_output, data, rois, offset, output_count, grad_input, grad_output, data, rois, offset, output_count, grad_input,
grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.output_dim, grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels,
ctx.group_size, ctx.pooled_size, ctx.part_size, ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part,
ctx.sample_per_part, ctx.trans_std) ctx.trans_std)
return (grad_input, torch.zeros_like(rois), grad_offset, None, None, return (grad_input, grad_rois, grad_offset, None, None, None, None,
None, None, None, None, None, None) None, None, None, None)
@staticmethod
def _infer_shape(ctx, data, rois):
n = rois.shape[0]
return n, ctx.output_dim, ctx.pooled_size, ctx.pooled_size
deform_roi_pooling = DeformRoIPoolingFunction.apply deform_roi_pooling = DeformRoIPoolingFunction.apply
...@@ -8,7 +8,7 @@ class DeformRoIPooling(nn.Module): ...@@ -8,7 +8,7 @@ class DeformRoIPooling(nn.Module):
def __init__(self, def __init__(self,
spatial_scale, spatial_scale,
out_size, out_size,
output_dim, out_channels,
no_trans, no_trans,
group_size=1, group_size=1,
part_size=None, part_size=None,
...@@ -17,7 +17,7 @@ class DeformRoIPooling(nn.Module): ...@@ -17,7 +17,7 @@ class DeformRoIPooling(nn.Module):
super(DeformRoIPooling, self).__init__() super(DeformRoIPooling, self).__init__()
self.spatial_scale = spatial_scale self.spatial_scale = spatial_scale
self.out_size = out_size self.out_size = out_size
self.output_dim = output_dim self.out_channels = out_channels
self.no_trans = no_trans self.no_trans = no_trans
self.group_size = group_size self.group_size = group_size
self.part_size = out_size if part_size is None else part_size self.part_size = out_size if part_size is None else part_size
...@@ -29,7 +29,7 @@ class DeformRoIPooling(nn.Module): ...@@ -29,7 +29,7 @@ class DeformRoIPooling(nn.Module):
offset = data.new_empty(0) offset = data.new_empty(0)
return deform_roi_pooling( return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size, data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size, self.part_size, self.out_channels, self.no_trans, self.group_size, self.part_size,
self.sample_per_part, self.trans_std) self.sample_per_part, self.trans_std)
...@@ -38,51 +38,52 @@ class DeformRoIPoolingPack(DeformRoIPooling): ...@@ -38,51 +38,52 @@ class DeformRoIPoolingPack(DeformRoIPooling):
def __init__(self, def __init__(self,
spatial_scale, spatial_scale,
out_size, out_size,
output_dim, out_channels,
no_trans, no_trans,
group_size=1, group_size=1,
part_size=None, part_size=None,
sample_per_part=4, sample_per_part=4,
trans_std=.0, trans_std=.0,
deform_fc_dim=1024): deform_fc_channels=1024):
super(DeformRoIPoolingPack, super(DeformRoIPoolingPack,
self).__init__(spatial_scale, out_size, output_dim, no_trans, self).__init__(spatial_scale, out_size, out_channels, no_trans,
group_size, part_size, sample_per_part, trans_std) group_size, part_size, sample_per_part, trans_std)
self.deform_fc_dim = deform_fc_dim self.deform_fc_channels = deform_fc_channels
if not no_trans: if not no_trans:
self.offset_fc = nn.Sequential( self.offset_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.output_dim, nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_dim), nn.ReLU(inplace=True), self.deform_fc_channels),
nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim, nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 2)) self.out_size * self.out_size * 2))
self.offset_fc[4].weight.data.zero_() self.offset_fc[-1].weight.data.zero_()
self.offset_fc[4].bias.data.zero_() self.offset_fc[-1].bias.data.zero_()
def forward(self, data, rois): def forward(self, data, rois):
assert data.size(1) == self.out_channels
if self.no_trans: if self.no_trans:
offset = data.new_empty(0) offset = data.new_empty(0)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std)
else: else:
n = rois.shape[0] n = rois.shape[0]
offset = data.new_empty(0) offset = data.new_empty(0)
x = deform_roi_pooling(data, rois, offset, self.spatial_scale, x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.output_dim, True, self.out_size, self.out_channels, True,
self.group_size, self.part_size, self.group_size, self.part_size,
self.sample_per_part, self.trans_std) self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1)) offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size, self.out_size) offset = offset.view(n, 2, self.out_size, self.out_size)
feat = deform_roi_pooling( return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size, data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size, self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std) self.part_size, self.sample_per_part, self.trans_std)
return feat
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
class ModulatedDeformRoIPoolingPack(DeformRoIPooling): class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
...@@ -90,57 +91,60 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): ...@@ -90,57 +91,60 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
def __init__(self, def __init__(self,
spatial_scale, spatial_scale,
out_size, out_size,
output_dim, out_channels,
no_trans, no_trans,
group_size=1, group_size=1,
part_size=None, part_size=None,
sample_per_part=4, sample_per_part=4,
trans_std=.0, trans_std=.0,
deform_fc_dim=1024): deform_fc_channels=1024):
super(ModulatedDeformRoIPoolingPack, self).__init__( super(ModulatedDeformRoIPoolingPack, self).__init__(
spatial_scale, out_size, output_dim, no_trans, group_size, spatial_scale, out_size, out_channels, no_trans, group_size,
part_size, sample_per_part, trans_std) part_size, sample_per_part, trans_std)
self.deform_fc_dim = deform_fc_dim self.deform_fc_channels = deform_fc_channels
if not no_trans: if not no_trans:
self.offset_fc = nn.Sequential( self.offset_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.output_dim, nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_dim), nn.ReLU(inplace=True), self.deform_fc_channels),
nn.Linear(self.deform_fc_dim, self.deform_fc_dim), nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim, nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 2)) self.out_size * self.out_size * 2))
self.offset_fc[4].weight.data.zero_() self.offset_fc[-1].weight.data.zero_()
self.offset_fc[4].bias.data.zero_() self.offset_fc[-1].bias.data.zero_()
self.mask_fc = nn.Sequential( self.mask_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.output_dim, nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_dim), nn.ReLU(inplace=True), self.deform_fc_channels),
nn.Linear(self.deform_fc_dim, nn.ReLU(inplace=True),
self.out_size * self.out_size * 1), nn.Sigmoid()) nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 1),
nn.Sigmoid())
self.mask_fc[2].weight.data.zero_() self.mask_fc[2].weight.data.zero_()
self.mask_fc[2].bias.data.zero_() self.mask_fc[2].bias.data.zero_()
def forward(self, data, rois): def forward(self, data, rois):
assert data.size(1) == self.out_channels
if self.no_trans: if self.no_trans:
offset = data.new_empty(0) offset = data.new_empty(0)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std)
else: else:
n = rois.shape[0] n = rois.shape[0]
offset = data.new_empty(0) offset = data.new_empty(0)
x = deform_roi_pooling(data, rois, offset, self.spatial_scale, x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.output_dim, True, self.out_size, self.out_channels, True,
self.group_size, self.part_size, self.group_size, self.part_size,
self.sample_per_part, self.trans_std) self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1)) offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size, self.out_size) offset = offset.view(n, 2, self.out_size, self.out_size)
mask = self.mask_fc(x.view(n, -1)) mask = self.mask_fc(x.view(n, -1))
mask = mask.view(n, 1, self.out_size, self.out_size) mask = mask.view(n, 1, self.out_size, self.out_size)
feat = deform_roi_pooling( return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size, data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size, self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std) * mask self.part_size, self.sample_per_part, self.trans_std) * mask
return feat
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment