Commit 6bd60eac authored by yhcao6's avatar yhcao6
Browse files

refactor dcn python interface

parent ef86c404
......@@ -20,3 +20,16 @@ echo "Building nms op..."
cd ../nms
make clean
make PYTHON=${PYTHON}
echo "Building nms op..."
cd ../nms
make clean
make PYTHON=${PYTHON}
echo "Building dcn..."
cd ../dcn
if [ -d "build" ]; then
rm -r build
fi
$PYTHON setup.py build_ext --inplace
$PYTHON setup_modulated.py build_ext --inplace
......@@ -2,49 +2,37 @@ import torch
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from .. import deform_conv
def deform_conv_function(input,
offset,
weight,
stride=1,
padding=0,
dilation=1,
deform_groups=1,
im2col_step=64):
if input is not None and input.dim() != 4:
raise ValueError(
"Expected 4D tensor as input, got {}D tensor instead.".format(
input.dim()))
f = DeformConvFunction(
_pair(stride), _pair(padding), _pair(dilation), deform_groups,
im2col_step)
return f(input, offset, weight)
from .. import deform_conv_cuda
class DeformConvFunction(Function):
def __init__(self,
stride,
padding,
dilation,
deformable_groups=1,
im2col_step=64):
super(DeformConvFunction, self).__init__()
self.stride = stride
self.padding = padding
self.dilation = dilation
self.deformable_groups = deformable_groups
self.im2col_step = im2col_step
@staticmethod
def forward(ctx,
input,
offset,
weight,
stride=1,
padding=0,
dilation=1,
deformable_groups=1,
im2col_step=64):
if input is not None and input.dim() != 4:
raise ValueError(
"Expected 4D tensor as input, got {}D tensor instead.".format(
input.dim()))
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.deformable_groups = deformable_groups
ctx.im2col_step = im2col_step
def forward(self, input, offset, weight):
self.save_for_backward(input, offset, weight)
ctx.save_for_backward(input, offset, weight)
output = input.new(*self._output_size(input, weight))
output = input.new(*DeformConvFunction._output_size(
input, weight, ctx.padding, ctx.dilation, ctx.stride))
self.bufs_ = [input.new(), input.new()] # columns, ones
ctx.bufs_ = [input.new(), input.new()] # columns, ones
if not input.is_cuda:
raise NotImplementedError
......@@ -56,18 +44,19 @@ class DeformConvFunction(Function):
if not isinstance(input, torch.cuda.FloatTensor):
raise NotImplementedError
cur_im2col_step = min(self.im2col_step, input.shape[0])
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] %
cur_im2col_step) == 0, 'im2col step must divide batchsize'
deform_conv.deform_conv_forward_cuda(
input, weight, offset, output, self.bufs_[0], self.bufs_[1],
weight.size(3), weight.size(2), self.stride[1], self.stride[0],
self.padding[1], self.padding[0], self.dilation[1],
self.dilation[0], self.deformable_groups, cur_im2col_step)
deform_conv_cuda.deform_conv_forward_cuda(
input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1],
weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.deformable_groups, cur_im2col_step)
return output
def backward(self, grad_output):
input, offset, weight = self.saved_tensors
@staticmethod
def backward(ctx, grad_output):
input, offset, weight = ctx.saved_tensors
grad_input = grad_offset = grad_weight = None
......@@ -81,44 +70,46 @@ class DeformConvFunction(Function):
if not isinstance(grad_output, torch.cuda.FloatTensor):
raise NotImplementedError
cur_im2col_step = min(self.im2col_step, input.shape[0])
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] %
cur_im2col_step) == 0, 'im2col step must divide batchsize'
if self.needs_input_grad[0] or self.needs_input_grad[1]:
grad_input = input.new(*input.size()).zero_()
grad_offset = offset.new(*offset.size()).zero_()
deform_conv.deform_conv_backward_input_cuda(
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
deform_conv_cuda.deform_conv_backward_input_cuda(
input, offset, grad_output, grad_input,
grad_offset, weight, self.bufs_[0], weight.size(3),
weight.size(2), self.stride[1], self.stride[0],
self.padding[1], self.padding[0], self.dilation[1],
self.dilation[0], self.deformable_groups, cur_im2col_step)
if self.needs_input_grad[2]:
grad_weight = weight.new(*weight.size()).zero_()
deform_conv.deform_conv_backward_parameters_cuda(
grad_offset, weight, ctx.bufs_[0], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.deformable_groups, cur_im2col_step)
if ctx.needs_input_grad[2]:
grad_weight = torch.zeros_like(weight)
deform_conv_cuda.deform_conv_backward_parameters_cuda(
input, offset, grad_output,
grad_weight, self.bufs_[0], self.bufs_[1], weight.size(3),
weight.size(2), self.stride[1], self.stride[0],
self.padding[1], self.padding[0], self.dilation[1],
self.dilation[0], self.deformable_groups, 1,
cur_im2col_step)
grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.deformable_groups, 1, cur_im2col_step)
return grad_input, grad_offset, grad_weight
return grad_input, grad_offset, grad_weight, None, None, None, None
def _output_size(self, input, weight):
@staticmethod
def _output_size(input, weight, padding, dilation, stride):
channels = weight.size(0)
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = self.padding[d]
kernel = self.dilation[d] * (weight.size(d + 2) - 1) + 1
stride = self.stride[d]
output_size += ((in_size + (2 * pad) - kernel) // stride + 1, )
pad = padding[d]
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
stride_ = stride[d]
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
if not all(map(lambda s: s > 0, output_size)):
raise ValueError(
"convolution input is too small (output would be {})".format(
'x'.join(map(str, output_size))))
return output_size
deform_conv = DeformConvFunction.apply
......@@ -6,128 +6,153 @@ from __future__ import print_function
import torch
from torch.autograd import Function
from .. import modulated_dcn as _backend
from .. import modulated_dcn_cuda as _backend
class ModulatedDeformConvFunction(Function):
def __init__(self, stride, padding, dilation=1, deformable_groups=1):
super(ModulatedDeformConvFunction, self).__init__()
self.stride = stride
self.padding = padding
self.dilation = dilation
self.deformable_groups = deformable_groups
def forward(self, input, offset, mask, weight, bias):
def __init__(ctx, stride, padding, dilation=1, deformable_groups=1):
super(ModulatedDeformConvFunction, ctx).__init__()
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.deformable_groups = deformable_groups
@staticmethod
def forward(ctx,
input,
offset,
mask,
weight,
bias,
stride,
padding,
dilation=1,
deformable_groups=1):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.deformable_groups = deformable_groups
if not input.is_cuda:
raise NotImplementedError
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
or input.requires_grad:
self.save_for_backward(input, offset, mask, weight, bias)
output = input.new(*self._infer_shape(input, weight))
self._bufs = [input.new(), input.new()]
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new(
*ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
ctx._bufs = [input.new(), input.new()]
_backend.modulated_deform_conv_cuda_forward(
input, weight, bias, self._bufs[0], offset, mask, output,
self._bufs[1], weight.shape[2], weight.shape[3], self.stride,
self.stride, self.padding, self.padding, self.dilation,
self.dilation, self.deformable_groups)
input, weight, bias, ctx._bufs[0], offset, mask, output,
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.deformable_groups)
return output
def backward(self, grad_output):
@staticmethod
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
input, offset, mask, weight, bias = self.saved_tensors
grad_input = input.new(*input.size()).zero_()
grad_offset = offset.new(*offset.size()).zero_()
grad_mask = mask.new(*mask.size()).zero_()
grad_weight = weight.new(*weight.size()).zero_()
grad_bias = bias.new(*bias.size()).zero_()
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
_backend.modulated_deform_conv_cuda_backward(
input, weight, bias, self._bufs[0], offset, mask, self._bufs[1],
input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
grad_output, weight.shape[2], weight.shape[3], self.stride,
self.stride, self.padding, self.padding, self.dilation,
self.dilation, self.deformable_groups)
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.deformable_groups)
return grad_input, grad_offset, grad_mask, grad_weight, grad_bias
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None)
def _infer_shape(self, input, weight):
@staticmethod
def _infer_shape(ctx, input, weight):
n = input.size(0)
channels_out = weight.size(0)
height, width = input.shape[2:4]
kernel_h, kernel_w = weight.shape[2:4]
height_out = (height + 2 * self.padding -
(self.dilation * (kernel_h - 1) + 1)) // self.stride + 1
width_out = (width + 2 * self.padding -
(self.dilation * (kernel_w - 1) + 1)) // self.stride + 1
return (n, channels_out, height_out, width_out)
height_out = (height + 2 * ctx.padding -
(ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
width_out = (width + 2 * ctx.padding -
(ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
return n, channels_out, height_out, width_out
class DeformRoIPoolingFunction(Function):
def __init__(self,
spatial_scale,
pooled_size,
output_dim,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0):
super(DeformRoIPoolingFunction, self).__init__()
self.spatial_scale = spatial_scale
self.pooled_size = pooled_size
self.output_dim = output_dim
self.no_trans = no_trans
self.group_size = group_size
self.part_size = pooled_size if part_size is None else part_size
self.sample_per_part = sample_per_part
self.trans_std = trans_std
assert self.trans_std >= 0.0 and self.trans_std <= 1.0
def forward(self, data, rois, offset):
@staticmethod
def forward(ctx,
data,
rois,
offset,
spatial_scale,
pooled_size,
output_dim,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0):
ctx.spatial_scale = spatial_scale
ctx.pooled_size = pooled_size
ctx.output_dim = output_dim
ctx.no_trans = no_trans
ctx.group_size = group_size
ctx.part_size = pooled_size if part_size is None else part_size
ctx.sample_per_part = sample_per_part
ctx.trans_std = trans_std
assert 0.0 <= ctx.trans_std <= 1.0
if not data.is_cuda:
raise NotImplementedError
output = data.new(*self._infer_shape(data, rois))
output_count = data.new(*self._infer_shape(data, rois))
output = data.new(
*DeformRoIPoolingFunction._infer_shape(ctx, data, rois))
output_count = data.new(
*DeformRoIPoolingFunction._infer_shape(ctx, data, rois))
_backend.deform_psroi_pooling_cuda_forward(
data, rois, offset, output, output_count, self.no_trans,
self.spatial_scale, self.output_dim, self.group_size,
self.pooled_size, self.part_size, self.sample_per_part,
self.trans_std)
data, rois, offset, output, output_count, ctx.no_trans,
ctx.spatial_scale, ctx.output_dim, ctx.group_size, ctx.pooled_size,
ctx.part_size, ctx.sample_per_part, ctx.trans_std)
# if data.requires_grad or rois.requires_grad or offset.requires_grad:
# self.save_for_backward(data, rois, offset, output_count)
self.data = data
self.rois = rois
self.offset = offset
self.output_count = output_count
# ctx.save_for_backward(data, rois, offset, output_count)
ctx.data = data
ctx.rois = rois
ctx.offset = offset
ctx.output_count = output_count
return output
def backward(self, grad_output):
@staticmethod
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
# data, rois, offset, output_count = self.saved_tensors
data = self.data
rois = self.rois
offset = self.offset
output_count = self.output_count
grad_input = data.new(*data.size()).zero_()
grad_offset = offset.new(*offset.size()).zero_()
# data, rois, offset, output_count = ctx.saved_tensors
data = ctx.data
rois = ctx.rois
offset = ctx.offset
output_count = ctx.output_count
grad_input = torch.zeros_like(data)
grad_offset = torch.zeros_like(offset)
_backend.deform_psroi_pooling_cuda_backward(
grad_output, data, rois, offset, output_count, grad_input,
grad_offset, self.no_trans, self.spatial_scale, self.output_dim,
self.group_size, self.pooled_size, self.part_size,
self.sample_per_part, self.trans_std)
return grad_input, torch.zeros(rois.shape).cuda(), grad_offset
def _infer_shape(self, data, rois):
# _, c, h, w = data.shape[:4]
# c = data.shape[1]
grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.output_dim,
ctx.group_size, ctx.pooled_size, ctx.part_size,
ctx.sample_per_part, ctx.trans_std)
return (grad_input, torch.zeros_like(rois), grad_offset, None, None,
None, None, None, None, None, None)
@staticmethod
def _infer_shape(ctx, data, rois):
n = rois.shape[0]
return n, self.output_dim, self.pooled_size, self.pooled_size
return n, ctx.output_dim, ctx.pooled_size, ctx.pooled_size
modulated_deform_conv = ModulatedDeformConvFunction.apply
deform_roi_pooling = DeformRoIPoolingFunction.apply
......@@ -2,10 +2,11 @@ import math
import torch
import torch.nn as nn
from mmcv.cnn import uniform_init
from torch.nn.modules.module import Module
from torch.nn.modules.utils import _pair
from ..functions.deform_conv import deform_conv_function
from ..functions.deform_conv import deform_conv
class DeformConv(Module):
......@@ -37,9 +38,9 @@ class DeformConv(Module):
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
uniform_init(self, -stdv, stdv)
def forward(self, input, offset):
return deform_conv_function(input, offset, self.weight, self.stride,
self.padding, self.dilation,
self.num_deformable_groups)
return deform_conv(input, offset, self.weight, self.stride,
self.padding, self.dilation,
self.num_deformable_groups)
......@@ -6,11 +6,12 @@ from __future__ import print_function
import math
import torch
from mmcv.cnn import uniform_init
from torch import nn
from torch.nn.modules.utils import _pair
from ..functions.modulated_dcn_func import DeformRoIPoolingFunction
from ..functions.modulated_dcn_func import ModulatedDeformConvFunction
from ..functions.modulated_dcn_func import deform_roi_pooling
from ..functions.modulated_dcn_func import modulated_deform_conv
class ModulatedDeformConv(nn.Module):
......@@ -46,13 +47,12 @@ class ModulatedDeformConv(nn.Module):
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
self.bias.data.zero_()
uniform_init(self, -stdv, stdv)
def forward(self, input, offset, mask):
func = ModulatedDeformConvFunction(
self.stride, self.padding, self.dilation, self.deformable_groups)
return func(input, offset, mask, self.weight, self.bias)
return modulated_deform_conv(input, offset, mask, self.weight,
self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
class ModulatedDeformConvPack(ModulatedDeformConv):
......@@ -89,9 +89,9 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
func = ModulatedDeformConvFunction(
self.stride, self.padding, self.dilation, self.deformable_groups)
return func(input, offset, mask, self.weight, self.bias)
return modulated_deform_conv(input, offset, mask, self.weight,
self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
class DeformRoIPooling(nn.Module):
......@@ -115,16 +115,14 @@ class DeformRoIPooling(nn.Module):
self.part_size = pooled_size if part_size is None else part_size
self.sample_per_part = sample_per_part
self.trans_std = trans_std
self.func = DeformRoIPoolingFunction(
self.spatial_scale, self.pooled_size, self.output_dim,
self.no_trans, self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
def forward(self, data, rois, offset):
if self.no_trans:
offset = data.new()
return self.func(data, rois, offset)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.pooled_size,
self.output_dim, self.no_trans, self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
......@@ -146,10 +144,6 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
self.deform_fc_dim = deform_fc_dim
if not no_trans:
self.func_offset = DeformRoIPoolingFunction(
self.spatial_scale, self.pooled_size, self.output_dim, True,
self.group_size, self.part_size, self.sample_per_part,
self.trans_std)
self.offset_fc = nn.Sequential(
nn.Linear(
self.pooled_size * self.pooled_size * self.output_dim,
......@@ -176,11 +170,20 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
else:
n = rois.shape[0]
offset = data.new()
x = self.func_offset(data, rois, offset)
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.pooled_size, self.output_dim, True,
self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.pooled_size, self.pooled_size)
mask = self.mask_fc(x.view(n, -1))
mask = mask.view(n, 1, self.pooled_size, self.pooled_size)
feat = self.func(data, rois, offset) * mask
feat = deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.pooled_size,
self.output_dim, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std) * mask
return feat
return self.func(data, rois, offset)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.pooled_size,
self.output_dim, self.no_trans, self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
......@@ -2,9 +2,9 @@ from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='modulated_deform_conv',
name='modulated_dcn_cuda',
ext_modules=[
CUDAExtension('modulated_dcn', [
CUDAExtension('modulated_dcn_cuda', [
'src/modulated_dcn_cuda.cpp',
'src/modulated_deform_im2col_cuda.cu',
'src/deform_psroi_pooling_cuda.cu'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment