Unverified Commit e8b6e3f0 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Port DeformConv to use the Dispatcher and support Autocast (#2898)

* Splitting tuples of stride, padding and dilation of DeformConv.

* Fixing types.

* Dispatcher + Autocast.

* + Autograd.

* Moving contiguous() convertions away dispatcher and into the implementations.

* Removing rvalue references.
parent f9e31a6d
#pragma once #pragma once
#include "cpu/vision_cpu.h" #if defined(WITH_CUDA) || defined(WITH_HIP)
#include "autocast.h"
#ifdef WITH_CUDA
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
#include "hip/vision_cuda.h"
#endif #endif
at::Tensor DeformConv2d_forward( // TODO: put this stuff in torchvision namespace
at::Tensor deform_conv2d(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& bias, const at::Tensor& bias,
const std::pair<int, int>& stride, const int64_t stride_h,
const std::pair<int, int>& padding, const int64_t stride_w,
const std::pair<int, int>& dilation, const int64_t pad_h,
const int groups, const int64_t pad_w,
const int offset_groups) { const int64_t dilation_h,
if (input.is_cuda()) { const int64_t dilation_w,
#if defined(WITH_CUDA) || defined(WITH_HIP) const int64_t groups,
return DeformConv2d_forward_cuda( const int64_t offset_groups) {
input.contiguous(), static auto op = c10::Dispatcher::singleton()
weight.contiguous(), .findSchemaOrThrow("torchvision::deform_conv2d", "")
offset.contiguous(), .typed<decltype(deform_conv2d)>();
bias.contiguous(), return op.call(
stride, input,
padding, weight,
dilation, offset,
groups, bias,
offset_groups); stride_h,
#else stride_w,
TORCH_CHECK(false, "Not compiled with GPU support"); pad_h,
#endif pad_w,
} dilation_h,
return DeformConv2d_forward_cpu( dilation_w,
input.contiguous(),
weight.contiguous(),
offset.contiguous(),
bias.contiguous(),
stride,
padding,
dilation,
groups, groups,
offset_groups); offset_groups);
} }
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> DeformConv2d_backward( #if defined(WITH_CUDA) || defined(WITH_HIP)
const at::Tensor& grad, at::Tensor DeformConv2d_autocast(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& bias, const at::Tensor& bias,
const std::pair<int, int>& stride, const int64_t stride_h,
const std::pair<int, int>& padding, const int64_t stride_w,
const std::pair<int, int>& dilation, const int64_t pad_h,
const int groups, const int64_t pad_w,
const int offset_groups) { const int64_t dilation_h,
if (grad.is_cuda()) { const int64_t dilation_w,
#if defined(WITH_CUDA) || defined(WITH_HIP) const int64_t groups,
return DeformConv2d_backward_cuda( const int64_t offset_groups) {
grad.contiguous(), c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
input.contiguous(), return deform_conv2d(
weight.contiguous(), at::autocast::cached_cast(at::kFloat, input),
offset.contiguous(), at::autocast::cached_cast(at::kFloat, weight),
bias.contiguous(), at::autocast::cached_cast(at::kFloat, offset),
stride, at::autocast::cached_cast(at::kFloat, bias),
padding, stride_h,
dilation, stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups, groups,
offset_groups); offset_groups)
#else .to(input.scalar_type());
TORCH_CHECK(false, "Not compiled with GPU support"); }
#endif #endif
}
return DeformConv2d_backward_cpu( std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
grad.contiguous(), _deform_conv2d_backward(
input.contiguous(), const at::Tensor& grad,
weight.contiguous(), const at::Tensor& input,
offset.contiguous(), const at::Tensor& weight,
bias.contiguous(), const at::Tensor& offset,
stride, const at::Tensor& bias,
padding, const int64_t stride_h,
dilation, const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
.typed<decltype(_deform_conv2d_backward)>();
return op.call(
grad,
input,
weight,
offset,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups, groups,
offset_groups); offset_groups);
} }
...@@ -105,14 +121,18 @@ class DeformConv2dFunction ...@@ -105,14 +121,18 @@ class DeformConv2dFunction
int64_t dilation_w, int64_t dilation_w,
int64_t groups, int64_t groups,
int64_t offset_groups) { int64_t offset_groups) {
auto output = DeformConv2d_forward( at::AutoNonVariableTypeMode g; // TODO_vv: check if necessary
auto output = deform_conv2d(
input, input,
weight, weight,
offset, offset,
bias, bias,
{stride_h, stride_w}, stride_h,
{pad_h, pad_w}, stride_w,
{dilation_h, dilation_w}, pad_h,
pad_w,
dilation_h,
dilation_w,
groups, groups,
offset_groups); offset_groups);
...@@ -149,15 +169,18 @@ class DeformConv2dFunction ...@@ -149,15 +169,18 @@ class DeformConv2dFunction
auto groups = ctx->saved_data["groups"].toInt(); auto groups = ctx->saved_data["groups"].toInt();
auto offset_groups = ctx->saved_data["offset_groups"].toInt(); auto offset_groups = ctx->saved_data["offset_groups"].toInt();
auto grads = DeformConv2d_backward( auto grads = _deform_conv2d_backward(
grad_output[0], grad_output[0],
input, input,
weight, weight,
offset, offset,
bias, bias,
{stride_h, stride_w}, stride_h,
{pad_h, pad_w}, stride_w,
{dilation_h, dilation_w}, pad_h,
pad_w,
dilation_h,
dilation_w,
groups, groups,
offset_groups); offset_groups);
auto grad_input = std::get<0>(grads); auto grad_input = std::get<0>(grads);
...@@ -182,20 +205,106 @@ class DeformConv2dFunction ...@@ -182,20 +205,106 @@ class DeformConv2dFunction
} }
}; };
at::Tensor deform_conv2d( // TODO: There should be an easier way to do this
class DeformConv2dBackwardFunction
: public torch::autograd::Function<DeformConv2dBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable grad,
torch::autograd::Variable input,
torch::autograd::Variable weight,
torch::autograd::Variable offset,
torch::autograd::Variable bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
at::AutoNonVariableTypeMode g;
auto result = _deform_conv2d_backward(
grad,
input,
weight,
offset,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups);
auto grad_input = std::get<0>(result);
auto grad_weight = std::get<1>(result);
auto grad_offset = std::get<2>(result);
auto grad_bias = std::get<3>(result);
return {
grad_input,
grad_weight,
grad_offset,
grad_bias,
};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
TORCH_CHECK(0, "double backwards on deform_conv2d not supported");
}
};
at::Tensor DeformConv2d_autograd(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& bias, const at::Tensor& bias,
int64_t stride_h, const int64_t stride_h,
int64_t stride_w, const int64_t stride_w,
int64_t pad_h, const int64_t pad_h,
int64_t pad_w, const int64_t pad_w,
int64_t dilation_h, const int64_t dilation_h,
int64_t dilation_w, const int64_t dilation_w,
int64_t groups, const int64_t groups,
int64_t offset_groups) { const int64_t offset_groups) {
auto result = DeformConv2dFunction::apply( return DeformConv2dFunction::apply(
input,
weight,
offset,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups)[0];
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_autograd(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
const int64_t stride_h,
const int64_t stride_w,
const int64_t pad_h,
const int64_t pad_w,
const int64_t dilation_h,
const int64_t dilation_w,
const int64_t groups,
const int64_t offset_groups) {
auto result = DeformConv2dBackwardFunction::apply(
grad,
input, input,
weight, weight,
offset, offset,
...@@ -208,5 +317,5 @@ at::Tensor deform_conv2d( ...@@ -208,5 +317,5 @@ at::Tensor deform_conv2d(
dilation_w, dilation_w,
groups, groups,
offset_groups); offset_groups);
return result[0]; return std::make_tuple(result[0], result[1], result[2], result[3]);
} }
\ No newline at end of file
...@@ -232,22 +232,23 @@ at::Tensor DeformConv2d_forward_cpu( ...@@ -232,22 +232,23 @@ at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& input_param, const at::Tensor& input_param,
const at::Tensor& weight_param, const at::Tensor& weight_param,
const at::Tensor& offset_param, const at::Tensor& offset_param,
const at::Tensor& bias, const at::Tensor& bias_param,
std::pair<int, int> stride, int64_t stride_h,
std::pair<int, int> pad, int64_t stride_w,
std::pair<int, int> dilation, int64_t pad_h,
int n_weight_grps, int64_t pad_w,
int n_offset_grps) { int64_t dil_h,
at::Tensor input = input_param; int64_t dil_w,
at::Tensor offset = offset_param; int64_t n_weight_grps,
at::Tensor weight = weight_param; int64_t n_offset_grps) {
at::Tensor input = input_param.contiguous();
at::Tensor offset = offset_param.contiguous();
at::Tensor weight = weight_param.contiguous();
at::Tensor bias = bias_param.contiguous();
TORCH_CHECK(input.ndimension() == 4); TORCH_CHECK(input.ndimension() == 4);
TORCH_CHECK(offset.ndimension() == 4); TORCH_CHECK(offset.ndimension() == 4);
TORCH_CHECK(weight.ndimension() == 4); TORCH_CHECK(weight.ndimension() == 4);
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(offset.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
int batch_sz = input.size(0); int batch_sz = input.size(0);
...@@ -263,15 +264,6 @@ at::Tensor DeformConv2d_forward_cpu( ...@@ -263,15 +264,6 @@ at::Tensor DeformConv2d_forward_cpu(
int weight_h = weight.size(2); int weight_h = weight.size(2);
int weight_w = weight.size(3); int weight_w = weight.size(3);
int stride_h = stride.first;
int stride_w = stride.second;
int pad_h = pad.first;
int pad_w = pad.second;
int dil_h = dilation.first;
int dil_w = dilation.second;
int ker_h = dil_h * (weight_h - 1) + 1; int ker_h = dil_h * (weight_h - 1) + 1;
int ker_w = dil_w * (weight_w - 1) + 1; int ker_w = dil_w * (weight_w - 1) + 1;
int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1;
...@@ -683,9 +675,12 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu( ...@@ -683,9 +675,12 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
at::Tensor weight, at::Tensor weight,
at::Tensor offset, at::Tensor offset,
at::Tensor grad_out, at::Tensor grad_out,
std::pair<int, int> stride, int stride_h,
std::pair<int, int> pad, int stride_w,
std::pair<int, int> dilation, int pad_h,
int pad_w,
int dil_h,
int dil_w,
int n_weight_grps, int n_weight_grps,
int n_offset_grps, int n_offset_grps,
int n_parallel_imgs) { int n_parallel_imgs) {
...@@ -700,15 +695,6 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu( ...@@ -700,15 +695,6 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
int weight_h = weight.size(2); int weight_h = weight.size(2);
int weight_w = weight.size(3); int weight_w = weight.size(3);
int stride_h = stride.first;
int stride_w = stride.second;
int pad_h = pad.first;
int pad_w = pad.second;
int dil_h = dilation.first;
int dil_w = dilation.second;
long out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) / stride_h + 1; long out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) / stride_h + 1;
long out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) / stride_w + 1; long out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) / stride_w + 1;
...@@ -813,9 +799,12 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( ...@@ -813,9 +799,12 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
const at::Tensor& weight, const at::Tensor& weight,
at::Tensor offset, at::Tensor offset,
const at::Tensor& grad_out, const at::Tensor& grad_out,
std::pair<int, int> stride, int stride_h,
std::pair<int, int> pad, int stride_w,
std::pair<int, int> dilation, int pad_h,
int pad_w,
int dil_h,
int dil_w,
int n_weight_grps, int n_weight_grps,
int n_offset_grps, int n_offset_grps,
int n_parallel_imgs) { int n_parallel_imgs) {
...@@ -830,15 +819,6 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( ...@@ -830,15 +819,6 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
int weight_h = weight.size(2); int weight_h = weight.size(2);
int weight_w = weight.size(3); int weight_w = weight.size(3);
int stride_h = stride.first;
int stride_w = stride.second;
int pad_h = pad.first;
int pad_w = pad.second;
int dil_h = dilation.first;
int dil_w = dilation.second;
long out_h = grad_out.size(2); long out_h = grad_out.size(2);
long out_w = grad_out.size(3); long out_w = grad_out.size(3);
...@@ -917,16 +897,25 @@ static at::Tensor deform_conv2d_backward_parameters_cpu( ...@@ -917,16 +897,25 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cpu( DeformConv2d_backward_cpu(
const at::Tensor& grad_out, const at::Tensor& grad_out_param,
const at::Tensor& input, const at::Tensor& input_param,
const at::Tensor& weight, const at::Tensor& weight_param,
const at::Tensor& offset, const at::Tensor& offset_param,
const at::Tensor& bias, const at::Tensor& bias_param,
std::pair<int, int> stride, int64_t stride_h,
std::pair<int, int> pad, int64_t stride_w,
std::pair<int, int> dilation, int64_t pad_h,
int n_weight_grps, int64_t pad_w,
int n_offset_grps) { int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps) {
at::Tensor grad_out = grad_out_param.contiguous();
at::Tensor input = input_param.contiguous();
at::Tensor weight = weight_param.contiguous();
at::Tensor offset = offset_param.contiguous();
at::Tensor bias = bias_param.contiguous();
const int batch_sz = input.size(0); const int batch_sz = input.size(0);
const int n_parallel_imgs = const int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
...@@ -936,9 +925,12 @@ DeformConv2d_backward_cpu( ...@@ -936,9 +925,12 @@ DeformConv2d_backward_cpu(
weight, weight,
offset, offset,
grad_out, grad_out,
stride, stride_h,
pad, stride_w,
dilation, pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps, n_weight_grps,
n_offset_grps, n_offset_grps,
n_parallel_imgs); n_parallel_imgs);
...@@ -951,9 +943,12 @@ DeformConv2d_backward_cpu( ...@@ -951,9 +943,12 @@ DeformConv2d_backward_cpu(
weight, weight,
offset, offset,
grad_out, grad_out,
stride, stride_h,
pad, stride_w,
dilation, pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps, n_weight_grps,
n_offset_grps, n_offset_grps,
n_parallel_imgs); n_parallel_imgs);
......
...@@ -93,11 +93,14 @@ VISION_API at::Tensor DeformConv2d_forward_cpu( ...@@ -93,11 +93,14 @@ VISION_API at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& bias, const at::Tensor& bias,
std::pair<int, int> stride, int64_t stride_h,
std::pair<int, int> pad, int64_t stride_w,
std::pair<int, int> dilation, int64_t pad_h,
int groups, int64_t pad_w,
int deformable_groups); int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
VISION_API std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> VISION_API std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cpu( DeformConv2d_backward_cpu(
...@@ -106,8 +109,11 @@ DeformConv2d_backward_cpu( ...@@ -106,8 +109,11 @@ DeformConv2d_backward_cpu(
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& bias, const at::Tensor& bias,
std::pair<int, int> stride, int64_t stride_h,
std::pair<int, int> pad, int64_t stride_w,
std::pair<int, int> dilation, int64_t pad_h,
int groups, int64_t pad_w,
int deformable_groups); int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
...@@ -248,22 +248,23 @@ at::Tensor DeformConv2d_forward_cuda( ...@@ -248,22 +248,23 @@ at::Tensor DeformConv2d_forward_cuda(
const at::Tensor& input_param, const at::Tensor& input_param,
const at::Tensor& weight_param, const at::Tensor& weight_param,
const at::Tensor& offset_param, const at::Tensor& offset_param,
const at::Tensor& bias, const at::Tensor& bias_param,
std::pair<int, int> stride, int64_t stride_h,
std::pair<int, int> pad, int64_t stride_w,
std::pair<int, int> dilation, int64_t pad_h,
int n_weight_grps, int64_t pad_w,
int n_offset_grps) { int64_t dil_h,
at::Tensor input = input_param; int64_t dil_w,
at::Tensor weight = weight_param; int64_t n_weight_grps,
at::Tensor offset = offset_param; int64_t n_offset_grps) {
at::Tensor input = input_param.contiguous();
at::Tensor offset = offset_param.contiguous();
at::Tensor weight = weight_param.contiguous();
at::Tensor bias = bias_param.contiguous();
TORCH_CHECK(input.ndimension() == 4); TORCH_CHECK(input.ndimension() == 4);
TORCH_CHECK(offset.ndimension() == 4); TORCH_CHECK(offset.ndimension() == 4);
TORCH_CHECK(weight.ndimension() == 4); TORCH_CHECK(weight.ndimension() == 4);
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(offset.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
at::DeviceGuard guard(input.device()); at::DeviceGuard guard(input.device());
...@@ -280,15 +281,6 @@ at::Tensor DeformConv2d_forward_cuda( ...@@ -280,15 +281,6 @@ at::Tensor DeformConv2d_forward_cuda(
int weight_h = weight.size(2); int weight_h = weight.size(2);
int weight_w = weight.size(3); int weight_w = weight.size(3);
int stride_h = stride.first;
int stride_w = stride.second;
int pad_h = pad.first;
int pad_w = pad.second;
int dil_h = dilation.first;
int dil_w = dilation.second;
int ker_h = dil_h * (weight_h - 1) + 1; int ker_h = dil_h * (weight_h - 1) + 1;
int ker_w = dil_w * (weight_w - 1) + 1; int ker_w = dil_w * (weight_w - 1) + 1;
int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1;
...@@ -711,9 +703,12 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda( ...@@ -711,9 +703,12 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda(
at::Tensor weight, at::Tensor weight,
at::Tensor offset, at::Tensor offset,
at::Tensor grad_out, at::Tensor grad_out,
std::pair<int, int> stride, int stride_h,
std::pair<int, int> pad, int stride_w,
std::pair<int, int> dilation, int pad_h,
int pad_w,
int dil_h,
int dil_w,
int n_weight_grps, int n_weight_grps,
int n_offset_grps, int n_offset_grps,
int n_parallel_imgs) { int n_parallel_imgs) {
...@@ -730,15 +725,6 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda( ...@@ -730,15 +725,6 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda(
int weight_h = weight.size(2); int weight_h = weight.size(2);
int weight_w = weight.size(3); int weight_w = weight.size(3);
int stride_h = stride.first;
int stride_w = stride.second;
int pad_h = pad.first;
int pad_w = pad.second;
int dil_h = dilation.first;
int dil_w = dilation.second;
long out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) / stride_w + 1; long out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) / stride_w + 1;
long out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) / stride_h + 1; long out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) / stride_h + 1;
...@@ -841,9 +827,12 @@ static at::Tensor deform_conv_backward_parameters_cuda( ...@@ -841,9 +827,12 @@ static at::Tensor deform_conv_backward_parameters_cuda(
const at::Tensor& weight, const at::Tensor& weight,
at::Tensor offset, at::Tensor offset,
const at::Tensor& grad_out, const at::Tensor& grad_out,
std::pair<int, int> stride, int stride_h,
std::pair<int, int> pad, int stride_w,
std::pair<int, int> dilation, int pad_h,
int pad_w,
int dil_h,
int dil_w,
int n_weight_grps, int n_weight_grps,
int n_offset_grps, int n_offset_grps,
int n_parallel_imgs) { int n_parallel_imgs) {
...@@ -860,15 +849,6 @@ static at::Tensor deform_conv_backward_parameters_cuda( ...@@ -860,15 +849,6 @@ static at::Tensor deform_conv_backward_parameters_cuda(
int weight_h = weight.size(2); int weight_h = weight.size(2);
int weight_w = weight.size(3); int weight_w = weight.size(3);
int stride_h = stride.first;
int stride_w = stride.second;
int pad_h = pad.first;
int pad_w = pad.second;
int dil_h = dilation.first;
int dil_w = dilation.second;
long out_h = grad_out.size(2); long out_h = grad_out.size(2);
long out_w = grad_out.size(3); long out_w = grad_out.size(3);
...@@ -946,16 +926,25 @@ static at::Tensor deform_conv_backward_parameters_cuda( ...@@ -946,16 +926,25 @@ static at::Tensor deform_conv_backward_parameters_cuda(
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cuda( DeformConv2d_backward_cuda(
const at::Tensor& grad_out, const at::Tensor& grad_out_param,
const at::Tensor& input, const at::Tensor& input_param,
const at::Tensor& weight, const at::Tensor& weight_param,
const at::Tensor& offset, const at::Tensor& offset_param,
const at::Tensor& bias, const at::Tensor& bias_param,
std::pair<int, int> stride, int64_t stride_h,
std::pair<int, int> pad, int64_t stride_w,
std::pair<int, int> dilation, int64_t pad_h,
int n_weight_grps, int64_t pad_w,
int n_offset_grps) { int64_t dil_h,
int64_t dil_w,
int64_t n_weight_grps,
int64_t n_offset_grps) {
at::Tensor grad_out = grad_out_param.contiguous();
at::Tensor input = input_param.contiguous();
at::Tensor weight = weight_param.contiguous();
at::Tensor offset = offset_param.contiguous();
at::Tensor bias = bias_param.contiguous();
const int batch_sz = input.size(0); const int batch_sz = input.size(0);
const int n_parallel_imgs = const int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
...@@ -965,9 +954,12 @@ DeformConv2d_backward_cuda( ...@@ -965,9 +954,12 @@ DeformConv2d_backward_cuda(
weight, weight,
offset, offset,
grad_out, grad_out,
stride, stride_h,
pad, stride_w,
dilation, pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps, n_weight_grps,
n_offset_grps, n_offset_grps,
n_parallel_imgs); n_parallel_imgs);
...@@ -980,9 +972,12 @@ DeformConv2d_backward_cuda( ...@@ -980,9 +972,12 @@ DeformConv2d_backward_cuda(
weight, weight,
offset, offset,
grad_out, grad_out,
stride, stride_h,
pad, stride_w,
dilation, pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps, n_weight_grps,
n_offset_grps, n_offset_grps,
n_parallel_imgs); n_parallel_imgs);
......
...@@ -93,11 +93,14 @@ VISION_API at::Tensor DeformConv2d_forward_cuda( ...@@ -93,11 +93,14 @@ VISION_API at::Tensor DeformConv2d_forward_cuda(
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& bias, const at::Tensor& bias,
std::pair<int, int> stride, int64_t stride_h,
std::pair<int, int> pad, int64_t stride_w,
std::pair<int, int> dilation, int64_t pad_h,
int groups, int64_t pad_w,
int deformable_groups); int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
VISION_API std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> VISION_API std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cuda( DeformConv2d_backward_cuda(
...@@ -106,8 +109,11 @@ DeformConv2d_backward_cuda( ...@@ -106,8 +109,11 @@ DeformConv2d_backward_cuda(
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& bias, const at::Tensor& bias,
std::pair<int, int> stride, int64_t stride_h,
std::pair<int, int> pad, int64_t stride_w,
std::pair<int, int> dilation, int64_t pad_h,
int groups, int64_t pad_w,
int deformable_groups); int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
...@@ -54,13 +54,18 @@ TORCH_LIBRARY(torchvision, m) { ...@@ -54,13 +54,18 @@ TORCH_LIBRARY(torchvision, m) {
m.def("_new_empty_tensor_op", &new_empty_tensor); m.def("_new_empty_tensor_op", &new_empty_tensor);
m.def("ps_roi_align", &ps_roi_align); m.def("ps_roi_align", &ps_roi_align);
m.def("ps_roi_pool", &ps_roi_pool); m.def("ps_roi_pool", &ps_roi_pool);
m.def("deform_conv2d", &deform_conv2d); m.def(
"deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor");
m.def(
"_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> (Tensor, Tensor, Tensor, Tensor)");
m.def("_cuda_version", &vision::cuda_version); m.def("_cuda_version", &vision::cuda_version);
} }
TORCH_LIBRARY_IMPL(torchvision, CPU, m) { TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("roi_align", ROIAlign_forward_cpu); m.impl("roi_align", ROIAlign_forward_cpu);
m.impl("_roi_align_backward", ROIAlign_backward_cpu); m.impl("_roi_align_backward", ROIAlign_backward_cpu);
m.impl("deform_conv2d", DeformConv2d_forward_cpu);
m.impl("_deform_conv2d_backward", DeformConv2d_backward_cpu);
m.impl("nms", nms_cpu); m.impl("nms", nms_cpu);
} }
...@@ -69,6 +74,8 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) { ...@@ -69,6 +74,8 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("roi_align", ROIAlign_forward_cuda); m.impl("roi_align", ROIAlign_forward_cuda);
m.impl("_roi_align_backward", ROIAlign_backward_cuda); m.impl("_roi_align_backward", ROIAlign_backward_cuda);
m.impl("deform_conv2d", DeformConv2d_forward_cuda);
m.impl("_deform_conv2d_backward", DeformConv2d_backward_cuda);
m.impl("nms", nms_cuda); m.impl("nms", nms_cuda);
} }
#endif #endif
...@@ -77,6 +84,7 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { ...@@ -77,6 +84,7 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
#if defined(WITH_CUDA) || defined(WITH_HIP) #if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("roi_align", ROIAlign_autocast); m.impl("roi_align", ROIAlign_autocast);
m.impl("deform_conv2d", DeformConv2d_autocast);
m.impl("nms", nms_autocast); m.impl("nms", nms_autocast);
} }
#endif #endif
...@@ -84,4 +92,6 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { ...@@ -84,4 +92,6 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_align", ROIAlign_autograd); m.impl("roi_align", ROIAlign_autograd);
m.impl("_roi_align_backward", ROIAlign_backward_autograd); m.impl("_roi_align_backward", ROIAlign_backward_autograd);
m.impl("deform_conv2d", DeformConv2d_autograd);
m.impl("_deform_conv2d_backward", DeformConv2d_backward_autograd);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment