Unverified Commit ccd797dd authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add test for large batches in DeformConv2d (#2040)

* Add test for large batches in DeformConv2d

* Clean-up and (try) fix DeformConv2d

* Simplifications and bugfixes

* Try fix CUDA now
parent 979bb72e
...@@ -454,7 +454,7 @@ class DeformConvTester(OpTester, unittest.TestCase): ...@@ -454,7 +454,7 @@ class DeformConvTester(OpTester, unittest.TestCase):
return out return out
def get_fn_args(self, device, contiguous): def get_fn_args(self, device, contiguous):
batch_sz = 1 batch_sz = 33
n_in_channels = 6 n_in_channels = 6
n_out_channels = 2 n_out_channels = 2
n_weight_grps = 2 n_weight_grps = 2
......
...@@ -713,55 +713,49 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu( ...@@ -713,55 +713,49 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
auto grad_input = at::zeros_like(input); auto grad_input = at::zeros_like(input);
auto grad_offset = at::zeros_like(offset); auto grad_offset = at::zeros_like(offset);
auto columns = at::zeros( auto columns = at::empty(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options()); input.options());
// Separate into blocks // Separate into blocks
grad_input = grad_input.view( grad_input = grad_input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
input = input.view( input = input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
grad_offset = grad_offset.view({batch_sz / n_parallel_imgs, grad_offset = grad_offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs, n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w, n_offset_grps * 2 * weight_h * weight_w,
out_h, out_h,
out_w}); out_w});
offset = offset.view({batch_sz / n_parallel_imgs, offset = offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs, n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w, n_offset_grps * 2 * weight_h * weight_w,
out_h, out_h,
out_w}); out_w});
grad_out = grad_out.view({batch_sz / n_parallel_imgs, grad_out = grad_out.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs, n_parallel_imgs,
n_out_channels, n_weight_grps,
out_h, n_out_channels / n_weight_grps,
out_w}); out_h,
grad_out.transpose_(1, 2); out_w}).permute({0, 2, 3, 1, 4, 5});
grad_out = grad_out.view({grad_out.size(0),
n_weight_grps, weight = weight.reshape({n_weight_grps,
grad_out.size(1) / n_weight_grps, weight.size(0) / n_weight_grps,
grad_out.size(2), weight.size(1),
grad_out.size(3), weight.size(2),
grad_out.size(4)}); weight.size(3)});
weight = weight.view({n_weight_grps, columns = columns.view(
weight.size(0) / n_weight_grps, {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
weight.size(1),
weight.size(2),
weight.size(3)});
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
columns.zero_();
// Separate into weight groups // Separate into weight groups
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int g = 0; g < n_weight_grps; g++) { for (int g = 0; g < n_weight_grps; g++) {
columns[g] = columns[g].addmm_( columns[g] = columns[g].addmm_(
weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1));
} }
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
compute_grad_offset( compute_grad_offset(
columns, columns,
...@@ -801,20 +795,9 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu( ...@@ -801,20 +795,9 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
grad_input[elt]); grad_input[elt]);
} }
grad_out = grad_out.view({grad_out.size(0),
grad_out.size(1) * grad_out.size(2),
grad_out.size(3),
grad_out.size(4),
grad_out.size(5)});
grad_out.transpose_(1, 2);
grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w});
grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w});
input = input.view({batch_sz, n_in_channels, in_h, in_w});
grad_offset = grad_offset.view( grad_offset = grad_offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
offset = offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
return std::make_tuple(grad_input, grad_offset); return std::make_tuple(grad_input, grad_offset);
} }
...@@ -854,46 +837,36 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( ...@@ -854,46 +837,36 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
long out_w = grad_out.size(3); long out_w = grad_out.size(3);
auto grad_weight = at::zeros_like(weight); auto grad_weight = at::zeros_like(weight);
;
auto columns = at::zeros(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options());
grad_out = grad_out.view({batch_sz / n_parallel_imgs, at::Tensor grad_out_buf = grad_out.reshape(
n_parallel_imgs, {batch_sz / n_parallel_imgs,
n_out_channels, n_parallel_imgs,
out_h, n_weight_grps,
out_w}); n_out_channels / n_weight_grps,
grad_out.transpose_(1, 2); out_h,
out_w}
at::Tensor grad_out_buf = at::zeros_like(grad_out); ).permute({0, 2, 3, 1, 4, 5}).contiguous();
grad_out_buf.copy_(grad_out);
grad_out_buf = grad_out_buf.view({batch_sz / n_parallel_imgs, input = input.reshape(
n_out_channels,
n_parallel_imgs * out_h,
out_w});
grad_out_buf = grad_out_buf.view({grad_out_buf.size(0),
n_weight_grps,
grad_out_buf.size(1) / n_weight_grps,
grad_out_buf.size(2),
grad_out_buf.size(3)});
grad_out.transpose_(1, 2);
grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w});
input = input.view(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
offset = offset.view({batch_sz / n_parallel_imgs, offset = offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs, n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w, n_offset_grps * 2 * weight_h * weight_w,
out_h, out_h,
out_w}); out_w});
grad_weight = grad_weight.view({n_weight_grps, grad_weight = grad_weight.view({n_weight_grps,
grad_weight.size(0) / n_weight_grps, grad_weight.size(0) / n_weight_grps,
grad_weight.size(1), grad_weight.size(1),
grad_weight.size(2), grad_weight.size(2),
grad_weight.size(3)}); grad_weight.size(3)});
auto columns = at::empty(
{n_weight_grps,
n_in_channels * weight_w * weight_h / n_weight_grps,
n_parallel_imgs * out_h * out_w},
input.options());
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
deformable_im2col( deformable_im2col(
input[elt], input[elt],
...@@ -915,8 +888,6 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( ...@@ -915,8 +888,6 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
n_offset_grps, n_offset_grps,
columns); columns);
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int g = 0; g < n_weight_grps; g++) { for (int g = 0; g < n_weight_grps; g++) {
grad_weight[g] = grad_weight[g] =
grad_weight[g] grad_weight[g]
...@@ -925,14 +896,8 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( ...@@ -925,14 +896,8 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0))
.view_as(grad_weight[g]); .view_as(grad_weight[g]);
} }
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
} }
input = input.view({batch_sz, n_in_channels, in_h, in_w});
offset = offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2), grad_weight.size(2),
grad_weight.size(3), grad_weight.size(3),
......
...@@ -744,55 +744,48 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda( ...@@ -744,55 +744,48 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda(
auto grad_input = at::zeros_like(input); auto grad_input = at::zeros_like(input);
auto grad_offset = at::zeros_like(offset); auto grad_offset = at::zeros_like(offset);
auto columns = at::zeros( auto columns = at::empty(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w}, {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options()); input.options());
// Separate into blocks // Separate into blocks
grad_input = grad_input.view( grad_input = grad_input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
input = input.view( input = input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
grad_offset = grad_offset.view({batch_sz / n_parallel_imgs, grad_offset = grad_offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs, n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w, n_offset_grps * 2 * weight_h * weight_w,
out_h, out_h,
out_w}); out_w});
offset = offset.view({batch_sz / n_parallel_imgs, offset = offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs, n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w, n_offset_grps * 2 * weight_h * weight_w,
out_h, out_h,
out_w}); out_w});
grad_out = grad_out.view({batch_sz / n_parallel_imgs, grad_out = grad_out.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs, n_parallel_imgs,
n_out_channels, n_weight_grps,
out_h, n_out_channels / n_weight_grps,
out_w}); out_h,
grad_out.transpose_(1, 2); out_w}).permute({0, 2, 3, 1, 4, 5});
grad_out = grad_out.view({grad_out.size(0),
n_weight_grps, weight = weight.reshape({n_weight_grps,
grad_out.size(1) / n_weight_grps, weight.size(0) / n_weight_grps,
grad_out.size(2), weight.size(1),
grad_out.size(3), weight.size(2),
grad_out.size(4)}); weight.size(3)});
weight = weight.view({n_weight_grps,
weight.size(0) / n_weight_grps,
weight.size(1),
weight.size(2),
weight.size(3)});
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
columns.zero_();
// Separate into weight groups // Separate into weight groups
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int g = 0; g < n_weight_grps; g++) { for (int g = 0; g < n_weight_grps; g++) {
columns[g] = columns[g].addmm_( columns[g] = columns[g].addmm_(
weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1));
} }
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
compute_grad_offset( compute_grad_offset(
columns, columns,
...@@ -832,20 +825,10 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda( ...@@ -832,20 +825,10 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda(
grad_input[elt]); grad_input[elt]);
} }
grad_out = grad_out.view({grad_out.size(0),
grad_out.size(1) * grad_out.size(2),
grad_out.size(3),
grad_out.size(4),
grad_out.size(5)});
grad_out.transpose_(1, 2);
grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w});
grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w}); grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w});
input = input.view({batch_sz, n_in_channels, in_h, in_w});
grad_offset = grad_offset.view( grad_offset = grad_offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w}); {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
offset = offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
return std::make_tuple(grad_input, grad_offset); return std::make_tuple(grad_input, grad_offset);
} }
...@@ -887,46 +870,36 @@ static at::Tensor deform_conv_backward_parameters_cuda( ...@@ -887,46 +870,36 @@ static at::Tensor deform_conv_backward_parameters_cuda(
long out_w = grad_out.size(3); long out_w = grad_out.size(3);
auto grad_weight = at::zeros_like(weight); auto grad_weight = at::zeros_like(weight);
;
auto columns = at::zeros(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options());
grad_out = grad_out.view({batch_sz / n_parallel_imgs, at::Tensor grad_out_buf = grad_out.reshape(
n_parallel_imgs, {batch_sz / n_parallel_imgs,
n_out_channels, n_parallel_imgs,
out_h, n_weight_grps,
out_w}); n_out_channels / n_weight_grps,
grad_out.transpose_(1, 2); out_h,
out_w}
at::Tensor grad_out_buf = at::zeros_like(grad_out); ).permute({0, 2, 3, 1, 4, 5}).contiguous();
grad_out_buf.copy_(grad_out);
grad_out_buf = grad_out_buf.view({batch_sz / n_parallel_imgs, input = input.reshape(
n_out_channels,
n_parallel_imgs * out_h,
out_w});
grad_out_buf = grad_out_buf.view({grad_out_buf.size(0),
n_weight_grps,
grad_out_buf.size(1) / n_weight_grps,
grad_out_buf.size(2),
grad_out_buf.size(3)});
grad_out.transpose_(1, 2);
grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w});
input = input.view(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w}); {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
offset = offset.view({batch_sz / n_parallel_imgs, offset = offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs, n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w, n_offset_grps * 2 * weight_h * weight_w,
out_h, out_h,
out_w}); out_w});
grad_weight = grad_weight.reshape({n_weight_grps,
grad_weight.size(0) / n_weight_grps,
grad_weight.size(1),
grad_weight.size(2),
grad_weight.size(3)});
auto columns = at::empty(
{n_weight_grps,
n_in_channels * weight_w * weight_h / n_weight_grps,
n_parallel_imgs * out_h * out_w},
input.options());
grad_weight = grad_weight.view({n_weight_grps,
grad_weight.size(0) / n_weight_grps,
grad_weight.size(1),
grad_weight.size(2),
grad_weight.size(3)});
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
deformable_im2col( deformable_im2col(
input[elt], input[elt],
...@@ -948,8 +921,6 @@ static at::Tensor deform_conv_backward_parameters_cuda( ...@@ -948,8 +921,6 @@ static at::Tensor deform_conv_backward_parameters_cuda(
n_offset_grps, n_offset_grps,
columns); columns);
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int g = 0; g < n_weight_grps; g++) { for (int g = 0; g < n_weight_grps; g++) {
grad_weight[g] = grad_weight[g] =
grad_weight[g] grad_weight[g]
...@@ -958,14 +929,8 @@ static at::Tensor deform_conv_backward_parameters_cuda( ...@@ -958,14 +929,8 @@ static at::Tensor deform_conv_backward_parameters_cuda(
grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0)) grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0))
.view_as(grad_weight[g]); .view_as(grad_weight[g]);
} }
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
} }
input = input.view({batch_sz, n_in_channels, in_h, in_w});
offset = offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2), grad_weight.size(2),
grad_weight.size(3), grad_weight.size(3),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment