ROIAlign.h 5.8 KB
Newer Older
1
2
#pragma once

3
#include "cpu/vision_cpu.h"
4
5

#ifdef WITH_CUDA
6
#include "cuda/vision_cuda.h"
7
#endif
8
9
10
#ifdef WITH_HIP
#include "hip/vision_cuda.h"
#endif
11

12
13
14
// TODO: put this stuff in torchvision namespace

at::Tensor roi_align(
15
16
    const at::Tensor& input, // Input feature map.
    const at::Tensor& rois, // List of ROIs to pool over.
17
    const double spatial_scale, // The scale of the image features. ROIs will be
18
    // scaled to this.
19
20
    const int64_t pooled_height, // The height of the pooled feature map.
    const int64_t pooled_width, // The width of the pooled feature
AhnDW's avatar
AhnDW committed
21
22
    const int64_t sampling_ratio, // The number of points to sample in each bin
    const bool aligned) // The flag for pixel shift
23
24
// along each axis.
{
25
26
27
28
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::roi_align", "")
                       .typed<decltype(roi_align)>();
  return op.call(
Francisco Massa's avatar
Francisco Massa committed
29
30
31
32
33
34
35
      input,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      sampling_ratio,
      aligned);
36
37
}

38
at::Tensor _roi_align_backward(
39
40
    const at::Tensor& grad,
    const at::Tensor& rois,
41
42
43
44
45
46
47
48
    const double spatial_scale,
    const int64_t pooled_height,
    const int64_t pooled_width,
    const int64_t batch_size,
    const int64_t channels,
    const int64_t height,
    const int64_t width,
    const int64_t sampling_ratio,
AhnDW's avatar
AhnDW committed
49
    const bool aligned) {
50
51
52
53
54
  static auto op =
      c10::Dispatcher::singleton()
          .findSchemaOrThrow("torchvision::_roi_align_backward", "")
          .typed<decltype(_roi_align_backward)>();
  return op.call(
55
56
57
58
59
60
61
62
63
      grad,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      batch_size,
      channels,
      height,
      width,
AhnDW's avatar
AhnDW committed
64
65
      sampling_ratio,
      aligned);
66
}
67
68
69

class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
 public:
70
71
72
73
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::Variable input,
      torch::autograd::Variable rois,
74
75
76
      const double spatial_scale,
      const int64_t pooled_height,
      const int64_t pooled_width,
AhnDW's avatar
AhnDW committed
77
78
      const int64_t sampling_ratio,
      const bool aligned) {
79
80
81
82
    ctx->saved_data["spatial_scale"] = spatial_scale;
    ctx->saved_data["pooled_height"] = pooled_height;
    ctx->saved_data["pooled_width"] = pooled_width;
    ctx->saved_data["sampling_ratio"] = sampling_ratio;
AhnDW's avatar
AhnDW committed
83
    ctx->saved_data["aligned"] = aligned;
84
85
    ctx->saved_data["input_shape"] = input.sizes();
    ctx->save_for_backward({rois});
86
87
    at::AutoNonVariableTypeMode g;
    auto result = roi_align(
88
89
90
91
92
        input,
        rois,
        spatial_scale,
        pooled_height,
        pooled_width,
AhnDW's avatar
AhnDW committed
93
94
        sampling_ratio,
        aligned);
95
96
97
    return {result};
  }

98
99
100
  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::variable_list grad_output) {
101
102
103
104
    // Use data saved in forward
    auto saved = ctx->get_saved_variables();
    auto rois = saved[0];
    auto input_shape = ctx->saved_data["input_shape"].toIntList();
105
    auto grad_in = _roi_align_backward(
106
107
108
109
110
111
112
113
114
        grad_output[0],
        rois,
        ctx->saved_data["spatial_scale"].toDouble(),
        ctx->saved_data["pooled_height"].toInt(),
        ctx->saved_data["pooled_width"].toInt(),
        input_shape[0],
        input_shape[1],
        input_shape[2],
        input_shape[3],
AhnDW's avatar
AhnDW committed
115
116
        ctx->saved_data["sampling_ratio"].toInt(),
        ctx->saved_data["aligned"].toBool());
Francisco Massa's avatar
Francisco Massa committed
117
    return {grad_in,
118
119
120
121
122
123
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable()};
124
125
126
  }
};

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
// TODO: There should be an easier way to do this
class ROIAlignBackwardFunction
    : public torch::autograd::Function<ROIAlignBackwardFunction> {
 public:
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::Variable grad,
      torch::autograd::Variable rois,
      const double spatial_scale,
      const int64_t pooled_height,
      const int64_t pooled_width,
      const int64_t batch_size,
      const int64_t channels,
      const int64_t height,
      const int64_t width,
      const int64_t sampling_ratio,
      const bool aligned) {
    at::AutoNonVariableTypeMode g;
    auto result = _roi_align_backward(
        grad,
        rois,
        spatial_scale,
        pooled_height,
        pooled_width,
        batch_size,
        channels,
        height,
        width,
        sampling_ratio,
        aligned);
    return {result};
  }

  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::variable_list grad_output) {
    TORCH_CHECK(0, "double backwards on roi_align not supported");
  }
};

at::Tensor ROIAlign_autograd(
168
169
    const at::Tensor& input,
    const at::Tensor& rois,
170
171
172
    const double spatial_scale,
    const int64_t pooled_height,
    const int64_t pooled_width,
AhnDW's avatar
AhnDW committed
173
174
    const int64_t sampling_ratio,
    const bool aligned) {
175
176
177
178
179
180
  return ROIAlignFunction::apply(
      input,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
AhnDW's avatar
AhnDW committed
181
182
      sampling_ratio,
      aligned)[0];
183
}
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209

at::Tensor ROIAlign_backward_autograd(
    const at::Tensor& grad,
    const at::Tensor& rois,
    const double spatial_scale,
    const int64_t pooled_height,
    const int64_t pooled_width,
    const int64_t batch_size,
    const int64_t channels,
    const int64_t height,
    const int64_t width,
    const int64_t sampling_ratio,
    const bool aligned) {
  return ROIAlignBackwardFunction::apply(
      grad,
      rois,
      spatial_scale,
      pooled_height,
      pooled_width,
      batch_size,
      channels,
      height,
      width,
      sampling_ratio,
      aligned)[0];
}