"packaging/vscode:/vscode.git/clone" did not exist on "a664dd0a5ceb75ca853e276dbc063a45925850fc"
Unverified Commit 0125a7dc authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

ROIPool + Dispatcher + Autocast + Code Cleanup (#2922)

* Fixing types.

* Dispatcher + Autocast.

* + Autograd.

* Formating.

* Fixing return casting with autocast.

* Clean up and refactor ROIPool implementation:
- Remove primitive const declaration from method names.
- Using references when possible.

* Restore include headers.

* New line at end of file.
parent f0c92d85
......@@ -3,59 +3,64 @@
#include "cpu/vision_cpu.h"
#ifdef WITH_CUDA
#include "autocast.h"
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
#include "autocast.h"
#include "hip/vision_cuda.h"
#endif
std::tuple<at::Tensor, at::Tensor> ROIPool_forward(
// TODO: put this stuff in torchvision namespace
std::tuple<at::Tensor, at::Tensor> roi_pool(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
if (input.is_cuda()) {
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::roi_pool", "")
.typed<decltype(roi_pool)>();
return op.call(input, rois, spatial_scale, pooled_height, pooled_width);
}
#if defined(WITH_CUDA) || defined(WITH_HIP)
return ROIPool_forward_cuda(
input, rois, spatial_scale, pooled_height, pooled_width);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
#endif
}
return ROIPool_forward_cpu(
input, rois, spatial_scale, pooled_height, pooled_width);
std::tuple<at::Tensor, at::Tensor> ROIPool_autocast(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto result = roi_pool(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, rois),
spatial_scale,
pooled_height,
pooled_width);
return std::make_tuple(
std::get<0>(result).to(input.scalar_type()),
std::get<1>(result).to(input.scalar_type()));
}
#endif
at::Tensor ROIPool_backward(
at::Tensor _roi_pool_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width) {
if (grad.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return ROIPool_backward_cuda(
grad,
rois,
argmax,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width);
#else
TORCH_CHECK(false, "Not compiled with GPU support");
#endif
}
return ROIPool_backward_cpu(
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_roi_pool_backward", "")
.typed<decltype(_roi_pool_backward)>();
return op.call(
grad,
rois,
argmax,
......@@ -72,33 +77,36 @@ class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable input,
torch::autograd::Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
const torch::autograd::Variable& input,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["input_shape"] = input.sizes();
auto result = ROIPool_forward(
input, rois, spatial_scale, pooled_height, pooled_width);
at::AutoNonVariableTypeMode g;
auto result =
roi_pool(input, rois, spatial_scale, pooled_height, pooled_width);
auto output = std::get<0>(result);
auto argmax = std::get<1>(result);
ctx->save_for_backward({rois, argmax});
ctx->mark_non_differentiable({argmax});
return {output, argmax};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
const torch::autograd::variable_list& grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto argmax = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = ROIPool_backward(
auto grad_in = _roi_pool_backward(
grad_output[0],
rois,
argmax,
......@@ -109,6 +117,7 @@ class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
input_shape[1],
input_shape[2],
input_shape[3]);
return {grad_in,
torch::autograd::Variable(),
torch::autograd::Variable(),
......@@ -117,13 +126,77 @@ class ROIPoolFunction : public torch::autograd::Function<ROIPoolFunction> {
}
};
std::tuple<at::Tensor, at::Tensor> roi_pool(
// TODO: There should be an easier way to do this
class ROIPoolBackwardFunction
: public torch::autograd::Function<ROIPoolBackwardFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& grad,
const torch::autograd::Variable& rois,
const torch::autograd::Variable& argmax,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
at::AutoNonVariableTypeMode g;
auto grad_in = _roi_pool_backward(
grad,
rois,
argmax,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width);
return {grad_in};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on roi_pool not supported");
}
};
std::tuple<at::Tensor, at::Tensor> ROIPool_autograd(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
auto result = ROIPoolFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width);
return std::tuple<at::Tensor, at::Tensor>(result[0], result[1]);
return std::make_tuple(result[0], result[1]);
}
at::Tensor ROIPool_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
return ROIPoolBackwardFunction::apply(
grad,
rois,
argmax,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width)[0];
}
......@@ -12,13 +12,13 @@ template <typename T>
void RoIPoolForward(
const T* input,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
int channels,
int height,
int width,
int pooled_height,
int pooled_width,
const T* rois,
const int num_rois,
int num_rois,
T* output,
int* argmax_data) {
for (int n = 0; n < num_rois; ++n) {
......@@ -81,18 +81,18 @@ template <typename T>
void RoIPoolBackward(
const T* grad_output,
const int* argmax_data,
const int num_rois,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
int num_rois,
int channels,
int height,
int width,
int pooled_height,
int pooled_width,
T* grad_input,
const T* rois,
const int n_stride,
const int c_stride,
const int h_stride,
const int w_stride) {
int n_stride,
int c_stride,
int h_stride,
int w_stride) {
for (int n = 0; n < num_rois; ++n) {
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
......@@ -123,9 +123,9 @@ void RoIPoolBackward(
std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width) {
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
......@@ -172,13 +172,13 @@ at::Tensor ROIPool_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width) {
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
// Check if input tensors are CPU tensors
TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
......
......@@ -5,21 +5,21 @@
VISION_API std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width);
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width);
VISION_API at::Tensor ROIPool_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width);
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width);
VISION_API at::Tensor ROIAlign_forward_cpu(
const at::Tensor& input,
......
......@@ -8,14 +8,14 @@
template <typename T>
__global__ void RoIPoolForward(
const int nthreads,
int nthreads,
const T* input,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
int channels,
int height,
int width,
int pooled_height,
int pooled_width,
const T* rois,
T* output,
int* argmax_data) {
......@@ -73,22 +73,22 @@ __global__ void RoIPoolForward(
template <typename T>
__global__ void RoIPoolBackward(
const int nthreads,
int nthreads,
const T* grad_output,
const int* argmax_data,
const int num_rois,
int num_rois,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
int channels,
int height,
int width,
int pooled_height,
int pooled_width,
T* grad_input,
const T* rois,
const int n_stride,
const int c_stride,
const int h_stride,
const int w_stride) {
int n_stride,
int c_stride,
int h_stride,
int w_stride) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
......@@ -118,9 +118,9 @@ __global__ void RoIPoolBackward(
std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width) {
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
TORCH_CHECK(
......@@ -182,13 +182,13 @@ at::Tensor ROIPool_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width) {
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
// Check if input tensors are CUDA tensors
TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor");
TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
......
......@@ -27,21 +27,21 @@ VISION_API at::Tensor ROIAlign_backward_cuda(
VISION_API std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width);
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width);
VISION_API at::Tensor ROIPool_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width);
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width);
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cuda(
const at::Tensor& input,
......
......@@ -50,7 +50,10 @@ TORCH_LIBRARY(torchvision, m) {
"roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor");
m.def(
"_roi_align_backward(Tensor grad, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width, int sampling_ratio, bool aligned) -> Tensor");
m.def("roi_pool", &roi_pool);
m.def(
"roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)");
m.def(
"_roi_pool_backward(Tensor grad, Tensor rois, Tensor argmax, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor");
m.def("_new_empty_tensor_op", &new_empty_tensor);
m.def(
"ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)");
......@@ -67,6 +70,8 @@ TORCH_LIBRARY(torchvision, m) {
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("roi_align", ROIAlign_forward_cpu);
m.impl("_roi_align_backward", ROIAlign_backward_cpu);
m.impl("roi_pool", ROIPool_forward_cpu);
m.impl("_roi_pool_backward", ROIPool_backward_cpu);
m.impl("deform_conv2d", DeformConv2d_forward_cpu);
m.impl("_deform_conv2d_backward", DeformConv2d_backward_cpu);
m.impl("nms", nms_cpu);
......@@ -79,6 +84,8 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("roi_align", ROIAlign_forward_cuda);
m.impl("_roi_align_backward", ROIAlign_backward_cuda);
m.impl("roi_pool", ROIPool_forward_cuda);
m.impl("_roi_pool_backward", ROIPool_backward_cuda);
m.impl("deform_conv2d", DeformConv2d_forward_cuda);
m.impl("_deform_conv2d_backward", DeformConv2d_backward_cuda);
m.impl("nms", nms_cuda);
......@@ -91,6 +98,7 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("roi_align", ROIAlign_autocast);
m.impl("roi_pool", ROIPool_autocast);
m.impl("deform_conv2d", DeformConv2d_autocast);
m.impl("nms", nms_autocast);
m.impl("ps_roi_align", PSROIAlign_autocast);
......@@ -100,6 +108,8 @@ TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
m.impl("roi_align", ROIAlign_autograd);
m.impl("_roi_align_backward", ROIAlign_backward_autograd);
m.impl("roi_pool", ROIPool_autograd);
m.impl("_roi_pool_backward", ROIPool_backward_autograd);
m.impl("deform_conv2d", DeformConv2d_autograd);
m.impl("_deform_conv2d_backward", DeformConv2d_backward_autograd);
m.impl("ps_roi_align", PSROIAlign_autograd);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment