ps_roi_pool_kernel.cu 9.42 KB
Newer Older
1
2
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
3
#include <THC/THCAtomics.cuh>
4
5

#include "cuda_helpers.h"
6
7
8
9
10
11
#include "ps_roi_pool_kernel.h"

namespace vision {
namespace ops {

namespace {
12
13

template <typename T>
14
__global__ void ps_roi_pool_forward_kernel_impl(
15
    int nthreads,
16
17
    const T* input,
    const T spatial_scale,
18
19
20
21
22
    int channels,
    int height,
    int width,
    int pooled_height,
    int pooled_width,
23
    const T* rois,
24
    int channels_out,
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    T* output,
    int* channel_mapping) {
  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    // (n, c_out, ph, pw) is an element in the pooled output
    int pw = index % pooled_width;
    int ph = (index / pooled_width) % pooled_height;
    int c_out = (index / pooled_width / pooled_height) % channels_out;
    int n = index / pooled_width / pooled_height / channels_out;

    // (n, c_in, ph, pw) is the associated element in the input
    int c_in = (c_out * pooled_height + ph) * pooled_width + pw;

    // [start, end) interval for spatial sampling
    const T* offset_rois = rois + n * 5;
    int roi_batch_ind = offset_rois[0];
    int roi_start_w = roundf(offset_rois[1] * spatial_scale);
    int roi_start_h = roundf(offset_rois[2] * spatial_scale);
    int roi_end_w = roundf(offset_rois[3] * spatial_scale);
    int roi_end_h = roundf(offset_rois[4] * spatial_scale);

    // Force too small ROIs to be 1x1
    int roi_width = max(roi_end_w - roi_start_w, 1);
    int roi_height = max(roi_end_h - roi_start_h, 1);
    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

    int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
    int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
    int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
    int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));

    // Add roi offsets and clip to input boundaries
    hstart = min(max(hstart + roi_start_h, 0), height - 1);
    hend = min(max(hend + roi_start_h, 0), height - 1);
    wstart = min(max(wstart + roi_start_w, 0), width - 1);
    wend = min(max(wend + roi_start_w, 0), width - 1);
    bool is_empty = (hend <= hstart) || (wend <= wstart);

    const T* offset_input =
        input + (roi_batch_ind * channels + c_in) * height * width;
    T out_sum = 0;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        int input_index = h * width + w;
        out_sum += offset_input[input_index];
      }
    }

    T bin_area = (hend - hstart) * (wend - wstart);
    output[index] = is_empty ? static_cast<T>(0) : out_sum / bin_area;
    channel_mapping[index] = c_in;
  }
}

template <typename T>
80
__global__ void ps_roi_pool_backward_kernel_impl(
81
    int nthreads,
82
83
    const T* grad_output,
    const int* channel_mapping,
84
    int num_rois,
85
    const T spatial_scale,
86
87
88
89
90
91
    int channels,
    int height,
    int width,
    int pooled_height,
    int pooled_width,
    int channels_out,
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    T* grad_input,
    const T* rois) {
  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    // (n, *, ph, pw) is an element in the pooled output
    int pw = index % pooled_width;
    int ph = (index / pooled_width) % pooled_height;
    int n = index / pooled_width / pooled_height / channels_out;

    const T* offset_rois = rois + n * 5;
    int roi_batch_ind = offset_rois[0];
    int roi_start_w = roundf(offset_rois[1] * spatial_scale);
    int roi_start_h = roundf(offset_rois[2] * spatial_scale);
    int roi_end_w = roundf(offset_rois[3] * spatial_scale);
    int roi_end_h = roundf(offset_rois[4] * spatial_scale);

    // Force too small ROIs to be 1x1
    int roi_width = max(roi_end_w - roi_start_w, 1);
    int roi_height = max(roi_end_h - roi_start_h, 1);
    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

    int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
    int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
    int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
    int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));

    // Add roi offsets and clip to input boundaries
    hstart = min(max(hstart + roi_start_h, 0), height);
    hend = min(max(hend + roi_start_h, 0), height);
    wstart = min(max(wstart + roi_start_w, 0), width);
    wend = min(max(wend + roi_start_w, 0), width);
    bool is_empty = (hend <= hstart) || (wend <= wstart);

    int c_in = channel_mapping[index];
    T* grad_input_offset =
        grad_input + (roi_batch_ind * channels + c_in) * height * width;
    T bin_area = (hend - hstart) * (wend - wstart);
    T diff_val = is_empty ? static_cast<T>(0) : grad_output[index] / bin_area;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        int grad_input_index = h * width + w;
        atomicAdd(grad_input_offset + grad_input_index, diff_val);
      }
    }
  }
}

139
140
141
} // namespace

std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_cuda(
142
143
    const at::Tensor& input,
    const at::Tensor& rois,
144
145
146
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width) {
147
  // Check if input tensors are CUDA tensors
vfdev's avatar
vfdev committed
148
149
150
151
  TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
  TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
  TORCH_CHECK(
      rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]");
152
153
154

  at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};

155
  at::CheckedFrom c = "ps_roi_pool_forward_cuda";
156
157
158
159
160
161
162
163
164
165
  at::checkAllSameGPU(c, {input_t, rois_t});
  at::checkAllSameType(c, {input_t, rois_t});

  at::cuda::CUDAGuard device_guard(input.device());

  auto num_rois = rois.size(0);
  auto channels = input.size(1);
  auto height = input.size(2);
  auto width = input.size(3);

vfdev's avatar
vfdev committed
166
  TORCH_CHECK(
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
      channels % (pooled_height * pooled_width) == 0,
      "input channels must be a multiple of pooling height * pooling width");
  int channels_out = channels / (pooled_height * pooled_width);

  auto output = at::zeros(
      {num_rois, channels_out, pooled_height, pooled_width}, input.options());
  auto channel_mapping =
      at::zeros(output.sizes(), input.options().dtype(at::kInt));

  auto output_size = output.numel();
  if (output_size == 0) {
    AT_CUDA_CHECK(cudaGetLastError());
    return std::make_tuple(output, channel_mapping);
  }

  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

184
  dim3 grid(std::min(
185
      ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)),
186
      static_cast<int64_t>(4096)));
187
188
  dim3 block(512);

189
  auto input_ = input.contiguous(), rois_ = rois.contiguous();
190
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
191
192
      input.scalar_type(), "ps_roi_pool_forward_cuda", [&] {
        ps_roi_pool_forward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
193
            output_size,
194
            input_.data_ptr<scalar_t>(),
195
196
197
198
199
200
            spatial_scale,
            channels,
            height,
            width,
            pooled_height,
            pooled_width,
201
            rois_.data_ptr<scalar_t>(),
202
            channels_out,
203
204
            output.data_ptr<scalar_t>(),
            channel_mapping.data_ptr<int>());
205
206
207
208
209
      });
  AT_CUDA_CHECK(cudaGetLastError());
  return std::make_tuple(output, channel_mapping);
}

210
at::Tensor ps_roi_pool_backward_cuda(
211
212
213
    const at::Tensor& grad,
    const at::Tensor& rois,
    const at::Tensor& channel_mapping,
214
215
216
217
218
219
220
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t batch_size,
    int64_t channels,
    int64_t height,
    int64_t width) {
221
  // Check if input tensors are CUDA tensors
vfdev's avatar
vfdev committed
222
223
224
  TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor");
  TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
  TORCH_CHECK(
225
      channel_mapping.is_cuda(), "channel_mapping must be a CUDA tensor");
226
227
228
229

  at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
      channel_mapping_t{channel_mapping, "channel_mapping", 3};

230
  at::CheckedFrom c = "ps_roi_pool_backward_cuda";
231
232
233
234
235
236
237
238
239
240
241
  at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t});
  at::checkAllSameType(c, {grad_t, rois_t});

  at::cuda::CUDAGuard device_guard(grad.device());

  auto num_rois = rois.size(0);
  auto grad_input =
      at::zeros({batch_size, channels, height, width}, grad.options());

  cudaStream_t stream = at::cuda::getCurrentCUDAStream();

242
  dim3 grid(std::min(
243
      ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)),
244
      static_cast<int64_t>(4096)));
245
246
247
248
249
250
251
252
253
254
  dim3 block(512);

  // handle possibly empty gradients
  if (grad.numel() == 0) {
    AT_CUDA_CHECK(cudaGetLastError());
    return grad_input;
  }

  int channels_out = channels / (pooled_height * pooled_width);

255
  auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
256
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
257
258
      grad.scalar_type(), "ps_roi_pool_backward_cuda", [&] {
        ps_roi_pool_backward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
259
            grad.numel(),
260
            grad_.data_ptr<scalar_t>(),
261
            channel_mapping.data_ptr<int>(),
262
263
264
265
266
267
268
269
            num_rois,
            spatial_scale,
            channels,
            height,
            width,
            pooled_height,
            pooled_width,
            channels_out,
270
            grad_input.data_ptr<scalar_t>(),
271
            rois_.data_ptr<scalar_t>());
272
273
274
275
      });
  AT_CUDA_CHECK(cudaGetLastError());
  return grad_input;
}
276
277
278

} // namespace ops
} // namespace vision