"...resnet50_tensorflow.git" did not exist on "7e67dbbcb986434eaef6f1e362abc2a1579cc7b5"
Unverified Commit 64b1c8b6 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #470 from hellock/dcn-api

Add some arguments to DCN ops
parents e421e832 c35a11b6
from .dcn import (DeformConv, DeformRoIPooling, DeformRoIPoolingPack, from .dcn import (DeformConv, DeformConvPack, ModulatedDeformConv,
ModulatedDeformRoIPoolingPack, ModulatedDeformConv, ModulatedDeformConvPack, DeformRoIPooling,
ModulatedDeformConvPack, deform_conv, modulated_deform_conv, DeformRoIPoolingPack, ModulatedDeformRoIPoolingPack,
deform_roi_pooling) deform_conv, modulated_deform_conv, deform_roi_pooling)
from .nms import nms, soft_nms from .nms import nms, soft_nms
from .roi_align import RoIAlign, roi_align from .roi_align import RoIAlign, roi_align
from .roi_pool import RoIPool, roi_pool from .roi_pool import RoIPool, roi_pool
__all__ = [ __all__ = [
'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool',
'DeformConv', 'DeformRoIPooling', 'DeformRoIPoolingPack', 'DeformConv', 'DeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv', 'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv',
'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv', 'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv',
'deform_roi_pooling' 'deform_roi_pooling'
......
from .functions.deform_conv import deform_conv, modulated_deform_conv from .functions.deform_conv import deform_conv, modulated_deform_conv
from .functions.deform_pool import deform_roi_pooling from .functions.deform_pool import deform_roi_pooling
from .modules.deform_conv import (DeformConv, ModulatedDeformConv, from .modules.deform_conv import (DeformConv, ModulatedDeformConv,
ModulatedDeformConvPack) DeformConvPack, ModulatedDeformConvPack)
from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack, from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack,
ModulatedDeformRoIPoolingPack) ModulatedDeformRoIPoolingPack)
__all__ = [ __all__ = [
'DeformConv', 'DeformRoIPooling', 'DeformRoIPoolingPack', 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv',
'ModulatedDeformRoIPoolingPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
'ModulatedDeformConvPack', 'deform_conv', 'ModulatedDeformRoIPoolingPack', 'deform_conv', 'modulated_deform_conv',
'modulated_deform_conv', 'deform_roi_pooling' 'deform_roi_pooling'
] ]
...@@ -19,15 +19,16 @@ class DeformConv(nn.Module): ...@@ -19,15 +19,16 @@ class DeformConv(nn.Module):
groups=1, groups=1,
deformable_groups=1, deformable_groups=1,
bias=False): bias=False):
assert not bias
super(DeformConv, self).__init__() super(DeformConv, self).__init__()
assert not bias
assert in_channels % groups == 0, \ assert in_channels % groups == 0, \
'in_channels {} cannot be divisible by groups {}'.format( 'in_channels {} cannot be divisible by groups {}'.format(
in_channels, groups) in_channels, groups)
assert out_channels % groups == 0, \ assert out_channels % groups == 0, \
'out_channels {} cannot be divisible by groups {}'.format( 'out_channels {} cannot be divisible by groups {}'.format(
out_channels, groups) out_channels, groups)
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.kernel_size = _pair(kernel_size) self.kernel_size = _pair(kernel_size)
...@@ -50,10 +51,34 @@ class DeformConv(nn.Module): ...@@ -50,10 +51,34 @@ class DeformConv(nn.Module):
stdv = 1. / math.sqrt(n) stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv) self.weight.data.uniform_(-stdv, stdv)
def forward(self, input, offset): def forward(self, x, offset):
return deform_conv(input, offset, self.weight, self.stride, return deform_conv(x, offset, self.weight, self.stride, self.padding,
self.padding, self.dilation, self.groups, self.dilation, self.groups, self.deformable_groups)
self.deformable_groups)
class DeformConvPack(DeformConv):
def __init__(self, *args, **kwargs):
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 2 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
offset = self.conv_offset(x)
return deform_conv(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
class ModulatedDeformConv(nn.Module): class ModulatedDeformConv(nn.Module):
...@@ -97,30 +122,19 @@ class ModulatedDeformConv(nn.Module): ...@@ -97,30 +122,19 @@ class ModulatedDeformConv(nn.Module):
if self.bias is not None: if self.bias is not None:
self.bias.data.zero_() self.bias.data.zero_()
def forward(self, input, offset, mask): def forward(self, x, offset, mask):
return modulated_deform_conv( return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
input, offset, mask, self.weight, self.bias, self.stride, self.stride, self.padding, self.dilation,
self.padding, self.dilation, self.groups, self.deformable_groups) self.groups, self.deformable_groups)
class ModulatedDeformConvPack(ModulatedDeformConv): class ModulatedDeformConvPack(ModulatedDeformConv):
def __init__(self, def __init__(self, *args, **kwargs):
in_channels, super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConvPack, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, deformable_groups, bias)
self.conv_offset_mask = nn.Conv2d( self.conv_offset_mask = nn.Conv2d(
self.in_channels // self.groups, self.in_channels,
self.deformable_groups * 3 * self.kernel_size[0] * self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1], self.kernel_size[1],
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
...@@ -133,11 +147,11 @@ class ModulatedDeformConvPack(ModulatedDeformConv): ...@@ -133,11 +147,11 @@ class ModulatedDeformConvPack(ModulatedDeformConv):
self.conv_offset_mask.weight.data.zero_() self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_() self.conv_offset_mask.bias.data.zero_()
def forward(self, input): def forward(self, x):
out = self.conv_offset_mask(input) out = self.conv_offset_mask(x)
o1, o2, mask = torch.chunk(out, 3, dim=1) o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1) offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask) mask = torch.sigmoid(mask)
return modulated_deform_conv( return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
input, offset, mask, self.weight, self.bias, self.stride, self.stride, self.padding, self.dilation,
self.padding, self.dilation, self.groups, self.deformable_groups) self.groups, self.deformable_groups)
...@@ -44,22 +44,28 @@ class DeformRoIPoolingPack(DeformRoIPooling): ...@@ -44,22 +44,28 @@ class DeformRoIPoolingPack(DeformRoIPooling):
part_size=None, part_size=None,
sample_per_part=4, sample_per_part=4,
trans_std=.0, trans_std=.0,
num_offset_fcs=3,
deform_fc_channels=1024): deform_fc_channels=1024):
super(DeformRoIPoolingPack, super(DeformRoIPoolingPack,
self).__init__(spatial_scale, out_size, out_channels, no_trans, self).__init__(spatial_scale, out_size, out_channels, no_trans,
group_size, part_size, sample_per_part, trans_std) group_size, part_size, sample_per_part, trans_std)
self.num_offset_fcs = num_offset_fcs
self.deform_fc_channels = deform_fc_channels self.deform_fc_channels = deform_fc_channels
if not no_trans: if not no_trans:
self.offset_fc = nn.Sequential( seq = []
nn.Linear(self.out_size * self.out_size * self.out_channels, ic = self.out_size * self.out_size * self.out_channels
self.deform_fc_channels), for i in range(self.num_offset_fcs):
nn.ReLU(inplace=True), if i < self.num_offset_fcs - 1:
nn.Linear(self.deform_fc_channels, self.deform_fc_channels), oc = self.deform_fc_channels
nn.ReLU(inplace=True), else:
nn.Linear(self.deform_fc_channels, oc = self.out_size * self.out_size * 2
self.out_size * self.out_size * 2)) seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_offset_fcs - 1:
seq.append(nn.ReLU(inplace=True))
self.offset_fc = nn.Sequential(*seq)
self.offset_fc[-1].weight.data.zero_() self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_() self.offset_fc[-1].bias.data.zero_()
...@@ -97,33 +103,49 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling): ...@@ -97,33 +103,49 @@ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
part_size=None, part_size=None,
sample_per_part=4, sample_per_part=4,
trans_std=.0, trans_std=.0,
num_offset_fcs=3,
num_mask_fcs=2,
deform_fc_channels=1024): deform_fc_channels=1024):
super(ModulatedDeformRoIPoolingPack, self).__init__( super(ModulatedDeformRoIPoolingPack, self).__init__(
spatial_scale, out_size, out_channels, no_trans, group_size, spatial_scale, out_size, out_channels, no_trans, group_size,
part_size, sample_per_part, trans_std) part_size, sample_per_part, trans_std)
self.num_offset_fcs = num_offset_fcs
self.num_mask_fcs = num_mask_fcs
self.deform_fc_channels = deform_fc_channels self.deform_fc_channels = deform_fc_channels
if not no_trans: if not no_trans:
self.offset_fc = nn.Sequential( offset_fc_seq = []
nn.Linear(self.out_size * self.out_size * self.out_channels, ic = self.out_size * self.out_size * self.out_channels
self.deform_fc_channels), for i in range(self.num_offset_fcs):
nn.ReLU(inplace=True), if i < self.num_offset_fcs - 1:
nn.Linear(self.deform_fc_channels, self.deform_fc_channels), oc = self.deform_fc_channels
nn.ReLU(inplace=True), else:
nn.Linear(self.deform_fc_channels, oc = self.out_size * self.out_size * 2
self.out_size * self.out_size * 2)) offset_fc_seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_offset_fcs - 1:
offset_fc_seq.append(nn.ReLU(inplace=True))
self.offset_fc = nn.Sequential(*offset_fc_seq)
self.offset_fc[-1].weight.data.zero_() self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_() self.offset_fc[-1].bias.data.zero_()
self.mask_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.out_channels, mask_fc_seq = []
self.deform_fc_channels), ic = self.out_size * self.out_size * self.out_channels
nn.ReLU(inplace=True), for i in range(self.num_mask_fcs):
nn.Linear(self.deform_fc_channels, if i < self.num_mask_fcs - 1:
self.out_size * self.out_size * 1), oc = self.deform_fc_channels
nn.Sigmoid()) else:
self.mask_fc[2].weight.data.zero_() oc = self.out_size * self.out_size
self.mask_fc[2].bias.data.zero_() mask_fc_seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_mask_fcs - 1:
mask_fc_seq.append(nn.ReLU(inplace=True))
else:
mask_fc_seq.append(nn.Sigmoid())
self.mask_fc = nn.Sequential(*mask_fc_seq)
self.mask_fc[-2].weight.data.zero_()
self.mask_fc[-2].bias.data.zero_()
def forward(self, data, rois): def forward(self, data, rois):
assert data.size(1) == self.out_channels assert data.size(1) == self.out_channels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment