Unverified Commit 0963ff71 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Move autograd implementations on separate files. (#3154)

parent 07a9c956
......@@ -52,7 +52,8 @@ include(GNUInstallDirs)
include(CMakePackageConfigHelpers)
set(TVCPP torchvision/csrc)
list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops ${TVCPP}/ops/cpu)
list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops
${TVCPP}/ops/autograd ${TVCPP}/ops/cpu)
if(WITH_CUDA)
list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast)
endif()
......
......@@ -136,7 +136,8 @@ def get_extensions():
main_file = glob.glob(os.path.join(extensions_dir, '*.cpp')) + glob.glob(os.path.join(extensions_dir, 'ops',
'*.cpp'))
source_cpu = glob.glob(os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp'))
source_cpu = glob.glob(os.path.join(extensions_dir, 'ops', 'autograd', '*.cpp')) + glob.glob(
os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp'))
is_rocm_pytorch = False
if torch.__version__ >= '1.5':
......
#include "../deform_conv2d.h"
#include <torch/autograd.h>
#include <torch/types.h>
namespace vision {
namespace ops {
namespace {
class DeformConv2dFunction
: public torch::autograd::Function<DeformConv2dFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask) {
at::AutoNonVariableTypeMode g;
auto output = deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
ctx->save_for_backward({input, weight, offset, mask, bias});
ctx->saved_data["stride_h"] = stride_h;
ctx->saved_data["stride_w"] = stride_w;
ctx->saved_data["pad_h"] = pad_h;
ctx->saved_data["pad_w"] = pad_w;
ctx->saved_data["dilation_h"] = dilation_h;
ctx->saved_data["dilation_w"] = dilation_w;
ctx->saved_data["groups"] = groups;
ctx->saved_data["offset_groups"] = offset_groups;
ctx->saved_data["use_mask"] = use_mask;
return {
output,
};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
auto saved = ctx->get_saved_variables();
auto input = saved[0];
auto weight = saved[1];
auto offset = saved[2];
auto mask = saved[3];
auto bias = saved[4];
auto stride_h = ctx->saved_data["stride_h"].toInt();
auto stride_w = ctx->saved_data["stride_w"].toInt();
auto pad_h = ctx->saved_data["pad_h"].toInt();
auto pad_w = ctx->saved_data["pad_w"].toInt();
auto dilation_h = ctx->saved_data["dilation_h"].toInt();
auto dilation_w = ctx->saved_data["dilation_w"].toInt();
auto groups = ctx->saved_data["groups"].toInt();
auto offset_groups = ctx->saved_data["offset_groups"].toInt();
auto use_mask = ctx->saved_data["use_mask"].toBool();
auto grads = detail::_deform_conv2d_backward(
grad_output[0],
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
auto grad_input = std::get<0>(grads);
auto grad_weight = std::get<1>(grads);
auto grad_offset = std::get<2>(grads);
auto grad_mask = std::get<3>(grads);
auto grad_bias = std::get<4>(grads);
return {
grad_input,
grad_weight,
grad_offset,
grad_mask,
grad_bias,
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
};
}
};
// TODO: There should be an easier way to do this
class DeformConv2dBackwardFunction
: public torch::autograd::Function<DeformConv2dBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& grad,
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask) {
at::AutoNonVariableTypeMode g;
auto result = detail::_deform_conv2d_backward(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
auto grad_input = std::get<0>(result);
auto grad_weight = std::get<1>(result);
auto grad_offset = std::get<2>(result);
auto grad_mask = std::get<3>(result);
auto grad_bias = std::get<4>(result);
return {
grad_input,
grad_weight,
grad_offset,
grad_mask,
grad_bias,
};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on deform_conv2d not supported");
}
};
at::Tensor deform_conv2d_autograd(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask) {
return DeformConv2dFunction::apply(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask)[0];
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
deform_conv2d_backward_autograd(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask) {
auto result = DeformConv2dBackwardFunction::apply(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
return std::make_tuple(result[0], result[1], result[2], result[3], result[4]);
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("deform_conv2d", deform_conv2d_autograd);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
}
} // namespace ops
} // namespace vision
#include "../ps_roi_align.h"
#include <torch/autograd.h>
#include <torch/types.h>
namespace vision {
namespace ops {
namespace {
class PSROIAlignFunction
: public torch::autograd::Function<PSROIAlignFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& input,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["input_shape"] = input.sizes();
at::AutoNonVariableTypeMode g;
auto result = ps_roi_align(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
auto output = std::get<0>(result);
auto channel_mapping = std::get<1>(result);
ctx->save_for_backward({rois, channel_mapping});
ctx->mark_non_differentiable({channel_mapping});
return {output, channel_mapping};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto channel_mapping = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = detail::_ps_roi_align_backward(
grad_output[0],
rois,
channel_mapping,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
ctx->saved_data["sampling_ratio"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
return {grad_in,
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable()};
}
};
// TODO: There should be an easier way to do this
class PSROIAlignBackwardFunction
: public torch::autograd::Function<PSROIAlignBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& grad,
const torch::autograd::Variable& rois,
const torch::autograd::Variable& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
at::AutoNonVariableTypeMode g;
auto grad_in = detail::_ps_roi_align_backward(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width);
return {grad_in};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on ps_roi_align not supported");
}
};
std::tuple<at::Tensor, at::Tensor> ps_roi_align_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio) {
auto result = PSROIAlignFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
return std::make_tuple(result[0], result[1]);
}
at::Tensor ps_roi_align_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
return PSROIAlignBackwardFunction::apply(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width)[0];
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("ps_roi_align", ps_roi_align_autograd);
m.impl("_ps_roi_align_backward", ps_roi_align_backward_autograd);
}
} // namespace ops
} // namespace vision
#include "../ps_roi_pool.h"
#include <torch/autograd.h>
#include <torch/types.h>
namespace vision {
namespace ops {
namespace {
class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& input,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["input_shape"] = input.sizes();
at::AutoNonVariableTypeMode g;
auto result =
ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width);
auto output = std::get<0>(result);
auto channel_mapping = std::get<1>(result);
ctx->save_for_backward({rois, channel_mapping});
ctx->mark_non_differentiable({channel_mapping});
return {output, channel_mapping};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto channel_mapping = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = detail::_ps_roi_pool_backward(
grad_output[0],
rois,
channel_mapping,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
return {grad_in,
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable()};
}
};
// TODO: There should be an easier way to do this
class PSROIPoolBackwardFunction
: public torch::autograd::Function<PSROIPoolBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& grad,
const torch::autograd::Variable& rois,
const torch::autograd::Variable& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
at::AutoNonVariableTypeMode g;
auto grad_in = detail::_ps_roi_pool_backward(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width);
return {grad_in};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on ps_roi_pool not supported");
}
};
std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
auto result = PSROIPoolFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width);
return std::make_tuple(result[0], result[1]);
}
at::Tensor ps_roi_pool_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
return PSROIPoolBackwardFunction::apply(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width)[0];
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("ps_roi_pool", ps_roi_pool_autograd);
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_autograd);
}
} // namespace ops
} // namespace vision
#include "../roi_align.h"
#include <torch/autograd.h>
#include <torch/types.h>
namespace vision {
namespace ops {
namespace {
class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& input,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["aligned"] = aligned;
ctx->saved_data["input_shape"] = input.sizes();
ctx->save_for_backward({rois});
at::AutoNonVariableTypeMode g;
auto result = roi_align(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned);
return {result};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = detail::_roi_align_backward(
grad_output[0],
rois,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
ctx->saved_data["sampling_ratio"].toInt(),
ctx->saved_data["aligned"].toBool());
return {grad_in,
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable()};
}
};
// TODO: There should be an easier way to do this
class ROIAlignBackwardFunction
: public torch::autograd::Function<ROIAlignBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& grad,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned) {
at::AutoNonVariableTypeMode g;
auto result = detail::_roi_align_backward(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio,
aligned);
return {result};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on roi_align not supported");
}
};
at::Tensor roi_align_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
return ROIAlignFunction::apply(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned)[0];
}
at::Tensor roi_align_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned) {
return ROIAlignBackwardFunction::apply(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio,
aligned)[0];
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_align", roi_align_autograd);
m.impl("_roi_align_backward", roi_align_backward_autograd);
}
} // namespace ops
} // namespace vision
#include "../roi_pool.h"
#include <torch/autograd.h>
#include <torch/types.h>
namespace vision {
namespace ops {
namespace {
class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& input,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["input_shape"] = input.sizes();
at::AutoNonVariableTypeMode g;
auto result =
roi_pool(input, rois, spatial_scale, pooled_height, pooled_width);
auto output = std::get<0>(result);
auto argmax = std::get<1>(result);
ctx->save_for_backward({rois, argmax});
ctx->mark_non_differentiable({argmax});
return {output, argmax};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto argmax = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = detail::_roi_pool_backward(
grad_output[0],
rois,
argmax,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
return {grad_in,
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable()};
}
};
// TODO: There should be an easier way to do this
class ROIPoolBackwardFunction
: public torch::autograd::Function<ROIPoolBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& grad,
const torch::autograd::Variable& rois,
const torch::autograd::Variable& argmax,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
at::AutoNonVariableTypeMode g;
auto grad_in = detail::_roi_pool_backward(
grad,
rois,
argmax,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width);
return {grad_in};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on roi_pool not supported");
}
};
std::tuple<at::Tensor, at::Tensor> roi_pool_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
auto result = ROIPoolFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width);
return std::make_tuple(result[0], result[1]);
}
at::Tensor roi_pool_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
return ROIPoolBackwardFunction::apply(
grad,
rois,
argmax,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width)[0];
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_pool", roi_pool_autograd);
m.impl("_roi_pool_backward", roi_pool_backward_autograd);
}
} // namespace ops
} // namespace vision
#include "deform_conv2d.h"
#include <torch/autograd.h>
#include <torch/types.h>
namespace vision {
......@@ -41,6 +40,8 @@ at::Tensor deform_conv2d(
use_mask);
}
namespace detail {
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward(
const at::Tensor& grad,
......@@ -80,6 +81,8 @@ _deform_conv2d_backward(
use_mask);
}
} // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor");
......@@ -87,257 +90,5 @@ TORCH_LIBRARY_FRAGMENT(torchvision, m) {
"_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
}
namespace {
class DeformConv2dFunction
: public torch::autograd::Function<DeformConv2dFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask) {
at::AutoNonVariableTypeMode g;
auto output = deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
ctx->save_for_backward({input, weight, offset, mask, bias});
ctx->saved_data["stride_h"] = stride_h;
ctx->saved_data["stride_w"] = stride_w;
ctx->saved_data["pad_h"] = pad_h;
ctx->saved_data["pad_w"] = pad_w;
ctx->saved_data["dilation_h"] = dilation_h;
ctx->saved_data["dilation_w"] = dilation_w;
ctx->saved_data["groups"] = groups;
ctx->saved_data["offset_groups"] = offset_groups;
ctx->saved_data["use_mask"] = use_mask;
return {
output,
};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
auto saved = ctx->get_saved_variables();
auto input = saved[0];
auto weight = saved[1];
auto offset = saved[2];
auto mask = saved[3];
auto bias = saved[4];
auto stride_h = ctx->saved_data["stride_h"].toInt();
auto stride_w = ctx->saved_data["stride_w"].toInt();
auto pad_h = ctx->saved_data["pad_h"].toInt();
auto pad_w = ctx->saved_data["pad_w"].toInt();
auto dilation_h = ctx->saved_data["dilation_h"].toInt();
auto dilation_w = ctx->saved_data["dilation_w"].toInt();
auto groups = ctx->saved_data["groups"].toInt();
auto offset_groups = ctx->saved_data["offset_groups"].toInt();
auto use_mask = ctx->saved_data["use_mask"].toBool();
auto grads = _deform_conv2d_backward(
grad_output[0],
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
auto grad_input = std::get<0>(grads);
auto grad_weight = std::get<1>(grads);
auto grad_offset = std::get<2>(grads);
auto grad_mask = std::get<3>(grads);
auto grad_bias = std::get<4>(grads);
return {
grad_input,
grad_weight,
grad_offset,
grad_mask,
grad_bias,
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
};
}
};
// TODO: There should be an easier way to do this
class DeformConv2dBackwardFunction
: public torch::autograd::Function<DeformConv2dBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& grad,
const torch::autograd::Variable& input,
const torch::autograd::Variable& weight,
const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask) {
at::AutoNonVariableTypeMode g;
auto result = _deform_conv2d_backward(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
auto grad_input = std::get<0>(result);
auto grad_weight = std::get<1>(result);
auto grad_offset = std::get<2>(result);
auto grad_mask = std::get<3>(result);
auto grad_bias = std::get<4>(result);
return {
grad_input,
grad_weight,
grad_offset,
grad_mask,
grad_bias,
};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on deform_conv2d not supported");
}
};
at::Tensor deform_conv2d_autograd(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask) {
return DeformConv2dFunction::apply(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask)[0];
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
deform_conv2d_backward_autograd(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask) {
auto result = DeformConv2dBackwardFunction::apply(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
return std::make_tuple(result[0], result[1], result[2], result[3], result[4]);
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("deform_conv2d", deform_conv2d_autograd);
m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
}
} // namespace ops
} // namespace vision
......@@ -22,5 +22,28 @@ VISION_API at::Tensor deform_conv2d(
int64_t offset_groups,
bool use_mask);
namespace detail {
VISION_API
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
bool use_mask);
} // namespace detail
} // namespace ops
} // namespace vision
#include "nms.h"
#include <torch/autograd.h>
#include <torch/types.h>
namespace vision {
......
#include "ps_roi_align.h"
#include <torch/autograd.h>
#include <torch/types.h>
namespace vision {
......@@ -20,6 +19,8 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
}
namespace detail {
at::Tensor _ps_roi_align_backward(
const at::Tensor& grad,
const at::Tensor& rois,
......@@ -50,6 +51,8 @@ at::Tensor _ps_roi_align_backward(
width);
}
} // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)");
......@@ -57,157 +60,5 @@ TORCH_LIBRARY_FRAGMENT(torchvision, m) {
"_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor");
}
namespace {
class PSROIAlignFunction
: public torch::autograd::Function<PSROIAlignFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& input,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["input_shape"] = input.sizes();
at::AutoNonVariableTypeMode g;
auto result = ps_roi_align(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
auto output = std::get<0>(result);
auto channel_mapping = std::get<1>(result);
ctx->save_for_backward({rois, channel_mapping});
ctx->mark_non_differentiable({channel_mapping});
return {output, channel_mapping};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto channel_mapping = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = _ps_roi_align_backward(
grad_output[0],
rois,
channel_mapping,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
ctx->saved_data["sampling_ratio"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
return {grad_in,
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable()};
}
};
// TODO: There should be an easier way to do this
class PSROIAlignBackwardFunction
: public torch::autograd::Function<PSROIAlignBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& grad,
const torch::autograd::Variable& rois,
const torch::autograd::Variable& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
at::AutoNonVariableTypeMode g;
auto grad_in = _ps_roi_align_backward(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width);
return {grad_in};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on ps_roi_align not supported");
}
};
std::tuple<at::Tensor, at::Tensor> ps_roi_align_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio) {
auto result = PSROIAlignFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
return std::make_tuple(result[0], result[1]);
}
at::Tensor ps_roi_align_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
return PSROIAlignBackwardFunction::apply(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width)[0];
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("ps_roi_align", ps_roi_align_autograd);
m.impl("_ps_roi_align_backward", ps_roi_align_backward_autograd);
}
} // namespace ops
} // namespace vision
......@@ -14,5 +14,22 @@ VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_align(
int64_t pooled_width,
int64_t sampling_ratio);
namespace detail {
VISION_API at::Tensor _ps_roi_align_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width);
} // namespace detail
} // namespace ops
} // namespace vision
#include "ps_roi_pool.h"
#include <torch/autograd.h>
#include <torch/types.h>
namespace vision {
......@@ -18,6 +17,8 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_pool(
return op.call(input, rois, spatial_scale, pooled_height, pooled_width);
}
namespace detail {
at::Tensor _ps_roi_pool_backward(
const at::Tensor& grad,
const at::Tensor& rois,
......@@ -46,6 +47,8 @@ at::Tensor _ps_roi_pool_backward(
width);
}
} // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
......@@ -53,142 +56,5 @@ TORCH_LIBRARY_FRAGMENT(torchvision, m) {
"_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
}
namespace {
class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& input,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["input_shape"] = input.sizes();
at::AutoNonVariableTypeMode g;
auto result =
ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width);
auto output = std::get<0>(result);
auto channel_mapping = std::get<1>(result);
ctx->save_for_backward({rois, channel_mapping});
ctx->mark_non_differentiable({channel_mapping});
return {output, channel_mapping};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto channel_mapping = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = _ps_roi_pool_backward(
grad_output[0],
rois,
channel_mapping,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
return {grad_in,
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable()};
}
};
// TODO: There should be an easier way to do this
class PSROIPoolBackwardFunction
: public torch::autograd::Function<PSROIPoolBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& grad,
const torch::autograd::Variable& rois,
const torch::autograd::Variable& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
at::AutoNonVariableTypeMode g;
auto grad_in = _ps_roi_pool_backward(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width);
return {grad_in};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on ps_roi_pool not supported");
}
};
std::tuple<at::Tensor, at::Tensor> ps_roi_pool_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
auto result = PSROIPoolFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width);
return std::make_tuple(result[0], result[1]);
}
at::Tensor ps_roi_pool_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
return PSROIPoolBackwardFunction::apply(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width)[0];
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("ps_roi_pool", ps_roi_pool_autograd);
m.impl("_ps_roi_pool_backward", ps_roi_pool_backward_autograd);
}
} // namespace ops
} // namespace vision
......@@ -13,5 +13,21 @@ VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_pool(
int64_t pooled_height,
int64_t pooled_width);
namespace detail {
VISION_API at::Tensor _ps_roi_pool_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width);
} // namespace detail
} // namespace ops
} // namespace vision
#include "roi_align.h"
#include <torch/autograd.h>
#include <torch/types.h>
namespace vision {
......@@ -30,6 +29,8 @@ at::Tensor roi_align(
aligned);
}
namespace detail {
at::Tensor _roi_align_backward(
const at::Tensor& grad,
const at::Tensor& rois,
......@@ -60,6 +61,8 @@ at::Tensor _roi_align_backward(
aligned);
}
} // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor");
......@@ -67,157 +70,5 @@ TORCH_LIBRARY_FRAGMENT(torchvision, m) {
"_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor");
}
namespace {
class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& input,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["aligned"] = aligned;
ctx->saved_data["input_shape"] = input.sizes();
ctx->save_for_backward({rois});
at::AutoNonVariableTypeMode g;
auto result = roi_align(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned);
return {result};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = _roi_align_backward(
grad_output[0],
rois,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
ctx->saved_data["sampling_ratio"].toInt(),
ctx->saved_data["aligned"].toBool());
return {grad_in,
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable()};
}
};
// TODO: There should be an easier way to do this
class ROIAlignBackwardFunction
: public torch::autograd::Function<ROIAlignBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& grad,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned) {
at::AutoNonVariableTypeMode g;
auto result = _roi_align_backward(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio,
aligned);
return {result};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on roi_align not supported");
}
};
at::Tensor roi_align_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
return ROIAlignFunction::apply(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned)[0];
}
at::Tensor roi_align_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned) {
return ROIAlignBackwardFunction::apply(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio,
aligned)[0];
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_align", roi_align_autograd);
m.impl("_roi_align_backward", roi_align_backward_autograd);
}
} // namespace ops
} // namespace vision
......@@ -15,5 +15,22 @@ VISION_API at::Tensor roi_align(
int64_t sampling_ratio,
bool aligned);
namespace detail {
VISION_API at::Tensor _roi_align_backward(
const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned);
} // namespace detail
} // namespace ops
} // namespace vision
#include "roi_pool.h"
#include <torch/autograd.h>
#include <torch/types.h>
namespace vision {
......@@ -18,6 +17,8 @@ std::tuple<at::Tensor, at::Tensor> roi_pool(
return op.call(input, rois, spatial_scale, pooled_height, pooled_width);
}
namespace detail {
at::Tensor _roi_pool_backward(
const at::Tensor& grad,
const at::Tensor& rois,
......@@ -45,6 +46,8 @@ at::Tensor _roi_pool_backward(
width);
}
} // namespace detail
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(
"roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
......@@ -52,142 +55,5 @@ TORCH_LIBRARY_FRAGMENT(torchvision, m) {
"_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
}
namespace {
class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& input,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["input_shape"] = input.sizes();
at::AutoNonVariableTypeMode g;
auto result =
roi_pool(input, rois, spatial_scale, pooled_height, pooled_width);
auto output = std::get<0>(result);
auto argmax = std::get<1>(result);
ctx->save_for_backward({rois, argmax});
ctx->mark_non_differentiable({argmax});
return {output, argmax};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto argmax = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = _roi_pool_backward(
grad_output[0],
rois,
argmax,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
return {grad_in,
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable(),
torch::autograd::Variable()};
}
};
// TODO: There should be an easier way to do this
class ROIPoolBackwardFunction
: public torch::autograd::Function<ROIPoolBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& grad,
const torch::autograd::Variable& rois,
const torch::autograd::Variable& argmax,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
at::AutoNonVariableTypeMode g;
auto grad_in = _roi_pool_backward(
grad,
rois,
argmax,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width);
return {grad_in};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on roi_pool not supported");
}
};
std::tuple<at::Tensor, at::Tensor> roi_pool_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
auto result = ROIPoolFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width);
return std::make_tuple(result[0], result[1]);
}
at::Tensor roi_pool_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
return ROIPoolBackwardFunction::apply(
grad,
rois,
argmax,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width)[0];
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_pool", roi_pool_autograd);
m.impl("_roi_pool_backward", roi_pool_backward_autograd);
}
} // namespace ops
} // namespace vision
......@@ -13,5 +13,21 @@ VISION_API std::tuple<at::Tensor, at::Tensor> roi_pool(
int64_t pooled_height,
int64_t pooled_width);
namespace detail {
VISION_API at::Tensor _roi_pool_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width);
} // namespace detail
} // namespace ops
} // namespace vision
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment