deform_conv2d.cpp 8.81 KB
Newer Older
1
#include "deform_conv2d.h"
2
3
4

#include <torch/autograd.h>
#include <torch/types.h>
5

6
7
namespace vision {
namespace ops {
8
9

at::Tensor deform_conv2d(
10
11
12
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
13
    const at::Tensor& mask,
14
    const at::Tensor& bias,
15
16
17
18
19
20
21
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
22
23
    int64_t offset_groups,
    bool use_mask) {
24
25
26
27
28
29
30
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::deform_conv2d", "")
                       .typed<decltype(deform_conv2d)>();
  return op.call(
      input,
      weight,
      offset,
31
      mask,
32
33
34
35
36
37
38
      bias,
      stride_h,
      stride_w,
      pad_h,
      pad_w,
      dilation_h,
      dilation_w,
39
      groups,
40
41
      offset_groups,
      use_mask);
42
43
}

44
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
45
46
47
48
49
_deform_conv2d_backward(
    const at::Tensor& grad,
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
50
    const at::Tensor& mask,
51
    const at::Tensor& bias,
52
53
54
55
56
57
58
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
59
60
    int64_t offset_groups,
    bool use_mask) {
61
62
63
64
65
66
67
68
69
  static auto op =
      c10::Dispatcher::singleton()
          .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
          .typed<decltype(_deform_conv2d_backward)>();
  return op.call(
      grad,
      input,
      weight,
      offset,
70
      mask,
71
72
73
74
75
76
77
      bias,
      stride_h,
      stride_w,
      pad_h,
      pad_w,
      dilation_h,
      dilation_w,
78
      groups,
79
80
      offset_groups,
      use_mask);
81
82
}

83
84
85
86
87
88
89
TORCH_LIBRARY_FRAGMENT(torchvision, m) {
  m.def(
      "deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor");
  m.def(
      "_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
}

90
91
namespace {

92
93
94
class DeformConv2dFunction
    : public torch::autograd::Function<DeformConv2dFunction> {
 public:
95
96
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
97
98
99
      const torch::autograd::Variable& input,
      const torch::autograd::Variable& weight,
      const torch::autograd::Variable& offset,
100
      const torch::autograd::Variable& mask,
101
      const torch::autograd::Variable& bias,
102
103
104
105
106
107
108
      int64_t stride_h,
      int64_t stride_w,
      int64_t pad_h,
      int64_t pad_w,
      int64_t dilation_h,
      int64_t dilation_w,
      int64_t groups,
109
110
      int64_t offset_groups,
      bool use_mask) {
111
    at::AutoNonVariableTypeMode g;
112
    auto output = deform_conv2d(
113
114
115
        input,
        weight,
        offset,
116
        mask,
117
        bias,
118
119
120
121
122
123
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        dilation_h,
        dilation_w,
124
        groups,
125
126
        offset_groups,
        use_mask);
127

128
    ctx->save_for_backward({input, weight, offset, mask, bias});
129
130
131
132
133
134
135
136
    ctx->saved_data["stride_h"] = stride_h;
    ctx->saved_data["stride_w"] = stride_w;
    ctx->saved_data["pad_h"] = pad_h;
    ctx->saved_data["pad_w"] = pad_w;
    ctx->saved_data["dilation_h"] = dilation_h;
    ctx->saved_data["dilation_w"] = dilation_w;
    ctx->saved_data["groups"] = groups;
    ctx->saved_data["offset_groups"] = offset_groups;
137
    ctx->saved_data["use_mask"] = use_mask;
138
139
140
141
142
143

    return {
        output,
    };
  }

144
145
  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
146
      const torch::autograd::variable_list& grad_output) {
147
148
149
150
    auto saved = ctx->get_saved_variables();
    auto input = saved[0];
    auto weight = saved[1];
    auto offset = saved[2];
151
152
    auto mask = saved[3];
    auto bias = saved[4];
153
154
155
156
157
158
159
160
161

    auto stride_h = ctx->saved_data["stride_h"].toInt();
    auto stride_w = ctx->saved_data["stride_w"].toInt();
    auto pad_h = ctx->saved_data["pad_h"].toInt();
    auto pad_w = ctx->saved_data["pad_w"].toInt();
    auto dilation_h = ctx->saved_data["dilation_h"].toInt();
    auto dilation_w = ctx->saved_data["dilation_w"].toInt();
    auto groups = ctx->saved_data["groups"].toInt();
    auto offset_groups = ctx->saved_data["offset_groups"].toInt();
162
    auto use_mask = ctx->saved_data["use_mask"].toBool();
163

164
    auto grads = _deform_conv2d_backward(
165
166
167
168
        grad_output[0],
        input,
        weight,
        offset,
169
        mask,
170
        bias,
171
172
173
174
175
176
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        dilation_h,
        dilation_w,
177
        groups,
178
179
        offset_groups,
        use_mask);
180
181
182
    auto grad_input = std::get<0>(grads);
    auto grad_weight = std::get<1>(grads);
    auto grad_offset = std::get<2>(grads);
183
184
    auto grad_mask = std::get<3>(grads);
    auto grad_bias = std::get<4>(grads);
185
186
187
188
189

    return {
        grad_input,
        grad_weight,
        grad_offset,
190
        grad_mask,
191
        grad_bias,
192
193
194
195
196
197
198
199
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
200
        torch::autograd::Variable(),
201
202
203
204
    };
  }
};

205
206
207
208
209
210
// TODO: There should be an easier way to do this
class DeformConv2dBackwardFunction
    : public torch::autograd::Function<DeformConv2dBackwardFunction> {
 public:
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
211
212
213
214
      const torch::autograd::Variable& grad,
      const torch::autograd::Variable& input,
      const torch::autograd::Variable& weight,
      const torch::autograd::Variable& offset,
215
      const torch::autograd::Variable& mask,
216
217
218
219
220
221
222
223
      const torch::autograd::Variable& bias,
      int64_t stride_h,
      int64_t stride_w,
      int64_t pad_h,
      int64_t pad_w,
      int64_t dilation_h,
      int64_t dilation_w,
      int64_t groups,
224
225
      int64_t offset_groups,
      bool use_mask) {
226
227
228
229
230
231
    at::AutoNonVariableTypeMode g;
    auto result = _deform_conv2d_backward(
        grad,
        input,
        weight,
        offset,
232
        mask,
233
234
235
236
237
238
239
240
        bias,
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        dilation_h,
        dilation_w,
        groups,
241
242
        offset_groups,
        use_mask);
243
244
245
246

    auto grad_input = std::get<0>(result);
    auto grad_weight = std::get<1>(result);
    auto grad_offset = std::get<2>(result);
247
248
    auto grad_mask = std::get<3>(result);
    auto grad_bias = std::get<4>(result);
249
250
251
252
253

    return {
        grad_input,
        grad_weight,
        grad_offset,
254
        grad_mask,
255
256
257
258
259
260
        grad_bias,
    };
  }

  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
261
      const torch::autograd::variable_list& grad_output) {
262
263
264
265
    TORCH_CHECK(0, "double backwards on deform_conv2d not supported");
  }
};

266
at::Tensor deform_conv2d_autograd(
267
268
269
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
270
    const at::Tensor& mask,
271
    const at::Tensor& bias,
272
273
274
275
276
277
278
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
279
280
    int64_t offset_groups,
    bool use_mask) {
281
  return DeformConv2dFunction::apply(
282
283
284
      input,
      weight,
      offset,
285
      mask,
286
287
288
289
290
291
292
293
      bias,
      stride_h,
      stride_w,
      pad_h,
      pad_w,
      dilation_h,
      dilation_w,
      groups,
294
295
      offset_groups,
      use_mask)[0];
296
}
297

298
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
299
deform_conv2d_backward_autograd(
300
301
302
303
    const at::Tensor& grad,
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
304
    const at::Tensor& mask,
305
    const at::Tensor& bias,
306
307
308
309
310
311
312
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
313
314
    int64_t offset_groups,
    bool use_mask) {
315
316
317
318
319
  auto result = DeformConv2dBackwardFunction::apply(
      grad,
      input,
      weight,
      offset,
320
      mask,
321
322
323
324
325
326
327
328
      bias,
      stride_h,
      stride_w,
      pad_h,
      pad_w,
      dilation_h,
      dilation_w,
      groups,
329
330
      offset_groups,
      use_mask);
331

332
333
  return std::make_tuple(result[0], result[1], result[2], result[3], result[4]);
}
334

335
336
} // namespace

337
338
339
340
341
TORCH_LIBRARY_IMPL(torchvision, Autograd, m) {
  m.impl("deform_conv2d", deform_conv2d_autograd);
  m.impl("_deform_conv2d_backward", deform_conv2d_backward_autograd);
}

342
343
} // namespace ops
} // namespace vision