Unverified Commit b93d5ee2 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

PSROIPool + Dispatcher + Autocast + Code Cleanup (#2926)

* Fixing types.

* Dispatcher + Autocast.

* + Autograd.

* Formating.

* Clean up and refactor PSROIPool implementation:
- Remove primitive const declaration from method names.
- Using references when possible.
- Fix variable naming.

* Restore include headers.

* New line at end of file.

* Resolving conflict, final cleanup, ordering method consistently across files.
parent 0125a7dc
...@@ -223,4 +223,4 @@ at::Tensor PSROIAlign_backward_autograd( ...@@ -223,4 +223,4 @@ at::Tensor PSROIAlign_backward_autograd(
channels, channels,
height, height,
width)[0]; width)[0];
} }
\ No newline at end of file
...@@ -3,62 +3,68 @@ ...@@ -3,62 +3,68 @@
#include "cpu/vision_cpu.h" #include "cpu/vision_cpu.h"
#ifdef WITH_CUDA #ifdef WITH_CUDA
#include "autocast.h"
#include "cuda/vision_cuda.h" #include "cuda/vision_cuda.h"
#endif #endif
#ifdef WITH_HIP #ifdef WITH_HIP
#include "autocast.h"
#include "hip/vision_cuda.h" #include "hip/vision_cuda.h"
#endif #endif
std::tuple<at::Tensor, at::Tensor> PSROIPool_forward( // TODO: put this stuff in torchvision namespace
std::tuple<at::Tensor, at::Tensor> ps_roi_pool(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width) { int64_t pooled_width) {
if (input.is_cuda()) { static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::ps_roi_pool", "")
.typed<decltype(ps_roi_pool)>();
return op.call(input, rois, spatial_scale, pooled_height, pooled_width);
}
#if defined(WITH_CUDA) || defined(WITH_HIP) #if defined(WITH_CUDA) || defined(WITH_HIP)
return PSROIPool_forward_cuda( std::tuple<at::Tensor, at::Tensor> PSROIPool_autocast(
input, rois, spatial_scale, pooled_height, pooled_width); const at::Tensor& input,
#else const at::Tensor& rois,
TORCH_CHECK(false, "Not compiled with GPU support"); double spatial_scale,
#endif int64_t pooled_height,
} int64_t pooled_width) {
return PSROIPool_forward_cpu( c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
input, rois, spatial_scale, pooled_height, pooled_width); auto result = ps_roi_pool(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, rois),
spatial_scale,
pooled_height,
pooled_width);
return std::make_tuple(
std::get<0>(result).to(input.scalar_type()),
std::get<1>(result).to(input.scalar_type()));
} }
#endif
at::Tensor PSROIPool_backward( at::Tensor _ps_roi_pool_backward(
const at::Tensor& grad, const at::Tensor& grad,
const at::Tensor& rois, const at::Tensor& rois,
const at::Tensor& mapping_channel, const at::Tensor& channel_mapping,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width, int64_t pooled_width,
const int batch_size, int64_t batch_size,
const int channels, int64_t channels,
const int height, int64_t height,
const int width) { int64_t width) {
if (grad.is_cuda()) { static auto op =
#if defined(WITH_CUDA) || defined(WITH_HIP) c10::Dispatcher::singleton()
return PSROIPool_backward_cuda( .findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "")
grad, .typed<decltype(_ps_roi_pool_backward)>();
rois, return op.call(
mapping_channel,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
#endif
}
return PSROIPool_backward_cpu(
grad, grad,
rois, rois,
mapping_channel, channel_mapping,
spatial_scale, spatial_scale,
pooled_height, pooled_height,
pooled_width, pooled_width,
...@@ -72,33 +78,36 @@ class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> { ...@@ -72,33 +78,36 @@ class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
public: public:
static torch::autograd::variable_list forward( static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx, torch::autograd::AutogradContext* ctx,
torch::autograd::Variable input, const torch::autograd::Variable& input,
torch::autograd::Variable rois, const torch::autograd::Variable& rois,
const double spatial_scale, double spatial_scale,
const int64_t pooled_height, int64_t pooled_height,
const int64_t pooled_width) { int64_t pooled_width) {
ctx->saved_data["spatial_scale"] = spatial_scale; ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height; ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width; ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["input_shape"] = input.sizes(); ctx->saved_data["input_shape"] = input.sizes();
auto result = PSROIPool_forward( at::AutoNonVariableTypeMode g;
input, rois, spatial_scale, pooled_height, pooled_width); auto result =
ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width);
auto output = std::get<0>(result); auto output = std::get<0>(result);
auto channel_mapping = std::get<1>(result); auto channel_mapping = std::get<1>(result);
ctx->save_for_backward({rois, channel_mapping}); ctx->save_for_backward({rois, channel_mapping});
ctx->mark_non_differentiable({channel_mapping}); ctx->mark_non_differentiable({channel_mapping});
return {output, channel_mapping}; return {output, channel_mapping};
} }
static torch::autograd::variable_list backward( static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx, torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) { const torch::autograd::variable_list& grad_output) {
// Use data saved in forward // Use data saved in forward
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto rois = saved[0]; auto rois = saved[0];
auto channel_mapping = saved[1]; auto channel_mapping = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList(); auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = PSROIPool_backward( auto grad_in = _ps_roi_pool_backward(
grad_output[0], grad_output[0],
rois, rois,
channel_mapping, channel_mapping,
...@@ -109,6 +118,7 @@ class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> { ...@@ -109,6 +118,7 @@ class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
input_shape[1], input_shape[1],
input_shape[2], input_shape[2],
input_shape[3]); input_shape[3]);
return {grad_in, return {grad_in,
torch::autograd::Variable(), torch::autograd::Variable(),
torch::autograd::Variable(), torch::autograd::Variable(),
...@@ -117,13 +127,77 @@ class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> { ...@@ -117,13 +127,77 @@ class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
} }
}; };
std::tuple<at::Tensor, at::Tensor> ps_roi_pool( // TODO: There should be an easier way to do this
class PSROIPoolBackwardFunction
: public torch::autograd::Function<PSROIPoolBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& grad,
const torch::autograd::Variable& rois,
const torch::autograd::Variable& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
at::AutoNonVariableTypeMode g;
auto grad_in = _ps_roi_pool_backward(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width);
return {grad_in};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on ps_roi_pool not supported");
}
};
std::tuple<at::Tensor, at::Tensor> PSROIPool_autograd(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
const double spatial_scale, double spatial_scale,
const int64_t pooled_height, int64_t pooled_height,
const int64_t pooled_width) { int64_t pooled_width) {
auto result = PSROIPoolFunction::apply( auto result = PSROIPoolFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width); input, rois, spatial_scale, pooled_height, pooled_width);
return std::tuple<at::Tensor, at::Tensor>(result[0], result[1]);
return std::make_tuple(result[0], result[1]);
}
at::Tensor PSROIPool_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
return PSROIPoolBackwardFunction::apply(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width)[0];
} }
...@@ -12,14 +12,14 @@ template <typename T> ...@@ -12,14 +12,14 @@ template <typename T>
void PSROIPoolForward( void PSROIPoolForward(
const T* input, const T* input,
const T spatial_scale, const T spatial_scale,
const int channels, int channels,
const int height, int height,
const int width, int width,
const int pooled_height, int pooled_height,
const int pooled_width, int pooled_width,
const T* rois, const T* rois,
const int channels_out, int channels_out,
const int num_rois, int num_rois,
T* output, T* output,
int* channel_mapping) { int* channel_mapping) {
for (int n = 0; n < num_rois; ++n) { for (int n = 0; n < num_rois; ++n) {
...@@ -82,14 +82,14 @@ template <typename T> ...@@ -82,14 +82,14 @@ template <typename T>
void PSROIPoolBackward( void PSROIPoolBackward(
const T* grad_output, const T* grad_output,
const int* channel_mapping, const int* channel_mapping,
const int num_rois, int num_rois,
const T spatial_scale, const T spatial_scale,
const int channels, int channels,
const int height, int height,
const int width, int width,
const int pooled_height, int pooled_height,
const int pooled_width, int pooled_width,
const int channels_out, int channels_out,
T* grad_input, T* grad_input,
const T* rois) { const T* rois) {
for (int n = 0; n < num_rois; ++n) { for (int n = 0; n < num_rois; ++n) {
...@@ -146,9 +146,9 @@ void PSROIPoolBackward( ...@@ -146,9 +146,9 @@ void PSROIPoolBackward(
std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu( std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width) { int64_t pooled_width) {
// Check if input tensors are CPU tensors // Check if input tensors are CPU tensors
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
...@@ -204,13 +204,13 @@ at::Tensor PSROIPool_backward_cpu( ...@@ -204,13 +204,13 @@ at::Tensor PSROIPool_backward_cpu(
const at::Tensor& grad, const at::Tensor& grad,
const at::Tensor& rois, const at::Tensor& rois,
const at::Tensor& channel_mapping, const at::Tensor& channel_mapping,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width, int64_t pooled_width,
const int batch_size, int64_t batch_size,
const int channels, int64_t channels,
const int height, int64_t height,
const int width) { int64_t width) {
// Check if input tensors are CPU tensors // Check if input tensors are CPU tensors
TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor"); TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor"); TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
......
...@@ -2,17 +2,73 @@ ...@@ -2,17 +2,73 @@
#include <torch/extension.h> #include <torch/extension.h>
#include "../macros.h" #include "../macros.h"
VISION_API std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu( VISION_API at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
VISION_API std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cpu(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
VISION_API at::Tensor nms_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold);
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio);
VISION_API at::Tensor PSROIAlign_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width);
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
double spatial_scale, double spatial_scale,
int64_t pooled_height, int64_t pooled_height,
int64_t pooled_width); int64_t pooled_width);
VISION_API at::Tensor ROIPool_backward_cpu( VISION_API at::Tensor PSROIPool_backward_cpu(
const at::Tensor& grad, const at::Tensor& grad,
const at::Tensor& rois, const at::Tensor& rois,
const at::Tensor& argmax, const at::Tensor& channel_mapping,
double spatial_scale, double spatial_scale,
int64_t pooled_height, int64_t pooled_height,
int64_t pooled_width, int64_t pooled_width,
...@@ -43,77 +99,21 @@ VISION_API at::Tensor ROIAlign_backward_cpu( ...@@ -43,77 +99,21 @@ VISION_API at::Tensor ROIAlign_backward_cpu(
int64_t sampling_ratio, int64_t sampling_ratio,
bool aligned); bool aligned);
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu( VISION_API std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width);
VISION_API at::Tensor PSROIPool_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& mapping_channel,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width);
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
double spatial_scale, double spatial_scale,
int64_t pooled_height, int64_t pooled_height,
int64_t pooled_width, int64_t pooled_width);
int64_t sampling_ratio);
VISION_API at::Tensor PSROIAlign_backward_cpu( VISION_API at::Tensor ROIPool_backward_cpu(
const at::Tensor& grad, const at::Tensor& grad,
const at::Tensor& rois, const at::Tensor& rois,
const at::Tensor& channel_mapping, const at::Tensor& argmax,
double spatial_scale, double spatial_scale,
int64_t pooled_height, int64_t pooled_height,
int64_t pooled_width, int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size, int64_t batch_size,
int64_t channels, int64_t channels,
int64_t height, int64_t height,
int64_t width); int64_t width);
VISION_API at::Tensor nms_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold);
VISION_API at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
VISION_API std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cpu(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
...@@ -8,16 +8,16 @@ ...@@ -8,16 +8,16 @@
template <typename T> template <typename T>
__global__ void PSROIPoolForward( __global__ void PSROIPoolForward(
const int nthreads, int nthreads,
const T* input, const T* input,
const T spatial_scale, const T spatial_scale,
const int channels, int channels,
const int height, int height,
const int width, int width,
const int pooled_height, int pooled_height,
const int pooled_width, int pooled_width,
const T* rois, const T* rois,
const int channels_out, int channels_out,
T* output, T* output,
int* channel_mapping) { int* channel_mapping) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { CUDA_1D_KERNEL_LOOP(index, nthreads) {
...@@ -74,17 +74,17 @@ __global__ void PSROIPoolForward( ...@@ -74,17 +74,17 @@ __global__ void PSROIPoolForward(
template <typename T> template <typename T>
__global__ void PSROIPoolBackward( __global__ void PSROIPoolBackward(
const int nthreads, int nthreads,
const T* grad_output, const T* grad_output,
const int* channel_mapping, const int* channel_mapping,
const int num_rois, int num_rois,
const T spatial_scale, const T spatial_scale,
const int channels, int channels,
const int height, int height,
const int width, int width,
const int pooled_height, int pooled_height,
const int pooled_width, int pooled_width,
const int channels_out, int channels_out,
T* grad_input, T* grad_input,
const T* rois) { const T* rois) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { CUDA_1D_KERNEL_LOOP(index, nthreads) {
...@@ -135,9 +135,9 @@ __global__ void PSROIPoolBackward( ...@@ -135,9 +135,9 @@ __global__ void PSROIPoolBackward(
std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cuda( std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cuda(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width) { int64_t pooled_width) {
// Check if input tensors are CUDA tensors // Check if input tensors are CUDA tensors
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
...@@ -206,13 +206,13 @@ at::Tensor PSROIPool_backward_cuda( ...@@ -206,13 +206,13 @@ at::Tensor PSROIPool_backward_cuda(
const at::Tensor& grad, const at::Tensor& grad,
const at::Tensor& rois, const at::Tensor& rois,
const at::Tensor& channel_mapping, const at::Tensor& channel_mapping,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width, int64_t pooled_width,
const int batch_size, int64_t batch_size,
const int channels, int64_t channels,
const int height, int64_t height,
const int width) { int64_t width) {
// Check if input tensors are CUDA tensors // Check if input tensors are CUDA tensors
TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor");
TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor"); TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
......
...@@ -2,118 +2,118 @@ ...@@ -2,118 +2,118 @@
#include <torch/extension.h> #include <torch/extension.h>
#include "../macros.h" #include "../macros.h"
VISION_API at::Tensor ROIAlign_forward_cuda( VISION_API at::Tensor DeformConv2d_forward_cuda(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
VISION_API std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
DeformConv2d_backward_cuda(
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
VISION_API at::Tensor nms_cuda(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold);
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
double spatial_scale, double spatial_scale,
int64_t pooled_height, int64_t pooled_height,
int64_t pooled_width, int64_t pooled_width,
int64_t sampling_ratio, int64_t sampling_ratio);
bool aligned);
VISION_API at::Tensor ROIAlign_backward_cuda( VISION_API at::Tensor PSROIAlign_backward_cuda(
const at::Tensor& grad, const at::Tensor& grad,
const at::Tensor& rois, const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale, double spatial_scale,
int64_t pooled_height, int64_t pooled_height,
int64_t pooled_width, int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size, int64_t batch_size,
int64_t channels, int64_t channels,
int64_t height, int64_t height,
int64_t width, int64_t width);
int64_t sampling_ratio,
bool aligned);
VISION_API std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width);
VISION_API at::Tensor ROIPool_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width);
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cuda( VISION_API std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cuda(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width); int64_t pooled_width);
VISION_API at::Tensor PSROIPool_backward_cuda( VISION_API at::Tensor PSROIPool_backward_cuda(
const at::Tensor& grad, const at::Tensor& grad,
const at::Tensor& rois, const at::Tensor& rois,
const at::Tensor& mapping_channel, const at::Tensor& channel_mapping,
const float spatial_scale, double spatial_scale,
const int pooled_height, int64_t pooled_height,
const int pooled_width, int64_t pooled_width,
const int batch_size, int64_t batch_size,
const int channels, int64_t channels,
const int height, int64_t height,
const int width); int64_t width);
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda( VISION_API at::Tensor ROIAlign_forward_cuda(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
double spatial_scale, double spatial_scale,
int64_t pooled_height, int64_t pooled_height,
int64_t pooled_width, int64_t pooled_width,
int64_t sampling_ratio); int64_t sampling_ratio,
bool aligned);
VISION_API at::Tensor PSROIAlign_backward_cuda( VISION_API at::Tensor ROIAlign_backward_cuda(
const at::Tensor& grad, const at::Tensor& grad,
const at::Tensor& rois, const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale, double spatial_scale,
int64_t pooled_height, int64_t pooled_height,
int64_t pooled_width, int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size, int64_t batch_size,
int64_t channels, int64_t channels,
int64_t height, int64_t height,
int64_t width); int64_t width,
int64_t sampling_ratio,
VISION_API at::Tensor nms_cuda( bool aligned);
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold);
VISION_API at::Tensor DeformConv2d_forward_cuda( VISION_API std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& weight, const at::Tensor& rois,
const at::Tensor& offset, const double spatial_scale,
const at::Tensor& bias, const int64_t pooled_height,
int64_t stride_h, const int64_t pooled_width);
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
VISION_API std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> VISION_API at::Tensor ROIPool_backward_cuda(
DeformConv2d_backward_cuda( const at::Tensor& grad,
const at::Tensor& grad_out, const at::Tensor& rois,
const at::Tensor& input, const at::Tensor& argmax,
const at::Tensor& weight, const double spatial_scale,
const at::Tensor& offset, const int64_t pooled_height,
const at::Tensor& bias, const int64_t pooled_width,
int64_t stride_h, const int64_t batch_size,
int64_t stride_w, const int64_t channels,
int64_t pad_h, const int64_t height,
int64_t pad_w, const int64_t width);
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t deformable_groups);
...@@ -45,73 +45,83 @@ int64_t cuda_version() noexcept { ...@@ -45,73 +45,83 @@ int64_t cuda_version() noexcept {
} // namespace vision } // namespace vision
TORCH_LIBRARY(torchvision, m) { TORCH_LIBRARY(torchvision, m) {
m.def(
"deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor");
m.def(
"_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> (Tensor, Tensor, Tensor, Tensor)");
m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"); m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor");
m.def( m.def(
"roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor"); "ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)");
m.def( m.def(
"_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor"); "_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor");
m.def( m.def(
"roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)"); "ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
m.def( m.def(
"_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor"); "_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
m.def("_new_empty_tensor_op", &new_empty_tensor);
m.def( m.def(
"ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)"); "roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor");
m.def( m.def(
"_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor"); "_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor");
m.def("ps_roi_pool", &ps_roi_pool);
m.def( m.def(
"deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> Tensor"); "roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
m.def( m.def(
"_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups) -> (Tensor, Tensor, Tensor, Tensor)"); "_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
m.def("_cuda_version", &vision::cuda_version); m.def("_cuda_version", &vision::cuda_version);
m.def("_new_empty_tensor_op", &new_empty_tensor);
} }
TORCH_LIBRARY_IMPL(torchvision, CPU, m) { TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("roi_align", ROIAlign_forward_cpu);
m.impl("_roi_align_backward", ROIAlign_backward_cpu);
m.impl("roi_pool", ROIPool_forward_cpu);
m.impl("_roi_pool_backward", ROIPool_backward_cpu);
m.impl("deform_conv2d", DeformConv2d_forward_cpu); m.impl("deform_conv2d", DeformConv2d_forward_cpu);
m.impl("_deform_conv2d_backward", DeformConv2d_backward_cpu); m.impl("_deform_conv2d_backward", DeformConv2d_backward_cpu);
m.impl("nms", nms_cpu); m.impl("nms", nms_cpu);
m.impl("ps_roi_align", PSROIAlign_forward_cpu); m.impl("ps_roi_align", PSROIAlign_forward_cpu);
m.impl("_ps_roi_align_backward", PSROIAlign_backward_cpu); m.impl("_ps_roi_align_backward", PSROIAlign_backward_cpu);
m.impl("ps_roi_pool", PSROIPool_forward_cpu);
m.impl("_ps_roi_pool_backward", PSROIPool_backward_cpu);
m.impl("roi_align", ROIAlign_forward_cpu);
m.impl("_roi_align_backward", ROIAlign_backward_cpu);
m.impl("roi_pool", ROIPool_forward_cpu);
m.impl("_roi_pool_backward", ROIPool_backward_cpu);
} }
// TODO: Place this in a hypothetical separate torchvision_cuda library // TODO: Place this in a hypothetical separate torchvision_cuda library
#if defined(WITH_CUDA) || defined(WITH_HIP) #if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("roi_align", ROIAlign_forward_cuda);
m.impl("_roi_align_backward", ROIAlign_backward_cuda);
m.impl("roi_pool", ROIPool_forward_cuda);
m.impl("_roi_pool_backward", ROIPool_backward_cuda);
m.impl("deform_conv2d", DeformConv2d_forward_cuda); m.impl("deform_conv2d", DeformConv2d_forward_cuda);
m.impl("_deform_conv2d_backward", DeformConv2d_backward_cuda); m.impl("_deform_conv2d_backward", DeformConv2d_backward_cuda);
m.impl("nms", nms_cuda); m.impl("nms", nms_cuda);
m.impl("ps_roi_align", PSROIAlign_forward_cuda); m.impl("ps_roi_align", PSROIAlign_forward_cuda);
m.impl("_ps_roi_align_backward", PSROIAlign_backward_cuda); m.impl("_ps_roi_align_backward", PSROIAlign_backward_cuda);
m.impl("ps_roi_pool", PSROIPool_forward_cuda);
m.impl("_ps_roi_pool_backward", PSROIPool_backward_cuda);
m.impl("roi_align", ROIAlign_forward_cuda);
m.impl("_roi_align_backward", ROIAlign_backward_cuda);
m.impl("roi_pool", ROIPool_forward_cuda);
m.impl("_roi_pool_backward", ROIPool_backward_cuda);
} }
#endif #endif
// Autocast only needs to wrap forward pass ops. // Autocast only needs to wrap forward pass ops.
#if defined(WITH_CUDA) || defined(WITH_HIP) #if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("roi_align", ROIAlign_autocast);
m.impl("roi_pool", ROIPool_autocast);
m.impl("deform_conv2d", DeformConv2d_autocast); m.impl("deform_conv2d", DeformConv2d_autocast);
m.impl("nms", nms_autocast); m.impl("nms", nms_autocast);
m.impl("ps_roi_align", PSROIAlign_autocast); m.impl("ps_roi_align", PSROIAlign_autocast);
m.impl("ps_roi_pool", PSROIPool_autocast);
m.impl("roi_align", ROIAlign_autocast);
m.impl("roi_pool", ROIPool_autocast);
} }
#endif #endif
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) { TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_align", ROIAlign_autograd);
m.impl("_roi_align_backward", ROIAlign_backward_autograd);
m.impl("roi_pool", ROIPool_autograd);
m.impl("_roi_pool_backward", ROIPool_backward_autograd);
m.impl("deform_conv2d", DeformConv2d_autograd); m.impl("deform_conv2d", DeformConv2d_autograd);
m.impl("_deform_conv2d_backward", DeformConv2d_backward_autograd); m.impl("_deform_conv2d_backward", DeformConv2d_backward_autograd);
m.impl("ps_roi_align", PSROIAlign_autograd); m.impl("ps_roi_align", PSROIAlign_autograd);
m.impl("_ps_roi_align_backward", PSROIAlign_backward_autograd); m.impl("_ps_roi_align_backward", PSROIAlign_backward_autograd);
m.impl("ps_roi_pool", PSROIPool_autograd);
m.impl("_ps_roi_pool_backward", PSROIPool_backward_autograd);
m.impl("roi_align", ROIAlign_autograd);
m.impl("_roi_align_backward", ROIAlign_backward_autograd);
m.impl("roi_pool", ROIPool_autograd);
m.impl("_roi_pool_backward", ROIPool_backward_autograd);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment