Commit 4a4b10fb authored by yhcao6's avatar yhcao6
Browse files

fix some bugs, add group divisible check

parent a87cb824
......@@ -53,6 +53,7 @@ class Bottleneck(_Bottleneck):
groups=groups,
bias=False)
else:
groups = self.dcn.get('groups', 1)
deformable_groups = self.dcn.get('deformable_groups', 1)
if not self.with_modulated_dcn:
conv_op = DeformConv
......@@ -194,7 +195,8 @@ class ResNeXt(ResNet):
base_width=self.base_width,
style=self.style,
with_cp=self.with_cp,
normalize=self.normalize)
normalize=self.normalize,
dcn=self.dcn)
self.inplanes = planes * self.block.expansion
layer_name = 'layer{}'.format(i + 1)
self.add_module(layer_name, res_layer)
......
......@@ -21,6 +21,13 @@ class DeformConv(nn.Module):
bias=False):
assert not bias
super(DeformConv, self).__init__()
assert in_channels % groups == 0, \
'in_channels {} cannot be divisible by groups {}'.format(
in_channels, groups)
assert out_channels % groups == 0, \
'out_channels {} cannot be divisible by groups {}'.format(
out_channels, groups)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
......
......@@ -601,7 +601,8 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
grad_weight = grad_weight.view({group, grad_weight.size(0) / group, grad_weight.size(1), grad_weight.size(2), grad_weight.size(3)});
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
if (with_bias)
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
for (int g = 0; g < group; g++){
grad_weight[g] = grad_weight[g].flatten(1).addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)).view_as(grad_weight[g]);
......@@ -612,7 +613,8 @@ void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
columns = columns.view({columns.size(0) * columns.size(1), columns.size(2)});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), grad_weight.size(2), grad_weight.size(3), grad_weight.size(4)});
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
if (with_bias)
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
}
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), grad_output.size(2), grad_output.size(3), grad_output.size(4)});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment