DeformConv.h 9.17 KB
Newer Older
1
2
#pragma once

3
4
5
6
7
8
9
#include "cpu/vision_cpu.h"

#ifdef WITH_CUDA
#include "autocast.h"
#include "cuda/vision_cuda.h"
#endif
#ifdef WITH_HIP
10
#include "autocast.h"
11
#include "hip/vision_cuda.h"
12
#endif
13

14
15
16
// TODO: put this stuff in torchvision namespace

at::Tensor deform_conv2d(
17
18
19
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
20
    const at::Tensor& mask,
21
    const at::Tensor& bias,
22
23
24
25
26
27
28
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
29
30
    int64_t offset_groups,
    bool use_mask) {
31
32
33
34
35
36
37
  static auto op = c10::Dispatcher::singleton()
                       .findSchemaOrThrow("torchvision::deform_conv2d", "")
                       .typed<decltype(deform_conv2d)>();
  return op.call(
      input,
      weight,
      offset,
38
      mask,
39
40
41
42
43
44
45
      bias,
      stride_h,
      stride_w,
      pad_h,
      pad_w,
      dilation_h,
      dilation_w,
46
      groups,
47
48
      offset_groups,
      use_mask);
49
50
}

51
52
#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor DeformConv2d_autocast(
53
54
55
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
56
    const at::Tensor& mask,
57
    const at::Tensor& bias,
58
59
60
61
62
63
64
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
65
66
    int64_t offset_groups,
    bool use_mask) {
67
68
69
70
71
  c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
  return deform_conv2d(
             at::autocast::cached_cast(at::kFloat, input),
             at::autocast::cached_cast(at::kFloat, weight),
             at::autocast::cached_cast(at::kFloat, offset),
72
             at::autocast::cached_cast(at::kFloat, mask),
73
74
75
76
77
78
79
80
             at::autocast::cached_cast(at::kFloat, bias),
             stride_h,
             stride_w,
             pad_h,
             pad_w,
             dilation_h,
             dilation_w,
             groups,
81
82
             offset_groups,
             use_mask)
83
84
      .to(input.scalar_type());
}
85
#endif
86

87
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
88
89
90
91
92
_deform_conv2d_backward(
    const at::Tensor& grad,
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
93
    const at::Tensor& mask,
94
    const at::Tensor& bias,
95
96
97
98
99
100
101
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
102
103
    int64_t offset_groups,
    bool use_mask) {
104
105
106
107
108
109
110
111
112
  static auto op =
      c10::Dispatcher::singleton()
          .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
          .typed<decltype(_deform_conv2d_backward)>();
  return op.call(
      grad,
      input,
      weight,
      offset,
113
      mask,
114
115
116
117
118
119
120
      bias,
      stride_h,
      stride_w,
      pad_h,
      pad_w,
      dilation_h,
      dilation_w,
121
      groups,
122
123
      offset_groups,
      use_mask);
124
125
126
127
128
}

class DeformConv2dFunction
    : public torch::autograd::Function<DeformConv2dFunction> {
 public:
129
130
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
131
132
133
      const torch::autograd::Variable& input,
      const torch::autograd::Variable& weight,
      const torch::autograd::Variable& offset,
134
      const torch::autograd::Variable& mask,
135
      const torch::autograd::Variable& bias,
136
137
138
139
140
141
142
      int64_t stride_h,
      int64_t stride_w,
      int64_t pad_h,
      int64_t pad_w,
      int64_t dilation_h,
      int64_t dilation_w,
      int64_t groups,
143
144
      int64_t offset_groups,
      bool use_mask) {
145
    at::AutoNonVariableTypeMode g;
146
    auto output = deform_conv2d(
147
148
149
        input,
        weight,
        offset,
150
        mask,
151
        bias,
152
153
154
155
156
157
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        dilation_h,
        dilation_w,
158
        groups,
159
160
        offset_groups,
        use_mask);
161

162
    ctx->save_for_backward({input, weight, offset, mask, bias});
163
164
165
166
167
168
169
170
    ctx->saved_data["stride_h"] = stride_h;
    ctx->saved_data["stride_w"] = stride_w;
    ctx->saved_data["pad_h"] = pad_h;
    ctx->saved_data["pad_w"] = pad_w;
    ctx->saved_data["dilation_h"] = dilation_h;
    ctx->saved_data["dilation_w"] = dilation_w;
    ctx->saved_data["groups"] = groups;
    ctx->saved_data["offset_groups"] = offset_groups;
171
    ctx->saved_data["use_mask"] = use_mask;
172
173
174
175
176
177

    return {
        output,
    };
  }

178
179
  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
180
      const torch::autograd::variable_list& grad_output) {
181
182
183
184
    auto saved = ctx->get_saved_variables();
    auto input = saved[0];
    auto weight = saved[1];
    auto offset = saved[2];
185
186
    auto mask = saved[3];
    auto bias = saved[4];
187
188
189
190
191
192
193
194
195

    auto stride_h = ctx->saved_data["stride_h"].toInt();
    auto stride_w = ctx->saved_data["stride_w"].toInt();
    auto pad_h = ctx->saved_data["pad_h"].toInt();
    auto pad_w = ctx->saved_data["pad_w"].toInt();
    auto dilation_h = ctx->saved_data["dilation_h"].toInt();
    auto dilation_w = ctx->saved_data["dilation_w"].toInt();
    auto groups = ctx->saved_data["groups"].toInt();
    auto offset_groups = ctx->saved_data["offset_groups"].toInt();
196
    auto use_mask = ctx->saved_data["use_mask"].toBool();
197

198
    auto grads = _deform_conv2d_backward(
199
200
201
202
        grad_output[0],
        input,
        weight,
        offset,
203
        mask,
204
        bias,
205
206
207
208
209
210
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        dilation_h,
        dilation_w,
211
        groups,
212
213
        offset_groups,
        use_mask);
214
215
216
    auto grad_input = std::get<0>(grads);
    auto grad_weight = std::get<1>(grads);
    auto grad_offset = std::get<2>(grads);
217
218
    auto grad_mask = std::get<3>(grads);
    auto grad_bias = std::get<4>(grads);
219
220
221
222
223

    return {
        grad_input,
        grad_weight,
        grad_offset,
224
        grad_mask,
225
        grad_bias,
226
227
228
229
230
231
232
233
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
        torch::autograd::Variable(),
234
        torch::autograd::Variable(),
235
236
237
238
    };
  }
};

239
240
241
242
243
244
// TODO: There should be an easier way to do this
class DeformConv2dBackwardFunction
    : public torch::autograd::Function<DeformConv2dBackwardFunction> {
 public:
  static torch::autograd::variable_list forward(
      torch::autograd::AutogradContext* ctx,
245
246
247
248
      const torch::autograd::Variable& grad,
      const torch::autograd::Variable& input,
      const torch::autograd::Variable& weight,
      const torch::autograd::Variable& offset,
249
      const torch::autograd::Variable& mask,
250
251
252
253
254
255
256
257
      const torch::autograd::Variable& bias,
      int64_t stride_h,
      int64_t stride_w,
      int64_t pad_h,
      int64_t pad_w,
      int64_t dilation_h,
      int64_t dilation_w,
      int64_t groups,
258
259
      int64_t offset_groups,
      bool use_mask) {
260
261
262
263
264
265
    at::AutoNonVariableTypeMode g;
    auto result = _deform_conv2d_backward(
        grad,
        input,
        weight,
        offset,
266
        mask,
267
268
269
270
271
272
273
274
        bias,
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        dilation_h,
        dilation_w,
        groups,
275
276
        offset_groups,
        use_mask);
277
278
279
280

    auto grad_input = std::get<0>(result);
    auto grad_weight = std::get<1>(result);
    auto grad_offset = std::get<2>(result);
281
282
    auto grad_mask = std::get<3>(result);
    auto grad_bias = std::get<4>(result);
283
284
285
286
287

    return {
        grad_input,
        grad_weight,
        grad_offset,
288
        grad_mask,
289
290
291
292
293
294
        grad_bias,
    };
  }

  static torch::autograd::variable_list backward(
      torch::autograd::AutogradContext* ctx,
295
      const torch::autograd::variable_list& grad_output) {
296
297
298
299
300
    TORCH_CHECK(0, "double backwards on deform_conv2d not supported");
  }
};

at::Tensor DeformConv2d_autograd(
301
302
303
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
304
    const at::Tensor& mask,
305
    const at::Tensor& bias,
306
307
308
309
310
311
312
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
313
314
    int64_t offset_groups,
    bool use_mask) {
315
  return DeformConv2dFunction::apply(
316
317
318
      input,
      weight,
      offset,
319
      mask,
320
321
322
323
324
325
326
327
      bias,
      stride_h,
      stride_w,
      pad_h,
      pad_w,
      dilation_h,
      dilation_w,
      groups,
328
329
      offset_groups,
      use_mask)[0];
330
}
331

332
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
333
334
335
336
337
DeformConv2d_backward_autograd(
    const at::Tensor& grad,
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
338
    const at::Tensor& mask,
339
    const at::Tensor& bias,
340
341
342
343
344
345
346
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t groups,
347
348
    int64_t offset_groups,
    bool use_mask) {
349
350
351
352
353
  auto result = DeformConv2dBackwardFunction::apply(
      grad,
      input,
      weight,
      offset,
354
      mask,
355
356
357
358
359
360
361
362
      bias,
      stride_h,
      stride_w,
      pad_h,
      pad_w,
      dilation_h,
      dilation_w,
      groups,
363
364
      offset_groups,
      use_mask);
365

366
367
  return std::make_tuple(result[0], result[1], result[2], result[3], result[4]);
}