Unverified Commit 5e3f56f8 authored by ychan's avatar ychan Committed by GitHub
Browse files

fix mdconv addmm bug for parrots (#450)



* fix mdconv addmm bug for parrots

* fix mdconv ctv save tensor
Co-authored-by: default avatarhanyachao <hanyachao@sensetime.com>
parent 6159dac2
...@@ -219,14 +219,26 @@ void ModulatedDeformConvBackwardCUDAKernelLauncher( ...@@ -219,14 +219,26 @@ void ModulatedDeformConvBackwardCUDAKernelLauncher(
weight.dim(2), weight.dim(3)}); weight.dim(2), weight.dim(3)});
for (size_t g = 0; g < group; g++) { for (size_t g = 0; g < group; g++) {
auto columns_g = columns[g]; auto columns_g = ctx.createDArrayLite(
gemm(ctx, 1, true, weight.elemType(), DArrayShape(columns.dim(1), columns.dim(2)));
weight[g].view( copy(ctx, columns_g, columns[g]);
{weight.dim(1), weight.dim(2) * weight.dim(3) * weight.dim(4)}), auto weight_g = weight[g].view(
false, {weight.dim(1), weight.dim(2) * weight.dim(3) * weight.dim(4)});
grad_output[b][g].view( weight_g = transpose(ctx, weight_g, 0, 1);
{grad_output.dim(2), grad_output.dim(3) * grad_output.dim(4)}),
0, columns_g); auto grad_output_bg = ctx.createDArrayLite(
grad_output.elemType(),
DArrayShape(grad_output.dim(2), grad_output.dim(3),
grad_output.dim(4)));
copy(ctx, grad_output_bg, grad_output[b][g]);
grad_output_bg =
grad_output_bg.view({grad_output_bg.dim(0),
grad_output_bg.dim(1) * grad_output_bg.dim(2)});
columns_g =
parrots::op::addmm(ctx, columns[g], weight_g, grad_output_bg, 0, 1);
auto columns_out = columns[g];
copy(ctx, columns_out, columns_g);
} }
columns = columns.view({columns.dim(0) * columns.dim(1), columns.dim(2)}); columns = columns.view({columns.dim(0) * columns.dim(1), columns.dim(2)});
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <float.h> #include <float.h>
#include <parrots/darray/darraymath.hpp> #include <parrots/darray/darraymath.hpp>
#include <parrots/darray/mathfunctions.hpp>
#include <parrots/extension.hpp> #include <parrots/extension.hpp>
#include <parrots/foundation/darrayutil.hpp> #include <parrots/foundation/darrayutil.hpp>
#include <parrots/foundation/exceptions.hpp> #include <parrots/foundation/exceptions.hpp>
......
...@@ -57,8 +57,6 @@ class ModulatedDeformConv2dFunction(Function): ...@@ -57,8 +57,6 @@ class ModulatedDeformConv2dFunction(Function):
ctx.with_bias = bias is not None ctx.with_bias = bias is not None
if not ctx.with_bias: if not ctx.with_bias:
bias = input.new_empty(0) # fake tensor bias = input.new_empty(0) # fake tensor
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
or input.requires_grad:
ctx.save_for_backward(input, offset, mask, weight, bias) ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty( output = input.new_empty(
ModulatedDeformConv2dFunction._output_size(ctx, input, weight)) ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment