ps_roi_align.cpp 6.96 KB
Newer Older
1
#include "ps_roi_align.h"
2
3
4

#include <torch/autograd.h>
#include <torch/types.h>
5

6
7
#if defined(WITH_CUDA) || defined(WITH_HIP)
#include <ATen/autocast_mode.h>
8
#endif
9

10
11
namespace vision {
namespace ops {
12
13

std::tuple<at::Tensor, at::Tensor> ps_roi_align(
14
15
    const at::Tensor& input,
    const at::Tensor& rois,
16
17
18
19
20
21
22
23
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t sampling_ratio) {
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::ps_roi_align", "")
                       .typed<decltype(ps_roi_align)>();
  return op.call(
24
25
26
      input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
}

27
#if defined(WITH_CUDA) || defined(WITH_HIP)
28
std::tuple<at::Tensor, at::Tensor> ps_roi_align_autocast(
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    const at::Tensor& input,
    const at::Tensor& rois,
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t sampling_ratio) {
  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  auto result = ps_roi_align(
      at::autocast::cached_cast(at::kFloat, input),
      at::autocast::cached_cast(at::kFloat, rois),
      spatial_scale,
      pooled_height,
      pooled_width,
      sampling_ratio);

  return std::make_tuple(
      std::get<0>(result).to(input.scalar_type()),
      std::get<1>(result).to(input.scalar_type()));
}
48
49
50
51

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
  m.impl("ps_roi_align", ps_roi_align_autocast);
}
52
#endif
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

at::Tensor _ps_roi_align_backward(
    const at::Tensor& grad,
    const at::Tensor& rois,
    const at::Tensor& channel_mapping,
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t sampling_ratio,
    int64_t batch_size,
    int64_t channels,
    int64_t height,
    int64_t width) {
  static auto op =
      c10::Dispatcher::singleton()
          .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "")
          .typed<decltype(_ps_roi_align_backward)>();
  return op.call(
71
72
      grad,
      rois,
73
      channel_mapping,
74
75
76
77
78
79
80
81
82
83
      spatial_scale,
      pooled_height,
      pooled_width,
      sampling_ratio,
      batch_size,
      channels,
      height,
      width);
}

84
85
86
87
88
89
90
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
  m.def(
      "ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)");
  m.def(
      "_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor");
}

91
92
namespace {

93
94
95
class PSROIAlignFunction
    : public torch::autograd::Function<PSROIAlignFunction> {
 public:
96
97
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
98
99
100
101
102
103
      const torch::autograd::Variable& input,
      const torch::autograd::Variable& rois,
      double spatial_scale,
      int64_t pooled_height,
      int64_t pooled_width,
      int64_t sampling_ratio) {
104
105
106
107
108
    ctx->saved_data["spatial_scale"] = spatial_scale;
    ctx->saved_data["pooled_height"] = pooled_height;
    ctx->saved_data["pooled_width"] = pooled_width;
    ctx->saved_data["sampling_ratio"] = sampling_ratio;
    ctx->saved_data["input_shape"] = input.sizes();
109
110
    at::AutoNonVariableTypeMode g;
    auto result = ps_roi_align(
111
112
113
114
115
116
        input,
        rois,
        spatial_scale,
        pooled_height,
        pooled_width,
        sampling_ratio);
117

118
119
120
121
    auto output = std::get<0>(result);
    auto channel_mapping = std::get<1>(result);
    ctx->save_for_backward({rois, channel_mapping});
    ctx->mark_non_differentiable({channel_mapping});
122

123
124
125
    return {output, channel_mapping};
  }

126
127
  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
128
      const torch::autograd::variable_list& grad_output) {
129
130
131
132
133
    // Use data saved in forward
    auto saved = ctx->get_saved_variables();
    auto rois = saved[0];
    auto channel_mapping = saved[1];
    auto input_shape = ctx->saved_data["input_shape"].toIntList();
134
    auto grad_in = _ps_roi_align_backward(
135
136
137
138
139
140
141
142
143
144
145
        grad_output[0],
        rois,
        channel_mapping,
        ctx->saved_data["spatial_scale"].toDouble(),
        ctx->saved_data["pooled_height"].toInt(),
        ctx->saved_data["pooled_width"].toInt(),
        ctx->saved_data["sampling_ratio"].toInt(),
        input_shape[0],
        input_shape[1],
        input_shape[2],
        input_shape[3]);
146

147
148
149
150
151
152
    return {grad_in,
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable()};
153
154
155
  }
};

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
// TODO: There should be an easier way to do this
class PSROIAlignBackwardFunction
    : public torch::autograd::Function<PSROIAlignBackwardFunction> {
 public:
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
      const torch::autograd::Variable& grad,
      const torch::autograd::Variable& rois,
      const torch::autograd::Variable& channel_mapping,
      double spatial_scale,
      int64_t pooled_height,
      int64_t pooled_width,
      int64_t sampling_ratio,
      int64_t batch_size,
      int64_t channels,
      int64_t height,
      int64_t width) {
    at::AutoNonVariableTypeMode g;
    auto grad_in = _ps_roi_align_backward(
        grad,
        rois,
        channel_mapping,
        spatial_scale,
        pooled_height,
        pooled_width,
        sampling_ratio,
        batch_size,
        channels,
        height,
        width);

    return {grad_in};
  }

  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
      const torch::autograd::variable_list& grad_output) {
    TORCH_CHECK(0, "double backwards on ps_roi_align not supported");
  }
};

197
198
199
} // namespace

std::tuple<at::Tensor, at::Tensor> ps_roi_align_autograd(
200
201
    const at::Tensor& input,
    const at::Tensor& rois,
202
203
204
205
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t sampling_ratio) {
206
207
  auto result = PSROIAlignFunction::apply(
      input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
208
209

  return std::make_tuple(result[0], result[1]);
210
}
211

212
at::Tensor ps_roi_align_backward_autograd(
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    const at::Tensor& grad,
    const at::Tensor& rois,
    const at::Tensor& channel_mapping,
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t sampling_ratio,
    int64_t batch_size,
    int64_t channels,
    int64_t height,
    int64_t width) {
  return PSROIAlignBackwardFunction::apply(
      grad,
      rois,
      channel_mapping,
      spatial_scale,
      pooled_height,
      pooled_width,
      sampling_ratio,
      batch_size,
      channels,
      height,
      width)[0];
236
}
237

238
239
240
241
242
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
  m.impl("ps_roi_align", ps_roi_align_autograd);
  m.impl("_ps_roi_align_backward", ps_roi_align_backward_autograd);
}

243
244
} // namespace ops
} // namespace vision