Commit 62304425 authored by Kai Chen's avatar Kai Chen
Browse files

use extension.h instead of torch.h and code formatting

parent 441015ea
This diff is collapsed.
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c
// based on
// author: Charles Shang // author: Charles Shang
// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob /mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c #include <torch/extension.h>
#include <torch/torch.h>
#include <cmath> #include <cmath>
#include <vector> #include <vector>
void DeformablePSROIPoolForward(const at::Tensor data, void DeformablePSROIPoolForward(
const at::Tensor bbox, const at::Tensor data, const at::Tensor bbox, const at::Tensor trans,
const at::Tensor trans, at::Tensor out, at::Tensor top_count, const int batch, const int channels,
at::Tensor out, const int height, const int width, const int num_bbox,
at::Tensor top_count, const int channels_trans, const int no_trans, const float spatial_scale,
const int batch, const int output_dim, const int group_size, const int pooled_size,
const int channels, const int part_size, const int sample_per_part, const float trans_std);
const int height,
const int width,
const int num_bbox,
const int channels_trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std);
void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad, void DeformablePSROIPoolBackwardAcc(
const at::Tensor data, const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox,
const at::Tensor bbox, const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad,
const at::Tensor trans, at::Tensor trans_grad, const int batch, const int channels,
const at::Tensor top_count, const int height, const int width, const int num_bbox,
at::Tensor in_grad, const int channels_trans, const int no_trans, const float spatial_scale,
at::Tensor trans_grad, const int output_dim, const int group_size, const int pooled_size,
const int batch, const int part_size, const int sample_per_part, const float trans_std);
const int channels,
const int height,
const int width,
const int num_bbox,
const int channels_trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std);
void deform_psroi_pooling_cuda_forward(at::Tensor input, at::Tensor bbox, void deform_psroi_pooling_cuda_forward(
at::Tensor trans, at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
at::Tensor out, at::Tensor top_count, at::Tensor top_count, const int no_trans, const float spatial_scale,
const int no_trans, const int output_dim, const int group_size, const int pooled_size,
const float spatial_scale, const int part_size, const int sample_per_part, const float trans_std) {
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
const int batch = input.size(0); const int batch = input.size(0);
...@@ -75,33 +45,18 @@ void deform_psroi_pooling_cuda_forward(at::Tensor input, at::Tensor bbox, ...@@ -75,33 +45,18 @@ void deform_psroi_pooling_cuda_forward(at::Tensor input, at::Tensor bbox,
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).", AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
out.size(0), num_bbox); out.size(0), num_bbox);
DeformablePSROIPoolForward(input, bbox, trans, out, top_count, DeformablePSROIPoolForward(
batch, channels, height, width, input, bbox, trans, out, top_count, batch, channels, height, width,
num_bbox, num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size,
channels_trans, pooled_size, part_size, sample_per_part, trans_std);
no_trans,
spatial_scale,
output_dim,
group_size,
pooled_size,
part_size,
sample_per_part,
trans_std);
} }
void deform_psroi_pooling_cuda_backward(at::Tensor out_grad, void deform_psroi_pooling_cuda_backward(
at::Tensor input, at::Tensor bbox, at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
at::Tensor trans, at::Tensor top_count, at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
at::Tensor input_grad, at::Tensor trans_grad, const int no_trans, const float spatial_scale, const int output_dim,
const int no_trans, const int group_size, const int pooled_size, const int part_size,
const float spatial_scale, const int sample_per_part, const float trans_std) {
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous"); AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
...@@ -116,29 +71,17 @@ void deform_psroi_pooling_cuda_backward(at::Tensor out_grad, ...@@ -116,29 +71,17 @@ void deform_psroi_pooling_cuda_backward(at::Tensor out_grad,
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).", AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
out_grad.size(0), num_bbox); out_grad.size(0), num_bbox);
DeformablePSROIPoolBackwardAcc(out_grad, DeformablePSROIPoolBackwardAcc(
input, out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch,
bbox, channels, height, width, num_bbox, channels_trans, no_trans,
trans, spatial_scale, output_dim, group_size, pooled_size, part_size,
top_count, sample_per_part, trans_std);
input_grad,
trans_grad,
batch, channels, height, width, num_bbox,
channels_trans,
no_trans,
spatial_scale,
output_dim,
group_size,
pooled_size,
part_size,
sample_per_part,
trans_std);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
{
m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward, m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward,
"deform psroi pooling forward(CUDA)"); "deform psroi pooling forward(CUDA)");
m.def("deform_psroi_pooling_cuda_backward", &deform_psroi_pooling_cuda_backward, m.def("deform_psroi_pooling_cuda_backward",
&deform_psroi_pooling_cuda_backward,
"deform psroi pooling backward(CUDA)"); "deform psroi pooling backward(CUDA)");
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment