Unverified Commit e3c63f34 authored by WangJiaZhen's avatar WangJiaZhen Committed by GitHub
Browse files

[Fix] Fix deform conv by adding an extra argument im2col_step (#1459)

* [Fix] fix deform conv by add argument

* [Fix] replace useless func with np.repeat and add necessary comment

* [Fix] reduce batch_size and remove useless lines and modify some var name

* [Fix] change position of comments

* [Fix] add im2col_step in the docstring and add two test cases

* [Fix] fix docstring and add comments for test cases

* [Fix] fix docstring

* [Fix] add note, fix issue link and other details

* [Fix] fix docstring details

* [Fix] fix links in docstring

* [Fix] fix docstring details
parent 0633f911
...@@ -117,8 +117,8 @@ class DeformConv2dFunction(Function): ...@@ -117,8 +117,8 @@ class DeformConv2dFunction(Function):
grad_input = grad_offset = grad_weight = None grad_input = grad_offset = grad_weight = None
cur_im2col_step = min(ctx.im2col_step, input.size(0)) cur_im2col_step = min(ctx.im2col_step, input.size(0))
assert (input.size(0) % assert (input.size(0) % cur_im2col_step
cur_im2col_step) == 0, 'im2col step must divide batchsize' ) == 0, 'batch size must be divisible by im2col_step'
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
...@@ -197,6 +197,13 @@ class DeformConv2d(nn.Module): ...@@ -197,6 +197,13 @@ class DeformConv2d(nn.Module):
`Deformable Convolutional Networks `Deformable Convolutional Networks
<https://arxiv.org/pdf/1703.06211.pdf>`_ <https://arxiv.org/pdf/1703.06211.pdf>`_
Note:
The argument ``im2col_step`` was added in version 1.3.17, which means
number of samples processed by the ``im2col_cuda_kernel`` per call.
It enables users to define ``batch_size`` and ``im2col_step`` more
flexibly and solved `issue mmcv#1440
<https://github.com/open-mmlab/mmcv/issues/1440>`_.
Args: Args:
in_channels (int): Number of channels in the input image. in_channels (int): Number of channels in the input image.
out_channels (int): Number of channels produced by the convolution. out_channels (int): Number of channels produced by the convolution.
...@@ -210,7 +217,10 @@ class DeformConv2d(nn.Module): ...@@ -210,7 +217,10 @@ class DeformConv2d(nn.Module):
deform_groups (int): Number of deformable group partitions. deform_groups (int): Number of deformable group partitions.
bias (bool): If True, adds a learnable bias to the output. bias (bool): If True, adds a learnable bias to the output.
Default: False. Default: False.
im2col_step (int): Number of samples processed by im2col_cuda_kernel
per call. It will work when ``batch_size`` > ``im2col_step``, but
``batch_size`` must be divisible by ``im2col_step``. Default: 32.
`New in version 1.3.17.`
""" """
@deprecated_api_warning({'deformable_groups': 'deform_groups'}, @deprecated_api_warning({'deformable_groups': 'deform_groups'},
...@@ -224,7 +234,8 @@ class DeformConv2d(nn.Module): ...@@ -224,7 +234,8 @@ class DeformConv2d(nn.Module):
dilation: Union[int, Tuple[int, ...]] = 1, dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1, groups: int = 1,
deform_groups: int = 1, deform_groups: int = 1,
bias: bool = False) -> None: bias: bool = False,
im2col_step: int = 32) -> None:
super(DeformConv2d, self).__init__() super(DeformConv2d, self).__init__()
assert not bias, \ assert not bias, \
...@@ -243,6 +254,7 @@ class DeformConv2d(nn.Module): ...@@ -243,6 +254,7 @@ class DeformConv2d(nn.Module):
self.dilation = _pair(dilation) self.dilation = _pair(dilation)
self.groups = groups self.groups = groups
self.deform_groups = deform_groups self.deform_groups = deform_groups
self.im2col_step = im2col_step
# enable compatibility with nn.Conv2d # enable compatibility with nn.Conv2d
self.transposed = False self.transposed = False
self.output_padding = _single(0) self.output_padding = _single(0)
...@@ -293,7 +305,8 @@ class DeformConv2d(nn.Module): ...@@ -293,7 +305,8 @@ class DeformConv2d(nn.Module):
offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0) offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0)
offset = offset.contiguous() offset = offset.contiguous()
out = deform_conv2d(x, offset, self.weight, self.stride, self.padding, out = deform_conv2d(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deform_groups) self.dilation, self.groups, self.deform_groups,
False, self.im2col_step)
if input_pad: if input_pad:
out = out[:, :, :out.size(2) - pad_h, :out.size(3) - out = out[:, :, :out.size(2) - pad_h, :out.size(3) -
pad_w].contiguous() pad_w].contiguous()
...@@ -361,7 +374,8 @@ class DeformConv2dPack(DeformConv2d): ...@@ -361,7 +374,8 @@ class DeformConv2dPack(DeformConv2d):
def forward(self, x): def forward(self, x):
offset = self.conv_offset(x) offset = self.conv_offset(x)
return deform_conv2d(x, offset, self.weight, self.stride, self.padding, return deform_conv2d(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deform_groups) self.dilation, self.groups, self.deform_groups,
False, self.im2col_step)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs): missing_keys, unexpected_keys, error_msgs):
......
...@@ -39,15 +39,27 @@ class TestDeformconv(object): ...@@ -39,15 +39,27 @@ class TestDeformconv(object):
def _test_deformconv(self, def _test_deformconv(self,
dtype=torch.float, dtype=torch.float,
threshold=1e-3, threshold=1e-3,
device='cuda'): device='cuda',
batch_size=10,
im2col_step=2):
if not torch.cuda.is_available() and device == 'cuda': if not torch.cuda.is_available() and device == 'cuda':
pytest.skip('test requires GPU') pytest.skip('test requires GPU')
from mmcv.ops import DeformConv2dPack from mmcv.ops import DeformConv2dPack
c_in = 1 c_in = 1
c_out = 1 c_out = 1
x = torch.tensor(input, device=device, dtype=dtype) batch_size = 10
repeated_input = np.repeat(input, batch_size, axis=0)
repeated_gt_out = np.repeat(gt_out, batch_size, axis=0)
repeated_gt_x_grad = np.repeat(gt_x_grad, batch_size, axis=0)
x = torch.tensor(repeated_input, device=device, dtype=dtype)
x.requires_grad = True x.requires_grad = True
model = DeformConv2dPack(c_in, c_out, 2, stride=1, padding=0) model = DeformConv2dPack(
in_channels=c_in,
out_channels=c_out,
kernel_size=2,
stride=1,
padding=0,
im2col_step=im2col_step)
model.conv_offset.weight.data = torch.nn.Parameter( model.conv_offset.weight.data = torch.nn.Parameter(
torch.Tensor(offset_weight).reshape(8, 1, 2, 2)) torch.Tensor(offset_weight).reshape(8, 1, 2, 2))
model.conv_offset.bias.data = torch.nn.Parameter( model.conv_offset.bias.data = torch.nn.Parameter(
...@@ -61,15 +73,21 @@ class TestDeformconv(object): ...@@ -61,15 +73,21 @@ class TestDeformconv(object):
out = model(x) out = model(x)
out.backward(torch.ones_like(out)) out.backward(torch.ones_like(out))
assert np.allclose(out.data.detach().cpu().numpy(), gt_out, threshold) assert np.allclose(out.data.detach().cpu().numpy(), repeated_gt_out,
assert np.allclose(x.grad.detach().cpu().numpy(), gt_x_grad, threshold) threshold)
assert np.allclose(x.grad.detach().cpu().numpy(), repeated_gt_x_grad,
threshold)
# the batch size of the input is increased which results in
# a larger gradient so we need to divide by the batch_size
assert np.allclose( assert np.allclose(
model.conv_offset.weight.grad.detach().cpu().numpy(), model.conv_offset.weight.grad.detach().cpu().numpy() / batch_size,
gt_offset_weight_grad, threshold) gt_offset_weight_grad, threshold)
assert np.allclose(model.conv_offset.bias.grad.detach().cpu().numpy(), assert np.allclose(
gt_offset_bias_grad, threshold) model.conv_offset.bias.grad.detach().cpu().numpy() / batch_size,
assert np.allclose(model.weight.grad.detach().cpu().numpy(), gt_offset_bias_grad, threshold)
gt_deform_weight_grad, threshold) assert np.allclose(
model.weight.grad.detach().cpu().numpy() / batch_size,
gt_deform_weight_grad, threshold)
from mmcv.ops import DeformConv2d from mmcv.ops import DeformConv2d
...@@ -86,7 +104,11 @@ class TestDeformconv(object): ...@@ -86,7 +104,11 @@ class TestDeformconv(object):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
model = DeformConv2d(3, 4, 3, groups=3) model = DeformConv2d(3, 4, 3, groups=3)
def _test_amp_deformconv(self, input_dtype, threshold=1e-3): def _test_amp_deformconv(self,
input_dtype,
threshold=1e-3,
batch_size=10,
im2col_step=2):
"""The function to test amp released on pytorch 1.6.0. """The function to test amp released on pytorch 1.6.0.
The type of input data might be torch.float or torch.half, The type of input data might be torch.float or torch.half,
...@@ -102,9 +124,18 @@ class TestDeformconv(object): ...@@ -102,9 +124,18 @@ class TestDeformconv(object):
from mmcv.ops import DeformConv2dPack from mmcv.ops import DeformConv2dPack
c_in = 1 c_in = 1
c_out = 1 c_out = 1
x = torch.Tensor(input).cuda().type(input_dtype) repeated_input = np.repeat(input, batch_size, axis=0)
repeated_gt_out = np.repeat(gt_out, batch_size, axis=0)
repeated_gt_x_grad = np.repeat(gt_x_grad, batch_size, axis=0)
x = torch.Tensor(repeated_input).cuda().type(input_dtype)
x.requires_grad = True x.requires_grad = True
model = DeformConv2dPack(c_in, c_out, 2, stride=1, padding=0) model = DeformConv2dPack(
in_channels=c_in,
out_channels=c_out,
kernel_size=2,
stride=1,
padding=0,
im2col_step=im2col_step)
model.conv_offset.weight.data = torch.nn.Parameter( model.conv_offset.weight.data = torch.nn.Parameter(
torch.Tensor(offset_weight).reshape(8, 1, 2, 2)) torch.Tensor(offset_weight).reshape(8, 1, 2, 2))
model.conv_offset.bias.data = torch.nn.Parameter( model.conv_offset.bias.data = torch.nn.Parameter(
...@@ -116,15 +147,19 @@ class TestDeformconv(object): ...@@ -116,15 +147,19 @@ class TestDeformconv(object):
out = model(x) out = model(x)
out.backward(torch.ones_like(out)) out.backward(torch.ones_like(out))
assert np.allclose(out.data.detach().cpu().numpy(), gt_out, threshold) assert np.allclose(out.data.detach().cpu().numpy(), repeated_gt_out,
assert np.allclose(x.grad.detach().cpu().numpy(), gt_x_grad, threshold) threshold)
assert np.allclose(x.grad.detach().cpu().numpy(), repeated_gt_x_grad,
threshold)
assert np.allclose( assert np.allclose(
model.conv_offset.weight.grad.detach().cpu().numpy(), model.conv_offset.weight.grad.detach().cpu().numpy() / batch_size,
gt_offset_weight_grad, threshold) gt_offset_weight_grad, threshold)
assert np.allclose(model.conv_offset.bias.grad.detach().cpu().numpy(), assert np.allclose(
gt_offset_bias_grad, threshold) model.conv_offset.bias.grad.detach().cpu().numpy() / batch_size,
assert np.allclose(model.weight.grad.detach().cpu().numpy(), gt_offset_bias_grad, threshold)
gt_deform_weight_grad, threshold) assert np.allclose(
model.weight.grad.detach().cpu().numpy() / batch_size,
gt_deform_weight_grad, threshold)
from mmcv.ops import DeformConv2d from mmcv.ops import DeformConv2d
...@@ -147,6 +182,13 @@ class TestDeformconv(object): ...@@ -147,6 +182,13 @@ class TestDeformconv(object):
self._test_deformconv(torch.double) self._test_deformconv(torch.double)
self._test_deformconv(torch.float) self._test_deformconv(torch.float)
self._test_deformconv(torch.half, threshold=1e-1) self._test_deformconv(torch.half, threshold=1e-1)
# test batch_size < im2col_step
self._test_deformconv(torch.float, batch_size=1, im2col_step=2)
# test bach_size % im2col_step != 0
with pytest.raises(
AssertionError,
match='batch size must be divisible by im2col_step'):
self._test_deformconv(torch.float, batch_size=10, im2col_step=3)
# test amp when torch version >= '1.6.0', the type of # test amp when torch version >= '1.6.0', the type of
# input data for deformconv might be torch.float or torch.half # input data for deformconv might be torch.float or torch.half
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment