#include #include #include "ROIAlign.h" #include "ROIPool.h" #include "nms.h" using namespace at; // If we are in a Windows environment, we need to define // initialization functions for the _custom_ops extension #ifdef _WIN32 #if PY_MAJOR_VERSION < 3 PyMODINIT_FUNC init_custom_ops(void) { // No need to do anything. // _custom_ops.py will run on load return NULL; } #else PyMODINIT_FUNC PyInit__custom_ops(void) { // No need to do anything. // _custom_ops.py will run on load return NULL; } #endif #endif using torch::Tensor; using torch::autograd::AutogradContext; using torch::autograd::Variable; using torch::autograd::variable_list; class ROIAlignFunction : public torch::autograd::Function { public: static variable_list forward( AutogradContext* ctx, Variable input, Variable rois, const double spatial_scale, const int64_t pooled_height, const int64_t pooled_width, const int64_t sampling_ratio) { ctx->saved_data["spatial_scale"] = spatial_scale; ctx->saved_data["pooled_height"] = pooled_height; ctx->saved_data["pooled_width"] = pooled_width; ctx->saved_data["sampling_ratio"] = sampling_ratio; ctx->saved_data["input_shape"] = input.sizes(); ctx->save_for_backward({rois}); auto result = ROIAlign_forward( input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); return {result}; } static variable_list backward( AutogradContext* ctx, variable_list grad_output) { // Use data saved in forward auto saved = ctx->get_saved_variables(); auto rois = saved[0]; auto input_shape = ctx->saved_data["input_shape"].toIntList(); auto grad_in = ROIAlign_backward( grad_output[0], rois, ctx->saved_data["spatial_scale"].toDouble(), ctx->saved_data["pooled_height"].toInt(), ctx->saved_data["pooled_width"].toInt(), input_shape[0], input_shape[1], input_shape[2], input_shape[3], ctx->saved_data["sampling_ratio"].toInt()); return { grad_in, Variable(), Variable(), Variable(), Variable(), Variable()}; } }; Tensor roi_align( const Tensor& input, const Tensor& rois, const double spatial_scale, const int64_t pooled_height, const int64_t pooled_width, const int64_t sampling_ratio) { return ROIAlignFunction::apply( input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio)[0]; } class ROIPoolFunction : public torch::autograd::Function { public: static variable_list forward( AutogradContext* ctx, Variable input, Variable rois, const double spatial_scale, const int64_t pooled_height, const int64_t pooled_width) { ctx->saved_data["spatial_scale"] = spatial_scale; ctx->saved_data["pooled_height"] = pooled_height; ctx->saved_data["pooled_width"] = pooled_width; ctx->saved_data["input_shape"] = input.sizes(); auto result = ROIPool_forward( input, rois, spatial_scale, pooled_height, pooled_width); auto output = std::get<0>(result); auto argmax = std::get<1>(result); ctx->save_for_backward({rois, argmax}); ctx->mark_non_differentiable({argmax}); return {output, argmax}; } static variable_list backward( AutogradContext* ctx, variable_list grad_output) { // Use data saved in forward auto saved = ctx->get_saved_variables(); auto rois = saved[0]; auto argmax = saved[1]; auto input_shape = ctx->saved_data["input_shape"].toIntList(); auto grad_in = ROIPool_backward( grad_output[0], rois, argmax, ctx->saved_data["spatial_scale"].toDouble(), ctx->saved_data["pooled_height"].toInt(), ctx->saved_data["pooled_width"].toInt(), input_shape[0], input_shape[1], input_shape[2], input_shape[3]); return {grad_in, Variable(), Variable(), Variable(), Variable()}; } }; std::tuple roi_pool( const Tensor& input, const Tensor& rois, const double spatial_scale, const int64_t pooled_height, const int64_t pooled_width) { auto result = ROIPoolFunction::apply( input, rois, spatial_scale, pooled_height, pooled_width); return std::tuple(result[0], result[1]); } static auto registry = torch::RegisterOperators() .op("torchvision::nms", &nms) .op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor", &roi_align) .op("torchvision::roi_pool", &roi_pool);