Unverified Commit b7af0e9f authored by ychan's avatar ychan Committed by GitHub
Browse files

fix mdconv backward bug (#563)

parent ece32796
...@@ -273,20 +273,57 @@ void ModulatedDeformConvBackwardCUDAKernelLauncher( ...@@ -273,20 +273,57 @@ void ModulatedDeformConvBackwardCUDAKernelLauncher(
} }
for (size_t g = 0; g < group; g++) { for (size_t g = 0; g < group; g++) {
auto grad_weight_g = grad_weight[g].view( auto grad_weight_g = ctx.createDArrayLite(
{grad_weight.dim(1), grad_weight.elemType(),
grad_weight.dim(2) * grad_weight.dim(3) * grad_weight.dim(4)}); DArrayShape(grad_weight.dim(1), grad_weight.dim(2),
gemm(ctx, 1, false, grad_weight.dim(3), grad_weight.dim(4)));
grad_output[b][g].view( copy(ctx, grad_weight_g, grad_weight[g]);
{grad_output.dim(2), grad_output.dim(3) * grad_output.dim(4)}), grad_weight_g = grad_weight_g.view(
true, columns[g], 1, grad_weight_g); {grad_weight_g.dim(0),
grad_weight_g.dim(1) * grad_weight_g.dim(2) * grad_weight_g.dim(3)});
auto columns_g = columns[g];
columns_g = transpose(ctx, columns_g, 0, 1);
auto grad_output_bg = ctx.createDArrayLite(
grad_output.elemType(),
DArrayShape(grad_output.dim(2), grad_output.dim(3),
grad_output.dim(4)));
copy(ctx, grad_output_bg, grad_output[b][g]);
grad_output_bg =
grad_output_bg.view({grad_output_bg.dim(0),
grad_output_bg.dim(1) * grad_output_bg.dim(2)});
grad_weight_g = parrots::op::addmm(ctx, grad_weight_g, grad_output_bg,
columns_g, 1, 1);
auto grad_weight_out = grad_weight[g];
copy(ctx, grad_weight_out, grad_weight_g);
if (with_bias) { if (with_bias) {
auto grad_bias_g = grad_bias[g].view({grad_bias.dim(1), 1}); auto grad_bias_g = ctx.createDArrayLite(grad_bias.elemType(),
gemm(ctx, 1, false, DArrayShape(grad_bias.dim(1)));
grad_output[b][g].view( copy(ctx, grad_bias_g, grad_bias[g]);
{grad_output.dim(2), grad_output.dim(3) * grad_output.dim(4)}), grad_bias_g = grad_bias_g.view({grad_bias_g.dim(0), 1});
false, ones.view({ones.dim(0) * ones.dim(1), 1}), 1, grad_bias_g);
auto grad_output_bg = ctx.createDArrayLite(
grad_output.elemType(),
DArrayShape(grad_output.dim(2), grad_output.dim(3),
grad_output.dim(4)));
copy(ctx, grad_output_bg, grad_output[b][g]);
grad_output_bg = grad_output_bg.view(
{grad_output_bg.dim(0),
grad_output_bg.dim(1) * grad_output_bg.dim(2)});
auto ones_g = ctx.createDArrayLite(
ones.elemType(), DArrayShape(ones.dim(0), ones.dim(1)));
copy(ctx, ones_g, ones);
ones_g = ones_g.view({ones_g.dim(0) * ones_g.dim(1), 1});
grad_bias_g =
parrots::op::addmm(ctx, grad_bias_g, grad_output_bg, ones_g, 1, 1);
auto grad_bias_out = grad_bias[g];
copy(ctx, grad_bias_out, grad_bias_g);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment