deform_conv2d_kernel.cu 34.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
/*!
 ******************* BEGIN Caffe Copyright Notice and Disclaimer
 *****************
 *
 * COPYRIGHT
 *
 * All contributions by the University of California:
 * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
 * All rights reserved.
 *
 * All other contributions:
 * Copyright (c) 2014-2017, the respective contributors
 * All rights reserved.
 *
 * Caffe uses a shared copyright model: each contributor holds copyright over
 * their contributions to Caffe. The project versioning records all such
 * contribution and copyright details. If a contributor wants to further mark
 * their specific copyright on a particular contribution, they should indicate
 * their copyright solely in the commit message of the change when it is
 * committed.
 *
 * LICENSE
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
 *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * CONTRIBUTION AGREEMENT
 *
 * By contributing to the BVLC/caffe repository through pull-request, comment,
 * or otherwise, the contributor releases their content to the
 * license and copyright terms herein.
 *
 ***************** END Caffe Copyright Notice and Disclaimer
 *********************
 *
 * Copyright (c) 2018 Microsoft
 * Licensed under The MIT License [see LICENSE for details]
 * \file modulated_deformable_im2col.cuh
 * \brief Function definitions of converting an image to
 * column matrix based on kernel, padding, dilation, and offset.
 * These functions are mainly used in deformable convolution operators.
 * \ref: https://arxiv.org/abs/1703.06211
 * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
 */

// modified from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu

// modified from
// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
71
#include <THC/THCAtomics.cuh>
72
73

#include "cuda_helpers.h"
74
#include "deform_conv2d_kernel.h"
75

76
77
78
79
namespace vision {
namespace ops {

namespace {
80
81
82

const int kMaxParallelImgs = 32;

83
inline unsigned int GET_THREADS() {
84
85
86
#ifdef __HIP_PLATFORM_HCC__
  return 256;
#endif
87
88
89
90
91
92
  if (at::cuda::getCurrentDeviceProperties()->major >= 6) {
    return 1024;
  }
  return 512;
}

93
94
95
inline unsigned int GET_BLOCKS(
    const unsigned int THREADS,
    const unsigned int N) {
96
97
98
  unsigned int kMaxGridNum =
      at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
  return std::min(kMaxGridNum, (N + THREADS - 1) / THREADS);
99
100
101
102
103
}

template <typename scalar_t>
__device__ scalar_t bilinear_interpolate(
    const scalar_t* in,
104
105
    int height,
    int width,
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    scalar_t h,
    scalar_t w) {
  if (h <= -1 || height <= h || w <= -1 || width <= w) {
    return 0;
  }

  int h_low = floor(h);
  int w_low = floor(w);
  int h_high = h_low + 1;
  int w_high = w_low + 1;

  scalar_t lh = h - h_low;
  scalar_t lw = w - w_low;
  scalar_t hh = 1 - lh, hw = 1 - lw;

  scalar_t v1 = 0;
  if (h_low >= 0 && w_low >= 0)
    v1 = in[h_low * width + w_low];
  scalar_t v2 = 0;
  if (h_low >= 0 && w_high <= width - 1)
    v2 = in[h_low * width + w_high];
  scalar_t v3 = 0;
  if (h_high <= height - 1 && w_low >= 0)
    v3 = in[h_high * width + w_low];
  scalar_t v4 = 0;
  if (h_high <= height - 1 && w_high <= width - 1)
    v4 = in[h_high * width + w_high];

  scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;

  scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  return val;
}

template <typename scalar_t>
141
__global__ void deformable_im2col_kernel(
142
    int n,
143
144
    const scalar_t* input_ptr,
    const scalar_t* offset_ptr,
145
    const scalar_t* mask_ptr,
146
147
148
149
150
151
152
153
    int height,
    int width,
    int weight_h,
    int weight_w,
    int pad_h,
    int pad_w,
    int stride_h,
    int stride_w,
154
155
    int dilation_h,
    int dilation_w,
156
157
158
159
160
    int batch_sz,
    int n_in_channels,
    int n_offset_grps,
    int out_h,
    int out_w,
161
    bool use_mask,
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    scalar_t* columns_ptr) {
  CUDA_1D_KERNEL_LOOP(index, n) {
    const int out_x = index % out_w;
    const int out_y = (index / out_w) % out_h;
    const int out_b = (index / (out_w * out_h)) % batch_sz;
    const int in_c = index / (out_w * out_h * batch_sz);
    const int out_c = in_c * weight_h * weight_w;

    int c_per_offset_grp = n_in_channels / n_offset_grps;
    const int grp_idx = in_c / c_per_offset_grp;

    columns_ptr +=
        (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) +
         out_y * out_w + out_x);

    input_ptr +=
        (out_b * (n_in_channels * height * width) + in_c * (height * width));

    offset_ptr += (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w *
        out_h * out_w;

183
184
185
186
187
    if (use_mask) {
      mask_ptr += (out_b * n_offset_grps + grp_idx) * weight_h * weight_w *
          out_h * out_w;
    }

188
189
    for (int i = 0; i < weight_h; ++i) {
      for (int j = 0; j < weight_w; ++j) {
190
191
192
193
194
195
196
197
198
        const int mask_idx = i * weight_w + j;
        const int offset_idx = 2 * mask_idx;

        scalar_t mask_value = 1;
        if (use_mask) {
          mask_value =
              mask_ptr[mask_idx * (out_h * out_w) + out_y * out_w + out_x];
        }

199
200
201
202
        const scalar_t offset_h =
            offset_ptr[offset_idx * (out_h * out_w) + out_y * out_w + out_x];
        const scalar_t offset_w = offset_ptr
            [(offset_idx + 1) * (out_h * out_w) + out_y * out_w + out_x];
203
204
205
206
        const scalar_t y =
            (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
        const scalar_t x =
            (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
207
208
        *columns_ptr =
            mask_value * bilinear_interpolate(input_ptr, height, width, y, x);
209
210
211
212
213
214
        columns_ptr += batch_sz * out_h * out_w;
      }
    }
  }
}

215
void deformable_im2col(
216
217
    const at::Tensor& input,
    const at::Tensor& data_offset,
218
    const at::Tensor& data_mask,
219
220
221
222
223
224
225
226
227
    int n_in_channels,
    int height,
    int width,
    int weight_h,
    int weight_w,
    int pad_h,
    int pad_w,
    int stride_h,
    int stride_w,
228
229
    int dilation_h,
    int dilation_w,
230
231
232
233
    int out_h,
    int out_w,
    int parallel_imgs,
    int deformable_group,
234
    bool use_mask,
235
236
237
    at::Tensor data_col) {
  int num_kernels = n_in_channels * out_h * out_w * parallel_imgs;

238
239
240
  const unsigned int threads = GET_THREADS();
  const unsigned int blocks = GET_BLOCKS(threads, num_kernels);

241
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
242
243
      input.scalar_type(), "deformable_im2col", ([&] {
        deformable_im2col_kernel<<<blocks, threads>>>(
244
245
246
            num_kernels,
            input.data_ptr<scalar_t>(),
            data_offset.data_ptr<scalar_t>(),
247
            data_mask.data_ptr<scalar_t>(),
248
249
250
251
252
253
254
255
            height,
            width,
            weight_h,
            weight_w,
            pad_h,
            pad_w,
            stride_h,
            stride_w,
256
257
            dilation_h,
            dilation_w,
258
259
260
261
262
            parallel_imgs,
            n_in_channels,
            deformable_group,
            out_h,
            out_w,
263
            use_mask,
264
265
266
267
268
269
270
271
272
            data_col.data_ptr<scalar_t>());
      }));

  cudaError_t err = cudaGetLastError();
  if (err != cudaSuccess) {
    printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
  }
}

273
int get_greatest_divisor_below_bound(int n, int bound) {
274
275
276
277
278
279
280
281
282
  for (int k = bound; k > 1; --k) {
    if (n % k == 0) {
      return k;
    }
  }
  return 1;
}

template <typename scalar_t>
283
__global__ void deformable_col2im_kernel(
284
    int n,
285
286
    const scalar_t* col,
    const scalar_t* offset_ptr,
287
    const scalar_t* mask_ptr,
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    int channels,
    int height,
    int width,
    int kernel_h,
    int kernel_w,
    int pad_h,
    int pad_w,
    int stride_h,
    int stride_w,
    int dilation_h,
    int dilation_w,
    int batch_sz,
    int n_offset_grps,
    int out_h,
    int out_w,
303
    bool use_mask,
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    scalar_t* grad_im) {
  CUDA_1D_KERNEL_LOOP(index, n) {
    const int out_x = index % out_w;
    const int out_y = (index / out_w) % out_h;
    const int b = (index / (out_w * out_h)) % batch_sz;
    const int j = (index / (out_w * out_h * batch_sz)) % kernel_w;
    const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h;
    const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h);

    int c_per_offset_grp = channels / n_offset_grps;
    const int offset_grp = c / c_per_offset_grp;

    offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w *
        out_h * out_w;
318
319
320
321
322
323
324
325
326
327
328
329

    if (use_mask) {
      mask_ptr += (b * n_offset_grps + offset_grp) * kernel_h * kernel_w *
          out_h * out_w;
    }

    const int mask_idx = i * kernel_w + j;
    const int offset_idx = 2 * mask_idx;

    const int offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x;
    const int offset_w_ptr = ((offset_idx + 1) * out_h + out_y) * out_w + out_x;

330
331
    const scalar_t offset_h = offset_ptr[offset_h_ptr];
    const scalar_t offset_w = offset_ptr[offset_w_ptr];
332
333
334
335
336
337

    scalar_t mask_value = 1;
    if (use_mask) {
      mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x];
    }

338
339
340
341
342
343
344
345
346
347
348
    const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
    const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;

    for (int dy = -1; dy <= 1; dy++) {
      for (int dx = -1; dx <= 1; dx++) {
        int yp = int(y) + dy;
        int xp = int(x) + dx;
        if (0 <= yp && yp < height && 0 <= xp && xp < width &&
            std::abs(y - yp) < 1 && std::abs(x - xp) < 1) {
          int grad_pos = ((b * channels + c) * height + yp) * width + xp;
          scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp));
349
          atomicAdd(grad_im + grad_pos, mask_value * weight * col[index]);
350
351
352
353
354
355
        }
      }
    }
  }
}

356
void compute_grad_input(
357
358
    const at::Tensor& columns,
    const at::Tensor& offset,
359
    const at::Tensor& mask,
360
361
362
363
364
365
366
367
368
369
370
371
372
    int channels,
    int height,
    int width,
    int weight_h,
    int weight_w,
    int pad_h,
    int pad_w,
    int stride_h,
    int stride_w,
    int dilation_h,
    int dilation_w,
    int parallel_imgs,
    int n_offset_grps,
373
    bool use_mask,
374
375
376
377
378
379
380
381
    at::Tensor grad_im) {
  int out_h =
      (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
  int out_w =
      (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
  int num_kernels =
      channels * weight_h * weight_w * out_h * out_w * parallel_imgs;

382
383
384
  const unsigned int threads = GET_THREADS();
  const unsigned int blocks = GET_BLOCKS(threads, num_kernels);

385
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
386
387
      columns.scalar_type(), "compute_grad_input", ([&] {
        deformable_col2im_kernel<<<blocks, threads>>>(
388
389
390
            num_kernels,
            columns.data_ptr<scalar_t>(),
            offset.data_ptr<scalar_t>(),
391
            mask.data_ptr<scalar_t>(),
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
            channels,
            height,
            width,
            weight_h,
            weight_w,
            pad_h,
            pad_w,
            stride_h,
            stride_w,
            dilation_h,
            dilation_w,
            parallel_imgs,
            n_offset_grps,
            out_h,
            out_w,
407
            use_mask,
408
409
410
411
412
413
414
415
416
417
418
419
            grad_im.data_ptr<scalar_t>());
      }));

  cudaError_t err = cudaGetLastError();
  if (err != cudaSuccess) {
    printf("error in compute_grad_input: %s\n", cudaGetErrorString(err));
  }
}

template <typename scalar_t>
__device__ scalar_t get_coordinate_weight(
    const scalar_t* im_data,
420
421
    int height,
    int width,
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    scalar_t y,
    scalar_t x,
    bool is_y_direction) {
  int y_l = floor(y);
  int x_l = floor(x);
  int y_h = y_l + 1;
  int x_h = x_l + 1;

  bool valid_y_l = 0 <= y_l && y_l < height;
  bool valid_y_h = 0 <= y_h && y_h < height;
  bool valid_x_l = 0 <= x_l && x_l < width;
  bool valid_x_h = 0 <= x_h && x_h < width;

  scalar_t zero = 0;
  scalar_t v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero;
  scalar_t v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero;
  scalar_t v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero;
  scalar_t v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero;

  if (is_y_direction) {
    scalar_t dx = x - x_l;
    return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx);
  } else {
    scalar_t dy = y - y_l;
    return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx);
  }
}

template <typename scalar_t>
451
__global__ void deformable_col2im_coord_kernel(
452
    int n,
453
454
455
    const scalar_t* col_ptr,
    const scalar_t* im_ptr,
    const scalar_t* offset_ptr,
456
    const scalar_t* mask_ptr,
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    int channels,
    int height,
    int width,
    int weight_h,
    int weight_w,
    int pad_h,
    int pad_w,
    int stride_h,
    int stride_w,
    int dilation_h,
    int dilation_w,
    int batch_sz,
    int offset_channels,
    int n_offset_grps,
    int out_h,
    int out_w,
473
474
475
    const bool use_mask,
    scalar_t* grad_offset,
    scalar_t* grad_mask) {
476
  CUDA_1D_KERNEL_LOOP(index, n) {
477
478
479
    scalar_t grad_offset_val = 0;
    scalar_t grad_mask_val = 0;

480
481
    int w = index % out_w;
    int h = (index / out_w) % out_h;
482
483
    int w_w = (index / (out_w * out_h * 2)) % weight_w;
    int w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h;
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    int c = (index / (out_w * out_h)) % offset_channels;
    int b = index / (out_w * out_h * offset_channels);

    const int offset_grp = c / (2 * weight_h * weight_w);
    const int col_step = weight_h * weight_w;

    int c_per_offset_grp = channels / n_offset_grps;

    col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz *
        out_w * out_h;
    im_ptr +=
        (b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width;
    offset_ptr += (b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w *
        out_h * out_w;

499
500
501
502
503
    if (use_mask) {
      mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w *
          out_h * out_w;
    }

504
    const int offset_c = c - offset_grp * 2 * weight_h * weight_w;
505
    const bool is_y_direction = offset_c % 2 == 0;
506
507
508
509
510
511
512
513
514
515

    const int c_bound = c_per_offset_grp * weight_h * weight_w;
    for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) {
      const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w;

      int out_x = col_pos % out_w;
      int out_y = (col_pos / out_w) % out_h;
      int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w;
      int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h;

516
517
      const int mask_idx = i * weight_w + j;

518
      const int offset_h_ptr =
519
          (((2 * mask_idx) * out_h + out_y) * out_w + out_x);
520
      const int offset_w_ptr =
521
          (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x);
522
523
524
      const scalar_t offset_h = offset_ptr[offset_h_ptr];
      const scalar_t offset_w = offset_ptr[offset_w_ptr];

525
526
527
528
529
      scalar_t mask_value = 1;
      if (use_mask) {
        mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x];
      }

530
531
532
533
534
      scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
      scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;

      const scalar_t weight =
          get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction);
535
536
537
538
539
540
541
      grad_offset_val += mask_value * weight * col_ptr[col_pos];

      if (use_mask && is_y_direction) {
        grad_mask_val += col_ptr[col_pos] *
            bilinear_interpolate(im_ptr, height, width, y, x);
      }

542
543
544
      im_ptr += height * width;
    }

545
546
547
548
549
550
551
552
553
554
555
556
    grad_offset[index] = grad_offset_val;

    if (use_mask && is_y_direction) {
      const int idx =
          ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w +
            w_w) *
               out_h +
           h) *
              out_w +
          w;
      grad_mask[idx] = grad_mask_val;
    }
557
558
559
  }
}

560
void compute_grad_offset_and_mask(
561
562
563
    const at::Tensor& columns,
    const at::Tensor& input,
    const at::Tensor& offset,
564
    const at::Tensor& mask,
565
566
567
568
569
570
571
572
573
574
575
576
577
    int channels,
    int height,
    int width,
    int weight_h,
    int weight_w,
    int pad_h,
    int pad_w,
    int stride_h,
    int stride_w,
    int dilation_h,
    int dilation_w,
    int parallel_imgs,
    int n_offset_grps,
578
579
580
    bool use_mask,
    at::Tensor grad_offset,
    at::Tensor grad_mask) {
581
582
583
584
585
586
587
  int out_h =
      (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
  int out_w =
      (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
  int num_kernels =
      out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs;

588
589
590
  const unsigned int threads = GET_THREADS();
  const unsigned int blocks = GET_BLOCKS(threads, num_kernels);

591
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
592
593
      columns.scalar_type(), "compute_grad_offset_and_mask", ([&] {
        deformable_col2im_coord_kernel<<<blocks, threads>>>(
594
595
596
597
            num_kernels,
            columns.data_ptr<scalar_t>(),
            input.data_ptr<scalar_t>(),
            offset.data_ptr<scalar_t>(),
598
            mask.data_ptr<scalar_t>(),
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
            channels,
            height,
            width,
            weight_h,
            weight_w,
            pad_h,
            pad_w,
            stride_h,
            stride_w,
            dilation_h,
            dilation_w,
            parallel_imgs,
            2 * weight_h * weight_w * n_offset_grps,
            n_offset_grps,
            out_h,
            out_w,
615
616
617
            use_mask,
            grad_offset.data_ptr<scalar_t>(),
            grad_mask.data_ptr<scalar_t>());
618
619
620
621
      }));

  cudaError_t err = cudaGetLastError();
  if (err != cudaSuccess) {
622
623
    printf(
        "error in compute_grad_offset_and_mask: %s\n", cudaGetErrorString(err));
624
625
626
  }
}

627
std::tuple<at::Tensor, at::Tensor, at::Tensor> backward_gradient_inputs(
628
629
630
    at::Tensor input,
    at::Tensor weight,
    at::Tensor offset,
631
    at::Tensor mask,
632
    at::Tensor grad_out,
633
634
635
636
    int stride_h,
    int stride_w,
    int pad_h,
    int pad_w,
637
638
    int dilation_h,
    int dilation_w,
639
640
    int n_weight_grps,
    int n_offset_grps,
641
642
    int n_parallel_imgs,
    bool use_mask) {
643
644
645
646
647
648
649
650
651
652
653
654
655
  at::DeviceGuard guard(input.device());

  int batch_sz = input.size(0);
  long n_in_channels = input.size(1);
  long in_h = input.size(2);
  long in_w = input.size(3);

  n_parallel_imgs = std::min(batch_sz, n_parallel_imgs);

  long n_out_channels = weight.size(0);
  int weight_h = weight.size(2);
  int weight_w = weight.size(3);

656
657
658
659
  long out_w =
      (in_w + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
  long out_h =
      (in_h + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
660
661
662

  auto grad_input = at::zeros_like(input);
  auto grad_offset = at::zeros_like(offset);
663
664
  auto grad_mask = at::zeros_like(mask);

665
  if (batch_sz == 0) {
666
    return std::make_tuple(grad_input, grad_offset, grad_mask);
667
  }
668

669
  auto columns = at::empty(
670
671
672
673
      {n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
      input.options());

  // Separate into blocks
674
  grad_input = grad_input.reshape(
675
      {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
676
  input = input.reshape(
677
      {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
678

679
680
681
682
683
684
685
686
687
688
689
  grad_offset = grad_offset.reshape({batch_sz / n_parallel_imgs,
                                     n_parallel_imgs,
                                     n_offset_grps * 2 * weight_h * weight_w,
                                     out_h,
                                     out_w});
  offset = offset.reshape({batch_sz / n_parallel_imgs,
                           n_parallel_imgs,
                           n_offset_grps * 2 * weight_h * weight_w,
                           out_h,
                           out_w});

690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
  if (use_mask) {
    grad_mask = grad_mask.reshape({batch_sz / n_parallel_imgs,
                                   n_parallel_imgs,
                                   n_offset_grps * weight_h * weight_w,
                                   out_h,
                                   out_w});
    mask = mask.reshape({batch_sz / n_parallel_imgs,
                         n_parallel_imgs,
                         n_offset_grps * weight_h * weight_w,
                         out_h,
                         out_w});
  }

  grad_out = grad_out
                 .reshape({batch_sz / n_parallel_imgs,
                           n_parallel_imgs,
                           n_weight_grps,
                           n_out_channels / n_weight_grps,
                           out_h,
                           out_w})
                 .permute({0, 2, 3, 1, 4, 5});
711
712
713
714
715
716

  weight = weight.reshape({n_weight_grps,
                           weight.size(0) / n_weight_grps,
                           weight.size(1),
                           weight.size(2),
                           weight.size(3)});
717

718
719
  columns = columns.view(
      {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
720
  for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
721
    columns.zero_();
722
723
724
725
726
727
    // Separate into weight groups
    for (int g = 0; g < n_weight_grps; g++) {
      columns[g] = columns[g].addmm_(
          weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1));
    }

728
    compute_grad_offset_and_mask(
729
730
731
        columns,
        input[elt],
        offset[elt],
732
        mask[elt],
733
734
735
736
737
738
739
740
741
        n_in_channels,
        in_h,
        in_w,
        weight_h,
        weight_w,
        pad_h,
        pad_w,
        stride_h,
        stride_w,
742
743
        dilation_h,
        dilation_w,
744
745
        n_parallel_imgs,
        n_offset_grps,
746
747
748
        use_mask,
        grad_offset[elt],
        grad_mask[elt]);
749
750
751
752

    compute_grad_input(
        columns,
        offset[elt],
753
        mask[elt],
754
755
756
757
758
759
760
761
762
        n_in_channels,
        in_h,
        in_w,
        weight_h,
        weight_w,
        pad_h,
        pad_w,
        stride_h,
        stride_w,
763
764
        dilation_h,
        dilation_w,
765
766
        n_parallel_imgs,
        n_offset_grps,
767
        use_mask,
768
769
770
771
772
773
774
        grad_input[elt]);
  }

  grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w});
  grad_offset = grad_offset.view(
      {batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});

775
776
777
778
779
780
  if (use_mask) {
    grad_mask = grad_mask.view(
        {batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w});
  }

  return std::make_tuple(grad_input, grad_offset, grad_mask);
781
782
}

783
at::Tensor backward_gradient_parameters(
784
    at::Tensor input,
785
    const at::Tensor& weight,
786
    at::Tensor offset,
787
    at::Tensor mask,
788
    const at::Tensor& grad_out,
789
790
791
792
    int stride_h,
    int stride_w,
    int pad_h,
    int pad_w,
793
794
    int dilation_h,
    int dilation_w,
795
796
    int n_weight_grps,
    int n_offset_grps,
797
798
    int n_parallel_imgs,
    bool use_mask) {
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
  at::DeviceGuard guard(input.device());

  int batch_sz = input.size(0);
  long n_in_channels = input.size(1);
  long in_h = input.size(2);
  long in_w = input.size(3);

  n_parallel_imgs = std::min(batch_sz, n_parallel_imgs);

  long n_out_channels = weight.size(0);
  int weight_h = weight.size(2);
  int weight_w = weight.size(3);

  long out_h = grad_out.size(2);
  long out_w = grad_out.size(3);

  auto grad_weight = at::zeros_like(weight);
816
817
818
  if (batch_sz == 0) {
    return grad_weight;
  }
819

820
821
822
823
824
825
826
827
828
  at::Tensor grad_out_buf = grad_out
                                .reshape({batch_sz / n_parallel_imgs,
                                          n_parallel_imgs,
                                          n_weight_grps,
                                          n_out_channels / n_weight_grps,
                                          out_h,
                                          out_w})
                                .permute({0, 2, 3, 1, 4, 5})
                                .contiguous();
829
830

  input = input.reshape(
831
      {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
832

833
834
835
836
837
838
  offset = offset.reshape({batch_sz / n_parallel_imgs,
                           n_parallel_imgs,
                           n_offset_grps * 2 * weight_h * weight_w,
                           out_h,
                           out_w});

839
840
841
842
843
844
845
846
  if (use_mask) {
    mask = mask.reshape({batch_sz / n_parallel_imgs,
                         n_parallel_imgs,
                         n_offset_grps * weight_h * weight_w,
                         out_h,
                         out_w});
  }

847
848
849
850
851
852
853
854
855
856
857
  grad_weight = grad_weight.reshape({n_weight_grps,
                                     grad_weight.size(0) / n_weight_grps,
                                     grad_weight.size(1),
                                     grad_weight.size(2),
                                     grad_weight.size(3)});

  auto columns = at::empty(
      {n_weight_grps,
       n_in_channels * weight_w * weight_h / n_weight_grps,
       n_parallel_imgs * out_h * out_w},
      input.options());
858
859
860
861
862

  for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
    deformable_im2col(
        input[elt],
        offset[elt],
863
        mask[elt],
864
865
866
867
868
869
870
871
872
        n_in_channels,
        in_h,
        in_w,
        weight_h,
        weight_w,
        pad_h,
        pad_w,
        stride_h,
        stride_w,
873
874
        dilation_h,
        dilation_w,
875
876
877
878
        out_h,
        out_w,
        n_parallel_imgs,
        n_offset_grps,
879
        use_mask,
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
        columns);

    for (int g = 0; g < n_weight_grps; g++) {
      grad_weight[g] =
          grad_weight[g]
              .flatten(1)
              .addmm_(
                  grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0))
              .view_as(grad_weight[g]);
    }
  }

  grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
                                  grad_weight.size(2),
                                  grad_weight.size(3),
                                  grad_weight.size(4)});
  return grad_weight;
}

899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
} // namespace

at::Tensor deform_conv2d_forward_cuda(
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
    const at::Tensor& mask,
    const at::Tensor& bias,
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
    int64_t dilation_h,
    int64_t dilation_w,
    int64_t n_weight_grps,
    int64_t n_offset_grps,
    bool use_mask) {
  at::Tensor input_c = input.contiguous();
  at::Tensor offset_c = offset.contiguous();
  at::Tensor weight_c = weight.contiguous();
  at::Tensor mask_c = mask.contiguous();
  at::Tensor bias_c = bias.contiguous();

  TORCH_CHECK(input_c.ndimension() == 4);
  TORCH_CHECK(offset_c.ndimension() == 4);
  TORCH_CHECK(!use_mask || mask_c.ndimension() == 4);
  TORCH_CHECK(weight_c.ndimension() == 4);
  TORCH_CHECK(input_c.is_cuda(), "input must be a CUDA tensor");

  at::DeviceGuard guard(input_c.device());

  int batch_sz = input_c.size(0);
  int in_channels = input_c.size(1);
  int in_h = input_c.size(2);
  int in_w = input_c.size(3);

  int n_parallel_imgs =
      get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);

  int out_channels = weight_c.size(0);
  int weight_h = weight_c.size(2);
  int weight_w = weight_c.size(3);

  int ker_h = dilation_h * (weight_h - 1) + 1;
  int ker_w = dilation_w * (weight_w - 1) + 1;
  int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1;
  int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1;

  TORCH_CHECK(
      weight_h > 0 && weight_w > 0,
      "weight_h: ",
      weight_h,
      " weight_w: ",
      weight_w);
  TORCH_CHECK(
      stride_h > 0 && stride_w > 0,
      "stride_h: ",
      stride_h,
      " stride_w: ",
      stride_w);
  TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w);
  TORCH_CHECK(
      dilation_h > 0 && dilation_w > 0,
      "dilation_h: ",
      dilation_h,
      " dilation_w: ",
      dilation_w);

  TORCH_CHECK(weight_c.size(1) * n_weight_grps == input_c.size(1));
  TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0);
  TORCH_CHECK(
      (offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w),
      "offset.shape[1] is not valid: got: ",
      offset_c.size(1),
      " expected: ",
      n_offset_grps * 2 * weight_h * weight_w);
  TORCH_CHECK(
      (!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w),
      "mask.shape[1] is not valid: got: ",
      mask_c.size(1),
      " expected: ",
      n_offset_grps * weight_h * weight_w);
  TORCH_CHECK(input_c.size(1) % n_offset_grps == 0);

  TORCH_CHECK(
      (offset_c.size(0) == input_c.size(0)), "invalid batch size of offset");
  TORCH_CHECK(
      (offset_c.size(2) == out_h && offset_c.size(3) == out_w),
      "offset output dims: (",
      offset_c.size(2),
      ", ",
      offset_c.size(3),
      ") - ",
      "computed output dims: (",
      out_h,
      ", ",
      out_w,
      ")");
  TORCH_CHECK(
      (mask_c.size(0) == input_c.size(0)), "invalid batch size of mask");
  TORCH_CHECK(
      (!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w)),
      "mask output dims: (",
      mask_c.size(2),
      ", ",
      mask_c.size(3),
      ") - ",
      "computed output dims: (",
      out_h,
      ", ",
      out_w,
      ")");
  TORCH_CHECK(
      out_h > 0 && out_w > 0,
      "Calculated output size too small - out_h: ",
      out_h,
      " out_w: ",
      out_w);

  auto out =
      at::zeros({batch_sz, out_channels, out_h, out_w}, input_c.options());
  if (batch_sz == 0) {
    return out;
  }

  // Separate batches into blocks
  out = out.view({batch_sz / n_parallel_imgs,
                  n_parallel_imgs,
                  out_channels,
                  out_h,
                  out_w});
  input_c = input_c.view(
      {batch_sz / n_parallel_imgs, n_parallel_imgs, in_channels, in_h, in_w});

  offset_c = offset_c.view({batch_sz / n_parallel_imgs,
                            n_parallel_imgs,
                            n_offset_grps * 2 * weight_h * weight_w,
                            out_h,
                            out_w});

  if (use_mask) {
    mask_c = mask_c.view({batch_sz / n_parallel_imgs,
                          n_parallel_imgs,
                          n_offset_grps * weight_h * weight_w,
                          out_h,
                          out_w});
  }

  at::Tensor out_buf = at::zeros(
      {batch_sz / n_parallel_imgs,
       out_channels,
       n_parallel_imgs * out_h,
       out_w},
      out.options());

  // Separate channels into convolution groups
  out_buf = out_buf.view({out_buf.size(0),
                          n_weight_grps,
                          out_buf.size(1) / n_weight_grps,
                          out_buf.size(2),
                          out_buf.size(3)});
  weight_c = weight_c.view({n_weight_grps,
                            weight_c.size(0) / n_weight_grps,
                            weight_c.size(1),
                            weight_c.size(2),
                            weight_c.size(3)});

  // Sample points and perform convolution
  auto columns = at::zeros(
      {in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w},
      input_c.options());
  for (int b = 0; b < batch_sz / n_parallel_imgs; b++) {
    deformable_im2col(
        input_c[b],
        offset_c[b],
        mask_c[b],
        in_channels,
        in_h,
        in_w,
        weight_h,
        weight_w,
        pad_h,
        pad_w,
        stride_h,
        stride_w,
        dilation_h,
        dilation_w,
        out_h,
        out_w,
        n_parallel_imgs,
        n_offset_grps,
        use_mask,
        columns);

    columns = columns.view(
        {n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
    for (int g = 0; g < n_weight_grps; g++) {
      out_buf[b][g] = out_buf[b][g]
                          .flatten(1)
                          .addmm_(weight_c[g].flatten(1), columns[g])
                          .view_as(out_buf[b][g]);
    }
    columns =
        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
  }

  out_buf = out_buf.view({batch_sz / n_parallel_imgs,
                          out_channels,
                          n_parallel_imgs,
                          out_h,
                          out_w});
  out_buf.transpose_(1, 2);
  out.copy_(out_buf);
  out = out.view({batch_sz, out_channels, out_h, out_w});

  return out + bias_c.view({1, out_channels, 1, 1});
}

1117
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
1118
1119
1120
1121
1122
1123
1124
deform_conv2d_backward_cuda(
    const at::Tensor& grad_out,
    const at::Tensor& input,
    const at::Tensor& weight,
    const at::Tensor& offset,
    const at::Tensor& mask,
    const at::Tensor& bias,
1125
1126
1127
1128
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_h,
    int64_t pad_w,
1129
1130
    int64_t dilation_h,
    int64_t dilation_w,
1131
    int64_t n_weight_grps,
1132
1133
    int64_t n_offset_grps,
    bool use_mask) {
1134
1135
1136
1137
1138
1139
1140
1141
  at::Tensor grad_out_c = grad_out.contiguous();
  at::Tensor input_c = input.contiguous();
  at::Tensor weight_c = weight.contiguous();
  at::Tensor offset_c = offset.contiguous();
  at::Tensor mask_c = mask.contiguous();
  at::Tensor bias_c = bias.contiguous();

  const int batch_sz = input_c.size(0);
1142
1143
1144
  const int n_parallel_imgs =
      get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);

1145
1146
1147
1148
1149
1150
  auto grad_input_and_offset_and_mask = backward_gradient_inputs(
      input_c,
      weight_c,
      offset_c,
      mask_c,
      grad_out_c,
1151
1152
1153
1154
      stride_h,
      stride_w,
      pad_h,
      pad_w,
1155
1156
      dilation_h,
      dilation_w,
1157
1158
      n_weight_grps,
      n_offset_grps,
1159
1160
      n_parallel_imgs,
      use_mask);
1161

1162
1163
1164
  auto grad_input = std::get<0>(grad_input_and_offset_and_mask);
  auto grad_offset = std::get<1>(grad_input_and_offset_and_mask);
  auto grad_mask = std::get<2>(grad_input_and_offset_and_mask);
1165

1166
1167
1168
1169
1170
1171
  auto grad_weight = backward_gradient_parameters(
      input_c,
      weight_c,
      offset_c,
      mask_c,
      grad_out_c,
1172
1173
1174
1175
      stride_h,
      stride_w,
      pad_h,
      pad_w,
1176
1177
      dilation_h,
      dilation_w,
1178
1179
      n_weight_grps,
      n_offset_grps,
1180
1181
      n_parallel_imgs,
      use_mask);
1182

1183
1184
  auto value = grad_out_c.sum({0, 2, 3});
  auto grad_bias = at::ones_like(bias_c) * value;
1185

1186
1187
  return std::make_tuple(
      grad_input, grad_weight, grad_offset, grad_mask, grad_bias);
1188
}
1189
1190
1191

} // namespace ops
} // namespace vision