deform_conv2d.cpp 9.18 KB
Newer Older
1
2
#include "deform_conv2d.h"
#include <torch/extension.h>
3

4
5
#if defined(WITH_CUDA) || defined(WITH_HIP)
#include <ATen/autocast_mode.h>
6
#endif
7

8
9
namespace vision {
namespace ops {
10
11

at::Tensor deform_conv2d(
12
13
14
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
15
    const at::Tensor& mask,
16
    const at::Tensor& bias,
17
18
19
20
21
22
23
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
24
25
    int64_t offset_groups,
    bool use_mask) {
26
27
28
29
30
31
32
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::deform_conv2d", "")
                       .typed<decltype(deform_conv2d)>();
  return op.call(
      input,
      weight,
      offset,
33
      mask,
34
35
36
37
38
39
40
      bias,
      stride_h,
      stride_w,
      pad_h,
      pad_w,
      dilation_h,
      dilation_w,
41
      groups,
42
43
      offset_groups,
      use_mask);
44
45
}

46
#if defined(WITH_CUDA) || defined(WITH_HIP)
47
at::Tensor deform_conv2d_autocast(
48
49
50
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
51
    const at::Tensor& mask,
52
    const at::Tensor& bias,
53
54
55
56
57
58
59
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
60
61
    int64_t offset_groups,
    bool use_mask) {
62
63
64
65
66
  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  return deform_conv2d(
             at::autocast::cached_cast(at::kFloat, input),
             at::autocast::cached_cast(at::kFloat, weight),
             at::autocast::cached_cast(at::kFloat, offset),
67
             at::autocast::cached_cast(at::kFloat, mask),
68
69
70
71
72
73
74
75
             at::autocast::cached_cast(at::kFloat, bias),
             stride_h,
             stride_w,
             pad_h,
             pad_w,
             dilation_h,
             dilation_w,
             groups,
76
77
             offset_groups,
             use_mask)
78
79
      .to(input.scalar_type());
}
80
#endif
81

82
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
83
84
85
86
87
_deform_conv2d_backward(
    const at::Tensor& grad,
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
88
    const at::Tensor& mask,
89
    const at::Tensor& bias,
90
91
92
93
94
95
96
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
97
98
    int64_t offset_groups,
    bool use_mask) {
99
100
101
102
103
104
105
106
107
  static auto op =
      c10::Dispatcher::singleton()
          .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
          .typed<decltype(_deform_conv2d_backward)>();
  return op.call(
      grad,
      input,
      weight,
      offset,
108
      mask,
109
110
111
112
113
114
115
      bias,
      stride_h,
      stride_w,
      pad_h,
      pad_w,
      dilation_h,
      dilation_w,
116
      groups,
117
118
      offset_groups,
      use_mask);
119
120
}

121
122
namespace {

123
124
125
class DeformConv2dFunction
    : public torch::autograd::Function<DeformConv2dFunction> {
 public:
126
127
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
128
129
130
      const torch::autograd::Variable& input,
      const torch::autograd::Variable& weight,
      const torch::autograd::Variable& offset,
131
      const torch::autograd::Variable& mask,
132
      const torch::autograd::Variable& bias,
133
134
135
136
137
138
139
      int64_t stride_h,
      int64_t stride_w,
      int64_t pad_h,
      int64_t pad_w,
      int64_t dilation_h,
      int64_t dilation_w,
      int64_t groups,
140
141
      int64_t offset_groups,
      bool use_mask) {
142
    at::AutoNonVariableTypeMode g;
143
    auto output = deform_conv2d(
144
145
146
        input,
        weight,
        offset,
147
        mask,
148
        bias,
149
150
151
152
153
154
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        dilation_h,
        dilation_w,
155
        groups,
156
157
        offset_groups,
        use_mask);
158

159
    ctx->save_for_backward({input, weight, offset, mask, bias});
160
161
162
163
164
165
166
167
    ctx->saved_data["stride_h"] = stride_h;
    ctx->saved_data["stride_w"] = stride_w;
    ctx->saved_data["pad_h"] = pad_h;
    ctx->saved_data["pad_w"] = pad_w;
    ctx->saved_data["dilation_h"] = dilation_h;
    ctx->saved_data["dilation_w"] = dilation_w;
    ctx->saved_data["groups"] = groups;
    ctx->saved_data["offset_groups"] = offset_groups;
168
    ctx->saved_data["use_mask"] = use_mask;
169
170
171
172
173
174

    return {
        output,
    };
  }

175
176
  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
177
      const torch::autograd::variable_list& grad_output) {
178
179
180
181
    auto saved = ctx->get_saved_variables();
    auto input = saved[0];
    auto weight = saved[1];
    auto offset = saved[2];
182
183
    auto mask = saved[3];
    auto bias = saved[4];
184
185
186
187
188
189
190
191
192

    auto stride_h = ctx->saved_data["stride_h"].toInt();
    auto stride_w = ctx->saved_data["stride_w"].toInt();
    auto pad_h = ctx->saved_data["pad_h"].toInt();
    auto pad_w = ctx->saved_data["pad_w"].toInt();
    auto dilation_h = ctx->saved_data["dilation_h"].toInt();
    auto dilation_w = ctx->saved_data["dilation_w"].toInt();
    auto groups = ctx->saved_data["groups"].toInt();
    auto offset_groups = ctx->saved_data["offset_groups"].toInt();
193
    auto use_mask = ctx->saved_data["use_mask"].toBool();
194

195
    auto grads = _deform_conv2d_backward(
196
197
198
199
        grad_output[0],
        input,
        weight,
        offset,
200
        mask,
201
        bias,
202
203
204
205
206
207
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        dilation_h,
        dilation_w,
208
        groups,
209
210
        offset_groups,
        use_mask);
211
212
213
    auto grad_input = std::get<0>(grads);
    auto grad_weight = std::get<1>(grads);
    auto grad_offset = std::get<2>(grads);
214
215
    auto grad_mask = std::get<3>(grads);
    auto grad_bias = std::get<4>(grads);
216
217
218
219
220

    return {
        grad_input,
        grad_weight,
        grad_offset,
221
        grad_mask,
222
        grad_bias,
223
224
225
226
227
228
229
230
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
231
        torch::autograd::Variable(),
232
233
234
235
    };
  }
};

236
237
238
239
240
241
// TODO: There should be an easier way to do this
class DeformConv2dBackwardFunction
    : public torch::autograd::Function<DeformConv2dBackwardFunction> {
 public:
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
242
243
244
245
      const torch::autograd::Variable& grad,
      const torch::autograd::Variable& input,
      const torch::autograd::Variable& weight,
      const torch::autograd::Variable& offset,
246
      const torch::autograd::Variable& mask,
247
248
249
250
251
252
253
254
      const torch::autograd::Variable& bias,
      int64_t stride_h,
      int64_t stride_w,
      int64_t pad_h,
      int64_t pad_w,
      int64_t dilation_h,
      int64_t dilation_w,
      int64_t groups,
255
256
      int64_t offset_groups,
      bool use_mask) {
257
258
259
260
261
262
    at::AutoNonVariableTypeMode g;
    auto result = _deform_conv2d_backward(
        grad,
        input,
        weight,
        offset,
263
        mask,
264
265
266
267
268
269
270
271
        bias,
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        dilation_h,
        dilation_w,
        groups,
272
273
        offset_groups,
        use_mask);
274
275
276
277

    auto grad_input = std::get<0>(result);
    auto grad_weight = std::get<1>(result);
    auto grad_offset = std::get<2>(result);
278
279
    auto grad_mask = std::get<3>(result);
    auto grad_bias = std::get<4>(result);
280
281
282
283
284

    return {
        grad_input,
        grad_weight,
        grad_offset,
285
        grad_mask,
286
287
288
289
290
291
        grad_bias,
    };
  }

  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
292
      const torch::autograd::variable_list& grad_output) {
293
294
295
296
    TORCH_CHECK(0, "double backwards on deform_conv2d not supported");
  }
};

297
298
299
} // namespace

at::Tensor deform_conv2d_autograd(
300
301
302
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
303
    const at::Tensor& mask,
304
    const at::Tensor& bias,
305
306
307
308
309
310
311
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
312
313
    int64_t offset_groups,
    bool use_mask) {
314
  return DeformConv2dFunction::apply(
315
316
317
      input,
      weight,
      offset,
318
      mask,
319
320
321
322
323
324
325
326
      bias,
      stride_h,
      stride_w,
      pad_h,
      pad_w,
      dilation_h,
      dilation_w,
      groups,
327
328
      offset_groups,
      use_mask)[0];
329
}
330

331
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
332
deform_conv2d_backward_autograd(
333
334
335
336
    const at::Tensor& grad,
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
337
    const at::Tensor& mask,
338
    const at::Tensor& bias,
339
340
341
342
343
344
345
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
346
347
    int64_t offset_groups,
    bool use_mask) {
348
349
350
351
352
  auto result = DeformConv2dBackwardFunction::apply(
      grad,
      input,
      weight,
      offset,
353
      mask,
354
355
356
357
358
359
360
361
      bias,
      stride_h,
      stride_w,
      pad_h,
      pad_w,
      dilation_h,
      dilation_w,
      groups,
362
363
      offset_groups,
      use_mask);
364

365
366
  return std::make_tuple(result[0], result[1], result[2], result[3], result[4]);
}
367
368
369

} // namespace ops
} // namespace vision