"sims/mem/memswitch/memswitch.cc" did not exist on "1b258eee5c6f27f7553e9afdb92302d0911d06c6"
Unverified Commit 896122c8 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #308 from hellock/pytorch-1.0

sync pytorch-1.0 with master branch
parents 69f40c98 7f317a0a
import torch
from torch.autograd import Function
from torch.nn.modules.utils import _pair
from .. import deform_conv_cuda
class DeformConvFunction(Function):
@staticmethod
def forward(ctx,
input,
offset,
weight,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
im2col_step=64):
if input is not None and input.dim() != 4:
raise ValueError(
"Expected 4D tensor as input, got {}D tensor instead.".format(
input.dim()))
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.im2col_step = im2col_step
ctx.save_for_backward(input, offset, weight)
output = input.new_empty(
DeformConvFunction._output_size(input, weight, ctx.padding,
ctx.dilation, ctx.stride))
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
if not input.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] %
cur_im2col_step) == 0, 'im2col step must divide batchsize'
deform_conv_cuda.deform_conv_forward_cuda(
input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1],
weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.groups, ctx.deformable_groups,
cur_im2col_step)
return output
@staticmethod
def backward(ctx, grad_output):
input, offset, weight = ctx.saved_tensors
grad_input = grad_offset = grad_weight = None
if not grad_output.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] %
cur_im2col_step) == 0, 'im2col step must divide batchsize'
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
deform_conv_cuda.deform_conv_backward_input_cuda(
input, offset, grad_output, grad_input,
grad_offset, weight, ctx.bufs_[0], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.groups, ctx.deformable_groups,
cur_im2col_step)
if ctx.needs_input_grad[2]:
grad_weight = torch.zeros_like(weight)
deform_conv_cuda.deform_conv_backward_parameters_cuda(
input, offset, grad_output,
grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
cur_im2col_step)
return (grad_input, grad_offset, grad_weight, None, None, None, None,
None)
@staticmethod
def _output_size(input, weight, padding, dilation, stride):
channels = weight.size(0)
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = padding[d]
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
stride_ = stride[d]
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
if not all(map(lambda s: s > 0, output_size)):
raise ValueError(
"convolution input is too small (output would be {})".format(
'x'.join(map(str, output_size))))
return output_size
class ModulatedDeformConvFunction(Function):
@staticmethod
def forward(ctx,
input,
offset,
mask,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.with_bias = bias is not None
if not ctx.with_bias:
bias = input.new_empty(1) # fake tensor
if not input.is_cuda:
raise NotImplementedError
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
or input.requires_grad:
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(
ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
deform_conv_cuda.modulated_deform_conv_cuda_forward(
input, weight, bias, ctx._bufs[0], offset, mask, output,
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.groups, ctx.deformable_groups, ctx.with_bias)
return output
@staticmethod
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
deform_conv_cuda.modulated_deform_conv_cuda_backward(
input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.groups, ctx.deformable_groups, ctx.with_bias)
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None, None)
@staticmethod
def _infer_shape(ctx, input, weight):
n = input.size(0)
channels_out = weight.size(0)
height, width = input.shape[2:4]
kernel_h, kernel_w = weight.shape[2:4]
height_out = (height + 2 * ctx.padding -
(ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
width_out = (width + 2 * ctx.padding -
(ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
return n, channels_out, height_out, width_out
deform_conv = DeformConvFunction.apply
modulated_deform_conv = ModulatedDeformConvFunction.apply
import torch
from torch.autograd import Function
from .. import deform_pool_cuda
class DeformRoIPoolingFunction(Function):
@staticmethod
def forward(ctx,
data,
rois,
offset,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0):
ctx.spatial_scale = spatial_scale
ctx.out_size = out_size
ctx.out_channels = out_channels
ctx.no_trans = no_trans
ctx.group_size = group_size
ctx.part_size = out_size if part_size is None else part_size
ctx.sample_per_part = sample_per_part
ctx.trans_std = trans_std
assert 0.0 <= ctx.trans_std <= 1.0
if not data.is_cuda:
raise NotImplementedError
n = rois.shape[0]
output = data.new_empty(n, out_channels, out_size, out_size)
output_count = data.new_empty(n, out_channels, out_size, out_size)
deform_pool_cuda.deform_psroi_pooling_cuda_forward(
data, rois, offset, output, output_count, ctx.no_trans,
ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size,
ctx.part_size, ctx.sample_per_part, ctx.trans_std)
if data.requires_grad or rois.requires_grad or offset.requires_grad:
ctx.save_for_backward(data, rois, offset)
ctx.output_count = output_count
return output
@staticmethod
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
data, rois, offset = ctx.saved_tensors
output_count = ctx.output_count
grad_input = torch.zeros_like(data)
grad_rois = None
grad_offset = torch.zeros_like(offset)
deform_pool_cuda.deform_psroi_pooling_cuda_backward(
grad_output, data, rois, offset, output_count, grad_input,
grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels,
ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part,
ctx.trans_std)
return (grad_input, grad_rois, grad_offset, None, None, None, None,
None, None, None, None)
deform_roi_pooling = DeformRoIPoolingFunction.apply
import math
import torch
import torch.nn as nn
from torch.nn.modules.utils import _pair
from ..functions.deform_conv import deform_conv, modulated_deform_conv
class DeformConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=False):
assert not bias
super(DeformConv, self).__init__()
assert in_channels % groups == 0, \
'in_channels {} cannot be divisible by groups {}'.format(
in_channels, groups)
assert out_channels % groups == 0, \
'out_channels {} cannot be divisible by groups {}'.format(
out_channels, groups)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deformable_groups = deformable_groups
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // self.groups,
*self.kernel_size))
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, input, offset):
return deform_conv(input, offset, self.weight, self.stride,
self.padding, self.dilation, self.groups,
self.deformable_groups)
class ModulatedDeformConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.deformable_groups = deformable_groups
self.with_bias = bias
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups,
*self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, input, offset, mask):
return modulated_deform_conv(
input, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups, self.deformable_groups)
class ModulatedDeformConvPack(ModulatedDeformConv):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConvPack, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, deformable_groups, bias)
self.conv_offset_mask = nn.Conv2d(
self.in_channels // self.groups,
self.deformable_groups * 3 * self.kernel_size[0] *
self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()
def forward(self, input):
out = self.conv_offset_mask(input)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv(
input, offset, mask, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups, self.deformable_groups)
from torch import nn
from ..functions.deform_pool import deform_roi_pooling
class DeformRoIPooling(nn.Module):
def __init__(self,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0):
super(DeformRoIPooling, self).__init__()
self.spatial_scale = spatial_scale
self.out_size = out_size
self.out_channels = out_channels
self.no_trans = no_trans
self.group_size = group_size
self.part_size = out_size if part_size is None else part_size
self.sample_per_part = sample_per_part
self.trans_std = trans_std
def forward(self, data, rois, offset):
if self.no_trans:
offset = data.new_empty(0)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
class DeformRoIPoolingPack(DeformRoIPooling):
def __init__(self,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0,
deform_fc_channels=1024):
super(DeformRoIPoolingPack,
self).__init__(spatial_scale, out_size, out_channels, no_trans,
group_size, part_size, sample_per_part, trans_std)
self.deform_fc_channels = deform_fc_channels
if not no_trans:
self.offset_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 2))
self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_()
def forward(self, data, rois):
assert data.size(1) == self.out_channels
if self.no_trans:
offset = data.new_empty(0)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std)
else:
n = rois.shape[0]
offset = data.new_empty(0)
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels, True,
self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size, self.out_size)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std)
class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
def __init__(self,
spatial_scale,
out_size,
out_channels,
no_trans,
group_size=1,
part_size=None,
sample_per_part=4,
trans_std=.0,
deform_fc_channels=1024):
super(ModulatedDeformRoIPoolingPack, self).__init__(
spatial_scale, out_size, out_channels, no_trans, group_size,
part_size, sample_per_part, trans_std)
self.deform_fc_channels = deform_fc_channels
if not no_trans:
self.offset_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 2))
self.offset_fc[-1].weight.data.zero_()
self.offset_fc[-1].bias.data.zero_()
self.mask_fc = nn.Sequential(
nn.Linear(self.out_size * self.out_size * self.out_channels,
self.deform_fc_channels),
nn.ReLU(inplace=True),
nn.Linear(self.deform_fc_channels,
self.out_size * self.out_size * 1),
nn.Sigmoid())
self.mask_fc[2].weight.data.zero_()
self.mask_fc[2].bias.data.zero_()
def forward(self, data, rois):
assert data.size(1) == self.out_channels
if self.no_trans:
offset = data.new_empty(0)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std)
else:
n = rois.shape[0]
offset = data.new_empty(0)
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
self.out_size, self.out_channels, True,
self.group_size, self.part_size,
self.sample_per_part, self.trans_std)
offset = self.offset_fc(x.view(n, -1))
offset = offset.view(n, 2, self.out_size, self.out_size)
mask = self.mask_fc(x.view(n, -1))
mask = mask.view(n, 1, self.out_size, self.out_size)
return deform_roi_pooling(
data, rois, offset, self.spatial_scale, self.out_size,
self.out_channels, self.no_trans, self.group_size,
self.part_size, self.sample_per_part, self.trans_std) * mask
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='deform_conv',
ext_modules=[
CUDAExtension('deform_conv_cuda', [
'src/deform_conv_cuda.cpp',
'src/deform_conv_cuda_kernel.cu',
]),
CUDAExtension('deform_pool_cuda', [
'src/deform_pool_cuda.cpp', 'src/deform_pool_cuda_kernel.cu'
]),
],
cmdclass={'build_ext': BuildExtension})
This diff is collapsed.
This diff is collapsed.
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c
// based on
// author: Charles Shang
// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
#include <torch/extension.h>
#include <cmath>
#include <vector>
void DeformablePSROIPoolForward(
const at::Tensor data, const at::Tensor bbox, const at::Tensor trans,
at::Tensor out, at::Tensor top_count, const int batch, const int channels,
const int height, const int width, const int num_bbox,
const int channels_trans, const int no_trans, const float spatial_scale,
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std);
void DeformablePSROIPoolBackwardAcc(
const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox,
const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad,
at::Tensor trans_grad, const int batch, const int channels,
const int height, const int width, const int num_bbox,
const int channels_trans, const int no_trans, const float spatial_scale,
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std);
void deform_psroi_pooling_cuda_forward(
at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
at::Tensor top_count, const int no_trans, const float spatial_scale,
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std) {
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0);
if (num_bbox != out.size(0))
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
out.size(0), num_bbox);
DeformablePSROIPoolForward(
input, bbox, trans, out, top_count, batch, channels, height, width,
num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size,
pooled_size, part_size, sample_per_part, trans_std);
}
void deform_psroi_pooling_cuda_backward(
at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
const int no_trans, const float spatial_scale, const int output_dim,
const int group_size, const int pooled_size, const int part_size,
const int sample_per_part, const float trans_std) {
AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_trans = no_trans ? 2 : trans.size(1);
const int num_bbox = bbox.size(0);
if (num_bbox != out_grad.size(0))
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
out_grad.size(0), num_bbox);
DeformablePSROIPoolBackwardAcc(
out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch,
channels, height, width, num_bbox, channels_trans, no_trans,
spatial_scale, output_dim, group_size, pooled_size, part_size,
sample_per_part, trans_std);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward,
"deform psroi pooling forward(CUDA)");
m.def("deform_psroi_pooling_cuda_backward",
&deform_psroi_pooling_cuda_backward,
"deform psroi pooling backward(CUDA)");
}
\ No newline at end of file
/*!
* Copyright (c) 2017 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file deformable_psroi_pooling.cu
* \brief
* \author Yi Li, Guodong Zhang, Jifeng Dai
*/
/***************** Adapted by Charles Shang *********************/
// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/cuda/deform_psroi_pooling_cuda.cu
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
#include <stdio.h>
#include <math.h>
#include <algorithm>
using namespace at;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N)
{
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
template <typename scalar_t>
__device__ scalar_t bilinear_interp(
const scalar_t *data,
const scalar_t x,
const scalar_t y,
const int width,
const int height)
{
int x1 = floor(x);
int x2 = ceil(x);
int y1 = floor(y);
int y2 = ceil(y);
scalar_t dist_x = (scalar_t)(x - x1);
scalar_t dist_y = (scalar_t)(y - y1);
scalar_t value11 = data[y1 * width + x1];
scalar_t value12 = data[y2 * width + x1];
scalar_t value21 = data[y1 * width + x2];
scalar_t value22 = data[y2 * width + x2];
scalar_t value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22;
return value;
}
template <typename scalar_t>
__global__ void DeformablePSROIPoolForwardKernel(
const int count,
const scalar_t *bottom_data,
const scalar_t spatial_scale,
const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const scalar_t *bottom_rois, const scalar_t *bottom_trans,
const int no_trans,
const scalar_t trans_std,
const int sample_per_part,
const int output_dim,
const int group_size,
const int part_size,
const int num_classes,
const int channels_each_class,
scalar_t *top_data,
scalar_t *top_count)
{
CUDA_KERNEL_LOOP(index, count)
{
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;
// [start, end) interval for spatial sampling
const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
// Force too small ROIs to be 1x1
scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
// Compute w and h at bottom
scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
int class_id = ctop / channels_each_class;
scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
wstart += trans_x * roi_width;
scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
hstart += trans_y * roi_height;
scalar_t sum = 0;
int count = 0;
int gw = floor((scalar_t)(pw)*group_size / pooled_width);
int gh = floor((scalar_t)(ph)*group_size / pooled_height);
gw = min(max(gw, 0), group_size - 1);
gh = min(max(gh, 0), group_size - 1);
const scalar_t *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
for (int ih = 0; ih < sample_per_part; ih++)
{
for (int iw = 0; iw < sample_per_part; iw++)
{
scalar_t w = wstart + iw * sub_bin_size_w;
scalar_t h = hstart + ih * sub_bin_size_h;
// bilinear interpolation
if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
{
continue;
}
w = min(max(w, 0.), width - 1.);
h = min(max(h, 0.), height - 1.);
int c = (ctop * group_size + gh) * group_size + gw;
scalar_t val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height);
sum += val;
count++;
}
}
top_data[index] = count == 0 ? (scalar_t)(0) : sum / count;
top_count[index] = count;
}
}
template <typename scalar_t>
__global__ void DeformablePSROIPoolBackwardAccKernel(
const int count,
const scalar_t *top_diff,
const scalar_t *top_count,
const int num_rois,
const scalar_t spatial_scale,
const int channels,
const int height, const int width,
const int pooled_height, const int pooled_width,
const int output_dim,
scalar_t *bottom_data_diff, scalar_t *bottom_trans_diff,
const scalar_t *bottom_data,
const scalar_t *bottom_rois,
const scalar_t *bottom_trans,
const int no_trans,
const scalar_t trans_std,
const int sample_per_part,
const int group_size,
const int part_size,
const int num_classes,
const int channels_each_class)
{
CUDA_KERNEL_LOOP(index, count)
{
// The output is in order (n, ctop, ph, pw)
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int ctop = (index / pooled_width / pooled_height) % output_dim;
int n = index / pooled_width / pooled_height / output_dim;
// [start, end) interval for spatial sampling
const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
// Force too small ROIs to be 1x1
scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
// Compute w and h at bottom
scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
int class_id = ctop / channels_each_class;
scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
wstart += trans_x * roi_width;
scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
hstart += trans_y * roi_height;
if (top_count[index] <= 0)
{
continue;
}
scalar_t diff_val = top_diff[index] / top_count[index];
const scalar_t *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
scalar_t *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
int gw = floor((scalar_t)(pw)*group_size / pooled_width);
int gh = floor((scalar_t)(ph)*group_size / pooled_height);
gw = min(max(gw, 0), group_size - 1);
gh = min(max(gh, 0), group_size - 1);
for (int ih = 0; ih < sample_per_part; ih++)
{
for (int iw = 0; iw < sample_per_part; iw++)
{
scalar_t w = wstart + iw * sub_bin_size_w;
scalar_t h = hstart + ih * sub_bin_size_h;
// bilinear interpolation
if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
{
continue;
}
w = min(max(w, 0.), width - 1.);
h = min(max(h, 0.), height - 1.);
int c = (ctop * group_size + gh) * group_size + gw;
// backward on feature
int x0 = floor(w);
int x1 = ceil(w);
int y0 = floor(h);
int y1 = ceil(h);
scalar_t dist_x = w - x0, dist_y = h - y0;
scalar_t q00 = (1 - dist_x) * (1 - dist_y);
scalar_t q01 = (1 - dist_x) * dist_y;
scalar_t q10 = dist_x * (1 - dist_y);
scalar_t q11 = dist_x * dist_y;
int bottom_index_base = c * height * width;
atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);
if (no_trans)
{
continue;
}
scalar_t U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
scalar_t U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
scalar_t U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
scalar_t U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
scalar_t diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;
diff_x *= roi_width;
scalar_t diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;
diff_y *= roi_height;
atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);
atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);
}
}
}
}
void DeformablePSROIPoolForward(const at::Tensor data,
const at::Tensor bbox,
const at::Tensor trans,
at::Tensor out,
at::Tensor top_count,
const int batch,
const int channels,
const int height,
const int width,
const int num_bbox,
const int channels_trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
const int pooled_height = pooled_size;
const int pooled_width = pooled_size;
const int count = num_bbox * output_dim * pooled_height * pooled_width;
const int num_classes = no_trans ? 1 : channels_trans / 2;
const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data.type(), "deformable_psroi_pool_forward", ([&] {
const scalar_t *bottom_data = data.data<scalar_t>();
const scalar_t *bottom_rois = bbox.data<scalar_t>();
const scalar_t *bottom_trans = no_trans ? NULL : trans.data<scalar_t>();
scalar_t *top_data = out.data<scalar_t>();
scalar_t *top_count_data = top_count.data<scalar_t>();
DeformablePSROIPoolForwardKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS>>>(
count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width,
bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim,
group_size, part_size, num_classes, channels_each_class, top_data, top_count_data);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
}
}
void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
const at::Tensor data,
const at::Tensor bbox,
const at::Tensor trans,
const at::Tensor top_count,
at::Tensor in_grad,
at::Tensor trans_grad,
const int batch,
const int channels,
const int height,
const int width,
const int num_bbox,
const int channels_trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
// LOG(INFO) << "DeformablePSROIPoolBackward";
const int num_rois = num_bbox;
const int pooled_height = pooled_size;
const int pooled_width = pooled_size;
const int count = num_bbox * output_dim * pooled_height * pooled_width;
const int num_classes = no_trans ? 1 : channels_trans / 2;
const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
out_grad.type(), "deformable_psroi_pool_backward_acc", ([&] {
const scalar_t *top_diff = out_grad.data<scalar_t>();
const scalar_t *bottom_data = data.data<scalar_t>();
const scalar_t *bottom_rois = bbox.data<scalar_t>();
const scalar_t *bottom_trans = no_trans ? NULL : trans.data<scalar_t>();
scalar_t *bottom_data_diff = in_grad.data<scalar_t>();
scalar_t *bottom_trans_diff = no_trans ? NULL : trans_grad.data<scalar_t>();
const scalar_t *top_count_data = top_count.data<scalar_t>();
DeformablePSROIPoolBackwardAccKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS>>>(
count, top_diff, top_count_data, num_rois, (scalar_t)spatial_scale, channels, height, width,
pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff,
bottom_data, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part,
group_size, part_size, num_classes, channels_each_class);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
}
}
\ No newline at end of file
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
# Written by Ross Girshick # Written by Ross Girshick
# -------------------------------------------------------- # --------------------------------------------------------
# cython: language_level=3, boundscheck=False
import numpy as np import numpy as np
cimport numpy as np cimport numpy as np
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
# Modified by Kai Chen # Modified by Kai Chen
# ---------------------------------------------------------- # ----------------------------------------------------------
# cython: language_level=3, boundscheck=False
import numpy as np import numpy as np
cimport numpy as np cimport numpy as np
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
# Written by Ross Girshick # Written by Ross Girshick
# -------------------------------------------------------- # --------------------------------------------------------
# cython: language_level=3, boundscheck=False
import numpy as np import numpy as np
cimport numpy as np cimport numpy as np
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment