"...git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "d70dee879fc046ead28072305ab716dcb59c7d84"
Unverified Commit 5a4bb19d authored by Licht Takeuchi's avatar Licht Takeuchi Committed by GitHub
Browse files

Add modulation input for DeformConv2D (#2791)

* Add modulation input for DeformConv2D

* lint

* Patch for GPU CI

* Remove bad cache on CI
parent 2b39f1e8
......@@ -455,6 +455,7 @@ jobs:
resource_class: gpu.small
environment:
image_name: "pytorch/manylinux-cuda101"
PYTHON_VERSION: << parameters.python_version >>
steps:
- checkout
- designate_upload_channel
......@@ -462,14 +463,9 @@ jobs:
name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:
keys:
- env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
- run:
name: Setup
command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh
command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh
- save_cache:
key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
......@@ -533,6 +529,7 @@ jobs:
name: windows-gpu
environment:
CUDA_VERSION: "10.1"
PYTHON_VERSION: << parameters.python_version >>
steps:
- checkout
- designate_upload_channel
......@@ -540,11 +537,6 @@ jobs:
name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:
keys:
- env-v1-windows-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/windows/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
- run:
name: Setup
command: .circleci/unittest/windows/scripts/setup_env.sh
......
......@@ -455,6 +455,7 @@ jobs:
resource_class: gpu.small
environment:
image_name: "pytorch/manylinux-cuda101"
PYTHON_VERSION: << parameters.python_version >>
steps:
- checkout
- designate_upload_channel
......@@ -462,14 +463,9 @@ jobs:
name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:
{% raw %}
keys:
- env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
{% endraw %}
- run:
name: Setup
command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh
command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh
- save_cache:
{% raw %}
key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
......@@ -533,6 +529,7 @@ jobs:
name: windows-gpu
environment:
CUDA_VERSION: "10.1"
PYTHON_VERSION: << parameters.python_version >>
steps:
- checkout
- designate_upload_channel
......@@ -540,11 +537,6 @@ jobs:
name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:
{% raw %}
keys:
- env-v1-windows-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/windows/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
{% endraw %}
- run:
name: Setup
command: .circleci/unittest/windows/scripts/setup_env.sh
......
......@@ -458,7 +458,7 @@ class NewEmptyTensorTester(unittest.TestCase):
class DeformConvTester(OpTester, unittest.TestCase):
def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1):
def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
stride_h, stride_w = _pair(stride)
pad_h, pad_w = _pair(padding)
dil_h, dil_w = _pair(dilation)
......@@ -489,12 +489,17 @@ class DeformConvTester(OpTester, unittest.TestCase):
c_in = weight_grp * in_c_per_weight_grp + c
offset_grp = c_in // in_c_per_offset_grp
offset_idx = 2 * (offset_grp * (weight_h * weight_w) + di * weight_w + dj)
mask_idx = offset_grp * (weight_h * weight_w) + di * weight_w + dj
offset_idx = 2 * mask_idx
pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]
out[b, c_out, i, j] += (weight[c_out, c, di, dj] *
mask_value = 1.0
if mask is not None:
mask_value = mask[b, mask_idx, i, j]
out[b, c_out, i, j] += (mask_value * weight[c_out, c, di, dj] *
bilinear_interpolate(x[b, c_in, :, :], pi, pj))
out += bias.view(1, n_out_channels, 1, 1)
return out
......@@ -523,6 +528,9 @@ class DeformConvTester(OpTester, unittest.TestCase):
offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w,
device=device, dtype=dtype, requires_grad=True)
mask = torch.randn(batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w,
device=device, dtype=dtype, requires_grad=True)
weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w,
device=device, dtype=dtype, requires_grad=True)
......@@ -531,9 +539,10 @@ class DeformConvTester(OpTester, unittest.TestCase):
if not contiguous:
x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
return x, weight, offset, bias, stride, pad, dilation
return x, weight, offset, mask, bias, stride, pad, dilation
def _test_forward(self, device, contiguous, dtype=None):
dtype = self.dtype if dtype is None else dtype
......@@ -541,21 +550,28 @@ class DeformConvTester(OpTester, unittest.TestCase):
self._test_forward_with_batchsize(device, contiguous, batch_sz, dtype)
def _test_forward_with_batchsize(self, device, contiguous, batch_sz, dtype):
x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
in_channels = 6
out_channels = 2
kernel_size = (3, 2)
groups = 2
tol = 1e-3 if dtype is torch.half else 1e-5
layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups).to(device=x.device, dtype=dtype)
res = layer(x, offset)
res = layer(x, offset, mask)
weight = layer.weight.data
bias = layer.bias.data
expected = self.expected_fn(x, weight, offset, bias, stride=stride, padding=padding, dilation=dilation)
expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)
self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol),
'\nres:\n{}\nexpected:\n{}'.format(res, expected))
# no modulation test
res = layer(x, offset)
expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
tol = 1e-3 if dtype is torch.half else 1e-5
self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol),
'\nres:\n{}\nexpected:\n{}'.format(res, expected))
......@@ -564,24 +580,46 @@ class DeformConvTester(OpTester, unittest.TestCase):
wrong_offset = torch.rand_like(offset[:, :2])
res = layer(x, wrong_offset)
with self.assertRaises(RuntimeError):
wrong_mask = torch.rand_like(mask[:, :2])
res = layer(x, offset, wrong_mask)
def _test_backward(self, device, contiguous):
for batch_sz in [0, 33]:
self._test_backward_with_batchsize(device, contiguous, batch_sz)
def _test_backward_with_batchsize(self, device, contiguous, batch_sz):
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, self.dtype)
x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(device, contiguous,
batch_sz, self.dtype)
def func(x_, offset_, mask_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride,
padding=padding, dilation=dilation, mask=mask_)
def func(x_, offset_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation)
gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5)
def func_no_mask(x_, offset_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride,
padding=padding, dilation=dilation, mask=None)
gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5)
@torch.jit.script
def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
# type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_,
padding=pad_, dilation=dilation_, mask=mask_)
gradcheck(func, (x, offset, weight, bias), nondet_tol=1e-5)
gradcheck(lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
(x, offset, mask, weight, bias), nondet_tol=1e-5)
@torch.jit.script
def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
# type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_)
def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
# type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_,
padding=pad_, dilation=dilation_, mask=None)
gradcheck(lambda z, off, wei, bi: script_func(z, off, wei, bi, stride, padding, dilation),
gradcheck(lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
(x, offset, weight, bias), nondet_tol=1e-5)
# Test from https://github.com/pytorch/vision/issues/2598
......@@ -593,17 +631,19 @@ class DeformConvTester(OpTester, unittest.TestCase):
init_weight = torch.randn(9, 9, 3, 3, requires_grad=True)
img = torch.randn(8, 9, 1000, 110)
offset = torch.rand(8, 2 * 3 * 3, 1000, 110)
mask = torch.rand(8, 3 * 3, 1000, 110)
if not contiguous:
img = img.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
weight = init_weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
else:
weight = init_weight
for d in ["cpu", "cuda"]:
out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1)
out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
out.mean().backward()
if true_cpu_grads is None:
true_cpu_grads = init_weight.grad
......
......@@ -17,6 +17,7 @@ at::Tensor deform_conv2d(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -25,7 +26,8 @@ at::Tensor deform_conv2d(
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
int64_t offset_groups,
bool use_mask) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::deform_conv2d", "")
.typed<decltype(deform_conv2d)>();
......@@ -33,6 +35,7 @@ at::Tensor deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
......@@ -41,7 +44,8 @@ at::Tensor deform_conv2d(
dilation_h,
dilation_w,
groups,
offset_groups);
offset_groups,
use_mask);
}
#if defined(WITH_CUDA) || defined(WITH_HIP)
......@@ -49,6 +53,7 @@ at::Tensor DeformConv2d_autocast(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -57,12 +62,14 @@ at::Tensor DeformConv2d_autocast(
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
int64_t offset_groups,
bool use_mask) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return deform_conv2d(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, weight),
at::autocast::cached_cast(at::kFloat, offset),
at::autocast::cached_cast(at::kFloat, mask),
at::autocast::cached_cast(at::kFloat, bias),
stride_h,
stride_w,
......@@ -71,17 +78,19 @@ at::Tensor DeformConv2d_autocast(
dilation_h,
dilation_w,
groups,
offset_groups)
offset_groups,
use_mask)
.to(input.scalar_type());
}
#endif
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -90,7 +99,8 @@ _deform_conv2d_backward(
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
int64_t offset_groups,
bool use_mask) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
......@@ -100,6 +110,7 @@ _deform_conv2d_backward(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
......@@ -108,7 +119,8 @@ _deform_conv2d_backward(
dilation_h,
dilation_w,
groups,
offset_groups);
offset_groups,
use_mask);
}
class DeformConv2dFunction
......@@ -119,6 +131,7 @@ class DeformConv2dFunction
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -127,12 +140,14 @@ class DeformConv2dFunction
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
int64_t offset_groups,
bool use_mask) {
at::AutoNonVariableTypeMode g;
auto output = deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
......@@ -141,9 +156,10 @@ class DeformConv2dFunction
dilation_h,
dilation_w,
groups,
offset_groups);
offset_groups,
use_mask);
ctx->save_for_backward({input, weight, offset, bias});
ctx->save_for_backward({input, weight, offset, mask, bias});
ctx->saved_data["stride_h"] = stride_h;
ctx->saved_data["stride_w"] = stride_w;
ctx->saved_data["pad_h"] = pad_h;
......@@ -152,6 +168,7 @@ class DeformConv2dFunction
ctx->saved_data["dilation_w"] = dilation_w;
ctx->saved_data["groups"] = groups;
ctx->saved_data["offset_groups"] = offset_groups;
ctx->saved_data["use_mask"] = use_mask;
return {
output,
......@@ -165,7 +182,8 @@ class DeformConv2dFunction
auto input = saved[0];
auto weight = saved[1];
auto offset = saved[2];
auto bias = saved[3];
auto mask = saved[3];
auto bias = saved[4];
auto stride_h = ctx->saved_data["stride_h"].toInt();
auto stride_w = ctx->saved_data["stride_w"].toInt();
......@@ -175,12 +193,14 @@ class DeformConv2dFunction
auto dilation_w = ctx->saved_data["dilation_w"].toInt();
auto groups = ctx->saved_data["groups"].toInt();
auto offset_groups = ctx->saved_data["offset_groups"].toInt();
auto use_mask = ctx->saved_data["use_mask"].toBool();
auto grads = _deform_conv2d_backward(
grad_output[0],
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
......@@ -189,16 +209,19 @@ class DeformConv2dFunction
dilation_h,
dilation_w,
groups,
offset_groups);
offset_groups,
use_mask);
auto grad_input = std::get<0>(grads);
auto grad_weight = std::get<1>(grads);
auto grad_offset = std::get<2>(grads);
auto grad_bias = std::get<3>(grads);
auto grad_mask = std::get<3>(grads);
auto grad_bias = std::get<4>(grads);
return {
grad_input,
grad_weight,
grad_offset,
grad_mask,
grad_bias,
torch::autograd::Variable(),
torch::autograd::Variable(),
......@@ -208,6 +231,7 @@ class DeformConv2dFunction
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
};
}
};
......@@ -222,6 +246,7 @@ class DeformConv2dBackwardFunction
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -230,13 +255,15 @@ class DeformConv2dBackwardFunction
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
int64_t offset_groups,
bool use_mask) {
at::AutoNonVariableTypeMode g;
auto result = _deform_conv2d_backward(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
......@@ -245,17 +272,20 @@ class DeformConv2dBackwardFunction
dilation_h,
dilation_w,
groups,
offset_groups);
offset_groups,
use_mask);
auto grad_input = std::get<0>(result);
auto grad_weight = std::get<1>(result);
auto grad_offset = std::get<2>(result);
auto grad_bias = std::get<3>(result);
auto grad_mask = std::get<3>(result);
auto grad_bias = std::get<4>(result);
return {
grad_input,
grad_weight,
grad_offset,
grad_mask,
grad_bias,
};
}
......@@ -271,6 +301,7 @@ at::Tensor DeformConv2d_autograd(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -279,11 +310,13 @@ at::Tensor DeformConv2d_autograd(
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
int64_t offset_groups,
bool use_mask) {
return DeformConv2dFunction::apply(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
......@@ -292,15 +325,17 @@ at::Tensor DeformConv2d_autograd(
dilation_h,
dilation_w,
groups,
offset_groups)[0];
offset_groups,
use_mask)[0];
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_autograd(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -309,12 +344,14 @@ DeformConv2d_backward_autograd(
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups) {
int64_t offset_groups,
bool use_mask) {
auto result = DeformConv2dBackwardFunction::apply(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
......@@ -323,7 +360,8 @@ DeformConv2d_backward_autograd(
dilation_h,
dilation_w,
groups,
offset_groups);
offset_groups,
use_mask);
return std::make_tuple(result[0], result[1], result[2], result[3]);
}
\ No newline at end of file
return std::make_tuple(result[0], result[1], result[2], result[3], result[4]);
}
This diff is collapsed.
......@@ -6,6 +6,7 @@ VISION_API at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -14,23 +15,27 @@ VISION_API at::Tensor DeformConv2d_forward_cpu(
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
int64_t deformable_groups,
bool use_mask);
VISION_API std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cpu(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
VISION_API std::
tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cpu(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups,
bool use_mask);
VISION_API at::Tensor nms_cpu(
const at::Tensor& dets,
......
This diff is collapsed.
......@@ -6,6 +6,7 @@ VISION_API at::Tensor DeformConv2d_forward_cuda(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
......@@ -14,23 +15,27 @@ VISION_API at::Tensor DeformConv2d_forward_cuda(
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
int64_t deformable_groups,
bool use_mask);
VISION_API std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cuda(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
VISION_API std::
tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cuda(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups,
bool use_mask);
VISION_API at::Tensor nms_cuda(
const at::Tensor& dets,
......
......@@ -46,9 +46,9 @@ int64_t cuda_version() noexcept {
TORCH_LIBRARY(torchvision, m) {
m.def(
"deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor");
"deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor");
m.def(
"_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> (Tensor, Tensor, Tensor, Tensor)");
"_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor");
m.def(
"ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)");
......
......@@ -17,6 +17,7 @@ def deform_conv2d(
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
dilation: Tuple[int, int] = (1, 1),
mask: Optional[Tensor] = None,
) -> Tensor:
"""
Performs Deformable Convolution, described in Deformable Convolutional Networks
......@@ -33,6 +34,9 @@ def deform_conv2d(
padding (int or Tuple[int, int]): height/width of padding of zeroes around
each image. Default: 0
dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1
mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width,
out_height, out_width]): masks to be applied for each position in the
convolution kernel.
Returns:
output (Tensor[batch_sz, out_channels, out_h, out_w]): result of convolution
......@@ -42,11 +46,12 @@ def deform_conv2d(
>>> input = torch.rand(4, 3, 10, 10)
>>> kh, kw = 3, 3
>>> weight = torch.rand(5, 3, kh, kw)
>>> # offset should have the same spatial size as the output
>>> # offset and mask should have the same spatial size as the output
>>> # of the convolution. In this case, for an input of 10, stride of 1
>>> # and kernel size of 3, without padding, the output size is 8
>>> offset = torch.rand(4, 2 * kh * kw, 8, 8)
>>> out = deform_conv2d(input, offset, weight)
>>> mask = torch.rand(4, kh * kw, 8, 8)
>>> out = deform_conv2d(input, offset, weight, mask=mask)
>>> print(out.shape)
>>> # returns
>>> torch.Size([4, 5, 8, 8])
......@@ -54,6 +59,12 @@ def deform_conv2d(
_assert_has_ops()
out_channels = weight.shape[0]
use_mask = mask is not None
if mask is None:
mask = torch.zeros((input.shape[0], 0), device=input.device, dtype=input.dtype)
if bias is None:
bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype)
......@@ -77,18 +88,21 @@ def deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h, stride_w,
pad_h, pad_w,
dil_h, dil_w,
n_weight_grps,
n_offset_grps)
n_offset_grps,
use_mask,)
class DeformConv2d(nn.Module):
"""
See deform_conv2d
"""
def __init__(
self,
in_channels: int,
......@@ -127,21 +141,25 @@ class DeformConv2d(nn.Module):
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, input: Tensor, offset: Tensor) -> Tensor:
def forward(self, input: Tensor, offset: Tensor, mask: Tensor = None) -> Tensor:
"""
Arguments:
input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width,
out_height, out_width]): offsets to be applied for each position in the
convolution kernel.
mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width,
out_height, out_width]): masks to be applied for each position in the
convolution kernel.
"""
return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride,
padding=self.padding, dilation=self.dilation)
padding=self.padding, dilation=self.dilation, mask=mask)
def __repr__(self) -> str:
s = self.__class__.__name__ + '('
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment