Unverified Commit 7ee5a8b7 authored by Yuwen Xiong's avatar Yuwen Xiong Committed by GitHub
Browse files

Fix shape error for deform conv (#2027)

* fix shape error for deform conv gpu op

recover shape of columns for next iteration in for loops, previous version will cause error when batch_sz / n_parallel_imgs > 1

* fix shape error for deform conv cpu op

recover shape of columns for next iteration in for loops, previous version will cause error when batch_sz / n_parallel_imgs > 1
parent 1c7aa0c0
...@@ -392,6 +392,8 @@ at::Tensor DeformConv2d_forward_cpu( ...@@ -392,6 +392,8 @@ at::Tensor DeformConv2d_forward_cpu(
.addmm_(weight[g].flatten(1), columns[g]) .addmm_(weight[g].flatten(1), columns[g])
.view_as(out_buf[b][g]); .view_as(out_buf[b][g]);
} }
columns = columns.view(
{columns.size(0) * columns.size(1), columns.size(2)});
} }
out_buf = out_buf.view({batch_sz / n_parallel_imgs, out_buf = out_buf.view({batch_sz / n_parallel_imgs,
...@@ -744,15 +746,16 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu( ...@@ -744,15 +746,16 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
grad_out.size(3), grad_out.size(3),
grad_out.size(4)}); grad_out.size(4)});
weight = weight.view({n_weight_grps,
weight.size(0) / n_weight_grps,
weight.size(1),
weight.size(2),
weight.size(3)});
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
// Separate into weight groups // Separate into weight groups
columns = columns.view( columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
weight = weight.view({n_weight_grps,
weight.size(0) / n_weight_grps,
weight.size(1),
weight.size(2),
weight.size(3)});
for (int g = 0; g < n_weight_grps; g++) { for (int g = 0; g < n_weight_grps; g++) {
columns[g] = columns[g].addmm_( columns[g] = columns[g].addmm_(
weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1));
......
...@@ -410,6 +410,8 @@ at::Tensor DeformConv2d_forward_cuda( ...@@ -410,6 +410,8 @@ at::Tensor DeformConv2d_forward_cuda(
.addmm_(weight[g].flatten(1), columns[g]) .addmm_(weight[g].flatten(1), columns[g])
.view_as(out_buf[b][g]); .view_as(out_buf[b][g]);
} }
columns = columns.view(
{columns.size(0) * columns.size(1), columns.size(2)});
} }
out_buf = out_buf.view({batch_sz / n_parallel_imgs, out_buf = out_buf.view({batch_sz / n_parallel_imgs,
...@@ -775,15 +777,16 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda( ...@@ -775,15 +777,16 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda(
grad_out.size(3), grad_out.size(3),
grad_out.size(4)}); grad_out.size(4)});
weight = weight.view({n_weight_grps,
weight.size(0) / n_weight_grps,
weight.size(1),
weight.size(2),
weight.size(3)});
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) { for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
// Separate into weight groups // Separate into weight groups
columns = columns.view( columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)}); {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
weight = weight.view({n_weight_grps,
weight.size(0) / n_weight_grps,
weight.size(1),
weight.size(2),
weight.size(3)});
for (int g = 0; g < n_weight_grps; g++) { for (int g = 0; g < n_weight_grps; g++) {
columns[g] = columns[g].addmm_( columns[g] = columns[g].addmm_(
weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1)); weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment