Unverified Commit b06e43d6 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

PSROIAlign + Dispatcher + Autocast + Code Cleanup (#2928)

* Fixing types.

* Dispatcher + Autocast.

* + Autograd.

* Clean up and refactor PSROIAlign implementation:
- Remove primitive const declaration from method names.
- Using references when possible.
- Sync naming of internal methods with other ops.

* Restoring names of internal methods to avoid conflicts.

* Restore include headers.
parent 0e5aee46
...@@ -3,72 +3,75 @@ ...@@ -3,72 +3,75 @@
#include "cpu/vision_cpu.h" #include "cpu/vision_cpu.h"
#ifdef WITH_CUDA #ifdef WITH_CUDA
#include "autocast.h"
#include "cuda/vision_cuda.h" #include "cuda/vision_cuda.h"
#endif #endif
#ifdef WITH_HIP #ifdef WITH_HIP
#include "autocast.h"
#include "hip/vision_cuda.h" #include "hip/vision_cuda.h"
#endif #endif
#include <iostream> #include <iostream>
std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward( // TODO: put this stuff in torchvision namespace
std::tuple<at::Tensor, at::Tensor> ps_roi_align(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width, int64_t pooled_width,
const int sampling_ratio) { int64_t sampling_ratio) {
if (input.is_cuda()) { static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::ps_roi_align", "")
.typed<decltype(ps_roi_align)>();
return op.call(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
}
#if defined(WITH_CUDA) || defined(WITH_HIP) #if defined(WITH_CUDA) || defined(WITH_HIP)
return PSROIAlign_forward_cuda( std::tuple<at::Tensor, at::Tensor> PSROIAlign_autocast(
input, const at::Tensor& input,
rois, const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto result = ps_roi_align(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, rois),
spatial_scale, spatial_scale,
pooled_height, pooled_height,
pooled_width, pooled_width,
sampling_ratio); sampling_ratio);
#else
TORCH_CHECK(false, "Not compiled with GPU support"); return std::make_tuple(
#endif std::get<0>(result).to(input.scalar_type()),
} std::get<1>(result).to(input.scalar_type()));
return PSROIAlign_forward_cpu(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
} }
#endif
at::Tensor PSROIAlign_backward( at::Tensor _ps_roi_align_backward(
const at::Tensor& grad, const at::Tensor& grad,
const at::Tensor& rois, const at::Tensor& rois,
const at::Tensor& mapping_channel, const at::Tensor& channel_mapping,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width, int64_t pooled_width,
const int sampling_ratio, int64_t sampling_ratio,
const int batch_size, int64_t batch_size,
const int channels, int64_t channels,
const int height, int64_t height,
const int width) { int64_t width) {
if (grad.is_cuda()) { static auto op =
#if defined(WITH_CUDA) || defined(WITH_HIP) c10::Dispatcher::singleton()
return PSROIAlign_backward_cuda( .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "")
.typed<decltype(_ps_roi_align_backward)>();
return op.call(
grad, grad,
rois, rois,
mapping_channel, channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
#endif
}
return PSROIAlign_backward_cpu(
grad,
rois,
mapping_channel,
spatial_scale, spatial_scale,
pooled_height, pooled_height,
pooled_width, pooled_width,
...@@ -84,40 +87,43 @@ class PSROIAlignFunction ...@@ -84,40 +87,43 @@ class PSROIAlignFunction
public: public:
static torch::autograd::variable_list forward( static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx, torch::autograd::AutogradContext* ctx,
torch::autograd::Variable input, const torch::autograd::Variable& input,
torch::autograd::Variable rois, const torch::autograd::Variable& rois,
const double spatial_scale, double spatial_scale,
const int64_t pooled_height, int64_t pooled_height,
const int64_t pooled_width, int64_t pooled_width,
const int64_t sampling_ratio) { int64_t sampling_ratio) {
ctx->saved_data["spatial_scale"] = spatial_scale; ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height; ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width; ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio; ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["input_shape"] = input.sizes(); ctx->saved_data["input_shape"] = input.sizes();
auto result = PSROIAlign_forward( at::AutoNonVariableTypeMode g;
auto result = ps_roi_align(
input, input,
rois, rois,
spatial_scale, spatial_scale,
pooled_height, pooled_height,
pooled_width, pooled_width,
sampling_ratio); sampling_ratio);
auto output = std::get<0>(result); auto output = std::get<0>(result);
auto channel_mapping = std::get<1>(result); auto channel_mapping = std::get<1>(result);
ctx->save_for_backward({rois, channel_mapping}); ctx->save_for_backward({rois, channel_mapping});
ctx->mark_non_differentiable({channel_mapping}); ctx->mark_non_differentiable({channel_mapping});
return {output, channel_mapping}; return {output, channel_mapping};
} }
static torch::autograd::variable_list backward( static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx, torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) { const torch::autograd::variable_list& grad_output) {
// Use data saved in forward // Use data saved in forward
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto rois = saved[0]; auto rois = saved[0];
auto channel_mapping = saved[1]; auto channel_mapping = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList(); auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = PSROIAlign_backward( auto grad_in = _ps_roi_align_backward(
grad_output[0], grad_output[0],
rois, rois,
channel_mapping, channel_mapping,
...@@ -129,6 +135,7 @@ class PSROIAlignFunction ...@@ -129,6 +135,7 @@ class PSROIAlignFunction
input_shape[1], input_shape[1],
input_shape[2], input_shape[2],
input_shape[3]); input_shape[3]);
return {grad_in, return {grad_in,
torch::autograd::Variable(), torch::autograd::Variable(),
torch::autograd::Variable(), torch::autograd::Variable(),
...@@ -138,14 +145,82 @@ class PSROIAlignFunction ...@@ -138,14 +145,82 @@ class PSROIAlignFunction
} }
}; };
std::tuple<at::Tensor, at::Tensor> ps_roi_align( // TODO: There should be an easier way to do this
class PSROIAlignBackwardFunction
: public torch::autograd::Function<PSROIAlignBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& grad,
const torch::autograd::Variable& rois,
const torch::autograd::Variable& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
at::AutoNonVariableTypeMode g;
auto grad_in = _ps_roi_align_backward(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width);
return {grad_in};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on ps_roi_align not supported");
}
};
std::tuple<at::Tensor, at::Tensor> PSROIAlign_autograd(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
const double spatial_scale, double spatial_scale,
const int64_t pooled_height, int64_t pooled_height,
const int64_t pooled_width, int64_t pooled_width,
const int64_t sampling_ratio) { int64_t sampling_ratio) {
auto result = PSROIAlignFunction::apply( auto result = PSROIAlignFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
return std::tuple<at::Tensor, at::Tensor>(result[0], result[1]);
return std::make_tuple(result[0], result[1]);
}
at::Tensor PSROIAlign_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
return PSROIAlignBackwardFunction::apply(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width)[0];
} }
\ No newline at end of file
...@@ -5,11 +5,11 @@ ...@@ -5,11 +5,11 @@
template <typename T> template <typename T>
T bilinear_interpolate( T bilinear_interpolate(
const T* input, const T* input,
const int height, int height,
const int width, int width,
T y, T y,
T x, T x,
const int index /* index for debug only*/) { int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary // deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) { if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty // empty
...@@ -58,17 +58,17 @@ T bilinear_interpolate( ...@@ -58,17 +58,17 @@ T bilinear_interpolate(
template <typename T> template <typename T>
void PSROIAlignForwardCPU( void PSROIAlignForwardCPU(
const int nthreads, int nthreads,
const T* input, const T* input,
const T spatial_scale, const T spatial_scale,
const int channels, int channels,
const int height, int height,
const int width, int width,
const int pooled_height, int pooled_height,
const int pooled_width, int pooled_width,
const int sampling_ratio, int sampling_ratio,
const T* rois, const T* rois,
const int channels_out, int channels_out,
T* output, T* output,
int* channel_mapping) { int* channel_mapping) {
int num_rois = nthreads / channels_out / pooled_width / pooled_height; int num_rois = nthreads / channels_out / pooled_width / pooled_height;
...@@ -139,8 +139,8 @@ void PSROIAlignForwardCPU( ...@@ -139,8 +139,8 @@ void PSROIAlignForwardCPU(
template <typename T> template <typename T>
void bilinear_interpolate_gradient( void bilinear_interpolate_gradient(
const int height, int height,
const int width, int width,
T y, T y,
T x, T x,
T& w1, T& w1,
...@@ -151,7 +151,7 @@ void bilinear_interpolate_gradient( ...@@ -151,7 +151,7 @@ void bilinear_interpolate_gradient(
int& x_high, int& x_high,
int& y_low, int& y_low,
int& y_high, int& y_high,
const int index /* index for debug only*/) { int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary // deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) { if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty // empty
...@@ -203,18 +203,18 @@ inline void add(T* address, const T& val) { ...@@ -203,18 +203,18 @@ inline void add(T* address, const T& val) {
template <typename T> template <typename T>
void PSROIAlignBackwardCPU( void PSROIAlignBackwardCPU(
const int nthreads, int nthreads,
const T* grad_output, const T* grad_output,
const int* channel_mapping, const int* channel_mapping,
const int num_rois, int num_rois,
const T spatial_scale, const T spatial_scale,
const int channels, int channels,
const int height, int height,
const int width, int width,
const int pooled_height, int pooled_height,
const int pooled_width, int pooled_width,
const int sampling_ratio, int sampling_ratio,
const int channels_out, int channels_out,
T* grad_input, T* grad_input,
const T* rois) { const T* rois) {
for (int index = 0; index < nthreads; index++) { for (int index = 0; index < nthreads; index++) {
...@@ -301,10 +301,10 @@ void PSROIAlignBackwardCPU( ...@@ -301,10 +301,10 @@ void PSROIAlignBackwardCPU(
std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu( std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width, int64_t pooled_width,
const int sampling_ratio) { int64_t sampling_ratio) {
// Check if input tensors are CPU tensors // Check if input tensors are CPU tensors
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
...@@ -361,14 +361,14 @@ at::Tensor PSROIAlign_backward_cpu( ...@@ -361,14 +361,14 @@ at::Tensor PSROIAlign_backward_cpu(
const at::Tensor& grad, const at::Tensor& grad,
const at::Tensor& rois, const at::Tensor& rois,
const at::Tensor& channel_mapping, const at::Tensor& channel_mapping,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width, int64_t pooled_width,
const int sampling_ratio, int64_t sampling_ratio,
const int batch_size, int64_t batch_size,
const int channels, int64_t channels,
const int height, int64_t height,
const int width) { int64_t width) {
// Check if input tensors are CPU tensors // Check if input tensors are CPU tensors
TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
......
...@@ -65,23 +65,23 @@ VISION_API at::Tensor PSROIPool_backward_cpu( ...@@ -65,23 +65,23 @@ VISION_API at::Tensor PSROIPool_backward_cpu(
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu( VISION_API std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width, int64_t pooled_width,
const int sampling_ratio); int64_t sampling_ratio);
VISION_API at::Tensor PSROIAlign_backward_cpu( VISION_API at::Tensor PSROIAlign_backward_cpu(
const at::Tensor& grad, const at::Tensor& grad,
const at::Tensor& rois, const at::Tensor& rois,
const at::Tensor& mapping_channel, const at::Tensor& channel_mapping,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width, int64_t pooled_width,
const int sampling_ratio, int64_t sampling_ratio,
const int batch_size, int64_t batch_size,
const int channels, int64_t channels,
const int height, int64_t height,
const int width); int64_t width);
VISION_API at::Tensor nms_cpu( VISION_API at::Tensor nms_cpu(
const at::Tensor& dets, const at::Tensor& dets,
......
...@@ -10,11 +10,11 @@ ...@@ -10,11 +10,11 @@
template <typename T> template <typename T>
__device__ T bilinear_interpolate( __device__ T bilinear_interpolate(
const T* input, const T* input,
const int height, int height,
const int width, int width,
T y, T y,
T x, T x,
const int index /* index for debug only*/) { int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary // deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) { if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty // empty
...@@ -63,17 +63,17 @@ __device__ T bilinear_interpolate( ...@@ -63,17 +63,17 @@ __device__ T bilinear_interpolate(
template <typename T> template <typename T>
__global__ void PSROIAlignForwardCUDA( __global__ void PSROIAlignForwardCUDA(
const int nthreads, int nthreads,
const T* input, const T* input,
const T spatial_scale, const T spatial_scale,
const int channels, int channels,
const int height, int height,
const int width, int width,
const int pooled_height, int pooled_height,
const int pooled_width, int pooled_width,
const int sampling_ratio, int sampling_ratio,
const T* rois, const T* rois,
const int channels_out, int channels_out,
T* output, T* output,
int* channel_mapping) { int* channel_mapping) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { CUDA_1D_KERNEL_LOOP(index, nthreads) {
...@@ -137,8 +137,8 @@ __global__ void PSROIAlignForwardCUDA( ...@@ -137,8 +137,8 @@ __global__ void PSROIAlignForwardCUDA(
template <typename T> template <typename T>
__device__ void bilinear_interpolate_gradient( __device__ void bilinear_interpolate_gradient(
const int height, int height,
const int width, int width,
T y, T y,
T x, T x,
T& w1, T& w1,
...@@ -149,7 +149,7 @@ __device__ void bilinear_interpolate_gradient( ...@@ -149,7 +149,7 @@ __device__ void bilinear_interpolate_gradient(
int& x_high, int& x_high,
int& y_low, int& y_low,
int& y_high, int& y_high,
const int index /* index for debug only*/) { int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary // deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) { if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty // empty
...@@ -196,18 +196,18 @@ __device__ void bilinear_interpolate_gradient( ...@@ -196,18 +196,18 @@ __device__ void bilinear_interpolate_gradient(
template <typename T> template <typename T>
__global__ void PSROIAlignBackwardCUDA( __global__ void PSROIAlignBackwardCUDA(
const int nthreads, int nthreads,
const T* grad_output, const T* grad_output,
const int* channel_mapping, const int* channel_mapping,
const int num_rois, int num_rois,
const T spatial_scale, const T spatial_scale,
const int channels, int channels,
const int height, int height,
const int width, int width,
const int pooled_height, int pooled_height,
const int pooled_width, int pooled_width,
const int sampling_ratio, int sampling_ratio,
const int channels_out, int channels_out,
T* grad_input, T* grad_input,
const T* rois) { const T* rois) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { CUDA_1D_KERNEL_LOOP(index, nthreads) {
...@@ -295,10 +295,10 @@ __global__ void PSROIAlignBackwardCUDA( ...@@ -295,10 +295,10 @@ __global__ void PSROIAlignBackwardCUDA(
std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda( std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width, int64_t pooled_width,
const int sampling_ratio) { int64_t sampling_ratio) {
// Check if input tensors are CUDA tensors // Check if input tensors are CUDA tensors
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
...@@ -369,14 +369,14 @@ at::Tensor PSROIAlign_backward_cuda( ...@@ -369,14 +369,14 @@ at::Tensor PSROIAlign_backward_cuda(
const at::Tensor& grad, const at::Tensor& grad,
const at::Tensor& rois, const at::Tensor& rois,
const at::Tensor& channel_mapping, const at::Tensor& channel_mapping,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width, int64_t pooled_width,
const int sampling_ratio, int64_t sampling_ratio,
const int batch_size, int64_t batch_size,
const int channels, int64_t channels,
const int height, int64_t height,
const int width) { int64_t width) {
// Check if input tensors are CUDA tensors // Check if input tensors are CUDA tensors
TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor");
TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
......
...@@ -65,23 +65,23 @@ VISION_API at::Tensor PSROIPool_backward_cuda( ...@@ -65,23 +65,23 @@ VISION_API at::Tensor PSROIPool_backward_cuda(
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda( VISION_API std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width, int64_t pooled_width,
const int sampling_ratio); int64_t sampling_ratio);
VISION_API at::Tensor PSROIAlign_backward_cuda( VISION_API at::Tensor PSROIAlign_backward_cuda(
const at::Tensor& grad, const at::Tensor& grad,
const at::Tensor& rois, const at::Tensor& rois,
const at::Tensor& mapping_channel, const at::Tensor& channel_mapping,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width, int64_t pooled_width,
const int sampling_ratio, int64_t sampling_ratio,
const int batch_size, int64_t batch_size,
const int channels, int64_t channels,
const int height, int64_t height,
const int width); int64_t width);
VISION_API at::Tensor nms_cuda( VISION_API at::Tensor nms_cuda(
const at::Tensor& dets, const at::Tensor& dets,
......
...@@ -52,7 +52,10 @@ TORCH_LIBRARY(torchvision, m) { ...@@ -52,7 +52,10 @@ TORCH_LIBRARY(torchvision, m) {
"_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor"); "_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor");
m.def("roi_pool", &roi_pool); m.def("roi_pool", &roi_pool);
m.def("_new_empty_tensor_op", &new_empty_tensor); m.def("_new_empty_tensor_op", &new_empty_tensor);
m.def("ps_roi_align", &ps_roi_align); m.def(
"ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)");
m.def(
"_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor");
m.def("ps_roi_pool", &ps_roi_pool); m.def("ps_roi_pool", &ps_roi_pool);
m.def( m.def(
"deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor"); "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor");
...@@ -67,6 +70,8 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) { ...@@ -67,6 +70,8 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("deform_conv2d", DeformConv2d_forward_cpu); m.impl("deform_conv2d", DeformConv2d_forward_cpu);
m.impl("_deform_conv2d_backward", DeformConv2d_backward_cpu); m.impl("_deform_conv2d_backward", DeformConv2d_backward_cpu);
m.impl("nms", nms_cpu); m.impl("nms", nms_cpu);
m.impl("ps_roi_align", PSROIAlign_forward_cpu);
m.impl("_ps_roi_align_backward", PSROIAlign_backward_cpu);
} }
// TODO: Place this in a hypothetical separate torchvision_cuda library // TODO: Place this in a hypothetical separate torchvision_cuda library
...@@ -77,6 +82,8 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { ...@@ -77,6 +82,8 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("deform_conv2d", DeformConv2d_forward_cuda); m.impl("deform_conv2d", DeformConv2d_forward_cuda);
m.impl("_deform_conv2d_backward", DeformConv2d_backward_cuda); m.impl("_deform_conv2d_backward", DeformConv2d_backward_cuda);
m.impl("nms", nms_cuda); m.impl("nms", nms_cuda);
m.impl("ps_roi_align", PSROIAlign_forward_cuda);
m.impl("_ps_roi_align_backward", PSROIAlign_backward_cuda);
} }
#endif #endif
...@@ -86,6 +93,7 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { ...@@ -86,6 +93,7 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("roi_align", ROIAlign_autocast); m.impl("roi_align", ROIAlign_autocast);
m.impl("deform_conv2d", DeformConv2d_autocast); m.impl("deform_conv2d", DeformConv2d_autocast);
m.impl("nms", nms_autocast); m.impl("nms", nms_autocast);
m.impl("ps_roi_align", PSROIAlign_autocast);
} }
#endif #endif
...@@ -94,4 +102,6 @@ TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { ...@@ -94,4 +102,6 @@ TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("_roi_align_backward", ROIAlign_backward_autograd); m.impl("_roi_align_backward", ROIAlign_backward_autograd);
m.impl("deform_conv2d", DeformConv2d_autograd); m.impl("deform_conv2d", DeformConv2d_autograd);
m.impl("_deform_conv2d_backward", DeformConv2d_backward_autograd); m.impl("_deform_conv2d_backward", DeformConv2d_backward_autograd);
m.impl("ps_roi_align", PSROIAlign_autograd);
m.impl("_ps_roi_align_backward", PSROIAlign_backward_autograd);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment