Commit 52b8685b authored by pedrofreire's avatar pedrofreire Committed by Francisco Massa
Browse files

Add Deformable Convolution operation. (#1586)

* Add Deformable Convolution operation.

This adds the deformable convolution operation, as described in Deformable Convolutional Networks (https://arxiv.org/abs/1703.06211).

- The code is based on https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp ; the whole code was modified and refactored to remove redundancies and increase clarity, and to adapt it to torchvision.

- The CPU part is a direct copy of the CUDA code; it might make sense to do follow-up adjustments in the CPU code to simplify it / optimize it, or to reuse functionality between CPU and CUDA..

- We also add tests (with a non-trivial set of parameters); they can be made more robust by randomizing the parameters and executing multiple times.

* Update DeformConv to be more consistent w/ Conv2d

* rename some variables and arguments to match Conv2d;
* add optional bias;
* add weight, offset and bias as module parameters;
* remove the n_parallel_imgs parameter;
* Fix __repr__;
* etc..

Initialization of weight and bias is the same as in Conv2d, and
initialization of offsets to zero is the same as in the paper.

This also includes some other small unrelated fixes/improvements.

* Apply clang-format in DeformConv files.

* Import Optional type annotation

* Remove offset param from DeformConv2d module

- We pass the offset in the forward of DeformConv2d, instead of having
an internal parameter. This adds some complexity to creating the module
(e.g. now you have to worry about the output size, to create the
offset), but it gives more flexibility.
- We also use make_tuple for tuple creation, in an attempt to fix error
w/ older compilers.

* Replace abs by std::abs

Old gcc versions were giving wrong results here, because they would
resolve abs as int -> int, thus causing undesired truncation. Replacing
abs by std::abs should allow for correct overloading of abs as float -> float.

* Reorder declarations for clarity

* Reorder weight and offset args in deform_conv2d

We place offset arg before the weight arg, to be more
consistent with DeformConv2d.forward(input, offset)

* Replace abs by std::abs in DeformConv_cuda
parent 5b1716a2
from __future__ import division
import math
import unittest
import numpy as np
import torch
from torch import Tensor
from torch.autograd import gradcheck
from torch.jit.annotations import Tuple
from torch.nn.modules.utils import _pair
from torchvision import ops
from itertools import product
import unittest
class RoIOpTester(object):
class OpTester(object):
@classmethod
def setUpClass(cls):
cls.dtype = torch.float64
......@@ -42,6 +45,14 @@ class RoIOpTester(object):
def test_backward_cuda_non_contiguous(self):
self._test_backward(device=torch.device('cuda'), contiguous=False)
def _test_forward(self, device, contiguous):
pass
def _test_backward(self, device, contiguous):
pass
class RoIOpTester(OpTester):
def _test_forward(self, device, contiguous):
pool_size = 5
# n_channels % (pool_size ** 2) == 0 required for PS opeartions.
......@@ -79,7 +90,6 @@ class RoIOpTester(object):
self.assertTrue(gradcheck(func, (x,)))
self.assertTrue(gradcheck(script_func, (x,)))
return
def fn(*args, **kwargs):
pass
......@@ -98,7 +108,7 @@ class RoIPoolTester(RoIOpTester, unittest.TestCase):
def get_script_fn(self, rois, pool_size):
@torch.jit.script
def script_fn(input, rois, pool_size):
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
# type: (Tensor, Tensor, int) -> Tensor
return ops.roi_pool(input, rois, pool_size, 1.0)[0]
return lambda x: script_fn(x, rois, pool_size)
......@@ -137,7 +147,7 @@ class PSRoIPoolTester(RoIOpTester, unittest.TestCase):
def get_script_fn(self, rois, pool_size):
@torch.jit.script
def script_fn(input, rois, pool_size):
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
# type: (Tensor, Tensor, int) -> Tensor
return ops.ps_roi_pool(input, rois, pool_size, 1.0)[0]
return lambda x: script_fn(x, rois, pool_size)
......@@ -174,29 +184,35 @@ class PSRoIPoolTester(RoIOpTester, unittest.TestCase):
return y
def bilinear_interpolate(data, height, width, y, x):
if y < -1.0 or y > height or x < -1.0 or x > width:
return 0.
def bilinear_interpolate(data, y, x, snap_border=False):
height, width = data.shape
y = min(max(0, y), height - 1)
x = min(max(0, x), width - 1)
if snap_border:
if -1 < y <= 0:
y = 0
elif height - 1 <= y < height:
y = height - 1
y_low = int(y)
y_high = min(y_low + 1, height - 1)
if -1 < x <= 0:
x = 0
elif width - 1 <= x < width:
x = width - 1
x_low = int(x)
x_high = min(x_low + 1, width - 1)
y_low = int(math.floor(y))
x_low = int(math.floor(x))
y_high = y_low + 1
x_high = x_low + 1
wy_h = y - y_low
wy_l = 1 - wy_h
wx_h = x - x_low
wy_l = 1 - wy_h
wx_l = 1 - wx_h
val = 0
for wx, x in zip((wx_l, wx_h), (x_low, x_high)):
for wy, y in zip((wy_l, wy_h), (y_low, y_high)):
val += wx * wy * data[y * width + x]
for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
if 0 <= yp < height and 0 <= xp < width:
val += wx * wy * data[yp, xp]
return val
......@@ -208,7 +224,7 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase):
def get_script_fn(self, rois, pool_size):
@torch.jit.script
def script_fn(input, rois, pool_size):
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
# type: (Tensor, Tensor, int) -> Tensor
return ops.roi_align(input, rois, pool_size, 1.0)[0]
return lambda x: script_fn(x, rois, pool_size)
......@@ -242,12 +258,7 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase):
y = start_h + (iy + 0.5) * bin_h / grid_h
for ix in range(0, grid_w):
x = start_w + (ix + 0.5) * bin_w / grid_w
val += bilinear_interpolate(
in_data[batch_idx, channel, :, :].flatten(),
in_data.size(-2),
in_data.size(-1),
y, x
)
val += bilinear_interpolate(in_data[batch_idx, channel, :, :], y, x, snap_border=True)
val /= grid_h * grid_w
out_data[r, channel, i, j] = val
......@@ -262,7 +273,7 @@ class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
def get_script_fn(self, rois, pool_size):
@torch.jit.script
def script_fn(input, rois, pool_size):
# type: (torch.Tensor, torch.Tensor, int) -> torch.Tensor
# type: (Tensor, Tensor, int) -> Tensor
return ops.ps_roi_align(input, rois, pool_size, 1.0)[0]
return lambda x: script_fn(x, rois, pool_size)
......@@ -298,12 +309,7 @@ class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
y = start_h + (iy + 0.5) * bin_h / grid_h
for ix in range(0, grid_w):
x = start_w + (ix + 0.5) * bin_w / grid_w
val += bilinear_interpolate(
in_data[batch_idx, c_in, :, :].flatten(),
in_data.size(-2),
in_data.size(-1),
y, x
)
val += bilinear_interpolate(in_data[batch_idx, c_in, :, :], y, x, snap_border=True)
val /= grid_h * grid_w
out_data[r, c_out, i, j] = val
......@@ -376,5 +382,120 @@ class NewEmptyTensorTester(unittest.TestCase):
assert out.dtype == input.dtype
class DeformConvTester(OpTester, unittest.TestCase):
def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1):
stride_h, stride_w = _pair(stride)
pad_h, pad_w = _pair(padding)
dil_h, dil_w = _pair(dilation)
weight_h, weight_w = weight.shape[-2:]
n_batches, n_in_channels, in_h, in_w = x.shape
n_out_channels = weight.shape[0]
out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1
n_offset_grps = offset.shape[1] // (2 * weight_h * weight_w)
in_c_per_offset_grp = n_in_channels // n_offset_grps
n_weight_grps = n_in_channels // weight.shape[1]
in_c_per_weight_grp = weight.shape[1]
out_c_per_weight_grp = n_out_channels // n_weight_grps
out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype)
for b in range(n_batches):
for c_out in range(n_out_channels):
for i in range(out_h):
for j in range(out_w):
for di in range(weight_h):
for dj in range(weight_w):
for c in range(in_c_per_weight_grp):
weight_grp = c_out // out_c_per_weight_grp
c_in = weight_grp * in_c_per_weight_grp + c
offset_grp = c_in // in_c_per_offset_grp
offset_idx = 2 * (offset_grp * (weight_h * weight_w) + di * weight_w + dj)
pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]
out[b, c_out, i, j] += (weight[c_out, c, di, dj] *
bilinear_interpolate(x[b, c_in, :, :], pi, pj))
out += bias.view(1, n_out_channels, 1, 1)
return out
def get_fn_args(self, device, contiguous):
batch_sz = 1
n_in_channels = 6
n_out_channels = 2
n_weight_grps = 2
n_offset_grps = 3
stride = (2, 1)
pad = (1, 0)
dilation = (2, 1)
stride_h, stride_w = stride
pad_h, pad_w = pad
dil_h, dil_w = dilation
weight_h, weight_w = (3, 2)
in_h, in_w = (5, 4)
out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1
x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=self.dtype, requires_grad=True)
offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w,
device=device, dtype=self.dtype, requires_grad=True)
weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w,
device=device, dtype=self.dtype, requires_grad=True)
bias = torch.randn(n_out_channels, device=device, dtype=self.dtype, requires_grad=True)
if not contiguous:
x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
return x, weight, offset, bias, stride, pad, dilation
def _test_forward(self, device, contiguous):
x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous)
in_channels = 6
out_channels = 2
kernel_size = (3, 2)
groups = 2
offset_groups = 3
layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, offset_groups=offset_groups).to(device=x.device,
dtype=x.dtype)
res = layer(x, offset)
weight = layer.weight.data
bias = layer.bias.data
expected = self.expected_fn(x, weight, offset, bias, stride=stride, padding=padding, dilation=dilation)
self.assertTrue(torch.allclose(res, expected), '\nres:\n{}\nexpected:\n{}'.format(res, expected))
def _test_backward(self, device, contiguous):
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous)
def func(x_, offset_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation)
gradcheck(func, (x, offset, weight, bias), nondet_tol=1e-5)
@torch.jit.script
def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
# type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_)
gradcheck(lambda z, off, wei, bi: script_func(z, off, wei, bi, stride, padding, dilation),
(x, offset, weight, bias), nondet_tol=1e-5)
if __name__ == '__main__':
unittest.main()
#pragma once
#include "cpu/vision_cpu.h"
#ifdef WITH_CUDA
#include "cuda/vision_cuda.h"
#endif
at::Tensor DeformConv2d_forward(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const std::pair<int, int>& stride,
const std::pair<int, int>& padding,
const std::pair<int, int>& dilation,
const int groups,
const int offset_groups) {
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return DeformConv2d_forward_cuda(
input.contiguous(),
weight.contiguous(),
offset.contiguous(),
bias.contiguous(),
stride,
padding,
dilation,
groups,
offset_groups);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return DeformConv2d_forward_cpu(
input.contiguous(),
weight.contiguous(),
offset.contiguous(),
bias.contiguous(),
stride,
padding,
dilation,
groups,
offset_groups);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> DeformConv2d_backward(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const std::pair<int, int>& stride,
const std::pair<int, int>& padding,
const std::pair<int, int>& dilation,
const int groups,
const int offset_groups) {
if (grad.type().is_cuda()) {
#ifdef WITH_CUDA
return DeformConv2d_backward_cuda(
grad.contiguous(),
input.contiguous(),
weight.contiguous(),
offset.contiguous(),
bias.contiguous(),
stride,
padding,
dilation,
groups,
offset_groups);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return DeformConv2d_backward_cpu(
grad.contiguous(),
input.contiguous(),
weight.contiguous(),
offset.contiguous(),
bias.contiguous(),
stride,
padding,
dilation,
groups,
offset_groups);
}
using namespace at;
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class DeformConv2dFunction
: public torch::autograd::Function<DeformConv2dFunction> {
public:
static variable_list forward(
AutogradContext* ctx,
Variable input,
Variable weight,
Variable offset,
Variable bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
auto output = DeformConv2d_forward(
input,
weight,
offset,
bias,
{stride_h, stride_w},
{pad_h, pad_w},
{dilation_h, dilation_w},
groups,
offset_groups);
ctx->save_for_backward({input, weight, offset, bias});
ctx->saved_data["stride_h"] = stride_h;
ctx->saved_data["stride_w"] = stride_w;
ctx->saved_data["pad_h"] = pad_h;
ctx->saved_data["pad_w"] = pad_w;
ctx->saved_data["dilation_h"] = dilation_h;
ctx->saved_data["dilation_w"] = dilation_w;
ctx->saved_data["groups"] = groups;
ctx->saved_data["offset_groups"] = offset_groups;
return {
output,
};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
auto saved = ctx->get_saved_variables();
auto input = saved[0];
auto weight = saved[1];
auto offset = saved[2];
auto bias = saved[3];
auto stride_h = ctx->saved_data["stride_h"].toInt();
auto stride_w = ctx->saved_data["stride_w"].toInt();
auto pad_h = ctx->saved_data["pad_h"].toInt();
auto pad_w = ctx->saved_data["pad_w"].toInt();
auto dilation_h = ctx->saved_data["dilation_h"].toInt();
auto dilation_w = ctx->saved_data["dilation_w"].toInt();
auto groups = ctx->saved_data["groups"].toInt();
auto offset_groups = ctx->saved_data["offset_groups"].toInt();
auto grads = DeformConv2d_backward(
grad_output[0],
input,
weight,
offset,
bias,
{stride_h, stride_w},
{pad_h, pad_w},
{dilation_h, dilation_w},
groups,
offset_groups);
auto grad_input = std::get<0>(grads);
auto grad_weight = std::get<1>(grads);
auto grad_offset = std::get<2>(grads);
auto grad_bias = std::get<3>(grads);
return {
grad_input,
grad_weight,
grad_offset,
grad_bias,
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
};
}
};
at::Tensor deform_conv2d(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
auto result = DeformConv2dFunction::apply(
input,
weight,
offset,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);
return result[0];
}
This diff is collapsed.
......@@ -84,3 +84,27 @@ at::Tensor nms_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
const float iou_threshold);
at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int groups,
int deformable_groups);
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cpu(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int groups,
int deformable_groups);
This diff is collapsed.
......@@ -85,3 +85,27 @@ at::Tensor nms_cuda(
const at::Tensor& dets,
const at::Tensor& scores,
const float iou_threshold);
at::Tensor DeformConv2d_forward_cuda(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int groups,
int deformable_groups);
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cuda(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
std::pair<int, int> stride,
std::pair<int, int> pad,
std::pair<int, int> dilation,
int groups,
int deformable_groups);
......@@ -5,6 +5,7 @@
#include <cuda.h>
#endif
#include "DeformConv.h"
#include "PSROIAlign.h"
#include "PSROIPool.h"
#include "ROIAlign.h"
......@@ -47,4 +48,5 @@ static auto registry =
.op("torchvision::_new_empty_tensor_op", &new_empty_tensor)
.op("torchvision::ps_roi_align", &ps_roi_align)
.op("torchvision::ps_roi_pool", &ps_roi_pool)
.op("torchvision::deform_conv2d", &deform_conv2d)
.op("torchvision::_cuda_version", &_cuda_version);
from .boxes import nms, box_iou
from .new_empty_tensor import _new_empty_tensor
from .deform_conv import deform_conv2d, DeformConv2d
from .roi_align import roi_align, RoIAlign
from .roi_pool import roi_pool, RoIPool
from .ps_roi_align import ps_roi_align, PSRoIAlign
......@@ -13,7 +14,7 @@ _register_custom_op()
__all__ = [
'nms', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool', '_new_empty_tensor',
'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', 'PSRoIPool',
'MultiScaleRoIAlign', 'FeaturePyramidNetwork'
'deform_conv2d', 'DeformConv2d', 'nms', 'roi_align', 'RoIAlign', 'roi_pool',
'RoIPool', '_new_empty_tensor', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool',
'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork'
]
import math
import torch
from torch import nn, Tensor
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair
from torch.jit.annotations import Optional, Tuple
def deform_conv2d(input, offset, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1)):
# type: (Tensor, Tensor, Tensor, Optional[Tensor], Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
"""
Performs Deformable Convolution, described in Deformable Convolutional Networks
Arguments:
input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width,
out_height, out_width]): offsets to be applied for each position in the
convolution kernel.
weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]):
convolution weights, split into groups of size (in_channels // groups)
bias (Tensor[out_channels]): optional bias of shape (out_channels,). Default: None
stride (int or Tuple[int, int]): distance between convolution centers. Default: 1
padding (int or Tuple[int, int]): height/width of padding of zeroes around
each image. Default: 0
dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1
Returns:
output (Tensor[batch_sz, out_channels, out_h, out_w]): result of convolution
"""
out_channels = weight.shape[0]
if bias is None:
bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype)
stride_h, stride_w = _pair(stride)
pad_h, pad_w = _pair(padding)
dil_h, dil_w = _pair(dilation)
weights_h, weights_w = weight.shape[-2:]
_, n_in_channels, in_h, in_w = input.shape
n_offset_grps = offset.shape[1] // (2 * weights_h * weights_w)
n_weight_grps = n_in_channels // weight.shape[1]
return torch.ops.torchvision.deform_conv2d(
input,
weight,
offset,
bias,
stride_h, stride_w,
pad_h, pad_w,
dil_h, dil_w,
n_weight_grps,
n_offset_grps)
class DeformConv2d(nn.Module):
"""
See deform_conv2d
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, offset_groups=1, bias=True):
super(DeformConv2d, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if in_channels % offset_groups != 0:
raise ValueError('in_channels must be divisible by offset_groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.offset_groups = offset_groups
self.weight = Parameter(torch.empty(out_channels, in_channels // groups, kernel_size[0], kernel_size[1]))
if bias:
self.bias = Parameter(torch.empty(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, input, offset):
"""
Arguments:
input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]):
convolution weights, split into groups of size (in_channels // groups)
offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width,
out_height, out_width]): offsets to be applied for each position in the
convolution kernel.
"""
return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride,
padding=self.padding, dilation=self.dilation)
def __repr__(self):
s = self.__class__.__name__ + '('
s += '{in_channels}'
s += ', {out_channels}'
s += ', kernel_size={kernel_size}'
s += ', stride={stride}'
s += ', padding={padding}' if self.padding != (0, 0) else ''
s += ', dilation={dilation}' if self.dilation != (1, 1) else ''
s += ', groups={groups}' if self.groups != 1 else ''
s += ', offset_groups={offset_groups}' if self.offset_groups != 1 else ''
s += ', bias=False' if self.bias is None else ''
s += ')'
return s.format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment