Unverified Commit 8016d880 authored by pc's avatar pc Committed by GitHub
Browse files

fix dcn compile error in parrots (#1378)

parent fa22d9db
This diff is collapsed.
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
template <typename T>
T deformable_im2col_bilinear_cpu(const T *input, const int data_width,
const int height, const int width, T h, T w) {
if (h <= -1 || height <= h || w <= -1 || width <= w) {
return 0;
}
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
T lh = h - h_low;
T lw = w - w_low;
T hh = 1 - lh, hw = 1 - lw;
T v1 = 0;
if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low];
T v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = input[h_low * data_width + w_high];
T v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = input[h_high * data_width + w_low];
T v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = input[h_high * data_width + w_high];
T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename T>
T get_gradient_weight_cpu(T argmax_h, T argmax_w, const int h, const int w,
const int height, const int width) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
// empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
T weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
template <typename T>
T get_coordinate_weight_cpu(T argmax_h, T argmax_w, const int height,
const int width, const T *im_data,
const int data_width, const int bp_dir) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
// empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
T weight = 0;
if (bp_dir == 0) {
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) *
im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) *
im_data[argmax_h_high * data_width + argmax_w_high];
} else if (bp_dir == 1) {
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
template <typename T>
void deformable_im2col_cpu_kernel(
const int n, const T *data_im, const T *data_offset, const int height,
const int width, const int kernel_h, const int kernel_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group, const int batch_size,
const int num_channels, const int deformable_group, const int height_col,
const int width_col, T *data_col) {
for (int index = 0; index < n; index++) {
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
T *data_col_ptr =
data_col +
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
const T *data_im_ptr =
data_im + (b_col * num_channels + c_im) * height * width;
const T *data_offset_ptr =
data_offset + (b_col * deformable_group + deformable_group_index) * 2 *
kernel_h * kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
T val = static_cast<T>(0);
const T h_im = h_in + i * dilation_h + offset_h;
const T w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
val = deformable_im2col_bilinear_cpu(data_im_ptr, width, height,
width, h_im, w_im);
*data_col_ptr = val;
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
template <typename T>
void deformable_col2im_cpu_kernel(
const int n, const T *data_col, const T *data_offset, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group, const int batch_size,
const int deformable_group, const int height_col, const int width_col,
T *grad_im) {
for (int index = 0; index < n; index++) {
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i =
(index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c =
index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const T *data_offset_ptr =
data_offset + (b * deformable_group + deformable_group_index) * 2 *
kernel_h * kernel_w * height_col * width_col;
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
const T cur_inv_h_data = h_in + i * dilation_h + offset_h;
const T cur_inv_w_data = w_in + j * dilation_w + offset_w;
const T cur_top_grad = data_col[index];
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++) {
for (int dx = -2; dx <= 2; dx++) {
if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1) {
int cur_bottom_grad_pos =
((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
T weight =
get_gradient_weight_cpu(cur_inv_h_data, cur_inv_w_data,
cur_h + dy, cur_w + dx, height, width);
*(grad_im + cur_bottom_grad_pos) += weight * cur_top_grad;
}
}
}
}
}
template <typename T>
void deformable_col2im_coord_cpu_kernel(
const int n, const T *data_col, const T *data_im, const T *data_offset,
const int channels, const int height, const int width, const int kernel_h,
const int kernel_w, const int pad_h, const int pad_w, const int stride_h,
const int stride_w, const int dilation_h, const int dilation_w,
const int channel_per_deformable_group, const int batch_size,
const int offset_channels, const int deformable_group, const int height_col,
const int width_col, T *grad_offset) {
for (int index = 0; index < n; index++) {
T val = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const T *data_col_ptr = data_col + deformable_group_index *
channel_per_deformable_group *
batch_size * width_col * height_col;
const T *data_im_ptr =
data_im + (b * deformable_group + deformable_group_index) *
channel_per_deformable_group / kernel_h / kernel_w *
height * width;
const T *data_offset_ptr =
data_offset + (b * deformable_group + deformable_group_index) * 2 *
kernel_h * kernel_w * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group;
col_c += col_step) {
const int col_pos =
(((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i =
(col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr =
(((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr =
(((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
w_out);
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
T inv_h = h_in + i * dilation_h + offset_h;
T inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
inv_h = inv_w = -2;
const T weight = get_coordinate_weight_cpu(
inv_h, inv_w, height, width, data_im_ptr + cnt * height * width,
width, bp_dir);
val += weight * data_col_ptr[col_pos];
cnt += 1;
}
grad_offset[index] = val;
}
}
void deformable_im2col_cpu(Tensor data_im, Tensor data_offset,
const int channels, const int height,
const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
Tensor data_col) {
int height_col =
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col =
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "deformable_im2col_cpu", [&] {
deformable_im2col_cpu_kernel<scalar_t>(
num_kernels, data_im.data_ptr<scalar_t>(),
data_offset.data_ptr<scalar_t>(), height, width, ksize_h, ksize_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
channel_per_deformable_group, parallel_imgs, channels,
deformable_group, height_col, width_col,
data_col.data_ptr<scalar_t>());
});
}
void deformable_col2im_cpu(Tensor data_col, Tensor data_offset,
const int channels, const int height,
const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
Tensor grad_im) {
// todo: make sure parallel_imgs is passed in correctly
int height_col =
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col =
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels =
channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
deformable_col2im_cpu_kernel<scalar_t>(
num_kernels, data_col_, data_offset_, channels, height, width,
ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, dilation_h,
dilation_w, channel_per_deformable_group, parallel_imgs,
deformable_group, height_col, width_col, grad_im_);
}));
}
void deformable_col2im_coord_cpu(
Tensor data_col, Tensor data_im, Tensor data_offset, const int channels,
const int height, const int width, const int ksize_h, const int ksize_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int parallel_imgs,
const int deformable_group, Tensor grad_offset) {
int height_col =
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col =
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w *
deformable_group * parallel_imgs;
int channel_per_deformable_group =
channels * ksize_h * ksize_w / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_coord_cpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
deformable_col2im_coord_cpu_kernel<scalar_t>(
num_kernels, data_col_, data_im_, data_offset_, channels, height,
width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group, parallel_imgs,
2 * ksize_h * ksize_w * deformable_group, deformable_group,
height_col, width_col, grad_offset_);
}));
}
...@@ -8,12 +8,6 @@ ...@@ -8,12 +8,6 @@
using namespace parrots; using namespace parrots;
#ifdef MMCV_WITH_CUDA #ifdef MMCV_WITH_CUDA
/*void deform_conv_forward_cuda(Tensor input, Tensor weight, Tensor offset,
* Tensor output, Tensor columns, Tensor ones,
* int kW, int kH, int dW, int dH, int padW,
* int padH, int dilationW, int dilationH, int
* group, int deformable_group, int im2col_step);
*/
void deform_conv_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr, void deform_conv_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
const OperatorBase::in_list_t& ins, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) { OperatorBase::out_list_t& outs) {
...@@ -41,18 +35,11 @@ void deform_conv_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr, ...@@ -41,18 +35,11 @@ void deform_conv_forward_cuda_parrots(CudaContext& ctx, const SSElement& attr,
auto columns = buildATensor(ctx, outs[1]); auto columns = buildATensor(ctx, outs[1]);
auto ones = buildATensor(ctx, outs[2]); auto ones = buildATensor(ctx, outs[2]);
deform_conv_forward_cuda(input, weight, offset, output, columns, ones, kW, kH, deform_conv_forward(input, weight, offset, output, columns, ones, kW, kH, dW,
dW, dH, padW, padH, dilationW, dilationH, group, dH, padW, padH, dilationW, dilationH, group,
deformable_group, im2col_step); deformable_group, im2col_step);
} }
/*void deform_conv_backward_input_cuda(Tensor input, Tensor offset,
* Tensor gradOutput, Tensor gradInput,
* Tensor gradOffset, Tensor weight,
* Tensor columns, int kW, int kH, int dW,
* int dH, int padW, int padH, int
* dilationW, int dilationH, int group, int deformable_group, int im2col_step);
*/
void deform_conv_backward_input_cuda_parrots(CudaContext& ctx, void deform_conv_backward_input_cuda_parrots(CudaContext& ctx,
const SSElement& attr, const SSElement& attr,
const OperatorBase::in_list_t& ins, const OperatorBase::in_list_t& ins,
...@@ -82,18 +69,12 @@ void deform_conv_backward_input_cuda_parrots(CudaContext& ctx, ...@@ -82,18 +69,12 @@ void deform_conv_backward_input_cuda_parrots(CudaContext& ctx,
auto weight = buildATensor(ctx, outs[2]); auto weight = buildATensor(ctx, outs[2]);
auto columns = buildATensor(ctx, outs[3]); auto columns = buildATensor(ctx, outs[3]);
deform_conv_backward_input_cuda(input, offset, gradOutput, gradInput, deform_conv_backward_input(input, offset, gradOutput, gradInput, gradOffset,
gradOffset, weight, columns, kW, kH, dW, dH, weight, columns, kW, kH, dW, dH, padW, padH,
padW, padH, dilationW, dilationH, group, dilationW, dilationH, group, deformable_group,
deformable_group, im2col_step); im2col_step);
} }
/*void deform_conv_backward_parameters_cuda(
* Tensor input, Tensor offset, Tensor gradOutput, Tensor gradWeight,
* Tensor columns, Tensor ones, int kW, int kH, int dW, int dH, int padW,
* int padH, int dilationW, int dilationH, int group, int deformable_group,
* float scale, int im2col_step);
*/
void deform_conv_backward_parameters_cuda_parrots( void deform_conv_backward_parameters_cuda_parrots(
CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins, CudaContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) { OperatorBase::out_list_t& outs) {
...@@ -122,10 +103,112 @@ void deform_conv_backward_parameters_cuda_parrots( ...@@ -122,10 +103,112 @@ void deform_conv_backward_parameters_cuda_parrots(
auto gradWeight = buildATensor(ctx, outs[0]); auto gradWeight = buildATensor(ctx, outs[0]);
auto columns = buildATensor(ctx, outs[1]); auto columns = buildATensor(ctx, outs[1]);
auto ones = buildATensor(ctx, outs[2]); auto ones = buildATensor(ctx, outs[2]);
deform_conv_backward_parameters_cuda(input, offset, gradOutput, gradWeight, deform_conv_backward_parameters(input, offset, gradOutput, gradWeight,
columns, ones, kW, kH, dW, dH, padW, columns, ones, kW, kH, dW, dH, padW, padH,
padH, dilationW, dilationH, group, dilationW, dilationH, group, deformable_group,
deformable_group, scale, im2col_step); scale, im2col_step);
}
#endif
void deform_conv_forward_cpu_parrots(HostContext& ctx, const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
int kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, deformable_group,
im2col_step;
SSAttrs(attr)
.get<int>("kW", kW)
.get<int>("kH", kH)
.get<int>("dW", dW)
.get<int>("dH", dH)
.get<int>("padW", padW)
.get<int>("padH", padH)
.get<int>("dilationW", dilationW)
.get<int>("dilationH", dilationH)
.get<int>("group", group)
.get<int>("deformable_group", deformable_group)
.get<int>("im2col_step", im2col_step)
.done();
const auto& input = buildATensor(ctx, ins[0]);
const auto& weight = buildATensor(ctx, ins[1]);
const auto& offset = buildATensor(ctx, ins[2]);
auto output = buildATensor(ctx, outs[0]);
auto columns = buildATensor(ctx, outs[1]);
auto ones = buildATensor(ctx, outs[2]);
deform_conv_forward(input, weight, offset, output, columns, ones, kW, kH, dW,
dH, padW, padH, dilationW, dilationH, group,
deformable_group, im2col_step);
}
void deform_conv_backward_input_cpu_parrots(HostContext& ctx,
const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
int kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, deformable_group,
im2col_step;
SSAttrs(attr)
.get<int>("kW", kW)
.get<int>("kH", kH)
.get<int>("dW", dW)
.get<int>("dH", dH)
.get<int>("padW", padW)
.get<int>("padH", padH)
.get<int>("dilationW", dilationW)
.get<int>("dilationH", dilationH)
.get<int>("group", group)
.get<int>("deformable_group", deformable_group)
.get<int>("im2col_step", im2col_step)
.done();
const auto& input = buildATensor(ctx, ins[0]);
const auto& offset = buildATensor(ctx, ins[1]);
const auto& gradOutput = buildATensor(ctx, ins[2]);
auto gradInput = buildATensor(ctx, outs[0]);
auto gradOffset = buildATensor(ctx, outs[1]);
auto weight = buildATensor(ctx, outs[2]);
auto columns = buildATensor(ctx, outs[3]);
deform_conv_backward_input(input, offset, gradOutput, gradInput, gradOffset,
weight, columns, kW, kH, dW, dH, padW, padH,
dilationW, dilationH, group, deformable_group,
im2col_step);
}
void deform_conv_backward_parameters_cpu_parrots(
HostContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
int kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, deformable_group,
im2col_step;
float scale;
SSAttrs(attr)
.get<int>("kW", kW)
.get<int>("kH", kH)
.get<int>("dW", dW)
.get<int>("dH", dH)
.get<int>("padW", padW)
.get<int>("padH", padH)
.get<int>("dilationW", dilationW)
.get<int>("dilationH", dilationH)
.get<int>("group", group)
.get<int>("deformable_group", deformable_group)
.get<float>("scale", scale)
.get<int>("im2col_step", im2col_step)
.done();
const auto& input = buildATensor(ctx, ins[0]);
const auto& offset = buildATensor(ctx, ins[1]);
const auto& gradOutput = buildATensor(ctx, ins[2]);
auto gradWeight = buildATensor(ctx, outs[0]);
auto columns = buildATensor(ctx, outs[1]);
auto ones = buildATensor(ctx, outs[2]);
deform_conv_backward_parameters(input, offset, gradOutput, gradWeight,
columns, ones, kW, kH, dW, dH, padW, padH,
dilationW, dilationH, group, deformable_group,
scale, im2col_step);
} }
PARROTS_EXTENSION_REGISTER(deform_conv_forward) PARROTS_EXTENSION_REGISTER(deform_conv_forward)
...@@ -142,7 +225,10 @@ PARROTS_EXTENSION_REGISTER(deform_conv_forward) ...@@ -142,7 +225,10 @@ PARROTS_EXTENSION_REGISTER(deform_conv_forward)
.attr("im2col_step") .attr("im2col_step")
.input(3) .input(3)
.output(3) .output(3)
.apply(deform_conv_forward_cpu_parrots)
#ifdef MMCV_WITH_CUDA
.apply(deform_conv_forward_cuda_parrots) .apply(deform_conv_forward_cuda_parrots)
#endif
.done(); .done();
PARROTS_EXTENSION_REGISTER(deform_conv_backward_input) PARROTS_EXTENSION_REGISTER(deform_conv_backward_input)
...@@ -159,7 +245,10 @@ PARROTS_EXTENSION_REGISTER(deform_conv_backward_input) ...@@ -159,7 +245,10 @@ PARROTS_EXTENSION_REGISTER(deform_conv_backward_input)
.attr("im2col_step") .attr("im2col_step")
.input(3) .input(3)
.output(4) .output(4)
.apply(deform_conv_backward_input_cpu_parrots)
#ifdef MMCV_WITH_CUDA
.apply(deform_conv_backward_input_cuda_parrots) .apply(deform_conv_backward_input_cuda_parrots)
#endif
.done(); .done();
PARROTS_EXTENSION_REGISTER(deform_conv_backward_parameters) PARROTS_EXTENSION_REGISTER(deform_conv_backward_parameters)
...@@ -177,6 +266,8 @@ PARROTS_EXTENSION_REGISTER(deform_conv_backward_parameters) ...@@ -177,6 +266,8 @@ PARROTS_EXTENSION_REGISTER(deform_conv_backward_parameters)
.attr("im2col_step") .attr("im2col_step")
.input(3) .input(3)
.output(3) .output(3)
.apply(deform_conv_backward_parameters_cpu_parrots)
#ifdef MMCV_WITH_CUDA
.apply(deform_conv_backward_parameters_cuda_parrots) .apply(deform_conv_backward_parameters_cuda_parrots)
.done();
#endif #endif
.done();
...@@ -4,24 +4,25 @@ ...@@ -4,24 +4,25 @@
#include <torch/extension.h> #include <torch/extension.h>
using namespace at; using namespace at;
void deform_conv_forward_cuda(Tensor input, Tensor weight, Tensor offset, void deform_conv_forward(Tensor input, Tensor weight, Tensor offset,
Tensor output, Tensor columns, Tensor ones, Tensor output, Tensor columns, Tensor ones, int kW,
int kW, int kH, int dW, int dH, int padW, int kH, int dW, int dH, int padW, int padH,
int padH, int dilationW, int dilationH, int group, int dilationW, int dilationH, int group,
int deformable_group, int im2col_step); int deformable_group, int im2col_step);
void deform_conv_backward_input_cuda(Tensor input, Tensor offset, void deform_conv_backward_input(Tensor input, Tensor offset, Tensor gradOutput,
Tensor gradOutput, Tensor gradInput, Tensor gradInput, Tensor gradOffset,
Tensor gradOffset, Tensor weight, Tensor weight, Tensor columns, int kW, int kH,
Tensor columns, int kW, int kH, int dW, int dW, int dH, int padW, int padH,
int dH, int padW, int padH, int dilationW, int dilationW, int dilationH, int group,
int dilationH, int group,
int deformable_group, int im2col_step); int deformable_group, int im2col_step);
void deform_conv_backward_parameters_cuda( void deform_conv_backward_parameters(Tensor input, Tensor offset,
Tensor input, Tensor offset, Tensor gradOutput, Tensor gradWeight, Tensor gradOutput, Tensor gradWeight,
Tensor columns, Tensor ones, int kW, int kH, int dW, int dH, int padW, Tensor columns, Tensor ones, int kW,
int padH, int dilationW, int dilationH, int group, int deformable_group, int kH, int dW, int dH, int padW, int padH,
float scale, int im2col_step); int dilationW, int dilationH, int group,
int deformable_group, float scale,
int im2col_step);
#endif // DEFORM_CONV_PYTORCH_H #endif // DEFORM_CONV_PYTORCH_H
...@@ -2,48 +2,59 @@ ...@@ -2,48 +2,59 @@
#include "pytorch_cpp_helper.hpp" #include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA #ifdef MMCV_WITH_CUDA
void ModulatedDeformConvForwardCUDAKernelLauncher(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w, const int group,
const int deformable_group, const bool with_bias);
void ModulatedDeformConvBackwardCUDAKernelLauncher( void modulated_deformable_im2col_cuda(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset, const Tensor data_im, const Tensor data_offset, const Tensor data_mask,
Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight, const int batch_size, const int channels, const int height_im,
Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output, const int width_im, const int height_col, const int width_col,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, const int stride_h, const int stride_w, const int dilation_h,
const bool with_bias); const int dilation_w, const int deformable_group, Tensor data_col);
void modulated_deform_conv_forward_cuda( void modulated_deformable_col2im_cuda(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset, const Tensor data_col, const Tensor data_offset, const Tensor data_mask,
Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w, const int batch_size, const int channels, const int height_im,
const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int width_im, const int height_col, const int width_col,
const int dilation_h, const int dilation_w, const int group, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
const int deformable_group, const bool with_bias) { const int stride_h, const int stride_w, const int dilation_h,
ModulatedDeformConvForwardCUDAKernelLauncher( const int dilation_w, const int deformable_group, Tensor grad_im);
input, weight, bias, ones, offset, mask, output, columns, kernel_h,
kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group, void modulated_deformable_col2im_coord_cuda(
deformable_group, with_bias); const Tensor data_col, const Tensor data_im, const Tensor data_offset,
} const Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
Tensor grad_offset, Tensor grad_mask);
void modulated_deform_conv_backward_cuda(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight,
Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias) {
ModulatedDeformConvBackwardCUDAKernelLauncher(
input, weight, bias, ones, offset, mask, columns, grad_input, grad_weight,
grad_bias, grad_offset, grad_mask, grad_output, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
deformable_group, with_bias);
}
#endif #endif
void modulated_deformable_im2col_cpu(
const Tensor data_im, const Tensor data_offset, const Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int deformable_group, Tensor data_col);
void modulated_deformable_col2im_cpu(
const Tensor data_col, const Tensor data_offset, const Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int deformable_group, Tensor grad_im);
void modulated_deformable_col2im_coord_cpu(
const Tensor data_col, const Tensor data_im, const Tensor data_offset,
const Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
Tensor grad_offset, Tensor grad_mask);
void modulated_deform_conv_forward( void modulated_deform_conv_forward(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset, Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w, Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w,
...@@ -61,15 +72,98 @@ void modulated_deform_conv_forward( ...@@ -61,15 +72,98 @@ void modulated_deform_conv_forward(
CHECK_CUDA_INPUT(output); CHECK_CUDA_INPUT(output);
CHECK_CUDA_INPUT(columns); CHECK_CUDA_INPUT(columns);
modulated_deform_conv_forward_cuda(
input, weight, bias, ones, offset, mask, output, columns, kernel_h,
kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w,
group, deformable_group, with_bias);
#else #else
AT_ERROR("ModulatedDeformConv is not compiled with GPU support"); AT_ERROR("ModulatedDeformConv is not compiled with GPU support");
#endif #endif
} else { } else {
AT_ERROR("ModulatedDeformConv is not implemented on CPU"); CHECK_CPU_INPUT(input);
CHECK_CPU_INPUT(weight);
CHECK_CPU_INPUT(bias);
CHECK_CPU_INPUT(ones);
CHECK_CPU_INPUT(offset);
CHECK_CPU_INPUT(mask);
CHECK_CPU_INPUT(output);
CHECK_CPU_INPUT(columns);
}
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
// resize output
output = output.view({batch, channels_out, height_out, width_out}).zero_();
// resize temporary columns
columns =
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
input.options());
output = output.view({output.size(0), group, output.size(1) / group,
output.size(2), output.size(3)});
for (int b = 0; b < batch; b++) {
if (input.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
modulated_deformable_im2col_cuda(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
#endif
} else {
modulated_deformable_im2col_cpu(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
}
// divide into group
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
for (int g = 0; g < group; g++) {
output[b][g] = output[b][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output[b][g]);
}
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
output = output.view({output.size(0), output.size(1) * output.size(2),
output.size(3), output.size(4)});
if (with_bias) {
output += bias.view({1, bias.size(0), 1, 1});
} }
} }
...@@ -96,15 +190,149 @@ void modulated_deform_conv_backward( ...@@ -96,15 +190,149 @@ void modulated_deform_conv_backward(
CHECK_CUDA_INPUT(grad_mask); CHECK_CUDA_INPUT(grad_mask);
CHECK_CUDA_INPUT(grad_output); CHECK_CUDA_INPUT(grad_output);
modulated_deform_conv_backward_cuda(
input, weight, bias, ones, offset, mask, columns, grad_input,
grad_weight, grad_bias, grad_offset, grad_mask, grad_output, kernel_h,
kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w,
group, deformable_group, with_bias);
#else #else
AT_ERROR("ModulatedDeformConv is not compiled with GPU support"); AT_ERROR("ModulatedDeformConv is not compiled with GPU support");
#endif #endif
} else { } else {
AT_ERROR("ModulatedDeformConv is not implemented on CPU"); CHECK_CPU_INPUT(input);
CHECK_CPU_INPUT(weight);
CHECK_CPU_INPUT(bias);
CHECK_CPU_INPUT(ones);
CHECK_CPU_INPUT(offset);
CHECK_CPU_INPUT(mask);
CHECK_CPU_INPUT(columns);
CHECK_CPU_INPUT(grad_input);
CHECK_CPU_INPUT(grad_weight);
CHECK_CPU_INPUT(grad_bias);
CHECK_CPU_INPUT(grad_offset);
CHECK_CPU_INPUT(grad_mask);
CHECK_CPU_INPUT(grad_output);
}
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
grad_input = grad_input.view({batch, channels, height, width});
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
input.options());
grad_output =
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
grad_output.size(2), grad_output.size(3)});
for (int b = 0; b < batch; b++) {
// divide int group
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++) {
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
grad_output[b][g].flatten(1), 0.0f, 1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
if (input.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(
columns, input[b], offset[b], mask[b], 1, channels, height, width,
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
grad_mask[b]);
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(
columns, offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, grad_input[b]);
// gradient w.r.t. weight, dWeight should accumulate across the batch and
// group
modulated_deformable_im2col_cuda(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
#endif
} else {
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cpu(
columns, input[b], offset[b], mask[b], 1, channels, height, width,
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
grad_mask[b]);
// gradient w.r.t. input data
modulated_deformable_col2im_cpu(
columns, offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, grad_input[b]);
// gradient w.r.t. weight, dWeight should accumulate across the batch and
// group
modulated_deformable_im2col_cpu(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
}
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
grad_weight.size(1), grad_weight.size(2),
grad_weight.size(3)});
if (with_bias)
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
for (int g = 0; g < group; g++) {
grad_weight[g] =
grad_weight[g]
.flatten(1)
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
.view_as(grad_weight[g]);
if (with_bias) {
grad_bias[g] =
grad_bias[g]
.view({-1, 1})
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
.view(-1);
}
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2), grad_weight.size(3),
grad_weight.size(4)});
if (with_bias)
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
} }
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
grad_output.size(2), grad_output.size(3),
grad_output.size(4)});
} }
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
template <typename T>
T dmcn_im2col_bilinear_cpu(const T *input, const int data_width,
const int height, const int width, T h, T w) {
int h_low = floorf(h);
int w_low = floorf(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
T lh = h - h_low;
T lw = w - w_low;
T hh = 1 - lh, hw = 1 - lw;
T v1 = 0;
if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low];
T v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = input[h_low * data_width + w_high];
T v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = input[h_high * data_width + w_low];
T v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = input[h_high * data_width + w_high];
T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename T>
T dmcn_get_gradient_weight_cpu(T argmax_h, T argmax_w, const int h, const int w,
const int height, const int width) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
// empty
return 0;
}
int argmax_h_low = floorf(argmax_h);
int argmax_w_low = floorf(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
T weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
template <typename T>
T dmcn_get_coordinate_weight_cpu(T argmax_h, T argmax_w, const int height,
const int width, const T *im_data,
const int data_width, const int bp_dir) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
// empty
return 0;
}
int argmax_h_low = floorf(argmax_h);
int argmax_w_low = floorf(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
T weight = 0;
if (bp_dir == 0) {
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) *
im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) *
im_data[argmax_h_high * data_width + argmax_w_high];
} else if (bp_dir == 1) {
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
template <typename T>
void modulated_deformable_im2col_cpu_kernel(
const int n, const T *data_im, const T *data_offset, const T *data_mask,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group, const int batch_size,
const int num_channels, const int deformable_group, const int height_col,
const int width_col, T *data_col) {
for (int index = 0; index < n; index++) {
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
T *data_col_ptr =
data_col +
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
const T *data_im_ptr =
data_im + (b_col * num_channels + c_im) * height * width;
const T *data_offset_ptr =
data_offset + (b_col * deformable_group + deformable_group_index) * 2 *
kernel_h * kernel_w * height_col * width_col;
const T *data_mask_ptr =
data_mask + (b_col * deformable_group + deformable_group_index) *
kernel_h * kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col +
w_col;
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
const T mask = data_mask_ptr[data_mask_hw_ptr];
T val = static_cast<T>(0);
const T h_im = h_in + i * dilation_h + offset_h;
const T w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, height, width,
h_im, w_im);
*data_col_ptr = val * mask;
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
template <typename T>
void modulated_deformable_col2im_cpu_kernel(
const int n, const T *data_col, const T *data_offset, const T *data_mask,
const int channels, const int height, const int width, const int kernel_h,
const int kernel_w, const int pad_h, const int pad_w, const int stride_h,
const int stride_w, const int dilation_h, const int dilation_w,
const int channel_per_deformable_group, const int batch_size,
const int deformable_group, const int height_col, const int width_col,
T *grad_im) {
for (int index = 0; index < n; index++) {
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i =
(index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c =
index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const T *data_offset_ptr =
data_offset + (b * deformable_group + deformable_group_index) * 2 *
kernel_h * kernel_w * height_col * width_col;
const T *data_mask_ptr =
data_mask + (b * deformable_group + deformable_group_index) * kernel_h *
kernel_w * height_col * width_col;
const int data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const int data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
const T mask = data_mask_ptr[data_mask_hw_ptr];
const T cur_inv_h_data = h_in + i * dilation_h + offset_h;
const T cur_inv_w_data = w_in + j * dilation_w + offset_w;
const T cur_top_grad = data_col[index] * mask;
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++) {
for (int dx = -2; dx <= 2; dx++) {
if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1) {
int cur_bottom_grad_pos =
((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
T weight = dmcn_get_gradient_weight_cpu(cur_inv_h_data,
cur_inv_w_data, cur_h + dy,
cur_w + dx, height, width);
*(grad_im + cur_bottom_grad_pos) += weight * cur_top_grad;
}
}
}
}
}
template <typename T>
void modulated_deformable_col2im_coord_cpu_kernel(
const int n, const T *data_col, const T *data_im, const T *data_offset,
const T *data_mask, const int channels, const int height, const int width,
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int channel_per_deformable_group,
const int batch_size, const int offset_channels, const int deformable_group,
const int height_col, const int width_col, T *grad_offset, T *grad_mask) {
for (int index = 0; index < n; index++) {
T val = 0, mval = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const T *data_col_ptr = data_col + deformable_group_index *
channel_per_deformable_group *
batch_size * width_col * height_col;
const T *data_im_ptr =
data_im + (b * deformable_group + deformable_group_index) *
channel_per_deformable_group / kernel_h / kernel_w *
height * width;
const T *data_offset_ptr =
data_offset + (b * deformable_group + deformable_group_index) * 2 *
kernel_h * kernel_w * height_col * width_col;
const T *data_mask_ptr =
data_mask + (b * deformable_group + deformable_group_index) * kernel_h *
kernel_w * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group;
col_c += col_step) {
const int col_pos =
(((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i =
(col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr =
(((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr =
(((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
w_out);
const int data_mask_hw_ptr =
(((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
const T mask = data_mask_ptr[data_mask_hw_ptr];
T inv_h = h_in + i * dilation_h + offset_h;
T inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
inv_h = inv_w = -2;
else
mval += data_col_ptr[col_pos] *
dmcn_im2col_bilinear_cpu(data_im_ptr + cnt * height * width,
width, height, width, inv_h, inv_w);
const T weight = dmcn_get_coordinate_weight_cpu(
inv_h, inv_w, height, width, data_im_ptr + cnt * height * width,
width, bp_dir);
val += weight * data_col_ptr[col_pos] * mask;
cnt += 1;
}
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
grad_offset[index] = val;
if (offset_c % 2 == 0)
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group +
// deformable_group_index) * kernel_h * kernel_w + offset_c / 2) *
// height_col + h) * width_col + w], mask_req, mval);
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h *
kernel_w +
offset_c / 2) *
height_col +
h) *
width_col +
w] = mval;
}
}
void modulated_deformable_im2col_cpu(
const Tensor data_im, const Tensor data_offset, const Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int deformable_group, Tensor data_col) {
// num_axes should be smaller than block size
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * batch_size * height_col * width_col;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "modulated_deformable_im2col_cpu", ([&] {
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
modulated_deformable_im2col_cpu_kernel(
num_kernels, data_im_, data_offset_, data_mask_, height_im,
width_im, kernel_h, kenerl_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group, batch_size,
channels, deformable_group, height_col, width_col, data_col_);
}));
}
void modulated_deformable_col2im_cpu(
const Tensor data_col, const Tensor data_offset, const Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int deformable_group, Tensor grad_im) {
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels =
channels * kernel_h * kernel_w * batch_size * height_col * width_col;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "modulated_deformable_col2im_cpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
modulated_deformable_col2im_cpu_kernel(
num_kernels, data_col_, data_offset_, data_mask_, channels,
height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h,
stride_w, dilation_h, dilation_w, channel_per_deformable_group,
batch_size, deformable_group, height_col, width_col, grad_im_);
}));
}
void modulated_deformable_col2im_coord_cpu(
const Tensor data_col, const Tensor data_im, const Tensor data_offset,
const Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
Tensor grad_offset, Tensor grad_mask) {
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h *
kernel_w * deformable_group;
const int channel_per_deformable_group =
channels * kernel_h * kernel_w / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "modulated_deformable_col2im_coord_cpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
scalar_t *grad_mask_ = grad_mask.data_ptr<scalar_t>();
modulated_deformable_col2im_coord_cpu_kernel(
num_kernels, data_col_, data_im_, data_offset_, data_mask_,
channels, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w,
stride_h, stride_w, dilation_h, dilation_w,
channel_per_deformable_group, batch_size,
2 * kernel_h * kernel_w * deformable_group, deformable_group,
height_col, width_col, grad_offset_, grad_mask_);
}));
}
...@@ -37,9 +37,9 @@ void modulated_deform_conv_forward_cuda_parrots( ...@@ -37,9 +37,9 @@ void modulated_deform_conv_forward_cuda_parrots(
auto output = buildATensor(ctx, outs[0]); auto output = buildATensor(ctx, outs[0]);
auto columns = buildATensor(ctx, outs[1]); auto columns = buildATensor(ctx, outs[1]);
modulated_deform_conv_forward_cuda( modulated_deform_conv_forward(input, weight, bias, ones, offset, mask, output,
input, weight, bias, ones, offset, mask, output, columns, kernel_h, columns, kernel_h, kernel_w, stride_h, stride_w,
kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group, pad_h, pad_w, dilation_h, dilation_w, group,
deformable_group, with_bias); deformable_group, with_bias);
} }
...@@ -76,13 +76,88 @@ void modulated_deform_conv_backward_cuda_parrots( ...@@ -76,13 +76,88 @@ void modulated_deform_conv_backward_cuda_parrots(
auto grad_offset = buildATensor(ctx, outs[4]); auto grad_offset = buildATensor(ctx, outs[4]);
auto grad_mask = buildATensor(ctx, outs[5]); auto grad_mask = buildATensor(ctx, outs[5]);
auto grad_output = buildATensor(ctx, outs[6]); auto grad_output = buildATensor(ctx, outs[6]);
modulated_deform_conv_backward_cuda( modulated_deform_conv_backward(
input, weight, bias, ones, offset, mask, columns, grad_input, grad_weight, input, weight, bias, ones, offset, mask, columns, grad_input, grad_weight,
grad_bias, grad_offset, grad_mask, grad_output, kernel_h, kernel_w, grad_bias, grad_offset, grad_mask, grad_output, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
deformable_group, with_bias); deformable_group, with_bias);
} }
#endif
void modulated_deform_conv_forward_cpu_parrots(
HostContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
int kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h,
dilation_w, group, deformable_group, with_bias;
SSAttrs(attr)
.get<int>("kernel_h", kernel_h)
.get<int>("kernel_w", kernel_w)
.get<int>("stride_h", stride_h)
.get<int>("stride_w", stride_w)
.get<int>("pad_h", pad_h)
.get<int>("pad_w", pad_w)
.get<int>("dilation_h", dilation_h)
.get<int>("dilation_w", dilation_w)
.get<int>("group", group)
.get<int>("deformable_group", deformable_group)
.get<int>("with_bias", with_bias)
.done();
const auto& input = buildATensor(ctx, ins[0]);
const auto& weight = buildATensor(ctx, ins[1]);
const auto& bias = buildATensor(ctx, ins[2]);
const auto& ones = buildATensor(ctx, ins[3]);
const auto& offset = buildATensor(ctx, ins[4]);
const auto& mask = buildATensor(ctx, ins[5]);
auto output = buildATensor(ctx, outs[0]);
auto columns = buildATensor(ctx, outs[1]);
modulated_deform_conv_forward(input, weight, bias, ones, offset, mask, output,
columns, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group,
deformable_group, with_bias);
}
void modulated_deform_conv_backward_cpu_parrots(
HostContext& ctx, const SSElement& attr, const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
int kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h,
dilation_w, group, deformable_group, with_bias;
SSAttrs(attr)
.get<int>("kernel_h", kernel_h)
.get<int>("kernel_w", kernel_w)
.get<int>("stride_h", stride_h)
.get<int>("stride_w", stride_w)
.get<int>("pad_h", pad_h)
.get<int>("pad_w", pad_w)
.get<int>("dilation_h", dilation_h)
.get<int>("dilation_w", dilation_w)
.get<int>("group", group)
.get<int>("deformable_group", deformable_group)
.get<int>("with_bias", with_bias)
.done();
const auto& input = buildATensor(ctx, ins[0]);
const auto& weight = buildATensor(ctx, ins[1]);
const auto& bias = buildATensor(ctx, ins[2]);
const auto& ones = buildATensor(ctx, ins[3]);
const auto& offset = buildATensor(ctx, ins[4]);
const auto& mask = buildATensor(ctx, ins[5]);
auto columns = buildATensor(ctx, outs[0]);
auto grad_input = buildATensor(ctx, outs[1]);
auto grad_weight = buildATensor(ctx, outs[2]);
auto grad_bias = buildATensor(ctx, outs[3]);
auto grad_offset = buildATensor(ctx, outs[4]);
auto grad_mask = buildATensor(ctx, outs[5]);
auto grad_output = buildATensor(ctx, outs[6]);
modulated_deform_conv_backward(
input, weight, bias, ones, offset, mask, columns, grad_input, grad_weight,
grad_bias, grad_offset, grad_mask, grad_output, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
deformable_group, with_bias);
}
PARROTS_EXTENSION_REGISTER(modulated_deform_conv_forward) PARROTS_EXTENSION_REGISTER(modulated_deform_conv_forward)
.attr("kernel_h") .attr("kernel_h")
.attr("kernel_w") .attr("kernel_w")
...@@ -97,7 +172,10 @@ PARROTS_EXTENSION_REGISTER(modulated_deform_conv_forward) ...@@ -97,7 +172,10 @@ PARROTS_EXTENSION_REGISTER(modulated_deform_conv_forward)
.attr("with_bias") .attr("with_bias")
.input(6) .input(6)
.output(2) .output(2)
.apply(modulated_deform_conv_forward_cpu_parrots)
#ifdef MMCV_WITH_CUDA
.apply(modulated_deform_conv_forward_cuda_parrots) .apply(modulated_deform_conv_forward_cuda_parrots)
#endif
.done(); .done();
PARROTS_EXTENSION_REGISTER(modulated_deform_conv_backward) PARROTS_EXTENSION_REGISTER(modulated_deform_conv_backward)
...@@ -114,6 +192,8 @@ PARROTS_EXTENSION_REGISTER(modulated_deform_conv_backward) ...@@ -114,6 +192,8 @@ PARROTS_EXTENSION_REGISTER(modulated_deform_conv_backward)
.attr("with_bias") .attr("with_bias")
.input(6) .input(6)
.output(7) .output(7)
.apply(modulated_deform_conv_backward_cpu_parrots)
#ifdef MMCV_WITH_CUDA
.apply(modulated_deform_conv_backward_cuda_parrots) .apply(modulated_deform_conv_backward_cuda_parrots)
.done();
#endif #endif
.done();
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
#include <torch/extension.h> #include <torch/extension.h>
using namespace at; using namespace at;
void modulated_deform_conv_forward_cuda( void modulated_deform_conv_forward(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset, Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w, Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w, const int group, const int dilation_h, const int dilation_w, const int group,
const int deformable_group, const bool with_bias); const int deformable_group, const bool with_bias);
void modulated_deform_conv_backward_cuda( void modulated_deform_conv_backward(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset, Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight, Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight,
Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output, Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment