Unverified Commit 44806038 authored by Edward Z. Yang's avatar Edward Z. Yang Committed by GitHub
Browse files

Port roi_align to actually use dispatcher (#2366)



* Switch torchvision registrations to new operator registration API.

This is still registering everything as catchalls, so we're really just
moving deck chairs around, but payoff is coming soon.
Signed-off-by: default avatarEdward Z. Yang <ezyang@fb.com>

* Port roi_align to actually use dispatcher
Signed-off-by: default avatarEdward Z. Yang <ezyang@fb.com>
parent 7fd24918
......@@ -9,8 +9,9 @@
#include "hip/vision_cuda.h"
#endif
// Interface for Python
at::Tensor ROIAlign_forward(
// TODO: put this stuff in torchvision namespace
at::Tensor roi_align(
const at::Tensor& input, // Input feature map.
const at::Tensor& rois, // List of ROIs to pool over.
const double spatial_scale, // The scale of the image features. ROIs will be
......@@ -21,21 +22,10 @@ at::Tensor ROIAlign_forward(
const bool aligned) // The flag for pixel shift
// along each axis.
{
if (input.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return ROIAlign_forward_cuda(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return ROIAlign_forward_cpu(
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::roi_align", "")
.typed<decltype(roi_align)>();
return op.call(
input,
rois,
spatial_scale,
......@@ -45,37 +35,23 @@ at::Tensor ROIAlign_forward(
aligned);
}
at::Tensor ROIAlign_backward(
at::Tensor _roi_align_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned) {
if (grad.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return ROIAlign_backward_cuda(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio,
aligned);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return ROIAlign_backward_cpu(
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_roi_align_backward", "")
.typed<decltype(_roi_align_backward)>();
return op.call(
grad,
rois,
spatial_scale,
......@@ -107,7 +83,8 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
ctx->saved_data["aligned"] = aligned;
ctx->saved_data["input_shape"] = input.sizes();
ctx->save_for_backward({rois});
auto result = ROIAlign_forward(
at::AutoNonVariableTypeMode g;
auto result = roi_align(
input,
rois,
spatial_scale,
......@@ -125,7 +102,7 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = ROIAlign_backward(
auto grad_in = _roi_align_backward(
grad_output[0],
rois,
ctx->saved_data["spatial_scale"].toDouble(),
......@@ -147,7 +124,47 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
}
};
at::Tensor roi_align(
// TODO: There should be an easier way to do this
class ROIAlignBackwardFunction
: public torch::autograd::Function<ROIAlignBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable grad,
torch::autograd::Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned) {
at::AutoNonVariableTypeMode g;
auto result = _roi_align_backward(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio,
aligned);
return {result};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
TORCH_CHECK(0, "double backwards on roi_align not supported");
}
};
at::Tensor ROIAlign_autograd(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
......@@ -164,3 +181,29 @@ at::Tensor roi_align(
sampling_ratio,
aligned)[0];
}
at::Tensor ROIAlign_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned) {
return ROIAlignBackwardFunction::apply(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio,
aligned)[0];
}
......@@ -381,10 +381,10 @@ void ROIAlignBackward(
at::Tensor ROIAlign_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const bool aligned) {
AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
......@@ -430,14 +430,14 @@ at::Tensor ROIAlign_forward_cpu(
at::Tensor ROIAlign_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned) {
AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor");
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
......
......@@ -23,23 +23,23 @@ at::Tensor ROIPool_backward_cpu(
at::Tensor ROIAlign_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const bool aligned);
at::Tensor ROIAlign_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned);
std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu(
......
......@@ -307,10 +307,10 @@ __global__ void RoIAlignBackward(
at::Tensor ROIAlign_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const bool aligned) {
AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(rois.is_cuda(), "rois must be a CUDA tensor");
......@@ -368,14 +368,14 @@ at::Tensor ROIAlign_forward_cuda(
at::Tensor ROIAlign_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned) {
AT_ASSERTM(grad.is_cuda(), "grad must be a CUDA tensor");
AT_ASSERTM(rois.is_cuda(), "rois must be a CUDA tensor");
......
......@@ -9,23 +9,23 @@
at::Tensor ROIAlign_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const bool aligned);
at::Tensor ROIAlign_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned);
std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(
......
......@@ -42,14 +42,34 @@ int64_t _cuda_version() {
#endif
}
static auto registry =
torch::RegisterOperators()
.op("torchvision::nms", &nms)
.op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor",
&roi_align)
.op("torchvision::roi_pool", &roi_pool)
.op("torchvision::_new_empty_tensor_op", &new_empty_tensor)
.op("torchvision::ps_roi_align", &ps_roi_align)
.op("torchvision::ps_roi_pool", &ps_roi_pool)
.op("torchvision::deform_conv2d", &deform_conv2d)
.op("torchvision::_cuda_version", &_cuda_version);
TORCH_LIBRARY(torchvision, m) {
m.def("nms", &nms);
m.def(
"roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor");
m.def(
"_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor");
m.def("roi_pool", &roi_pool);
m.def("_new_empty_tensor_op", &new_empty_tensor);
m.def("ps_roi_align", &ps_roi_align);
m.def("ps_roi_pool", &ps_roi_pool);
m.def("deform_conv2d", &deform_conv2d);
m.def("_cuda_version", &_cuda_version);
}
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("roi_align", ROIAlign_forward_cpu);
m.impl("_roi_align_backward", ROIAlign_backward_cpu);
}
// TODO: Place this in a hypothetical separate torchvision_cuda library
#if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("roi_align", ROIAlign_forward_cuda);
m.impl("_roi_align_backward", ROIAlign_backward_cuda);
}
#endif
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_align", ROIAlign_autograd);
m.impl("_roi_align_backward", ROIAlign_backward_autograd);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment