ROIAlign.h 6.47 KB
Newer Older
1
2
#pragma once

3
#include "cpu/vision_cpu.h"
4
5

#ifdef WITH_CUDA
6
#include "autocast.h"
7
#include "cuda/vision_cuda.h"
8
#endif
9
10
11
#ifdef WITH_HIP
#include "hip/vision_cuda.h"
#endif
12

13
14
// TODO: put this stuff in torchvision namespace

15
// roi_align dispatch nexus
16
at::Tensor roi_align(
17
18
    const at::Tensor& input, // Input feature map.
    const at::Tensor& rois, // List of ROIs to pool over.
19
    const double spatial_scale, // The scale of the image features. ROIs will be
20
    // scaled to this.
21
22
    const int64_t pooled_height, // The height of the pooled feature map.
    const int64_t pooled_width, // The width of the pooled feature
AhnDW's avatar
AhnDW committed
23
24
    const int64_t sampling_ratio, // The number of points to sample in each bin
    const bool aligned) // The flag for pixel shift
25
26
// along each axis.
{
27
28
29
30
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::roi_align", "")
                       .typed<decltype(roi_align)>();
  return op.call(
Francisco Massa's avatar
Francisco Massa committed
31
32
33
34
35
36
37
      input,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      sampling_ratio,
      aligned);
38
39
}

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#ifdef WITH_CUDA
at::Tensor ROIAlign_autocast(
    const at::Tensor& input,
    const at::Tensor& rois,
    const double spatial_scale,
    const int64_t pooled_height,
    const int64_t pooled_width,
    const int64_t sampling_ratio,
    const bool aligned) {
  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  return roi_align(
             autocast::_cast(at::kFloat, input),
             autocast::_cast(at::kFloat, rois),
             spatial_scale,
             pooled_height,
             pooled_width,
             sampling_ratio,
             aligned)
      .to(input.scalar_type());
}
#endif

62
at::Tensor _roi_align_backward(
63
64
    const at::Tensor& grad,
    const at::Tensor& rois,
65
66
67
68
69
70
71
72
    const double spatial_scale,
    const int64_t pooled_height,
    const int64_t pooled_width,
    const int64_t batch_size,
    const int64_t channels,
    const int64_t height,
    const int64_t width,
    const int64_t sampling_ratio,
AhnDW's avatar
AhnDW committed
73
    const bool aligned) {
74
75
76
77
78
  static auto op =
      c10::Dispatcher::singleton()
          .findSchemaOrThrow("torchvision::_roi_align_backward", "")
          .typed<decltype(_roi_align_backward)>();
  return op.call(
79
80
81
82
83
84
85
86
87
      grad,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      batch_size,
      channels,
      height,
      width,
AhnDW's avatar
AhnDW committed
88
89
      sampling_ratio,
      aligned);
90
}
91
92
93

class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
 public:
94
95
96
97
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::Variable input,
      torch::autograd::Variable rois,
98
99
100
      const double spatial_scale,
      const int64_t pooled_height,
      const int64_t pooled_width,
AhnDW's avatar
AhnDW committed
101
102
      const int64_t sampling_ratio,
      const bool aligned) {
103
104
105
106
    ctx->saved_data["spatial_scale"] = spatial_scale;
    ctx->saved_data["pooled_height"] = pooled_height;
    ctx->saved_data["pooled_width"] = pooled_width;
    ctx->saved_data["sampling_ratio"] = sampling_ratio;
AhnDW's avatar
AhnDW committed
107
    ctx->saved_data["aligned"] = aligned;
108
109
    ctx->saved_data["input_shape"] = input.sizes();
    ctx->save_for_backward({rois});
110
111
    at::AutoNonVariableTypeMode g;
    auto result = roi_align(
112
113
114
115
116
        input,
        rois,
        spatial_scale,
        pooled_height,
        pooled_width,
AhnDW's avatar
AhnDW committed
117
118
        sampling_ratio,
        aligned);
119
120
121
    return {result};
  }

122
123
124
  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::variable_list grad_output) {
125
126
127
128
    // Use data saved in forward
    auto saved = ctx->get_saved_variables();
    auto rois = saved[0];
    auto input_shape = ctx->saved_data["input_shape"].toIntList();
129
    auto grad_in = _roi_align_backward(
130
131
132
133
134
135
136
137
138
        grad_output[0],
        rois,
        ctx->saved_data["spatial_scale"].toDouble(),
        ctx->saved_data["pooled_height"].toInt(),
        ctx->saved_data["pooled_width"].toInt(),
        input_shape[0],
        input_shape[1],
        input_shape[2],
        input_shape[3],
AhnDW's avatar
AhnDW committed
139
140
        ctx->saved_data["sampling_ratio"].toInt(),
        ctx->saved_data["aligned"].toBool());
Francisco Massa's avatar
Francisco Massa committed
141
    return {grad_in,
142
143
144
145
146
147
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable()};
148
149
150
  }
};

151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
// TODO: There should be an easier way to do this
class ROIAlignBackwardFunction
    : public torch::autograd::Function<ROIAlignBackwardFunction> {
 public:
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::Variable grad,
      torch::autograd::Variable rois,
      const double spatial_scale,
      const int64_t pooled_height,
      const int64_t pooled_width,
      const int64_t batch_size,
      const int64_t channels,
      const int64_t height,
      const int64_t width,
      const int64_t sampling_ratio,
      const bool aligned) {
    at::AutoNonVariableTypeMode g;
    auto result = _roi_align_backward(
        grad,
        rois,
        spatial_scale,
        pooled_height,
        pooled_width,
        batch_size,
        channels,
        height,
        width,
        sampling_ratio,
        aligned);
    return {result};
  }

  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::variable_list grad_output) {
    TORCH_CHECK(0, "double backwards on roi_align not supported");
  }
};

at::Tensor ROIAlign_autograd(
192
193
    const at::Tensor& input,
    const at::Tensor& rois,
194
195
196
    const double spatial_scale,
    const int64_t pooled_height,
    const int64_t pooled_width,
AhnDW's avatar
AhnDW committed
197
198
    const int64_t sampling_ratio,
    const bool aligned) {
199
200
201
202
203
204
  return ROIAlignFunction::apply(
      input,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
AhnDW's avatar
AhnDW committed
205
206
      sampling_ratio,
      aligned)[0];
207
}
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

at::Tensor ROIAlign_backward_autograd(
    const at::Tensor& grad,
    const at::Tensor& rois,
    const double spatial_scale,
    const int64_t pooled_height,
    const int64_t pooled_width,
    const int64_t batch_size,
    const int64_t channels,
    const int64_t height,
    const int64_t width,
    const int64_t sampling_ratio,
    const bool aligned) {
  return ROIAlignBackwardFunction::apply(
      grad,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      batch_size,
      channels,
      height,
      width,
      sampling_ratio,
      aligned)[0];
}