"vscode:/vscode.git/clone" did not exist on "b0a9d16f25a30eed4459c462a6ff5ca977645f15"
Unverified Commit 02a1918a authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Minor cleanup of roi_align_forward_kernel_impl (#3619)



* minor clean up

* do same for ps_roialign
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 591c899c
...@@ -62,7 +62,7 @@ T bilinear_interpolate( ...@@ -62,7 +62,7 @@ T bilinear_interpolate(
template <typename T> template <typename T>
void ps_roi_align_forward_kernel_impl( void ps_roi_align_forward_kernel_impl(
int nthreads, int num_rois,
const T* input, const T* input,
const T spatial_scale, const T spatial_scale,
int channels, int channels,
...@@ -75,7 +75,6 @@ void ps_roi_align_forward_kernel_impl( ...@@ -75,7 +75,6 @@ void ps_roi_align_forward_kernel_impl(
int channels_out, int channels_out,
T* output, T* output,
int* channel_mapping) { int* channel_mapping) {
int num_rois = nthreads / channels_out / pooled_width / pooled_height;
for (int n = 0; n < num_rois; n++) { for (int n = 0; n < num_rois; n++) {
// [start, end) interval for spatial sampling // [start, end) interval for spatial sampling
const T* offset_rois = rois + n * 5; const T* offset_rois = rois + n * 5;
...@@ -335,8 +334,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_kernel( ...@@ -335,8 +334,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_kernel(
auto channel_mapping = auto channel_mapping =
at::zeros(output.sizes(), input.options().dtype(at::kInt)); at::zeros(output.sizes(), input.options().dtype(at::kInt));
auto output_size = output.numel(); if (output.numel() == 0) {
if (output_size == 0) {
return std::make_tuple(output, channel_mapping); return std::make_tuple(output, channel_mapping);
} }
...@@ -344,7 +342,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_kernel( ...@@ -344,7 +342,7 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_kernel(
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "ps_roi_align_forward_kernel", [&] { input.scalar_type(), "ps_roi_align_forward_kernel", [&] {
ps_roi_align_forward_kernel_impl<scalar_t>( ps_roi_align_forward_kernel_impl<scalar_t>(
output_size, num_rois,
input_.data_ptr<scalar_t>(), input_.data_ptr<scalar_t>(),
spatial_scale, spatial_scale,
channels, channels,
......
...@@ -117,7 +117,7 @@ void pre_calc_for_bilinear_interpolate( ...@@ -117,7 +117,7 @@ void pre_calc_for_bilinear_interpolate(
template <typename T> template <typename T>
void roi_align_forward_kernel_impl( void roi_align_forward_kernel_impl(
int nthreads, int n_rois,
const T* input, const T* input,
const T& spatial_scale, const T& spatial_scale,
int channels, int channels,
...@@ -129,7 +129,6 @@ void roi_align_forward_kernel_impl( ...@@ -129,7 +129,6 @@ void roi_align_forward_kernel_impl(
bool aligned, bool aligned,
const T* rois, const T* rois,
T* output) { T* output) {
int n_rois = nthreads / channels / pooled_width / pooled_height;
// (n, c, ph, pw) is an element in the pooled output // (n, c, ph, pw) is an element in the pooled output
// can be parallelized using omp // can be parallelized using omp
// #pragma omp parallel for num_threads(32) // #pragma omp parallel for num_threads(32)
...@@ -414,8 +413,6 @@ at::Tensor roi_align_forward_kernel( ...@@ -414,8 +413,6 @@ at::Tensor roi_align_forward_kernel(
at::Tensor output = at::zeros( at::Tensor output = at::zeros(
{num_rois, channels, pooled_height, pooled_width}, input.options()); {num_rois, channels, pooled_height, pooled_width}, input.options());
auto output_size = num_rois * pooled_height * pooled_width * channels;
if (output.numel() == 0) if (output.numel() == 0)
return output; return output;
...@@ -423,7 +420,7 @@ at::Tensor roi_align_forward_kernel( ...@@ -423,7 +420,7 @@ at::Tensor roi_align_forward_kernel(
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "roi_align_forward_kernel", [&] { input.scalar_type(), "roi_align_forward_kernel", [&] {
roi_align_forward_kernel_impl<scalar_t>( roi_align_forward_kernel_impl<scalar_t>(
output_size, num_rois,
input_.data_ptr<scalar_t>(), input_.data_ptr<scalar_t>(),
spatial_scale, spatial_scale,
channels, channels,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment