Unverified Commit 831c0df3 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[DO NOT MERGE/n00b] Add empty batch support for DeformConv2d (#2782)

* Adding checks on forward and backward passes.

* Adding unit-tests.
parent d5379656
......@@ -478,8 +478,7 @@ class DeformConvTester(OpTester, unittest.TestCase):
out += bias.view(1, n_out_channels, 1, 1)
return out
def get_fn_args(self, device, contiguous):
batch_sz = 33
def get_fn_args(self, device, contiguous, batch_sz):
n_in_channels = 6
n_out_channels = 2
n_weight_grps = 2
......@@ -516,7 +515,11 @@ class DeformConvTester(OpTester, unittest.TestCase):
return x, weight, offset, bias, stride, pad, dilation
def _test_forward(self, device, contiguous):
x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous)
for batch_sz in [0, 33]:
self._test_forward_with_batchsize(device, contiguous, batch_sz)
def _test_forward_with_batchsize(self, device, contiguous, batch_sz):
x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz)
in_channels = 6
out_channels = 2
kernel_size = (3, 2)
......@@ -538,7 +541,11 @@ class DeformConvTester(OpTester, unittest.TestCase):
res = layer(x, wrong_offset)
def _test_backward(self, device, contiguous):
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous)
for batch_sz in [0, 33]:
self._test_backward_with_batchsize(device, contiguous, batch_sz)
def _test_backward_with_batchsize(self, device, contiguous, batch_sz):
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz)
def func(x_, offset_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation)
......
......@@ -326,6 +326,9 @@ at::Tensor DeformConv2d_forward_cpu(
out_w);
auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options());
if (batch_sz == 0) {
return out;
}
// Separate batches into blocks
out = out.view({batch_sz / n_parallel_imgs,
......@@ -713,6 +716,9 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
auto grad_input = at::zeros_like(input);
auto grad_offset = at::zeros_like(offset);
if (batch_sz == 0) {
return std::make_tuple(grad_input, grad_offset);
}
auto columns = at::empty(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options());
......@@ -839,6 +845,9 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
long out_w = grad_out.size(3);
auto grad_weight = at::zeros_like(weight);
if (batch_sz == 0) {
return grad_weight;
}
at::Tensor grad_out_buf = grad_out
.reshape({batch_sz / n_parallel_imgs,
......
......@@ -343,6 +343,9 @@ at::Tensor DeformConv2d_forward_cuda(
out_w);
auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options());
if (batch_sz == 0) {
return out;
}
// Separate batches into blocks
out = out.view({batch_sz / n_parallel_imgs,
......@@ -743,6 +746,9 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda(
auto grad_input = at::zeros_like(input);
auto grad_offset = at::zeros_like(offset);
if (batch_sz == 0) {
return std::make_tuple(grad_input, grad_offset);
}
auto columns = at::empty(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options());
......@@ -869,6 +875,9 @@ static at::Tensor deform_conv_backward_parameters_cuda(
long out_w = grad_out.size(3);
auto grad_weight = at::zeros_like(weight);
if (batch_sz == 0) {
return grad_weight;
}
at::Tensor grad_out_buf = grad_out.reshape(
{batch_sz / n_parallel_imgs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment