"git@developer.sourcefind.cn:modelzoo/InfiniteYou_pytorch.git" did not exist on "bc281f4d113545207a1c884e7a87aeae4e1846ca"
Unverified Commit f0c92d85 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

ROIAlign code cleanup (#2906)

* Clean up and refactor ROIAlign implementation:
- Remove primitive const declaration from method names.
- Passing as const ref instead of value where possible.
- Remove unnecessary headers.

* Adding back include for cpu.

* Restore include headers.
parent b06e43d6
......@@ -17,12 +17,12 @@
at::Tensor roi_align(
const at::Tensor& input, // Input feature map.
const at::Tensor& rois, // List of ROIs to pool over.
const double spatial_scale, // The scale of the image features. ROIs will be
double spatial_scale, // The scale of the image features. ROIs will be
// scaled to this.
const int64_t pooled_height, // The height of the pooled feature map.
const int64_t pooled_width, // The width of the pooled feature
const int64_t sampling_ratio, // The number of points to sample in each bin
const bool aligned) // The flag for pixel shift
int64_t pooled_height, // The height of the pooled feature map.
int64_t pooled_width, // The width of the pooled feature
int64_t sampling_ratio, // The number of points to sample in each bin
bool aligned) // The flag for pixel shift
// along each axis.
{
static auto op = c10::Dispatcher::singleton()
......@@ -42,11 +42,11 @@ at::Tensor roi_align(
at::Tensor ROIAlign_autocast(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const bool aligned) {
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return roi_align(
at::autocast::cached_cast(at::kFloat, input),
......@@ -63,15 +63,15 @@ at::Tensor ROIAlign_autocast(
at::Tensor _roi_align_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned) {
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_roi_align_backward", "")
......@@ -94,13 +94,13 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable input,
torch::autograd::Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const bool aligned) {
const torch::autograd::Variable& input,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
......@@ -122,7 +122,7 @@ class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
const torch::autograd::variable_list& grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
......@@ -155,17 +155,17 @@ class ROIAlignBackwardFunction
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
torch::autograd::Variable grad,
torch::autograd::Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned) {
const torch::autograd::Variable& grad,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned) {
at::AutoNonVariableTypeMode g;
auto result = _roi_align_backward(
grad,
......@@ -184,7 +184,7 @@ class ROIAlignBackwardFunction
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output) {
const torch::autograd::variable_list& grad_output) {
TORCH_CHECK(0, "double backwards on roi_align not supported");
}
};
......@@ -192,11 +192,11 @@ class ROIAlignBackwardFunction
at::Tensor ROIAlign_autograd(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const bool aligned) {
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
return ROIAlignFunction::apply(
input,
rois,
......@@ -210,15 +210,15 @@ at::Tensor ROIAlign_autograd(
at::Tensor ROIAlign_backward_autograd(
const at::Tensor& grad,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned) {
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned) {
return ROIAlignBackwardFunction::apply(
grad,
rois,
......
......@@ -16,12 +16,12 @@ struct PreCalc {
template <typename T>
void pre_calc_for_bilinear_interpolate(
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int iy_upper,
const int ix_upper,
int height,
int width,
int pooled_height,
int pooled_width,
int iy_upper,
int ix_upper,
T roi_start_h,
T roi_start_w,
T bin_size_h,
......@@ -112,16 +112,16 @@ void pre_calc_for_bilinear_interpolate(
template <typename T>
void ROIAlignForward(
const int nthreads,
int nthreads,
const T* input,
const T& spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const bool aligned,
int channels,
int height,
int width,
int pooled_height,
int pooled_width,
int sampling_ratio,
bool aligned,
const T* rois,
T* output) {
int n_rois = nthreads / channels / pooled_width / pooled_height;
......@@ -214,8 +214,8 @@ void ROIAlignForward(
template <typename T>
void bilinear_interpolate_gradient(
const int height,
const int width,
int height,
int width,
T y,
T x,
T& w1,
......@@ -226,7 +226,7 @@ void bilinear_interpolate_gradient(
int& x_high,
int& y_low,
int& y_high,
const int index /* index for debug only*/) {
int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
......@@ -278,22 +278,22 @@ inline void add(T* address, const T& val) {
template <typename T>
void ROIAlignBackward(
const int nthreads,
int nthreads,
const T* grad_output,
const T& spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const bool aligned,
int channels,
int height,
int width,
int pooled_height,
int pooled_width,
int sampling_ratio,
bool aligned,
T* grad_input,
const T* rois,
const int n_stride,
const int c_stride,
const int h_stride,
const int w_stride) {
int n_stride,
int c_stride,
int h_stride,
int w_stride) {
for (int index = 0; index < nthreads; index++) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
......@@ -387,11 +387,11 @@ void ROIAlignBackward(
at::Tensor ROIAlign_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const bool aligned) {
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
......@@ -437,15 +437,15 @@ at::Tensor ROIAlign_forward_cpu(
at::Tensor ROIAlign_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned) {
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned) {
TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor");
TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
......
......@@ -24,24 +24,24 @@ VISION_API at::Tensor ROIPool_backward_cpu(
VISION_API at::Tensor ROIAlign_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const bool aligned);
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned);
VISION_API at::Tensor ROIAlign_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned);
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned);
VISION_API std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu(
const at::Tensor& input,
......
......@@ -9,11 +9,11 @@
template <typename T>
__device__ T bilinear_interpolate(
const T* input,
const int height,
const int width,
int height,
int width,
T y,
T x,
const int index /* index for debug only*/) {
int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
......@@ -62,16 +62,16 @@ __device__ T bilinear_interpolate(
template <typename T>
__global__ void RoIAlignForward(
const int nthreads,
int nthreads,
const T* input,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const bool aligned,
int channels,
int height,
int width,
int pooled_height,
int pooled_width,
int sampling_ratio,
bool aligned,
const T* rois,
T* output) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
......@@ -139,8 +139,8 @@ __global__ void RoIAlignForward(
template <typename T>
__device__ void bilinear_interpolate_gradient(
const int height,
const int width,
int height,
int width,
T y,
T x,
T& w1,
......@@ -151,7 +151,7 @@ __device__ void bilinear_interpolate_gradient(
int& x_high,
int& y_low,
int& y_high,
const int index /* index for debug only*/) {
int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
......@@ -198,22 +198,22 @@ __device__ void bilinear_interpolate_gradient(
template <typename T>
__global__ void RoIAlignBackward(
const int nthreads,
int nthreads,
const T* grad_output,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const bool aligned,
int channels,
int height,
int width,
int pooled_height,
int pooled_width,
int sampling_ratio,
bool aligned,
T* grad_input,
const T* rois,
const int n_stride,
const int c_stride,
const int h_stride,
const int w_stride) {
int n_stride,
int c_stride,
int h_stride,
int w_stride) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
......@@ -313,11 +313,11 @@ __global__ void RoIAlignBackward(
at::Tensor ROIAlign_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const bool aligned) {
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
TORCH_CHECK(
......@@ -376,15 +376,15 @@ at::Tensor ROIAlign_forward_cuda(
at::Tensor ROIAlign_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned) {
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned) {
TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor");
TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
......
......@@ -5,24 +5,24 @@
VISION_API at::Tensor ROIAlign_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const bool aligned);
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned);
VISION_API at::Tensor ROIAlign_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t batch_size,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t sampling_ratio,
const bool aligned);
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned);
VISION_API std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(
const at::Tensor& input,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment