Commit 52b8685b authored by pedrofreire's avatar pedrofreire Committed by Francisco Massa
Browse files

Add Deformable Convolution operation. (#1586)

* Add Deformable Convolution operation.

This adds the deformable convolution operation, as described in Deformable Convolutional Networks (https://arxiv.org/abs/1703.06211).

- The code is based on https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp ; the whole code was modified and refactored to remove redundancies and increase clarity, and to adapt it to torchvision.

- The CPU part is a direct copy of the CUDA code; it might make sense to do follow-up adjustments in the CPU code to simplify it / optimize it, or to reuse functionality between CPU and CUDA..

- We also add tests (with a non-trivial set of parameters); they can be made more robust by randomizing the parameters and executing multiple times.

* Update DeformConv to be more consistent w/ Conv2d

* rename some variables and arguments to match Conv2d;
* add optional bias;
* add weight, offset and bias as module parameters;
* remove the n_parallel_imgs parameter;
* Fix __repr__;
* etc..

Initialization of weight and bias is the same as in Conv2d, and
initialization of offsets to zero is the same as in the paper.

This also includes some other small unrelated fixes/improvements.

* Apply clang-format in DeformConv files.

* Import Optional type annotation

* Remove offset param from DeformConv2d module

- We pass the offset in the forward of DeformConv2d, instead of having
an internal parameter. This adds some complexity to creating the module
(e.g. now you have to worry about the output size, to create the
offset), but it gives more flexibility.
- We also use make_tuple for tuple creation, in an attempt to fix error
w/ older compilers.

* Replace abs by std::abs

Old gcc versions were giving wrong results here, because they would
resolve abs as int -> int, thus causing undesired truncation. Replacing
abs by std::abs should allow for correct overloading of abs as float -> float.

* Reorder declarations for clarity

* Reorder weight and offset args in deform_conv2d

We place offset arg before the weight arg, to be more
consistent with DeformConv2d.forward(input, offset)

* Replace abs by std::abs in DeformConv_cuda
parent 5b1716a2
from __future__ import division
import math
import unittest
import numpy as np
import torch
from torch import Tensor
from torch.autograd import gradcheck
from torch.jit.annotations import Tuple
from torch.nn.modules.utils import _pair
from torchvision import ops
from itertools import product
import unittest
class RoIOpTester(object):
class OpTester(object):
@classmethod
def setUpClass(cls):
cls.dtype = torch.float64
......@@ -42,6 +45,14 @@ class RoIOpTester(object):
def test_backward_cuda_non_contiguous(self):
self._test_backward(device=torch.device('cuda'), contiguous=False)
def _test_forward(self, device, contiguous):
pass
def _test_backward(self, device, contiguous):
pass
class RoIOpTester(OpTester):
def _test_forward(self, device, contiguous):
pool_size = 5
# n_channels % (pool_size ** 2) == 0 required for PS opeartions.
......@@ -79,7 +90,6 @@ class RoIOpTester(object):
self.assertTrue(gradcheck(func, (x,)))
self.assertTrue(gradcheck(script_func, (x,)))
return
def fn(*args, **kwargs):
pass
......@@ -98,7 +108,7 @@ class RoIPoolTester(RoIOpTester, unittest.TestCase):
def get_script_fn(self, rois, pool_size):
@torch.jit.script
def script_fn(input, rois, pool_size):
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
# type: (Tensor, Tensor, int) -> Tensor
return ops.roi_pool(input, rois, pool_size, 1.0)[0]
return lambda x: script_fn(x, rois, pool_size)
......@@ -137,7 +147,7 @@ class PSRoIPoolTester(RoIOpTester, unittest.TestCase):
def get_script_fn(self, rois, pool_size):
@torch.jit.script
def script_fn(input, rois, pool_size):
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
# type: (Tensor, Tensor, int) -> Tensor
return ops.ps_roi_pool(input, rois, pool_size, 1.0)[0]
return lambda x: script_fn(x, rois, pool_size)
......@@ -174,29 +184,35 @@ class PSRoIPoolTester(RoIOpTester, unittest.TestCase):
return y
def bilinear_interpolate(data, height, width, y, x):
if y < -1.0 or y > height or x < -1.0 or x > width:
return 0.
def bilinear_interpolate(data, y, x, snap_border=False):
height, width = data.shape
y = min(max(0, y), height - 1)
x = min(max(0, x), width - 1)
if snap_border:
if -1 < y <= 0:
y = 0
elif height - 1 <= y < height:
y = height - 1
y_low = int(y)
y_high = min(y_low + 1, height - 1)
if -1 < x <= 0:
x = 0
elif width - 1 <= x < width:
x = width - 1
x_low = int(x)
x_high = min(x_low + 1, width - 1)
y_low = int(math.floor(y))
x_low = int(math.floor(x))
y_high = y_low + 1
x_high = x_low + 1
wy_h = y - y_low
wy_l = 1 - wy_h
wx_h = x - x_low
wy_l = 1 - wy_h
wx_l = 1 - wx_h
val = 0
for wx, x in zip((wx_l, wx_h), (x_low, x_high)):
for wy, y in zip((wy_l, wy_h), (y_low, y_high)):
val += wx * wy * data[y * width + x]
for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
if 0 <= yp < height and 0 <= xp < width:
val += wx * wy * data[yp, xp]
return val
......@@ -208,7 +224,7 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase):
def get_script_fn(self, rois, pool_size):
@torch.jit.script
def script_fn(input, rois, pool_size):
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
# type: (Tensor, Tensor, int) -> Tensor
return ops.roi_align(input, rois, pool_size, 1.0)[0]
return lambda x: script_fn(x, rois, pool_size)
......@@ -242,12 +258,7 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase):
y = start_h + (iy + 0.5) * bin_h / grid_h
for ix in range(0, grid_w):
x = start_w + (ix + 0.5) * bin_w / grid_w
val += bilinear_interpolate(
in_data[batch_idx, channel, :, :].flatten(),
in_data.size(-2),
in_data.size(-1),
y, x
)
val += bilinear_interpolate(in_data[batch_idx, channel, :, :], y, x, snap_border=True)
val /= grid_h * grid_w
out_data[r, channel, i, j] = val
......@@ -262,7 +273,7 @@ class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
def get_script_fn(self, rois, pool_size):
@torch.jit.script
def script_fn(input, rois, pool_size):
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
# type: (Tensor, Tensor, int) -> Tensor
return ops.ps_roi_align(input, rois, pool_size, 1.0)[0]
return lambda x: script_fn(x, rois, pool_size)
......@@ -298,12 +309,7 @@ class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
y = start_h + (iy + 0.5) * bin_h / grid_h
for ix in range(0, grid_w):
x = start_w + (ix + 0.5) * bin_w / grid_w
val += bilinear_interpolate(
in_data[batch_idx, c_in, :, :].flatten(),
in_data.size(-2),
in_data.size(-1),
y, x
)
val += bilinear_interpolate(in_data[batch_idx, c_in, :, :], y, x, snap_border=True)
val /= grid_h * grid_w
out_data[r, c_out, i, j] = val
......@@ -376,5 +382,120 @@ class NewEmptyTensorTester(unittest.TestCase):
assert out.dtype == input.dtype
class DeformConvTester(OpTester, unittest.TestCase):
def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1):
stride_h, stride_w = _pair(stride)
pad_h, pad_w = _pair(padding)
dil_h, dil_w = _pair(dilation)
weight_h, weight_w = weight.shape[-2:]
n_batches, n_in_channels, in_h, in_w = x.shape
n_out_channels = weight.shape[0]
out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1
n_offset_grps = offset.shape[1] // (2 * weight_h * weight_w)
in_c_per_offset_grp = n_in_channels // n_offset_grps
n_weight_grps = n_in_channels // weight.shape[1]
in_c_per_weight_grp = weight.shape[1]
out_c_per_weight_grp = n_out_channels // n_weight_grps
out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype)
for b in range(n_batches):
for c_out in range(n_out_channels):
for i in range(out_h):
for j in range(out_w):
for di in range(weight_h):
for dj in range(weight_w):
for c in range(in_c_per_weight_grp):
weight_grp = c_out // out_c_per_weight_grp
c_in = weight_grp * in_c_per_weight_grp + c
offset_grp = c_in // in_c_per_offset_grp
offset_idx = 2 * (offset_grp * (weight_h * weight_w) + di * weight_w + dj)
pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]
out[b, c_out, i, j] += (weight[c_out, c, di, dj] *
bilinear_interpolate(x[b, c_in, :, :], pi, pj))
out += bias.view(1, n_out_channels, 1, 1)
return out
def get_fn_args(self, device, contiguous):
batch_sz = 1
n_in_channels = 6
n_out_channels = 2
n_weight_grps = 2
n_offset_grps = 3
stride = (2, 1)
pad = (1, 0)
dilation = (2, 1)
stride_h, stride_w = stride
pad_h, pad_w = pad
dil_h, dil_w = dilation
weight_h, weight_w = (3, 2)
in_h, in_w = (5, 4)
out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1
x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=self.dtype, requires_grad=True)
offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w,
device=device, dtype=self.dtype, requires_grad=True)
weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w,
device=device, dtype=self.dtype, requires_grad=True)
bias = torch.randn(n_out_channels, device=device, dtype=self.dtype, requires_grad=True)
if not contiguous:
x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
return x, weight, offset, bias, stride, pad, dilation
def _test_forward(self, device, contiguous):
x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous)
in_channels = 6
out_channels = 2
kernel_size = (3, 2)
groups = 2
offset_groups = 3
layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, offset_groups=offset_groups).to(device=x.device,
dtype=x.dtype)
res = layer(x, offset)
weight = layer.weight.data
bias = layer.bias.data
expected = self.expected_fn(x, weight, offset, bias, stride=stride, padding=padding, dilation=dilation)
self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(res, expected))
def _test_backward(self, device, contiguous):
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous)
def func(x_, offset_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation)
gradcheck(func, (x, offset, weight, bias), nondet_tol=1e-5)
@torch.jit.script
def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
# type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_)
gradcheck(lambda z, off, wei, bi: script_func(z, off, wei, bi, stride, padding, dilation),
(x, offset, weight, bias), nondet_tol=1e-5)
if __name__ == '__main__':
unittest.main()
#pragma once
#include "cpu/vision_cpu.h"
#ifdef WITH_CUDA
#include "cuda/vision_cuda.h"
#endif
at::Tensor DeformConv2d_forward(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const std::pair<int, int>& stride,
const std::pair<int, int>& padding,
const std::pair<int, int>& dilation,
const int groups,
const int offset_groups) {
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return DeformConv2d_forward_cuda(
input.contiguous(),
weight.contiguous(),
offset.contiguous(),
bias.contiguous(),
stride,
padding,
dilation,
groups,
offset_groups);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return DeformConv2d_forward_cpu(
input.contiguous(),
weight.contiguous(),
offset.contiguous(),
bias.contiguous(),
stride,
padding,
dilation,
groups,
offset_groups);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> DeformConv2d_backward(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const std::pair<int, int>& stride,
const std::pair<int, int>& padding,
const std::pair<int, int>& dilation,
const int groups,
const int offset_groups) {
if (grad.type().is_cuda()) {
#ifdef WITH_CUDA
return DeformConv2d_backward_cuda(
grad.contiguous(),
input.contiguous(),
weight.contiguous(),
offset.contiguous(),
bias.contiguous(),
stride,
padding,
dilation,
groups,
offset_groups);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return DeformConv2d_backward_cpu(
grad.contiguous(),
input.contiguous(),
weight.contiguous(),
offset.contiguous(),
bias.contiguous(),
stride,
padding,
dilation,
groups,
offset_groups);
}
using namespace at;
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class DeformConv2dFunction
: public torch::autograd::Function<DeformConv2dFunction> {
public:
static variable_list forward(
AutogradContext* ctx,
Variable input,
Variable weight,
Variable offset,
Variable bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
auto output = DeformConv2d_forward(
input,
weight,
offset,
bias,
{stride_h, stride_w},
{pad_h, pad_w},
{dilation_h, dilation_w},
groups,
offset_groups);
ctx->save_for_backward({input, weight, offset, bias});
ctx->saved_data["stride_h"] = stride_h;
ctx->saved_data["stride_w"] = stride_w;
ctx->saved_data["pad_h"] = pad_h;
ctx->saved_data["pad_w"] = pad_w;
ctx->saved_data["dilation_h"] = dilation_h;
ctx->saved_data["dilation_w"] = dilation_w;
ctx->saved_data["groups"] = groups;
ctx->saved_data["offset_groups"] = offset_groups;
return {
output,
};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
auto saved = ctx->get_saved_variables();
auto input = saved[0];
auto weight = saved[1];
auto offset = saved[2];
auto bias = saved[3];
auto stride_h = ctx->saved_data["stride_h"].toInt();
auto stride_w = ctx->saved_data["stride_w"].toInt();
auto pad_h = ctx->saved_data["pad_h"].toInt();
auto pad_w = ctx->saved_data["pad_w"].toInt();
auto dilation_h = ctx->saved_data["dilation_h"].toInt();
auto dilation_w = ctx->saved_data["dilation_w"].toInt();
auto groups = ctx->saved_data["groups"].toInt();
auto offset_groups = ctx->saved_data["offset_groups"].toInt();
auto grads = DeformConv2d_backward(
grad_output[0],
input,
weight,
offset,
bias,
{stride_h, stride_w},
{pad_h, pad_w},
{dilation_h, dilation_w},
groups,
offset_groups);
auto grad_input = std::get<0>(grads);
auto grad_weight = std::get<1>(grads);
auto grad_offset = std::get<2>(grads);
auto grad_bias = std::get<3>(grads);
return {
grad_input,
grad_weight,
grad_offset,
grad_bias,
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
};
}
};
at::Tensor deform_conv2d(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
auto result = DeformConv2dFunction::apply(
input,
weight,
offset,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);
return result[0];
}
/*!
******************* BEGIN Caffe Copyright Notice and Disclaimer
*****************
*
* COPYRIGHT
*
* All contributions by the University of California:
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
* All rights reserved.
*
* All other contributions:
* Copyright (c) 2014-2017, the respective contributors
* All rights reserved.
*
* Caffe uses a shared copyright model: each contributor holds copyright over
* their contributions to Caffe. The project versioning records all such
* contribution and copyright details. If a contributor wants to further mark
* their specific copyright on a particular contribution, they should indicate
* their copyright solely in the commit message of the change when it is
* committed.
*
* LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
*FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
*DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
*SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
*CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
*OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
*OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* CONTRIBUTION AGREEMENT
*
* By contributing to the BVLC/caffe repository through pull-request, comment,
* or otherwise, the contributor releases their content to the
* license and copyright terms herein.
*
***************** END Caffe Copyright Notice and Disclaimer
*********************
*
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file modulated_deformable_im2col.cuh
* \brief Function definitions of converting an image to
* column matrix based on kernel, padding, dilation, and offset.
* These functions are mainly used in deformable convolution operators.
* \ref: https://arxiv.org/abs/1703.06211
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
*/
// modified from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
// modified from
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <TH/TH.h>
#include <cmath>
#include <iostream>
#include <tuple>
using namespace at;
const int kMaxParallelImgs = 32;
template <typename scalar_t>
static scalar_t bilinear_interpolate(
const scalar_t* in,
const int height,
const int width,
scalar_t h,
scalar_t w) {
if (h <= -1 || height <= h || w <= -1 || width <= w) {
return 0;
}
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
scalar_t lh = h - h_low;
scalar_t lw = w - w_low;
scalar_t hh = 1 - lh, hw = 1 - lw;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = in[h_low * width + w_low];
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = in[h_low * width + w_high];
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = in[h_high * width + w_low];
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = in[h_high * width + w_high];
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
static void deformable_im2col_kernel(
const int n,
const scalar_t* input,
const scalar_t* offset,
const int height,
const int width,
const int weight_h,
const int weight_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dil_h,
const int dil_w,
const int batch_sz,
const int n_in_channels,
const int n_offset_grps,
const int out_h,
const int out_w,
scalar_t* columns) {
for (int index = 0; index != n; ++index) {
const int out_x = index % out_w;
const int out_y = (index / out_w) % out_h;
const int out_b = (index / (out_w * out_h)) % batch_sz;
const int in_c = index / (out_w * out_h * batch_sz);
const int out_c = in_c * weight_h * weight_w;
int c_per_offset_grp = n_in_channels / n_offset_grps;
const int grp_idx = in_c / c_per_offset_grp;
auto columns_ptr = columns +
(out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) +
out_y * out_w + out_x);
auto input_ptr = input +
(out_b * (n_in_channels * height * width) + in_c * (height * width));
auto offset_ptr = offset +
(out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h *
out_w;
for (int i = 0; i < weight_h; ++i) {
for (int j = 0; j < weight_w; ++j) {
const int offset_idx = 2 * (i * weight_w + j);
const scalar_t offset_h =
offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x];
const scalar_t offset_w = offset_ptr
[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x];
const scalar_t y = (out_y * stride_h - pad_h) + i * dil_h + offset_h;
const scalar_t x = (out_x * stride_w - pad_w) + j * dil_w + offset_w;
*columns_ptr = bilinear_interpolate(input_ptr, height, width, y, x);
columns_ptr += batch_sz * out_h * out_w;
}
}
}
}
static void deformable_im2col(
const at::Tensor input,
const at::Tensor data_offset,
int n_in_channels,
int height,
int width,
int weight_h,
int weight_w,
int pad_h,
int pad_w,
int stride_h,
int stride_w,
int dil_h,
int dil_w,
int out_h,
int out_w,
int parallel_imgs,
int deformable_group,
at::Tensor data_col) {
int num_kernels = n_in_channels * out_h * out_w * parallel_imgs;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "deformable_im2col", ([&] {
deformable_im2col_kernel(
num_kernels,
input.data_ptr<scalar_t>(),
data_offset.data_ptr<scalar_t>(),
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
parallel_imgs,
n_in_channels,
deformable_group,
out_h,
out_w,
data_col.data_ptr<scalar_t>());
}));
}
static int get_greatest_divisor_below_bound(int n, int bound) {
for (int k = bound; k > 1; --k) {
if (n % k == 0) {
return k;
}
}
return 1;
}
at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& bias,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int n_weight_grps,
int n_offset_grps) {
at::Tensor input = input_param;
at::Tensor offset = offset_param;
at::Tensor weight = weight_param;
TORCH_CHECK(input.ndimension() == 4);
TORCH_CHECK(offset.ndimension() == 4);
TORCH_CHECK(weight.ndimension() == 4);
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(offset.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
int batch_sz = input.size(0);
int n_in_channels = input.size(1);
int in_h = input.size(2);
int in_w = input.size(3);
int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
// Unpack shapes and args
int out_channels = weight.size(0);
int weight_h = weight.size(2);
int weight_w = weight.size(3);
int stride_h = stride.first;
int stride_w = stride.second;
int pad_h = pad.first;
int pad_w = pad.second;
int dil_h = dilation.first;
int dil_w = dilation.second;
int ker_h = dil_h * (weight_h - 1) + 1;
int ker_w = dil_w * (weight_w - 1) + 1;
int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1;
int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1;
TORCH_CHECK(
weight_h > 0 && weight_w > 0,
"weight_h: ",
weight_h,
" weight_w: ",
weight_w);
TORCH_CHECK(
stride_h > 0 && stride_w > 0,
"stride_h: ",
stride_h,
" stride_w: ",
stride_w);
TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w);
TORCH_CHECK(dil_h > 0 && dil_w > 0, "dil_h: ", dil_h, " dil_w: ", dil_w);
TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1));
TORCH_CHECK(weight.size(0) % n_weight_grps == 0);
TORCH_CHECK(input.size(1) % n_offset_grps == 0);
TORCH_CHECK(
(offset.size(0) == input.size(0)), "invalid batch size of offset");
TORCH_CHECK(
(offset.size(1) == n_offset_grps * 2 * weight_h * weight_w),
"got: ",
offset.size(1),
" expected: ",
n_offset_grps * 2 * weight_h * weight_w);
TORCH_CHECK(
(offset.size(2) == out_h && offset.size(3) == out_w),
"offset output dims: (",
offset.size(2),
", ",
offset.size(3),
") - ",
"computed output dims: (",
out_h,
", ",
out_w,
")");
TORCH_CHECK(
out_h > 0 && out_w > 0,
"Calculated output size too small - out_h: ",
out_h,
" out_w: ",
out_w);
auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options());
// Separate batches into blocks
out = out.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
out_channels,
out_h,
out_w});
input = input.view(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
offset = offset.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
at::Tensor out_buf = at::zeros(
{batch_sz / n_parallel_imgs,
out_channels,
n_parallel_imgs * out_h,
out_w},
out.options());
// Separate channels into convolution groups
out_buf = out_buf.view({out_buf.size(0),
n_weight_grps,
out_buf.size(1) / n_weight_grps,
out_buf.size(2),
out_buf.size(3)});
weight = weight.view({n_weight_grps,
weight.size(0) / n_weight_grps,
weight.size(1),
weight.size(2),
weight.size(3)});
// Sample points and perform convolution
auto columns = at::zeros(
{n_in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w},
input.options());
for (int b = 0; b < batch_sz / n_parallel_imgs; b++) {
deformable_im2col(
input[b],
offset[b],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
out_h,
out_w,
n_parallel_imgs,
n_offset_grps,
columns);
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int g = 0; g < n_weight_grps; g++) {
out_buf[b][g] = out_buf[b][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(out_buf[b][g]);
}
}
out_buf = out_buf.view({batch_sz / n_parallel_imgs,
out_channels,
n_parallel_imgs,
out_h,
out_w});
out_buf.transpose_(1, 2);
out.copy_(out_buf);
out = out.view({batch_sz, out_channels, out_h, out_w});
return out + bias.view({1, out_channels, 1, 1});
}
template <typename scalar_t>
static void deformable_col2im_kernel(
const int n,
const scalar_t* col,
const scalar_t* offset,
const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int batch_sz,
const int n_offset_grps,
const int out_h,
const int out_w,
scalar_t* grad_im) {
for (int index = 0; index != n; ++index) {
const int out_x = index % out_w;
const int out_y = (index / out_w) % out_h;
const int b = (index / (out_w * out_h)) % batch_sz;
const int j = (index / (out_w * out_h * batch_sz)) % kernel_w;
const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h;
const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h);
int c_per_offset_grp = channels / n_offset_grps;
const int offset_grp = c / c_per_offset_grp;
auto offset_ptr = offset +
(b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h *
out_w;
const int offset_h_ptr =
((2 * (i * kernel_w + j)) * out_h + out_y) * out_w + out_x;
const int offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * out_h + out_y) * out_w + out_x;
const scalar_t offset_h = offset_ptr[offset_h_ptr];
const scalar_t offset_w = offset_ptr[offset_w_ptr];
const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
for (int dy = -1; dy <= 1; dy++) {
for (int dx = -1; dx <= 1; dx++) {
int yp = int(y) + dy;
int xp = int(x) + dx;
if (0 <= yp && yp < height && 0 <= xp && xp < width &&
std::abs(y - yp) < 1 && std::abs(x - xp) < 1) {
int grad_pos = ((b * channels + c) * height + yp) * width + xp;
scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp));
grad_im[grad_pos] += weight * col[index];
}
}
}
}
}
static void compute_grad_input(
const at::Tensor columns,
const at::Tensor offset,
const int channels,
const int height,
const int width,
const int weight_h,
const int weight_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int parallel_imgs,
const int n_offset_grps,
at::Tensor grad_im) {
int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
int out_w =
(width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
int num_kernels =
channels * weight_h * weight_w * out_h * out_w * parallel_imgs;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
columns.scalar_type(), "deformable_col2im", ([&] {
deformable_col2im_kernel(
num_kernels,
columns.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(),
channels,
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
parallel_imgs,
n_offset_grps,
out_h,
out_w,
grad_im.data_ptr<scalar_t>());
}));
}
template <typename scalar_t>
static scalar_t get_coordinate_weight(
const scalar_t* im_data,
const int height,
const int width,
scalar_t y,
scalar_t x,
bool is_y_direction) {
int y_l = floor(y);
int x_l = floor(x);
int y_h = y_l + 1;
int x_h = x_l + 1;
bool valid_y_l = 0 <= y_l && y_l < height;
bool valid_y_h = 0 <= y_h && y_h < height;
bool valid_x_l = 0 <= x_l && x_l < width;
bool valid_x_h = 0 <= x_h && x_h < width;
scalar_t zero = 0;
scalar_t v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero;
scalar_t v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero;
scalar_t v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero;
scalar_t v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero;
if (is_y_direction) {
scalar_t dx = x - x_l;
return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx);
} else {
scalar_t dy = y - y_l;
return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx);
}
}
template <typename scalar_t>
static void deformable_col2im_coord_kernel(
const int n,
const scalar_t* col,
const scalar_t* im,
const scalar_t* offset,
const int channels,
const int height,
const int width,
const int weight_h,
const int weight_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int batch_sz,
const int offset_channels,
const int n_offset_grps,
const int out_h,
const int out_w,
scalar_t* grad_offset) {
for (int index = 0; index != n; ++index) {
scalar_t val = 0;
int w = index % out_w;
int h = (index / out_w) % out_h;
int c = (index / (out_w * out_h)) % offset_channels;
int b = index / (out_w * out_h * offset_channels);
const int offset_grp = c / (2 * weight_h * weight_w);
const int col_step = weight_h * weight_w;
int c_per_offset_grp = channels / n_offset_grps;
auto col_ptr = col +
offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * out_w *
out_h;
auto im_ptr = im +
(b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width;
auto offset_ptr = offset +
(b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h *
out_w;
const int offset_c = c - offset_grp * 2 * weight_h * weight_w;
const int is_y_direction = offset_c % 2 == 0;
const int c_bound = c_per_offset_grp * weight_h * weight_w;
for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) {
const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w;
int out_x = col_pos % out_w;
int out_y = (col_pos / out_w) % out_h;
int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w;
int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h;
const int offset_h_idx =
(((2 * (i * weight_w + j)) * out_h + out_y) * out_w + out_x);
const int offset_w_idx =
(((2 * (i * weight_w + j) + 1) * out_h + out_y) * out_w + out_x);
const scalar_t offset_h = offset_ptr[offset_h_idx];
const scalar_t offset_w = offset_ptr[offset_w_idx];
scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
const scalar_t weight =
get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction);
val += weight * col_ptr[col_pos];
im_ptr += height * width;
}
grad_offset[index] = val;
}
}
static void compute_grad_offset(
const at::Tensor columns,
const at::Tensor input,
const at::Tensor offset,
const int channels,
const int height,
const int width,
const int weight_h,
const int weight_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int parallel_imgs,
const int n_offset_grps,
at::Tensor grad_offset) {
int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
int out_w =
(width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
int num_kernels =
out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
columns.scalar_type(), "deformable_col2im_coord", ([&] {
deformable_col2im_coord_kernel(
num_kernels,
columns.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(),
channels,
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
parallel_imgs,
2 * weight_h * weight_w * n_offset_grps,
n_offset_grps,
out_h,
out_w,
grad_offset.data_ptr<scalar_t>());
}));
}
static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
at::Tensor input,
at::Tensor weight,
at::Tensor offset,
at::Tensor grad_out,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int n_weight_grps,
int n_offset_grps,
int n_parallel_imgs) {
int batch_sz = input.size(0);
int n_in_channels = input.size(1);
int in_h = input.size(2);
int in_w = input.size(3);
n_parallel_imgs = std::min(batch_sz, n_parallel_imgs);
long n_out_channels = weight.size(0);
int weight_h = weight.size(2);
int weight_w = weight.size(3);
int stride_h = stride.first;
int stride_w = stride.second;
int pad_h = pad.first;
int pad_w = pad.second;
int dil_h = dilation.first;
int dil_w = dilation.second;
long out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) / stride_h + 1;
long out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) / stride_w + 1;
auto grad_input = at::zeros_like(input);
auto grad_offset = at::zeros_like(offset);
auto columns = at::zeros(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options());
// Separate into blocks
grad_input = grad_input.view(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
input = input.view(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
grad_offset = grad_offset.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
offset = offset.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
grad_out = grad_out.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_out_channels,
out_h,
out_w});
grad_out.transpose_(1, 2);
grad_out = grad_out.view({grad_out.size(0),
n_weight_grps,
grad_out.size(1) / n_weight_grps,
grad_out.size(2),
grad_out.size(3),
grad_out.size(4)});
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
// Separate into weight groups
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
weight = weight.view({n_weight_grps,
weight.size(0) / n_weight_grps,
weight.size(1),
weight.size(2),
weight.size(3)});
for (int g = 0; g < n_weight_grps; g++) {
columns[g] = columns[g].addmm_(
weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1));
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
compute_grad_offset(
columns,
input[elt],
offset[elt],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
n_parallel_imgs,
n_offset_grps,
grad_offset[elt]);
compute_grad_input(
columns,
offset[elt],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
n_parallel_imgs,
n_offset_grps,
grad_input[elt]);
}
grad_out = grad_out.view({grad_out.size(0),
grad_out.size(1) * grad_out.size(2),
grad_out.size(3),
grad_out.size(4),
grad_out.size(5)});
grad_out.transpose_(1, 2);
grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w});
grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w});
input = input.view({batch_sz, n_in_channels, in_h, in_w});
grad_offset = grad_offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
offset = offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
return std::make_tuple(grad_input, grad_offset);
}
static at::Tensor deform_conv2d_backward_parameters_cpu(
at::Tensor input,
at::Tensor weight,
at::Tensor offset,
at::Tensor grad_out,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int n_weight_grps,
int n_offset_grps,
int n_parallel_imgs) {
int batch_sz = input.size(0);
int n_in_channels = input.size(1);
int in_h = input.size(2);
int in_w = input.size(3);
n_parallel_imgs = std::min(batch_sz, n_parallel_imgs);
long n_out_channels = weight.size(0);
int weight_h = weight.size(2);
int weight_w = weight.size(3);
int stride_h = stride.first;
int stride_w = stride.second;
int pad_h = pad.first;
int pad_w = pad.second;
int dil_h = dilation.first;
int dil_w = dilation.second;
long out_h = grad_out.size(2);
long out_w = grad_out.size(3);
auto grad_weight = at::zeros_like(weight);
;
auto columns = at::zeros(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options());
grad_out = grad_out.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_out_channels,
out_h,
out_w});
grad_out.transpose_(1, 2);
at::Tensor grad_out_buf = at::zeros_like(grad_out);
grad_out_buf.copy_(grad_out);
grad_out_buf = grad_out_buf.view({batch_sz / n_parallel_imgs,
n_out_channels,
n_parallel_imgs * out_h,
out_w});
grad_out_buf = grad_out_buf.view({grad_out_buf.size(0),
n_weight_grps,
grad_out_buf.size(1) / n_weight_grps,
grad_out_buf.size(2),
grad_out_buf.size(3)});
grad_out.transpose_(1, 2);
grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w});
input = input.view(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
offset = offset.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
grad_weight = grad_weight.view({n_weight_grps,
grad_weight.size(0) / n_weight_grps,
grad_weight.size(1),
grad_weight.size(2),
grad_weight.size(3)});
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
deformable_im2col(
input[elt],
offset[elt],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
out_h,
out_w,
n_parallel_imgs,
n_offset_grps,
columns);
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int g = 0; g < n_weight_grps; g++) {
grad_weight[g] =
grad_weight[g]
.flatten(1)
.addmm_(
grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0))
.view_as(grad_weight[g]);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
input = input.view({batch_sz, n_in_channels, in_h, in_w});
offset = offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2),
grad_weight.size(3),
grad_weight.size(4)});
return grad_weight;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cpu(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int n_weight_grps,
int n_offset_grps) {
const int batch_sz = input.size(0);
const int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
auto grad_input_and_offset = deform_conv2d_backward_input_cpu(
input,
weight,
offset,
grad_out,
stride,
pad,
dilation,
n_weight_grps,
n_offset_grps,
n_parallel_imgs);
auto grad_input = std::get<0>(grad_input_and_offset);
auto grad_offset = std::get<1>(grad_input_and_offset);
auto grad_weight = deform_conv2d_backward_parameters_cpu(
input,
weight,
offset,
grad_out,
stride,
pad,
dilation,
n_weight_grps,
n_offset_grps,
n_parallel_imgs);
auto grad_bias = at::ones_like(bias) * grad_out.sum({0, 2, 3});
return std::make_tuple(grad_input, grad_weight, grad_offset, grad_bias);
}
......@@ -84,3 +84,27 @@ at::Tensor nms_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
const float iou_threshold);
at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int groups,
int deformable_groups);
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cpu(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int groups,
int deformable_groups);
/*!
******************* BEGIN Caffe Copyright Notice and Disclaimer
*****************
*
* COPYRIGHT
*
* All contributions by the University of California:
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
* All rights reserved.
*
* All other contributions:
* Copyright (c) 2014-2017, the respective contributors
* All rights reserved.
*
* Caffe uses a shared copyright model: each contributor holds copyright over
* their contributions to Caffe. The project versioning records all such
* contribution and copyright details. If a contributor wants to further mark
* their specific copyright on a particular contribution, they should indicate
* their copyright solely in the commit message of the change when it is
* committed.
*
* LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
*AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
*IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
*FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
*DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
*SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
*CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
*OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
*OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* CONTRIBUTION AGREEMENT
*
* By contributing to the BVLC/caffe repository through pull-request, comment,
* or otherwise, the contributor releases their content to the
* license and copyright terms herein.
*
***************** END Caffe Copyright Notice and Disclaimer
*********************
*
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file modulated_deformable_im2col.cuh
* \brief Function definitions of converting an image to
* column matrix based on kernel, padding, dilation, and offset.
* These functions are mainly used in deformable convolution operators.
* \ref: https://arxiv.org/abs/1703.06211
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
*/
// modified from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
// modified from
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include "cuda_helpers.h"
#include <cmath>
#include <iostream>
#include <tuple>
using namespace at;
const int CUDA_NUM_THREADS = 1024;
const int kMaxGridNum = 65535;
const int kMaxParallelImgs = 32;
inline int GET_BLOCKS(const int N) {
return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
}
template <typename scalar_t>
__device__ scalar_t bilinear_interpolate(
const scalar_t* in,
const int height,
const int width,
scalar_t h,
scalar_t w) {
if (h <= -1 || height <= h || w <= -1 || width <= w) {
return 0;
}
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
scalar_t lh = h - h_low;
scalar_t lw = w - w_low;
scalar_t hh = 1 - lh, hw = 1 - lw;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = in[h_low * width + w_low];
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = in[h_low * width + w_high];
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = in[h_high * width + w_low];
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = in[h_high * width + w_high];
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename scalar_t>
__global__ void deformable_im2col_gpu_kernel(
const int n,
const scalar_t* input_ptr,
const scalar_t* offset_ptr,
const int height,
const int width,
const int weight_h,
const int weight_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dil_h,
const int dil_w,
const int batch_sz,
const int n_in_channels,
const int n_offset_grps,
const int out_h,
const int out_w,
scalar_t* columns_ptr) {
CUDA_1D_KERNEL_LOOP(index, n) {
const int out_x = index % out_w;
const int out_y = (index / out_w) % out_h;
const int out_b = (index / (out_w * out_h)) % batch_sz;
const int in_c = index / (out_w * out_h * batch_sz);
const int out_c = in_c * weight_h * weight_w;
int c_per_offset_grp = n_in_channels / n_offset_grps;
const int grp_idx = in_c / c_per_offset_grp;
columns_ptr +=
(out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) +
out_y * out_w + out_x);
input_ptr +=
(out_b * (n_in_channels * height * width) + in_c * (height * width));
offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w *
out_h * out_w;
for (int i = 0; i < weight_h; ++i) {
for (int j = 0; j < weight_w; ++j) {
const int offset_idx = 2 * (i * weight_w + j);
const scalar_t offset_h =
offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x];
const scalar_t offset_w = offset_ptr
[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x];
const scalar_t y = (out_y * stride_h - pad_h) + i * dil_h + offset_h;
const scalar_t x = (out_x * stride_w - pad_w) + j * dil_w + offset_w;
*columns_ptr = bilinear_interpolate(input_ptr, height, width, y, x);
columns_ptr += batch_sz * out_h * out_w;
}
}
}
}
static void deformable_im2col(
const at::Tensor input,
const at::Tensor data_offset,
int n_in_channels,
int height,
int width,
int weight_h,
int weight_w,
int pad_h,
int pad_w,
int stride_h,
int stride_w,
int dil_h,
int dil_w,
int out_h,
int out_w,
int parallel_imgs,
int deformable_group,
at::Tensor data_col) {
int num_kernels = n_in_channels * out_h * out_w * parallel_imgs;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "deformable_im2col_gpu", ([&] {
deformable_im2col_gpu_kernel<<<
GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS>>>(
num_kernels,
input.data_ptr<scalar_t>(),
data_offset.data_ptr<scalar_t>(),
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
parallel_imgs,
n_in_channels,
deformable_group,
out_h,
out_w,
data_col.data_ptr<scalar_t>());
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
}
}
static int get_greatest_divisor_below_bound(int n, int bound) {
for (int k = bound; k > 1; --k) {
if (n % k == 0) {
return k;
}
}
return 1;
}
at::Tensor DeformConv2d_forward_cuda(
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& bias,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int n_weight_grps,
int n_offset_grps) {
at::Tensor input = input_param;
at::Tensor weight = weight_param;
at::Tensor offset = offset_param;
TORCH_CHECK(input.ndimension() == 4);
TORCH_CHECK(offset.ndimension() == 4);
TORCH_CHECK(weight.ndimension() == 4);
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(offset.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
TORCH_CHECK(input.device().is_cuda(), "input must be a CUDA tensor");
at::DeviceGuard guard(input.device());
int batch_sz = input.size(0);
int in_channels = input.size(1);
int in_h = input.size(2);
int in_w = input.size(3);
int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
int out_channels = weight.size(0);
int weight_h = weight.size(2);
int weight_w = weight.size(3);
int stride_h = stride.first;
int stride_w = stride.second;
int pad_h = pad.first;
int pad_w = pad.second;
int dil_h = dilation.first;
int dil_w = dilation.second;
int ker_h = dil_h * (weight_h - 1) + 1;
int ker_w = dil_w * (weight_w - 1) + 1;
int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1;
int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1;
TORCH_CHECK(
weight_h > 0 && weight_w > 0,
"weight_h: ",
weight_h,
" weight_w: ",
weight_w);
TORCH_CHECK(
stride_h > 0 && stride_w > 0,
"stride_h: ",
stride_h,
" stride_w: ",
stride_w);
TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w);
TORCH_CHECK(dil_h > 0 && dil_w > 0, "dil_h: ", dil_h, " dil_w: ", dil_w);
TORCH_CHECK(weight.size(1) * n_weight_grps == input.size(1));
TORCH_CHECK(weight.size(0) % n_weight_grps == 0);
TORCH_CHECK(input.size(1) % n_offset_grps == 0);
TORCH_CHECK(
(offset.size(0) == input.size(0)), "invalid batch size of offset");
TORCH_CHECK(
(offset.size(1) == n_offset_grps * 2 * weight_h * weight_w),
"got: ",
offset.size(1),
" expected: ",
n_offset_grps * 2 * weight_h * weight_w);
TORCH_CHECK(
(offset.size(2) == out_h && offset.size(3) == out_w),
"offset output dims: (",
offset.size(2),
", ",
offset.size(3),
") - ",
"computed output dims: (",
out_h,
", ",
out_w,
")");
TORCH_CHECK(
out_h > 0 && out_w > 0,
"Calculated output size too small - out_h: ",
out_h,
" out_w: ",
out_w);
auto out = at::zeros({batch_sz, out_channels, out_h, out_w}, input.options());
// Separate batches into blocks
out = out.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
out_channels,
out_h,
out_w});
input = input.view(
{batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w});
offset = offset.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
at::Tensor out_buf = at::zeros(
{batch_sz / n_parallel_imgs,
out_channels,
n_parallel_imgs * out_h,
out_w},
out.options());
// Separate channels into convolution groups
out_buf = out_buf.view({out_buf.size(0),
n_weight_grps,
out_buf.size(1) / n_weight_grps,
out_buf.size(2),
out_buf.size(3)});
weight = weight.view({n_weight_grps,
weight.size(0) / n_weight_grps,
weight.size(1),
weight.size(2),
weight.size(3)});
// Sample points and perform convolution
auto columns = at::zeros(
{in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w},
input.options());
for (int b = 0; b < batch_sz / n_parallel_imgs; b++) {
deformable_im2col(
input[b],
offset[b],
in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
out_h,
out_w,
n_parallel_imgs,
n_offset_grps,
columns);
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int g = 0; g < n_weight_grps; g++) {
out_buf[b][g] = out_buf[b][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(out_buf[b][g]);
}
}
out_buf = out_buf.view({batch_sz / n_parallel_imgs,
out_channels,
n_parallel_imgs,
out_h,
out_w});
out_buf.transpose_(1, 2);
out.copy_(out_buf);
out = out.view({batch_sz, out_channels, out_h, out_w});
return out + bias.view({1, out_channels, 1, 1});
}
template <typename scalar_t>
__global__ void deformable_col2im_gpu_kernel(
const int n,
const scalar_t* col,
const scalar_t* offset_ptr,
const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int batch_sz,
const int n_offset_grps,
const int out_h,
const int out_w,
scalar_t* grad_im) {
CUDA_1D_KERNEL_LOOP(index, n) {
const int out_x = index % out_w;
const int out_y = (index / out_w) % out_h;
const int b = (index / (out_w * out_h)) % batch_sz;
const int j = (index / (out_w * out_h * batch_sz)) % kernel_w;
const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h;
const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h);
int c_per_offset_grp = channels / n_offset_grps;
const int offset_grp = c / c_per_offset_grp;
offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w *
out_h * out_w;
const int offset_h_ptr =
((2 * (i * kernel_w + j)) * out_h + out_y) * out_w + out_x;
const int offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * out_h + out_y) * out_w + out_x;
const scalar_t offset_h = offset_ptr[offset_h_ptr];
const scalar_t offset_w = offset_ptr[offset_w_ptr];
const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
for (int dy = -1; dy <= 1; dy++) {
for (int dx = -1; dx <= 1; dx++) {
int yp = int(y) + dy;
int xp = int(x) + dx;
if (0 <= yp && yp < height && 0 <= xp && xp < width &&
std::abs(y - yp) < 1 && std::abs(x - xp) < 1) {
int grad_pos = ((b * channels + c) * height + yp) * width + xp;
scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp));
atomicAdd(grad_im + grad_pos, weight * col[index]);
}
}
}
}
}
static void compute_grad_input(
const at::Tensor columns,
const at::Tensor offset,
const int channels,
const int height,
const int width,
const int weight_h,
const int weight_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int parallel_imgs,
const int n_offset_grps,
at::Tensor grad_im) {
int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
int out_w =
(width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
int num_kernels =
channels * weight_h * weight_w * out_h * out_w * parallel_imgs;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
columns.scalar_type(), "deformable_col2im_gpu", ([&] {
deformable_col2im_gpu_kernel<<<
GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS>>>(
num_kernels,
columns.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(),
channels,
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
parallel_imgs,
n_offset_grps,
out_h,
out_w,
grad_im.data_ptr<scalar_t>());
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in compute_grad_input: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
__device__ scalar_t get_coordinate_weight(
const scalar_t* im_data,
const int height,
const int width,
scalar_t y,
scalar_t x,
bool is_y_direction) {
int y_l = floor(y);
int x_l = floor(x);
int y_h = y_l + 1;
int x_h = x_l + 1;
bool valid_y_l = 0 <= y_l && y_l < height;
bool valid_y_h = 0 <= y_h && y_h < height;
bool valid_x_l = 0 <= x_l && x_l < width;
bool valid_x_h = 0 <= x_h && x_h < width;
scalar_t zero = 0;
scalar_t v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero;
scalar_t v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero;
scalar_t v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero;
scalar_t v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero;
if (is_y_direction) {
scalar_t dx = x - x_l;
return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx);
} else {
scalar_t dy = y - y_l;
return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx);
}
}
template <typename scalar_t>
__global__ void deformable_col2im_coord_gpu_kernel(
const int n,
const scalar_t* col_ptr,
const scalar_t* im_ptr,
const scalar_t* offset_ptr,
const int channels,
const int height,
const int width,
const int weight_h,
const int weight_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int batch_sz,
const int offset_channels,
const int n_offset_grps,
const int out_h,
const int out_w,
scalar_t* grad_offset) {
CUDA_1D_KERNEL_LOOP(index, n) {
scalar_t val = 0;
int w = index % out_w;
int h = (index / out_w) % out_h;
int c = (index / (out_w * out_h)) % offset_channels;
int b = index / (out_w * out_h * offset_channels);
const int offset_grp = c / (2 * weight_h * weight_w);
const int col_step = weight_h * weight_w;
int c_per_offset_grp = channels / n_offset_grps;
col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz *
out_w * out_h;
im_ptr +=
(b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width;
offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w *
out_h * out_w;
const int offset_c = c - offset_grp * 2 * weight_h * weight_w;
const int is_y_direction = offset_c % 2 == 0;
const int c_bound = c_per_offset_grp * weight_h * weight_w;
for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) {
const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w;
int out_x = col_pos % out_w;
int out_y = (col_pos / out_w) % out_h;
int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w;
int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h;
const int offset_h_ptr =
(((2 * (i * weight_w + j)) * out_h + out_y) * out_w + out_x);
const int offset_w_ptr =
(((2 * (i * weight_w + j) + 1) * out_h + out_y) * out_w + out_x);
const scalar_t offset_h = offset_ptr[offset_h_ptr];
const scalar_t offset_w = offset_ptr[offset_w_ptr];
scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
const scalar_t weight =
get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction);
val += weight * col_ptr[col_pos];
im_ptr += height * width;
}
grad_offset[index] = val;
}
}
static void compute_grad_offset(
const at::Tensor columns,
const at::Tensor input,
const at::Tensor offset,
const int channels,
const int height,
const int width,
const int weight_h,
const int weight_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int parallel_imgs,
const int n_offset_grps,
at::Tensor grad_offset) {
int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
int out_w =
(width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
int num_kernels =
out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
columns.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
deformable_col2im_coord_gpu_kernel<<<
GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS>>>(
num_kernels,
columns.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(),
channels,
height,
width,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
parallel_imgs,
2 * weight_h * weight_w * n_offset_grps,
n_offset_grps,
out_h,
out_w,
grad_offset.data_ptr<scalar_t>());
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in compute_grad_offset: %s\n", cudaGetErrorString(err));
}
}
static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda(
at::Tensor input,
at::Tensor weight,
at::Tensor offset,
at::Tensor grad_out,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int n_weight_grps,
int n_offset_grps,
int n_parallel_imgs) {
at::DeviceGuard guard(input.device());
int batch_sz = input.size(0);
long n_in_channels = input.size(1);
long in_h = input.size(2);
long in_w = input.size(3);
n_parallel_imgs = std::min(batch_sz, n_parallel_imgs);
long n_out_channels = weight.size(0);
int weight_h = weight.size(2);
int weight_w = weight.size(3);
int stride_h = stride.first;
int stride_w = stride.second;
int pad_h = pad.first;
int pad_w = pad.second;
int dil_h = dilation.first;
int dil_w = dilation.second;
long out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) / stride_w + 1;
long out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) / stride_h + 1;
auto grad_input = at::zeros_like(input);
auto grad_offset = at::zeros_like(offset);
auto columns = at::zeros(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options());
// Separate into blocks
grad_input = grad_input.view(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
input = input.view(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
grad_offset = grad_offset.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
offset = offset.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
grad_out = grad_out.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_out_channels,
out_h,
out_w});
grad_out.transpose_(1, 2);
grad_out = grad_out.view({grad_out.size(0),
n_weight_grps,
grad_out.size(1) / n_weight_grps,
grad_out.size(2),
grad_out.size(3),
grad_out.size(4)});
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
// Separate into weight groups
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
weight = weight.view({n_weight_grps,
weight.size(0) / n_weight_grps,
weight.size(1),
weight.size(2),
weight.size(3)});
for (int g = 0; g < n_weight_grps; g++) {
columns[g] = columns[g].addmm_(
weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1));
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
compute_grad_offset(
columns,
input[elt],
offset[elt],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
n_parallel_imgs,
n_offset_grps,
grad_offset[elt]);
compute_grad_input(
columns,
offset[elt],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
n_parallel_imgs,
n_offset_grps,
grad_input[elt]);
}
grad_out = grad_out.view({grad_out.size(0),
grad_out.size(1) * grad_out.size(2),
grad_out.size(3),
grad_out.size(4),
grad_out.size(5)});
grad_out.transpose_(1, 2);
grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w});
grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w});
input = input.view({batch_sz, n_in_channels, in_h, in_w});
grad_offset = grad_offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
offset = offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
return std::make_tuple(grad_input, grad_offset);
}
static at::Tensor deform_conv_backward_parameters_cuda(
at::Tensor input,
at::Tensor weight,
at::Tensor offset,
at::Tensor grad_out,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int n_weight_grps,
int n_offset_grps,
int n_parallel_imgs) {
at::DeviceGuard guard(input.device());
int batch_sz = input.size(0);
long n_in_channels = input.size(1);
long in_h = input.size(2);
long in_w = input.size(3);
n_parallel_imgs = std::min(batch_sz, n_parallel_imgs);
long n_out_channels = weight.size(0);
int weight_h = weight.size(2);
int weight_w = weight.size(3);
int stride_h = stride.first;
int stride_w = stride.second;
int pad_h = pad.first;
int pad_w = pad.second;
int dil_h = dilation.first;
int dil_w = dilation.second;
long out_h = grad_out.size(2);
long out_w = grad_out.size(3);
auto grad_weight = at::zeros_like(weight);
;
auto columns = at::zeros(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options());
grad_out = grad_out.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_out_channels,
out_h,
out_w});
grad_out.transpose_(1, 2);
at::Tensor grad_out_buf = at::zeros_like(grad_out);
grad_out_buf.copy_(grad_out);
grad_out_buf = grad_out_buf.view({batch_sz / n_parallel_imgs,
n_out_channels,
n_parallel_imgs * out_h,
out_w});
grad_out_buf = grad_out_buf.view({grad_out_buf.size(0),
n_weight_grps,
grad_out_buf.size(1) / n_weight_grps,
grad_out_buf.size(2),
grad_out_buf.size(3)});
grad_out.transpose_(1, 2);
grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w});
input = input.view(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
offset = offset.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
grad_weight = grad_weight.view({n_weight_grps,
grad_weight.size(0) / n_weight_grps,
grad_weight.size(1),
grad_weight.size(2),
grad_weight.size(3)});
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
deformable_im2col(
input[elt],
offset[elt],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dil_h,
dil_w,
out_h,
out_w,
n_parallel_imgs,
n_offset_grps,
columns);
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int g = 0; g < n_weight_grps; g++) {
grad_weight[g] =
grad_weight[g]
.flatten(1)
.addmm_(
grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0))
.view_as(grad_weight[g]);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
input = input.view({batch_sz, n_in_channels, in_h, in_w});
offset = offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2),
grad_weight.size(3),
grad_weight.size(4)});
return grad_weight;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cuda(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int n_weight_grps,
int n_offset_grps) {
const int batch_sz = input.size(0);
const int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
auto grad_input_and_offset = deform_conv_backward_input_cuda(
input,
weight,
offset,
grad_out,
stride,
pad,
dilation,
n_weight_grps,
n_offset_grps,
n_parallel_imgs);
auto grad_input = std::get<0>(grad_input_and_offset);
auto grad_offset = std::get<1>(grad_input_and_offset);
auto grad_weight = deform_conv_backward_parameters_cuda(
input,
weight,
offset,
grad_out,
stride,
pad,
dilation,
n_weight_grps,
n_offset_grps,
n_parallel_imgs);
auto value = grad_out.sum({0, 2, 3});
auto grad_bias = at::ones_like(bias) * value;
return std::make_tuple(grad_input, grad_weight, grad_offset, grad_bias);
}
......@@ -85,3 +85,27 @@ at::Tensor nms_cuda(
const at::Tensor& dets,
const at::Tensor& scores,
const float iou_threshold);
at::Tensor DeformConv2d_forward_cuda(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int groups,
int deformable_groups);
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cuda(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int groups,
int deformable_groups);
......@@ -5,6 +5,7 @@
#include <cuda.h>
#endif
#include "DeformConv.h"
#include "PSROIAlign.h"
#include "PSROIPool.h"
#include "ROIAlign.h"
......@@ -47,4 +48,5 @@ static auto registry =
.op("torchvision::_new_empty_tensor_op", &new_empty_tensor)
.op("torchvision::ps_roi_align", &ps_roi_align)
.op("torchvision::ps_roi_pool", &ps_roi_pool)
.op("torchvision::deform_conv2d", &deform_conv2d)
.op("torchvision::_cuda_version", &_cuda_version);
from .boxes import nms, box_iou
from .new_empty_tensor import _new_empty_tensor
from .deform_conv import deform_conv2d, DeformConv2d
from .roi_align import roi_align, RoIAlign
from .roi_pool import roi_pool, RoIPool
from .ps_roi_align import ps_roi_align, PSRoIAlign
......@@ -13,7 +14,7 @@ _register_custom_op()
__all__ = [
'nms', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool', '_new_empty_tensor',
'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', 'PSRoIPool',
'MultiScaleRoIAlign', 'FeaturePyramidNetwork'
'deform_conv2d', 'DeformConv2d', 'nms', 'roi_align', 'RoIAlign', 'roi_pool',
'RoIPool', '_new_empty_tensor', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool',
'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork'
]
import math
import torch
from torch import nn, Tensor
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair
from torch.jit.annotations import Optional, Tuple
def deform_conv2d(input, offset, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1)):
# type: (Tensor, Tensor, Tensor, Optional[Tensor], Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
"""
Performs Deformable Convolution, described in Deformable Convolutional Networks
Arguments:
input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width,
out_height, out_width]): offsets to be applied for each position in the
convolution kernel.
weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]):
convolution weights, split into groups of size (in_channels // groups)
bias (Tensor[out_channels]): optional bias of shape (out_channels,). Default: None
stride (int or Tuple[int, int]): distance between convolution centers. Default: 1
padding (int or Tuple[int, int]): height/width of padding of zeroes around
each image. Default: 0
dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1
Returns:
output (Tensor[batch_sz, out_channels, out_h, out_w]): result of convolution
"""
out_channels = weight.shape[0]
if bias is None:
bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype)
stride_h, stride_w = _pair(stride)
pad_h, pad_w = _pair(padding)
dil_h, dil_w = _pair(dilation)
weights_h, weights_w = weight.shape[-2:]
_, n_in_channels, in_h, in_w = input.shape
n_offset_grps = offset.shape[1] // (2 * weights_h * weights_w)
n_weight_grps = n_in_channels // weight.shape[1]
return torch.ops.torchvision.deform_conv2d(
input,
weight,
offset,
bias,
stride_h, stride_w,
pad_h, pad_w,
dil_h, dil_w,
n_weight_grps,
n_offset_grps)
class DeformConv2d(nn.Module):
"""
See deform_conv2d
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, offset_groups=1, bias=True):
super(DeformConv2d, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if in_channels % offset_groups != 0:
raise ValueError('in_channels must be divisible by offset_groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.offset_groups = offset_groups
self.weight = Parameter(torch.empty(out_channels, in_channels // groups, kernel_size[0], kernel_size[1]))
if bias:
self.bias = Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, input, offset):
"""
Arguments:
input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]):
convolution weights, split into groups of size (in_channels // groups)
offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width,
out_height, out_width]): offsets to be applied for each position in the
convolution kernel.
"""
return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride,
padding=self.padding, dilation=self.dilation)
def __repr__(self):
s = self.__class__.__name__ + '('
s += '{in_channels}'
s += ', {out_channels}'
s += ', kernel_size={kernel_size}'
s += ', stride={stride}'
s += ', padding={padding}' if self.padding != (0, 0) else ''
s += ', dilation={dilation}' if self.dilation != (1, 1) else ''
s += ', groups={groups}' if self.groups != 1 else ''
s += ', offset_groups={offset_groups}' if self.offset_groups != 1 else ''
s += ', bias=False' if self.bias is None else ''
s += ')'
return s.format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment