PSROIAlign.h 6.26 KB
Newer Older
1
2
3
4
5
#pragma once

#include "cpu/vision_cpu.h"

#ifdef WITH_CUDA
6
#include "autocast.h"
7
8
#include "cuda/vision_cuda.h"
#endif
9
#ifdef WITH_HIP
10
#include "autocast.h"
11
12
#include "hip/vision_cuda.h"
#endif
13
14
15

#include <iostream>

16
17
18
// TODO: put this stuff in torchvision namespace

std::tuple<at::Tensor, at::Tensor> ps_roi_align(
19
20
    const at::Tensor& input,
    const at::Tensor& rois,
21
22
23
24
25
26
27
28
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t sampling_ratio) {
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::ps_roi_align", "")
                       .typed<decltype(ps_roi_align)>();
  return op.call(
29
30
31
      input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
}

32
#if defined(WITH_CUDA) || defined(WITH_HIP)
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
std::tuple<at::Tensor, at::Tensor> PSROIAlign_autocast(
    const at::Tensor& input,
    const at::Tensor& rois,
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t sampling_ratio) {
  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  auto result = ps_roi_align(
      at::autocast::cached_cast(at::kFloat, input),
      at::autocast::cached_cast(at::kFloat, rois),
      spatial_scale,
      pooled_height,
      pooled_width,
      sampling_ratio);

  return std::make_tuple(
      std::get<0>(result).to(input.scalar_type()),
      std::get<1>(result).to(input.scalar_type()));
}
53
#endif
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

at::Tensor _ps_roi_align_backward(
    const at::Tensor& grad,
    const at::Tensor& rois,
    const at::Tensor& channel_mapping,
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t sampling_ratio,
    int64_t batch_size,
    int64_t channels,
    int64_t height,
    int64_t width) {
  static auto op =
      c10::Dispatcher::singleton()
          .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "")
          .typed<decltype(_ps_roi_align_backward)>();
  return op.call(
72
73
      grad,
      rois,
74
      channel_mapping,
75
76
77
78
79
80
81
82
83
84
85
86
87
      spatial_scale,
      pooled_height,
      pooled_width,
      sampling_ratio,
      batch_size,
      channels,
      height,
      width);
}

class PSROIAlignFunction
    : public torch::autograd::Function<PSROIAlignFunction> {
 public:
88
89
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
90
91
92
93
94
95
      const torch::autograd::Variable& input,
      const torch::autograd::Variable& rois,
      double spatial_scale,
      int64_t pooled_height,
      int64_t pooled_width,
      int64_t sampling_ratio) {
96
97
98
99
100
    ctx->saved_data["spatial_scale"] = spatial_scale;
    ctx->saved_data["pooled_height"] = pooled_height;
    ctx->saved_data["pooled_width"] = pooled_width;
    ctx->saved_data["sampling_ratio"] = sampling_ratio;
    ctx->saved_data["input_shape"] = input.sizes();
101
102
    at::AutoNonVariableTypeMode g;
    auto result = ps_roi_align(
103
104
105
106
107
108
        input,
        rois,
        spatial_scale,
        pooled_height,
        pooled_width,
        sampling_ratio);
109

110
111
112
113
    auto output = std::get<0>(result);
    auto channel_mapping = std::get<1>(result);
    ctx->save_for_backward({rois, channel_mapping});
    ctx->mark_non_differentiable({channel_mapping});
114

115
116
117
    return {output, channel_mapping};
  }

118
119
  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
120
      const torch::autograd::variable_list& grad_output) {
121
122
123
124
125
    // Use data saved in forward
    auto saved = ctx->get_saved_variables();
    auto rois = saved[0];
    auto channel_mapping = saved[1];
    auto input_shape = ctx->saved_data["input_shape"].toIntList();
126
    auto grad_in = _ps_roi_align_backward(
127
128
129
130
131
132
133
134
135
136
137
        grad_output[0],
        rois,
        channel_mapping,
        ctx->saved_data["spatial_scale"].toDouble(),
        ctx->saved_data["pooled_height"].toInt(),
        ctx->saved_data["pooled_width"].toInt(),
        ctx->saved_data["sampling_ratio"].toInt(),
        input_shape[0],
        input_shape[1],
        input_shape[2],
        input_shape[3]);
138

139
140
141
142
143
144
    return {grad_in,
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable(),
            torch::autograd::Variable()};
145
146
147
  }
};

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
// TODO: There should be an easier way to do this
class PSROIAlignBackwardFunction
    : public torch::autograd::Function<PSROIAlignBackwardFunction> {
 public:
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
      const torch::autograd::Variable& grad,
      const torch::autograd::Variable& rois,
      const torch::autograd::Variable& channel_mapping,
      double spatial_scale,
      int64_t pooled_height,
      int64_t pooled_width,
      int64_t sampling_ratio,
      int64_t batch_size,
      int64_t channels,
      int64_t height,
      int64_t width) {
    at::AutoNonVariableTypeMode g;
    auto grad_in = _ps_roi_align_backward(
        grad,
        rois,
        channel_mapping,
        spatial_scale,
        pooled_height,
        pooled_width,
        sampling_ratio,
        batch_size,
        channels,
        height,
        width);

    return {grad_in};
  }

  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
      const torch::autograd::variable_list& grad_output) {
    TORCH_CHECK(0, "double backwards on ps_roi_align not supported");
  }
};

std::tuple<at::Tensor, at::Tensor> PSROIAlign_autograd(
190
191
    const at::Tensor& input,
    const at::Tensor& rois,
192
193
194
195
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t sampling_ratio) {
196
197
  auto result = PSROIAlignFunction::apply(
      input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
198
199

  return std::make_tuple(result[0], result[1]);
200
}
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

at::Tensor PSROIAlign_backward_autograd(
    const at::Tensor& grad,
    const at::Tensor& rois,
    const at::Tensor& channel_mapping,
    double spatial_scale,
    int64_t pooled_height,
    int64_t pooled_width,
    int64_t sampling_ratio,
    int64_t batch_size,
    int64_t channels,
    int64_t height,
    int64_t width) {
  return PSROIAlignBackwardFunction::apply(
      grad,
      rois,
      channel_mapping,
      spatial_scale,
      pooled_height,
      pooled_width,
      sampling_ratio,
      batch_size,
      channels,
      height,
      width)[0];
226
}