Unverified Commit 5a4bb19d authored by Licht Takeuchi's avatar Licht Takeuchi Committed by GitHub
Browse files

Add modulation input for DeformConv2D (#2791)

* Add modulation input for DeformConv2D

* lint

* Patch for GPU CI

* Remove bad cache on CI
parent 2b39f1e8
...@@ -455,6 +455,7 @@ jobs: ...@@ -455,6 +455,7 @@ jobs:
resource_class: gpu.small resource_class: gpu.small
environment: environment:
image_name: "pytorch/manylinux-cuda101" image_name: "pytorch/manylinux-cuda101"
PYTHON_VERSION: << parameters.python_version >>
steps: steps:
- checkout - checkout
- designate_upload_channel - designate_upload_channel
...@@ -462,14 +463,9 @@ jobs: ...@@ -462,14 +463,9 @@ jobs:
name: Generate cache key name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache. # This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:
keys:
- env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
- run: - run:
name: Setup name: Setup
command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh
- save_cache: - save_cache:
key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
...@@ -533,6 +529,7 @@ jobs: ...@@ -533,6 +529,7 @@ jobs:
name: windows-gpu name: windows-gpu
environment: environment:
CUDA_VERSION: "10.1" CUDA_VERSION: "10.1"
PYTHON_VERSION: << parameters.python_version >>
steps: steps:
- checkout - checkout
- designate_upload_channel - designate_upload_channel
...@@ -540,11 +537,6 @@ jobs: ...@@ -540,11 +537,6 @@ jobs:
name: Generate cache key name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache. # This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:
keys:
- env-v1-windows-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/windows/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
- run: - run:
name: Setup name: Setup
command: .circleci/unittest/windows/scripts/setup_env.sh command: .circleci/unittest/windows/scripts/setup_env.sh
......
...@@ -455,6 +455,7 @@ jobs: ...@@ -455,6 +455,7 @@ jobs:
resource_class: gpu.small resource_class: gpu.small
environment: environment:
image_name: "pytorch/manylinux-cuda101" image_name: "pytorch/manylinux-cuda101"
PYTHON_VERSION: << parameters.python_version >>
steps: steps:
- checkout - checkout
- designate_upload_channel - designate_upload_channel
...@@ -462,14 +463,9 @@ jobs: ...@@ -462,14 +463,9 @@ jobs:
name: Generate cache key name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache. # This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:
{% raw %}
keys:
- env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
{% endraw %}
- run: - run:
name: Setup name: Setup
command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh command: docker run -e PYTHON_VERSION -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux/scripts/setup_env.sh
- save_cache: - save_cache:
{% raw %} {% raw %}
key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }} key: env-v2-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
...@@ -533,6 +529,7 @@ jobs: ...@@ -533,6 +529,7 @@ jobs:
name: windows-gpu name: windows-gpu
environment: environment:
CUDA_VERSION: "10.1" CUDA_VERSION: "10.1"
PYTHON_VERSION: << parameters.python_version >>
steps: steps:
- checkout - checkout
- designate_upload_channel - designate_upload_channel
...@@ -540,11 +537,6 @@ jobs: ...@@ -540,11 +537,6 @@ jobs:
name: Generate cache key name: Generate cache key
# This will refresh cache on Sundays, nightly build should generate new cache. # This will refresh cache on Sundays, nightly build should generate new cache.
command: echo "$(date +"%Y-%U")" > .circleci-weekly command: echo "$(date +"%Y-%U")" > .circleci-weekly
- restore_cache:
{% raw %}
keys:
- env-v1-windows-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/windows/scripts/environment.yml" }}-{{ checksum ".circleci-weekly" }}
{% endraw %}
- run: - run:
name: Setup name: Setup
command: .circleci/unittest/windows/scripts/setup_env.sh command: .circleci/unittest/windows/scripts/setup_env.sh
......
...@@ -458,7 +458,7 @@ class NewEmptyTensorTester(unittest.TestCase): ...@@ -458,7 +458,7 @@ class NewEmptyTensorTester(unittest.TestCase):
class DeformConvTester(OpTester, unittest.TestCase): class DeformConvTester(OpTester, unittest.TestCase):
def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1): def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
stride_h, stride_w = _pair(stride) stride_h, stride_w = _pair(stride)
pad_h, pad_w = _pair(padding) pad_h, pad_w = _pair(padding)
dil_h, dil_w = _pair(dilation) dil_h, dil_w = _pair(dilation)
...@@ -489,12 +489,17 @@ class DeformConvTester(OpTester, unittest.TestCase): ...@@ -489,12 +489,17 @@ class DeformConvTester(OpTester, unittest.TestCase):
c_in = weight_grp * in_c_per_weight_grp + c c_in = weight_grp * in_c_per_weight_grp + c
offset_grp = c_in // in_c_per_offset_grp offset_grp = c_in // in_c_per_offset_grp
offset_idx = 2 * (offset_grp * (weight_h * weight_w) + di * weight_w + dj) mask_idx = offset_grp * (weight_h * weight_w) + di * weight_w + dj
offset_idx = 2 * mask_idx
pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j] pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j] pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]
out[b, c_out, i, j] += (weight[c_out, c, di, dj] * mask_value = 1.0
if mask is not None:
mask_value = mask[b, mask_idx, i, j]
out[b, c_out, i, j] += (mask_value * weight[c_out, c, di, dj] *
bilinear_interpolate(x[b, c_in, :, :], pi, pj)) bilinear_interpolate(x[b, c_in, :, :], pi, pj))
out += bias.view(1, n_out_channels, 1, 1) out += bias.view(1, n_out_channels, 1, 1)
return out return out
...@@ -523,6 +528,9 @@ class DeformConvTester(OpTester, unittest.TestCase): ...@@ -523,6 +528,9 @@ class DeformConvTester(OpTester, unittest.TestCase):
offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w, offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w,
device=device, dtype=dtype, requires_grad=True) device=device, dtype=dtype, requires_grad=True)
mask = torch.randn(batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w,
device=device, dtype=dtype, requires_grad=True)
weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w, weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w,
device=device, dtype=dtype, requires_grad=True) device=device, dtype=dtype, requires_grad=True)
...@@ -531,9 +539,10 @@ class DeformConvTester(OpTester, unittest.TestCase): ...@@ -531,9 +539,10 @@ class DeformConvTester(OpTester, unittest.TestCase):
if not contiguous: if not contiguous:
x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2) x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1) offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0) weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
return x, weight, offset, bias, stride, pad, dilation return x, weight, offset, mask, bias, stride, pad, dilation
def _test_forward(self, device, contiguous, dtype=None): def _test_forward(self, device, contiguous, dtype=None):
dtype = self.dtype if dtype is None else dtype dtype = self.dtype if dtype is None else dtype
...@@ -541,21 +550,28 @@ class DeformConvTester(OpTester, unittest.TestCase): ...@@ -541,21 +550,28 @@ class DeformConvTester(OpTester, unittest.TestCase):
self._test_forward_with_batchsize(device, contiguous, batch_sz, dtype) self._test_forward_with_batchsize(device, contiguous, batch_sz, dtype)
def _test_forward_with_batchsize(self, device, contiguous, batch_sz, dtype): def _test_forward_with_batchsize(self, device, contiguous, batch_sz, dtype):
x, _, offset, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype) x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
in_channels = 6 in_channels = 6
out_channels = 2 out_channels = 2
kernel_size = (3, 2) kernel_size = (3, 2)
groups = 2 groups = 2
tol = 1e-3 if dtype is torch.half else 1e-5
layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups).to(device=x.device, dtype=dtype) dilation=dilation, groups=groups).to(device=x.device, dtype=dtype)
res = layer(x, offset) res = layer(x, offset, mask)
weight = layer.weight.data weight = layer.weight.data
bias = layer.bias.data bias = layer.bias.data
expected = self.expected_fn(x, weight, offset, bias, stride=stride, padding=padding, dilation=dilation) expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)
self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol),
'\nres:\n{}\nexpected:\n{}'.format(res, expected))
# no modulation test
res = layer(x, offset)
expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
tol = 1e-3 if dtype is torch.half else 1e-5
self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol), self.assertTrue(torch.allclose(res.to(expected.dtype), expected, rtol=tol, atol=tol),
'\nres:\n{}\nexpected:\n{}'.format(res, expected)) '\nres:\n{}\nexpected:\n{}'.format(res, expected))
...@@ -564,24 +580,46 @@ class DeformConvTester(OpTester, unittest.TestCase): ...@@ -564,24 +580,46 @@ class DeformConvTester(OpTester, unittest.TestCase):
wrong_offset = torch.rand_like(offset[:, :2]) wrong_offset = torch.rand_like(offset[:, :2])
res = layer(x, wrong_offset) res = layer(x, wrong_offset)
with self.assertRaises(RuntimeError):
wrong_mask = torch.rand_like(mask[:, :2])
res = layer(x, offset, wrong_mask)
def _test_backward(self, device, contiguous): def _test_backward(self, device, contiguous):
for batch_sz in [0, 33]: for batch_sz in [0, 33]:
self._test_backward_with_batchsize(device, contiguous, batch_sz) self._test_backward_with_batchsize(device, contiguous, batch_sz)
def _test_backward_with_batchsize(self, device, contiguous, batch_sz): def _test_backward_with_batchsize(self, device, contiguous, batch_sz):
x, weight, offset, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, self.dtype) x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(device, contiguous,
batch_sz, self.dtype)
def func(x_, offset_, mask_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride,
padding=padding, dilation=dilation, mask=mask_)
def func(x_, offset_, weight_, bias_): gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5)
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation)
def func_no_mask(x_, offset_, weight_, bias_):
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride,
padding=padding, dilation=dilation, mask=None)
gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5)
@torch.jit.script
def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
# type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_,
padding=pad_, dilation=dilation_, mask=mask_)
gradcheck(func, (x, offset, weight, bias), nondet_tol=1e-5) gradcheck(lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
(x, offset, mask, weight, bias), nondet_tol=1e-5)
@torch.jit.script @torch.jit.script
def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_): def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
# type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor # type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_) return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_,
padding=pad_, dilation=dilation_, mask=None)
gradcheck(lambda z, off, wei, bi: script_func(z, off, wei, bi, stride, padding, dilation), gradcheck(lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
(x, offset, weight, bias), nondet_tol=1e-5) (x, offset, weight, bias), nondet_tol=1e-5)
# Test from https://github.com/pytorch/vision/issues/2598 # Test from https://github.com/pytorch/vision/issues/2598
...@@ -593,17 +631,19 @@ class DeformConvTester(OpTester, unittest.TestCase): ...@@ -593,17 +631,19 @@ class DeformConvTester(OpTester, unittest.TestCase):
init_weight = torch.randn(9, 9, 3, 3, requires_grad=True) init_weight = torch.randn(9, 9, 3, 3, requires_grad=True)
img = torch.randn(8, 9, 1000, 110) img = torch.randn(8, 9, 1000, 110)
offset = torch.rand(8, 2 * 3 * 3, 1000, 110) offset = torch.rand(8, 2 * 3 * 3, 1000, 110)
mask = torch.rand(8, 3 * 3, 1000, 110)
if not contiguous: if not contiguous:
img = img.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2) img = img.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1) offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
weight = init_weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0) weight = init_weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
else: else:
weight = init_weight weight = init_weight
for d in ["cpu", "cuda"]: for d in ["cpu", "cuda"]:
out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1) out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
out.mean().backward() out.mean().backward()
if true_cpu_grads is None: if true_cpu_grads is None:
true_cpu_grads = init_weight.grad true_cpu_grads = init_weight.grad
......
...@@ -17,6 +17,7 @@ at::Tensor deform_conv2d( ...@@ -17,6 +17,7 @@ at::Tensor deform_conv2d(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias, const at::Tensor& bias,
int64_t stride_h, int64_t stride_h,
int64_t stride_w, int64_t stride_w,
...@@ -25,7 +26,8 @@ at::Tensor deform_conv2d( ...@@ -25,7 +26,8 @@ at::Tensor deform_conv2d(
int64_t dilation_h, int64_t dilation_h,
int64_t dilation_w, int64_t dilation_w,
int64_t groups, int64_t groups,
int64_t offset_groups) { int64_t offset_groups,
bool use_mask) {
static auto op = c10::Dispatcher::singleton() static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::deform_conv2d", "") .findSchemaOrThrow("torchvision::deform_conv2d", "")
.typed<decltype(deform_conv2d)>(); .typed<decltype(deform_conv2d)>();
...@@ -33,6 +35,7 @@ at::Tensor deform_conv2d( ...@@ -33,6 +35,7 @@ at::Tensor deform_conv2d(
input, input,
weight, weight,
offset, offset,
mask,
bias, bias,
stride_h, stride_h,
stride_w, stride_w,
...@@ -41,7 +44,8 @@ at::Tensor deform_conv2d( ...@@ -41,7 +44,8 @@ at::Tensor deform_conv2d(
dilation_h, dilation_h,
dilation_w, dilation_w,
groups, groups,
offset_groups); offset_groups,
use_mask);
} }
#if defined(WITH_CUDA) || defined(WITH_HIP) #if defined(WITH_CUDA) || defined(WITH_HIP)
...@@ -49,6 +53,7 @@ at::Tensor DeformConv2d_autocast( ...@@ -49,6 +53,7 @@ at::Tensor DeformConv2d_autocast(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias, const at::Tensor& bias,
int64_t stride_h, int64_t stride_h,
int64_t stride_w, int64_t stride_w,
...@@ -57,12 +62,14 @@ at::Tensor DeformConv2d_autocast( ...@@ -57,12 +62,14 @@ at::Tensor DeformConv2d_autocast(
int64_t dilation_h, int64_t dilation_h,
int64_t dilation_w, int64_t dilation_w,
int64_t groups, int64_t groups,
int64_t offset_groups) { int64_t offset_groups,
bool use_mask) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return deform_conv2d( return deform_conv2d(
at::autocast::cached_cast(at::kFloat, input), at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, weight), at::autocast::cached_cast(at::kFloat, weight),
at::autocast::cached_cast(at::kFloat, offset), at::autocast::cached_cast(at::kFloat, offset),
at::autocast::cached_cast(at::kFloat, mask),
at::autocast::cached_cast(at::kFloat, bias), at::autocast::cached_cast(at::kFloat, bias),
stride_h, stride_h,
stride_w, stride_w,
...@@ -71,17 +78,19 @@ at::Tensor DeformConv2d_autocast( ...@@ -71,17 +78,19 @@ at::Tensor DeformConv2d_autocast(
dilation_h, dilation_h,
dilation_w, dilation_w,
groups, groups,
offset_groups) offset_groups,
use_mask)
.to(input.scalar_type()); .to(input.scalar_type());
} }
#endif #endif
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward( _deform_conv2d_backward(
const at::Tensor& grad, const at::Tensor& grad,
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias, const at::Tensor& bias,
int64_t stride_h, int64_t stride_h,
int64_t stride_w, int64_t stride_w,
...@@ -90,7 +99,8 @@ _deform_conv2d_backward( ...@@ -90,7 +99,8 @@ _deform_conv2d_backward(
int64_t dilation_h, int64_t dilation_h,
int64_t dilation_w, int64_t dilation_w,
int64_t groups, int64_t groups,
int64_t offset_groups) { int64_t offset_groups,
bool use_mask) {
static auto op = static auto op =
c10::Dispatcher::singleton() c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_deform_conv2d_backward", "") .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
...@@ -100,6 +110,7 @@ _deform_conv2d_backward( ...@@ -100,6 +110,7 @@ _deform_conv2d_backward(
input, input,
weight, weight,
offset, offset,
mask,
bias, bias,
stride_h, stride_h,
stride_w, stride_w,
...@@ -108,7 +119,8 @@ _deform_conv2d_backward( ...@@ -108,7 +119,8 @@ _deform_conv2d_backward(
dilation_h, dilation_h,
dilation_w, dilation_w,
groups, groups,
offset_groups); offset_groups,
use_mask);
} }
class DeformConv2dFunction class DeformConv2dFunction
...@@ -119,6 +131,7 @@ class DeformConv2dFunction ...@@ -119,6 +131,7 @@ class DeformConv2dFunction
const torch::autograd::Variable& input, const torch::autograd::Variable& input,
const torch::autograd::Variable& weight, const torch::autograd::Variable& weight,
const torch::autograd::Variable& offset, const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias, const torch::autograd::Variable& bias,
int64_t stride_h, int64_t stride_h,
int64_t stride_w, int64_t stride_w,
...@@ -127,12 +140,14 @@ class DeformConv2dFunction ...@@ -127,12 +140,14 @@ class DeformConv2dFunction
int64_t dilation_h, int64_t dilation_h,
int64_t dilation_w, int64_t dilation_w,
int64_t groups, int64_t groups,
int64_t offset_groups) { int64_t offset_groups,
bool use_mask) {
at::AutoNonVariableTypeMode g; at::AutoNonVariableTypeMode g;
auto output = deform_conv2d( auto output = deform_conv2d(
input, input,
weight, weight,
offset, offset,
mask,
bias, bias,
stride_h, stride_h,
stride_w, stride_w,
...@@ -141,9 +156,10 @@ class DeformConv2dFunction ...@@ -141,9 +156,10 @@ class DeformConv2dFunction
dilation_h, dilation_h,
dilation_w, dilation_w,
groups, groups,
offset_groups); offset_groups,
use_mask);
ctx->save_for_backward({input, weight, offset, bias}); ctx->save_for_backward({input, weight, offset, mask, bias});
ctx->saved_data["stride_h"] = stride_h; ctx->saved_data["stride_h"] = stride_h;
ctx->saved_data["stride_w"] = stride_w; ctx->saved_data["stride_w"] = stride_w;
ctx->saved_data["pad_h"] = pad_h; ctx->saved_data["pad_h"] = pad_h;
...@@ -152,6 +168,7 @@ class DeformConv2dFunction ...@@ -152,6 +168,7 @@ class DeformConv2dFunction
ctx->saved_data["dilation_w"] = dilation_w; ctx->saved_data["dilation_w"] = dilation_w;
ctx->saved_data["groups"] = groups; ctx->saved_data["groups"] = groups;
ctx->saved_data["offset_groups"] = offset_groups; ctx->saved_data["offset_groups"] = offset_groups;
ctx->saved_data["use_mask"] = use_mask;
return { return {
output, output,
...@@ -165,7 +182,8 @@ class DeformConv2dFunction ...@@ -165,7 +182,8 @@ class DeformConv2dFunction
auto input = saved[0]; auto input = saved[0];
auto weight = saved[1]; auto weight = saved[1];
auto offset = saved[2]; auto offset = saved[2];
auto bias = saved[3]; auto mask = saved[3];
auto bias = saved[4];
auto stride_h = ctx->saved_data["stride_h"].toInt(); auto stride_h = ctx->saved_data["stride_h"].toInt();
auto stride_w = ctx->saved_data["stride_w"].toInt(); auto stride_w = ctx->saved_data["stride_w"].toInt();
...@@ -175,12 +193,14 @@ class DeformConv2dFunction ...@@ -175,12 +193,14 @@ class DeformConv2dFunction
auto dilation_w = ctx->saved_data["dilation_w"].toInt(); auto dilation_w = ctx->saved_data["dilation_w"].toInt();
auto groups = ctx->saved_data["groups"].toInt(); auto groups = ctx->saved_data["groups"].toInt();
auto offset_groups = ctx->saved_data["offset_groups"].toInt(); auto offset_groups = ctx->saved_data["offset_groups"].toInt();
auto use_mask = ctx->saved_data["use_mask"].toBool();
auto grads = _deform_conv2d_backward( auto grads = _deform_conv2d_backward(
grad_output[0], grad_output[0],
input, input,
weight, weight,
offset, offset,
mask,
bias, bias,
stride_h, stride_h,
stride_w, stride_w,
...@@ -189,16 +209,19 @@ class DeformConv2dFunction ...@@ -189,16 +209,19 @@ class DeformConv2dFunction
dilation_h, dilation_h,
dilation_w, dilation_w,
groups, groups,
offset_groups); offset_groups,
use_mask);
auto grad_input = std::get<0>(grads); auto grad_input = std::get<0>(grads);
auto grad_weight = std::get<1>(grads); auto grad_weight = std::get<1>(grads);
auto grad_offset = std::get<2>(grads); auto grad_offset = std::get<2>(grads);
auto grad_bias = std::get<3>(grads); auto grad_mask = std::get<3>(grads);
auto grad_bias = std::get<4>(grads);
return { return {
grad_input, grad_input,
grad_weight, grad_weight,
grad_offset, grad_offset,
grad_mask,
grad_bias, grad_bias,
torch::autograd::Variable(), torch::autograd::Variable(),
torch::autograd::Variable(), torch::autograd::Variable(),
...@@ -208,6 +231,7 @@ class DeformConv2dFunction ...@@ -208,6 +231,7 @@ class DeformConv2dFunction
torch::autograd::Variable(), torch::autograd::Variable(),
torch::autograd::Variable(), torch::autograd::Variable(),
torch::autograd::Variable(), torch::autograd::Variable(),
torch::autograd::Variable(),
}; };
} }
}; };
...@@ -222,6 +246,7 @@ class DeformConv2dBackwardFunction ...@@ -222,6 +246,7 @@ class DeformConv2dBackwardFunction
const torch::autograd::Variable& input, const torch::autograd::Variable& input,
const torch::autograd::Variable& weight, const torch::autograd::Variable& weight,
const torch::autograd::Variable& offset, const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias, const torch::autograd::Variable& bias,
int64_t stride_h, int64_t stride_h,
int64_t stride_w, int64_t stride_w,
...@@ -230,13 +255,15 @@ class DeformConv2dBackwardFunction ...@@ -230,13 +255,15 @@ class DeformConv2dBackwardFunction
int64_t dilation_h, int64_t dilation_h,
int64_t dilation_w, int64_t dilation_w,
int64_t groups, int64_t groups,
int64_t offset_groups) { int64_t offset_groups,
bool use_mask) {
at::AutoNonVariableTypeMode g; at::AutoNonVariableTypeMode g;
auto result = _deform_conv2d_backward( auto result = _deform_conv2d_backward(
grad, grad,
input, input,
weight, weight,
offset, offset,
mask,
bias, bias,
stride_h, stride_h,
stride_w, stride_w,
...@@ -245,17 +272,20 @@ class DeformConv2dBackwardFunction ...@@ -245,17 +272,20 @@ class DeformConv2dBackwardFunction
dilation_h, dilation_h,
dilation_w, dilation_w,
groups, groups,
offset_groups); offset_groups,
use_mask);
auto grad_input = std::get<0>(result); auto grad_input = std::get<0>(result);
auto grad_weight = std::get<1>(result); auto grad_weight = std::get<1>(result);
auto grad_offset = std::get<2>(result); auto grad_offset = std::get<2>(result);
auto grad_bias = std::get<3>(result); auto grad_mask = std::get<3>(result);
auto grad_bias = std::get<4>(result);
return { return {
grad_input, grad_input,
grad_weight, grad_weight,
grad_offset, grad_offset,
grad_mask,
grad_bias, grad_bias,
}; };
} }
...@@ -271,6 +301,7 @@ at::Tensor DeformConv2d_autograd( ...@@ -271,6 +301,7 @@ at::Tensor DeformConv2d_autograd(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias, const at::Tensor& bias,
int64_t stride_h, int64_t stride_h,
int64_t stride_w, int64_t stride_w,
...@@ -279,11 +310,13 @@ at::Tensor DeformConv2d_autograd( ...@@ -279,11 +310,13 @@ at::Tensor DeformConv2d_autograd(
int64_t dilation_h, int64_t dilation_h,
int64_t dilation_w, int64_t dilation_w,
int64_t groups, int64_t groups,
int64_t offset_groups) { int64_t offset_groups,
bool use_mask) {
return DeformConv2dFunction::apply( return DeformConv2dFunction::apply(
input, input,
weight, weight,
offset, offset,
mask,
bias, bias,
stride_h, stride_h,
stride_w, stride_w,
...@@ -292,15 +325,17 @@ at::Tensor DeformConv2d_autograd( ...@@ -292,15 +325,17 @@ at::Tensor DeformConv2d_autograd(
dilation_h, dilation_h,
dilation_w, dilation_w,
groups, groups,
offset_groups)[0]; offset_groups,
use_mask)[0];
} }
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_autograd( DeformConv2d_backward_autograd(
const at::Tensor& grad, const at::Tensor& grad,
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias, const at::Tensor& bias,
int64_t stride_h, int64_t stride_h,
int64_t stride_w, int64_t stride_w,
...@@ -309,12 +344,14 @@ DeformConv2d_backward_autograd( ...@@ -309,12 +344,14 @@ DeformConv2d_backward_autograd(
int64_t dilation_h, int64_t dilation_h,
int64_t dilation_w, int64_t dilation_w,
int64_t groups, int64_t groups,
int64_t offset_groups) { int64_t offset_groups,
bool use_mask) {
auto result = DeformConv2dBackwardFunction::apply( auto result = DeformConv2dBackwardFunction::apply(
grad, grad,
input, input,
weight, weight,
offset, offset,
mask,
bias, bias,
stride_h, stride_h,
stride_w, stride_w,
...@@ -323,7 +360,8 @@ DeformConv2d_backward_autograd( ...@@ -323,7 +360,8 @@ DeformConv2d_backward_autograd(
dilation_h, dilation_h,
dilation_w, dilation_w,
groups, groups,
offset_groups); offset_groups,
use_mask);
return std::make_tuple(result[0], result[1], result[2], result[3]); return std::make_tuple(result[0], result[1], result[2], result[3], result[4]);
} }
This diff is collapsed.
...@@ -6,6 +6,7 @@ VISION_API at::Tensor DeformConv2d_forward_cpu( ...@@ -6,6 +6,7 @@ VISION_API at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias, const at::Tensor& bias,
int64_t stride_h, int64_t stride_h,
int64_t stride_w, int64_t stride_w,
...@@ -14,14 +15,17 @@ VISION_API at::Tensor DeformConv2d_forward_cpu( ...@@ -14,14 +15,17 @@ VISION_API at::Tensor DeformConv2d_forward_cpu(
int64_t dilation_h, int64_t dilation_h,
int64_t dilation_w, int64_t dilation_w,
int64_t groups, int64_t groups,
int64_t deformable_groups); int64_t deformable_groups,
bool use_mask);
VISION_API std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> VISION_API std::
DeformConv2d_backward_cpu( tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cpu(
const at::Tensor& grad_out, const at::Tensor& grad_out,
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias, const at::Tensor& bias,
int64_t stride_h, int64_t stride_h,
int64_t stride_w, int64_t stride_w,
...@@ -30,7 +34,8 @@ DeformConv2d_backward_cpu( ...@@ -30,7 +34,8 @@ DeformConv2d_backward_cpu(
int64_t dilation_h, int64_t dilation_h,
int64_t dilation_w, int64_t dilation_w,
int64_t groups, int64_t groups,
int64_t deformable_groups); int64_t deformable_groups,
bool use_mask);
VISION_API at::Tensor nms_cpu( VISION_API at::Tensor nms_cpu(
const at::Tensor& dets, const at::Tensor& dets,
......
This diff is collapsed.
...@@ -6,6 +6,7 @@ VISION_API at::Tensor DeformConv2d_forward_cuda( ...@@ -6,6 +6,7 @@ VISION_API at::Tensor DeformConv2d_forward_cuda(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias, const at::Tensor& bias,
int64_t stride_h, int64_t stride_h,
int64_t stride_w, int64_t stride_w,
...@@ -14,14 +15,17 @@ VISION_API at::Tensor DeformConv2d_forward_cuda( ...@@ -14,14 +15,17 @@ VISION_API at::Tensor DeformConv2d_forward_cuda(
int64_t dilation_h, int64_t dilation_h,
int64_t dilation_w, int64_t dilation_w,
int64_t groups, int64_t groups,
int64_t deformable_groups); int64_t deformable_groups,
bool use_mask);
VISION_API std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> VISION_API std::
DeformConv2d_backward_cuda( tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cuda(
const at::Tensor& grad_out, const at::Tensor& grad_out,
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias, const at::Tensor& bias,
int64_t stride_h, int64_t stride_h,
int64_t stride_w, int64_t stride_w,
...@@ -30,7 +34,8 @@ DeformConv2d_backward_cuda( ...@@ -30,7 +34,8 @@ DeformConv2d_backward_cuda(
int64_t dilation_h, int64_t dilation_h,
int64_t dilation_w, int64_t dilation_w,
int64_t groups, int64_t groups,
int64_t deformable_groups); int64_t deformable_groups,
bool use_mask);
VISION_API at::Tensor nms_cuda( VISION_API at::Tensor nms_cuda(
const at::Tensor& dets, const at::Tensor& dets,
......
...@@ -46,9 +46,9 @@ int64_t cuda_version() noexcept { ...@@ -46,9 +46,9 @@ int64_t cuda_version() noexcept {
TORCH_LIBRARY(torchvision, m) { TORCH_LIBRARY(torchvision, m) {
m.def( m.def(
"deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor"); "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor");
m.def( m.def(
"_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> (Tensor, Tensor, Tensor, Tensor)"); "_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"); m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor");
m.def( m.def(
"ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)"); "ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)");
......
...@@ -17,6 +17,7 @@ def deform_conv2d( ...@@ -17,6 +17,7 @@ def deform_conv2d(
stride: Tuple[int, int] = (1, 1), stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0), padding: Tuple[int, int] = (0, 0),
dilation: Tuple[int, int] = (1, 1), dilation: Tuple[int, int] = (1, 1),
mask: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
""" """
Performs Deformable Convolution, described in Deformable Convolutional Networks Performs Deformable Convolution, described in Deformable Convolutional Networks
...@@ -33,6 +34,9 @@ def deform_conv2d( ...@@ -33,6 +34,9 @@ def deform_conv2d(
padding (int or Tuple[int, int]): height/width of padding of zeroes around padding (int or Tuple[int, int]): height/width of padding of zeroes around
each image. Default: 0 each image. Default: 0
dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1 dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1
mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width,
out_height, out_width]): masks to be applied for each position in the
convolution kernel.
Returns: Returns:
output (Tensor[batch_sz, out_channels, out_h, out_w]): result of convolution output (Tensor[batch_sz, out_channels, out_h, out_w]): result of convolution
...@@ -42,11 +46,12 @@ def deform_conv2d( ...@@ -42,11 +46,12 @@ def deform_conv2d(
>>> input = torch.rand(4, 3, 10, 10) >>> input = torch.rand(4, 3, 10, 10)
>>> kh, kw = 3, 3 >>> kh, kw = 3, 3
>>> weight = torch.rand(5, 3, kh, kw) >>> weight = torch.rand(5, 3, kh, kw)
>>> # offset should have the same spatial size as the output >>> # offset and mask should have the same spatial size as the output
>>> # of the convolution. In this case, for an input of 10, stride of 1 >>> # of the convolution. In this case, for an input of 10, stride of 1
>>> # and kernel size of 3, without padding, the output size is 8 >>> # and kernel size of 3, without padding, the output size is 8
>>> offset = torch.rand(4, 2 * kh * kw, 8, 8) >>> offset = torch.rand(4, 2 * kh * kw, 8, 8)
>>> out = deform_conv2d(input, offset, weight) >>> mask = torch.rand(4, kh * kw, 8, 8)
>>> out = deform_conv2d(input, offset, weight, mask=mask)
>>> print(out.shape) >>> print(out.shape)
>>> # returns >>> # returns
>>> torch.Size([4, 5, 8, 8]) >>> torch.Size([4, 5, 8, 8])
...@@ -54,6 +59,12 @@ def deform_conv2d( ...@@ -54,6 +59,12 @@ def deform_conv2d(
_assert_has_ops() _assert_has_ops()
out_channels = weight.shape[0] out_channels = weight.shape[0]
use_mask = mask is not None
if mask is None:
mask = torch.zeros((input.shape[0], 0), device=input.device, dtype=input.dtype)
if bias is None: if bias is None:
bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype) bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype)
...@@ -77,18 +88,21 @@ def deform_conv2d( ...@@ -77,18 +88,21 @@ def deform_conv2d(
input, input,
weight, weight,
offset, offset,
mask,
bias, bias,
stride_h, stride_w, stride_h, stride_w,
pad_h, pad_w, pad_h, pad_w,
dil_h, dil_w, dil_h, dil_w,
n_weight_grps, n_weight_grps,
n_offset_grps) n_offset_grps,
use_mask,)
class DeformConv2d(nn.Module): class DeformConv2d(nn.Module):
""" """
See deform_conv2d See deform_conv2d
""" """
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -127,21 +141,25 @@ class DeformConv2d(nn.Module): ...@@ -127,21 +141,25 @@ class DeformConv2d(nn.Module):
def reset_parameters(self) -> None: def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5)) init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None: if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound) init.uniform_(self.bias, -bound, bound)
def forward(self, input: Tensor, offset: Tensor) -> Tensor: def forward(self, input: Tensor, offset: Tensor, mask: Tensor = None) -> Tensor:
""" """
Arguments: Arguments:
input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width,
out_height, out_width]): offsets to be applied for each position in the out_height, out_width]): offsets to be applied for each position in the
convolution kernel. convolution kernel.
mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width,
out_height, out_width]): masks to be applied for each position in the
convolution kernel.
""" """
return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride, return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride,
padding=self.padding, dilation=self.dilation) padding=self.padding, dilation=self.dilation, mask=mask)
def __repr__(self) -> str: def __repr__(self) -> str:
s = self.__class__.__name__ + '(' s = self.__class__.__name__ + '('
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment