Unverified Commit 64b1c8b6 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #470 from hellock/dcn-api

Add some arguments to DCN ops
parents e421e832 c35a11b6
from .dcn import (DeformConv, DeformRoIPooling, DeformRoIPoolingPack,
ModulatedDeformRoIPoolingPack, ModulatedDeformConv,
ModulatedDeformConvPack, deform_conv, modulated_deform_conv,
deform_roi_pooling)
from .dcn import (DeformConv, DeformConvPack, ModulatedDeformConv,
ModulatedDeformConvPack, DeformRoIPooling,
DeformRoIPoolingPack, ModulatedDeformRoIPoolingPack,
deform_conv, modulated_deform_conv, deform_roi_pooling)
from .nms import nms, soft_nms
from .roi_align import RoIAlign, roi_align
from .roi_pool import RoIPool, roi_pool
__all__ = [
'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool',
'DeformConv', 'DeformRoIPooling', 'DeformRoIPoolingPack',
'DeformConv', 'DeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv',
'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv',
'deform_roi_pooling'
......
from .functions.deform_conv import deform_conv, modulated_deform_conv
from .functions.deform_pool import deform_roi_pooling
from .modules.deform_conv import (DeformConv, ModulatedDeformConv,
ModulatedDeformConvPack)
DeformConvPack, ModulatedDeformConvPack)
from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack,
ModulatedDeformRoIPoolingPack)
__all__ = [
'DeformConv', 'DeformRoIPooling', 'DeformRoIPoolingPack',
'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv',
'ModulatedDeformConvPack', 'deform_conv',
'modulated_deform_conv', 'deform_roi_pooling'
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv',
'ModulatedDeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
'ModulatedDeformRoIPoolingPack', 'deform_conv', 'modulated_deform_conv',
'deform_roi_pooling'
]
......@@ -19,15 +19,16 @@ class DeformConv(nn.Module):
groups=1,
deformable_groups=1,
bias=False):
assert not bias
super(DeformConv, self).__init__()
assert not bias
assert in_channels % groups == 0, \
'in_channels {} cannot be divisible by groups {}'.format(
in_channels, groups)
assert out_channels % groups == 0, \
'out_channels {} cannot be divisible by groups {}'.format(
out_channels, groups)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
......@@ -50,10 +51,34 @@ class DeformConv(nn.Module):
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, input, offset):
return deform_conv(input, offset, self.weight, self.stride,
self.padding, self.dilation, self.groups,
self.deformable_groups)
def forward(self, x, offset):
return deform_conv(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
class DeformConvPack(DeformConv):
def __init__(self, *args, **kwargs):
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 2 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
offset = self.conv_offset(x)
return deform_conv(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
class ModulatedDeformConv(nn.Module):
......@@ -97,30 +122,19 @@ class ModulatedDeformConv(nn.Module):
if self.bias is not None:
self.bias.data.zero_()
def forward(self, input, offset, mask):
return modulated_deform_conv(
input, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups, self.deformable_groups)
def forward(self, x, offset, mask):
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
class ModulatedDeformConvPack(ModulatedDeformConv):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConvPack, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, deformable_groups, bias)
def __init__(self, *args, **kwargs):
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset_mask = nn.Conv2d(
self.in_channels // self.groups,
self.in_channels,
self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
......@@ -133,11 +147,11 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, input):
out = self.conv_offset_mask(input)
def forward(self, x):
out = self.conv_offset_mask(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv(
input, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups, self.deformable_groups)
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
......@@ -44,22 +44,28 @@ class DeformRoIPoolingPack(DeformRoIPooling):
part_size=None,
sample_per_part=4,
trans_std=.0,
num_offset_fcs=3,
deform_fc_channels=1024):
super(DeformRoIPoolingPack,
self).__init__(spatial_scale, out_size, out_channels, no_trans,
group_size, part_size, sample_per_part, trans_std)
self.num_offset_fcs = num_offset_fcs
self.deform_fc_channels = deform_fc_channels
if not no_trans:
self.offset_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 2))
seq = []
ic = self.out_size * self.out_size * self.out_channels
for i in range(self.num_offset_fcs):
if i < self.num_offset_fcs - 1:
oc = self.deform_fc_channels
else:
oc = self.out_size * self.out_size * 2
seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_offset_fcs - 1:
seq.append(nn.ReLU(inplace=True))
self.offset_fc = nn.Sequential(*seq)
self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_()
......@@ -97,33 +103,49 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
part_size=None,
sample_per_part=4,
trans_std=.0,
num_offset_fcs=3,
num_mask_fcs=2,
deform_fc_channels=1024):
super(ModulatedDeformRoIPoolingPack, self).__init__(
spatial_scale, out_size, out_channels, no_trans, group_size,
part_size, sample_per_part, trans_std)
self.num_offset_fcs = num_offset_fcs
self.num_mask_fcs = num_mask_fcs
self.deform_fc_channels = deform_fc_channels
if not no_trans:
self.offset_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 2))
offset_fc_seq = []
ic = self.out_size * self.out_size * self.out_channels
for i in range(self.num_offset_fcs):
if i < self.num_offset_fcs - 1:
oc = self.deform_fc_channels
else:
oc = self.out_size * self.out_size * 2
offset_fc_seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_offset_fcs - 1:
offset_fc_seq.append(nn.ReLU(inplace=True))
self.offset_fc = nn.Sequential(*offset_fc_seq)
self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_()
self.mask_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 1),
nn.Sigmoid())
self.mask_fc[2].weight.data.zero_()
self.mask_fc[2].bias.data.zero_()
mask_fc_seq = []
ic = self.out_size * self.out_size * self.out_channels
for i in range(self.num_mask_fcs):
if i < self.num_mask_fcs - 1:
oc = self.deform_fc_channels
else:
oc = self.out_size * self.out_size
mask_fc_seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_mask_fcs - 1:
mask_fc_seq.append(nn.ReLU(inplace=True))
else:
mask_fc_seq.append(nn.Sigmoid())
self.mask_fc = nn.Sequential(*mask_fc_seq)
self.mask_fc[-2].weight.data.zero_()
self.mask_fc[-2].bias.data.zero_()
def forward(self, data, rois):
assert data.size(1) == self.out_channels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment