Unverified Commit 17ea1482 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Improve error message and avoid segfault in DeformConv2d (#1660)

parent 2d7c0667
......@@ -467,11 +467,9 @@ class DeformConvTester(OpTester, unittest.TestCase):
out_channels = 2
kernel_size = (3, 2)
groups = 2
offset_groups = 3
layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, offset_groups=offset_groups).to(device=x.device,
dtype=x.dtype)
dilation=dilation, groups=groups).to(device=x.device, dtype=x.dtype)
res = layer(x, offset)
weight = layer.weight.data
......@@ -480,6 +478,11 @@ class DeformConvTester(OpTester, unittest.TestCase):
self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(res, expected))
# test for wrong sizes
with self.assertRaises(RuntimeError):
wrong_offset = torch.rand_like(offset[:, :2])
res = layer(x, wrong_offset)
def _test_backward(self, device, contiguous):
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous)
......
......@@ -296,16 +296,16 @@ at::Tensor DeformConv2d_forward_cpu(
TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1));
TORCH_CHECK(weight.size(0) % n_weight_grps == 0);
TORCH_CHECK(input.size(1) % n_offset_grps == 0);
TORCH_CHECK(
(offset.size(0) == input.size(0)), "invalid batch size of offset");
TORCH_CHECK(
(offset.size(1) == n_offset_grps * 2 * weight_h * weight_w),
"got: ",
"offset.shape[1] is not valid: got: ",
offset.size(1),
" expected: ",
n_offset_grps * 2 * weight_h * weight_w);
TORCH_CHECK(input.size(1) % n_offset_grps == 0);
TORCH_CHECK(
(offset.size(0) == input.size(0)), "invalid batch size of offset");
TORCH_CHECK(
(offset.size(2) == out_h && offset.size(3) == out_w),
"offset output dims: (",
......
......@@ -314,16 +314,16 @@ at::Tensor DeformConv2d_forward_cuda(
TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1));
TORCH_CHECK(weight.size(0) % n_weight_grps == 0);
TORCH_CHECK(input.size(1) % n_offset_grps == 0);
TORCH_CHECK(
(offset.size(0) == input.size(0)), "invalid batch size of offset");
TORCH_CHECK(
(offset.size(1) == n_offset_grps * 2 * weight_h * weight_w),
"got: ",
"offset.shape[1] is not valid: got: ",
offset.size(1),
" expected: ",
n_offset_grps * 2 * weight_h * weight_w);
TORCH_CHECK(input.size(1) % n_offset_grps == 0);
TORCH_CHECK(
(offset.size(0) == input.size(0)), "invalid batch size of offset");
TORCH_CHECK(
(offset.size(2) == out_h && offset.size(3) == out_w),
"offset output dims: (",
......
......@@ -28,6 +28,20 @@ def deform_conv2d(input, offset, weight, bias=None, stride=(1, 1), padding=(0, 0
Returns:
output (Tensor[batch_sz, out_channels, out_h, out_w]): result of convolution
Examples::
>>> input = torch.rand(1, 3, 10, 10)
>>> kh, kw = 3, 3
>>> weight = torch.rand(5, 3, kh, kw)
>>> # offset should have the same spatial size as the output
>>> # of the convolution. In this case, for an input of 10, stride of 1
>>> # and kernel size of 3, without padding, the output size is 8
>>> offset = torch.rand(5, 2 * kh * kw, 8, 8)
>>> out = deform_conv2d(input, offset, weight)
>>> print(out.shape)
>>> # returns
>>> torch.Size([1, 5, 8, 8])
"""
out_channels = weight.shape[0]
......@@ -43,6 +57,13 @@ def deform_conv2d(input, offset, weight, bias=None, stride=(1, 1), padding=(0, 0
n_offset_grps = offset.shape[1] // (2 * weights_h * weights_w)
n_weight_grps = n_in_channels // weight.shape[1]
if n_offset_grps == 0:
raise RuntimeError(
"the shape of the offset tensor at dimension 1 is not valid. It should "
"be a multiple of 2 * weight.size[2] * weight.size[3].\n"
"Got offset.shape[1]={}, while 2 * weight.size[2] * weight.size[3]={}".format(
offset.shape[1], 2 * weights_h * weights_w))
return torch.ops.torchvision.deform_conv2d(
input,
weight,
......@@ -60,13 +81,11 @@ class DeformConv2d(nn.Module):
See deform_conv2d
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, offset_groups=1, bias=True):
dilation=1, groups=1, bias=True):
super(DeformConv2d, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if in_channels % offset_groups != 0:
raise ValueError('in_channels must be divisible by offset_groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
......@@ -77,7 +96,6 @@ class DeformConv2d(nn.Module):
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.offset_groups = offset_groups
self.weight = Parameter(torch.empty(out_channels, in_channels // groups,
self.kernel_size[0], self.kernel_size[1]))
......@@ -100,8 +118,6 @@ class DeformConv2d(nn.Module):
"""
Arguments:
input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]):
convolution weights, split into groups of size (in_channels // groups)
offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width,
out_height, out_width]): offsets to be applied for each position in the
convolution kernel.
......@@ -118,7 +134,6 @@ class DeformConv2d(nn.Module):
s += ', padding={padding}' if self.padding != (0, 0) else ''
s += ', dilation={dilation}' if self.dilation != (1, 1) else ''
s += ', groups={groups}' if self.groups != 1 else ''
s += ', offset_groups={offset_groups}' if self.offset_groups != 1 else ''
s += ', bias=False' if self.bias is None else ''
s += ')'
return s.format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment