Commit 66489d6b authored by yhcao6's avatar yhcao6
Browse files

separate deform_conv and deform_pool, clean some code

parent 170b66f5
from .functions import deform_conv
from .functions.modulated_dcn_func import (modulated_deform_conv,
deform_roi_pooling)
from .modules.deform_conv import DeformConv
from .modules.modulated_dcn import (
DeformRoIPooling, ModulatedDeformRoIPoolingPack, ModulatedDeformConv,
from .functions.deform_conv import deform_conv, modulated_deform_conv
from .functions.deform_pool import deform_roi_pooling
from .modules.deform_conv import (DeformConv, ModulatedDeformConv,
ModulatedDeformConvPack)
from .modules.deform_pool import (DeformRoIPooling,
ModulatedDeformRoIPoolingPack)
__all__ = [
'DeformConv', 'DeformRoIPooling', 'ModulatedDeformRoIPoolingPack',
......
......@@ -3,6 +3,7 @@ from torch.autograd import Function
from torch.nn.modules.utils import _pair
from .. import deform_conv_cuda
from .. import modulated_dcn_cuda as _backend
class DeformConvFunction(Function):
......@@ -37,13 +38,6 @@ class DeformConvFunction(Function):
if not input.is_cuda:
raise NotImplementedError
else:
if isinstance(input, torch.autograd.Variable):
if not isinstance(input.data, torch.cuda.FloatTensor):
raise NotImplementedError
else:
if not isinstance(input, torch.cuda.FloatTensor):
raise NotImplementedError
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] %
cur_im2col_step) == 0, 'im2col step must divide batchsize'
......@@ -63,13 +57,6 @@ class DeformConvFunction(Function):
if not grad_output.is_cuda:
raise NotImplementedError
else:
if isinstance(grad_output, torch.autograd.Variable):
if not isinstance(grad_output.data, torch.cuda.FloatTensor):
raise NotImplementedError
else:
if not isinstance(grad_output, torch.cuda.FloatTensor):
raise NotImplementedError
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] %
cur_im2col_step) == 0, 'im2col step must divide batchsize'
......@@ -112,4 +99,70 @@ class DeformConvFunction(Function):
return output_size
class ModulatedDeformConvFunction(Function):
@staticmethod
def forward(ctx,
input,
offset,
mask,
weight,
bias,
stride,
padding,
dilation=1,
deformable_groups=1):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.deformable_groups = deformable_groups
if not input.is_cuda:
raise NotImplementedError
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
or input.requires_grad:
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new(
*ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
ctx._bufs = [input.new(), input.new()]
_backend.modulated_deform_conv_cuda_forward(
input, weight, bias, ctx._bufs[0], offset, mask, output,
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.deformable_groups)
return output
@staticmethod
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
_backend.modulated_deform_conv_cuda_backward(
input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.deformable_groups)
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None)
@staticmethod
def _infer_shape(ctx, input, weight):
n = input.size(0)
channels_out = weight.size(0)
height, width = input.shape[2:4]
kernel_h, kernel_w = weight.shape[2:4]
height_out = (height + 2 * ctx.padding -
(ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
width_out = (width + 2 * ctx.padding -
(ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
return n, channels_out, height_out, width_out
deform_conv = DeformConvFunction.apply
modulated_deform_conv = ModulatedDeformConvFunction.apply
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
from torch.autograd import Function
from .. import modulated_dcn_cuda as _backend
class ModulatedDeformConvFunction(Function):
def __init__(ctx, stride, padding, dilation=1, deformable_groups=1):
super(ModulatedDeformConvFunction, ctx).__init__()
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.deformable_groups = deformable_groups
@staticmethod
def forward(ctx,
input,
offset,
mask,
weight,
bias,
stride,
padding,
dilation=1,
deformable_groups=1):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.deformable_groups = deformable_groups
if not input.is_cuda:
raise NotImplementedError
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
or input.requires_grad:
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new(
*ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
ctx._bufs = [input.new(), input.new()]
_backend.modulated_deform_conv_cuda_forward(
input, weight, bias, ctx._bufs[0], offset, mask, output,
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.deformable_groups)
return output
@staticmethod
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
_backend.modulated_deform_conv_cuda_backward(
input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.deformable_groups)
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None)
@staticmethod
def _infer_shape(ctx, input, weight):
n = input.size(0)
channels_out = weight.size(0)
height, width = input.shape[2:4]
kernel_h, kernel_w = weight.shape[2:4]
height_out = (height + 2 * ctx.padding -
(ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
width_out = (width + 2 * ctx.padding -
(ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
return n, channels_out, height_out, width_out
class DeformRoIPoolingFunction(Function):
@staticmethod
......@@ -154,5 +77,4 @@ class DeformRoIPoolingFunction(Function):
return n, ctx.output_dim, ctx.pooled_size, ctx.pooled_size
modulated_deform_conv = ModulatedDeformConvFunction.apply
deform_roi_pooling = DeformRoIPoolingFunction.apply
......@@ -2,13 +2,12 @@ import math
import torch
import torch.nn as nn
from torch.nn.modules.module import Module
from torch.nn.modules.utils import _pair
from ..functions.deform_conv import deform_conv
from ..functions.deform_conv import deform_conv, modulated_deform_conv
class DeformConv(Module):
class DeformConv(nn.Module):
def __init__(self,
in_channels,
......@@ -43,3 +42,84 @@ class DeformConv(Module):
return deform_conv(input, offset, self.weight, self.stride,
self.padding, self.dilation,
self.num_deformable_groups)
class ModulatedDeformConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
deformable_groups=1,
no_bias=True):
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.deformable_groups = deformable_groups
self.no_bias = no_bias
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
self.bias = nn.Parameter(torch.zeros(out_channels))
self.reset_parameters()
if self.no_bias:
self.bias.requires_grad = False
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
self.bias.data.zero_()
def forward(self, input, offset, mask):
return modulated_deform_conv(input, offset, mask, self.weight,
self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
class ModulatedDeformConvPack(ModulatedDeformConv):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
deformable_groups=1,
no_bias=False):
super(ModulatedDeformConvPack,
self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, deformable_groups, no_bias)
self.conv_offset_mask = nn.Conv2d(
self.in_channels,
self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=(self.stride, self.stride),
padding=(self.padding, self.padding),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, input):
out = self.conv_offset_mask(input)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv(input, offset, mask, self.weight,
self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import torch
from torch import nn
from torch.nn.modules.utils import _pair
from ..functions.modulated_dcn_func import deform_roi_pooling
from ..functions.modulated_dcn_func import modulated_deform_conv
class ModulatedDeformConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
deformable_groups=1,
no_bias=True):
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.deformable_groups = deformable_groups
self.no_bias = no_bias
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
self.bias = nn.Parameter(torch.zeros(out_channels))
self.reset_parameters()
if self.no_bias:
self.bias.requires_grad = False
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
self.bias.data.zero_()
def forward(self, input, offset, mask):
return modulated_deform_conv(input, offset, mask, self.weight,
self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
class ModulatedDeformConvPack(ModulatedDeformConv):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation=1,
deformable_groups=1,
no_bias=False):
super(ModulatedDeformConvPack,
self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, deformable_groups, no_bias)
self.conv_offset_mask = nn.Conv2d(
self.in_channels,
self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=(self.stride, self.stride),
padding=(self.padding, self.padding),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, input):
out = self.conv_offset_mask(input)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv(input, offset, mask, self.weight,
self.bias, self.stride, self.padding,
self.dilation, self.deformable_groups)
from ..functions.deform_pool import deform_roi_pooling
class DeformRoIPooling(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment