ROIAlign.h 6.52 KB
Newer Older
1
2
#pragma once

3
#include "cpu/vision_cpu.h"
4
5

#ifdef WITH_CUDA
6
#include "autocast.h"
7
#include "cuda/vision_cuda.h"
8
#endif
9
#ifdef WITH_HIP
10
#include "autocast.h"
11
12
#include "hip/vision_cuda.h"
#endif
13

14
15
// TODO: put this stuff in torchvision namespace

16
// roi_align dispatch nexus
17
at::Tensor roi_align(
18
19
    const at::Tensor& input, // Input feature map.
    const at::Tensor& rois, // List of ROIs to pool over.
20
    const double spatial_scale, // The scale of the image features. ROIs will be
21
    // scaled to this.
22
23
    const int64_t pooled_height, // The height of the pooled feature map.
    const int64_t pooled_width, // The width of the pooled feature
AhnDW's avatar
AhnDW committed
24
25
    const int64_t sampling_ratio, // The number of points to sample in each bin
    const bool aligned) // The flag for pixel shift
26
27
// along each axis.
{
28
29
30
31
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::roi_align", "")
                       .typed<decltype(roi_align)>();
  return op.call(
Francisco Massa's avatar
Francisco Massa committed
32
33
34
35
36
37
38
      input,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      sampling_ratio,
      aligned);
39
40
}

41
#if defined(WITH_CUDA) || defined(WITH_HIP)
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
at::Tensor ROIAlign_autocast(
    const at::Tensor& input,
    const at::Tensor& rois,
    const double spatial_scale,
    const int64_t pooled_height,
    const int64_t pooled_width,
    const int64_t sampling_ratio,
    const bool aligned) {
  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  return roi_align(
             autocast::_cast(at::kFloat, input),
             autocast::_cast(at::kFloat, rois),
             spatial_scale,
             pooled_height,
             pooled_width,
             sampling_ratio,
             aligned)
      .to(input.scalar_type());
}
#endif

63
at::Tensor _roi_align_backward(
64
65
    const at::Tensor& grad,
    const at::Tensor& rois,
66
67
68
69
70
71
72
73
    const double spatial_scale,
    const int64_t pooled_height,
    const int64_t pooled_width,
    const int64_t batch_size,
    const int64_t channels,
    const int64_t height,
    const int64_t width,
    const int64_t sampling_ratio,
AhnDW's avatar
AhnDW committed
74
    const bool aligned) {
75
76
77
78
79
  static auto op =
      c10::Dispatcher::singleton()
          .findSchemaOrThrow("torchvision::_roi_align_backward", "")
          .typed<decltype(_roi_align_backward)>();
  return op.call(
80
81
82
83
84
85
86
87
88
      grad,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      batch_size,
      channels,
      height,
      width,
AhnDW's avatar
AhnDW committed
89
90
      sampling_ratio,
      aligned);
91
}
92
93
94

class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
 public:
95
96
97
98
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::Variable input,
      torch::autograd::Variable rois,
99
100
101
      const double spatial_scale,
      const int64_t pooled_height,
      const int64_t pooled_width,
AhnDW's avatar
AhnDW committed
102
103
      const int64_t sampling_ratio,
      const bool aligned) {
104
105
106
107
    ctx->saved_data["spatial_scale"] = spatial_scale;
    ctx->saved_data["pooled_height"] = pooled_height;
    ctx->saved_data["pooled_width"] = pooled_width;
    ctx->saved_data["sampling_ratio"] = sampling_ratio;
AhnDW's avatar
AhnDW committed
108
    ctx->saved_data["aligned"] = aligned;
109
110
    ctx->saved_data["input_shape"] = input.sizes();
    ctx->save_for_backward({rois});
111
112
    at::AutoNonVariableTypeMode g;
    auto result = roi_align(
113
114
115
116
117
        input,
        rois,
        spatial_scale,
        pooled_height,
        pooled_width,
AhnDW's avatar
AhnDW committed
118
119
        sampling_ratio,
        aligned);
120
121
122
    return {result};
  }

123
124
125
  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::variable_list grad_output) {
126
127
128
129
    // Use data saved in forward
    auto saved = ctx->get_saved_variables();
    auto rois = saved[0];
    auto input_shape = ctx->saved_data["input_shape"].toIntList();
130
    auto grad_in = _roi_align_backward(
131
132
133
134
135
136
137
138
139
        grad_output[0],
        rois,
        ctx->saved_data["spatial_scale"].toDouble(),
        ctx->saved_data["pooled_height"].toInt(),
        ctx->saved_data["pooled_width"].toInt(),
        input_shape[0],
        input_shape[1],
        input_shape[2],
        input_shape[3],
AhnDW's avatar
AhnDW committed
140
141
        ctx->saved_data["sampling_ratio"].toInt(),
        ctx->saved_data["aligned"].toBool());
Francisco Massa's avatar
Francisco Massa committed
142
    return {grad_in,
143
144
145
146
147
148
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable()};
149
150
151
  }
};

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
// TODO: There should be an easier way to do this
class ROIAlignBackwardFunction
    : public torch::autograd::Function<ROIAlignBackwardFunction> {
 public:
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::Variable grad,
      torch::autograd::Variable rois,
      const double spatial_scale,
      const int64_t pooled_height,
      const int64_t pooled_width,
      const int64_t batch_size,
      const int64_t channels,
      const int64_t height,
      const int64_t width,
      const int64_t sampling_ratio,
      const bool aligned) {
    at::AutoNonVariableTypeMode g;
    auto result = _roi_align_backward(
        grad,
        rois,
        spatial_scale,
        pooled_height,
        pooled_width,
        batch_size,
        channels,
        height,
        width,
        sampling_ratio,
        aligned);
    return {result};
  }

  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::variable_list grad_output) {
    TORCH_CHECK(0, "double backwards on roi_align not supported");
  }
};

at::Tensor ROIAlign_autograd(
193
194
    const at::Tensor& input,
    const at::Tensor& rois,
195
196
197
    const double spatial_scale,
    const int64_t pooled_height,
    const int64_t pooled_width,
AhnDW's avatar
AhnDW committed
198
199
    const int64_t sampling_ratio,
    const bool aligned) {
200
201
202
203
204
205
  return ROIAlignFunction::apply(
      input,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
AhnDW's avatar
AhnDW committed
206
207
      sampling_ratio,
      aligned)[0];
208
}
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

at::Tensor ROIAlign_backward_autograd(
    const at::Tensor& grad,
    const at::Tensor& rois,
    const double spatial_scale,
    const int64_t pooled_height,
    const int64_t pooled_width,
    const int64_t batch_size,
    const int64_t channels,
    const int64_t height,
    const int64_t width,
    const int64_t sampling_ratio,
    const bool aligned) {
  return ROIAlignBackwardFunction::apply(
      grad,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      batch_size,
      channels,
      height,
      width,
      sampling_ratio,
      aligned)[0];
}