#include #include #include #include #include #include "cuda_helpers.h" template __global__ void PSROIPoolForward( const int nthreads, const T* input, const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const T* rois, const int channels_out, T* output, int* channel_mapping) { CUDA_1D_KERNEL_LOOP(index, nthreads) { // (n, c_out, ph, pw) is an element in the pooled output int pw = index % pooled_width; int ph = (index / pooled_width) % pooled_height; int c_out = (index / pooled_width / pooled_height) % channels_out; int n = index / pooled_width / pooled_height / channels_out; // (n, c_in, ph, pw) is the associated element in the input int c_in = (c_out * pooled_height + ph) * pooled_width + pw; // [start, end) interval for spatial sampling const T* offset_rois = rois + n * 5; int roi_batch_ind = offset_rois[0]; int roi_start_w = roundf(offset_rois[1] * spatial_scale); int roi_start_h = roundf(offset_rois[2] * spatial_scale); int roi_end_w = roundf(offset_rois[3] * spatial_scale); int roi_end_h = roundf(offset_rois[4] * spatial_scale); // Force too small ROIs to be 1x1 int roi_width = max(roi_end_w - roi_start_w, 1); int roi_height = max(roi_end_h - roi_start_h, 1); T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); // Add roi offsets and clip to input boundaries hstart = min(max(hstart + roi_start_h, 0), height - 1); hend = min(max(hend + roi_start_h, 0), height - 1); wstart = min(max(wstart + roi_start_w, 0), width - 1); wend = min(max(wend + roi_start_w, 0), width - 1); bool is_empty = (hend <= hstart) || (wend <= wstart); const T* offset_input = input + (roi_batch_ind * channels + c_in) * height * width; T out_sum = 0; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { int input_index = h * width + w; out_sum += offset_input[input_index]; } } T bin_area = (hend - hstart) * (wend - wstart); output[index] = is_empty ? static_cast(0) : out_sum / bin_area; channel_mapping[index] = c_in; } } template __global__ void PSROIPoolBackward( const int nthreads, const T* grad_output, const int* channel_mapping, const int num_rois, const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int channels_out, T* grad_input, const T* rois) { CUDA_1D_KERNEL_LOOP(index, nthreads) { // (n, *, ph, pw) is an element in the pooled output int pw = index % pooled_width; int ph = (index / pooled_width) % pooled_height; int n = index / pooled_width / pooled_height / channels_out; const T* offset_rois = rois + n * 5; int roi_batch_ind = offset_rois[0]; int roi_start_w = roundf(offset_rois[1] * spatial_scale); int roi_start_h = roundf(offset_rois[2] * spatial_scale); int roi_end_w = roundf(offset_rois[3] * spatial_scale); int roi_end_h = roundf(offset_rois[4] * spatial_scale); // Force too small ROIs to be 1x1 int roi_width = max(roi_end_w - roi_start_w, 1); int roi_height = max(roi_end_h - roi_start_h, 1); T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); // Add roi offsets and clip to input boundaries hstart = min(max(hstart + roi_start_h, 0), height); hend = min(max(hend + roi_start_h, 0), height); wstart = min(max(wstart + roi_start_w, 0), width); wend = min(max(wend + roi_start_w, 0), width); bool is_empty = (hend <= hstart) || (wend <= wstart); int c_in = channel_mapping[index]; T* grad_input_offset = grad_input + (roi_batch_ind * channels + c_in) * height * width; T bin_area = (hend - hstart) * (wend - wstart); T diff_val = is_empty ? static_cast(0) : grad_output[index] / bin_area; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { int grad_input_index = h * width + w; atomicAdd(grad_input_offset + grad_input_index, diff_val); } } } } std::tuple PSROIPool_forward_cuda( const at::Tensor& input, const at::Tensor& rois, const float spatial_scale, const int pooled_height, const int pooled_width) { // Check if input tensors are CUDA tensors AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; at::CheckedFrom c = "PSROIPool_forward_cuda"; at::checkAllSameGPU(c, {input_t, rois_t}); at::checkAllSameType(c, {input_t, rois_t}); at::cuda::CUDAGuard device_guard(input.device()); auto num_rois = rois.size(0); auto channels = input.size(1); auto height = input.size(2); auto width = input.size(3); AT_ASSERTM( channels % (pooled_height * pooled_width) == 0, "input channels must be a multiple of pooling height * pooling width"); int channels_out = channels / (pooled_height * pooled_width); auto output = at::zeros( {num_rois, channels_out, pooled_height, pooled_width}, input.options()); auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kInt)); auto output_size = output.numel(); if (output_size == 0) { AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(output, channel_mapping); } cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 grid(std::min(at::cuda::ATenCeilDiv(output_size, 512L), 4096L)); dim3 block(512); AT_DISPATCH_FLOATING_TYPES_AND_HALF( input.scalar_type(), "PSROIPool_forward", [&] { PSROIPoolForward<<>>( output_size, input.contiguous().data(), spatial_scale, channels, height, width, pooled_height, pooled_width, rois.contiguous().data(), channels_out, output.data(), channel_mapping.data()); }); AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(output, channel_mapping); } at::Tensor PSROIPool_backward_cuda( const at::Tensor& grad, const at::Tensor& rois, const at::Tensor& channel_mapping, const float spatial_scale, const int pooled_height, const int pooled_width, const int batch_size, const int channels, const int height, const int width) { // Check if input tensors are CUDA tensors AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); AT_ASSERTM( channel_mapping.type().is_cuda(), "channel_mapping must be a CUDA tensor"); at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, channel_mapping_t{channel_mapping, "channel_mapping", 3}; at::CheckedFrom c = "PSROIPool_backward_cuda"; at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); at::checkAllSameType(c, {grad_t, rois_t}); at::cuda::CUDAGuard device_guard(grad.device()); auto num_rois = rois.size(0); auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 grid(std::min(at::cuda::ATenCeilDiv(grad.numel(), 512L), 4096L)); dim3 block(512); // handle possibly empty gradients if (grad.numel() == 0) { AT_CUDA_CHECK(cudaGetLastError()); return grad_input; } int channels_out = channels / (pooled_height * pooled_width); AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "PSROIPool_backward", [&] { PSROIPoolBackward<<>>( grad.numel(), grad.contiguous().data(), channel_mapping.data(), num_rois, spatial_scale, channels, height, width, pooled_height, pooled_width, channels_out, grad_input.data(), rois.contiguous().data()); }); AT_CUDA_CHECK(cudaGetLastError()); return grad_input; }