Commit 03c38bdc authored by yhcao6's avatar yhcao6
Browse files

separate dcn and dpool cpp, restruct some code

parent 66489d6b
......@@ -2,11 +2,12 @@ from .functions.deform_conv import deform_conv, modulated_deform_conv
from .functions.deform_pool import deform_roi_pooling
from .modules.deform_conv import (DeformConv, ModulatedDeformConv,
ModulatedDeformConvPack)
from .modules.deform_pool import (DeformRoIPooling,
from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack,
ModulatedDeformRoIPoolingPack)
__all__ = [
'DeformConv', 'DeformRoIPooling', 'ModulatedDeformRoIPoolingPack',
'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
'DeformConv', 'DeformRoIPooling', 'DeformRoIPoolingPack',
'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv',
'ModulatedDeformConvPack', 'deform_conv',
'modulated_deform_conv', 'deform_roi_pooling'
]
......@@ -3,7 +3,6 @@ from torch.autograd import Function
from torch.nn.modules.utils import _pair
from .. import deform_conv_cuda
from .. import modulated_dcn_cuda as _backend
class DeformConvFunction(Function):
......@@ -124,7 +123,7 @@ class ModulatedDeformConvFunction(Function):
output = input.new(
*ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
ctx._bufs = [input.new(), input.new()]
_backend.modulated_deform_conv_cuda_forward(
deform_conv_cuda.modulated_deform_conv_cuda_forward(
input, weight, bias, ctx._bufs[0], offset, mask, output,
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
......@@ -141,7 +140,7 @@ class ModulatedDeformConvFunction(Function):
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
_backend.modulated_deform_conv_cuda_backward(
deform_conv_cuda.modulated_deform_conv_cuda_backward(
input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
......
import torch
from torch.autograd import Function
from .. import modulated_dcn_cuda as _backend
from .. import deform_pool_cuda
class DeformRoIPoolingFunction(Function):
......@@ -36,7 +36,7 @@ class DeformRoIPoolingFunction(Function):
*DeformRoIPoolingFunction._infer_shape(ctx, data, rois))
output_count = data.new(
*DeformRoIPoolingFunction._infer_shape(ctx, data, rois))
_backend.deform_psroi_pooling_cuda_forward(
deform_pool_cuda.deform_psroi_pooling_cuda_forward(
data, rois, offset, output, output_count, ctx.no_trans,
ctx.spatial_scale, ctx.output_dim, ctx.group_size, ctx.pooled_size,
ctx.part_size, ctx.sample_per_part, ctx.trans_std)
......@@ -63,7 +63,7 @@ class DeformRoIPoolingFunction(Function):
grad_input = torch.zeros_like(data)
grad_offset = torch.zeros_like(offset)
_backend.deform_psroi_pooling_cuda_backward(
deform_pool_cuda.deform_psroi_pooling_cuda_backward(
grad_output, data, rois, offset, output_count, grad_input,
grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.output_dim,
ctx.group_size, ctx.pooled_size, ctx.part_size,
......
......@@ -16,7 +16,7 @@ class DeformConv(nn.Module):
stride=1,
padding=0,
dilation=1,
num_deformable_groups=1):
deformable_groups=1):
super(DeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
......@@ -24,7 +24,7 @@ class DeformConv(nn.Module):
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.num_deformable_groups = num_deformable_groups
self.deformable_groups = deformable_groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
......@@ -41,7 +41,7 @@ class DeformConv(nn.Module):
def forward(self, input, offset):
return deform_conv(input, offset, self.weight, self.stride,
self.padding, self.dilation,
self.num_deformable_groups)
self.deformable_groups)
class ModulatedDeformConv(nn.Module):
......@@ -54,7 +54,7 @@ class ModulatedDeformConv(nn.Module):
padding,
dilation=1,
deformable_groups=1,
no_bias=True):
bias=False):
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
......@@ -63,13 +63,12 @@ class ModulatedDeformConv(nn.Module):
self.padding = padding
self.dilation = dilation
self.deformable_groups = deformable_groups
self.no_bias = no_bias
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
self.bias = nn.Parameter(torch.zeros(out_channels))
self.reset_parameters()
if self.no_bias:
if not bias:
self.bias.requires_grad = False
def reset_parameters(self):
......@@ -96,10 +95,10 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
padding,
dilation=1,
deformable_groups=1,
no_bias=False):
bias=True):
super(ModulatedDeformConvPack,
self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, deformable_groups, no_bias)
padding, dilation, deformable_groups, bias)
self.conv_offset_mask = nn.Conv2d(
self.in_channels,
......
......@@ -7,7 +7,7 @@ class DeformRoIPooling(nn.Module):
def __init__(self,
spatial_scale,
pooled_size,
out_size,
output_dim,
no_trans,
group_size=1,
......@@ -16,12 +16,11 @@ class DeformRoIPooling(nn.Module):
trans_std=.0):
super(DeformRoIPooling, self).__init__()
self.spatial_scale = spatial_scale
self.pooled_size = pooled_size
self.out_size = pooled_size
self.out_size = out_size
self.output_dim = output_dim
self.no_trans = no_trans
self.group_size = group_size
self.part_size = pooled_size if part_size is None else part_size
self.part_size = out_size if part_size is None else part_size
self.sample_per_part = sample_per_part
self.trans_std = trans_std
......@@ -29,7 +28,7 @@ class DeformRoIPooling(nn.Module):
if self.no_trans:
offset = data.new()
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.pooled_size,
data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
......@@ -38,7 +37,7 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
def __init__(self,
spatial_scale,
pooled_size,
out_size,
output_dim,
no_trans,
group_size=1,
......@@ -47,7 +46,7 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
trans_std=.0,
deform_fc_dim=1024):
super(ModulatedDeformRoIPoolingPack, self).__init__(
spatial_scale, pooled_size, output_dim, no_trans, group_size,
spatial_scale, out_size, output_dim, no_trans, group_size,
part_size, sample_per_part, trans_std)
self.deform_fc_dim = deform_fc_dim
......@@ -55,20 +54,20 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
if not no_trans:
self.offset_fc = nn.Sequential(
nn.Linear(
self.pooled_size * self.pooled_size * self.output_dim,
self.out_size * self.out_size * self.output_dim,
self.deform_fc_dim), nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim,
self.pooled_size * self.pooled_size * 2))
self.out_size * self.out_size * 2))
self.offset_fc[4].weight.data.zero_()
self.offset_fc[4].bias.data.zero_()
self.mask_fc = nn.Sequential(
nn.Linear(
self.pooled_size * self.pooled_size * self.output_dim,
self.out_size * self.out_size * self.output_dim,
self.deform_fc_dim), nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim,
self.pooled_size * self.pooled_size * 1),
self.out_size * self.out_size * 1),
nn.Sigmoid())
self.mask_fc[2].weight.data.zero_()
self.mask_fc[2].bias.data.zero_()
......@@ -80,19 +79,72 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
n = rois.shape[0]
offset = data.new()
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.pooled_size, self.output_dim, True,
self.out_size, self.output_dim, True,
self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.pooled_size, self.pooled_size)
offset = offset.view(n, 2, self.out_size, self.out_size)
mask = self.mask_fc(x.view(n, -1))
mask = mask.view(n, 1, self.pooled_size, self.pooled_size)
mask = mask.view(n, 1, self.out_size, self.out_size)
feat = deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.pooled_size,
data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std) * mask
return feat
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.pooled_size,
data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
class DeformRoIPoolingPack(DeformRoIPooling):
def __init__(self,
spatial_scale,
out_size,
output_dim,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0,
deform_fc_dim=1024):
super(DeformRoIPoolingPack, self).__init__(
spatial_scale, out_size, output_dim, no_trans, group_size,
part_size, sample_per_part, trans_std)
self.deform_fc_dim = deform_fc_dim
if not no_trans:
self.offset_fc = nn.Sequential(
nn.Linear(
self.out_size * self.out_size * self.output_dim,
self.deform_fc_dim), nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim, self.deform_fc_dim),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_dim,
self.out_size * self.out_size * 2))
self.offset_fc[4].weight.data.zero_()
self.offset_fc[4].bias.data.zero_()
def forward(self, data, rois):
if self.no_trans:
offset = data.new()
else:
n = rois.shape[0]
offset = data.new()
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.output_dim, True,
self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size, self.out_size)
feat = deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std)
return feat
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.output_dim, self.no_trans, self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
......@@ -2,11 +2,14 @@ from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='deform_conv_cuda',
name='deform_conv',
ext_modules=[
CUDAExtension('deform_conv_cuda', [
'src/deform_conv_cuda.cpp',
'src/deform_conv_cuda_kernel.cu',
]),
CUDAExtension('deform_pool_cuda', [
'src/deform_pool_cuda.cpp', 'src/deform_pool_cuda_kernel.cu'
]),
],
cmdclass={'build_ext': BuildExtension})
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='modulated_dcn_cuda',
ext_modules=[
CUDAExtension('modulated_dcn_cuda', [
'src/modulated_dcn_cuda.cpp',
'src/modulated_deform_im2col_cuda.cu',
'src/deform_psroi_pooling_cuda.cu'
]),
],
cmdclass={'build_ext': BuildExtension})
......@@ -33,6 +33,32 @@ void deformable_col2im_coord(const at::Tensor data_col,
const int dilation_w, const int parallel_imgs,
const int deformable_group, at::Tensor grad_offset);
void modulated_deformable_im2col_cuda(const at::Tensor data_im, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor data_col);
void modulated_deformable_col2im_cuda(const at::Tensor data_col, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor grad_im);
void modulated_deformable_col2im_coord_cuda(const at::Tensor data_col, const at::Tensor data_im,
const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kenerl_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor grad_offset,
at::Tensor grad_mask);
void shape_check(at::Tensor input, at::Tensor offset,
at::Tensor *gradOutput, at::Tensor weight, int kH, int kW,
int dH, int dW, int padH, int padW, int dilationH,
......@@ -256,16 +282,6 @@ int deform_conv_backward_input_cuda(
{batchSize / im2col_step, im2col_step, nOutputPlane, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
gradOutputBuffer = gradOutputBuffer.view(
{batchSize / im2col_step, nOutputPlane, im2col_step, outputHeight, outputWidth});
gradOutputBuffer.copy_(gradOutput);
gradOutputBuffer = gradOutputBuffer.view(
{batchSize / im2col_step, nOutputPlane, im2col_step * outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
gradOutput = gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
gradInput = gradInput.view(
{batchSize / im2col_step, im2col_step, nInputPlane, inputHeight, inputWidth});
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, inputHeight, inputWidth});
......@@ -276,7 +292,7 @@ int deform_conv_backward_input_cuda(
for (int elt = 0; elt < batchSize / im2col_step; elt++)
{
columns = columns.addmm_(weight.flatten(1).transpose(0, 1), gradOutputBuffer[elt].flatten(1), 0.0f, 1.0f);
columns = columns.addmm_(weight.flatten(1).transpose(0, 1), gradOutput[elt].flatten(1), 0.0f, 1.0f);
deformable_col2im_coord(
columns, input[elt], offset[elt],
......@@ -289,6 +305,9 @@ int deform_conv_backward_input_cuda(
deformable_group, gradInput[elt]);
}
gradOutput.transpose_(1, 2);
gradOutput = gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
gradOffset = gradOffset.view({batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
......@@ -394,6 +413,148 @@ int deform_conv_backward_parameters_cuda(
return 1;
}
void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask,
at::Tensor output, at::Tensor columns,
int kernel_h, int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const int deformable_group)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out)
{
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.type());
}
// resize output
output = output.view({batch, channels_out, height_out, width_out});
// resize temporary columns
columns = at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.type());
for (int b = 0; b < batch; b++)
{
// Do Bias first:
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
// (N x 1) (1 x M)
output[b] = output[b].flatten(1).addmm_(bias.view({-1, 1}), ones.view({1, -1}), 0.0f, 1.0f).view_as(output[b]);
modulated_deformable_im2col_cuda(input[b], offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
deformable_group, columns);
//(k * m) x (m * n)
// Y = WC
output[b] = output[b].flatten(1).addmm_(weight.flatten(1), columns).view_as(output[b]);
}
}
void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask,
at::Tensor columns,
at::Tensor grad_input, at::Tensor grad_weight,
at::Tensor grad_bias, at::Tensor grad_offset,
at::Tensor grad_mask, at::Tensor grad_output,
int kernel_h, int kernel_w,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dilation_h, int dilation_w,
int deformable_group)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out)
{
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.type());
}
grad_input = grad_input.view({batch, channels, height, width});
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, input.type());
for (int b = 0; b < batch; b++)
{
columns.addmm_(weight.flatten(1).transpose(0, 1), grad_output[b].flatten(1), 0.0f, 1.0f);
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(columns, input[b], offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_offset[b], grad_mask[b]);
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(columns, offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_input[b]);
// gradient w.r.t. weight, dWeight should accumulate across the batch and group
modulated_deformable_im2col_cuda(input[b], offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
columns);
grad_weight = grad_weight.flatten(1).addmm_(grad_output[b].flatten(1), columns.transpose(0, 1)).view_as(grad_weight);
grad_bias = grad_bias.view({-1, 1}).addmm_(grad_output[b].flatten(1), ones.view({-1, 1})).view(-1);
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda, "deform forward (CUDA)");
......@@ -401,4 +562,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
"deform_conv_backward_input (CUDA)");
m.def("deform_conv_backward_parameters_cuda", &deform_conv_backward_parameters_cuda,
"deform_conv_backward_parameters (CUDA)");
m.def("modulated_deform_conv_cuda_forward", &modulated_deform_conv_cuda_forward,
"modulated deform conv forward (CUDA)");
m.def("modulated_deform_conv_cuda_backward", &modulated_deform_conv_cuda_backward,
"modulated deform conv backward (CUDA)");
}
......@@ -8,32 +8,6 @@
#include <cmath>
#include <vector>
void modulated_deformable_im2col_cuda(const at::Tensor data_im, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor data_col);
void modulated_deformable_col2im_cuda(const at::Tensor data_col, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor grad_im);
void modulated_deformable_col2im_coord_cuda(const at::Tensor data_col, const at::Tensor data_im,
const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kenerl_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor grad_offset,
at::Tensor grad_mask);
void DeformablePSROIPoolForward(const at::Tensor data,
const at::Tensor bbox,
const at::Tensor trans,
......@@ -76,148 +50,6 @@ void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
const int sample_per_part,
const float trans_std);
void modulated_deform_conv_cuda_forward(at::Tensor input, at::Tensor weight,
at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask,
at::Tensor output, at::Tensor columns,
int kernel_h, int kernel_w,
const int stride_h, const int stride_w,
const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w,
const int deformable_group)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out)
{
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.type());
}
// resize output
output = output.view({batch, channels_out, height_out, width_out});
// resize temporary columns
columns = at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.type());
for (int b = 0; b < batch; b++)
{
// Do Bias first:
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
// (N x 1) (1 x M)
output[b] = output[b].flatten(1).addmm_(bias.view({-1, 1}), ones.view({1, -1}), 0.0f, 1.0f).view_as(output[b]);
modulated_deformable_im2col_cuda(input[b], offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
deformable_group, columns);
//(k * m) x (m * n)
// Y = WC
output[b] = output[b].flatten(1).addmm_(weight.flatten(1), columns).view_as(output[b]);
}
}
void modulated_deform_conv_cuda_backward(at::Tensor input, at::Tensor weight,
at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask,
at::Tensor columns,
at::Tensor grad_input, at::Tensor grad_weight,
at::Tensor grad_bias, at::Tensor grad_offset,
at::Tensor grad_mask, at::Tensor grad_output,
int kernel_h, int kernel_w,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dilation_h, int dilation_w,
int deformable_group)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel);
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out)
{
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.type());
}
grad_input = grad_input.view({batch, channels, height, width});
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, input.type());
for (int b = 0; b < batch; b++)
{
columns.addmm_(weight.flatten(1).transpose(0, 1), grad_output[b].flatten(1), 0.0f, 1.0f);
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(columns, input[b], offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_offset[b], grad_mask[b]);
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(columns, offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_input[b]);
// gradient w.r.t. weight, dWeight should accumulate across the batch and group
modulated_deformable_im2col_cuda(input[b], offset[b], mask[b],
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
columns);
grad_weight = grad_weight.flatten(1).addmm_(grad_output[b].flatten(1), columns.transpose(0, 1)).view_as(grad_weight);
grad_bias = grad_bias.view({-1, 1}).addmm_(grad_output[b].flatten(1), ones.view({-1, 1})).view(-1);
}
}
void deform_psroi_pooling_cuda_forward(at::Tensor input, at::Tensor bbox,
at::Tensor trans,
at::Tensor out, at::Tensor top_count,
......@@ -305,10 +137,6 @@ void deform_psroi_pooling_cuda_backward(at::Tensor out_grad,
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("modulated_deform_conv_cuda_forward", &modulated_deform_conv_cuda_forward,
"modulated deform conv forward (CUDA)");
m.def("modulated_deform_conv_cuda_backward", &modulated_deform_conv_cuda_backward,
"modulated deform conv backward (CUDA)");
m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward,
"deform psroi pooling forward(CUDA)");
m.def("deform_psroi_pooling_cuda_backward", &deform_psroi_pooling_cuda_backward,
......
......@@ -356,7 +356,6 @@ void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
group_size, part_size, num_classes, channels_each_class);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment