Unverified Commit 0e5aee46 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

DeformConv code cleanup (#2905)

* Clean up and refactor DeformConv implementation:
- Remove primitive const declaration from method names.
- Passing as const ref instead of value where possible.
- Aligning method names between cpu and cuda.

* Adding newline.

* Adding back include for cpu.

* Restoring method names of private methods to avoid conflicts.

* Restore include headers.
parent 45e027c7
#pragma once #pragma once
#if defined(WITH_CUDA) || defined(WITH_HIP) #include "cpu/vision_cpu.h"
#ifdef WITH_CUDA
#include "autocast.h"
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
#include "autocast.h" #include "autocast.h"
#include "hip/vision_cuda.h"
#endif #endif
// TODO: put this stuff in torchvision namespace // TODO: put this stuff in torchvision namespace
...@@ -11,14 +18,14 @@ at::Tensor deform_conv2d( ...@@ -11,14 +18,14 @@ at::Tensor deform_conv2d(
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& bias, const at::Tensor& bias,
const int64_t stride_h, int64_t stride_h,
const int64_t stride_w, int64_t stride_w,
const int64_t pad_h, int64_t pad_h,
const int64_t pad_w, int64_t pad_w,
const int64_t dilation_h, int64_t dilation_h,
const int64_t dilation_w, int64_t dilation_w,
const int64_t groups, int64_t groups,
const int64_t offset_groups) { int64_t offset_groups) {
static auto op = c10::Dispatcher::singleton() static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::deform_conv2d", "") .findSchemaOrThrow("torchvision::deform_conv2d", "")
.typed<decltype(deform_conv2d)>(); .typed<decltype(deform_conv2d)>();
...@@ -43,14 +50,14 @@ at::Tensor DeformConv2d_autocast( ...@@ -43,14 +50,14 @@ at::Tensor DeformConv2d_autocast(
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& bias, const at::Tensor& bias,
const int64_t stride_h, int64_t stride_h,
const int64_t stride_w, int64_t stride_w,
const int64_t pad_h, int64_t pad_h,
const int64_t pad_w, int64_t pad_w,
const int64_t dilation_h, int64_t dilation_h,
const int64_t dilation_w, int64_t dilation_w,
const int64_t groups, int64_t groups,
const int64_t offset_groups) { int64_t offset_groups) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return deform_conv2d( return deform_conv2d(
at::autocast::cached_cast(at::kFloat, input), at::autocast::cached_cast(at::kFloat, input),
...@@ -76,14 +83,14 @@ _deform_conv2d_backward( ...@@ -76,14 +83,14 @@ _deform_conv2d_backward(
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& bias, const at::Tensor& bias,
const int64_t stride_h, int64_t stride_h,
const int64_t stride_w, int64_t stride_w,
const int64_t pad_h, int64_t pad_h,
const int64_t pad_w, int64_t pad_w,
const int64_t dilation_h, int64_t dilation_h,
const int64_t dilation_w, int64_t dilation_w,
const int64_t groups, int64_t groups,
const int64_t offset_groups) { int64_t offset_groups) {
static auto op = static auto op =
c10::Dispatcher::singleton() c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_deform_conv2d_backward", "") .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
...@@ -109,10 +116,10 @@ class DeformConv2dFunction ...@@ -109,10 +116,10 @@ class DeformConv2dFunction
public: public:
static torch::autograd::variable_list forward( static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx, torch::autograd::AutogradContext* ctx,
torch::autograd::Variable input, const torch::autograd::Variable& input,
torch::autograd::Variable weight, const torch::autograd::Variable& weight,
torch::autograd::Variable offset, const torch::autograd::Variable& offset,
torch::autograd::Variable bias, const torch::autograd::Variable& bias,
int64_t stride_h, int64_t stride_h,
int64_t stride_w, int64_t stride_w,
int64_t pad_h, int64_t pad_h,
...@@ -121,7 +128,7 @@ class DeformConv2dFunction ...@@ -121,7 +128,7 @@ class DeformConv2dFunction
int64_t dilation_w, int64_t dilation_w,
int64_t groups, int64_t groups,
int64_t offset_groups) { int64_t offset_groups) {
at::AutoNonVariableTypeMode g; // TODO_vv: check if necessary at::AutoNonVariableTypeMode g;
auto output = deform_conv2d( auto output = deform_conv2d(
input, input,
weight, weight,
...@@ -153,7 +160,7 @@ class DeformConv2dFunction ...@@ -153,7 +160,7 @@ class DeformConv2dFunction
static torch::autograd::variable_list backward( static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx, torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) { const torch::autograd::variable_list& grad_output) {
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto input = saved[0]; auto input = saved[0];
auto weight = saved[1]; auto weight = saved[1];
...@@ -211,19 +218,19 @@ class DeformConv2dBackwardFunction ...@@ -211,19 +218,19 @@ class DeformConv2dBackwardFunction
public: public:
static torch::autograd::variable_list forward( static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx, torch::autograd::AutogradContext* ctx,
torch::autograd::Variable grad, const torch::autograd::Variable& grad,
torch::autograd::Variable input, const torch::autograd::Variable& input,
torch::autograd::Variable weight, const torch::autograd::Variable& weight,
torch::autograd::Variable offset, const torch::autograd::Variable& offset,
torch::autograd::Variable bias, const torch::autograd::Variable& bias,
const int64_t stride_h, int64_t stride_h,
const int64_t stride_w, int64_t stride_w,
const int64_t pad_h, int64_t pad_h,
const int64_t pad_w, int64_t pad_w,
const int64_t dilation_h, int64_t dilation_h,
const int64_t dilation_w, int64_t dilation_w,
const int64_t groups, int64_t groups,
const int64_t offset_groups) { int64_t offset_groups) {
at::AutoNonVariableTypeMode g; at::AutoNonVariableTypeMode g;
auto result = _deform_conv2d_backward( auto result = _deform_conv2d_backward(
grad, grad,
...@@ -255,7 +262,7 @@ class DeformConv2dBackwardFunction ...@@ -255,7 +262,7 @@ class DeformConv2dBackwardFunction
static torch::autograd::variable_list backward( static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx, torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) { const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on deform_conv2d not supported"); TORCH_CHECK(0, "double backwards on deform_conv2d not supported");
} }
}; };
...@@ -265,14 +272,14 @@ at::Tensor DeformConv2d_autograd( ...@@ -265,14 +272,14 @@ at::Tensor DeformConv2d_autograd(
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& bias, const at::Tensor& bias,
const int64_t stride_h, int64_t stride_h,
const int64_t stride_w, int64_t stride_w,
const int64_t pad_h, int64_t pad_h,
const int64_t pad_w, int64_t pad_w,
const int64_t dilation_h, int64_t dilation_h,
const int64_t dilation_w, int64_t dilation_w,
const int64_t groups, int64_t groups,
const int64_t offset_groups) { int64_t offset_groups) {
return DeformConv2dFunction::apply( return DeformConv2dFunction::apply(
input, input,
weight, weight,
...@@ -295,14 +302,14 @@ DeformConv2d_backward_autograd( ...@@ -295,14 +302,14 @@ DeformConv2d_backward_autograd(
const at::Tensor& weight, const at::Tensor& weight,
const at::Tensor& offset, const at::Tensor& offset,
const at::Tensor& bias, const at::Tensor& bias,
const int64_t stride_h, int64_t stride_h,
const int64_t stride_w, int64_t stride_w,
const int64_t pad_h, int64_t pad_h,
const int64_t pad_w, int64_t pad_w,
const int64_t dilation_h, int64_t dilation_h,
const int64_t dilation_w, int64_t dilation_w,
const int64_t groups, int64_t groups,
const int64_t offset_groups) { int64_t offset_groups) {
auto result = DeformConv2dBackwardFunction::apply( auto result = DeformConv2dBackwardFunction::apply(
grad, grad,
input, input,
...@@ -317,5 +324,6 @@ DeformConv2d_backward_autograd( ...@@ -317,5 +324,6 @@ DeformConv2d_backward_autograd(
dilation_w, dilation_w,
groups, groups,
offset_groups); offset_groups);
return std::make_tuple(result[0], result[1], result[2], result[3]); return std::make_tuple(result[0], result[1], result[2], result[3]);
} }
\ No newline at end of file
...@@ -79,8 +79,8 @@ const int kMaxParallelImgs = 32; ...@@ -79,8 +79,8 @@ const int kMaxParallelImgs = 32;
template <typename scalar_t> template <typename scalar_t>
static scalar_t bilinear_interpolate( static scalar_t bilinear_interpolate(
const scalar_t* in, const scalar_t* in,
const int height, int height,
const int width, int width,
scalar_t h, scalar_t h,
scalar_t w) { scalar_t w) {
if (h <= -1 || height <= h || w <= -1 || width <= w) { if (h <= -1 || height <= h || w <= -1 || width <= w) {
...@@ -117,24 +117,24 @@ static scalar_t bilinear_interpolate( ...@@ -117,24 +117,24 @@ static scalar_t bilinear_interpolate(
template <typename scalar_t> template <typename scalar_t>
static void deformable_im2col_kernel( static void deformable_im2col_kernel(
const int n, int n,
const scalar_t* input, const scalar_t* input,
const scalar_t* offset, const scalar_t* offset,
const int height, int height,
const int width, int width,
const int weight_h, int weight_h,
const int weight_w, int weight_w,
const int pad_h, int pad_h,
const int pad_w, int pad_w,
const int stride_h, int stride_h,
const int stride_w, int stride_w,
const int dil_h, int dil_h,
const int dil_w, int dil_w,
const int batch_sz, int batch_sz,
const int n_in_channels, int n_in_channels,
const int n_offset_grps, int n_offset_grps,
const int out_h, int out_h,
const int out_w, int out_w,
scalar_t* columns) { scalar_t* columns) {
for (int index = 0; index != n; ++index) { for (int index = 0; index != n; ++index) {
const int out_x = index % out_w; const int out_x = index % out_w;
...@@ -174,8 +174,8 @@ static void deformable_im2col_kernel( ...@@ -174,8 +174,8 @@ static void deformable_im2col_kernel(
} }
static void deformable_im2col( static void deformable_im2col(
const at::Tensor input, const at::Tensor& input,
const at::Tensor data_offset, const at::Tensor& data_offset,
int n_in_channels, int n_in_channels,
int height, int height,
int width, int width,
...@@ -403,24 +403,24 @@ at::Tensor DeformConv2d_forward_cpu( ...@@ -403,24 +403,24 @@ at::Tensor DeformConv2d_forward_cpu(
template <typename scalar_t> template <typename scalar_t>
static void deformable_col2im_kernel( static void deformable_col2im_kernel(
const int n, int n,
const scalar_t* col, const scalar_t* col,
const scalar_t* offset, const scalar_t* offset,
const int channels, int channels,
const int height, int height,
const int width, int width,
const int kernel_h, int kernel_h,
const int kernel_w, int kernel_w,
const int pad_h, int pad_h,
const int pad_w, int pad_w,
const int stride_h, int stride_h,
const int stride_w, int stride_w,
const int dilation_h, int dilation_h,
const int dilation_w, int dilation_w,
const int batch_sz, int batch_sz,
const int n_offset_grps, int n_offset_grps,
const int out_h, int out_h,
const int out_w, int out_w,
scalar_t* grad_im) { scalar_t* grad_im) {
for (int index = 0; index != n; ++index) { for (int index = 0; index != n; ++index) {
const int out_x = index % out_w; const int out_x = index % out_w;
...@@ -461,21 +461,21 @@ static void deformable_col2im_kernel( ...@@ -461,21 +461,21 @@ static void deformable_col2im_kernel(
} }
static void compute_grad_input( static void compute_grad_input(
const at::Tensor columns, const at::Tensor& columns,
const at::Tensor offset, const at::Tensor& offset,
const int channels, int channels,
const int height, int height,
const int width, int width,
const int weight_h, int weight_h,
const int weight_w, int weight_w,
const int pad_h, int pad_h,
const int pad_w, int pad_w,
const int stride_h, int stride_h,
const int stride_w, int stride_w,
const int dilation_h, int dilation_h,
const int dilation_w, int dilation_w,
const int parallel_imgs, int parallel_imgs,
const int n_offset_grps, int n_offset_grps,
at::Tensor grad_im) { at::Tensor grad_im) {
int out_h = int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
...@@ -512,8 +512,8 @@ static void compute_grad_input( ...@@ -512,8 +512,8 @@ static void compute_grad_input(
template <typename scalar_t> template <typename scalar_t>
static scalar_t get_coordinate_weight( static scalar_t get_coordinate_weight(
const scalar_t* im_data, const scalar_t* im_data,
const int height, int height,
const int width, int width,
scalar_t y, scalar_t y,
scalar_t x, scalar_t x,
bool is_y_direction) { bool is_y_direction) {
...@@ -544,26 +544,26 @@ static scalar_t get_coordinate_weight( ...@@ -544,26 +544,26 @@ static scalar_t get_coordinate_weight(
template <typename scalar_t> template <typename scalar_t>
static void deformable_col2im_coord_kernel( static void deformable_col2im_coord_kernel(
const int n, int n,
const scalar_t* col, const scalar_t* col,
const scalar_t* im, const scalar_t* im,
const scalar_t* offset, const scalar_t* offset,
const int channels, int channels,
const int height, int height,
const int width, int width,
const int weight_h, int weight_h,
const int weight_w, int weight_w,
const int pad_h, int pad_h,
const int pad_w, int pad_w,
const int stride_h, int stride_h,
const int stride_w, int stride_w,
const int dilation_h, int dilation_h,
const int dilation_w, int dilation_w,
const int batch_sz, int batch_sz,
const int offset_channels, int offset_channels,
const int n_offset_grps, int n_offset_grps,
const int out_h, int out_h,
const int out_w, int out_w,
scalar_t* grad_offset) { scalar_t* grad_offset) {
for (int index = 0; index != n; ++index) { for (int index = 0; index != n; ++index) {
scalar_t val = 0; scalar_t val = 0;
...@@ -619,22 +619,22 @@ static void deformable_col2im_coord_kernel( ...@@ -619,22 +619,22 @@ static void deformable_col2im_coord_kernel(
} }
static void compute_grad_offset( static void compute_grad_offset(
const at::Tensor columns, const at::Tensor& columns,
const at::Tensor input, const at::Tensor& input,
const at::Tensor offset, const at::Tensor& offset,
const int channels, int channels,
const int height, int height,
const int width, int width,
const int weight_h, int weight_h,
const int weight_w, int weight_w,
const int pad_h, int pad_h,
const int pad_w, int pad_w,
const int stride_h, int stride_h,
const int stride_w, int stride_w,
const int dilation_h, int dilation_h,
const int dilation_w, int dilation_w,
const int parallel_imgs, int parallel_imgs,
const int n_offset_grps, int n_offset_grps,
at::Tensor grad_offset) { at::Tensor grad_offset) {
int out_h = int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
......
...@@ -89,8 +89,8 @@ inline unsigned int GET_BLOCKS(const unsigned int N) { ...@@ -89,8 +89,8 @@ inline unsigned int GET_BLOCKS(const unsigned int N) {
template <typename scalar_t> template <typename scalar_t>
__device__ scalar_t bilinear_interpolate( __device__ scalar_t bilinear_interpolate(
const scalar_t* in, const scalar_t* in,
const int height, int height,
const int width, int width,
scalar_t h, scalar_t h,
scalar_t w) { scalar_t w) {
if (h <= -1 || height <= h || w <= -1 || width <= w) { if (h <= -1 || height <= h || w <= -1 || width <= w) {
...@@ -127,24 +127,24 @@ __device__ scalar_t bilinear_interpolate( ...@@ -127,24 +127,24 @@ __device__ scalar_t bilinear_interpolate(
template <typename scalar_t> template <typename scalar_t>
__global__ void deformable_im2col_gpu_kernel( __global__ void deformable_im2col_gpu_kernel(
const int n, int n,
const scalar_t* input_ptr, const scalar_t* input_ptr,
const scalar_t* offset_ptr, const scalar_t* offset_ptr,
const int height, int height,
const int width, int width,
const int weight_h, int weight_h,
const int weight_w, int weight_w,
const int pad_h, int pad_h,
const int pad_w, int pad_w,
const int stride_h, int stride_h,
const int stride_w, int stride_w,
const int dil_h, int dil_h,
const int dil_w, int dil_w,
const int batch_sz, int batch_sz,
const int n_in_channels, int n_in_channels,
const int n_offset_grps, int n_offset_grps,
const int out_h, int out_h,
const int out_w, int out_w,
scalar_t* columns_ptr) { scalar_t* columns_ptr) {
CUDA_1D_KERNEL_LOOP(index, n) { CUDA_1D_KERNEL_LOOP(index, n) {
const int out_x = index % out_w; const int out_x = index % out_w;
...@@ -183,8 +183,8 @@ __global__ void deformable_im2col_gpu_kernel( ...@@ -183,8 +183,8 @@ __global__ void deformable_im2col_gpu_kernel(
} }
static void deformable_im2col( static void deformable_im2col(
const at::Tensor input, const at::Tensor& input,
const at::Tensor data_offset, const at::Tensor& data_offset,
int n_in_channels, int n_in_channels,
int height, int height,
int width, int width,
...@@ -420,24 +420,24 @@ at::Tensor DeformConv2d_forward_cuda( ...@@ -420,24 +420,24 @@ at::Tensor DeformConv2d_forward_cuda(
template <typename scalar_t> template <typename scalar_t>
__global__ void deformable_col2im_gpu_kernel( __global__ void deformable_col2im_gpu_kernel(
const int n, int n,
const scalar_t* col, const scalar_t* col,
const scalar_t* offset_ptr, const scalar_t* offset_ptr,
const int channels, int channels,
const int height, int height,
const int width, int width,
const int kernel_h, int kernel_h,
const int kernel_w, int kernel_w,
const int pad_h, int pad_h,
const int pad_w, int pad_w,
const int stride_h, int stride_h,
const int stride_w, int stride_w,
const int dilation_h, int dilation_h,
const int dilation_w, int dilation_w,
const int batch_sz, int batch_sz,
const int n_offset_grps, int n_offset_grps,
const int out_h, int out_h,
const int out_w, int out_w,
scalar_t* grad_im) { scalar_t* grad_im) {
CUDA_1D_KERNEL_LOOP(index, n) { CUDA_1D_KERNEL_LOOP(index, n) {
const int out_x = index % out_w; const int out_x = index % out_w;
...@@ -477,21 +477,21 @@ __global__ void deformable_col2im_gpu_kernel( ...@@ -477,21 +477,21 @@ __global__ void deformable_col2im_gpu_kernel(
} }
static void compute_grad_input( static void compute_grad_input(
const at::Tensor columns, const at::Tensor& columns,
const at::Tensor offset, const at::Tensor& offset,
const int channels, int channels,
const int height, int height,
const int width, int width,
const int weight_h, int weight_h,
const int weight_w, int weight_w,
const int pad_h, int pad_h,
const int pad_w, int pad_w,
const int stride_h, int stride_h,
const int stride_w, int stride_w,
const int dilation_h, int dilation_h,
const int dilation_w, int dilation_w,
const int parallel_imgs, int parallel_imgs,
const int n_offset_grps, int n_offset_grps,
at::Tensor grad_im) { at::Tensor grad_im) {
int out_h = int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
...@@ -535,8 +535,8 @@ static void compute_grad_input( ...@@ -535,8 +535,8 @@ static void compute_grad_input(
template <typename scalar_t> template <typename scalar_t>
__device__ scalar_t get_coordinate_weight( __device__ scalar_t get_coordinate_weight(
const scalar_t* im_data, const scalar_t* im_data,
const int height, int height,
const int width, int width,
scalar_t y, scalar_t y,
scalar_t x, scalar_t x,
bool is_y_direction) { bool is_y_direction) {
...@@ -567,26 +567,26 @@ __device__ scalar_t get_coordinate_weight( ...@@ -567,26 +567,26 @@ __device__ scalar_t get_coordinate_weight(
template <typename scalar_t> template <typename scalar_t>
__global__ void deformable_col2im_coord_gpu_kernel( __global__ void deformable_col2im_coord_gpu_kernel(
const int n, int n,
const scalar_t* col_ptr, const scalar_t* col_ptr,
const scalar_t* im_ptr, const scalar_t* im_ptr,
const scalar_t* offset_ptr, const scalar_t* offset_ptr,
const int channels, int channels,
const int height, int height,
const int width, int width,
const int weight_h, int weight_h,
const int weight_w, int weight_w,
const int pad_h, int pad_h,
const int pad_w, int pad_w,
const int stride_h, int stride_h,
const int stride_w, int stride_w,
const int dilation_h, int dilation_h,
const int dilation_w, int dilation_w,
const int batch_sz, int batch_sz,
const int offset_channels, int offset_channels,
const int n_offset_grps, int n_offset_grps,
const int out_h, int out_h,
const int out_w, int out_w,
scalar_t* grad_offset) { scalar_t* grad_offset) {
CUDA_1D_KERNEL_LOOP(index, n) { CUDA_1D_KERNEL_LOOP(index, n) {
scalar_t val = 0; scalar_t val = 0;
...@@ -640,22 +640,22 @@ __global__ void deformable_col2im_coord_gpu_kernel( ...@@ -640,22 +640,22 @@ __global__ void deformable_col2im_coord_gpu_kernel(
} }
static void compute_grad_offset( static void compute_grad_offset(
const at::Tensor columns, const at::Tensor& columns,
const at::Tensor input, const at::Tensor& input,
const at::Tensor offset, const at::Tensor& offset,
const int channels, int channels,
const int height, int height,
const int width, int width,
const int weight_h, int weight_h,
const int weight_w, int weight_w,
const int pad_h, int pad_h,
const int pad_w, int pad_w,
const int stride_h, int stride_h,
const int stride_w, int stride_w,
const int dilation_h, int dilation_h,
const int dilation_w, int dilation_w,
const int parallel_imgs, int parallel_imgs,
const int n_offset_grps, int n_offset_grps,
at::Tensor grad_offset) { at::Tensor grad_offset) {
int out_h = int out_h =
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
...@@ -698,7 +698,7 @@ static void compute_grad_offset( ...@@ -698,7 +698,7 @@ static void compute_grad_offset(
} }
} }
static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda( static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cuda(
at::Tensor input, at::Tensor input,
at::Tensor weight, at::Tensor weight,
at::Tensor offset, at::Tensor offset,
...@@ -822,7 +822,7 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda( ...@@ -822,7 +822,7 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv_backward_input_cuda(
return std::make_tuple(grad_input, grad_offset); return std::make_tuple(grad_input, grad_offset);
} }
static at::Tensor deform_conv_backward_parameters_cuda( static at::Tensor deform_conv2d_backward_parameters_cuda(
at::Tensor input, at::Tensor input,
const at::Tensor& weight, const at::Tensor& weight,
at::Tensor offset, at::Tensor offset,
...@@ -949,7 +949,7 @@ DeformConv2d_backward_cuda( ...@@ -949,7 +949,7 @@ DeformConv2d_backward_cuda(
const int n_parallel_imgs = const int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs); get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
auto grad_input_and_offset = deform_conv_backward_input_cuda( auto grad_input_and_offset = deform_conv2d_backward_input_cuda(
input, input,
weight, weight,
offset, offset,
...@@ -967,7 +967,7 @@ DeformConv2d_backward_cuda( ...@@ -967,7 +967,7 @@ DeformConv2d_backward_cuda(
auto grad_input = std::get<0>(grad_input_and_offset); auto grad_input = std::get<0>(grad_input_and_offset);
auto grad_offset = std::get<1>(grad_input_and_offset); auto grad_offset = std::get<1>(grad_input_and_offset);
auto grad_weight = deform_conv_backward_parameters_cuda( auto grad_weight = deform_conv2d_backward_parameters_cuda(
input, input,
weight, weight,
offset, offset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment