Unverified Commit 21143568 authored by ychan's avatar ychan Committed by GitHub
Browse files

fix dcon forward and backward bug (#565)

parent b7af0e9f
...@@ -268,6 +268,7 @@ void DeformConvForwardCUDAKernelLauncher( ...@@ -268,6 +268,7 @@ void DeformConvForwardCUDAKernelLauncher(
gemm(ctx, 1, false, weight_g, false, columns_g, 1, output_g); gemm(ctx, 1, false, weight_g, false, columns_g, 1, output_g);
} }
columns = columns.view({columns.dim(0) * columns.dim(1), columns.dim(2)}); columns = columns.view({columns.dim(0) * columns.dim(1), columns.dim(2)});
weight = weight.view({nOutputPlane, nInputPlane, kH, kW});
} }
output_buffer = output_buffer.view( output_buffer = output_buffer.view(
...@@ -372,6 +373,7 @@ void DeformConvBackwardInputCUDAKernelLauncher( ...@@ -372,6 +373,7 @@ void DeformConvBackwardInputCUDAKernelLauncher(
gradOutput = gradOutput.view({gradOutput.dim(0), gradOutput = gradOutput.view({gradOutput.dim(0),
gradOutput.dim(1) * gradOutput.dim(2), gradOutput.dim(1) * gradOutput.dim(2),
im2col_step, outputHeight, outputWidth}); im2col_step, outputHeight, outputWidth});
weight = weight.view({nOutputPlane, nInputPlane, kH, kW});
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
......
...@@ -278,6 +278,8 @@ void DeformConvForwardCUDAKernelLauncher(Tensor input, Tensor weight, ...@@ -278,6 +278,8 @@ void DeformConvForwardCUDAKernelLauncher(Tensor input, Tensor weight,
} }
columns = columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)}); columns.view({columns.size(0) * columns.size(1), columns.size(2)});
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
} }
output_buffer = output_buffer.view( output_buffer = output_buffer.view(
...@@ -375,6 +377,7 @@ void DeformConvBackwardInputCUDAKernelLauncher( ...@@ -375,6 +377,7 @@ void DeformConvBackwardInputCUDAKernelLauncher(
gradOutput = gradOutput.view( gradOutput = gradOutput.view(
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
weight = weight.view({nOutputPlane, nInputPlane, kH, kW});
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment