deform_conv2d.cpp 10 KB
Newer Older
1
#include "deform_conv2d.h"
2
3
4

#include <torch/autograd.h>
#include <torch/types.h>
5

6
7
#if defined(WITH_CUDA) || defined(WITH_HIP)
#include <ATen/autocast_mode.h>
8
#endif
9

10
11
namespace vision {
namespace ops {
12
13

at::Tensor deform_conv2d(
14
15
16
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
17
    const at::Tensor& mask,
18
    const at::Tensor& bias,
19
20
21
22
23
24
25
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
26
27
    int64_t offset_groups,
    bool use_mask) {
28
29
30
31
32
33
34
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::deform_conv2d", "")
                       .typed<decltype(deform_conv2d)>();
  return op.call(
      input,
      weight,
      offset,
35
      mask,
36
37
38
39
40
41
42
      bias,
      stride_h,
      stride_w,
      pad_h,
      pad_w,
      dilation_h,
      dilation_w,
43
      groups,
44
45
      offset_groups,
      use_mask);
46
47
}

48
#if defined(WITH_CUDA) || defined(WITH_HIP)
49
at::Tensor deform_conv2d_autocast(
50
51
52
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
53
    const at::Tensor& mask,
54
    const at::Tensor& bias,
55
56
57
58
59
60
61
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
62
63
    int64_t offset_groups,
    bool use_mask) {
64
65
66
67
68
  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  return deform_conv2d(
             at::autocast::cached_cast(at::kFloat, input),
             at::autocast::cached_cast(at::kFloat, weight),
             at::autocast::cached_cast(at::kFloat, offset),
69
             at::autocast::cached_cast(at::kFloat, mask),
70
71
72
73
74
75
76
77
             at::autocast::cached_cast(at::kFloat, bias),
             stride_h,
             stride_w,
             pad_h,
             pad_w,
             dilation_h,
             dilation_w,
             groups,
78
79
             offset_groups,
             use_mask)
80
81
      .to(input.scalar_type());
}
82
83
84
85

TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
  m.impl("deform_conv2d", deform_conv2d_autocast);
}
86
#endif
87

88
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
89
90
91
92
93
_deform_conv2d_backward(
    const at::Tensor& grad,
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
94
    const at::Tensor& mask,
95
    const at::Tensor& bias,
96
97
98
99
100
101
102
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
103
104
    int64_t offset_groups,
    bool use_mask) {
105
106
107
108
109
110
111
112
113
  static auto op =
      c10::Dispatcher::singleton()
          .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
          .typed<decltype(_deform_conv2d_backward)>();
  return op.call(
      grad,
      input,
      weight,
      offset,
114
      mask,
115
116
117
118
119
120
121
      bias,
      stride_h,
      stride_w,
      pad_h,
      pad_w,
      dilation_h,
      dilation_w,
122
      groups,
123
124
      offset_groups,
      use_mask);
125
126
}

127
128
129
130
131
132
133
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
  m.def(
      "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor");
  m.def(
      "_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
}

134
135
namespace {

136
137
138
class DeformConv2dFunction
    : public torch::autograd::Function<DeformConv2dFunction> {
 public:
139
140
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
141
142
143
      const torch::autograd::Variable& input,
      const torch::autograd::Variable& weight,
      const torch::autograd::Variable& offset,
144
      const torch::autograd::Variable& mask,
145
      const torch::autograd::Variable& bias,
146
147
148
149
150
151
152
      int64_t stride_h,
      int64_t stride_w,
      int64_t pad_h,
      int64_t pad_w,
      int64_t dilation_h,
      int64_t dilation_w,
      int64_t groups,
153
154
      int64_t offset_groups,
      bool use_mask) {
155
    at::AutoNonVariableTypeMode g;
156
    auto output = deform_conv2d(
157
158
159
        input,
        weight,
        offset,
160
        mask,
161
        bias,
162
163
164
165
166
167
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        dilation_h,
        dilation_w,
168
        groups,
169
170
        offset_groups,
        use_mask);
171

172
    ctx->save_for_backward({input, weight, offset, mask, bias});
173
174
175
176
177
178
179
180
    ctx->saved_data["stride_h"] = stride_h;
    ctx->saved_data["stride_w"] = stride_w;
    ctx->saved_data["pad_h"] = pad_h;
    ctx->saved_data["pad_w"] = pad_w;
    ctx->saved_data["dilation_h"] = dilation_h;
    ctx->saved_data["dilation_w"] = dilation_w;
    ctx->saved_data["groups"] = groups;
    ctx->saved_data["offset_groups"] = offset_groups;
181
    ctx->saved_data["use_mask"] = use_mask;
182
183
184
185
186
187

    return {
        output,
    };
  }

188
189
  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
190
      const torch::autograd::variable_list& grad_output) {
191
192
193
194
    auto saved = ctx->get_saved_variables();
    auto input = saved[0];
    auto weight = saved[1];
    auto offset = saved[2];
195
196
    auto mask = saved[3];
    auto bias = saved[4];
197
198
199
200
201
202
203
204
205

    auto stride_h = ctx->saved_data["stride_h"].toInt();
    auto stride_w = ctx->saved_data["stride_w"].toInt();
    auto pad_h = ctx->saved_data["pad_h"].toInt();
    auto pad_w = ctx->saved_data["pad_w"].toInt();
    auto dilation_h = ctx->saved_data["dilation_h"].toInt();
    auto dilation_w = ctx->saved_data["dilation_w"].toInt();
    auto groups = ctx->saved_data["groups"].toInt();
    auto offset_groups = ctx->saved_data["offset_groups"].toInt();
206
    auto use_mask = ctx->saved_data["use_mask"].toBool();
207

208
    auto grads = _deform_conv2d_backward(
209
210
211
212
        grad_output[0],
        input,
        weight,
        offset,
213
        mask,
214
        bias,
215
216
217
218
219
220
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        dilation_h,
        dilation_w,
221
        groups,
222
223
        offset_groups,
        use_mask);
224
225
226
    auto grad_input = std::get<0>(grads);
    auto grad_weight = std::get<1>(grads);
    auto grad_offset = std::get<2>(grads);
227
228
    auto grad_mask = std::get<3>(grads);
    auto grad_bias = std::get<4>(grads);
229
230
231
232
233

    return {
        grad_input,
        grad_weight,
        grad_offset,
234
        grad_mask,
235
        grad_bias,
236
237
238
239
240
241
242
243
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
244
        torch::autograd::Variable(),
245
246
247
248
    };
  }
};

249
250
251
252
253
254
// TODO: There should be an easier way to do this
class DeformConv2dBackwardFunction
    : public torch::autograd::Function<DeformConv2dBackwardFunction> {
 public:
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
255
256
257
258
      const torch::autograd::Variable& grad,
      const torch::autograd::Variable& input,
      const torch::autograd::Variable& weight,
      const torch::autograd::Variable& offset,
259
      const torch::autograd::Variable& mask,
260
261
262
263
264
265
266
267
      const torch::autograd::Variable& bias,
      int64_t stride_h,
      int64_t stride_w,
      int64_t pad_h,
      int64_t pad_w,
      int64_t dilation_h,
      int64_t dilation_w,
      int64_t groups,
268
269
      int64_t offset_groups,
      bool use_mask) {
270
271
272
273
274
275
    at::AutoNonVariableTypeMode g;
    auto result = _deform_conv2d_backward(
        grad,
        input,
        weight,
        offset,
276
        mask,
277
278
279
280
281
282
283
284
        bias,
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        dilation_h,
        dilation_w,
        groups,
285
286
        offset_groups,
        use_mask);
287
288
289
290

    auto grad_input = std::get<0>(result);
    auto grad_weight = std::get<1>(result);
    auto grad_offset = std::get<2>(result);
291
292
    auto grad_mask = std::get<3>(result);
    auto grad_bias = std::get<4>(result);
293
294
295
296
297

    return {
        grad_input,
        grad_weight,
        grad_offset,
298
        grad_mask,
299
300
301
302
303
304
        grad_bias,
    };
  }

  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
305
      const torch::autograd::variable_list& grad_output) {
306
307
308
309
    TORCH_CHECK(0, "double backwards on deform_conv2d not supported");
  }
};

310
311
312
} // namespace

at::Tensor deform_conv2d_autograd(
313
314
315
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
316
    const at::Tensor& mask,
317
    const at::Tensor& bias,
318
319
320
321
322
323
324
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
325
326
    int64_t offset_groups,
    bool use_mask) {
327
  return DeformConv2dFunction::apply(
328
329
330
      input,
      weight,
      offset,
331
      mask,
332
333
334
335
336
337
338
339
      bias,
      stride_h,
      stride_w,
      pad_h,
      pad_w,
      dilation_h,
      dilation_w,
      groups,
340
341
      offset_groups,
      use_mask)[0];
342
}
343

344
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
345
deform_conv2d_backward_autograd(
346
347
348
349
    const at::Tensor& grad,
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
350
    const at::Tensor& mask,
351
    const at::Tensor& bias,
352
353
354
355
356
357
358
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
359
360
    int64_t offset_groups,
    bool use_mask) {
361
362
363
364
365
  auto result = DeformConv2dBackwardFunction::apply(
      grad,
      input,
      weight,
      offset,
366
      mask,
367
368
369
370
371
372
373
374
      bias,
      stride_h,
      stride_w,
      pad_h,
      pad_w,
      dilation_h,
      dilation_w,
      groups,
375
376
      offset_groups,
      use_mask);
377

378
379
  return std::make_tuple(result[0], result[1], result[2], result[3], result[4]);
}
380

381
382
383
384
385
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
  m.impl("deform_conv2d", deform_conv2d_autograd);
  m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
}

386
387
} // namespace ops
} // namespace vision