roi_pool_kernel.cu 8.15 KB
Newer Older
1
2
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
3
#include <float.h>
4
#include <THC/THCAtomics.cuh>
5
6

#include "cuda_helpers.h"
7
8
9
10
11
12
#include "roi_pool_kernel.h"

namespace vision {
namespace ops {

namespace {
13
14

template <typename T>
15
__global__ void roi_pool_forward_kernel_impl(
16
    int nthreads,
17
18
    const T* input,
    const T spatial_scale,
19
20
21
22
23
    int channels,
    int height,
    int width,
    int pooled_height,
    int pooled_width,
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    const T* rois,
    T* output,
    int* argmax_data) {
  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    // (n, c, ph, pw) is an element in the pooled output
    int pw = index % pooled_width;
    int ph = (index / pooled_width) % pooled_height;
    int c = (index / pooled_width / pooled_height) % channels;
    int n = index / pooled_width / pooled_height / channels;

    const T* offset_rois = rois + n * 5;
    int roi_batch_ind = offset_rois[0];
    int roi_start_w = round(offset_rois[1] * spatial_scale);
    int roi_start_h = round(offset_rois[2] * spatial_scale);
    int roi_end_w = round(offset_rois[3] * spatial_scale);
    int roi_end_h = round(offset_rois[4] * spatial_scale);

    // Force malformed ROIs to be 1x1
    int roi_width = max(roi_end_w - roi_start_w + 1, 1);
    int roi_height = max(roi_end_h - roi_start_h + 1, 1);
    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

    int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
    int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
    int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
    int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));

    // Add roi offsets and clip to input boundaries
    hstart = min(max(hstart + roi_start_h, 0), height);
    hend = min(max(hend + roi_start_h, 0), height);
    wstart = min(max(wstart + roi_start_w, 0), width);
    wend = min(max(wend + roi_start_w, 0), width);
    bool is_empty = (hend <= hstart) || (wend <= wstart);

    // Define an empty pooling region to be zero
    T maxval = is_empty ? 0 : -FLT_MAX;
    // If nothing is pooled, argmax = -1 causes nothing to be backprop'd
    int maxidx = -1;
    const T* offset_input =
        input + (roi_batch_ind * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        int input_index = h * width + w;
        if (offset_input[input_index] > maxval) {
          maxval = offset_input[input_index];
          maxidx = input_index;
        }
      }
    }
    output[index] = maxval;
    argmax_data[index] = maxidx;
  }
}

template <typename T>
80
__global__ void roi_pool_backward_kernel_impl(
81
    int nthreads,
82
83
    const T* grad_output,
    const int* argmax_data,
84
    int num_rois,
85
    const T spatial_scale,
86
87
88
89
90
    int channels,
    int height,
    int width,
    int pooled_height,
    int pooled_width,
91
92
    T* grad_input,
    const T* rois,
93
94
95
96
    int n_stride,
    int c_stride,
    int h_stride,
    int w_stride) {
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    // (n, c, ph, pw) is an element in the pooled output
    int pw = index % pooled_width;
    int ph = (index / pooled_width) % pooled_height;
    int c = (index / pooled_width / pooled_height) % channels;
    int n = index / pooled_width / pooled_height / channels;

    const T* offset_rois = rois + n * 5;
    int roi_batch_ind = offset_rois[0];
    T* grad_input_offset =
        grad_input + ((roi_batch_ind * channels + c) * height * width);

    int output_offset = n * n_stride + c * c_stride;
    const int* argmax_data_offset =
        argmax_data + (n * channels + c) * pooled_height * pooled_width;
    int argmax = argmax_data_offset[ph * pooled_width + pw];

    if (argmax != -1) {
      atomicAdd(
          grad_input_offset + argmax,
          static_cast<T>(
              grad_output[output_offset + ph * h_stride + pw * w_stride]));
    }
  }
}

123
124
125
} // namespace

std::tuple<at::Tensor, at::Tensor> roi_pool_forward_cuda(
126
127
    const at::Tensor& input,
    const at::Tensor& rois,
128
129
130
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width) {
vfdev's avatar
vfdev committed
131
132
133
134
  TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
  TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
  TORCH_CHECK(
      rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]");
135
136
137

  at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};

