Unverified Commit 5a4bb19d authored by Licht Takeuchi's avatar Licht Takeuchi Committed by GitHub
Browse files

Add modulation input for DeformConv2D (#2791)

* Add modulation input for DeformConv2D

* lint

* Patch for GPU CI

* Remove bad cache on CI
parent 2b39f1e8
......@@ -455,6 +455,7 @@ jobs:
resource_class: gpu.small
environment:
image_name: "pytorch/manylinux-cuda101"
PYTHON_VERSION: << parameters.python_version >>
steps:
- checkout
- designate_upload_channel
......@@ -462,14 +463,9 @@ jobs:
name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:
keys:
- env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
- run:
name: Setup
command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh
command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh
- save_cache:
key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
......@@ -533,6 +529,7 @@ jobs:
name: windows-gpu
environment:
CUDA_VERSION: "10.1"
PYTHON_VERSION: << parameters.python_version >>
steps:
- checkout
- designate_upload_channel
......@@ -540,11 +537,6 @@ jobs:
name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:
keys:
- env-v1-windows-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/windows/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
- run:
name: Setup
command: .circleci/unittest/windows/scripts/setup_env.sh
......
......@@ -455,6 +455,7 @@ jobs:
resource_class: gpu.small
environment:
image_name: "pytorch/manylinux-cuda101"
PYTHON_VERSION: << parameters.python_version >>
steps:
- checkout
- designate_upload_channel
......@@ -462,14 +463,9 @@ jobs:
name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:
{% raw %}
keys:
- env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
{% endraw %}
- run:
name: Setup
command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh
command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh
- save_cache:
{% raw %}
key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
......@@ -533,6 +529,7 @@ jobs:
name: windows-gpu
environment:
CUDA_VERSION: "10.1"
PYTHON_VERSION: << parameters.python_version >>
steps:
- checkout
- designate_upload_channel
......@@ -540,11 +537,6 @@ jobs:
name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:
{% raw %}
keys:
- env-v1-windows-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/windows/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
{% endraw %}
- run:
name: Setup
command: .circleci/unittest/windows/scripts/setup_env.sh
......
......@@ -458,7 +458,7 @@ class NewEmptyTensorTester(unittest.TestCase):
class DeformConvTester(OpTester, unittest.TestCase):
def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1):
def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
stride_h, stride_w = _pair(stride)
pad_h, pad_w = _pair(padding)
dil_h, dil_w = _pair(dilation)
......@@ -489,12 +489,17 @@ class DeformConvTester(OpTester, unittest.TestCase):
c_in = weight_grp * in_c_per_weight_grp + c
offset_grp = c_in // in_c_per_offset_grp
offset_idx = 2 * (offset_grp * (weight_h * weight_w) + di * weight_w + dj)
mask_idx = offset_grp * (weight_h * weight_w) + di * weight_w + dj
offset_idx = 2 * mask_idx
pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]
out[b, c_out, i, j] += (weight[c_out, c, di, dj] *
mask_value = 1.0
if mask is not None:
mask_value = mask[b, mask_idx, i, j]
out[b, c_out, i, j] += (mask_value * weight[c_out, c, di, dj] *
bilinear_interpolate(x[b, c_in, :, :], pi, pj))
out += bias.view(1, n_out_channels, 1, 1)
return out
......@@ -523,6 +528,9 @@ class DeformConvTester(OpTester, unittest.TestCase):
offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w,
device=device, dtype=dtype, requires_grad=True)
mask = torch.randn(batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w,
device=device, dtype=dtype, requires_grad=True)
weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w,
device=device, dtype=dtype, requires_grad=True)
......@@ -531,9 +539,10 @@ class DeformConvTester(OpTester, unittest.TestCase):
if not contiguous:
x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
return x, weight, offset, bias, stride, pad, dilation
return x, weight, offset, mask, bias, stride, pad, dilation
def _test_forward(self, device, contiguous, dtype=None):
dtype = self.dtype if dtype is None else dtype
......@@ -541,21 +550,28 @@ class DeformConvTester(OpTester, unittest.TestCase):
self._test_forward_with_batchsize(device, contiguous, batch_sz, dtype)
def _test_forward_with_batchsize(self, device, contiguous, batch_sz, dtype):
x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
in_channels = 6
out_channels = 2
kernel_size = (3, 2)
groups = 2
tol = 1e-3 if dtype is torch.half else 1e-5
layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups).to(device=x.device, dtype=dtype)
res = layer(x, offset)
res = layer(x, offset, mask)
weight = layer.weight.data
bias = layer.bias.data
expected = self.expected_fn(x, weight, offset, bias, stride=stride, padding=padding, dilation=dilation)
expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)
self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol),
'\nres:\n{}\nexpected:\n{}'.format(res, expected))
# no modulation test
res = layer(x, offset)
expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
tol = 1e-3 if dtype is torch.half else 1e-5
self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol),
'\nres:\n{}\nexpected:\n{}'.format(res, expected))
......@@ -564,24 +580,46 @@ class DeformConvTester(OpTester, unittest.TestCase):
wrong_offset = torch.rand_like(offset[:, :2])
res = layer(x, wrong_offset)
with self.assertRaises(RuntimeError):
wrong_mask = torch.rand_like(mask[:, :2])
res = layer(x, offset, wrong_mask)
def _test_backward(self, device, contiguous):
for batch_sz in [0, 33]:
self._test_backward_with_batchsize(device, contiguous, batch_sz)
def _test_backward_with_batchsize(self, device, contiguous, batch_sz):
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, self.dtype)
x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(device, contiguous,
batch_sz, self.dtype)
def func(x_, offset_, mask_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride,
padding=padding, dilation=dilation, mask=mask_)
def func(x_, offset_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation)
gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5)
def func_no_mask(x_, offset_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride,
padding=padding, dilation=dilation, mask=None)
gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5)
@torch.jit.script
def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
# type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_,
padding=pad_, dilation=dilation_, mask=mask_)
gradcheck(func, (x, offset, weight, bias), nondet_tol=1e-5)
gradcheck(lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
(x, offset, mask, weight, bias), nondet_tol=1e-5)
@torch.jit.script
def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
# type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_)
def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
# type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_,
padding=pad_, dilation=dilation_, mask=None)
gradcheck(lambda z, off, wei, bi: script_func(z, off, wei, bi, stride, padding, dilation),
gradcheck(lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
(x, offset, weight, bias), nondet_tol=1e-5)
# Test from https://github.com/pytorch/vision/issues/2598
......@@ -593,17 +631,19 @@ class DeformConvTester(OpTester, unittest.TestCase):
init_weight = torch.randn(9, 9, 3, 3, requires_grad=True)
img = torch.randn(8, 9, 1000, 110)
offset = torch.rand(8, 2 * 3 * 3, 1000, 110)
mask = torch.rand(8, 3 * 3, 1000, 110)
if not contiguous:
img = img.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
weight = init_weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
else:
weight = init_weight
for d in ["cpu", "cuda"]:
out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1)
out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
out.mean().backward()
if true_cpu_grads is None:
true_cpu_grads = init_weight.grad
......
......@@ -17,6 +17,7 @@ at::Tensor deform_conv2d(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -25,7 +26,8 @@ at::Tensor deform_conv2d(
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
int64_t offset_groups,
bool use_mask) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::deform_conv2d", "")
.typed<decltype(deform_conv2d)>();
......@@ -33,6 +35,7 @@ at::Tensor deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
......@@ -41,7 +44,8 @@ at::Tensor deform_conv2d(
dilation_h,
dilation_w,
groups,
offset_groups);
offset_groups,
use_mask);
}
#if defined(WITH_CUDA) || defined(WITH_HIP)
......@@ -49,6 +53,7 @@ at::Tensor DeformConv2d_autocast(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -57,12 +62,14 @@ at::Tensor DeformConv2d_autocast(
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
int64_t offset_groups,
bool use_mask) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return deform_conv2d(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, weight),
at::autocast::cached_cast(at::kFloat, offset),
at::autocast::cached_cast(at::kFloat, mask),
at::autocast::cached_cast(at::kFloat, bias),
stride_h,
stride_w,
......@@ -71,17 +78,19 @@ at::Tensor DeformConv2d_autocast(
dilation_h,
dilation_w,
groups,
offset_groups)
offset_groups,
use_mask)
.to(input.scalar_type());
}
#endif
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -90,7 +99,8 @@ _deform_conv2d_backward(
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
int64_t offset_groups,
bool use_mask) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
......@@ -100,6 +110,7 @@ _deform_conv2d_backward(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
......@@ -108,7 +119,8 @@ _deform_conv2d_backward(
dilation_h,
dilation_w,
groups,
offset_groups);
offset_groups,
use_mask);
}
class DeformConv2dFunction
......@@ -119,6 +131,7 @@ class DeformConv2dFunction
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -127,12 +140,14 @@ class DeformConv2dFunction
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
int64_t offset_groups,
bool use_mask) {
at::AutoNonVariableTypeMode g;
auto output = deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
......@@ -141,9 +156,10 @@ class DeformConv2dFunction
dilation_h,
dilation_w,
groups,
offset_groups);
offset_groups,
use_mask);
ctx->save_for_backward({input, weight, offset, bias});
ctx->save_for_backward({input, weight, offset, mask, bias});
ctx->saved_data["stride_h"] = stride_h;
ctx->saved_data["stride_w"] = stride_w;
ctx->saved_data["pad_h"] = pad_h;
......@@ -152,6 +168,7 @@ class DeformConv2dFunction
ctx->saved_data["dilation_w"] = dilation_w;
ctx->saved_data["groups"] = groups;
ctx->saved_data["offset_groups"] = offset_groups;
ctx->saved_data["use_mask"] = use_mask;
return {
output,
......@@ -165,7 +182,8 @@ class DeformConv2dFunction
auto input = saved[0];
auto weight = saved[1];
auto offset = saved[2];
auto bias = saved[3];
auto mask = saved[3];
auto bias = saved[4];
auto stride_h = ctx->saved_data["stride_h"].toInt();
auto stride_w = ctx->saved_data["stride_w"].toInt();
......@@ -175,12 +193,14 @@ class DeformConv2dFunction
auto dilation_w = ctx->saved_data["dilation_w"].toInt();
auto groups = ctx->saved_data["groups"].toInt();
auto offset_groups = ctx->saved_data["offset_groups"].toInt();
auto use_mask = ctx->saved_data["use_mask"].toBool();
auto grads = _deform_conv2d_backward(
grad_output[0],
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
......@@ -189,16 +209,19 @@ class DeformConv2dFunction
dilation_h,
dilation_w,
groups,
offset_groups);
offset_groups,
use_mask);
auto grad_input = std::get<0>(grads);
auto grad_weight = std::get<1>(grads);
auto grad_offset = std::get<2>(grads);
auto grad_bias = std::get<3>(grads);
auto grad_mask = std::get<3>(grads);
auto grad_bias = std::get<4>(grads);
return {
grad_input,
grad_weight,
grad_offset,
grad_mask,
grad_bias,
torch::autograd::Variable(),
torch::autograd::Variable(),
......@@ -208,6 +231,7 @@ class DeformConv2dFunction
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
};
}
};
......@@ -222,6 +246,7 @@ class DeformConv2dBackwardFunction
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -230,13 +255,15 @@ class DeformConv2dBackwardFunction
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
int64_t offset_groups,
bool use_mask) {
at::AutoNonVariableTypeMode g;
auto result = _deform_conv2d_backward(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
......@@ -245,17 +272,20 @@ class DeformConv2dBackwardFunction
dilation_h,
dilation_w,
groups,
offset_groups);
offset_groups,
use_mask);
auto grad_input = std::get<0>(result);
auto grad_weight = std::get<1>(result);
auto grad_offset = std::get<2>(result);
auto grad_bias = std::get<3>(result);
auto grad_mask = std::get<3>(result);
auto grad_bias = std::get<4>(result);
return {
grad_input,
grad_weight,
grad_offset,
grad_mask,
grad_bias,
};
}
......@@ -271,6 +301,7 @@ at::Tensor DeformConv2d_autograd(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -279,11 +310,13 @@ at::Tensor DeformConv2d_autograd(
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
int64_t offset_groups,
bool use_mask) {
return DeformConv2dFunction::apply(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
......@@ -292,15 +325,17 @@ at::Tensor DeformConv2d_autograd(
dilation_h,
dilation_w,
groups,
offset_groups)[0];
offset_groups,
use_mask)[0];
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_autograd(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -309,12 +344,14 @@ DeformConv2d_backward_autograd(
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
int64_t offset_groups,
bool use_mask) {
auto result = DeformConv2dBackwardFunction::apply(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
......@@ -323,7 +360,8 @@ DeformConv2d_backward_autograd(
dilation_h,
dilation_w,
groups,
offset_groups);
offset_groups,
use_mask);
return std::make_tuple(result[0], result[1], result[2], result[3]);
return std::make_tuple(result[0], result[1], result[2], result[3], result[4]);
}
......@@ -120,6 +120,7 @@ static void deformable_im2col_kernel(
int n,
const scalar_t* input,
const scalar_t* offset,
const scalar_t* mask,
int height,
int width,
int weight_h,
......@@ -135,6 +136,7 @@ static void deformable_im2col_kernel(
int n_offset_grps,
int out_h,
int out_w,
bool use_mask,
scalar_t* columns) {
for (int index = 0; index != n; ++index) {
const int out_x = index % out_w;
......@@ -157,16 +159,31 @@ static void deformable_im2col_kernel(
(out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h *
out_w;
auto mask_ptr = mask;
if (use_mask) {
mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w *
out_h * out_w;
}
for (int i = 0; i < weight_h; ++i) {
for (int j = 0; j < weight_w; ++j) {
const int offset_idx = 2 * (i * weight_w + j);
const int mask_idx = i * weight_w + j;
const int offset_idx = 2 * mask_idx;
scalar_t mask_value = 1;
if (use_mask) {
mask_value =
mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x];
}
const scalar_t offset_h =
offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x];
const scalar_t offset_w = offset_ptr
[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x];
const scalar_t y = (out_y * stride_h - pad_h) + i * dil_h + offset_h;
const scalar_t x = (out_x * stride_w - pad_w) + j * dil_w + offset_w;
*columns_ptr = bilinear_interpolate(input_ptr, height, width, y, x);
*columns_ptr =
mask_value * bilinear_interpolate(input_ptr, height, width, y, x);
columns_ptr += batch_sz * out_h * out_w;
}
}
......@@ -176,6 +193,7 @@ static void deformable_im2col_kernel(
static void deformable_im2col(
const at::Tensor& input,
const at::Tensor& data_offset,
const at::Tensor& data_mask,
int n_in_channels,
int height,
int width,
......@@ -191,6 +209,7 @@ static void deformable_im2col(
int out_w,
int parallel_imgs,
int deformable_group,
bool use_mask,
at::Tensor data_col) {
int num_kernels = n_in_channels * out_h * out_w * parallel_imgs;
......@@ -200,6 +219,7 @@ static void deformable_im2col(
num_kernels,
input.data_ptr<scalar_t>(),
data_offset.data_ptr<scalar_t>(),
data_mask.data_ptr<scalar_t>(),
height,
width,
weight_h,
......@@ -215,6 +235,7 @@ static void deformable_im2col(
deformable_group,
out_h,
out_w,
use_mask,
data_col.data_ptr<scalar_t>());
}));
}
......@@ -232,6 +253,7 @@ at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
......@@ -240,14 +262,17 @@ at::Tensor DeformConv2d_forward_cpu(
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps) {
int64_t n_offset_grps,
bool use_mask) {
at::Tensor input = input_param.contiguous();
at::Tensor offset = offset_param.contiguous();
at::Tensor weight = weight_param.contiguous();
at::Tensor mask = mask_param.contiguous();
at::Tensor bias = bias_param.contiguous();
TORCH_CHECK(input.ndimension() == 4);
TORCH_CHECK(offset.ndimension() == 4);
TORCH_CHECK(!use_mask || mask.ndimension() == 4);
TORCH_CHECK(weight.ndimension() == 4);
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
......@@ -292,6 +317,12 @@ at::Tensor DeformConv2d_forward_cpu(
offset.size(1),
" expected: ",
n_offset_grps * 2 * weight_h * weight_w);
TORCH_CHECK(
(!use_mask || mask.size(1) == n_offset_grps * weight_h * weight_w),
"mask.shape[1] is not valid: got: ",
mask.size(1),
" expected: ",
n_offset_grps * weight_h * weight_w);
TORCH_CHECK(input.size(1) % n_offset_grps == 0);
TORCH_CHECK(
......@@ -308,6 +339,19 @@ at::Tensor DeformConv2d_forward_cpu(
", ",
out_w,
")");
TORCH_CHECK((mask.size(0) == input.size(0)), "invalid batch size of mask");
TORCH_CHECK(
(!use_mask || (mask.size(2) == out_h && mask.size(3) == out_w)),
"offset output dims: (",
mask.size(2),
", ",
mask.size(3),
") - ",
"computed output dims: (",
out_h,
", ",
out_w,
")");
TORCH_CHECK(
out_h > 0 && out_w > 0,
"Calculated output size too small - out_h: ",
......@@ -328,11 +372,21 @@ at::Tensor DeformConv2d_forward_cpu(
out_w});
input = input.view(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
offset = offset.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
if (use_mask) {
mask = mask.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * weight_h * weight_w,
out_h,
out_w});
}
at::Tensor out_buf = at::zeros(
{batch_sz / n_parallel_imgs,
out_channels,
......@@ -360,6 +414,7 @@ at::Tensor DeformConv2d_forward_cpu(
deformable_im2col(
input[b],
offset[b],
mask[b],
n_in_channels,
in_h,
in_w,
......@@ -375,6 +430,7 @@ at::Tensor DeformConv2d_forward_cpu(
out_w,
n_parallel_imgs,
n_offset_grps,
use_mask,
columns);
columns = columns.view(
......@@ -406,6 +462,7 @@ static void deformable_col2im_kernel(
int n,
const scalar_t* col,
const scalar_t* offset,
const scalar_t* mask,
int channels,
int height,
int width,
......@@ -421,6 +478,7 @@ static void deformable_col2im_kernel(
int n_offset_grps,
int out_h,
int out_w,
bool use_mask,
scalar_t* grad_im) {
for (int index = 0; index != n; ++index) {
const int out_x = index % out_w;
......@@ -436,12 +494,27 @@ static void deformable_col2im_kernel(
auto offset_ptr = offset +
(b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h *
out_w;
const int offset_h_ptr =
((2 * (i * kernel_w + j)) * out_h + out_y) * out_w + out_x;
const int offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * out_h + out_y) * out_w + out_x;
auto mask_ptr = mask;
if (use_mask) {
mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w *
out_h * out_w;
}
const int mask_idx = i * kernel_w + j;
const int offset_idx = 2 * mask_idx;
const int offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x;
const int offset_w_ptr = ((offset_idx + 1) * out_h + out_y) * out_w + out_x;
const scalar_t offset_h = offset_ptr[offset_h_ptr];
const scalar_t offset_w = offset_ptr[offset_w_ptr];
scalar_t mask_value = 1;
if (use_mask) {
mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x];
}
const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
......@@ -453,7 +526,7 @@ static void deformable_col2im_kernel(
std::abs(y - yp) < 1 && std::abs(x - xp) < 1) {
int grad_pos = ((b * channels + c) * height + yp) * width + xp;
scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp));
grad_im[grad_pos] += weight * col[index];
grad_im[grad_pos] += mask_value * weight * col[index];
}
}
}
......@@ -463,6 +536,7 @@ static void deformable_col2im_kernel(
static void compute_grad_input(
const at::Tensor& columns,
const at::Tensor& offset,
const at::Tensor& mask,
int channels,
int height,
int width,
......@@ -476,6 +550,7 @@ static void compute_grad_input(
int dilation_w,
int parallel_imgs,
int n_offset_grps,
bool use_mask,
at::Tensor grad_im) {
int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
......@@ -490,6 +565,7 @@ static void compute_grad_input(
num_kernels,
columns.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(),
mask.data_ptr<scalar_t>(),
channels,
height,
width,
......@@ -505,6 +581,7 @@ static void compute_grad_input(
n_offset_grps,
out_h,
out_w,
use_mask,
grad_im.data_ptr<scalar_t>());
}));
}
......@@ -548,6 +625,7 @@ static void deformable_col2im_coord_kernel(
const scalar_t* col,
const scalar_t* im,
const scalar_t* offset,
const scalar_t* mask,
int channels,
int height,
int width,
......@@ -564,11 +642,17 @@ static void deformable_col2im_coord_kernel(
int n_offset_grps,
int out_h,
int out_w,
scalar_t* grad_offset) {
bool use_mask,
scalar_t* grad_offset,
scalar_t* grad_mask) {
for (int index = 0; index != n; ++index) {
scalar_t val = 0;
scalar_t grad_offset_val = 0;
scalar_t grad_mask_val = 0;
int w = index % out_w;
int h = (index / out_w) % out_h;
int w_w = (index / (out_w * out_h * 2)) % weight_w;
int w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h;
int c = (index / (out_w * out_h)) % offset_channels;
int b = index / (out_w * out_h * offset_channels);
......@@ -586,6 +670,12 @@ static void deformable_col2im_coord_kernel(
(b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h *
out_w;
auto mask_ptr = mask;
if (use_mask) {
mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w *
out_h * out_w;
}
const int offset_c = c - offset_grp * 2 * weight_h * weight_w;
const bool is_y_direction = offset_c % 2 == 0;
......@@ -598,30 +688,55 @@ static void deformable_col2im_coord_kernel(
int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w;
int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h;
const int mask_idx = i * weight_w + j;
const int offset_h_idx =
(((2 * (i * weight_w + j)) * out_h + out_y) * out_w + out_x);
(((2 * mask_idx) * out_h + out_y) * out_w + out_x);
const int offset_w_idx =
(((2 * (i * weight_w + j) + 1) * out_h + out_y) * out_w + out_x);
(((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x);
const scalar_t offset_h = offset_ptr[offset_h_idx];
const scalar_t offset_w = offset_ptr[offset_w_idx];
scalar_t mask_value = 1;
if (use_mask) {
mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x];
}
scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
const scalar_t weight =
get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction);
val += weight * col_ptr[col_pos];
grad_offset_val += mask_value * weight * col_ptr[col_pos];
if (use_mask && is_y_direction) {
grad_mask_val += col_ptr[col_pos] *
bilinear_interpolate(im_ptr, height, width, y, x);
}
im_ptr += height * width;
}
grad_offset[index] = val;
grad_offset[index] = grad_offset_val;
if (use_mask && is_y_direction) {
const int idx =
((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w +
w_w) *
out_h +
h) *
out_w +
w;
grad_mask[idx] = grad_mask_val;
}
}
}
static void compute_grad_offset(
static void compute_grad_offset_and_mask(
const at::Tensor& columns,
const at::Tensor& input,
const at::Tensor& offset,
const at::Tensor& mask,
int channels,
int height,
int width,
......@@ -635,7 +750,9 @@ static void compute_grad_offset(
int dilation_w,
int parallel_imgs,
int n_offset_grps,
at::Tensor grad_offset) {
bool use_mask,
at::Tensor grad_offset,
at::Tensor grad_mask) {
int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
int out_w =
......@@ -650,6 +767,7 @@ static void compute_grad_offset(
columns.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(),
mask.data_ptr<scalar_t>(),
channels,
height,
width,
......@@ -666,14 +784,18 @@ static void compute_grad_offset(
n_offset_grps,
out_h,
out_w,
grad_offset.data_ptr<scalar_t>());
use_mask,
grad_offset.data_ptr<scalar_t>(),
grad_mask.data_ptr<scalar_t>());
}));
}
static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
static std::tuple<at::Tensor, at::Tensor, at::Tensor>
deform_conv2d_backward_input_cpu(
at::Tensor input,
at::Tensor weight,
at::Tensor offset,
at::Tensor mask,
at::Tensor grad_out,
int stride_h,
int stride_w,
......@@ -683,7 +805,8 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
int dil_w,
int n_weight_grps,
int n_offset_grps,
int n_parallel_imgs) {
int n_parallel_imgs,
bool use_mask) {
int batch_sz = input.size(0);
int n_in_channels = input.size(1);
int in_h = input.size(2);
......@@ -700,9 +823,12 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
auto grad_input = at::zeros_like(input);
auto grad_offset = at::zeros_like(offset);
auto grad_mask = at::zeros_like(mask);
if (batch_sz == 0) {
return std::make_tuple(grad_input, grad_offset);
return std::make_tuple(grad_input, grad_offset, grad_mask);
}
auto columns = at::empty(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options());
......@@ -712,6 +838,7 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
input = input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
grad_offset = grad_offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
......@@ -723,6 +850,19 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
out_h,
out_w});
if (use_mask) {
grad_mask = grad_mask.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * weight_h * weight_w,
out_h,
out_w});
mask = mask.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * weight_h * weight_w,
out_h,
out_w});
}
grad_out = grad_out
.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
......@@ -749,10 +889,11 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1));
}
compute_grad_offset(
compute_grad_offset_and_mask(
columns,
input[elt],
offset[elt],
mask[elt],
n_in_channels,
in_h,
in_w,
......@@ -766,11 +907,14 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
dil_w,
n_parallel_imgs,
n_offset_grps,
grad_offset[elt]);
use_mask,
grad_offset[elt],
grad_mask[elt]);
compute_grad_input(
columns,
offset[elt],
mask[elt],
n_in_channels,
in_h,
in_w,
......@@ -784,6 +928,7 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
dil_w,
n_parallel_imgs,
n_offset_grps,
use_mask,
grad_input[elt]);
}
......@@ -791,13 +936,19 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
grad_offset = grad_offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
return std::make_tuple(grad_input, grad_offset);
if (use_mask) {
grad_mask = grad_mask.view(
{batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w});
}
return std::make_tuple(grad_input, grad_offset, grad_mask);
}
static at::Tensor deform_conv2d_backward_parameters_cpu(
at::Tensor input,
const at::Tensor& weight,
at::Tensor offset,
at::Tensor mask,
const at::Tensor& grad_out,
int stride_h,
int stride_w,
......@@ -807,7 +958,8 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
int dil_w,
int n_weight_grps,
int n_offset_grps,
int n_parallel_imgs) {
int n_parallel_imgs,
bool use_mask) {
int batch_sz = input.size(0);
int n_in_channels = input.size(1);
int in_h = input.size(2);
......@@ -839,12 +991,21 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
input = input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
offset = offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
if (use_mask) {
mask = mask.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * weight_h * weight_w,
out_h,
out_w});
}
grad_weight = grad_weight.view({n_weight_grps,
grad_weight.size(0) / n_weight_grps,
grad_weight.size(1),
......@@ -861,6 +1022,7 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
deformable_im2col(
input[elt],
offset[elt],
mask[elt],
n_in_channels,
in_h,
in_w,
......@@ -876,6 +1038,7 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
out_w,
n_parallel_imgs,
n_offset_grps,
use_mask,
columns);
for (int g = 0; g < n_weight_grps; g++) {
......@@ -895,12 +1058,13 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
return grad_weight;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cpu(
const at::Tensor& grad_out_param,
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
......@@ -909,21 +1073,24 @@ DeformConv2d_backward_cpu(
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps) {
int64_t n_offset_grps,
bool use_mask) {
at::Tensor grad_out = grad_out_param.contiguous();
at::Tensor input = input_param.contiguous();
at::Tensor weight = weight_param.contiguous();
at::Tensor offset = offset_param.contiguous();
at::Tensor mask = mask_param.contiguous();
at::Tensor bias = bias_param.contiguous();
const int batch_sz = input.size(0);
const int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
auto grad_input_and_offset = deform_conv2d_backward_input_cpu(
auto grad_input_and_offset_and_mask = deform_conv2d_backward_input_cpu(
input,
weight,
offset,
mask,
grad_out,
stride_h,
stride_w,
......@@ -933,15 +1100,18 @@ DeformConv2d_backward_cpu(
dil_w,
n_weight_grps,
n_offset_grps,
n_parallel_imgs);
n_parallel_imgs,
use_mask);
auto grad_input = std::get<0>(grad_input_and_offset);
auto grad_offset = std::get<1>(grad_input_and_offset);
auto grad_input = std::get<0>(grad_input_and_offset_and_mask);
auto grad_offset = std::get<1>(grad_input_and_offset_and_mask);
auto grad_mask = std::get<2>(grad_input_and_offset_and_mask);
auto grad_weight = deform_conv2d_backward_parameters_cpu(
input,
weight,
offset,
mask,
grad_out,
stride_h,
stride_w,
......@@ -951,9 +1121,11 @@ DeformConv2d_backward_cpu(
dil_w,
n_weight_grps,
n_offset_grps,
n_parallel_imgs);
n_parallel_imgs,
use_mask);
auto grad_bias = at::ones_like(bias) * grad_out.sum({0, 2, 3});
return std::make_tuple(grad_input, grad_weight, grad_offset, grad_bias);
return std::make_tuple(
grad_input, grad_weight, grad_offset, grad_mask, grad_bias);
}
......@@ -6,6 +6,7 @@ VISION_API at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -14,14 +15,17 @@ VISION_API at::Tensor DeformConv2d_forward_cpu(
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
int64_t deformable_groups,
bool use_mask);
VISION_API std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cpu(
VISION_API std::
tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cpu(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -30,7 +34,8 @@ DeformConv2d_backward_cpu(
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
int64_t deformable_groups,
bool use_mask);
VISION_API at::Tensor nms_cpu(
const at::Tensor& dets,
......
......@@ -78,12 +78,19 @@
#include <iostream>
#include <tuple>
const unsigned int CUDA_NUM_THREADS = 1024;
const int kMaxParallelImgs = 32;
inline unsigned int GET_BLOCKS(const unsigned int N) {
unsigned int kMaxGridNum = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
inline unsigned int GET_THREADS() {
if (at::cuda::getCurrentDeviceProperties()->major >= 6) {
return 1024;
}
return 512;
}
inline unsigned int GET_BLOCKS(const unsigned int THREADS, const unsigned int N) {
unsigned int kMaxGridNum =
at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
return std::min(kMaxGridNum, (N + THREADS - 1) / THREADS);
}
template <typename scalar_t>
......@@ -130,6 +137,7 @@ __global__ void deformable_im2col_gpu_kernel(
int n,
const scalar_t* input_ptr,
const scalar_t* offset_ptr,
const scalar_t* mask_ptr,
int height,
int width,
int weight_h,
......@@ -145,6 +153,7 @@ __global__ void deformable_im2col_gpu_kernel(
int n_offset_grps,
int out_h,
int out_w,
bool use_mask,
scalar_t* columns_ptr) {
CUDA_1D_KERNEL_LOOP(index, n) {
const int out_x = index % out_w;
......@@ -166,16 +175,30 @@ __global__ void deformable_im2col_gpu_kernel(
offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w *
out_h * out_w;
if (use_mask) {
mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w *
out_h * out_w;
}
for (int i = 0; i < weight_h; ++i) {
for (int j = 0; j < weight_w; ++j) {
const int offset_idx = 2 * (i * weight_w + j);
const int mask_idx = i * weight_w + j;
const int offset_idx = 2 * mask_idx;
scalar_t mask_value = 1;
if (use_mask) {
mask_value =
mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x];
}
const scalar_t offset_h =
offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x];
const scalar_t offset_w = offset_ptr
[(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x];
const scalar_t y = (out_y * stride_h - pad_h) + i * dil_h + offset_h;
const scalar_t x = (out_x * stride_w - pad_w) + j * dil_w + offset_w;
*columns_ptr = bilinear_interpolate(input_ptr, height, width, y, x);
*columns_ptr =
mask_value * bilinear_interpolate(input_ptr, height, width, y, x);
columns_ptr += batch_sz * out_h * out_w;
}
}
......@@ -185,6 +208,7 @@ __global__ void deformable_im2col_gpu_kernel(
static void deformable_im2col(
const at::Tensor& input,
const at::Tensor& data_offset,
const at::Tensor& data_mask,
int n_in_channels,
int height,
int width,
......@@ -200,17 +224,22 @@ static void deformable_im2col(
int out_w,
int parallel_imgs,
int deformable_group,
bool use_mask,
at::Tensor data_col) {
int num_kernels = n_in_channels * out_h * out_w * parallel_imgs;
const unsigned int threads = GET_THREADS();
const unsigned int blocks = GET_BLOCKS(threads, num_kernels);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "deformable_im2col_gpu", ([&] {
deformable_im2col_gpu_kernel<<<
GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS>>>(
blocks,
threads>>>(
num_kernels,
input.data_ptr<scalar_t>(),
data_offset.data_ptr<scalar_t>(),
data_mask.data_ptr<scalar_t>(),
height,
width,
weight_h,
......@@ -226,6 +255,7 @@ static void deformable_im2col(
deformable_group,
out_h,
out_w,
use_mask,
data_col.data_ptr<scalar_t>());
}));
......@@ -248,6 +278,7 @@ at::Tensor DeformConv2d_forward_cuda(
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
......@@ -256,14 +287,17 @@ at::Tensor DeformConv2d_forward_cuda(
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps) {
int64_t n_offset_grps,
bool use_mask) {
at::Tensor input = input_param.contiguous();
at::Tensor offset = offset_param.contiguous();
at::Tensor weight = weight_param.contiguous();
at::Tensor mask = mask_param.contiguous();
at::Tensor bias = bias_param.contiguous();
TORCH_CHECK(input.ndimension() == 4);
TORCH_CHECK(offset.ndimension() == 4);
TORCH_CHECK(!use_mask || mask.ndimension() == 4);
TORCH_CHECK(weight.ndimension() == 4);
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
......@@ -309,6 +343,12 @@ at::Tensor DeformConv2d_forward_cuda(
offset.size(1),
" expected: ",
n_offset_grps * 2 * weight_h * weight_w);
TORCH_CHECK(
(!use_mask || mask.size(1) == n_offset_grps * weight_h * weight_w),
"mask.shape[1] is not valid: got: ",
mask.size(1),
" expected: ",
n_offset_grps * weight_h * weight_w);
TORCH_CHECK(input.size(1) % n_offset_grps == 0);
TORCH_CHECK(
......@@ -325,6 +365,19 @@ at::Tensor DeformConv2d_forward_cuda(
", ",
out_w,
")");
TORCH_CHECK((mask.size(0) == input.size(0)), "invalid batch size of mask");
TORCH_CHECK(
(!use_mask || (mask.size(2) == out_h && mask.size(3) == out_w)),
"mask output dims: (",
mask.size(2),
", ",
mask.size(3),
") - ",
"computed output dims: (",
out_h,
", ",
out_w,
")");
TORCH_CHECK(
out_h > 0 && out_w > 0,
"Calculated output size too small - out_h: ",
......@@ -345,11 +398,21 @@ at::Tensor DeformConv2d_forward_cuda(
out_w});
input = input.view(
{batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w});
offset = offset.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
if (use_mask) {
mask = mask.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * weight_h * weight_w,
out_h,
out_w});
}
at::Tensor out_buf = at::zeros(
{batch_sz / n_parallel_imgs,
out_channels,
......@@ -377,6 +440,7 @@ at::Tensor DeformConv2d_forward_cuda(
deformable_im2col(
input[b],
offset[b],
mask[b],
in_channels,
in_h,
in_w,
......@@ -392,6 +456,7 @@ at::Tensor DeformConv2d_forward_cuda(
out_w,
n_parallel_imgs,
n_offset_grps,
use_mask,
columns);
columns = columns.view(
......@@ -402,8 +467,8 @@ at::Tensor DeformConv2d_forward_cuda(
.addmm_(weight[g].flatten(1), columns[g])
.view_as(out_buf[b][g]);
}
columns = columns.view(
{columns.size(0) * columns.size(1), columns.size(2)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
out_buf = out_buf.view({batch_sz / n_parallel_imgs,
......@@ -423,6 +488,7 @@ __global__ void deformable_col2im_gpu_kernel(
int n,
const scalar_t* col,
const scalar_t* offset_ptr,
const scalar_t* mask_ptr,
int channels,
int height,
int width,
......@@ -438,6 +504,7 @@ __global__ void deformable_col2im_gpu_kernel(
int n_offset_grps,
int out_h,
int out_w,
bool use_mask,
scalar_t* grad_im) {
CUDA_1D_KERNEL_LOOP(index, n) {
const int out_x = index % out_w;
......@@ -452,12 +519,26 @@ __global__ void deformable_col2im_gpu_kernel(
offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w *
out_h * out_w;
const int offset_h_ptr =
((2 * (i * kernel_w + j)) * out_h + out_y) * out_w + out_x;
const int offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * out_h + out_y) * out_w + out_x;
if (use_mask) {
mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w *
out_h * out_w;
}
const int mask_idx = i * kernel_w + j;
const int offset_idx = 2 * mask_idx;
const int offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x;
const int offset_w_ptr = ((offset_idx + 1) * out_h + out_y) * out_w + out_x;
const scalar_t offset_h = offset_ptr[offset_h_ptr];
const scalar_t offset_w = offset_ptr[offset_w_ptr];
scalar_t mask_value = 1;
if (use_mask) {
mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x];
}
const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
......@@ -469,7 +550,7 @@ __global__ void deformable_col2im_gpu_kernel(
std::abs(y - yp) < 1 && std::abs(x - xp) < 1) {
int grad_pos = ((b * channels + c) * height + yp) * width + xp;
scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp));
atomicAdd(grad_im + grad_pos, weight * col[index]);
atomicAdd(grad_im + grad_pos, mask_value * weight * col[index]);
}
}
}
......@@ -479,6 +560,7 @@ __global__ void deformable_col2im_gpu_kernel(
static void compute_grad_input(
const at::Tensor& columns,
const at::Tensor& offset,
const at::Tensor& mask,
int channels,
int height,
int width,
......@@ -492,6 +574,7 @@ static void compute_grad_input(
int dilation_w,
int parallel_imgs,
int n_offset_grps,
bool use_mask,
at::Tensor grad_im) {
int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
......@@ -500,14 +583,18 @@ static void compute_grad_input(
int num_kernels =
channels * weight_h * weight_w * out_h * out_w * parallel_imgs;
const unsigned int threads = GET_THREADS();
const unsigned int blocks = GET_BLOCKS(threads, num_kernels);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
columns.scalar_type(), "deformable_col2im_gpu", ([&] {
deformable_col2im_gpu_kernel<<<
GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS>>>(
blocks,
threads>>>(
num_kernels,
columns.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(),
mask.data_ptr<scalar_t>(),
channels,
height,
width,
......@@ -523,6 +610,7 @@ static void compute_grad_input(
n_offset_grps,
out_h,
out_w,
use_mask,
grad_im.data_ptr<scalar_t>());
}));
......@@ -571,6 +659,7 @@ __global__ void deformable_col2im_coord_gpu_kernel(
const scalar_t* col_ptr,
const scalar_t* im_ptr,
const scalar_t* offset_ptr,
const scalar_t* mask_ptr,
int channels,
int height,
int width,
......@@ -587,11 +676,17 @@ __global__ void deformable_col2im_coord_gpu_kernel(
int n_offset_grps,
int out_h,
int out_w,
scalar_t* grad_offset) {
const bool use_mask,
scalar_t* grad_offset,
scalar_t* grad_mask) {
CUDA_1D_KERNEL_LOOP(index, n) {
scalar_t val = 0;
scalar_t grad_offset_val = 0;
scalar_t grad_mask_val = 0;
int w = index % out_w;
int h = (index / out_w) % out_h;
int w_w = (index / (out_w * out_h * 2)) % weight_w;
int w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h;
int c = (index / (out_w * out_h)) % offset_channels;
int b = index / (out_w * out_h * offset_channels);
......@@ -607,6 +702,11 @@ __global__ void deformable_col2im_coord_gpu_kernel(
offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w *
out_h * out_w;
if (use_mask) {
mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w *
out_h * out_w;
}
const int offset_c = c - offset_grp * 2 * weight_h * weight_w;
const bool is_y_direction = offset_c % 2 == 0;
......@@ -619,30 +719,55 @@ __global__ void deformable_col2im_coord_gpu_kernel(
int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w;
int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h;
const int mask_idx = i * weight_w + j;
const int offset_h_ptr =
(((2 * (i * weight_w + j)) * out_h + out_y) * out_w + out_x);
(((2 * mask_idx) * out_h + out_y) * out_w + out_x);
const int offset_w_ptr =
(((2 * (i * weight_w + j) + 1) * out_h + out_y) * out_w + out_x);
(((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x);
const scalar_t offset_h = offset_ptr[offset_h_ptr];
const scalar_t offset_w = offset_ptr[offset_w_ptr];
scalar_t mask_value = 1;
if (use_mask) {
mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x];
}
scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
const scalar_t weight =
get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction);
val += weight * col_ptr[col_pos];
grad_offset_val += mask_value * weight * col_ptr[col_pos];
if (use_mask && is_y_direction) {
grad_mask_val += col_ptr[col_pos] *
bilinear_interpolate(im_ptr, height, width, y, x);
}
im_ptr += height * width;
}
grad_offset[index] = val;
grad_offset[index] = grad_offset_val;
if (use_mask && is_y_direction) {
const int idx =
((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w +
w_w) *
out_h +
h) *
out_w +
w;
grad_mask[idx] = grad_mask_val;
}
}
}
static void compute_grad_offset(
static void compute_grad_offset_and_mask(
const at::Tensor& columns,
const at::Tensor& input,
const at::Tensor& offset,
const at::Tensor& mask,
int channels,
int height,
int width,
......@@ -656,7 +781,9 @@ static void compute_grad_offset(
int dilation_w,
int parallel_imgs,
int n_offset_grps,
at::Tensor grad_offset) {
bool use_mask,
at::Tensor grad_offset,
at::Tensor grad_mask) {
int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
int out_w =
......@@ -664,15 +791,19 @@ static void compute_grad_offset(
int num_kernels =
out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs;
const unsigned int threads = GET_THREADS();
const unsigned int blocks = GET_BLOCKS(threads, num_kernels);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
columns.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
deformable_col2im_coord_gpu_kernel<<<
GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS>>>(
blocks,
threads>>>(
num_kernels,
columns.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
offset.data_ptr<scalar_t>(),
mask.data_ptr<scalar_t>(),
channels,
height,
width,
......@@ -689,19 +820,23 @@ static void compute_grad_offset(
n_offset_grps,
out_h,
out_w,
grad_offset.data_ptr<scalar_t>());
use_mask,
grad_offset.data_ptr<scalar_t>(),
grad_mask.data_ptr<scalar_t>());
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in compute_grad_offset: %s\n", cudaGetErrorString(err));
printf(
"error in compute_grad_offset_and_mask: %s\n", cudaGetErrorString(err));
}
}
static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cuda(
static std::tuple<at::Tensor, at::Tensor, at::Tensor> deform_conv2d_backward_input_cuda(
at::Tensor input,
at::Tensor weight,
at::Tensor offset,
at::Tensor mask,
at::Tensor grad_out,
int stride_h,
int stride_w,
......@@ -711,7 +846,8 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cuda(
int dil_w,
int n_weight_grps,
int n_offset_grps,
int n_parallel_imgs) {
int n_parallel_imgs,
bool use_mask) {
at::DeviceGuard guard(input.device());
int batch_sz = input.size(0);
......@@ -730,9 +866,12 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cuda(
auto grad_input = at::zeros_like(input);
auto grad_offset = at::zeros_like(offset);
auto grad_mask = at::zeros_like(mask);
if (batch_sz == 0) {
return std::make_tuple(grad_input, grad_offset);
return std::make_tuple(grad_input, grad_offset, grad_mask);
}
auto columns = at::empty(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options());
......@@ -742,6 +881,7 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cuda(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
input = input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
grad_offset = grad_offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
......@@ -753,12 +893,27 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cuda(
out_h,
out_w});
grad_out = grad_out.reshape({batch_sz / n_parallel_imgs,
if (use_mask) {
grad_mask = grad_mask.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * weight_h * weight_w,
out_h,
out_w});
mask = mask.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * weight_h * weight_w,
out_h,
out_w});
}
grad_out = grad_out
.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_weight_grps,
n_out_channels / n_weight_grps,
out_h,
out_w}).permute({0, 2, 3, 1, 4, 5});
out_w})
.permute({0, 2, 3, 1, 4, 5});
weight = weight.reshape({n_weight_grps,
weight.size(0) / n_weight_grps,
......@@ -776,10 +931,11 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cuda(
weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1));
}
compute_grad_offset(
compute_grad_offset_and_mask(
columns,
input[elt],
offset[elt],
mask[elt],
n_in_channels,
in_h,
in_w,
......@@ -793,11 +949,14 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cuda(
dil_w,
n_parallel_imgs,
n_offset_grps,
grad_offset[elt]);
use_mask,
grad_offset[elt],
grad_mask[elt]);
compute_grad_input(
columns,
offset[elt],
mask[elt],
n_in_channels,
in_h,
in_w,
......@@ -811,21 +970,27 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cuda(
dil_w,
n_parallel_imgs,
n_offset_grps,
use_mask,
grad_input[elt]);
}
grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w});
grad_offset = grad_offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
return std::make_tuple(grad_input, grad_offset);
if (use_mask) {
grad_mask = grad_mask.view(
{batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w});
}
return std::make_tuple(grad_input, grad_offset, grad_mask);
}
static at::Tensor deform_conv2d_backward_parameters_cuda(
at::Tensor input,
const at::Tensor& weight,
at::Tensor offset,
at::Tensor mask,
const at::Tensor& grad_out,
int stride_h,
int stride_w,
......@@ -835,7 +1000,8 @@ static at::Tensor deform_conv2d_backward_parameters_cuda(
int dil_w,
int n_weight_grps,
int n_offset_grps,
int n_parallel_imgs) {
int n_parallel_imgs,
bool use_mask) {
at::DeviceGuard guard(input.device());
int batch_sz = input.size(0);
......@@ -857,23 +1023,33 @@ static at::Tensor deform_conv2d_backward_parameters_cuda(
return grad_weight;
}
at::Tensor grad_out_buf = grad_out.reshape(
{batch_sz / n_parallel_imgs,
at::Tensor grad_out_buf = grad_out
.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_weight_grps,
n_out_channels / n_weight_grps,
out_h,
out_w}
).permute({0, 2, 3, 1, 4, 5}).contiguous();
out_w})
.permute({0, 2, 3, 1, 4, 5})
.contiguous();
input = input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
offset = offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
if (use_mask) {
mask = mask.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * weight_h * weight_w,
out_h,
out_w});
}
grad_weight = grad_weight.reshape({n_weight_grps,
grad_weight.size(0) / n_weight_grps,
grad_weight.size(1),
......@@ -890,6 +1066,7 @@ static at::Tensor deform_conv2d_backward_parameters_cuda(
deformable_im2col(
input[elt],
offset[elt],
mask[elt],
n_in_channels,
in_h,
in_w,
......@@ -905,6 +1082,7 @@ static at::Tensor deform_conv2d_backward_parameters_cuda(
out_w,
n_parallel_imgs,
n_offset_grps,
use_mask,
columns);
for (int g = 0; g < n_weight_grps; g++) {
......@@ -924,12 +1102,13 @@ static at::Tensor deform_conv2d_backward_parameters_cuda(
return grad_weight;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cuda(
const at::Tensor& grad_out_param,
const at::Tensor& input_param,
const at::Tensor& weight_param,
const at::Tensor& offset_param,
const at::Tensor& mask_param,
const at::Tensor& bias_param,
int64_t stride_h,
int64_t stride_w,
......@@ -938,21 +1117,24 @@ DeformConv2d_backward_cuda(
int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps) {
int64_t n_offset_grps,
bool use_mask) {
at::Tensor grad_out = grad_out_param.contiguous();
at::Tensor input = input_param.contiguous();
at::Tensor weight = weight_param.contiguous();
at::Tensor offset = offset_param.contiguous();
at::Tensor mask = mask_param.contiguous();
at::Tensor bias = bias_param.contiguous();
const int batch_sz = input.size(0);
const int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
auto grad_input_and_offset = deform_conv2d_backward_input_cuda(
auto grad_input_and_offset_and_mask = deform_conv2d_backward_input_cuda(
input,
weight,
offset,
mask,
grad_out,
stride_h,
stride_w,
......@@ -962,15 +1144,18 @@ DeformConv2d_backward_cuda(
dil_w,
n_weight_grps,
n_offset_grps,
n_parallel_imgs);
n_parallel_imgs,
use_mask);
auto grad_input = std::get<0>(grad_input_and_offset);
auto grad_offset = std::get<1>(grad_input_and_offset);
auto grad_input = std::get<0>(grad_input_and_offset_and_mask);
auto grad_offset = std::get<1>(grad_input_and_offset_and_mask);
auto grad_mask = std::get<2>(grad_input_and_offset_and_mask);
auto grad_weight = deform_conv2d_backward_parameters_cuda(
input,
weight,
offset,
mask,
grad_out,
stride_h,
stride_w,
......@@ -980,10 +1165,12 @@ DeformConv2d_backward_cuda(
dil_w,
n_weight_grps,
n_offset_grps,
n_parallel_imgs);
n_parallel_imgs,
use_mask);
auto value = grad_out.sum({0, 2, 3});
auto grad_bias = at::ones_like(bias) * value;
return std::make_tuple(grad_input, grad_weight, grad_offset, grad_bias);
return std::make_tuple(
grad_input, grad_weight, grad_offset, grad_mask, grad_bias);
}
......@@ -6,6 +6,7 @@ VISION_API at::Tensor DeformConv2d_forward_cuda(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -14,14 +15,17 @@ VISION_API at::Tensor DeformConv2d_forward_cuda(
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
int64_t deformable_groups,
bool use_mask);
VISION_API std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cuda(
VISION_API std::
tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cuda(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -30,7 +34,8 @@ DeformConv2d_backward_cuda(
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
int64_t deformable_groups,
bool use_mask);
VISION_API at::Tensor nms_cuda(
const at::Tensor& dets,
......
......@@ -46,9 +46,9 @@ int64_t cuda_version() noexcept {
TORCH_LIBRARY(torchvision, m) {
m.def(
"deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor");
"deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor");
m.def(
"_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> (Tensor, Tensor, Tensor, Tensor)");
"_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor");
m.def(
"ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)");
......
......@@ -17,6 +17,7 @@ def deform_conv2d(
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
dilation: Tuple[int, int] = (1, 1),
mask: Optional[Tensor] = None,
) -> Tensor:
"""
Performs Deformable Convolution, described in Deformable Convolutional Networks
......@@ -33,6 +34,9 @@ def deform_conv2d(
padding (int or Tuple[int, int]): height/width of padding of zeroes around
each image. Default: 0
dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1
mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width,
out_height, out_width]): masks to be applied for each position in the
convolution kernel.
Returns:
output (Tensor[batch_sz, out_channels, out_h, out_w]): result of convolution
......@@ -42,11 +46,12 @@ def deform_conv2d(
>>> input = torch.rand(4, 3, 10, 10)
>>> kh, kw = 3, 3
>>> weight = torch.rand(5, 3, kh, kw)
>>> # offset should have the same spatial size as the output
>>> # offset and mask should have the same spatial size as the output
>>> # of the convolution. In this case, for an input of 10, stride of 1
>>> # and kernel size of 3, without padding, the output size is 8
>>> offset = torch.rand(4, 2 * kh * kw, 8, 8)
>>> out = deform_conv2d(input, offset, weight)
>>> mask = torch.rand(4, kh * kw, 8, 8)
>>> out = deform_conv2d(input, offset, weight, mask=mask)
>>> print(out.shape)
>>> # returns
>>> torch.Size([4, 5, 8, 8])
......@@ -54,6 +59,12 @@ def deform_conv2d(
_assert_has_ops()
out_channels = weight.shape[0]
use_mask = mask is not None
if mask is None:
mask = torch.zeros((input.shape[0], 0), device=input.device, dtype=input.dtype)
if bias is None:
bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype)
......@@ -77,18 +88,21 @@ def deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h, stride_w,
pad_h, pad_w,
dil_h, dil_w,
n_weight_grps,
n_offset_grps)
n_offset_grps,
use_mask,)
class DeformConv2d(nn.Module):
"""
See deform_conv2d
"""
def __init__(
self,
in_channels: int,
......@@ -127,21 +141,25 @@ class DeformConv2d(nn.Module):
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, input: Tensor, offset: Tensor) -> Tensor:
def forward(self, input: Tensor, offset: Tensor, mask: Tensor = None) -> Tensor:
"""
Arguments:
input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width,
out_height, out_width]): offsets to be applied for each position in the
convolution kernel.
mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width,
out_height, out_width]): masks to be applied for each position in the
convolution kernel.
"""
return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride,
padding=self.padding, dilation=self.dilation)
padding=self.padding, dilation=self.dilation, mask=mask)
def __repr__(self) -> str:
s = self.__class__.__name__ + '('
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment