ROIAlign.h 6.3 KB
Newer Older
1
2
#pragma once

3
#include "cpu/vision_cpu.h"
4
5

#ifdef WITH_CUDA
6
#include "autocast.h"
7
#include "cuda/vision_cuda.h"
8
#endif
9
#ifdef WITH_HIP
10
#include "autocast.h"
11
12
#include "hip/vision_cuda.h"
#endif
13

14
15
// TODO: put this stuff in torchvision namespace

16
// roi_align dispatch nexus
17
at::Tensor roi_align(
18
19
    const at::Tensor& input, // Input feature map.
    const at::Tensor& rois, // List of ROIs to pool over.
20
    double spatial_scale, // The scale of the image features. ROIs will be
21
    // scaled to this.
22
23
24
25
    int64_t pooled_height, // The height of the pooled feature map.
    int64_t pooled_width, // The width of the pooled feature
    int64_t sampling_ratio, // The number of points to sample in each bin
    bool aligned) // The flag for pixel shift
26
27
// along each axis.
{
28
29
30
31
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::roi_align", "")
                       .typed<decltype(roi_align)>();
  return op.call(
Francisco Massa's avatar
Francisco Massa committed
32
33
34
35
36
37
38
      input,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      sampling_ratio,
      aligned);
39
40
}

41
#if defined(WITH_CUDA) || defined(WITH_HIP)
42
43
44
at::Tensor ROIAlign_autocast(
    const at::Tensor& input,
    const at::Tensor& rois,
45
46
47
48
49
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t sampling_ratio,
    bool aligned) {
50
51
  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  return roi_align(
mcarilli's avatar
mcarilli committed
52
53
             at::autocast::cached_cast(at::kFloat, input),
             at::autocast::cached_cast(at::kFloat, rois),
54
55
56
57
58
59
60
61
62
             spatial_scale,
             pooled_height,
             pooled_width,
             sampling_ratio,
             aligned)
      .to(input.scalar_type());
}
#endif

63
at::Tensor _roi_align_backward(
64
65
    const at::Tensor& grad,
    const at::Tensor& rois,
66
67
68
69
70
71
72
73
74
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t batch_size,
    int64_t channels,
    int64_t height,
    int64_t width,
    int64_t sampling_ratio,
    bool aligned) {
75
76
77
78
79
  static auto op =
      c10::Dispatcher::singleton()
          .findSchemaOrThrow("torchvision::_roi_align_backward", "")
          .typed<decltype(_roi_align_backward)>();
  return op.call(
80
81
82
83
84
85
86
87
88
      grad,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      batch_size,
      channels,
      height,
      width,
AhnDW's avatar
AhnDW committed
89
90
      sampling_ratio,
      aligned);
91
}
92
93
94

class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
 public:
95
96
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
97
98
99
100
101
102
103
      const torch::autograd::Variable& input,
      const torch::autograd::Variable& rois,
      double spatial_scale,
      int64_t pooled_height,
      int64_t pooled_width,
      int64_t sampling_ratio,
      bool aligned) {
104
105
106
107
    ctx->saved_data["spatial_scale"] = spatial_scale;
    ctx->saved_data["pooled_height"] = pooled_height;
    ctx->saved_data["pooled_width"] = pooled_width;
    ctx->saved_data["sampling_ratio"] = sampling_ratio;
AhnDW's avatar
AhnDW committed
108
    ctx->saved_data["aligned"] = aligned;
109
110
    ctx->saved_data["input_shape"] = input.sizes();
    ctx->save_for_backward({rois});
111
112
    at::AutoNonVariableTypeMode g;
    auto result = roi_align(
113
114
115
116
117
        input,
        rois,
        spatial_scale,
        pooled_height,
        pooled_width,
AhnDW's avatar
AhnDW committed
118
119
        sampling_ratio,
        aligned);
120
121
122
    return {result};
  }

123
124
  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
125
      const torch::autograd::variable_list& grad_output) {
126
127
128
129
    // Use data saved in forward
    auto saved = ctx->get_saved_variables();
    auto rois = saved[0];
    auto input_shape = ctx->saved_data["input_shape"].toIntList();
130
    auto grad_in = _roi_align_backward(
131
132
133
134
135
136
137
138
139
        grad_output[0],
        rois,
        ctx->saved_data["spatial_scale"].toDouble(),
        ctx->saved_data["pooled_height"].toInt(),
        ctx->saved_data["pooled_width"].toInt(),
        input_shape[0],
        input_shape[1],
        input_shape[2],
        input_shape[3],
AhnDW's avatar
AhnDW committed
140
141
        ctx->saved_data["sampling_ratio"].toInt(),
        ctx->saved_data["aligned"].toBool());
Francisco Massa's avatar
Francisco Massa committed
142
    return {grad_in,
143
144
145
146
147
148
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable()};
149
150
151
  }
};

152
153
154
155
156
157
// TODO: There should be an easier way to do this
class ROIAlignBackwardFunction
    : public torch::autograd::Function<ROIAlignBackwardFunction> {
 public:
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
158
159
160
161
162
163
164
165
166
167
168
      const torch::autograd::Variable& grad,
      const torch::autograd::Variable& rois,
      double spatial_scale,
      int64_t pooled_height,
      int64_t pooled_width,
      int64_t batch_size,
      int64_t channels,
      int64_t height,
      int64_t width,
      int64_t sampling_ratio,
      bool aligned) {
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    at::AutoNonVariableTypeMode g;
    auto result = _roi_align_backward(
        grad,
        rois,
        spatial_scale,
        pooled_height,
        pooled_width,
        batch_size,
        channels,
        height,
        width,
        sampling_ratio,
        aligned);
    return {result};
  }

  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
187
      const torch::autograd::variable_list& grad_output) {
188
189
190
191
192
    TORCH_CHECK(0, "double backwards on roi_align not supported");
  }
};

at::Tensor ROIAlign_autograd(
193
194
    const at::Tensor& input,
    const at::Tensor& rois,
195
196
197
198
199
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t sampling_ratio,
    bool aligned) {
200
201
202
203
204
205
  return ROIAlignFunction::apply(
      input,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
AhnDW's avatar
AhnDW committed
206
207
      sampling_ratio,
      aligned)[0];
208
}
209
210
211
212

at::Tensor ROIAlign_backward_autograd(
    const at::Tensor& grad,
    const at::Tensor& rois,
213
214
215
216
217
218
219
220
221
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t batch_size,
    int64_t channels,
    int64_t height,
    int64_t width,
    int64_t sampling_ratio,
    bool aligned) {
222
223
224
225
226
227
228
229
230
231
232
233
234
  return ROIAlignBackwardFunction::apply(
      grad,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      batch_size,
      channels,
      height,
      width,
      sampling_ratio,
      aligned)[0];
}