138
  at::CheckedFrom c = "roi_pool_forward_cuda";
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
  at::checkAllSameGPU(c, {input_t, rois_t});
  at::checkAllSameType(c, {input_t, rois_t});

  at::cuda::CUDAGuard device_guard(input.device());

  auto num_rois = rois.size(0);
  auto channels = input.size(1);
  auto height = input.size(2);
  auto width = input.size(3);

  at::Tensor output = at::zeros(
      {num_rois, channels, pooled_height, pooled_width}, input.options());
  at::Tensor argmax = at::zeros(
      {num_rois, channels, pooled_height, pooled_width},
      input.options().dtype(at::kInt));

  auto output_size = num_rois * pooled_height * pooled_width * channels;
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

Francisco Massa's avatar
Francisco Massa committed
158
  dim3 grid(std::min(
159
      ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)),
Francisco Massa's avatar
Francisco Massa committed
160
      static_cast<int64_t>(4096)));
161
162
163
164
165
166
167
  dim3 block(512);

  if (output.numel() == 0) {
    AT_CUDA_CHECK(cudaGetLastError());
    return std::make_tuple(output, argmax);
  }

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
  auto input_ = input.contiguous(), rois_ = rois.contiguous();
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      input.scalar_type(), "roi_pool_forward_cuda", [&] {
        roi_pool_forward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
            output_size,
            input_.data_ptr<scalar_t>(),
            spatial_scale,
            channels,
            height,
            width,
            pooled_height,
            pooled_width,
            rois_.data_ptr<scalar_t>(),
            output.data_ptr<scalar_t>(),
            argmax.data_ptr<int>());
      });
184
185
186
187
  AT_CUDA_CHECK(cudaGetLastError());
  return std::make_tuple(output, argmax);
}

188
at::Tensor roi_pool_backward_cuda(
189
190
191
    const at::Tensor& grad,
    const at::Tensor& rois,
    const at::Tensor& argmax,
192
193
194
195
196
197
198
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t batch_size,
    int64_t channels,
    int64_t height,
    int64_t width) {
199
  // Check if input tensors are CUDA tensors
vfdev's avatar
vfdev committed
200
201
202
  TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor");
  TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
  TORCH_CHECK(argmax.is_cuda(), "argmax must be a CUDA tensor");
203
204
205
206

  at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
      argmax_t{argmax, "argmax", 3};

207
  at::CheckedFrom c = "roi_pool_backward_cuda";
208
209
210
211
212
213
214
215
216
217
218
219
  at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t});
  at::checkAllSameType(c, {grad_t, rois_t});

  at::cuda::CUDAGuard device_guard(grad.device());

  auto num_rois = rois.size(0);

  at::Tensor grad_input =
      at::zeros({batch_size, channels, height, width}, grad.options());

  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

Francisco Massa's avatar
Francisco Massa committed
220
  dim3 grid(std::min(
221
      ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)),
Francisco Massa's avatar
Francisco Massa committed
222
      static_cast<int64_t>(4096)));
223
224
225
226
227
228
229
230
231
232
233
234
235
  dim3 block(512);

  // handle possibly empty gradients
  if (grad.numel() == 0) {
    AT_CUDA_CHECK(cudaGetLastError());
    return grad_input;
  }

  int n_stride = grad.stride(0);
  int c_stride = grad.stride(1);
  int h_stride = grad.stride(2);
  int w_stride = grad.stride(3);

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
  auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous();
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      grad.scalar_type(), "roi_pool_backward_cuda", [&] {
        roi_pool_backward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
            grad.numel(),
            grad.data_ptr<scalar_t>(),
            argmax_.data_ptr<int>(),
            num_rois,
            spatial_scale,
            channels,
            height,
            width,
            pooled_height,
            pooled_width,
            grad_input.data_ptr<scalar_t>(),
            rois_.data_ptr<scalar_t>(),
            n_stride,
            c_stride,
            h_stride,
            w_stride);
      });
257
258
259
  AT_CUDA_CHECK(cudaGetLastError());
  return grad_input;
}
260
261
262

} // namespace ops
} // namespace vision