Unverified Commit bd4471cc authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Replaced gpuAtomicAdd by fastAtomicAdd (#7596)

parent 6ccc712b
......@@ -70,7 +70,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>
#include "cuda_helpers.h"
......@@ -335,6 +335,8 @@ __global__ void deformable_col2im_kernel(
index_t out_w,
bool use_mask,
scalar_t* grad_im) {
const index_t grad_im_numel = width * height * channels * batch_sz;
CUDA_1D_KERNEL_LOOP_T(index, n, int64_t) {
const index_t out_x = index % out_w;
const index_t out_y = (index / out_w) % out_h;
......@@ -381,7 +383,12 @@ __global__ void deformable_col2im_kernel(
std::abs(y - yp) < 1 && std::abs(x - xp) < 1) {
index_t grad_pos = ((b * channels + c) * height + yp) * width + xp;
scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp));
gpuAtomicAdd(grad_im + grad_pos, mask_value * weight * col[index]);
at::native::fastAtomicAdd(
grad_im,
grad_pos,
grad_im_numel,
mask_value * weight * col[index],
true);
}
}
}
......
......@@ -2,7 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>
#include "cuda_helpers.h"
......@@ -212,7 +212,8 @@ __global__ void ps_roi_align_backward_kernel_impl(
int sampling_ratio,
int channels_out,
T* grad_input,
const T* rois) {
const T* rois,
const int memory_span) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, *, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
......@@ -235,8 +236,6 @@ __global__ void ps_roi_align_backward_kernel_impl(
T bin_size_w = roi_width / static_cast<T>(pooled_width);
int c_in = channel_mapping[index];
T* grad_input_offset =
grad_input + (roi_batch_ind * channels + c_in) * height * width;
// Do not using floor/ceil; this implementation detail is critical
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
......@@ -252,6 +251,8 @@ __global__ void ps_roi_align_backward_kernel_impl(
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
const int offset = (roi_batch_ind * channels + c_in) * height * width;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = hstart +
static_cast<T>(iy + .5f) * bin_size_h /
......@@ -285,10 +286,30 @@ __global__ void ps_roi_align_backward_kernel_impl(
T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
gpuAtomicAdd(grad_input_offset + y_low * width + x_low, g1);
gpuAtomicAdd(grad_input_offset + y_low * width + x_high, g2);
gpuAtomicAdd(grad_input_offset + y_high * width + x_low, g3);
gpuAtomicAdd(grad_input_offset + y_high * width + x_high, g4);
at::native::fastAtomicAdd(
grad_input,
offset + y_low * width + x_low,
memory_span,
static_cast<T>(g1),
true);
at::native::fastAtomicAdd(
grad_input,
offset + y_low * width + x_high,
memory_span,
static_cast<T>(g2),
true);
at::native::fastAtomicAdd(
grad_input,
offset + y_high * width + x_low,
memory_span,
static_cast<T>(g3),
true);
at::native::fastAtomicAdd(
grad_input,
offset + y_high * width + x_high,
memory_span,
static_cast<T>(g4),
true);
} // if
} // ix
} // iy
......@@ -430,7 +451,8 @@ at::Tensor ps_roi_align_backward_kernel(
sampling_ratio,
channels_out,
grad_input.data_ptr<scalar_t>(),
rois_.data_ptr<scalar_t>());
rois_.data_ptr<scalar_t>(),
grad_input.numel());
});
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
......
......@@ -2,7 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>
#include "cuda_helpers.h"
......@@ -91,7 +91,8 @@ __global__ void ps_roi_pool_backward_kernel_impl(
int pooled_width,
int channels_out,
T* grad_input,
const T* rois) {
const T* rois,
const int memory_span) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, *, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
......@@ -124,14 +125,15 @@ __global__ void ps_roi_pool_backward_kernel_impl(
bool is_empty = (hend <= hstart) || (wend <= wstart);
int c_in = channel_mapping[index];
T* grad_input_offset =
grad_input + (roi_batch_ind * channels + c_in) * height * width;
T bin_area = (hend - hstart) * (wend - wstart);
T diff_val = is_empty ? static_cast<T>(0) : grad_output[index] / bin_area;
const int offset = (roi_batch_ind * channels + c_in) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int grad_input_index = h * width + w;
gpuAtomicAdd(grad_input_offset + grad_input_index, diff_val);
at::native::fastAtomicAdd(
grad_input, offset + grad_input_index, memory_span, diff_val, true);
}
}
}
......@@ -269,7 +271,8 @@ at::Tensor ps_roi_pool_backward_kernel(
pooled_width,
channels_out,
grad_input.data_ptr<scalar_t>(),
rois_.data_ptr<scalar_t>());
rois_.data_ptr<scalar_t>(),
grad_input.numel());
});
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
......
......@@ -2,7 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>
#include "cuda_helpers.h"
......@@ -218,7 +218,8 @@ __global__ void roi_align_backward_kernel_impl(
int n_stride,
int c_stride,
int h_stride,
int w_stride) {
int w_stride,
const int memory_span) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
......@@ -247,12 +248,9 @@ __global__ void roi_align_backward_kernel_impl(
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
T* offset_grad_input =
grad_input + ((roi_batch_ind * channels + c) * height * width);
// We need to index the gradient using the tensor strides to access the
// correct values.
int output_offset = n * n_stride + c * c_stride;
const int output_offset = n * n_stride + c * c_stride;
const T* offset_grad_output = grad_output + output_offset;
const T grad_output_this_bin =
offset_grad_output[ph * h_stride + pw * w_stride];
......@@ -267,6 +265,8 @@ __global__ void roi_align_backward_kernel_impl(
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
const int input_offset = (roi_batch_ind * channels + c) * height * width;
for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{
const T y = roi_start_h + ph * bin_size_h +
......@@ -301,14 +301,30 @@ __global__ void roi_align_backward_kernel_impl(
T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
gpuAtomicAdd(
offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
gpuAtomicAdd(
offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
gpuAtomicAdd(
offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
gpuAtomicAdd(
offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
at::native::fastAtomicAdd(
grad_input,
input_offset + y_low * width + x_low,
memory_span,
static_cast<T>(g1),
true);
at::native::fastAtomicAdd(
grad_input,
input_offset + y_low * width + x_high,
memory_span,
static_cast<T>(g2),
true);
at::native::fastAtomicAdd(
grad_input,
input_offset + y_high * width + x_low,
memory_span,
static_cast<T>(g3),
true);
at::native::fastAtomicAdd(
grad_input,
input_offset + y_high * width + x_high,
memory_span,
static_cast<T>(g4),
true);
} // if
} // ix
} // iy
......@@ -442,7 +458,8 @@ at::Tensor roi_align_backward_kernel(
n_stride,
c_stride,
h_stride,
w_stride);
w_stride,
grad_input.numel());
});
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
......
......@@ -3,7 +3,7 @@
#include <c10/cuda/CUDAGuard.h>
#include <float.h>
#include <torch/library.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>
#include "cuda_helpers.h"
......@@ -94,7 +94,8 @@ __global__ void roi_pool_backward_kernel_impl(
int n_stride,
int c_stride,
int h_stride,
int w_stride) {
int w_stride,
const int memory_span) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
......@@ -104,19 +105,21 @@ __global__ void roi_pool_backward_kernel_impl(
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
T* grad_input_offset =
grad_input + ((roi_batch_ind * channels + c) * height * width);
int output_offset = n * n_stride + c * c_stride;
const int output_offset = n * n_stride + c * c_stride;
const int* argmax_data_offset =
argmax_data + (n * channels + c) * pooled_height * pooled_width;
int argmax = argmax_data_offset[ph * pooled_width + pw];
const int argmax = argmax_data_offset[ph * pooled_width + pw];
const int offset = (roi_batch_ind * channels + c) * height * width;
if (argmax != -1) {
gpuAtomicAdd(
grad_input_offset + argmax,
at::native::fastAtomicAdd(
grad_input,
offset + argmax,
memory_span,
static_cast<T>(
grad_output[output_offset + ph * h_stride + pw * w_stride]));
grad_output[output_offset + ph * h_stride + pw * w_stride]),
true);
}
}
}
......@@ -253,7 +256,8 @@ at::Tensor roi_pool_backward_kernel(
n_stride,
c_stride,
h_stride,
w_stride);
w_stride,
grad_input.numel());
});
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment