Unverified Commit 831c0df3 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[DO NOT MERGE/n00b] Add empty batch support for DeformConv2d (#2782)

* Adding checks on forward and backward passes.

* Adding unit-tests.
parent d5379656
...@@ -478,8 +478,7 @@ class DeformConvTester(OpTester, unittest.TestCase): ...@@ -478,8 +478,7 @@ class DeformConvTester(OpTester, unittest.TestCase):
out += bias.view(1, n_out_channels, 1, 1) out += bias.view(1, n_out_channels, 1, 1)
return out return out
def get_fn_args(self, device, contiguous): def get_fn_args(self, device, contiguous, batch_sz):
batch_sz = 33
n_in_channels = 6 n_in_channels = 6
n_out_channels = 2 n_out_channels = 2
n_weight_grps = 2 n_weight_grps = 2
...@@ -516,7 +515,11 @@ class DeformConvTester(OpTester, unittest.TestCase): ...@@ -516,7 +515,11 @@ class DeformConvTester(OpTester, unittest.TestCase):
return x, weight, offset, bias, stride, pad, dilation return x, weight, offset, bias, stride, pad, dilation
def _test_forward(self, device, contiguous): def _test_forward(self, device, contiguous):
x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous) for batch_sz in [0, 33]:
self._test_forward_with_batchsize(device, contiguous, batch_sz)
def _test_forward_with_batchsize(self, device, contiguous, batch_sz):
x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz)
in_channels = 6 in_channels = 6
out_channels = 2 out_channels = 2
kernel_size = (3, 2) kernel_size = (3, 2)
...@@ -538,7 +541,11 @@ class DeformConvTester(OpTester, unittest.TestCase): ...@@ -538,7 +541,11 @@ class DeformConvTester(OpTester, unittest.TestCase):
res = layer(x, wrong_offset) res = layer(x, wrong_offset)
def _test_backward(self, device, contiguous): def _test_backward(self, device, contiguous):
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous) for batch_sz in [0, 33]:
self._test_backward_with_batchsize(device, contiguous, batch_sz)
def _test_backward_with_batchsize(self, device, contiguous, batch_sz):
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz)
def func(x_, offset_, weight_, bias_): def func(x_, offset_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation) return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation)
......
...@@ -326,6 +326,9 @@ at::Tensor DeformConv2d_forward_cpu( ...@@ -326,6 +326,9 @@ at::Tensor DeformConv2d_forward_cpu(
out_w); out_w);
auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options()); auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options());
if (batch_sz == 0) {
return out;
}
// Separate batches into blocks // Separate batches into blocks
out = out.view({batch_sz / n_parallel_imgs, out = out.view({batch_sz / n_parallel_imgs,
...@@ -713,6 +716,9 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu( ...@@ -713,6 +716,9 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
auto grad_input = at::zeros_like(input); auto grad_input = at::zeros_like(input);
auto grad_offset = at::zeros_like(offset); auto grad_offset = at::zeros_like(offset);
if (batch_sz == 0) {
return std::make_tuple(grad_input, grad_offset);
}
auto columns = at::empty( auto columns = at::empty(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options()); input.options());
...@@ -839,6 +845,9 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( ...@@ -839,6 +845,9 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
long out_w = grad_out.size(3); long out_w = grad_out.size(3);
auto grad_weight = at::zeros_like(weight); auto grad_weight = at::zeros_like(weight);
if (batch_sz == 0) {
return grad_weight;
}
at::Tensor grad_out_buf = grad_out at::Tensor grad_out_buf = grad_out
.reshape({batch_sz / n_parallel_imgs, .reshape({batch_sz / n_parallel_imgs,
......
...@@ -343,6 +343,9 @@ at::Tensor DeformConv2d_forward_cuda( ...@@ -343,6 +343,9 @@ at::Tensor DeformConv2d_forward_cuda(
out_w); out_w);
auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options()); auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options());
if (batch_sz == 0) {
return out;
}
// Separate batches into blocks // Separate batches into blocks
out = out.view({batch_sz / n_parallel_imgs, out = out.view({batch_sz / n_parallel_imgs,
...@@ -743,6 +746,9 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda( ...@@ -743,6 +746,9 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda(
auto grad_input = at::zeros_like(input); auto grad_input = at::zeros_like(input);
auto grad_offset = at::zeros_like(offset); auto grad_offset = at::zeros_like(offset);
if (batch_sz == 0) {
return std::make_tuple(grad_input, grad_offset);
}
auto columns = at::empty( auto columns = at::empty(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options()); input.options());
...@@ -869,6 +875,9 @@ static at::Tensor deform_conv_backward_parameters_cuda( ...@@ -869,6 +875,9 @@ static at::Tensor deform_conv_backward_parameters_cuda(
long out_w = grad_out.size(3); long out_w = grad_out.size(3);
auto grad_weight = at::zeros_like(weight); auto grad_weight = at::zeros_like(weight);
if (batch_sz == 0) {
return grad_weight;
}
at::Tensor grad_out_buf = grad_out.reshape( at::Tensor grad_out_buf = grad_out.reshape(
{batch_sz / n_parallel_imgs, {batch_sz / n_parallel_imgs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment