"vscode:/vscode.git/clone" did not exist on "cb1c3df1a220425fd3e5f853e3713744e16a952c"
Commit dc57735f authored by yhcao6's avatar yhcao6
Browse files

move bias check from module to function

parent b1ba5939
...@@ -107,17 +107,18 @@ class ModulatedDeformConvFunction(Function): ...@@ -107,17 +107,18 @@ class ModulatedDeformConvFunction(Function):
offset, offset,
mask, mask,
weight, weight,
bias, bias=None,
stride, stride=1,
padding, padding=0,
dilation=1, dilation=1,
deformable_groups=1, deformable_groups=1):
with_bias=False):
ctx.stride = stride ctx.stride = stride
ctx.padding = padding ctx.padding = padding
ctx.dilation = dilation ctx.dilation = dilation
ctx.deformable_groups = deformable_groups ctx.deformable_groups = deformable_groups
ctx.with_bias = with_bias ctx.with_bias = bias is not None
if not ctx.with_bias:
bias = input.new_empty(1) # fake tensor
if not input.is_cuda: if not input.is_cuda:
raise NotImplementedError raise NotImplementedError
if weight.requires_grad or mask.requires_grad or offset.requires_grad \ if weight.requires_grad or mask.requires_grad or offset.requires_grad \
...@@ -149,9 +150,11 @@ class ModulatedDeformConvFunction(Function): ...@@ -149,9 +150,11 @@ class ModulatedDeformConvFunction(Function):
grad_output, weight.shape[2], weight.shape[3], ctx.stride, grad_output, weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.deformable_groups, ctx.with_bias) ctx.deformable_groups, ctx.with_bias)
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None, None) None, None, None, None)
@staticmethod @staticmethod
def _infer_shape(ctx, input, weight): def _infer_shape(ctx, input, weight):
......
...@@ -43,9 +43,6 @@ class DeformRoIPoolingFunction(Function): ...@@ -43,9 +43,6 @@ class DeformRoIPoolingFunction(Function):
if data.requires_grad or rois.requires_grad or offset.requires_grad: if data.requires_grad or rois.requires_grad or offset.requires_grad:
ctx.save_for_backward(data, rois, offset) ctx.save_for_backward(data, rois, offset)
# ctx.data = data
# ctx.rois = rois
# ctx.offset = offset
ctx.output_count = output_count ctx.output_count = output_count
return output return output
...@@ -56,9 +53,6 @@ class DeformRoIPoolingFunction(Function): ...@@ -56,9 +53,6 @@ class DeformRoIPoolingFunction(Function):
raise NotImplementedError raise NotImplementedError
data, rois, offset = ctx.saved_tensors data, rois, offset = ctx.saved_tensors
# data = ctx.data
# rois = ctx.rois
# offset = ctx.offset
output_count = ctx.output_count output_count = ctx.output_count
grad_input = torch.zeros_like(data) grad_input = torch.zeros_like(data)
grad_offset = torch.zeros_like(offset) grad_offset = torch.zeros_like(offset)
......
...@@ -71,7 +71,7 @@ class ModulatedDeformConv(nn.Module): ...@@ -71,7 +71,7 @@ class ModulatedDeformConv(nn.Module):
if bias: if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels)) self.bias = nn.Parameter(torch.Tensor(out_channels))
else: else:
self.bias = nn.Parameter(torch.Tensor([0])) # fake tensor self.register_parameter('bias', None)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -80,14 +80,13 @@ class ModulatedDeformConv(nn.Module): ...@@ -80,14 +80,13 @@ class ModulatedDeformConv(nn.Module):
n *= k n *= k
stdv = 1. / math.sqrt(n) stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv) self.weight.data.uniform_(-stdv, stdv)
if self.with_bias: if self.bias is not None:
self.bias.data.zero_() self.bias.data.zero_()
def forward(self, input, offset, mask): def forward(self, input, offset, mask):
return modulated_deform_conv(input, offset, mask, self.weight, return modulated_deform_conv(input, offset, mask, self.weight,
self.bias, self.stride, self.padding, self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups, self.dilation, self.deformable_groups)
self.with_bias)
class ModulatedDeformConvPack(ModulatedDeformConv): class ModulatedDeformConvPack(ModulatedDeformConv):
...@@ -110,8 +109,8 @@ class ModulatedDeformConvPack(ModulatedDeformConv): ...@@ -110,8 +109,8 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
self.deformable_groups * 3 * self.kernel_size[0] * self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1], self.kernel_size[1],
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
stride=(self.stride, self.stride), stride=_pair(self.stride),
padding=(self.padding, self.padding), padding=_pair(self.padding),
bias=True) bias=True)
self.init_offset() self.init_offset()
...@@ -126,5 +125,4 @@ class ModulatedDeformConvPack(ModulatedDeformConv): ...@@ -126,5 +125,4 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
mask = torch.sigmoid(mask) mask = torch.sigmoid(mask)
return modulated_deform_conv(input, offset, mask, self.weight, return modulated_deform_conv(input, offset, mask, self.weight,
self.bias, self.stride, self.padding, self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups, self.dilation, self.deformable_groups)
self.with_bias)
...@@ -33,7 +33,7 @@ class DeformRoIPooling(nn.Module): ...@@ -33,7 +33,7 @@ class DeformRoIPooling(nn.Module):
self.sample_per_part, self.trans_std) self.sample_per_part, self.trans_std)
class ModulatedDeformRoIPoolingPack(DeformRoIPooling): class DeformRoIPoolingPack(DeformRoIPooling):
def __init__(self, def __init__(self,
spatial_scale, spatial_scale,
...@@ -45,32 +45,22 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): ...@@ -45,32 +45,22 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
sample_per_part=4, sample_per_part=4,
trans_std=.0, trans_std=.0,
deform_fc_dim=1024): deform_fc_dim=1024):
super(ModulatedDeformRoIPoolingPack, self).__init__( super(DeformRoIPoolingPack,
spatial_scale, out_size, output_dim, no_trans, group_size, self).__init__(spatial_scale, out_size, output_dim, no_trans,
part_size, sample_per_part, trans_std) group_size, part_size, sample_per_part, trans_std)
self.deform_fc_dim = deform_fc_dim self.deform_fc_dim = deform_fc_dim
if not no_trans: if not no_trans:
self.offset_fc = nn.Sequential( self.offset_fc = nn.Sequential(
nn.Linear( nn.Linear(self.out_size * self.out_size * self.output_dim,
self.out_size * self.out_size * self.output_dim, self.deform_fc_dim), nn.ReLU(inplace=True),
self.deform_fc_dim), nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim, self.deform_fc_dim), nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim, nn.Linear(self.deform_fc_dim,
self.out_size * self.out_size * 2)) self.out_size * self.out_size * 2))
self.offset_fc[4].weight.data.zero_() self.offset_fc[4].weight.data.zero_()
self.offset_fc[4].bias.data.zero_() self.offset_fc[4].bias.data.zero_()
self.mask_fc = nn.Sequential(
nn.Linear(
self.out_size * self.out_size * self.output_dim,
self.deform_fc_dim), nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim,
self.out_size * self.out_size * 1),
nn.Sigmoid())
self.mask_fc[2].weight.data.zero_()
self.mask_fc[2].bias.data.zero_()
def forward(self, data, rois): def forward(self, data, rois):
if self.no_trans: if self.no_trans:
...@@ -84,12 +74,10 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): ...@@ -84,12 +74,10 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
self.sample_per_part, self.trans_std) self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1)) offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size, self.out_size) offset = offset.view(n, 2, self.out_size, self.out_size)
mask = self.mask_fc(x.view(n, -1))
mask = mask.view(n, 1, self.out_size, self.out_size)
feat = deform_roi_pooling( feat = deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size, data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size, self.output_dim, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std) * mask self.part_size, self.sample_per_part, self.trans_std)
return feat return feat
return deform_roi_pooling( return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size, data, rois, offset, self.spatial_scale, self.out_size,
...@@ -97,7 +85,7 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): ...@@ -97,7 +85,7 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
self.sample_per_part, self.trans_std) self.sample_per_part, self.trans_std)
class DeformRoIPoolingPack(DeformRoIPooling): class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
def __init__(self, def __init__(self,
spatial_scale, spatial_scale,
...@@ -109,7 +97,7 @@ class DeformRoIPoolingPack(DeformRoIPooling): ...@@ -109,7 +97,7 @@ class DeformRoIPoolingPack(DeformRoIPooling):
sample_per_part=4, sample_per_part=4,
trans_std=.0, trans_std=.0,
deform_fc_dim=1024): deform_fc_dim=1024):
super(DeformRoIPoolingPack, self).__init__( super(ModulatedDeformRoIPoolingPack, self).__init__(
spatial_scale, out_size, output_dim, no_trans, group_size, spatial_scale, out_size, output_dim, no_trans, group_size,
part_size, sample_per_part, trans_std) part_size, sample_per_part, trans_std)
...@@ -117,15 +105,21 @@ class DeformRoIPoolingPack(DeformRoIPooling): ...@@ -117,15 +105,21 @@ class DeformRoIPoolingPack(DeformRoIPooling):
if not no_trans: if not no_trans:
self.offset_fc = nn.Sequential( self.offset_fc = nn.Sequential(
nn.Linear( nn.Linear(self.out_size * self.out_size * self.output_dim,
self.out_size * self.out_size * self.output_dim, self.deform_fc_dim), nn.ReLU(inplace=True),
self.deform_fc_dim), nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim, self.deform_fc_dim), nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim, nn.Linear(self.deform_fc_dim,
self.out_size * self.out_size * 2)) self.out_size * self.out_size * 2))
self.offset_fc[4].weight.data.zero_() self.offset_fc[4].weight.data.zero_()
self.offset_fc[4].bias.data.zero_() self.offset_fc[4].bias.data.zero_()
self.mask_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.output_dim,
self.deform_fc_dim), nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim,
self.out_size * self.out_size * 1), nn.Sigmoid())
self.mask_fc[2].weight.data.zero_()
self.mask_fc[2].bias.data.zero_()
def forward(self, data, rois): def forward(self, data, rois):
if self.no_trans: if self.no_trans:
...@@ -139,10 +133,12 @@ class DeformRoIPoolingPack(DeformRoIPooling): ...@@ -139,10 +133,12 @@ class DeformRoIPoolingPack(DeformRoIPooling):
self.sample_per_part, self.trans_std) self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1)) offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size, self.out_size) offset = offset.view(n, 2, self.out_size, self.out_size)
mask = self.mask_fc(x.view(n, -1))
mask = mask.view(n, 1, self.out_size, self.out_size)
feat = deform_roi_pooling( feat = deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size, data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size, self.output_dim, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std) self.part_size, self.sample_per_part, self.trans_std) * mask
return feat return feat
return deform_roi_pooling( return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size, data, rois, offset, self.spatial_scale, self.out_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment