Commit a91fe722 authored by Thomas Viehmann's avatar Thomas Viehmann Committed by Francisco Massa
Browse files

Make custom ops differentiable (#1314)

* Make custom ops differentiable

and replace autograd.Function. Use ops unconditionally.

We may consider removing the extension functions in a follow-up.

The code-path is tested by the exisitng tests for differentiability.

* add scripting gradchecks tests and use intlist

* fix implicit tuple conversion for gcc-5

* fix merge
parent cabca398
...@@ -188,6 +188,12 @@ class RoIPoolTester(unittest.TestCase): ...@@ -188,6 +188,12 @@ class RoIPoolTester(unittest.TestCase):
assert gradcheck(func, (x,)), 'gradcheck failed for roi_pool CPU' assert gradcheck(func, (x,)), 'gradcheck failed for roi_pool CPU'
assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for roi_pool CPU' assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for roi_pool CPU'
@torch.jit.script
def script_func(input, rois):
return torch.ops.torchvision.roi_pool(input, rois, 1.0, 5, 5)[0]
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_pool'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_pool_basic_cuda(self): def test_roi_pool_basic_cuda(self):
device = torch.device('cuda') device = torch.device('cuda')
...@@ -274,6 +280,12 @@ class RoIPoolTester(unittest.TestCase): ...@@ -274,6 +280,12 @@ class RoIPoolTester(unittest.TestCase):
assert gradcheck(func, (x,)), 'gradcheck failed for roi_pool CUDA' assert gradcheck(func, (x,)), 'gradcheck failed for roi_pool CUDA'
assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for roi_pool CUDA' assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for roi_pool CUDA'
@torch.jit.script
def script_func(input, rois):
return torch.ops.torchvision.roi_pool(input, rois, 1.0, 5, 5)[0]
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_pool on CUDA'
class RoIAlignTester(unittest.TestCase): class RoIAlignTester(unittest.TestCase):
@classmethod @classmethod
...@@ -428,6 +440,12 @@ class RoIAlignTester(unittest.TestCase): ...@@ -428,6 +440,12 @@ class RoIAlignTester(unittest.TestCase):
assert gradcheck(func, (x,)), 'gradcheck failed for RoIAlign CPU' assert gradcheck(func, (x,)), 'gradcheck failed for RoIAlign CPU'
assert gradcheck(func, (x.transpose(2, 3),)), 'gradcheck failed for RoIAlign CPU' assert gradcheck(func, (x.transpose(2, 3),)), 'gradcheck failed for RoIAlign CPU'
@torch.jit.script
def script_func(input, rois):
return torch.ops.torchvision.roi_align(input, rois, 0.5, 5, 5, 1)[0]
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_align'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_align_gradient_cuda(self): def test_roi_align_gradient_cuda(self):
""" """
...@@ -462,6 +480,12 @@ class RoIAlignTester(unittest.TestCase): ...@@ -462,6 +480,12 @@ class RoIAlignTester(unittest.TestCase):
assert gradcheck(func, (x,)), 'gradcheck failed for RoIAlign CUDA' assert gradcheck(func, (x,)), 'gradcheck failed for RoIAlign CUDA'
assert gradcheck(func, (x.transpose(2, 3),)), 'gradcheck failed for RoIAlign CUDA' assert gradcheck(func, (x.transpose(2, 3),)), 'gradcheck failed for RoIAlign CUDA'
@torch.jit.script
def script_func(input, rois):
return torch.ops.torchvision.roi_align(input, rois, 0.5, 5, 5, 1)[0]
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_align on CUDA'
class NMSTester(unittest.TestCase): class NMSTester(unittest.TestCase):
def reference_nms(self, boxes, scores, iou_threshold): def reference_nms(self, boxes, scores, iou_threshold):
......
...@@ -25,9 +25,135 @@ PyMODINIT_FUNC PyInit__custom_ops(void) { ...@@ -25,9 +25,135 @@ PyMODINIT_FUNC PyInit__custom_ops(void) {
#endif #endif
#endif #endif
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
public:
static variable_list forward(
AutogradContext* ctx,
Variable input,
Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["input_shape"] = input.sizes();
ctx->save_for_backward({rois});
auto result = ROIAlign_forward(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
return {result};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = ROIAlign_backward(
grad_output[0],
rois,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
ctx->saved_data["sampling_ratio"].toInt());
return {
grad_in, Variable(), Variable(), Variable(), Variable(), Variable()};
}
};
Tensor roi_align(
const Tensor& input,
const Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
return ROIAlignFunction::apply(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio)[0];
}
class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
public:
static variable_list forward(
AutogradContext* ctx,
Variable input,
Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["input_shape"] = input.sizes();
auto result = ROIPool_forward(
input, rois, spatial_scale, pooled_height, pooled_width);
auto output = std::get<0>(result);
auto argmax = std::get<1>(result);
ctx->save_for_backward({rois, argmax});
ctx->mark_non_differentiable({argmax});
return {output, argmax};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto argmax = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = ROIPool_backward(
grad_output[0],
rois,
argmax,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
std::tuple<Tensor, Tensor> roi_pool(
const Tensor& input,
const Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
auto result = ROIPoolFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width);
return std::tuple<Tensor, Tensor>(result[0], result[1]);
}
static auto registry = static auto registry =
torch::RegisterOperators() torch::RegisterOperators()
.op("torchvision::nms", &nms) .op("torchvision::nms", &nms)
.op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor", .op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor",
&ROIAlign_forward) &roi_align)
.op("torchvision::roi_pool", &ROIPool_forward); .op("torchvision::roi_pool", &roi_pool);
...@@ -10,35 +10,6 @@ from torchvision.extension import _lazy_import ...@@ -10,35 +10,6 @@ from torchvision.extension import _lazy_import
from ._utils import convert_boxes_to_roi_format from ._utils import convert_boxes_to_roi_format
class _RoIAlignFunction(Function):
@staticmethod
def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio):
ctx.save_for_backward(roi)
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.sampling_ratio = sampling_ratio
ctx.input_shape = input.size()
_C = _lazy_import()
output = _C.roi_align_forward(
input, roi, spatial_scale,
output_size[0], output_size[1], sampling_ratio)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
rois, = ctx.saved_tensors
output_size = ctx.output_size
spatial_scale = ctx.spatial_scale
sampling_ratio = ctx.sampling_ratio
bs, ch, h, w = ctx.input_shape
_C = _lazy_import()
grad_input = _C.roi_align_backward(
grad_output, rois, spatial_scale,
output_size[0], output_size[1], bs, ch, h, w, sampling_ratio)
return grad_input, None, None, None, None
def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1): def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
""" """
Performs Region of Interest (RoI) Align operator described in Mask R-CNN Performs Region of Interest (RoI) Align operator described in Mask R-CNN
...@@ -66,14 +37,10 @@ def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1): ...@@ -66,14 +37,10 @@ def roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
rois = boxes rois = boxes
if not isinstance(rois, torch.Tensor): if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois) rois = convert_boxes_to_roi_format(rois)
# TODO: Change this to support backwards, which we _lazy_import()
# do not currently support when JIT tracing. return torch.ops.torchvision.roi_align(input, rois, spatial_scale,
if torch._C._get_tracing_state(): output_size[0], output_size[1],
_lazy_import() sampling_ratio)
return torch.ops.torchvision.roi_align(input, rois, spatial_scale,
output_size[0], output_size[1],
sampling_ratio)
return _RoIAlignFunction.apply(input, rois, output_size, spatial_scale, sampling_ratio)
class RoIAlign(nn.Module): class RoIAlign(nn.Module):
......
...@@ -10,33 +10,6 @@ from torchvision.extension import _lazy_import ...@@ -10,33 +10,6 @@ from torchvision.extension import _lazy_import
from ._utils import convert_boxes_to_roi_format from ._utils import convert_boxes_to_roi_format
class _RoIPoolFunction(Function):
@staticmethod
def forward(ctx, input, rois, output_size, spatial_scale):
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
ctx.input_shape = input.size()
_C = _lazy_import()
output, argmax = _C.roi_pool_forward(
input, rois, spatial_scale,
output_size[0], output_size[1])
ctx.save_for_backward(rois, argmax)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
rois, argmax = ctx.saved_tensors
output_size = ctx.output_size
spatial_scale = ctx.spatial_scale
bs, ch, h, w = ctx.input_shape
_C = _lazy_import()
grad_input = _C.roi_pool_backward(
grad_output, rois, argmax, spatial_scale,
output_size[0], output_size[1], bs, ch, h, w)
return grad_input, None, None, None
def roi_pool(input, boxes, output_size, spatial_scale=1.0): def roi_pool(input, boxes, output_size, spatial_scale=1.0):
""" """
Performs Region of Interest (RoI) Pool operator described in Fast R-CNN Performs Region of Interest (RoI) Pool operator described in Fast R-CNN
...@@ -59,14 +32,10 @@ def roi_pool(input, boxes, output_size, spatial_scale=1.0): ...@@ -59,14 +32,10 @@ def roi_pool(input, boxes, output_size, spatial_scale=1.0):
rois = boxes rois = boxes
if not isinstance(rois, torch.Tensor): if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois) rois = convert_boxes_to_roi_format(rois)
# TODO: Change this to support backwards, which we _lazy_import()
# do not currently support when JIT tracing. output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale,
if torch._C._get_tracing_state(): output_size[0], output_size[1])
_lazy_import() return output
output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale,
output_size[0], output_size[1])
return output
return _RoIPoolFunction.apply(input, rois, output_size, spatial_scale)
class RoIPool(nn.Module): class RoIPool(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment