Unverified Commit e9fb2a35 authored by puhuk's avatar puhuk Committed by GitHub
Browse files

Replace usages of atomicAdd with gpuAtomicAdd (#5823)



To resolve issue #5815
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent b6ab6563
...@@ -385,7 +385,7 @@ __global__ void deformable_col2im_kernel( ...@@ -385,7 +385,7 @@ __global__ void deformable_col2im_kernel(
std::abs(y - yp) < 1 && std::abs(x - xp) < 1) { std::abs(y - yp) < 1 && std::abs(x - xp) < 1) {
index_t grad_pos = ((b * channels + c) * height + yp) * width + xp; index_t grad_pos = ((b * channels + c) * height + yp) * width + xp;
scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp)); scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp));
atomicAdd(grad_im + grad_pos, mask_value * weight * col[index]); gpuAtomicAdd(grad_im + grad_pos, mask_value * weight * col[index]);
} }
} }
} }
......
...@@ -285,10 +285,10 @@ __global__ void ps_roi_align_backward_kernel_impl( ...@@ -285,10 +285,10 @@ __global__ void ps_roi_align_backward_kernel_impl(
T g4 = grad_output_this_bin * w4 / count; T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomicAdd(grad_input_offset + y_low * width + x_low, g1); gpuAtomicAdd(grad_input_offset + y_low * width + x_low, g1);
atomicAdd(grad_input_offset + y_low * width + x_high, g2); gpuAtomicAdd(grad_input_offset + y_low * width + x_high, g2);
atomicAdd(grad_input_offset + y_high * width + x_low, g3); gpuAtomicAdd(grad_input_offset + y_high * width + x_low, g3);
atomicAdd(grad_input_offset + y_high * width + x_high, g4); gpuAtomicAdd(grad_input_offset + y_high * width + x_high, g4);
} // if } // if
} // ix } // ix
} // iy } // iy
......
...@@ -131,7 +131,7 @@ __global__ void ps_roi_pool_backward_kernel_impl( ...@@ -131,7 +131,7 @@ __global__ void ps_roi_pool_backward_kernel_impl(
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
int grad_input_index = h * width + w; int grad_input_index = h * width + w;
atomicAdd(grad_input_offset + grad_input_index, diff_val); gpuAtomicAdd(grad_input_offset + grad_input_index, diff_val);
} }
} }
} }
......
...@@ -301,13 +301,13 @@ __global__ void roi_align_backward_kernel_impl( ...@@ -301,13 +301,13 @@ __global__ void roi_align_backward_kernel_impl(
T g4 = grad_output_this_bin * w4 / count; T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomicAdd( gpuAtomicAdd(
offset_grad_input + y_low * width + x_low, static_cast<T>(g1)); offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
atomicAdd( gpuAtomicAdd(
offset_grad_input + y_low * width + x_high, static_cast<T>(g2)); offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
atomicAdd( gpuAtomicAdd(
offset_grad_input + y_high * width + x_low, static_cast<T>(g3)); offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
atomicAdd( gpuAtomicAdd(
offset_grad_input + y_high * width + x_high, static_cast<T>(g4)); offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
} // if } // if
} // ix } // ix
......
...@@ -113,7 +113,7 @@ __global__ void roi_pool_backward_kernel_impl( ...@@ -113,7 +113,7 @@ __global__ void roi_pool_backward_kernel_impl(
int argmax = argmax_data_offset[ph * pooled_width + pw]; int argmax = argmax_data_offset[ph * pooled_width + pw];
if (argmax != -1) { if (argmax != -1) {
atomicAdd( gpuAtomicAdd(
grad_input_offset + argmax, grad_input_offset + argmax,
static_cast<T>( static_cast<T>(
grad_output[output_offset + ph * h_stride + pw * w_stride])); grad_output[output_offset + ph * h_stride + pw * w_stride]));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment