Commit a91fe722 authored by Thomas Viehmann's avatar Thomas Viehmann Committed by Francisco Massa
Browse files

Make custom ops differentiable (#1314)

* Make custom ops differentiable

and replace autograd.Function. Use ops unconditionally.

We may consider removing the extension functions in a follow-up.

The code-path is tested by the exisitng tests for differentiability.

* add scripting gradchecks tests and use intlist

* fix implicit tuple conversion for gcc-5

* fix merge
parent cabca398
......@@ -188,6 +188,12 @@ class RoIPoolTester(unittest.TestCase):
assert gradcheck(func, (x,)), 'gradcheck failed for roi_pool CPU'
assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for roi_pool CPU'
@torch.jit.script
def script_func(input, rois):
return torch.ops.torchvision.roi_pool(input, rois, 1.0, 5, 5)[0]
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_pool'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_pool_basic_cuda(self):
device = torch.device('cuda')
......@@ -274,6 +280,12 @@ class RoIPoolTester(unittest.TestCase):
assert gradcheck(func, (x,)), 'gradcheck failed for roi_pool CUDA'
assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for roi_pool CUDA'
@torch.jit.script
def script_func(input, rois):
return torch.ops.torchvision.roi_pool(input, rois, 1.0, 5, 5)[0]
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_pool on CUDA'
class RoIAlignTester(unittest.TestCase):
@classmethod
......@@ -428,6 +440,12 @@ class RoIAlignTester(unittest.TestCase):
assert gradcheck(func, (x,)), 'gradcheck failed for RoIAlign CPU'
assert gradcheck(func, (x.transpose(2, 3),)), 'gradcheck failed for RoIAlign CPU'
@torch.jit.script
def script_func(input, rois):
return torch.ops.torchvision.roi_align(input, rois, 0.5, 5, 5, 1)[0]
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_align'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_align_gradient_cuda(self):
"""
......@@ -462,6 +480,12 @@ class RoIAlignTester(unittest.TestCase):
assert gradcheck(func, (x,)), 'gradcheck failed for RoIAlign CUDA'
assert gradcheck(func, (x.transpose(2, 3),)), 'gradcheck failed for RoIAlign CUDA'
@torch.jit.script
def script_func(input, rois):
return torch.ops.torchvision.roi_align(input, rois, 0.5, 5, 5, 1)[0]
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_align on CUDA'
class NMSTester(unittest.TestCase):
def reference_nms(self, boxes, scores, iou_threshold):
......
......@@ -25,9 +25,135 @@ PyMODINIT_FUNC PyInit__custom_ops(void) {
#endif
#endif
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
public:
static variable_list forward(
AutogradContext* ctx,
Variable input,
Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["input_shape"] = input.sizes();
ctx->save_for_backward({rois});
auto result = ROIAlign_forward(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
return {result};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = ROIAlign_backward(
grad_output[0],
rois,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
ctx->saved_data["sampling_ratio"].toInt());
return {
grad_in, Variable(), Variable(), Variable(), Variable(), Variable()};
}
};
Tensor roi_align(
const Tensor& input,
const Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
return ROIAlignFunction::apply(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio)[0];
}
class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
public:
static variable_list forward(
AutogradContext* ctx,
Variable input,
Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["input_shape"] = input.sizes();
auto result = ROIPool_forward(
input, rois, spatial_scale, pooled_height, pooled_width);
auto output = std::get<0>(result);
auto argmax = std::get<1>(result);
ctx->save_for_backward({rois, argmax});
ctx->mark_non_differentiable({argmax});
return {output, argmax};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto argmax = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = ROIPool_backward(
grad_output[0],
rois,
argmax,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
std::tuple<Tensor, Tensor> roi_pool(
const Tensor& input,
const Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
auto result = ROIPoolFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width);
return std::tuple<Tensor, Tensor>(result[0], result[1]);
}
static auto registry =
torch::RegisterOperators()
.op("torchvision::nms", &nms)
.op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor",
&ROIAlign_forward)
.op("torchvision::roi_pool", &ROIPool_forward);
&roi_align)
.op("torchvision::roi_pool", &roi_pool);
......@@ -10,35 +10,6 @@ from torchvision.extension import _lazy_import
from ._utils import convert_boxes_to_roi_format
class _RoIAlignFunction(Function):
@staticmethod
def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio):
ctx.save_for_backward(roi)
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.sampling_ratio = sampling_ratio
ctx.input_shape = input.size()
_C = _lazy_import()
output = _C.roi_align_forward(
input, roi, spatial_scale,
output_size[0], output_size[1], sampling_ratio)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
rois, = ctx.saved_tensors
output_size = ctx.output_size
spatial_scale = ctx.spatial_scale
sampling_ratio = ctx.sampling_ratio
bs, ch, h, w = ctx.input_shape
_C = _lazy_import()
grad_input = _C.roi_align_backward(
grad_output, rois, spatial_scale,
output_size[0], output_size[1], bs, ch, h, w, sampling_ratio)
return grad_input, None, None, None, None
def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
"""
Performs Region of Interest (RoI) Align operator described in Mask R-CNN
......@@ -66,14 +37,10 @@ def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
rois = boxes
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
# TODO: Change this to support backwards, which we
# do not currently support when JIT tracing.
if torch._C._get_tracing_state():
_lazy_import()
return torch.ops.torchvision.roi_align(input, rois, spatial_scale,
output_size[0], output_size[1],
sampling_ratio)
return _RoIAlignFunction.apply(input, rois, output_size, spatial_scale, sampling_ratio)
_lazy_import()
return torch.ops.torchvision.roi_align(input, rois, spatial_scale,
output_size[0], output_size[1],
sampling_ratio)
class RoIAlign(nn.Module):
......
......@@ -10,33 +10,6 @@ from torchvision.extension import _lazy_import
from ._utils import convert_boxes_to_roi_format
class _RoIPoolFunction(Function):
@staticmethod
def forward(ctx, input, rois, output_size, spatial_scale):
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.input_shape = input.size()
_C = _lazy_import()
output, argmax = _C.roi_pool_forward(
input, rois, spatial_scale,
output_size[0], output_size[1])
ctx.save_for_backward(rois, argmax)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
rois, argmax = ctx.saved_tensors
output_size = ctx.output_size
spatial_scale = ctx.spatial_scale
bs, ch, h, w = ctx.input_shape
_C = _lazy_import()
grad_input = _C.roi_pool_backward(
grad_output, rois, argmax, spatial_scale,
output_size[0], output_size[1], bs, ch, h, w)
return grad_input, None, None, None
def roi_pool(input, boxes, output_size, spatial_scale=1.0):
"""
Performs Region of Interest (RoI) Pool operator described in Fast R-CNN
......@@ -59,14 +32,10 @@ def roi_pool(input, boxes, output_size, spatial_scale=1.0):
rois = boxes
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
# TODO: Change this to support backwards, which we
# do not currently support when JIT tracing.
if torch._C._get_tracing_state():
_lazy_import()
output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale,
output_size[0], output_size[1])
return output
return _RoIPoolFunction.apply(input, rois, output_size, spatial_scale)
_lazy_import()
output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale,
output_size[0], output_size[1])
return output
class RoIPool(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment