PSROIPool_cpu.cpp 8.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <TH/TH.h>
#include <algorithm>

template <class T>
inline void add(T* address, const T& val) {
  *address += val;
}

template <typename T>
void PSROIPoolForward(
    const T* input,
    const T spatial_scale,
15
16
17
18
19
    int channels,
    int height,
    int width,
    int pooled_height,
    int pooled_width,
20
    const T* rois,
21
22
    int channels_out,
    int num_rois,
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    T* output,
    int* channel_mapping) {
  for (int n = 0; n < num_rois; ++n) {
    const T* offset_rois = rois + n * 5;
    int roi_batch_ind = offset_rois[0];
    int roi_start_w = round(offset_rois[1] * spatial_scale);
    int roi_start_h = round(offset_rois[2] * spatial_scale);
    int roi_end_w = round(offset_rois[3] * spatial_scale);
    int roi_end_h = round(offset_rois[4] * spatial_scale);

    // Force too small ROIs to be 1x1
    int roi_width = std::max(roi_end_w - roi_start_w, 1);
    int roi_height = std::max(roi_end_h - roi_start_h, 1);
    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

    int c_in = 0;
    for (int c_out = 0; c_out < channels_out; ++c_out) {
      for (int ph = 0; ph < pooled_height; ++ph) {
        for (int pw = 0; pw < pooled_width; ++pw) {
          int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
          int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
          int hend =
              static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
          int wend =
              static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));

          // Add roi offsets and clip to input boundaries
          hstart = std::min(std::max(hstart + roi_start_h, 0), height - 1);
          hend = std::min(std::max(hend + roi_start_h, 0), height - 1);
          wstart = std::min(std::max(wstart + roi_start_w, 0), width - 1);
          wend = std::min(std::max(wend + roi_start_w, 0), width - 1);
          bool is_empty = (hend <= hstart) || (wend <= wstart);

          const T* offset_input =
              input + (roi_batch_ind * channels + c_in) * height * width;

          T out_sum = 0;
          for (int h = hstart; h < hend; ++h) {
            for (int w = wstart; w < wend; ++w) {
              int input_index = h * width + w;
              out_sum += offset_input[input_index];
            }
          }

          int index =
              ((n * channels_out + c_out) * pooled_height + ph) * pooled_width +
              pw;
          T bin_area = (hend - hstart) * (wend - wstart);
          output[index] = is_empty ? static_cast<T>(0) : out_sum / bin_area;
          channel_mapping[index] = c_in;
          c_in++;
        }
      }
    }
  }
}

template <typename T>
void PSROIPoolBackward(
    const T* grad_output,
    const int* channel_mapping,
85
    int num_rois,
86
    const T spatial_scale,
87
88
89
90
91
92
    int channels,
    int height,
    int width,
    int pooled_height,
    int pooled_width,
    int channels_out,
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    T* grad_input,
    const T* rois) {
  for (int n = 0; n < num_rois; ++n) {
    const T* offset_rois = rois + n * 5;
    int roi_batch_ind = offset_rois[0];
    int roi_start_w = roundf(offset_rois[1] * spatial_scale);
    int roi_start_h = roundf(offset_rois[2] * spatial_scale);
    int roi_end_w = roundf(offset_rois[3] * spatial_scale);
    int roi_end_h = roundf(offset_rois[4] * spatial_scale);

    // Force too small ROIs to be 1x1
    int roi_width = std::max(roi_end_w - roi_start_w, 1);
    int roi_height = std::max(roi_end_h - roi_start_h, 1);
    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

    for (int ph = 0; ph < pooled_height; ++ph) {
      for (int pw = 0; pw < pooled_width; ++pw) {
        int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
        int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
        int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
        int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));

        // Add roi offsets and clip to input boundaries
        hstart = std::min(std::max(hstart + roi_start_h, 0), height);
        hend = std::min(std::max(hend + roi_start_h, 0), height);
        wstart = std::min(std::max(wstart + roi_start_w, 0), width);
        wend = std::min(std::max(wend + roi_start_w, 0), width);
        bool is_empty = (hend <= hstart) || (wend <= wstart);

        for (int c_out = 0; c_out < channels_out; ++c_out) {
          int index =
              ((n * channels_out + c_out) * pooled_height + ph) * pooled_width +
              pw;
          int c_in = channel_mapping[index];

          T* grad_input_offset =
              grad_input + (roi_batch_ind * channels + c_in) * height * width;
          T bin_area = (hend - hstart) * (wend - wstart);
          T diff_val =
              is_empty ? static_cast<T>(0) : grad_output[index] / bin_area;
          for (int h = hstart; h < hend; ++h) {
            for (int w = wstart; w < wend; ++w) {
              int grad_input_index = h * width + w;
              add(grad_input_offset + grad_input_index, diff_val);
            }
          }
        }
      }
    }
  }
}

std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu(
    const at::Tensor& input,
    const at::Tensor& rois,
149
150
151
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width) {
152
  // Check if input tensors are CPU tensors
vfdev's avatar
vfdev committed
153
154
155
156
  TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
  TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
  TORCH_CHECK(
      rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]");
157
158
159
160
161
162
163
164
165
166
167

  at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};

  at::CheckedFrom c = "PSROIPool_forward_cpu";
  at::checkAllSameType(c, {input_t, rois_t});

  int num_rois = rois.size(0);
  int channels = input.size(1);
  int height = input.size(2);
  int width = input.size(3);

vfdev's avatar
vfdev committed
168
  TORCH_CHECK(
169
170
171
172
173
174
175
176
177
178
179
180
181
182
      channels % (pooled_height * pooled_width) == 0,
      "input channels must be a multiple of pooling height * pooling width");
  int channels_out = channels / (pooled_height * pooled_width);

  auto output = at::zeros(
      {num_rois, channels_out, pooled_height, pooled_width}, input.options());
  auto channel_mapping =
      at::zeros(output.sizes(), input.options().dtype(at::kInt));

  auto output_size = output.numel();
  if (output_size == 0) {
    return std::make_tuple(output, channel_mapping);
  }

183
  auto input_ = input.contiguous(), rois_ = rois.contiguous();
184
185
186
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      input.scalar_type(), "PSROIPool_forward", [&] {
        PSROIPoolForward<scalar_t>(
187
            input_.data_ptr<scalar_t>(),
188
189
190
191
192
193
            spatial_scale,
            channels,
            height,
            width,
            pooled_height,
            pooled_width,
194
            rois_.data_ptr<scalar_t>(),
195
196
            channels_out,
            num_rois,
197
198
            output.data_ptr<scalar_t>(),
            channel_mapping.data_ptr<int>());
199
200
201
202
203
204
205
206
      });
  return std::make_tuple(output, channel_mapping);
}

at::Tensor PSROIPool_backward_cpu(
    const at::Tensor& grad,
    const at::Tensor& rois,
    const at::Tensor& channel_mapping,
207
208
209
210
211
212
213
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t batch_size,
    int64_t channels,
    int64_t height,
    int64_t width) {
214
  // Check if input tensors are CPU tensors
vfdev's avatar
vfdev committed
215
216
217
  TORCH_CHECK(grad.device().is_cpu(), "grad must be a CPU tensor");
  TORCH_CHECK(rois.device().is_cpu(), "rois must be a CPU tensor");
  TORCH_CHECK(
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
      channel_mapping.device().is_cpu(),
      "channel_mapping must be a CPU tensor");

  at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
      channel_mapping_t{channel_mapping, "channel_mapping", 3};

  at::CheckedFrom c = "PSROIPool_backward_cpu";
  at::checkAllSameType(c, {grad_t, rois_t});

  auto num_rois = rois.size(0);
  auto grad_input =
      at::zeros({batch_size, channels, height, width}, grad.options());

  // handle possibly empty gradients
  if (grad.numel() == 0) {
    return grad_input;
  }

  int channels_out = channels / (pooled_height * pooled_width);

238
  auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
239
240
241
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      grad.scalar_type(), "PSROIPool_backward", [&] {
        PSROIPoolBackward<scalar_t>(
242
            grad_.data_ptr<scalar_t>(),
243
            channel_mapping.data_ptr<int>(),
244
245
246
247
248
249
250
251
            num_rois,
            spatial_scale,
            channels,
            height,
            width,
            pooled_height,
            pooled_width,
            channels_out,
252
            grad_input.data_ptr<scalar_t>(),
253
            rois_.data_ptr<scalar_t>());
254
255
256
      });
  return grad_input;
}