Commit 9266cc35 authored by yizhou-wang's avatar yizhou-wang
Browse files

source code for MNet and TDC

parent 7c3fd6f9
__pycache__/ __pycache__/
*.egg-info/ *.egg-info/
*.so
build/
data/ data/
checkpoints/ checkpoints/
results/ results/
......
import math
import torch
import torch.nn as nn
class MNet(nn.Module):
def __init__(self, in_chirps, out_channels, conv_op=None):
super(MNet, self).__init__()
self.in_chirps = in_chirps
self.out_channels = out_channels
if conv_op is None:
conv_op = nn.Conv3d
self.conv_op = conv_op
self.t_conv3d = conv_op(in_channels=2, out_channels=out_channels, kernel_size=(3, 1, 1), stride=(2, 1, 1),
padding=(1, 0, 0))
t_conv_out = math.floor((in_chirps + 2 * 1 - (3 - 1) - 1) / 2 + 1)
self.t_maxpool = nn.MaxPool3d(kernel_size=(t_conv_out, 1, 1))
def forward(self, x):
batch_size, n_channels, win_size, in_chirps, w, h = x.shape
x_out = torch.zeros((batch_size, self.out_channels, win_size, w, h)).cuda()
for win in range(win_size):
x_win = self.t_conv3d(x[:, :, win, :, :, :])
x_win = self.t_maxpool(x_win)
x_win = x_win.view(batch_size, self.out_channels, w, h)
x_out[:, :, win, ] = x_win
return x_out
if __name__ == '__main__':
batch_size = 4
in_channels = 2
win_size = 32
in_chirps = 4
w = 128
h = 128
out_channels = 32
mnet = MNet(in_chirps=in_chirps, out_channels=out_channels)
input = torch.randn(batch_size, in_channels, win_size, in_chirps, w, h)
output = mnet(input)
import torch.nn as nn
from .backbones.cdc import RODEncode, RODDecode
from .modules.mnet import MNet
from ..ops.dcn import DeformConvPack3D
class RODNetCDCDCN(nn.Module):
def __init__(self, n_class, mnet_cfg=None):
super(RODNetCDCDCN, self).__init__()
if mnet_cfg is not None:
in_chirps_mnet, out_channels_mnet = mnet_cfg
self.mnet = MNet(in_chirps_mnet, out_channels_mnet, conv_op=DeformConvPack3D)
self.with_mnet = True
self.c3d_encode = RODEncode(in_channels=out_channels_mnet)
else:
self.with_mnet = False
self.c3d_encode = RODEncode()
self.c3d_decode = RODDecode(n_class)
def forward(self, x):
if self.with_mnet:
x = self.mnet(x)
x = self.c3d_encode(x)
dets = self.c3d_decode(x)
return dets
import torch
import torch.nn as nn
from .backbones.hg_dcn import RadarStackedHourglass
from .prep_layers.mnet import MNet
from ..ops.dcn import DeformConvPack3D
class RODNetHGDCNv3(nn.Module):
def __init__(self, n_class, stacked_num=2, mnet_cfg=None):
super(RODNetHGDCNv3, self).__init__()
if mnet_cfg is not None:
in_chirps_mnet, out_channels_mnet = mnet_cfg
self.mnet = MNet(in_chirps_mnet, out_channels_mnet, conv_op=DeformConvPack3D)
self.with_mnet = True
self.stacked_hourglass = RadarStackedHourglass(n_class, stacked_num=stacked_num,
in_channels=out_channels_mnet, conv_op=DeformConvPack3D)
else:
self.with_mnet = False
self.stacked_hourglass = RadarStackedHourglass(n_class, stacked_num=stacked_num, conv_op=DeformConvPack3D)
def forward(self, x):
if self.with_mnet:
x = self.mnet(x)
out, offsets = self.stacked_hourglass(x)
return out, offsets
if __name__ == '__main__':
torch.cuda.set_device(torch.device('cuda:0'))
batch_size = 1
in_channels = 2
win_size = 4
in_chirps = 6
w = 64
h = 64
out_channels = 8
model = RODNetHGDCN(n_class=3, stacked_num=1, mnet_cfg=(in_chirps, out_channels)).cuda()
for iter in range(10):
input = torch.randn(batch_size, in_channels, win_size, in_chirps, w, h).cuda()
output = model(input)
print("forward done")
output_gt = torch.randn(batch_size, 3, win_size, w, h).cuda()
criterion = nn.BCELoss()
loss = criterion(output[0], output_gt)
loss.backward()
from .deform_conv_2d import DeformConv2D, DeformConvPack2D
from .deform_conv_2d import ModulatedDeformConv2D, ModulatedDeformConvPack2D
from .deform_pool_2d import DeformRoIPooling2D, DeformRoIPoolingPack2D
from .deform_pool_2d import ModulatedDeformRoIPoolingPack2D
from .deform_conv_3d import DeformConv3D, DeformConvPack3D
from .deform_conv_3d import ModulatedDeformConv3D, ModulatedDeformConvPack3D
# from .deform_pool_3d import DeformRoIPooling3D, DeformRoIPoolingPack3D
# from .deform_pool_3d import ModulatedDeformRoIPoolingPack3D
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single
# from mmdet.utils import print_log
from . import deform_conv_2d_cuda
class DeformConvFunction2D(Function):
@staticmethod
def forward(ctx,
input,
offset,
weight,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
im2col_step=64):
if input is not None and input.dim() != 4:
raise ValueError(
'Expected 4D tensor as input, got {}D tensor instead.'.format(
input.dim()))
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.im2col_step = im2col_step
ctx.save_for_backward(input, offset, weight)
output = input.new_empty(
DeformConvFunction2D._output_size(input, weight, ctx.padding,
ctx.dilation, ctx.stride))
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
if not input.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] %
cur_im2col_step) == 0, 'im2col step must divide batchsize'
deform_conv_2d_cuda.deform_conv_forward_cuda(
input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1],
weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.groups, ctx.deformable_groups,
cur_im2col_step)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, offset, weight = ctx.saved_tensors
grad_input = grad_offset = grad_weight = None
if not grad_output.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] %
cur_im2col_step) == 0, 'im2col step must divide batchsize'
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
deform_conv_2d_cuda.deform_conv_backward_input_cuda(
input, offset, grad_output, grad_input,
grad_offset, weight, ctx.bufs_[0], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.groups, ctx.deformable_groups,
cur_im2col_step)
if ctx.needs_input_grad[2]:
grad_weight = torch.zeros_like(weight)
deform_conv_2d_cuda.deform_conv_backward_parameters_cuda(
input, offset, grad_output,
grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
cur_im2col_step)
return (grad_input, grad_offset, grad_weight, None, None, None, None,
None)
@staticmethod
def _output_size(input, weight, padding, dilation, stride):
channels = weight.size(0)
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = padding[d]
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
stride_ = stride[d]
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1,)
if not all(map(lambda s: s > 0, output_size)):
raise ValueError(
'convolution input is too small (output would be {})'.format(
'x'.join(map(str, output_size))))
return output_size
class ModulatedDeformConvFunction2D(Function):
@staticmethod
def forward(ctx,
input,
offset,
mask,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.with_bias = bias is not None
if not ctx.with_bias:
bias = input.new_empty(1) # fake tensor
if not input.is_cuda:
raise NotImplementedError
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
or input.requires_grad:
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(
ModulatedDeformConvFunction2D._infer_shape(ctx, input, weight))
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
deform_conv_2d_cuda.modulated_deform_conv_cuda_forward(
input, weight, bias, ctx._bufs[0], offset, mask, output,
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.groups, ctx.deformable_groups, ctx.with_bias)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
deform_conv_2d_cuda.modulated_deform_conv_cuda_backward(
input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.groups, ctx.deformable_groups, ctx.with_bias)
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None, None)
@staticmethod
def _infer_shape(ctx, input, weight):
n = input.size(0)
channels_out = weight.size(0)
height, width = input.shape[2:4]
kernel_h, kernel_w = weight.shape[2:4]
height_out = (height + 2 * ctx.padding -
(ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
width_out = (width + 2 * ctx.padding -
(ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
return n, channels_out, height_out, width_out
deform_conv = DeformConvFunction2D.apply
modulated_deform_conv = ModulatedDeformConvFunction2D.apply
class DeformConv2D(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=False):
super(DeformConv2D, self).__init__()
assert not bias
assert in_channels % groups == 0, \
'in_channels {} cannot be divisible by groups {}'.format(
in_channels, groups)
assert out_channels % groups == 0, \
'out_channels {} cannot be divisible by groups {}'.format(
out_channels, groups)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deformable_groups = deformable_groups
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // self.groups,
*self.kernel_size))
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, x, offset):
# To fix an assert error in deform_conv_cuda.cpp:128
# input image is smaller than kernel
input_pad = (
x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
if input_pad:
pad_h = max(self.kernel_size[0] - x.size(2), 0)
pad_w = max(self.kernel_size[1] - x.size(3), 0)
x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant',
0).contiguous()
out = deform_conv(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
if input_pad:
out = out[:, :, :out.size(2) - pad_h, :out.size(3) -
pad_w].contiguous()
return out
class DeformConvPack2D(DeformConv2D):
"""A Deformable Conv Encapsulation that acts as normal Conv layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version = 2
def __init__(self, *args, **kwargs):
super(DeformConvPack2D, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 2 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
dilation=_pair(self.dilation),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
offset = self.conv_offset(x)
return deform_conv(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if version is None or version < 2:
# the key is different in early versions
# In version < 2, DeformConvPack loads previous benchmark models.
if (prefix + 'conv_offset.weight' not in state_dict
and prefix[:-1] + '_offset.weight' in state_dict):
state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
prefix[:-1] + '_offset.weight')
if (prefix + 'conv_offset.bias' not in state_dict
and prefix[:-1] + '_offset.bias' in state_dict):
state_dict[prefix +
'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
'_offset.bias')
if version is not None and version > 1:
print_log(
'DeformConvPack {} is upgraded to version 2.'.format(
prefix.rstrip('.')),
logger='root')
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)
class ModulatedDeformConv2D(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConv2D, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.deformable_groups = deformable_groups
self.with_bias = bias
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups,
*self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, x, offset, mask):
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
class ModulatedDeformConvPack2D(ModulatedDeformConv2D):
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version = 2
def __init__(self, *args, **kwargs):
super(ModulatedDeformConvPack2D, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
dilation=_pair(self.dilation),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if version is None or version < 2:
# the key is different in early versions
# In version < 2, ModulatedDeformConvPack
# loads previous benchmark models.
if (prefix + 'conv_offset.weight' not in state_dict
and prefix[:-1] + '_offset.weight' in state_dict):
state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
prefix[:-1] + '_offset.weight')
if (prefix + 'conv_offset.bias' not in state_dict
and prefix[:-1] + '_offset.bias' in state_dict):
state_dict[prefix +
'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
'_offset.bias')
if version is not None and version > 1:
print_log(
'ModulatedDeformConvPack {} is upgraded to version 2.'.format(
prefix.rstrip('.')),
logger='root')
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _triple, _pair, _single
from . import deform_conv_3d_cuda
class DeformConvFunction3D(Function):
@staticmethod
def forward(ctx,
input,
offset,
weight,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
im2col_step=64):
if input is not None and input.dim() != 5:
raise ValueError(
'Expected 5D tensor as input, got {}D tensor instead.'.format(
input.dim()))
ctx.stride = _triple(stride)
ctx.padding = _triple(padding)
ctx.dilation = _triple(dilation)
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.im2col_step = im2col_step
ctx.save_for_backward(input, offset, weight)
output = input.new_empty(
DeformConvFunction3D._output_size(input, weight, ctx.padding,
ctx.dilation, ctx.stride))
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
if not input.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] %
cur_im2col_step) == 0, 'im2col step must divide batchsize'
deform_conv_3d_cuda.deform_conv_forward_cuda(
input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1],
weight.size(4), weight.size(3), weight.size(2),
ctx.stride[2], ctx.stride[1], ctx.stride[0],
ctx.padding[2], ctx.padding[1], ctx.padding[0],
ctx.dilation[2], ctx.dilation[1], ctx.dilation[0],
ctx.groups, ctx.deformable_groups,
cur_im2col_step)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, offset, weight = ctx.saved_tensors
grad_input = grad_offset = grad_weight = None
if not grad_output.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] %
cur_im2col_step) == 0, 'im2col step must divide batchsize'
# needs_input_grad[0][1] for input and offset, [2] for kernel weights
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
deform_conv_3d_cuda.deform_conv_backward_input_cuda(
input, offset, grad_output, grad_input,
grad_offset, weight, ctx.bufs_[0],
weight.size(4), weight.size(3), weight.size(2),
ctx.stride[2], ctx.stride[1], ctx.stride[0],
ctx.padding[2], ctx.padding[1], ctx.padding[0],
ctx.dilation[2], ctx.dilation[1], ctx.dilation[0],
ctx.groups, ctx.deformable_groups,
cur_im2col_step)
if ctx.needs_input_grad[2]:
grad_weight = torch.zeros_like(weight)
deform_conv_3d_cuda.deform_conv_backward_parameters_cuda(
input, offset, grad_output,
grad_weight, ctx.bufs_[0], ctx.bufs_[1],
weight.size(4), weight.size(3), weight.size(2),
ctx.stride[2], ctx.stride[1], ctx.stride[0],
ctx.padding[2], ctx.padding[1], ctx.padding[0],
ctx.dilation[2], ctx.dilation[1], ctx.dilation[0],
ctx.groups, ctx.deformable_groups, 1,
cur_im2col_step)
return (grad_input, grad_offset, grad_weight, None, None, None, None,
None)
@staticmethod
def _output_size(input, weight, padding, dilation, stride):
channels = weight.size(0)
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = padding[d]
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
stride_ = stride[d]
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1,)
if not all(map(lambda s: s > 0, output_size)):
raise ValueError(
'convolution input is too small (output would be {})'.format(
'x'.join(map(str, output_size))))
return output_size
class ModulatedDeformConvFunction3D(Function):
@staticmethod
def forward(ctx,
input,
offset,
mask,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.with_bias = bias is not None
if not ctx.with_bias:
bias = input.new_empty(1) # fake tensor
if not input.is_cuda:
raise NotImplementedError
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
or input.requires_grad:
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(
ModulatedDeformConvFunction3D._infer_shape(ctx, input, weight))
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
deform_conv_3d_cuda.modulated_deform_conv_cuda_forward(
input, weight, bias, ctx._bufs[0], offset, mask, output,
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.groups, ctx.deformable_groups, ctx.with_bias)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
deform_conv_3d_cuda.modulated_deform_conv_cuda_backward(
input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.groups, ctx.deformable_groups, ctx.with_bias)
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None, None)
@staticmethod
def _infer_shape(ctx, input, weight):
n = input.size(0)
channels_out = weight.size(0)
height, width = input.shape[2:4]
kernel_h, kernel_w = weight.shape[2:4]
height_out = (height + 2 * ctx.padding -
(ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
width_out = (width + 2 * ctx.padding -
(ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
return n, channels_out, height_out, width_out
deform_conv = DeformConvFunction3D.apply
modulated_deform_conv = ModulatedDeformConvFunction3D.apply
class DeformConv3D(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=False):
super(DeformConv3D, self).__init__()
assert not bias
assert in_channels % groups == 0, \
'in_channels {} cannot be divisible by groups {}'.format(
in_channels, groups)
assert out_channels % groups == 0, \
'out_channels {} cannot be divisible by groups {}'.format(
out_channels, groups)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _triple(kernel_size)
self.stride = _triple(stride)
self.padding = _triple(padding)
self.dilation = _triple(dilation)
self.groups = groups
self.deformable_groups = deformable_groups
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // self.groups,
*self.kernel_size))
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, x, offset):
# To fix an assert error in deform_conv_cuda.cpp:128
# input image is smaller than kernel
# TODO: add t to input_pad
input_pad = (
x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
if input_pad:
pad_h = max(self.kernel_size[0] - x.size(2), 0)
pad_w = max(self.kernel_size[1] - x.size(3), 0)
x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant',
0).contiguous()
out = deform_conv(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
if input_pad:
out = out[:, :, :out.size(2) - pad_h, :out.size(3) -
pad_w].contiguous()
return out
class DeformConvPack3D(DeformConv3D):
"""A Deformable Conv Encapsulation that acts as normal Conv layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version = 2
def __init__(self, *args, **kwargs):
super(DeformConvPack3D, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv3d(
self.in_channels,
self.deformable_groups * 2 * self.kernel_size[0] *
self.kernel_size[1] * self.kernel_size[2],
kernel_size=self.kernel_size,
stride=_triple(self.stride),
padding=_triple(self.padding),
dilation=_triple(self.dilation),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
offset = self.conv_offset(x)
return deform_conv(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if version is None or version < 2:
# the key is different in early versions
# In version < 2, DeformConvPack loads previous benchmark models.
# TODO: check here
if (prefix + 'conv_offset.weight' not in state_dict
and prefix[:-1] + '_offset.weight' in state_dict):
state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
prefix[:-1] + '_offset.weight')
if (prefix + 'conv_offset.bias' not in state_dict
and prefix[:-1] + '_offset.bias' in state_dict):
state_dict[prefix +
'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
'_offset.bias')
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)
class ModulatedDeformConv3D(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConv3D, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.deformable_groups = deformable_groups
self.with_bias = bias
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups,
*self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, x, offset, mask):
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
class ModulatedDeformConvPack3D(ModulatedDeformConv3D):
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version = 2
def __init__(self, *args, **kwargs):
super(ModulatedDeformConvPack3D, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
dilation=_pair(self.dilation),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if version is None or version < 2:
# the key is different in early versions
# In version < 2, ModulatedDeformConvPack
# loads previous benchmark models.
if (prefix + 'conv_offset.weight' not in state_dict
and prefix[:-1] + '_offset.weight' in state_dict):
state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
prefix[:-1] + '_offset.weight')
if (prefix + 'conv_offset.bias' not in state_dict
and prefix[:-1] + '_offset.bias' in state_dict):
state_dict[prefix +
'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
'_offset.bias')
if version is not None and version > 1:
print_log(
'ModulatedDeformConvPack {} is upgraded to version 2.'.format(
prefix.rstrip('.')),
logger='root')
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from . import deform_pool_2d_cuda
class DeformRoIPoolingFunction2D(Function):
@staticmethod
def forward(ctx,
data,
rois,
offset,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0):
# TODO: support unsquare RoIs
out_h, out_w = _pair(out_size)
assert isinstance(out_h, int) and isinstance(out_w, int)
assert out_h == out_w
out_size = out_h # out_h and out_w must be equal
ctx.spatial_scale = spatial_scale
ctx.out_size = out_size
ctx.out_channels = out_channels
ctx.no_trans = no_trans
ctx.group_size = group_size
ctx.part_size = out_size if part_size is None else part_size
ctx.sample_per_part = sample_per_part
ctx.trans_std = trans_std
assert 0.0 <= ctx.trans_std <= 1.0
if not data.is_cuda:
raise NotImplementedError
n = rois.shape[0]
output = data.new_empty(n, out_channels, out_size, out_size)
output_count = data.new_empty(n, out_channels, out_size, out_size)
deform_pool_2d_cuda.deform_psroi_pooling_cuda_forward(
data, rois, offset, output, output_count, ctx.no_trans,
ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size,
ctx.part_size, ctx.sample_per_part, ctx.trans_std)
if data.requires_grad or rois.requires_grad or offset.requires_grad:
ctx.save_for_backward(data, rois, offset)
ctx.output_count = output_count
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
data, rois, offset = ctx.saved_tensors
output_count = ctx.output_count
grad_input = torch.zeros_like(data)
grad_rois = None
grad_offset = torch.zeros_like(offset)
deform_pool_2d_cuda.deform_psroi_pooling_cuda_backward(
grad_output, data, rois, offset, output_count, grad_input,
grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels,
ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part,
ctx.trans_std)
return (grad_input, grad_rois, grad_offset, None, None, None, None,
None, None, None, None)
deform_roi_pooling = DeformRoIPoolingFunction2D.apply
class DeformRoIPooling2D(nn.Module):
def __init__(self,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0):
super(DeformRoIPooling2D, self).__init__()
self.spatial_scale = spatial_scale
self.out_size = _pair(out_size)
self.out_channels = out_channels
self.no_trans = no_trans
self.group_size = group_size
self.part_size = out_size if part_size is None else part_size
self.sample_per_part = sample_per_part
self.trans_std = trans_std
def forward(self, data, rois, offset):
if self.no_trans:
offset = data.new_empty(0)
return deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels,
self.no_trans, self.group_size,
self.part_size, self.sample_per_part,
self.trans_std)
class DeformRoIPoolingPack2D(DeformRoIPooling2D):
def __init__(self,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0,
num_offset_fcs=3,
deform_fc_channels=1024):
super(DeformRoIPoolingPack2D,
self).__init__(spatial_scale, out_size, out_channels, no_trans,
group_size, part_size, sample_per_part, trans_std)
self.num_offset_fcs = num_offset_fcs
self.deform_fc_channels = deform_fc_channels
if not no_trans:
seq = []
ic = self.out_size[0] * self.out_size[1] * self.out_channels
for i in range(self.num_offset_fcs):
if i < self.num_offset_fcs - 1:
oc = self.deform_fc_channels
else:
oc = self.out_size[0] * self.out_size[1] * 2
seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_offset_fcs - 1:
seq.append(nn.ReLU(inplace=True))
self.offset_fc = nn.Sequential(*seq)
self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_()
def forward(self, data, rois):
assert data.size(1) == self.out_channels
n = rois.shape[0]
if n == 0:
return data.new_empty(n, self.out_channels, self.out_size[0],
self.out_size[1])
if self.no_trans:
offset = data.new_empty(0)
return deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels,
self.no_trans, self.group_size,
self.part_size, self.sample_per_part,
self.trans_std)
else:
offset = data.new_empty(0)
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels, True,
self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size[0], self.out_size[1])
return deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels,
self.no_trans, self.group_size,
self.part_size, self.sample_per_part,
self.trans_std)
class ModulatedDeformRoIPoolingPack2D(DeformRoIPooling2D):
def __init__(self,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0,
num_offset_fcs=3,
num_mask_fcs=2,
deform_fc_channels=1024):
super(ModulatedDeformRoIPoolingPack2D,
self).__init__(spatial_scale, out_size, out_channels, no_trans,
group_size, part_size, sample_per_part, trans_std)
self.num_offset_fcs = num_offset_fcs
self.num_mask_fcs = num_mask_fcs
self.deform_fc_channels = deform_fc_channels
if not no_trans:
offset_fc_seq = []
ic = self.out_size[0] * self.out_size[1] * self.out_channels
for i in range(self.num_offset_fcs):
if i < self.num_offset_fcs - 1:
oc = self.deform_fc_channels
else:
oc = self.out_size[0] * self.out_size[1] * 2
offset_fc_seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_offset_fcs - 1:
offset_fc_seq.append(nn.ReLU(inplace=True))
self.offset_fc = nn.Sequential(*offset_fc_seq)
self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_()
mask_fc_seq = []
ic = self.out_size[0] * self.out_size[1] * self.out_channels
for i in range(self.num_mask_fcs):
if i < self.num_mask_fcs - 1:
oc = self.deform_fc_channels
else:
oc = self.out_size[0] * self.out_size[1]
mask_fc_seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_mask_fcs - 1:
mask_fc_seq.append(nn.ReLU(inplace=True))
else:
mask_fc_seq.append(nn.Sigmoid())
self.mask_fc = nn.Sequential(*mask_fc_seq)
self.mask_fc[-2].weight.data.zero_()
self.mask_fc[-2].bias.data.zero_()
def forward(self, data, rois):
assert data.size(1) == self.out_channels
n = rois.shape[0]
if n == 0:
return data.new_empty(n, self.out_channels, self.out_size[0],
self.out_size[1])
if self.no_trans:
offset = data.new_empty(0)
return deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels,
self.no_trans, self.group_size,
self.part_size, self.sample_per_part,
self.trans_std)
else:
offset = data.new_empty(0)
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels, True,
self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size[0], self.out_size[1])
mask = self.mask_fc(x.view(n, -1))
mask = mask.view(n, 1, self.out_size[0], self.out_size[1])
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std) * mask
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair
from . import deform_pool_3d_cuda
class DeformRoIPoolingFunction3D(Function):
@staticmethod
def forward(ctx,
data,
rois,
offset,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0):
# TODO: support unsquare RoIs
out_h, out_w = _pair(out_size)
assert isinstance(out_h, int) and isinstance(out_w, int)
assert out_h == out_w
out_size = out_h # out_h and out_w must be equal
ctx.spatial_scale = spatial_scale
ctx.out_size = out_size
ctx.out_channels = out_channels
ctx.no_trans = no_trans
ctx.group_size = group_size
ctx.part_size = out_size if part_size is None else part_size
ctx.sample_per_part = sample_per_part
ctx.trans_std = trans_std
assert 0.0 <= ctx.trans_std <= 1.0
if not data.is_cuda:
raise NotImplementedError
n = rois.shape[0]
output = data.new_empty(n, out_channels, out_size, out_size)
output_count = data.new_empty(n, out_channels, out_size, out_size)
deform_pool_3d_cuda.deform_psroi_pooling_cuda_forward(
data, rois, offset, output, output_count, ctx.no_trans,
ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size,
ctx.part_size, ctx.sample_per_part, ctx.trans_std)
if data.requires_grad or rois.requires_grad or offset.requires_grad:
ctx.save_for_backward(data, rois, offset)
ctx.output_count = output_count
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
data, rois, offset = ctx.saved_tensors
output_count = ctx.output_count
grad_input = torch.zeros_like(data)
grad_rois = None
grad_offset = torch.zeros_like(offset)
deform_pool_3d_cuda.deform_psroi_pooling_cuda_backward(
grad_output, data, rois, offset, output_count, grad_input,
grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels,
ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part,
ctx.trans_std)
return (grad_input, grad_rois, grad_offset, None, None, None, None,
None, None, None, None)
deform_roi_pooling = DeformRoIPoolingFunction3D.apply
class DeformRoIPooling3D(nn.Module):
def __init__(self,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0):
super(DeformRoIPooling3D, self).__init__()
self.spatial_scale = spatial_scale
self.out_size = _pair(out_size)
self.out_channels = out_channels
self.no_trans = no_trans
self.group_size = group_size
self.part_size = out_size if part_size is None else part_size
self.sample_per_part = sample_per_part
self.trans_std = trans_std
def forward(self, data, rois, offset):
if self.no_trans:
offset = data.new_empty(0)
return deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels,
self.no_trans, self.group_size,
self.part_size, self.sample_per_part,
self.trans_std)
class DeformRoIPoolingPack3D(DeformRoIPooling3D):
def __init__(self,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0,
num_offset_fcs=3,
deform_fc_channels=1024):
super(DeformRoIPoolingPack3D,
self).__init__(spatial_scale, out_size, out_channels, no_trans,
group_size, part_size, sample_per_part, trans_std)
self.num_offset_fcs = num_offset_fcs
self.deform_fc_channels = deform_fc_channels
if not no_trans:
seq = []
ic = self.out_size[0] * self.out_size[1] * self.out_channels
for i in range(self.num_offset_fcs):
if i < self.num_offset_fcs - 1:
oc = self.deform_fc_channels
else:
oc = self.out_size[0] * self.out_size[1] * 2
seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_offset_fcs - 1:
seq.append(nn.ReLU(inplace=True))
self.offset_fc = nn.Sequential(*seq)
self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_()
def forward(self, data, rois):
assert data.size(1) == self.out_channels
n = rois.shape[0]
if n == 0:
return data.new_empty(n, self.out_channels, self.out_size[0],
self.out_size[1])
if self.no_trans:
offset = data.new_empty(0)
return deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels,
self.no_trans, self.group_size,
self.part_size, self.sample_per_part,
self.trans_std)
else:
offset = data.new_empty(0)
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels, True,
self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size[0], self.out_size[1])
return deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels,
self.no_trans, self.group_size,
self.part_size, self.sample_per_part,
self.trans_std)
class ModulatedDeformRoIPoolingPack3D(DeformRoIPooling3D):
def __init__(self,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0,
num_offset_fcs=3,
num_mask_fcs=2,
deform_fc_channels=1024):
super(ModulatedDeformRoIPoolingPack3D,
self).__init__(spatial_scale, out_size, out_channels, no_trans,
group_size, part_size, sample_per_part, trans_std)
self.num_offset_fcs = num_offset_fcs
self.num_mask_fcs = num_mask_fcs
self.deform_fc_channels = deform_fc_channels
if not no_trans:
offset_fc_seq = []
ic = self.out_size[0] * self.out_size[1] * self.out_channels
for i in range(self.num_offset_fcs):
if i < self.num_offset_fcs - 1:
oc = self.deform_fc_channels
else:
oc = self.out_size[0] * self.out_size[1] * 2
offset_fc_seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_offset_fcs - 1:
offset_fc_seq.append(nn.ReLU(inplace=True))
self.offset_fc = nn.Sequential(*offset_fc_seq)
self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_()
mask_fc_seq = []
ic = self.out_size[0] * self.out_size[1] * self.out_channels
for i in range(self.num_mask_fcs):
if i < self.num_mask_fcs - 1:
oc = self.deform_fc_channels
else:
oc = self.out_size[0] * self.out_size[1]
mask_fc_seq.append(nn.Linear(ic, oc))
ic = oc
if i < self.num_mask_fcs - 1:
mask_fc_seq.append(nn.ReLU(inplace=True))
else:
mask_fc_seq.append(nn.Sigmoid())
self.mask_fc = nn.Sequential(*mask_fc_seq)
self.mask_fc[-2].weight.data.zero_()
self.mask_fc[-2].bias.data.zero_()
def forward(self, data, rois):
assert data.size(1) == self.out_channels
n = rois.shape[0]
if n == 0:
return data.new_empty(n, self.out_channels, self.out_size[0],
self.out_size[1])
if self.no_trans:
offset = data.new_empty(0)
return deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels,
self.no_trans, self.group_size,
self.part_size, self.sample_per_part,
self.trans_std)
else:
offset = data.new_empty(0)
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels, True,
self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size[0], self.out_size[1])
mask = self.mask_fc(x.view(n, -1))
mask = mask.view(n, 1, self.out_size[0], self.out_size[1])
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std) * mask
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
#include <torch/extension.h>
#include <ATen/DeviceGuard.h>
#include <cmath>
#include <vector>
void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
const int channels, const int height, const int width,
const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
at::Tensor data_col);
void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
const int channels, const int height, const int width,
const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
at::Tensor grad_im);
void deformable_col2im_coord(
const at::Tensor data_col, const at::Tensor data_im,
const at::Tensor data_offset, const int channels, const int height,
const int width, const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int parallel_imgs,
const int deformable_group, at::Tensor grad_offset);
void modulated_deformable_im2col_cuda(
const at::Tensor data_im, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
at::Tensor data_col);
void modulated_deformable_col2im_cuda(
const at::Tensor data_col, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
at::Tensor grad_im);
void modulated_deformable_col2im_coord_cuda(
const at::Tensor data_col, const at::Tensor data_im,
const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int deformable_group, at::Tensor grad_offset,
at::Tensor grad_mask);
void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
int padW, int dilationH, int dilationW, int group,
int deformable_group) {
AT_CHECK(weight.ndimension() == 4,
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
"but got: %s",
weight.ndimension());
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
AT_CHECK(kW > 0 && kH > 0,
"kernel size should be greater than zero, but got kH: %d kW: %d", kH,
kW);
AT_CHECK((weight.size(2) == kH && weight.size(3) == kW),
"kernel size should be consistent with weight, ",
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
kW, weight.size(2), weight.size(3));
AT_CHECK(dW > 0 && dH > 0,
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
AT_CHECK(
dilationW > 0 && dilationH > 0,
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
dilationH, dilationW);
int ndim = input.ndimension();
int dimf = 0;
int dimh = 1;
int dimw = 2;
if (ndim == 4) {
dimf++;
dimh++;
dimw++;
}
AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
ndim);
long nInputPlane = weight.size(1) * group;
long inputHeight = input.size(dimh);
long inputWidth = input.size(dimw);
long nOutputPlane = weight.size(0);
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
AT_CHECK(nInputPlane % deformable_group == 0,
"input channels must divide deformable group size");
if (outputWidth < 1 || outputHeight < 1)
AT_ERROR(
"Given input size: (%ld x %ld x %ld). "
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
outputWidth);
AT_CHECK(input.size(1) == nInputPlane,
"invalid number of input planes, expected: %d, but got: %d",
nInputPlane, input.size(1));
AT_CHECK((inputHeight >= kH && inputWidth >= kW),
"input image is smaller than kernel");
AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
"invalid spatial size of offset, expected height: %d width: %d, but "
"got height: %d width: %d",
outputHeight, outputWidth, offset.size(2), offset.size(3));
AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
"invalid number of channels of offset");
if (gradOutput != NULL) {
AT_CHECK(gradOutput->size(dimf) == nOutputPlane,
"invalid number of gradOutput planes, expected: %d, but got: %d",
nOutputPlane, gradOutput->size(dimf));
AT_CHECK((gradOutput->size(dimh) == outputHeight &&
gradOutput->size(dimw) == outputWidth),
"invalid size of gradOutput, expected height: %d width: %d , but "
"got height: %d width: %d",
outputHeight, outputWidth, gradOutput->size(dimh),
gradOutput->size(dimw));
}
}
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
at::Tensor offset, at::Tensor output,
at::Tensor columns, at::Tensor ones, int kW,
int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH, int group,
int deformable_group, int im2col_step) {
// todo: resize columns to include im2col: done
// todo: add im2col_step as input
// todo: add new output buffer and transpose it to output (or directly
// transpose output) todo: possibly change data indexing because of
// parallel_imgs
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
dilationH, dilationW, group, deformable_group);
at::DeviceGuard guard(input.device());
input = input.contiguous();
offset = offset.contiguous();
weight = weight.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input.unsqueeze_(0);
offset.unsqueeze_(0);
}
// todo: assert batchsize dividable by im2col_step
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = weight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
outputHeight, outputWidth});
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
ones = at::ones({outputHeight, outputWidth}, input.options());
}
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
at::Tensor output_buffer =
at::zeros({batchSize / im2col_step, nOutputPlane,
im2col_step * outputHeight, outputWidth},
output.options());
output_buffer = output_buffer.view(
{output_buffer.size(0), group, output_buffer.size(1) / group,
output_buffer.size(2), output_buffer.size(3)});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++) {
output_buffer[elt][g] = output_buffer[elt][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output_buffer[elt][g]);
}
}
output_buffer = output_buffer.view(
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
output_buffer.size(3), output_buffer.size(4)});
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
im2col_step, outputHeight, outputWidth});
output_buffer.transpose_(1, 2);
output.copy_(output_buffer);
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
output = output.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
}
return 1;
}
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
at::Tensor gradOutput, at::Tensor gradInput,
at::Tensor gradOffset, at::Tensor weight,
at::Tensor columns, int kW, int kH, int dW,
int dH, int padW, int padH, int dilationW,
int dilationH, int group,
int deformable_group, int im2col_step) {
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
dilationH, dilationW, group, deformable_group);
at::DeviceGuard guard(input.device());
input = input.contiguous();
offset = offset.contiguous();
gradOutput = gradOutput.contiguous();
weight = weight.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input = input.view({1, input.size(0), input.size(1), input.size(2)});
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
gradOutput = gradOutput.view(
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = weight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
// change order of grad output
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
nOutputPlane, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight,
outputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
// divide into groups
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
gradOutput = gradOutput.view(
{gradOutput.size(0), group, gradOutput.size(1) / group,
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
for (int g = 0; g < group; g++) {
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradOutput = gradOutput.view(
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
dilationH, dilationW, im2col_step, deformable_group,
gradOffset[elt]);
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, gradInput[elt]);
}
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
gradOffset = gradOffset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
gradOffset =
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
}
return 1;
}
int deform_conv_backward_parameters_cuda(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
int padW, int padH, int dilationW, int dilationH, int group,
int deformable_group, float scale, int im2col_step) {
// todo: transpose and reshape outGrad
// todo: reshape columns
// todo: add im2col_step as input
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
padW, dilationH, dilationW, group, deformable_group);
at::DeviceGuard guard(input.device());
input = input.contiguous();
offset = offset.contiguous();
gradOutput = gradOutput.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input = input.view(
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
gradOutput = gradOutput.view(
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = gradWeight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
nOutputPlane, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
gradOutputBuffer =
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
outputHeight, outputWidth});
gradOutputBuffer.copy_(gradOutput);
gradOutputBuffer =
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
im2col_step * outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, columns);
// divide into group
gradOutputBuffer = gradOutputBuffer.view(
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
gradWeight =
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
gradWeight.size(2), gradWeight.size(3)});
for (int g = 0; g < group; g++) {
gradWeight[g] = gradWeight[g]
.flatten(1)
.addmm_(gradOutputBuffer[elt][g].flatten(1),
columns[g].transpose(1, 0), 1.0, scale)
.view_as(gradWeight[g]);
}
gradOutputBuffer = gradOutputBuffer.view(
{gradOutputBuffer.size(0),
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
gradWeight.size(2), gradWeight.size(3),
gradWeight.size(4)});
}
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
}
return 1;
}
void modulated_deform_conv_cuda_forward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int deformable_group,
const bool with_bias) {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
// resize output
output = output.view({batch, channels_out, height_out, width_out}).zero_();
// resize temporary columns
columns =
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
input.options());
output = output.view({output.size(0), group, output.size(1) / group,
output.size(2), output.size(3)});
for (int b = 0; b < batch; b++) {
modulated_deformable_im2col_cuda(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
// divide into group
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
for (int g = 0; g < group; g++) {
output[b][g] = output[b][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output[b][g]);
}
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
output = output.view({output.size(0), output.size(1) * output.size(2),
output.size(3), output.size(4)});
if (with_bias) {
output += bias.view({1, bias.size(0), 1, 1});
}
}
void modulated_deform_conv_cuda_backward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor columns,
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias) {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
grad_input = grad_input.view({batch, channels, height, width});
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
input.options());
grad_output =
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
grad_output.size(2), grad_output.size(3)});
for (int b = 0; b < batch; b++) {
// divide int group
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++) {
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
grad_output[b][g].flatten(1), 0.0f, 1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(
columns, input[b], offset[b], mask[b], 1, channels, height, width,
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
grad_mask[b]);
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(
columns, offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, grad_input[b]);
// gradient w.r.t. weight, dWeight should accumulate across the batch and
// group
modulated_deformable_im2col_cuda(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
grad_weight.size(1), grad_weight.size(2),
grad_weight.size(3)});
if (with_bias)
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
for (int g = 0; g < group; g++) {
grad_weight[g] =
grad_weight[g]
.flatten(1)
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
.view_as(grad_weight[g]);
if (with_bias) {
grad_bias[g] =
grad_bias[g]
.view({-1, 1})
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
.view(-1);
}
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2), grad_weight.size(3),
grad_weight.size(4)});
if (with_bias)
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
}
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
grad_output.size(2), grad_output.size(3),
grad_output.size(4)});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda,
"deform forward (CUDA)");
m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda,
"deform_conv_backward_input (CUDA)");
m.def("deform_conv_backward_parameters_cuda",
&deform_conv_backward_parameters_cuda,
"deform_conv_backward_parameters (CUDA)");
m.def("modulated_deform_conv_cuda_forward",
&modulated_deform_conv_cuda_forward,
"modulated deform conv forward (CUDA)");
m.def("modulated_deform_conv_cuda_backward",
&modulated_deform_conv_cuda_backward,
"modulated deform conv backward (CUDA)");
}
/*!
******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
*
* COPYRIGHT
*
* All contributions by the University of California:
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
* All rights reserved.
*
* All other contributions:
* Copyright (c) 2014-2017, the respective contributors
* All rights reserved.
*
* Caffe uses a shared copyright model: each contributor holds copyright over
* their contributions to Caffe. The project versioning records all such
* contribution and copyright details. If a contributor wants to further mark
* their specific copyright on a particular contribution, they should indicate
* their copyright solely in the commit message of the change when it is
* committed.
*
* LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* CONTRIBUTION AGREEMENT
*
* By contributing to the BVLC/caffe repository through pull-request, comment,
* or otherwise, the contributor releases their content to the
* license and copyright terms herein.
*
***************** END Caffe Copyright Notice and Disclaimer ********************
*
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file modulated_deformable_im2col.cuh
* \brief Function definitions of converting an image to
* column matrix based on kernel, padding, dilation, and offset.
* These functions are mainly used in deformable convolution operators.
* \ref: https://arxiv.org/abs/1703.06211
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
*/
// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#include <stdio.h>
#include <math.h>
#include <float.h>
using namespace at;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
const int kMaxGridNum = 65535;
inline int GET_BLOCKS(const int N)
{
return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
}
template <typename scalar_t>
__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
const int height, const int width, scalar_t h, scalar_t w)
{
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
scalar_t lh = h - h_low;
scalar_t lw = w - w_low;
scalar_t hh = 1 - lh, hw = 1 - lw;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = bottom_data[h_low * data_width + w_low];
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
const int h, const int w, const int height, const int width)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
template <typename scalar_t>
__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
const int height, const int width, const scalar_t *im_data,
const int data_width, const int bp_dir)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (bp_dir == 0)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
else if (bp_dir == 1)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
template <typename scalar_t>
__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
const int batch_size, const int num_channels, const int deformable_group,
const int height_col, const int width_col,
scalar_t *data_col)
{
CUDA_KERNEL_LOOP(index, n)
{
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
//const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i)
{
for (int j = 0; j < kernel_w; ++j)
{
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
scalar_t val = static_cast<scalar_t>(0);
const scalar_t h_im = h_in + i * dilation_h + offset_h;
const scalar_t w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
{
//const scalar_t map_h = i * dilation_h + offset_h;
//const scalar_t map_w = j * dilation_w + offset_w;
//const int cur_height = height - h_in;
//const int cur_width = width - w_in;
//val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val;
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
void deformable_im2col(
const at::Tensor data_im, const at::Tensor data_offset, const int channels,
const int height, const int width, const int ksize_h, const int ksize_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int parallel_imgs,
const int deformable_group, at::Tensor data_col)
{
// num_axes should be smaller than block size
// todo: check parallel_imgs is correctly passed in
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
const scalar_t *data_im_ = data_im.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
scalar_t *data_col_ = data_col.data<scalar_t>();
deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
channel_per_deformable_group, parallel_imgs, channels, deformable_group,
height_col, width_col, data_col_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
__global__ void deformable_col2im_gpu_kernel(
const int n, const scalar_t *data_col, const scalar_t *data_offset,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int deformable_group,
const int height_col, const int width_col,
scalar_t *grad_im)
{
CUDA_KERNEL_LOOP(index, n)
{
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
2 * kernel_h * kernel_w * height_col * width_col;
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
const scalar_t cur_top_grad = data_col[index];
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++)
{
for (int dx = -2; dx <= 2; dx++)
{
if (cur_h + dy >= 0 && cur_h + dy < height &&
cur_w + dx >= 0 && cur_w + dx < width &&
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1)
{
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
}
}
}
}
}
void deformable_col2im(
const at::Tensor data_col, const at::Tensor data_offset, const int channels,
const int height, const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
at::Tensor grad_im)
{
// todo: make sure parallel_imgs is passed in correctly
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
const scalar_t *data_col_ = data_col.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
scalar_t *grad_im_ = grad_im.data<scalar_t>();
deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
ksize_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
parallel_imgs, deformable_group, height_col, width_col, grad_im_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
const scalar_t *data_im, const scalar_t *data_offset,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int offset_channels, const int deformable_group,
const int height_col, const int width_col, scalar_t *grad_offset)
{
CUDA_KERNEL_LOOP(index, n)
{
scalar_t val = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
batch_size * width_col * height_col;
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
channel_per_deformable_group / kernel_h / kernel_w * height * width;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
kernel_h * kernel_w * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
{
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
scalar_t inv_h = h_in + i * dilation_h + offset_h;
scalar_t inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
{
inv_h = inv_w = -2;
}
const scalar_t weight = get_coordinate_weight(
inv_h, inv_w,
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
val += weight * data_col_ptr[col_pos];
cnt += 1;
}
grad_offset[index] = val;
}
}
void deformable_col2im_coord(
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
const int channels, const int height, const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
const int stride_w, const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
{
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
const scalar_t *data_col_ = data_col.data<scalar_t>();
const scalar_t *data_im_ = data_im.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
height_col, width_col, grad_offset_);
}));
}
template <typename scalar_t>
__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
const int height, const int width, scalar_t h, scalar_t w)
{
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
scalar_t lh = h - h_low;
scalar_t lw = w - w_low;
scalar_t hh = 1 - lh, hw = 1 - lw;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = bottom_data[h_low * data_width + w_low];
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
const int h, const int w, const int height, const int width)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
template <typename scalar_t>
__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
const int height, const int width, const scalar_t *im_data,
const int data_width, const int bp_dir)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (bp_dir == 0)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
else if (bp_dir == 1)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
template <typename scalar_t>
__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int num_channels, const int deformable_group,
const int height_col, const int width_col,
scalar_t *data_col)
{
CUDA_KERNEL_LOOP(index, n)
{
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
//const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i)
{
for (int j = 0; j < kernel_w; ++j)
{
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
scalar_t val = static_cast<scalar_t>(0);
const scalar_t h_im = h_in + i * dilation_h + offset_h;
const scalar_t w_im = w_in + j * dilation_w + offset_w;
//if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
{
//const float map_h = i * dilation_h + offset_h;
//const float map_w = j * dilation_w + offset_w;
//const int cur_height = height - h_in;
//const int cur_width = width - w_in;
//val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val * mask;
data_col_ptr += batch_size * height_col * width_col;
//data_col_ptr += height_col * width_col;
}
}
}
}
template <typename scalar_t>
__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int deformable_group,
const int height_col, const int width_col,
scalar_t *grad_im)
{
CUDA_KERNEL_LOOP(index, n)
{
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
const scalar_t cur_top_grad = data_col[index] * mask;
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++)
{
for (int dx = -2; dx <= 2; dx++)
{
if (cur_h + dy >= 0 && cur_h + dy < height &&
cur_w + dx >= 0 && cur_w + dx < width &&
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1)
{
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
}
}
}
}
}
template <typename scalar_t>
__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
const scalar_t *data_col, const scalar_t *data_im,
const scalar_t *data_offset, const scalar_t *data_mask,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int offset_channels, const int deformable_group,
const int height_col, const int width_col,
scalar_t *grad_offset, scalar_t *grad_mask)
{
CUDA_KERNEL_LOOP(index, n)
{
scalar_t val = 0, mval = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
{
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
scalar_t inv_h = h_in + i * dilation_h + offset_h;
scalar_t inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
{
inv_h = inv_w = -2;
}
else
{
mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
}
const scalar_t weight = dmcn_get_coordinate_weight(
inv_h, inv_w,
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
val += weight * data_col_ptr[col_pos] * mask;
cnt += 1;
}
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
grad_offset[index] = val;
if (offset_c % 2 == 0)
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
}
}
void modulated_deformable_im2col_cuda(
const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor data_col)
{
// num_axes should be smaller than block size
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * batch_size * height_col * width_col;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
const scalar_t *data_im_ = data_im.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
scalar_t *data_col_ = data_col.data<scalar_t>();
modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
batch_size, channels, deformable_group, height_col, width_col, data_col_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
}
}
void modulated_deformable_col2im_cuda(
const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor grad_im)
{
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
const scalar_t *data_col_ = data_col.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
scalar_t *grad_im_ = grad_im.data<scalar_t>();
modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, deformable_group, height_col, width_col, grad_im_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
}
}
void modulated_deformable_col2im_coord_cuda(
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group,
at::Tensor grad_offset, at::Tensor grad_mask)
{
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
const scalar_t *data_col_ = data_col.data<scalar_t>();
const scalar_t *data_im_ = data_im.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
scalar_t *grad_mask_ = grad_mask.data<scalar_t>();
modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
grad_offset_, grad_mask_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
}
}
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
#include <torch/extension.h>
#include <ATen/DeviceGuard.h>
#include <cmath>
#include <vector>
#include <iostream>
//#define DEBUG_INFO
void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
const int channels, const int time, const int height, const int width,
const int ksize_t, const int ksize_h, const int ksize_w,
const int pad_t, const int pad_h, const int pad_w,
const int stride_t, const int stride_h, const int stride_w,
const int dilation_t, const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
at::Tensor data_col);
void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
const int channels, const int time, const int height, const int width,
const int ksize_t, const int ksize_h, const int ksize_w,
const int pad_t, const int pad_h, const int pad_w,
const int stride_t, const int stride_h, const int stride_w,
const int dilation_t, const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
at::Tensor grad_im);
void deformable_col2im_coord(
const at::Tensor data_col, const at::Tensor data_im,
const at::Tensor data_offset, const int channels,
const int time, const int height, const int width,
const int ksize_t, const int ksize_h, const int ksize_w,
const int pad_t, const int pad_h, const int pad_w,
const int stride_t, const int stride_h, const int stride_w,
const int dilation_t, const int dilation_h, const int dilation_w,
const int parallel_imgs,
const int deformable_group, at::Tensor grad_offset);
void modulated_deformable_im2col_cuda(
const at::Tensor data_im, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
at::Tensor data_col);
void modulated_deformable_col2im_cuda(
const at::Tensor data_col, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
at::Tensor grad_im);
void modulated_deformable_col2im_coord_cuda(
const at::Tensor data_col, const at::Tensor data_im,
const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int deformable_group, at::Tensor grad_offset,
at::Tensor grad_mask);
void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
at::Tensor weight, int kH, int kW, int kT, int dH, int dW, int dT,
int padH, int padW, int padT, int dilationH, int dilationW, int dilationT,
int group, int deformable_group) {
AT_CHECK(weight.ndimension() == 5,
"5D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
"but got: %s",
weight.ndimension());
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
AT_CHECK(kW > 0 && kH > 0 && kT > 0,
"kernel size should be greater than zero, but got kH: %d kW: %d kT: %d", kH,
kW, kT);
AT_CHECK((weight.size(2) == kT && weight.size(3) == kH && weight.size(4) == kW),
"kernel size should be consistent with weight, ",
"but got kH: %d kW: %d kT: %d weight.size(2): %d, weight.size(3): %d, weight.size(4): %d", kH,
kW, kT, weight.size(2), weight.size(3), weight.size(4));
AT_CHECK(dW > 0 && dH > 0 && dT > 0,
"stride should be greater than zero, but got dH: %d dW: %d dT: %d", dH, dW, dT);
AT_CHECK(
dilationW > 0 && dilationH > 0 && dilationT > 0,
"dilation should be greater than 0, but got dilationH: %d dilationW: %d dilationT: %d",
dilationH, dilationW, dilationT);
int ndim = input.ndimension();
int dimf = 0;
int dimt = 1;
int dimh = 2;
int dimw = 3;
if (ndim == 5) {
dimf++;
dimt++;
dimh++;
dimw++;
}
AT_CHECK(ndim == 4 || ndim == 5, "4D or 5D input tensor expected but got: %s",
ndim);
long nInputPlane = weight.size(1) * group;
long inputTime = input.size(dimt);
long inputHeight = input.size(dimh);
long inputWidth = input.size(dimw);
long nOutputPlane = weight.size(0);
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputTime =
(inputTime + 2 * padT - (dilationT * (kT - 1) + 1)) / dT + 1;
AT_CHECK(nInputPlane % deformable_group == 0,
"input channels must divide deformable group size");
if (outputWidth < 1 || outputHeight < 1)
AT_ERROR(
"Given input size: (%ld x %ld x %ld). "
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
outputWidth);
AT_CHECK(input.size(1) == nInputPlane,
"invalid number of input planes, expected: %d, but got: %d",
nInputPlane, input.size(1));
AT_CHECK((inputHeight >= kH && inputWidth >= kW && inputTime >= kT),
"input data is smaller than kernel");
AT_CHECK((offset.size(2) == outputTime && offset.size(3) == outputHeight && offset.size(4) == outputWidth),
"invalid spatial size of offset, expected time: %d height: %d width: %d, but "
"got time: %d height: %d width: %d",
outputTime, outputHeight, outputWidth, offset.size(2), offset.size(3), offset.size(4));
AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW * kT),
"invalid number of channels of offset");
if (gradOutput != NULL) {
AT_CHECK(gradOutput->size(dimf) == nOutputPlane,
"invalid number of gradOutput planes, expected: %d, but got: %d",
nOutputPlane, gradOutput->size(dimf));
AT_CHECK((gradOutput->size(dimt) == outputTime &&
gradOutput->size(dimh) == outputHeight &&
gradOutput->size(dimw) == outputWidth),
"invalid size of gradOutput, expected time: %d height: %d width: %d, but "
"got time: %d height: %d width: %d",
outputTime, outputHeight, outputWidth,
gradOutput->size(dimt), gradOutput->size(dimh), gradOutput->size(dimw));
}
}
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
at::Tensor offset, at::Tensor output,
at::Tensor columns, at::Tensor ones,
int kW, int kH, int kT, int dW, int dH, int dT,
int padW, int padH, int padT,
int dilationW, int dilationH, int dilationT,
int group, int deformable_group, int im2col_step) {
// todo: resize columns to include im2col: done
// todo: add im2col_step as input
// todo: add new output buffer and transpose it to output (or directly
// transpose output) todo: possibly change data indexing because of
// parallel_imgs
#ifdef DEBUG_INFO
std::cout << "[cpp]deform_conv_forward_cuda: forward start" << "\n";
#endif
shape_check(input, offset, NULL, weight, kH, kW, kT, dH, dW, dT, padH, padW, padT,
dilationH, dilationW, dilationT, group, deformable_group);
at::DeviceGuard guard(input.device());
#ifdef DEBUG_INFO
std::cout << "[cpp]deform_conv_forward_cuda: finish shape_check()" << "\n";
#endif
input = input.contiguous();
offset = offset.contiguous();
weight = weight.contiguous();
int batch = 1;
if (input.ndimension() == 4) {
// Force batch
batch = 0;
input.unsqueeze_(0);
offset.unsqueeze_(0);
}
// todo: assert batchsize dividable by im2col_step
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputTime = input.size(2);
long inputHeight = input.size(3);
long inputWidth = input.size(4);
long nOutputPlane = weight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
long outputTime =
(inputTime + 2 * padT - (dilationT * (kT - 1) + 1)) / dT + 1;
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
outputTime, outputHeight, outputWidth});
#ifdef DEBUG_INFO
std::cout << "[cpp]deform_conv_forward_cuda: columns.size=" << nInputPlane * kW * kH * kT << " " << im2col_step * outputHeight * outputWidth * outputTime << std::endl;
#endif
columns = at::zeros(
{nInputPlane * kW * kH * kT, im2col_step * outputHeight * outputWidth * outputTime},
input.options());
#ifdef DEBUG_INFO
std::cout << "[cpp]deform_conv_forward_cuda: finish build columns" << "\n";
#endif
if (ones.ndimension() != 3 ||
ones.size(0) * ones.size(1) * ones.size(2) < outputTime * outputHeight * outputWidth) {
ones = at::ones({outputTime, outputHeight, outputWidth}, input.options());
}
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputTime, inputHeight, inputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW * kT, outputTime, outputHeight, outputWidth});
#ifdef DEBUG_INFO
std::cout << "[cpp]deform_conv_forward_cuda: finish build input & offset" << "\n";
#endif
at::Tensor output_buffer =
at::zeros({batchSize / im2col_step, nOutputPlane,
im2col_step, outputTime, outputHeight, outputWidth},
output.options()); // TODO: dim different from original mmdet: flatten(1) following ???TO CHECK???
output_buffer = output_buffer.view(
{output_buffer.size(0), group, output_buffer.size(1) / group,
output_buffer.size(2), output_buffer.size(3),
output_buffer.size(4), output_buffer.size(5)});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(input[elt], offset[elt], nInputPlane, inputTime, inputHeight, inputWidth,
kT, kH, kW, padT, padH, padW, dT, dH, dW, dilationT, dilationH, dilationW,
im2col_step, deformable_group, columns);
#ifdef DEBUG_INFO
std::cout << "[cpp]deform_conv_forward_cuda: finish deformable_im2col()" << "\n";
#endif
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3), weight.size(4)});
for (int g = 0; g < group; g++) {
output_buffer[elt][g] = output_buffer[elt][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output_buffer[elt][g]);
}
#ifdef DEBUG_INFO
std::cout << "[cpp]deform_conv_forward_cuda: finish calculate output_buffer" << "\n";
#endif
}
output_buffer = output_buffer.view(
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
output_buffer.size(3), output_buffer.size(4), output_buffer.size(5), output_buffer.size(6)});
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
im2col_step, outputTime, outputHeight, outputWidth});
output_buffer.transpose_(1, 2);
output.copy_(output_buffer);
output = output.view({batchSize, nOutputPlane, outputTime, outputHeight, outputWidth});
input = input.view({batchSize, nInputPlane, inputTime, inputHeight, inputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW * kT, outputTime, outputHeight, outputWidth});
if (batch == 0) {
output = output.view({nOutputPlane, outputTime, outputHeight, outputWidth});
input = input.view({nInputPlane, inputTime, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3), offset.size(4)});
}
return 1;
}
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
at::Tensor gradOutput, at::Tensor gradInput,
at::Tensor gradOffset, at::Tensor weight,
at::Tensor columns,
int kW, int kH, int kT, int dW, int dH, int dT,
int padW, int padH, int padT,
int dilationW, int dilationH, int dilationT,
int group, int deformable_group, int im2col_step) {
shape_check(input, offset, &gradOutput, weight, kH, kW, kT, dH, dW, dT, padH, padW, padT,
dilationH, dilationW, dilationT, group, deformable_group);
at::DeviceGuard guard(input.device());
input = input.contiguous();
offset = offset.contiguous();
gradOutput = gradOutput.contiguous();
weight = weight.contiguous();
int batch = 1;
if (input.ndimension() == 4) {
// Force batch
batch = 0;
input = input.view({1, input.size(0), input.size(1), input.size(2), input.size(3)});
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2), offset.size(3)});
gradOutput = gradOutput.view(
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputTime = input.size(2);
long inputHeight = input.size(3);
long inputWidth = input.size(4);
long nOutputPlane = weight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
long outputTime =
(inputTime + 2 * padT - (dilationT * (kT - 1) + 1)) / dT + 1;
AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
gradInput = gradInput.view({batchSize, nInputPlane, inputTime, inputHeight, inputWidth});
columns = at::zeros(
{nInputPlane * kW * kH * kT, im2col_step * outputTime * outputHeight * outputWidth},
input.options());
// change order of grad output
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
nOutputPlane, outputTime, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputTime, inputHeight, inputWidth});
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputTime, inputHeight, inputWidth});
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kT * kH * kW, outputTime, outputHeight,
outputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kT * kH * kW, outputTime, outputHeight, outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
// divide into groups
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3), weight.size(4)});
gradOutput = gradOutput.view(
{gradOutput.size(0), group, gradOutput.size(1) / group,
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
for (int g = 0; g < group; g++) {
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradOutput = gradOutput.view(
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5), gradOutput.size(6)});
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
inputTime, inputHeight, inputWidth, kT, kH, kW, padT, padH, padW, dT, dH, dW,
dilationT, dilationH, dilationW, im2col_step, deformable_group,
gradOffset[elt]);
deformable_col2im(columns, offset[elt], nInputPlane, inputTime, inputHeight, inputWidth,
kT, kH, kW, padT, padH, padW, dT, dH, dW, dilationT, dilationH, dilationW,
im2col_step, deformable_group, gradInput[elt]);
}
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputTime, outputHeight, outputWidth});
gradInput = gradInput.view({batchSize, nInputPlane, inputTime, inputHeight, inputWidth});
input = input.view({batchSize, nInputPlane, inputTime, inputHeight, inputWidth});
gradOffset = gradOffset.view(
{batchSize, deformable_group * 2 * kT * kH * kW, outputTime, outputHeight, outputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kT * kH * kW, outputTime, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputTime, outputHeight, outputWidth});
input = input.view({nInputPlane, inputTime, inputHeight, inputWidth});
gradInput = gradInput.view({nInputPlane, inputTime, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3), offset.size(4)});
gradOffset =
gradOffset.view({offset.size(1), offset.size(2), offset.size(3), offset.size(4)});
}
return 1;
}
int deform_conv_backward_parameters_cuda(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns, at::Tensor ones,
int kW, int kH, int kT, int dW, int dH, int dT,
int padW, int padH, int padT,
int dilationW, int dilationH, int dilationT,
int group,
int deformable_group, float scale, int im2col_step) {
// todo: transpose and reshape outGrad
// todo: reshape columns
// todo: add im2col_step as input
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, kT, dH, dW, dT, padH, padW, padT,
dilationH, dilationW, dilationT, group, deformable_group);
at::DeviceGuard guard(input.device());
input = input.contiguous();
offset = offset.contiguous();
gradOutput = gradOutput.contiguous();
int batch = 1;
if (input.ndimension() == 4) {
// Force batch
batch = 0;
input = input.view(
at::IntList({1, input.size(0), input.size(1), input.size(2), input.size(3)}));
gradOutput = gradOutput.view(
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputTime = input.size(2);
long inputHeight = input.size(3);
long inputWidth = input.size(4);
long nOutputPlane = gradWeight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
long outputTime =
(inputTime + 2 * padT - (dilationT * (kT - 1) + 1)) / dT + 1;
AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
columns = at::zeros(
{nInputPlane * kW * kH * kT, im2col_step * outputHeight * outputWidth * outputTime},
input.options());
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
nOutputPlane, outputTime, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
gradOutputBuffer =
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
outputTime, outputHeight, outputWidth});
gradOutputBuffer.copy_(gradOutput);
// gradOutputBuffer =
// gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
// im2col_step * outputHeight, outputWidth}); // TODO: dim different from original mmdet: flatten(1) following ???TO CHECK???
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputTime, outputHeight, outputWidth});
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputTime, inputHeight, inputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW * kT, outputTime, outputHeight, outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(input[elt], offset[elt], nInputPlane, inputTime, inputHeight, inputWidth,
kT, kH, kW, padT, padH, padW, dT, dH, dW, dilationT, dilationH, dilationW,
im2col_step, deformable_group, columns);
// divide into group
gradOutputBuffer = gradOutputBuffer.view(
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
gradOutputBuffer.size(2), gradOutputBuffer.size(3), gradOutputBuffer.size(4), gradOutputBuffer.size(5)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
gradWeight =
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
gradWeight.size(2), gradWeight.size(3), gradWeight.size(4)});
for (int g = 0; g < group; g++) {
gradWeight[g] = gradWeight[g]
.flatten(1)
.addmm_(gradOutputBuffer[elt][g].flatten(1),
columns[g].transpose(1, 0), 1.0, scale)
.view_as(gradWeight[g]);
}
gradOutputBuffer = gradOutputBuffer.view(
{gradOutputBuffer.size(0),
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
gradOutputBuffer.size(3), gradOutputBuffer.size(4), gradOutputBuffer.size(5), gradOutputBuffer.size(6)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
gradWeight.size(2), gradWeight.size(3),
gradWeight.size(4), gradWeight.size(5)});
}
input = input.view({batchSize, nInputPlane, inputTime, inputHeight, inputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kT * kH * kW, outputTime, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputTime, outputHeight, outputWidth});
input = input.view({nInputPlane, inputTime, inputHeight, inputWidth});
}
return 1;
}
void modulated_deform_conv_cuda_forward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int deformable_group,
const bool with_bias) {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
// resize output
output = output.view({batch, channels_out, height_out, width_out}).zero_();
// resize temporary columns
columns =
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
input.options());
output = output.view({output.size(0), group, output.size(1) / group,
output.size(2), output.size(3)});
for (int b = 0; b < batch; b++) {
modulated_deformable_im2col_cuda(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
// divide into group
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
for (int g = 0; g < group; g++) {
output[b][g] = output[b][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output[b][g]);
}
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
output = output.view({output.size(0), output.size(1) * output.size(2),
output.size(3), output.size(4)});
if (with_bias) {
output += bias.view({1, bias.size(0), 1, 1});
}
}
void modulated_deform_conv_cuda_backward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor columns,
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias) {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
grad_input = grad_input.view({batch, channels, height, width});
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
input.options());
grad_output =
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
grad_output.size(2), grad_output.size(3)});
for (int b = 0; b < batch; b++) {
// divide int group
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++) {
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
grad_output[b][g].flatten(1), 0.0f, 1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(
columns, input[b], offset[b], mask[b], 1, channels, height, width,
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
grad_mask[b]);
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(
columns, offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, grad_input[b]);
// gradient w.r.t. weight, dWeight should accumulate across the batch and
// group
modulated_deformable_im2col_cuda(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
grad_weight.size(1), grad_weight.size(2),
grad_weight.size(3)});
if (with_bias)
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
for (int g = 0; g < group; g++) {
grad_weight[g] =
grad_weight[g]
.flatten(1)
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
.view_as(grad_weight[g]);
if (with_bias) {
grad_bias[g] =
grad_bias[g]
.view({-1, 1})
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
.view(-1);
}
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2), grad_weight.size(3),
grad_weight.size(4)});
if (with_bias)
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
}
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
grad_output.size(2), grad_output.size(3),
grad_output.size(4)});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda,
"deform forward (CUDA)");
m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda,
"deform_conv_backward_input (CUDA)");
m.def("deform_conv_backward_parameters_cuda",
&deform_conv_backward_parameters_cuda,
"deform_conv_backward_parameters (CUDA)");
m.def("modulated_deform_conv_cuda_forward",
&modulated_deform_conv_cuda_forward,
"modulated deform conv forward (CUDA)");
m.def("modulated_deform_conv_cuda_backward",
&modulated_deform_conv_cuda_backward,
"modulated deform conv backward (CUDA)");
}
/*!
******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
*
* COPYRIGHT
*
* All contributions by the University of California:
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
* All rights reserved.
*
* All other contributions:
* Copyright (c) 2014-2017, the respective contributors
* All rights reserved.
*
* Caffe uses a shared copyright model: each contributor holds copyright over
* their contributions to Caffe. The project versioning records all such
* contribution and copyright details. If a contributor wants to further mark
* their specific copyright on a particular contribution, they should indicate
* their copyright solely in the commit message of the change when it is
* committed.
*
* LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* CONTRIBUTION AGREEMENT
*
* By contributing to the BVLC/caffe repository through pull-request, comment,
* or otherwise, the contributor releases their content to the
* license and copyright terms herein.
*
***************** END Caffe Copyright Notice and Disclaimer ********************
*
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file modulated_deformable_im2col.cuh
* \brief Function definitions of converting an image to
* column matrix based on kernel, padding, dilation, and offset.
* These functions are mainly used in deformable convolution operators.
* \ref: https://arxiv.org/abs/1703.06211
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
*/
// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include <iostream>
using namespace at;
//#define DEBUG_INFO
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
#ifndef DEBUG_INFO //normal mode
const int CUDA_NUM_THREADS = 1024;
const int kMaxGridNum = 65535;
//const int CUDA_NUM_THREADS = 2;
//const int kMaxGridNum = 2;
#else //debug mode
const int CUDA_NUM_THREADS = 16;
const int kMaxGridNum = 1;
//const int CUDA_NUM_THREADS = 1;
//const int kMaxGridNum = 1;
#endif
inline int GET_BLOCKS(const int N)
{
return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
}
template <typename scalar_t>
__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data,
const int data_height, const int data_width,
const int height, const int width,
const int t, scalar_t h, scalar_t w)
{
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
scalar_t lh = h - h_low;
scalar_t lw = w - w_low;
scalar_t hh = 1 - lh, hw = 1 - lw;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = bottom_data[t * data_height * data_width + h_low * data_width + w_low];
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[t * data_height * data_width + h_low * data_width + w_high];
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[t * data_height * data_width + h_high * data_width + w_low];
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[t * data_height * data_width + h_high * data_width + w_high];
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
const int h, const int w, const int height, const int width)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
template <typename scalar_t>
__device__ scalar_t get_coordinate_weight(const int cur_t, scalar_t argmax_h, scalar_t argmax_w,
const int time, const int height, const int width, const scalar_t *im_data,
const int data_height, const int data_width, const int bp_dir)
{
if (cur_t <= -1 || cur_t >= time || argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (bp_dir == 0)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[cur_t * data_height * data_width + argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) * im_data[cur_t * data_height * data_width + argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) * im_data[cur_t * data_height * data_width + argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) * im_data[cur_t * data_height * data_width + argmax_h_high * data_width + argmax_w_high];
}
else if (bp_dir == 1)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[cur_t * data_height * data_width + argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) * im_data[cur_t * data_height * data_width + argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) * im_data[cur_t * data_height * data_width + argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) * im_data[cur_t * data_height * data_width + argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
template <typename scalar_t>
__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
const int time, const int height, const int width,
const int kernel_t, const int kernel_h, const int kernel_w,
const int pad_t, const int pad_h, const int pad_w,
const int stride_t, const int stride_h, const int stride_w,
const int dilation_t, const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int num_channels, const int deformable_group,
const int time_col, const int height_col, const int width_col,
scalar_t *data_col)
{
// #ifdef DEBUG_INFO
// if (threadIdx.x == 0) {
// printf("[cu]deformable_im2col_gpu_kernel\n");
// }
// #endif
CUDA_KERNEL_LOOP(index, n)
{
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int t_col = (index / width_col / height_col) % time_col;
const int b_col = (index / width_col / height_col / time_col) % batch_size;
const int c_im = (index / width_col / height_col / time_col) / batch_size;
const int c_col = c_im * kernel_t * kernel_h * kernel_w;
// #ifdef DEBUG_INFO
// if (threadIdx.x == 0) {
// printf("[cu:%d|%d]deformable_im2col_gpu_kernel: index(%d/%d)=%d,%d,%d,%d,%d,%d|%d,%d,%d\n",
// blockIdx.x, threadIdx.x, index, n, w_col, h_col, t_col, b_col, c_im, c_col, time_col, height_col, width_col);
// }
// #endif
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int t_in = t_col * stride_t - pad_t;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
// #ifdef DEBUG_INFO
// printf("[cu]deformable_im2col_gpu_kernel: data_col_ptr+=%d\n",
// ((((c_col * batch_size + b_col)) * time_col + t_col) * height_col + h_col) * width_col + w_col);
// printf("[cu]deformable_im2col_gpu_kernel: data_im_ptr+=%d\n",
// (b_col * num_channels + c_im) * time * height * width);
// printf("[cu]deformable_im2col_gpu_kernel: data_offset_ptr+=%d\n",
// (b_col * deformable_group + deformable_group_index) * 2 * kernel_t * kernel_h * kernel_w * time_col * height_col * width_col);
// #endif
scalar_t *data_col_ptr = data_col + ((((c_col * batch_size + b_col)) * time_col + t_col) * height_col + h_col) * width_col + w_col;
//const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * time * height * width;
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) *
2 * kernel_t * kernel_h * kernel_w * time_col * height_col * width_col;
for (int k = 0; k < kernel_t; ++k)
{
for (int i = 0; i < kernel_h; ++i)
{
for (int j = 0; j < kernel_w; ++j)
{
const int data_offset_h_ptr = (((2 * ((k * kernel_h + i) * kernel_w + j)) * time_col + t_col) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr = (((2 * ((k * kernel_h + i) * kernel_w + j) + 1) * time_col + t_col) * height_col + h_col) * width_col + w_col;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
scalar_t val = static_cast<scalar_t>(0);
const int t_im = t_in + k * dilation_t;
const scalar_t h_im = h_in + i * dilation_h + offset_h;
const scalar_t w_im = w_in + j * dilation_w + offset_w;
if (t_im > -1 && h_im > -1 && w_im > -1 && t_im < time && h_im < height && w_im < width)
{
//const scalar_t map_h = i * dilation_h + offset_h;
//const scalar_t map_w = j * dilation_w + offset_w;
//const int cur_height = height - h_in;
//const int cur_width = width - w_in;
//val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
// #ifdef DEBUG_INFO
// if (threadIdx.x == 0) {
// printf("[cu:%d|%d]deformable_im2col_gpu_kernel: thw_im=%d,%f,%f\n", blockIdx.x, threadIdx.x, t_im, h_im, w_im);
// printf("[cu:%d|%d]deformable_im2col_gpu_kernel: offset=%f,%f\n", blockIdx.x, threadIdx.x, offset_h, offset_w);
// }
// #endif
val = deformable_im2col_bilinear(data_im_ptr, height, width, height, width, t_im, h_im, w_im);
}
*data_col_ptr = val;
data_col_ptr += batch_size * time_col * height_col * width_col;
}
}
}
}
}
void deformable_im2col(
const at::Tensor data_im, const at::Tensor data_offset, const int channels,
const int time, const int height, const int width,
const int ksize_t, const int ksize_h, const int ksize_w,
const int pad_t, const int pad_h, const int pad_w,
const int stride_t, const int stride_h, const int stride_w,
const int dilation_t, const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group, at::Tensor data_col)
{
#ifdef DEBUG_INFO
printf("[cu]deformable_im2col\n");
#endif
// num_axes should be smaller than block size
// todo: check parallel_imgs is correctly passed in
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int time_col = (time + 2 * pad_t - (dilation_t * (ksize_t - 1) + 1)) / stride_t + 1;
int num_kernels = channels * height_col * width_col * time_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
#ifdef DEBUG_INFO
printf("[cu]deformable_im2col: thw_col=%d,%d,%d\n", time_col, height_col, width_col);
printf("[cu]deformable_im2col: num_kernels=%d\n", num_kernels);
#endif
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
const scalar_t *data_im_ = data_im.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
scalar_t *data_col_ = data_col.data<scalar_t>();
deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_im_, data_offset_, time, height, width, ksize_t, ksize_h, ksize_w,
pad_t, pad_h, pad_w, stride_t, stride_h, stride_w, dilation_t, dilation_h, dilation_w,
channel_per_deformable_group, parallel_imgs, channels, deformable_group,
time_col, height_col, width_col, data_col_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
__global__ void deformable_col2im_gpu_kernel(
const int n, const scalar_t *data_col, const scalar_t *data_offset,
const int channels, const int time, const int height, const int width,
const int kernel_t, const int kernel_h, const int kernel_w,
const int pad_t, const int pad_h, const int pad_w,
const int stride_t, const int stride_h, const int stride_w,
const int dilation_t, const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int deformable_group,
const int time_col, const int height_col, const int width_col,
scalar_t *grad_im)
{
// #ifdef DEBUG_INFO
// if (threadIdx.x == 0) {
// printf("[cu]deformable_col2im_gpu_kernel\n");
// }
// #endif
CUDA_KERNEL_LOOP(index, n)
{
const int j = (index / width_col / height_col / time_col / batch_size) % kernel_w;
const int i = (index / width_col / height_col / time_col / batch_size / kernel_w) % kernel_h;
const int k = (index / width_col / height_col / time_col / batch_size / kernel_w / kernel_h) % kernel_t;
const int c = index / width_col / height_col / time_col / batch_size / kernel_w / kernel_h / kernel_t;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int t_out = (index / width_col / height_col) % time_col;
int b = (index / width_col / height_col / time_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
int t_in = t_out * stride_t - pad_t;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
2 * kernel_t * kernel_h * kernel_w * time_col * height_col * width_col;
const int data_offset_h_ptr = (((2 * ((k * kernel_h + i) * kernel_w + j)) * time_col + t_out) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr = (((2 * ((k * kernel_h + i) * kernel_w + j) + 1) * time_col + t_out) * height_col + h_out) * width_col + w_out;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const int cur_inv_t_data = t_in + k * dilation_t;
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
const scalar_t cur_top_grad = data_col[index];
const int cur_t = cur_inv_t_data;
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++)
{
for (int dx = -2; dx <= 2; dx++)
{
if (cur_t >= 0 && cur_t < time &&
cur_h + dy >= 0 && cur_h + dy < height &&
cur_w + dx >= 0 && cur_w + dx < width &&
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1)
{
int cur_bottom_grad_pos = (((b * channels + c) * time + cur_t) * height + cur_h + dy) * width + cur_w + dx;
#ifdef DEBUG_INFO
if (threadIdx.x == 0) {
printf("[cu:%d|%d]deformable_col2im_gpu_kernel: cur_thw=%d,%d,%d dyx=%d,%d\n",
blockIdx.x, threadIdx.x, cur_t, cur_h + dy, cur_w + dx, dy, dx);
}
#endif
scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
}
}
}
}
}
void deformable_col2im(
const at::Tensor data_col, const at::Tensor data_offset, const int channels,
const int time, const int height, const int width,
const int ksize_t, const int ksize_h, const int ksize_w,
const int pad_t, const int pad_h, const int pad_w,
const int stride_t, const int stride_h, const int stride_w,
const int dilation_t, const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
at::Tensor grad_im)
{
#ifdef DEBUG_INFO
printf("[cu]deformable_col2im\n");
#endif
// todo: make sure parallel_imgs is passed in correctly
int time_col = (time + 2 * pad_t - (dilation_t * (ksize_t - 1) + 1)) / stride_t + 1;
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * ksize_t * ksize_h * ksize_w * time_col * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
const scalar_t *data_col_ = data_col.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
scalar_t *grad_im_ = grad_im.data<scalar_t>();
deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_offset_, channels, time, height, width,
ksize_t, ksize_h, ksize_w, pad_t, pad_h, pad_w, stride_t, stride_h, stride_w,
dilation_t, dilation_h, dilation_w, channel_per_deformable_group,
parallel_imgs, deformable_group, time_col, height_col, width_col, grad_im_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
const scalar_t *data_im, const scalar_t *data_offset,
const int channels, const int time, const int height, const int width,
const int kernel_t, const int kernel_h, const int kernel_w,
const int pad_t, const int pad_h, const int pad_w,
const int stride_t, const int stride_h, const int stride_w,
const int dilation_t, const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int offset_channels, const int deformable_group,
const int time_col, const int height_col, const int width_col,
scalar_t *grad_offset)
{
// #ifdef DEBUG_INFO
// if (threadIdx.x == 0) {
// printf("[cu]deformable_col2im_coord_gpu_kernel\n");
// }
// #endif
CUDA_KERNEL_LOOP(index, n)
{
scalar_t val = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int t = (index / width_col / height_col) % time_col;
int c = (index / width_col / height_col / time_col) % offset_channels;
int b = (index / width_col / height_col / time_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_t * kernel_h * kernel_w);
const int col_step = kernel_t * kernel_h * kernel_w;
int cnt = 0;
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
batch_size * width_col * height_col * time_col;
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
channel_per_deformable_group / kernel_t / kernel_h / kernel_w * time * height * width;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
kernel_t * kernel_h * kernel_w * time_col * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_t * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
{
const int col_pos = (((col_c * batch_size + b) * time_col + t) * height_col + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / time_col / batch_size) % kernel_w;
int i = (col_pos / width_col / height_col / time_col / batch_size / kernel_w) % kernel_h;
int k = (col_pos / width_col / height_col / time_col / batch_size / kernel_w / kernel_h) % kernel_t;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int t_out = (col_pos / width_col / height_col) % time_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
int t_in = t_out * stride_t - pad_t;
const int data_offset_h_ptr = (((2 * ((k * kernel_h + i) * kernel_w + j)) * time_col + t_out) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr = (((2 * ((k * kernel_h + i) * kernel_w + j) + 1) * time_col + t_out) * height_col + h_out) * width_col + w_out;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
int inv_t = t_in + k * dilation_t;
scalar_t inv_h = h_in + i * dilation_h + offset_h;
scalar_t inv_w = w_in + j * dilation_w + offset_w;
if (inv_t <= -1 || inv_h <= -1 || inv_w <= -1 || inv_t >= time || inv_h >= height || inv_w >= width)
{
inv_t = inv_h = inv_w = -2;
}
// #ifdef DEBUG_INFO
// if (threadIdx.x == 0) {
// printf("[cu]deformable_col2im_coord_gpu_kernel: inv_thw=%d,%f,%f\n", inv_t, inv_h, inv_w);
// }
// #endif
const scalar_t weight = get_coordinate_weight(
inv_t, inv_h, inv_w,
time, height, width, data_im_ptr + cnt * time * height * width, height, width, bp_dir);
val += weight * data_col_ptr[col_pos];
cnt += 1;
}
grad_offset[index] = val;
}
}
void deformable_col2im_coord(
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
const int channels, const int time, const int height, const int width,
const int ksize_t, const int ksize_h, const int ksize_w,
const int pad_t, const int pad_h, const int pad_w,
const int stride_t, const int stride_h, const int stride_w,
const int dilation_t, const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
{
#ifdef DEBUG_INFO
// if (threadIdx.x == 0) {
printf("[cu]deformable_col2im_coord\n");
// }
#endif
int time_col = (time + 2 * pad_t - (dilation_t * (ksize_t - 1) + 1)) / stride_t + 1;
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = time_col * height_col * width_col * 2 * ksize_t * ksize_h * ksize_w * deformable_group * parallel_imgs;
int channel_per_deformable_group = channels * ksize_t * ksize_h * ksize_w / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
const scalar_t *data_col_ = data_col.data<scalar_t>();
const scalar_t *data_im_ = data_im.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_im_, data_offset_, channels, time, height, width,
ksize_t, ksize_h, ksize_w, pad_t, pad_h, pad_w, stride_t, stride_h, stride_w,
dilation_t, dilation_h, dilation_w, channel_per_deformable_group,
parallel_imgs, 2 * ksize_t * ksize_h * ksize_w * deformable_group, deformable_group,
time_col, height_col, width_col, grad_offset_);
}));
}
template <typename scalar_t>
__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
const int height, const int width, scalar_t h, scalar_t w)
{
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
scalar_t lh = h - h_low;
scalar_t lw = w - w_low;
scalar_t hh = 1 - lh, hw = 1 - lw;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = bottom_data[h_low * data_width + w_low];
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
const int h, const int w, const int height, const int width)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
template <typename scalar_t>
__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
const int height, const int width, const scalar_t *im_data,
const int data_width, const int bp_dir)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (bp_dir == 0)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
else if (bp_dir == 1)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
template <typename scalar_t>
__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int num_channels, const int deformable_group,
const int height_col, const int width_col,
scalar_t *data_col)
{
CUDA_KERNEL_LOOP(index, n)
{
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
//const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i)
{
for (int j = 0; j < kernel_w; ++j)
{
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
scalar_t val = static_cast<scalar_t>(0);
const scalar_t h_im = h_in + i * dilation_h + offset_h;
const scalar_t w_im = w_in + j * dilation_w + offset_w;
//if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
{
//const float map_h = i * dilation_h + offset_h;
//const float map_w = j * dilation_w + offset_w;
//const int cur_height = height - h_in;
//const int cur_width = width - w_in;
//val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val * mask;
data_col_ptr += batch_size * height_col * width_col;
//data_col_ptr += height_col * width_col;
}
}
}
}
template <typename scalar_t>
__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int deformable_group,
const int height_col, const int width_col,
scalar_t *grad_im)
{
CUDA_KERNEL_LOOP(index, n)
{
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
const scalar_t cur_top_grad = data_col[index] * mask;
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++)
{
for (int dx = -2; dx <= 2; dx++)
{
if (cur_h + dy >= 0 && cur_h + dy < height &&
cur_w + dx >= 0 && cur_w + dx < width &&
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1)
{
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
}
}
}
}
}
template <typename scalar_t>
__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
const scalar_t *data_col, const scalar_t *data_im,
const scalar_t *data_offset, const scalar_t *data_mask,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int offset_channels, const int deformable_group,
const int height_col, const int width_col,
scalar_t *grad_offset, scalar_t *grad_mask)
{
CUDA_KERNEL_LOOP(index, n)
{
scalar_t val = 0, mval = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
{
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
scalar_t inv_h = h_in + i * dilation_h + offset_h;
scalar_t inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
{
inv_h = inv_w = -2;
}
else
{
mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
}
const scalar_t weight = dmcn_get_coordinate_weight(
inv_h, inv_w,
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
val += weight * data_col_ptr[col_pos] * mask;
cnt += 1;
}
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
grad_offset[index] = val;
if (offset_c % 2 == 0)
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
}
}
void modulated_deformable_im2col_cuda(
const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor data_col)
{
// num_axes should be smaller than block size
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * batch_size * height_col * width_col;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
const scalar_t *data_im_ = data_im.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
scalar_t *data_col_ = data_col.data<scalar_t>();
modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
batch_size, channels, deformable_group, height_col, width_col, data_col_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
}
}
void modulated_deformable_col2im_cuda(
const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor grad_im)
{
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
const scalar_t *data_col_ = data_col.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
scalar_t *grad_im_ = grad_im.data<scalar_t>();
modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, deformable_group, height_col, width_col, grad_im_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
}
}
void modulated_deformable_col2im_coord_cuda(
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group,
at::Tensor grad_offset, at::Tensor grad_mask)
{
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
const scalar_t *data_col_ = data_col.data<scalar_t>();
const scalar_t *data_im_ = data_im.data<scalar_t>();
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
scalar_t *grad_mask_ = grad_mask.data<scalar_t>();
modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
grad_offset_, grad_mask_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
}
}
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c
// based on
// author: Charles Shang
// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
#include <torch/extension.h>
#include <ATen/DeviceGuard.h>
#include <cmath>
#include <vector>
void DeformablePSROIPoolForward(
const at::Tensor data, const at::Tensor bbox, const at::Tensor trans,
at::Tensor out, at::Tensor top_count, const int batch, const int channels,
const int height, const int width, const int num_bbox,
const int channels_trans, const int no_trans, const float spatial_scale,
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std);
void DeformablePSROIPoolBackwardAcc(
const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox,
const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad,
at::Tensor trans_grad, const int batch, const int channels,
const int height, const int width, const int num_bbox,
const int channels_trans, const int no_trans, const float spatial_scale,
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std);
void deform_psroi_pooling_cuda_forward(
at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
at::Tensor top_count, const int no_trans, const float spatial_scale,
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std) {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0);
if (num_bbox != out.size(0))
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
out.size(0), num_bbox);
DeformablePSROIPoolForward(
input, bbox, trans, out, top_count, batch, channels, height, width,
num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size,
pooled_size, part_size, sample_per_part, trans_std);
}
void deform_psroi_pooling_cuda_backward(
at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
const int no_trans, const float spatial_scale, const int output_dim,
const int group_size, const int pooled_size, const int part_size,
const int sample_per_part, const float trans_std) {
AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0);
if (num_bbox != out_grad.size(0))
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
out_grad.size(0), num_bbox);
DeformablePSROIPoolBackwardAcc(
out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch,
channels, height, width, num_bbox, channels_trans, no_trans,
spatial_scale, output_dim, group_size, pooled_size, part_size,
sample_per_part, trans_std);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward,
"deform psroi pooling forward(CUDA)");
m.def("deform_psroi_pooling_cuda_backward",
&deform_psroi_pooling_cuda_backward,
"deform psroi pooling backward(CUDA)");
}
/*!
* Copyright (c) 2017 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file deformable_psroi_pooling.cu
* \brief
* \author Yi Li, Guodong Zhang, Jifeng Dai
*/
/***************** Adapted by Charles Shang *********************/
// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/cuda/deform_psroi_pooling_cuda.cu
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
#include <stdio.h>
#include <math.h>
#include <algorithm>
using namespace at;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N)
{
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
template <typename scalar_t>
__device__ scalar_t bilinear_interp(
const scalar_t *data,
const scalar_t x,
const scalar_t y,
const int width,
const int height)
{
int x1 = floor(x);
int x2 = ceil(x);
int y1 = floor(y);
int y2 = ceil(y);
scalar_t dist_x = (scalar_t)(x - x1);
scalar_t dist_y = (scalar_t)(y - y1);
scalar_t value11 = data[y1 * width + x1];
scalar_t value12 = data[y2 * width + x1];
scalar_t value21 = data[y1 * width + x2];
scalar_t value22 = data[y2 * width + x2];
scalar_t value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22;
return value;
}
template <typename scalar_t>
__global__ void DeformablePSROIPoolForwardKernel(
const int count,
const scalar_t *bottom_data,
const scalar_t spatial_scale,
const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const scalar_t *bottom_rois, const scalar_t *bottom_trans,
const int no_trans,
const scalar_t trans_std,
const int sample_per_part,
const int output_dim,
const int group_size,
const int part_size,
const int num_classes,
const int channels_each_class,
scalar_t *top_data,
scalar_t *top_count)
{
CUDA_KERNEL_LOOP(index, count)
{
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;
// [start, end) interval for spatial sampling
const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
// Force too small ROIs to be 1x1
scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
// Compute w and h at bottom
scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
int class_id = ctop / channels_each_class;
scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
wstart += trans_x * roi_width;
scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
hstart += trans_y * roi_height;
scalar_t sum = 0;
int count = 0;
int gw = floor((scalar_t)(pw)*group_size / pooled_width);
int gh = floor((scalar_t)(ph)*group_size / pooled_height);
gw = min(max(gw, 0), group_size - 1);
gh = min(max(gh, 0), group_size - 1);
const scalar_t *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
for (int ih = 0; ih < sample_per_part; ih++)
{
for (int iw = 0; iw < sample_per_part; iw++)
{
scalar_t w = wstart + iw * sub_bin_size_w;
scalar_t h = hstart + ih * sub_bin_size_h;
// bilinear interpolation
if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
{
continue;
}
w = min(max(w, 0.), width - 1.);
h = min(max(h, 0.), height - 1.);
int c = (ctop * group_size + gh) * group_size + gw;
scalar_t val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height);
sum += val;
count++;
}
}
top_data[index] = count == 0 ? (scalar_t)(0) : sum / count;
top_count[index] = count;
}
}
template <typename scalar_t>
__global__ void DeformablePSROIPoolBackwardAccKernel(
const int count,
const scalar_t *top_diff,
const scalar_t *top_count,
const int num_rois,
const scalar_t spatial_scale,
const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const int output_dim,
scalar_t *bottom_data_diff, scalar_t *bottom_trans_diff,
const scalar_t *bottom_data,
const scalar_t *bottom_rois,
const scalar_t *bottom_trans,
const int no_trans,
const scalar_t trans_std,
const int sample_per_part,
const int group_size,
const int part_size,
const int num_classes,
const int channels_each_class)
{
CUDA_KERNEL_LOOP(index, count)
{
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;
// [start, end) interval for spatial sampling
const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
// Force too small ROIs to be 1x1
scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
// Compute w and h at bottom
scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
int class_id = ctop / channels_each_class;
scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
wstart += trans_x * roi_width;
scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
hstart += trans_y * roi_height;
if (top_count[index] <= 0)
{
continue;
}
scalar_t diff_val = top_diff[index] / top_count[index];
const scalar_t *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
scalar_t *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
int gw = floor((scalar_t)(pw)*group_size / pooled_width);
int gh = floor((scalar_t)(ph)*group_size / pooled_height);
gw = min(max(gw, 0), group_size - 1);
gh = min(max(gh, 0), group_size - 1);
for (int ih = 0; ih < sample_per_part; ih++)
{
for (int iw = 0; iw < sample_per_part; iw++)
{
scalar_t w = wstart + iw * sub_bin_size_w;
scalar_t h = hstart + ih * sub_bin_size_h;
// bilinear interpolation
if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
{
continue;
}
w = min(max(w, 0.), width - 1.);
h = min(max(h, 0.), height - 1.);
int c = (ctop * group_size + gh) * group_size + gw;
// backward on feature
int x0 = floor(w);
int x1 = ceil(w);
int y0 = floor(h);
int y1 = ceil(h);
scalar_t dist_x = w - x0, dist_y = h - y0;
scalar_t q00 = (1 - dist_x) * (1 - dist_y);
scalar_t q01 = (1 - dist_x) * dist_y;
scalar_t q10 = dist_x * (1 - dist_y);
scalar_t q11 = dist_x * dist_y;
int bottom_index_base = c * height * width;
atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);
if (no_trans)
{
continue;
}
scalar_t U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
scalar_t U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
scalar_t U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
scalar_t U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
scalar_t diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;
diff_x *= roi_width;
scalar_t diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;
diff_y *= roi_height;
atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);
atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);
}
}
}
}
void DeformablePSROIPoolForward(const at::Tensor data,
const at::Tensor bbox,
const at::Tensor trans,
at::Tensor out,
at::Tensor top_count,
const int batch,
const int channels,
const int height,
const int width,
const int num_bbox,
const int channels_trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
const int pooled_height = pooled_size;
const int pooled_width = pooled_size;
const int count = num_bbox * output_dim * pooled_height * pooled_width;
const int num_classes = no_trans ? 1 : channels_trans / 2;
const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data.scalar_type(), "deformable_psroi_pool_forward", ([&] {
const scalar_t *bottom_data = data.data<scalar_t>();
const scalar_t *bottom_rois = bbox.data<scalar_t>();
const scalar_t *bottom_trans = no_trans ? NULL : trans.data<scalar_t>();
scalar_t *top_data = out.data<scalar_t>();
scalar_t *top_count_data = top_count.data<scalar_t>();
DeformablePSROIPoolForwardKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width,
bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim,
group_size, part_size, num_classes, channels_each_class, top_data, top_count_data);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
}
}
void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
const at::Tensor data,
const at::Tensor bbox,
const at::Tensor trans,
const at::Tensor top_count,
at::Tensor in_grad,
at::Tensor trans_grad,
const int batch,
const int channels,
const int height,
const int width,
const int num_bbox,
const int channels_trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
// LOG(INFO) << "DeformablePSROIPoolBackward";
const int num_rois = num_bbox;
const int pooled_height = pooled_size;
const int pooled_width = pooled_size;
const int count = num_bbox * output_dim * pooled_height * pooled_width;
const int num_classes = no_trans ? 1 : channels_trans / 2;
const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
out_grad.scalar_type(), "deformable_psroi_pool_backward_acc", ([&] {
const scalar_t *top_diff = out_grad.data<scalar_t>();
const scalar_t *bottom_data = data.data<scalar_t>();
const scalar_t *bottom_rois = bbox.data<scalar_t>();
const scalar_t *bottom_trans = no_trans ? NULL : trans.data<scalar_t>();
scalar_t *bottom_data_diff = in_grad.data<scalar_t>();
scalar_t *bottom_trans_diff = no_trans ? NULL : trans_grad.data<scalar_t>();
const scalar_t *top_count_data = top_count.data<scalar_t>();
DeformablePSROIPoolBackwardAccKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
count, top_diff, top_count_data, num_rois, (scalar_t)spatial_scale, channels, height, width,
pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff,
bottom_data, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part,
group_size, part_size, num_classes, channels_each_class);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
}
}
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c
// based on
// author: Charles Shang
// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
#include <torch/extension.h>
#include <ATen/DeviceGuard.h>
#include <cmath>
#include <vector>
void DeformablePSROIPoolForward(
const at::Tensor data, const at::Tensor bbox, const at::Tensor trans,
at::Tensor out, at::Tensor top_count, const int batch, const int channels,
const int height, const int width, const int num_bbox,
const int channels_trans, const int no_trans, const float spatial_scale,
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std);
void DeformablePSROIPoolBackwardAcc(
const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox,
const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad,
at::Tensor trans_grad, const int batch, const int channels,
const int height, const int width, const int num_bbox,
const int channels_trans, const int no_trans, const float spatial_scale,
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std);
void deform_psroi_pooling_cuda_forward(
at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
at::Tensor top_count, const int no_trans, const float spatial_scale,
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std) {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0);
if (num_bbox != out.size(0))
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
out.size(0), num_bbox);
DeformablePSROIPoolForward(
input, bbox, trans, out, top_count, batch, channels, height, width,
num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size,
pooled_size, part_size, sample_per_part, trans_std);
}
void deform_psroi_pooling_cuda_backward(
at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
const int no_trans, const float spatial_scale, const int output_dim,
const int group_size, const int pooled_size, const int part_size,
const int sample_per_part, const float trans_std) {
AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0);
if (num_bbox != out_grad.size(0))
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
out_grad.size(0), num_bbox);
DeformablePSROIPoolBackwardAcc(
out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch,
channels, height, width, num_bbox, channels_trans, no_trans,
spatial_scale, output_dim, group_size, pooled_size, part_size,
sample_per_part, trans_std);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward,
"deform psroi pooling forward(CUDA)");
m.def("deform_psroi_pooling_cuda_backward",
&deform_psroi_pooling_cuda_backward,
"deform psroi pooling backward(CUDA)");
}
/*!
* Copyright (c) 2017 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file deformable_psroi_pooling.cu
* \brief
* \author Yi Li, Guodong Zhang, Jifeng Dai
*/
/***************** Adapted by Charles Shang *********************/
// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/cuda/deform_psroi_pooling_cuda.cu
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
#include <stdio.h>
#include <math.h>
#include <algorithm>
using namespace at;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N)
{
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
template <typename scalar_t>
__device__ scalar_t bilinear_interp(
const scalar_t *data,
const scalar_t x,
const scalar_t y,
const int width,
const int height)
{
int x1 = floor(x);
int x2 = ceil(x);
int y1 = floor(y);
int y2 = ceil(y);
scalar_t dist_x = (scalar_t)(x - x1);
scalar_t dist_y = (scalar_t)(y - y1);
scalar_t value11 = data[y1 * width + x1];
scalar_t value12 = data[y2 * width + x1];
scalar_t value21 = data[y1 * width + x2];
scalar_t value22 = data[y2 * width + x2];
scalar_t value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22;
return value;
}
template <typename scalar_t>
__global__ void DeformablePSROIPoolForwardKernel(
const int count,
const scalar_t *bottom_data,
const scalar_t spatial_scale,
const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const scalar_t *bottom_rois, const scalar_t *bottom_trans,
const int no_trans,
const scalar_t trans_std,
const int sample_per_part,
const int output_dim,
const int group_size,
const int part_size,
const int num_classes,
const int channels_each_class,
scalar_t *top_data,
scalar_t *top_count)
{
CUDA_KERNEL_LOOP(index, count)
{
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;
// [start, end) interval for spatial sampling
const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
// Force too small ROIs to be 1x1
scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
// Compute w and h at bottom
scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
int class_id = ctop / channels_each_class;
scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
wstart += trans_x * roi_width;
scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
hstart += trans_y * roi_height;
scalar_t sum = 0;
int count = 0;
int gw = floor((scalar_t)(pw)*group_size / pooled_width);
int gh = floor((scalar_t)(ph)*group_size / pooled_height);
gw = min(max(gw, 0), group_size - 1);
gh = min(max(gh, 0), group_size - 1);
const scalar_t *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
for (int ih = 0; ih < sample_per_part; ih++)
{
for (int iw = 0; iw < sample_per_part; iw++)
{
scalar_t w = wstart + iw * sub_bin_size_w;
scalar_t h = hstart + ih * sub_bin_size_h;
// bilinear interpolation
if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
{
continue;
}
w = min(max(w, 0.), width - 1.);
h = min(max(h, 0.), height - 1.);
int c = (ctop * group_size + gh) * group_size + gw;
scalar_t val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height);
sum += val;
count++;
}
}
top_data[index] = count == 0 ? (scalar_t)(0) : sum / count;
top_count[index] = count;
}
}
template <typename scalar_t>
__global__ void DeformablePSROIPoolBackwardAccKernel(
const int count,
const scalar_t *top_diff,
const scalar_t *top_count,
const int num_rois,
const scalar_t spatial_scale,
const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const int output_dim,
scalar_t *bottom_data_diff, scalar_t *bottom_trans_diff,
const scalar_t *bottom_data,
const scalar_t *bottom_rois,
const scalar_t *bottom_trans,
const int no_trans,
const scalar_t trans_std,
const int sample_per_part,
const int group_size,
const int part_size,
const int num_classes,
const int channels_each_class)
{
CUDA_KERNEL_LOOP(index, count)
{
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;
// [start, end) interval for spatial sampling
const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
// Force too small ROIs to be 1x1
scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
// Compute w and h at bottom
scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
int class_id = ctop / channels_each_class;
scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
wstart += trans_x * roi_width;
scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
hstart += trans_y * roi_height;
if (top_count[index] <= 0)
{
continue;
}
scalar_t diff_val = top_diff[index] / top_count[index];
const scalar_t *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
scalar_t *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
int gw = floor((scalar_t)(pw)*group_size / pooled_width);
int gh = floor((scalar_t)(ph)*group_size / pooled_height);
gw = min(max(gw, 0), group_size - 1);
gh = min(max(gh, 0), group_size - 1);
for (int ih = 0; ih < sample_per_part; ih++)
{
for (int iw = 0; iw < sample_per_part; iw++)
{
scalar_t w = wstart + iw * sub_bin_size_w;
scalar_t h = hstart + ih * sub_bin_size_h;
// bilinear interpolation
if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
{
continue;
}
w = min(max(w, 0.), width - 1.);
h = min(max(h, 0.), height - 1.);
int c = (ctop * group_size + gh) * group_size + gw;
// backward on feature
int x0 = floor(w);
int x1 = ceil(w);
int y0 = floor(h);
int y1 = ceil(h);
scalar_t dist_x = w - x0, dist_y = h - y0;
scalar_t q00 = (1 - dist_x) * (1 - dist_y);
scalar_t q01 = (1 - dist_x) * dist_y;
scalar_t q10 = dist_x * (1 - dist_y);
scalar_t q11 = dist_x * dist_y;
int bottom_index_base = c * height * width;
atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);
if (no_trans)
{
continue;
}
scalar_t U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
scalar_t U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
scalar_t U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
scalar_t U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
scalar_t diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;
diff_x *= roi_width;
scalar_t diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;
diff_y *= roi_height;
atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);
atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);
}
}
}
}
void DeformablePSROIPoolForward(const at::Tensor data,
const at::Tensor bbox,
const at::Tensor trans,
at::Tensor out,
at::Tensor top_count,
const int batch,
const int channels,
const int height,
const int width,
const int num_bbox,
const int channels_trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
const int pooled_height = pooled_size;
const int pooled_width = pooled_size;
const int count = num_bbox * output_dim * pooled_height * pooled_width;
const int num_classes = no_trans ? 1 : channels_trans / 2;
const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data.scalar_type(), "deformable_psroi_pool_forward", ([&] {
const scalar_t *bottom_data = data.data<scalar_t>();
const scalar_t *bottom_rois = bbox.data<scalar_t>();
const scalar_t *bottom_trans = no_trans ? NULL : trans.data<scalar_t>();
scalar_t *top_data = out.data<scalar_t>();
scalar_t *top_count_data = top_count.data<scalar_t>();
DeformablePSROIPoolForwardKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width,
bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim,
group_size, part_size, num_classes, channels_each_class, top_data, top_count_data);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
}
}
void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
const at::Tensor data,
const at::Tensor bbox,
const at::Tensor trans,
const at::Tensor top_count,
at::Tensor in_grad,
at::Tensor trans_grad,
const int batch,
const int channels,
const int height,
const int width,
const int num_bbox,
const int channels_trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
// LOG(INFO) << "DeformablePSROIPoolBackward";
const int num_rois = num_bbox;
const int pooled_height = pooled_size;
const int pooled_width = pooled_size;
const int count = num_bbox * output_dim * pooled_height * pooled_width;
const int num_classes = no_trans ? 1 : channels_trans / 2;
const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
out_grad.scalar_type(), "deformable_psroi_pool_backward_acc", ([&] {
const scalar_t *top_diff = out_grad.data<scalar_t>();
const scalar_t *bottom_data = data.data<scalar_t>();
const scalar_t *bottom_rois = bbox.data<scalar_t>();
const scalar_t *bottom_trans = no_trans ? NULL : trans.data<scalar_t>();
scalar_t *bottom_data_diff = in_grad.data<scalar_t>();
scalar_t *bottom_trans_diff = no_trans ? NULL : trans_grad.data<scalar_t>();
const scalar_t *top_count_data = top_count.data<scalar_t>();
DeformablePSROIPoolBackwardAccKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
count, top_diff, top_count_data, num_rois, (scalar_t)spatial_scale, channels, height, width,
pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff,
bottom_data, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part,
group_size, part_size, num_classes, channels_each_class);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
}
}
...@@ -44,7 +44,7 @@ def make_cuda_ext(name, module, sources): ...@@ -44,7 +44,7 @@ def make_cuda_ext(name, module, sources):
if __name__ == '__main__': if __name__ == '__main__':
setup( setup(
name='rodnet', name='rodnet',
version='1.0', version='1.1',
description='RODNet: Object Detection from Radar Data', description='RODNet: Object Detection from Radar Data',
long_description=readme(), long_description=readme(),
long_description_content_type='text/markdown', long_description_content_type='text/markdown',
...@@ -77,7 +77,39 @@ if __name__ == '__main__': ...@@ -77,7 +77,39 @@ if __name__ == '__main__':
keywords='rodnet, object detection, radar, autonomous driving', keywords='rodnet, object detection, radar, autonomous driving',
packages=find_packages(include=["rodnet.*"]), packages=find_packages(include=["rodnet.*"]),
package_data={'rodnet.ops': ['*/*.so']},
python_requires='>=3.6', python_requires='>=3.6',
install_requires=get_requirements(), install_requires=get_requirements(),
ext_modules=[
make_cuda_ext(
name='deform_conv_2d_cuda',
module='rodnet.ops.dcn',
sources=[
'src/deform_conv_2d_cuda.cpp',
'src/deform_conv_2d_cuda_kernel.cu'
]),
make_cuda_ext(
name='deform_conv_3d_cuda',
module='rodnet.ops.dcn',
sources=[
'src/deform_conv_3d_cuda.cpp',
'src/deform_conv_3d_cuda_kernel.cu'
]),
make_cuda_ext(
name='deform_pool_2d_cuda',
module='rodnet.ops.dcn',
sources=[
'src/deform_pool_2d_cuda.cpp',
'src/deform_pool_2d_cuda_kernel.cu'
]),
make_cuda_ext(
name='deform_pool_3d_cuda',
module='rodnet.ops.dcn',
sources=[
'src/deform_pool_3d_cuda.cpp',
'src/deform_pool_3d_cuda_kernel.cu'
]),
],
cmdclass={'build_ext': BuildExtension},
zip_safe=False zip_safe=False
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment