Unverified Commit 17ea1482 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Improve error message and avoid segfault in DeformConv2d (#1660)

parent 2d7c0667
...@@ -467,11 +467,9 @@ class DeformConvTester(OpTester, unittest.TestCase): ...@@ -467,11 +467,9 @@ class DeformConvTester(OpTester, unittest.TestCase):
out_channels = 2 out_channels = 2
kernel_size = (3, 2) kernel_size = (3, 2)
groups = 2 groups = 2
offset_groups = 3
layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, offset_groups=offset_groups).to(device=x.device, dilation=dilation, groups=groups).to(device=x.device, dtype=x.dtype)
dtype=x.dtype)
res = layer(x, offset) res = layer(x, offset)
weight = layer.weight.data weight = layer.weight.data
...@@ -480,6 +478,11 @@ class DeformConvTester(OpTester, unittest.TestCase): ...@@ -480,6 +478,11 @@ class DeformConvTester(OpTester, unittest.TestCase):
self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(res, expected)) self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(res, expected))
# test for wrong sizes
with self.assertRaises(RuntimeError):
wrong_offset = torch.rand_like(offset[:, :2])
res = layer(x, wrong_offset)
def _test_backward(self, device, contiguous): def _test_backward(self, device, contiguous):
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous) x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous)
......
...@@ -296,16 +296,16 @@ at::Tensor DeformConv2d_forward_cpu( ...@@ -296,16 +296,16 @@ at::Tensor DeformConv2d_forward_cpu(
TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1)); TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1));
TORCH_CHECK(weight.size(0) % n_weight_grps == 0); TORCH_CHECK(weight.size(0) % n_weight_grps == 0);
TORCH_CHECK(input.size(1) % n_offset_grps == 0);
TORCH_CHECK(
(offset.size(0) == input.size(0)), "invalid batch size of offset");
TORCH_CHECK( TORCH_CHECK(
(offset.size(1) == n_offset_grps * 2 * weight_h * weight_w), (offset.size(1) == n_offset_grps * 2 * weight_h * weight_w),
"got: ", "offset.shape[1] is not valid: got: ",
offset.size(1), offset.size(1),
" expected: ", " expected: ",
n_offset_grps * 2 * weight_h * weight_w); n_offset_grps * 2 * weight_h * weight_w);
TORCH_CHECK(input.size(1) % n_offset_grps == 0);
TORCH_CHECK(
(offset.size(0) == input.size(0)), "invalid batch size of offset");
TORCH_CHECK( TORCH_CHECK(
(offset.size(2) == out_h && offset.size(3) == out_w), (offset.size(2) == out_h && offset.size(3) == out_w),
"offset output dims: (", "offset output dims: (",
......
...@@ -314,16 +314,16 @@ at::Tensor DeformConv2d_forward_cuda( ...@@ -314,16 +314,16 @@ at::Tensor DeformConv2d_forward_cuda(
TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1)); TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1));
TORCH_CHECK(weight.size(0) % n_weight_grps == 0); TORCH_CHECK(weight.size(0) % n_weight_grps == 0);
TORCH_CHECK(input.size(1) % n_offset_grps == 0);
TORCH_CHECK(
(offset.size(0) == input.size(0)), "invalid batch size of offset");
TORCH_CHECK( TORCH_CHECK(
(offset.size(1) == n_offset_grps * 2 * weight_h * weight_w), (offset.size(1) == n_offset_grps * 2 * weight_h * weight_w),
"got: ", "offset.shape[1] is not valid: got: ",
offset.size(1), offset.size(1),
" expected: ", " expected: ",
n_offset_grps * 2 * weight_h * weight_w); n_offset_grps * 2 * weight_h * weight_w);
TORCH_CHECK(input.size(1) % n_offset_grps == 0);
TORCH_CHECK(
(offset.size(0) == input.size(0)), "invalid batch size of offset");
TORCH_CHECK( TORCH_CHECK(
(offset.size(2) == out_h && offset.size(3) == out_w), (offset.size(2) == out_h && offset.size(3) == out_w),
"offset output dims: (", "offset output dims: (",
......
...@@ -28,6 +28,20 @@ def deform_conv2d(input, offset, weight, bias=None, stride=(1, 1), padding=(0, 0 ...@@ -28,6 +28,20 @@ def deform_conv2d(input, offset, weight, bias=None, stride=(1, 1), padding=(0, 0
Returns: Returns:
output (Tensor[batch_sz, out_channels, out_h, out_w]): result of convolution output (Tensor[batch_sz, out_channels, out_h, out_w]): result of convolution
Examples::
>>> input = torch.rand(1, 3, 10, 10)
>>> kh, kw = 3, 3
>>> weight = torch.rand(5, 3, kh, kw)
>>> # offset should have the same spatial size as the output
>>> # of the convolution. In this case, for an input of 10, stride of 1
>>> # and kernel size of 3, without padding, the output size is 8
>>> offset = torch.rand(5, 2 * kh * kw, 8, 8)
>>> out = deform_conv2d(input, offset, weight)
>>> print(out.shape)
>>> # returns
>>> torch.Size([1, 5, 8, 8])
""" """
out_channels = weight.shape[0] out_channels = weight.shape[0]
...@@ -43,6 +57,13 @@ def deform_conv2d(input, offset, weight, bias=None, stride=(1, 1), padding=(0, 0 ...@@ -43,6 +57,13 @@ def deform_conv2d(input, offset, weight, bias=None, stride=(1, 1), padding=(0, 0
n_offset_grps = offset.shape[1] // (2 * weights_h * weights_w) n_offset_grps = offset.shape[1] // (2 * weights_h * weights_w)
n_weight_grps = n_in_channels // weight.shape[1] n_weight_grps = n_in_channels // weight.shape[1]
if n_offset_grps == 0:
raise RuntimeError(
"the shape of the offset tensor at dimension 1 is not valid. It should "
"be a multiple of 2 * weight.size[2] * weight.size[3].\n"
"Got offset.shape[1]={}, while 2 * weight.size[2] * weight.size[3]={}".format(
offset.shape[1], 2 * weights_h * weights_w))
return torch.ops.torchvision.deform_conv2d( return torch.ops.torchvision.deform_conv2d(
input, input,
weight, weight,
...@@ -60,13 +81,11 @@ class DeformConv2d(nn.Module): ...@@ -60,13 +81,11 @@ class DeformConv2d(nn.Module):
See deform_conv2d See deform_conv2d
""" """
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, offset_groups=1, bias=True): dilation=1, groups=1, bias=True):
super(DeformConv2d, self).__init__() super(DeformConv2d, self).__init__()
if in_channels % groups != 0: if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups') raise ValueError('in_channels must be divisible by groups')
if in_channels % offset_groups != 0:
raise ValueError('in_channels must be divisible by offset_groups')
if out_channels % groups != 0: if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups') raise ValueError('out_channels must be divisible by groups')
...@@ -77,7 +96,6 @@ class DeformConv2d(nn.Module): ...@@ -77,7 +96,6 @@ class DeformConv2d(nn.Module):
self.padding = _pair(padding) self.padding = _pair(padding)
self.dilation = _pair(dilation) self.dilation = _pair(dilation)
self.groups = groups self.groups = groups
self.offset_groups = offset_groups
self.weight = Parameter(torch.empty(out_channels, in_channels // groups, self.weight = Parameter(torch.empty(out_channels, in_channels // groups,
self.kernel_size[0], self.kernel_size[1])) self.kernel_size[0], self.kernel_size[1]))
...@@ -100,8 +118,6 @@ class DeformConv2d(nn.Module): ...@@ -100,8 +118,6 @@ class DeformConv2d(nn.Module):
""" """
Arguments: Arguments:
input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]):
convolution weights, split into groups of size (in_channels // groups)
offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width,
out_height, out_width]): offsets to be applied for each position in the out_height, out_width]): offsets to be applied for each position in the
convolution kernel. convolution kernel.
...@@ -118,7 +134,6 @@ class DeformConv2d(nn.Module): ...@@ -118,7 +134,6 @@ class DeformConv2d(nn.Module):
s += ', padding={padding}' if self.padding != (0, 0) else '' s += ', padding={padding}' if self.padding != (0, 0) else ''
s += ', dilation={dilation}' if self.dilation != (1, 1) else '' s += ', dilation={dilation}' if self.dilation != (1, 1) else ''
s += ', groups={groups}' if self.groups != 1 else '' s += ', groups={groups}' if self.groups != 1 else ''
s += ', offset_groups={offset_groups}' if self.offset_groups != 1 else ''
s += ', bias=False' if self.bias is None else '' s += ', bias=False' if self.bias is None else ''
s += ')' s += ')'
return s.format(**self.__dict__) return s.format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment