learned_nonlin_cuda_kernel.cu 39.7 KB
Newer Older
1
#include <torch/extension.h>
Daniel Povey's avatar
Fix..  
Daniel Povey committed
2
#include <c10/cuda/CUDAStream.h>  // for getCurrentCUDAStream()
3
4
5
6
#include <cooperative_groups.h>



Daniel Povey's avatar
Daniel Povey committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

/*
  Tiled summing reduction within a warp.  Requires that the thread-block
  be 1-dimensional, i.e.  blockDim.y == blockDim.z == 1.  Does not use
  __syncthreads, so it is safe to call in a subset of threads.
  TODO: we can in principle do this without a buffer, using __shfl_down()
  (see here https://sodocumentation.net/cuda/topic/6566/parallel-reduction--e-g--how-to-sum-an-array-)
  if CC >= 3.0.

  Args:
      threads_per_tile:  Must be a power of 2 in the interval [1,32].  Summation is
                         within blocks of threads of this size.
       buf:              Pointer to the start of a __shared__ buffer of size
                         blockDim.x, to be used as a temporary within this function.
       val:              The value to be summed
   Return:
       Threads where blockDim.x % threads_per_tile == 0 will return the sum:
         \sum_{i=0}^{threads_per_tile-1} [val in thread threadIdx.x + i]
       Return value in other threads is undefined.
 */
template <typename scalar_t>
__forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
                                                          __volatile__ scalar_t *buf,
                                                          scalar_t val) {
  // Each iteration halves the number of active threads
  // Each thread adds its partial sum[i] to sum[lane+i]
  for (int i = threads_per_tile / 2; i > 0; i /= 2) {
    buf[threadIdx.x] = val;
    if (threadIdx.x % threads_per_tile < i)
      val += buf[threadIdx.x + i];
  }
  return val; // Only threads with threadIdx.x % threads_per_tile == 0 will
              // return the full sums of their tiles.
}

42
/*
Daniel Povey's avatar
Daniel Povey committed
43
44
45
46
  Forward of integrated_conv.  Each thread group handles a single channel (equal
  to blockIdx.x), and loops over patches of the output and over the image n
  within the batch (different thread groups may be responsible for different
  subsets of patches and/or images, see docs of gridDim below).
47
48
49
50
51
52
53

  Template args:
      scalar_t: the floating-point type, e.g. float, double, maybe half.

  Args:
      input:  input image, shape (N, 2*C, H, W)
      pos_add:  positional encoding, additive part,  shape (C, kH, kW)
Daniel Povey's avatar
Daniel Povey committed
54
55
56
      pos_mul:  positional encoding, multiplicative part, shape (C, kH, kW)
      output:   output image, shape (N, 2*C, H, W)
   Note: kH and kW must both be odd so that it's clear how to zero-pad.
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

  The thread-block should have one dimension (x); blockDim.x should equal
  some small power of 2 (threads_per_opixel) times the output-patch size which is
  opatchH * opatchW (the output-patch height and width).  We expect
  threads_per_opixel to be 1, 2, or 4; we use a linear summation to sum up the
  different threads' partial sums, and if threads_per_opixel gets larger we'd
  need to make this a logarithmic reduction.

   The requirements on the grid dimension are:
       gridDim.x == num-channels C (required)
       gridDim.y <= num-patches per image (recommended)
       gridDim.z <= batch-size N (recommended)
  When we invoke this kernel, we'll invoke it as:
   integrated_conv_forward<<<gridDim, blockDim, bytesShared, stream>>>
  where bytesShared is the number of bytes needed in `extern_buf`:
    bytesShared = sizeof(shared_t) * numel, where
    numel = 2 * (kH * kW) + max(blockDim.x, (opatchH + kH - 1) * (patchW + kW - 1))
 */
extern __shared__ int extern_buf[];

template <typename scalar_t>
__global__
void integrated_conv_kernel(
    torch::PackedTensorAccessor32<scalar_t, 4> input,  // N, 2*C, H, W
    torch::PackedTensorAccessor32<scalar_t, 3> pos_add,  // C, kH, kW
    torch::PackedTensorAccessor32<scalar_t, 3> pos_mul,  // C, kH, kW
Daniel Povey's avatar
Fix..  
Daniel Povey committed
83
    torch::PackedTensorAccessor32<scalar_t, 4> output,  // N, C, H, W
84
85
86
87
    int opatchH,  // output-patch height,
    int opatchW  // output-patch width
                             ) {
  const int H = input.size(2),
Daniel Povey's avatar
Fix..  
Daniel Povey committed
88
      W = input.size(3),
89
90
91
92
93
94
95
96
97
98
99
100
      kH = pos_add.size(1),
      kW = pos_add.size(2),
      npatchH = (H + opatchH - 1) / opatchH,  // num patches in vertical dim
      npatchW = (W + opatchW - 1) / opatchW,  // num patches in horizontal dim
      npatch = npatchH * npatchW;  // total number of patches per image

  // Channel index.
  const int c = blockIdx.x;
  // We don't need to check the range of `c` because we set gridDim.x to the
  // exact number of channels.

  const int ipatchH = opatchH + kH - 1,
Daniel Povey's avatar
Fix..  
Daniel Povey committed
101
      ipatchW = opatchW + kW - 1,
102
103
104
105
106
107
108
      ipatch_size = ipatchH * ipatchW,
      opatch_size = opatchH * opatchW;

  // `extern_buf` is general-purpose shared memory, which we'll divide between
  // pos_add, pos_mul and src_img_buf to be shared between the src image size
  // (ipatch_size) and the number of threads (blockDim.x)

Daniel Povey's avatar
Fix..  
Daniel Povey committed
109
110
111
  // these are pointers to __shared__ memory; the compiler should
  // be able to figure this out.
  scalar_t
112
113
114
115
116
117
118
119
120
121
122
123
      *pos_add_buf = (scalar_t*)extern_buf,     // pos_add positional-encoding / kernel parameters,
                                // indexed [kh*kW + kw] where kh and kw are vertical
                                // and horizontal positions in the kernel.
      *pos_mul_buf = pos_add_buf + (kH * kW), // pos_mul positional-encoding / kernel parameters,
                                              // indexed [kh*kW + kw] where kh and kw are vertical
                                              // and horizontal positions in the kernel.
      *src_img_buf = pos_mul_buf + (kH * kW);    // version of input image that relates to source position,
                             // of size [ipatch_size], indexed [h*ipatchW + w]...
                             // note, the 'h' and 'w' indexes are into the zero-padded input
                             // image.


Daniel Povey's avatar
Fix..  
Daniel Povey committed
124
  int threads_per_opixel = blockDim.x / opatch_size;
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
  assert(blockDim.x == opatch_size * threads_per_opixel);

  // pos_in_patch will be interpreted as h_in_patch * opatchW + w_in_patch.
  int pos_in_patch = threadIdx.x / threads_per_opixel;

  // Load parts of the kernel parameters pos_add and pos_mul into shared memory,
  // in pos_add_buf and pos_mul_buf
  for (int i = threadIdx.x; i < kH * kW; i += blockDim.x) {
    int kh = i / kW,
        kw = i % kW;
    pos_add_buf[i] = pos_add[c][kh][kw];
    pos_mul_buf[i] = pos_mul[c][kh][kw];
  }

  // n is the index within the batch.  Loop to make sure we cover all images in
  // the batch.  input.size(0) is the batch size N.  All threads in the thread-block
  // loop the same number of times.
  for (int n = blockIdx.z; n < input.size(0); n += gridDim.z) {

    // Loop over the patch within the output image.  All threads in the
    // thread-block loop the same number of times.
    for (int patch_idx = blockIdx.y; patch_idx < npatch; patch_idx += gridDim.y) {
      // (patch_h_offset, patch_w_offset) are the (vertical, horizontal) indexes
      // of the lowest-numbered pixel in the patch of output that this thread
      // block is responsible for.
      int patch_h_offset = (patch_idx / npatchW) * opatchH,
          patch_w_offset = (patch_idx % npatchW) * opatchW;

      // This __syncthreads() is only necessary if we have already looped at
      // least once over n or patch_idx: it's in case other threads are still
      // using the `src_img_buf` buffer for something else.
      __syncthreads();

      // Load the 'src' part of the input patch; the size of this is the size of
      // the output patch plus a border of sizes kH//2, kW//2 on each side.
      for (int i = threadIdx.x; i < ipatch_size; i += blockDim.x) {
        int h_in_kernel = i / ipatchW,
            w_in_kernel = i % ipatchW;
        int src_h = patch_h_offset + h_in_kernel - (kH / 2),  // kH / 2 is offset due to padding
            src_w = patch_w_offset + w_in_kernel - (kW / 2);
        scalar_t src_val = scalar_t(0);
Daniel Povey's avatar
Daniel Povey committed
166
167
        if ((unsigned int)src_h < (unsigned int)H &&  // h >= 0 && h < H
            (unsigned int)src_w < (unsigned int)W)    // w >= 0 && w < W
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
          src_val = input[n][c][src_h][src_w];
        src_img_buf[i] = src_val;
      }
      // make sure all threads have written to `src_img_buf`
      __syncthreads();


      // 'h' and 'w' are the positions within the output image, that this tile
      // of size threads_per_opixel is responsible for.
      int h = patch_h_offset + pos_in_patch / opatchW,
          w = patch_w_offset + pos_in_patch % opatchW;

      // The "destination" pixel; this is an input.  It gets added to each
      // src pixel, prior to the relu, in the loop below.
      scalar_t dest_val = scalar_t(0);
      if (h < H && w < W) {
        // Several threads (within the same tile, which implies the same warp)
        // may load the same value here, but I believe the device's memory
        // subsystem handles this well enough that we can just ignore the issue
        // rather than try to optimize it.
        // https://forums.developer.nvidia.com/t/accessing-same-global-memory-address-within-warps/66574
Daniel Povey's avatar
Fix..  
Daniel Povey committed
189
        int C = input.size(1) / 2;
190
191
192
193
194
195
196
197
        dest_val = input[n][c + C][h][w];  // else 0.
      }

      // `sum` is the partial sum that this thread computes; we'll sum this over
      // the `threads_per_opixel` threads in the tile to get the output pixel
      // value.
      scalar_t sum = 0.0;

Daniel Povey's avatar
Daniel Povey committed
198
      for (int pos_in_kernel = threadIdx.x % threads_per_opixel;
199
200
201
202
203
204
205
206
           pos_in_kernel < (kH * kW);
           pos_in_kernel += threads_per_opixel) {
        int h_in_kernel = pos_in_kernel / kW,
            w_in_kernel = pos_in_kernel % kW;
        // Note: this is actually more like cross-correlation, as we don't
        // have a negative sign on the h and w indexes in the kernel.
        // Also note: we already took care of padding and the associated
        // offsets of -(kH / 2) and -(kW / 2).
Daniel Povey's avatar
Fix..  
Daniel Povey committed
207
208
        int h_in_src_patch = (pos_in_patch / opatchW) + h_in_kernel,
            w_in_src_patch = (pos_in_patch % opatchW) + w_in_kernel;
209
210
211
212
213
214
        scalar_t src_val = src_img_buf[h_in_src_patch * ipatchW + w_in_src_patch],
            pos_add_val = pos_add_buf[pos_in_kernel];
        scalar_t relu = (src_val + dest_val + pos_add_val);
        if (relu > 0.0)
          sum += relu * pos_mul_buf[pos_in_kernel];
      }
Daniel Povey's avatar
Daniel Povey committed
215
216
      // Sync threads because src_img_buf is also used above.
      __syncthreads();
Daniel Povey's avatar
Daniel Povey committed
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
      // Aggregate `sum` over threads
      sum = tiled_warp_reduce_sum(threads_per_opixel, src_img_buf, sum);
      if (threadIdx.x % threads_per_opixel == 0 && h < H && w < W) {
        output[n][c][h][w] = sum;
      }
    }
  }
}


/*
  Backward of integrated_conv.  Each thread group handles a single channel (equal
  to blockIdx.x), and loops over patches of the output and over the image n
  within the batch (different thread groups may be responsible for different
  subsets of patches and/or images, see docs of gridDim below).

  If you want to understand this code, you should first understand the forward
  code.  Here are some points to understand how this works:

  First, understand the difference between the patch of size patchH by
  patchW, which is the basic patch size that is related to the blockDim.x,
  and the padded patch size ppatchH and ppatchW, where:
     ppatchH = patchH + kH - 1
     ppatchW = patchW + kW - 1.

  In the forward pass, we dealt with a patch of output and a padded patch of
  input.  In this backward-pass code, when computing the `grad_input` we deal
  with a patch of input and a padded patch of output (this ensures that
  different thread-blocks write to distinct patches of `grad_input`).  But this
  approach is not sufficient to update `grad_pos_add` and `grad_pos_mul`,
  because it's possible for elements of the zero-padding of `input` to
  contribute to `grad_pos_add` and `grad_pos_mul`.  So when computing the
  gradients for those quantities, we actually use a padded input patch and an
  un-padded output patch.  This requires that we load into shared memory the
  padded versions of both input and grad_output.


  Template args:
      scalar_t: the floating-point type, e.g. float, double, maybe half.

  Args:
      input [in]:  input image, shape (N, 2*C, H, W)
      pos_add [in]:  positional encoding, additive part,  shape (C, kH, kW)
      pos_mul [in]:  positional encoding, multiplicative part, shape (C, kH, kW)
      grad_output [in]: the gradient w.r.t. the output of the forward pass, shape (N, C, H, W)
      grad_input [out]: the gradient w.r.t. the input, of shape N, 2*C, H, W
      grad_pos_add [out]: the gradient w.r.t. pos_add, indexed [block][c][kh][kw],
            of shape num_blocks, C, kH, kW,
            where `block` is an index we'll later sum over, that corresponds to
            the identity of the thread-block (except, not including the channel
            dimension == gridDim.x).  So, block == blockIdx.z * gridDim.y + blockIdx.y,
            and num_blocks == gridDim.y * gridDim.z.
      grad_pos_mul [out]: the gradient w.r.t. pos_mul, like grad_pos_add above.
      patchH: the height of the patch size this kernel operates on (prior to padding)
      patchW: the width of the patch size this kernel operates on (prior to padding)
      threads_per_pixel: the number of threads assigned to compute each pixel
              of grad_input.  Require patchH * patchW * threads_per_pixel <= blockDim.x
              and threads_per_pixel must be a power of 2 in the interval [1,32].
      threads_per_kernel_pos: the number of threads assigned to compute each kernel
              position of grad_pos_add and grad_pos_mul.
              Require kH * kW * threads_per_kernel_pos <= blockDim.x,
              and threads_per_kernel_pos must be a power of 2 in the interval [1,32].
              This requires that kH * kW must not be greater than 1024.

  Note: kH and kW must both be odd so that it's clear how to zero-pad.

  The thread-block should have one dimension (x); see docs for threads_per_pixel
  and threads_per_kernel_pos for requirements on blockDim.x.  Also, blockDim.x
  must be an exact multiple of 64, so we can divide the threads by 2 and they
  will be in different warps.

  The requirements on the grid dimension are:
       gridDim.x == num-channels C (required)
       gridDim.y <= num-patches per image (recommended)
       gridDim.z <= batch-size N (recommended)
  When we invoke this kernel, we'll invoke it as:
   integrated_conv_forward<<<gridDim, blockDim, bytesShared, stream>>>
  where bytesShared is the number of bytes needed in `extern_buf`:

   bytesShared = sizeof(shared_t) * numel, where
    numel = 4 * (kH * kW) + 3 * (ppatchH * ppatchW) + blockDim.x
 */


template <typename scalar_t>
__global__
void integrated_conv_kernel_backward(
    torch::PackedTensorAccessor32<scalar_t, 4> input,  // N, 2*C, H, W
    torch::PackedTensorAccessor32<scalar_t, 3> pos_add,  // C, kH, kW
    torch::PackedTensorAccessor32<scalar_t, 3> pos_mul,  // C, kH, kW
    torch::PackedTensorAccessor32<scalar_t, 4> grad_output,  // N, C, H, W
    torch::PackedTensorAccessor32<scalar_t, 4> grad_input,  // N, 2*C, H, W
    torch::PackedTensorAccessor32<scalar_t, 4> grad_pos_add, // block, C, kH, kW, see above for `block`
    torch::PackedTensorAccessor32<scalar_t, 4> grad_pos_mul, // block, C, kH, kW, see above for `block`
    int patchH,  // non-padded patch height
    int patchW,  // non-padded patch width
    int threads_per_pixel,
    int threads_per_kernel_pos) {

  const int H = input.size(2),
      W = input.size(3),
      kH = pos_add.size(1),
      kW = pos_add.size(2),
      npatchH = (H + patchH - 1) / patchH,  // num patches in vertical dim
      npatchW = (W + patchW - 1) / patchW,  // num patches in horizontal dim
      npatch = npatchH * npatchW;  // total number of patches per image

  // Channel index.
  const int c = blockIdx.x;
  // We don't need to check the range of `c` because we set gridDim.x to the
  // exact number of channels.

  const int ppatchH = patchH + kH - 1,  // ppatchH is the padded patch height.
      ppatchW = patchW + kW - 1,  // ppatchW is the padded patch width
      patch_size = patchH * patchW,  // un-padded patch size
      ppatch_size = ppatchH * ppatchW;  // padded patch size

  // `extern_buf` is general-purpose shared memory, which we'll divide between
  // various buffers.

  // these are pointers to __shared__ memory; the compiler should
  // be able to figure this out.
  scalar_t
      *pos_add_buf = (scalar_t*)extern_buf,     // pos_add positional-encoding / kernel parameters,
                                // indexed [kh*kW + kw] where kh and kw are vertical
                                // and horizontal positions in the kernel.
      *pos_mul_buf = pos_add_buf + (kH * kW), // pos_mul positional-encoding / kernel parameters,
                                              // indexed [kh*kW + kw] where kh and kw are vertical
                                              // and horizontal positions in the kernel.
      *src_img_buf = pos_mul_buf + (kH * kW),    // version of input image that relates to source position,
                             // of size [ppatch_size], indexed [h*ppatchW + w],
                             // where the 'h' and 'w' indexes are into the zero-padded input
                             // image.
      *dest_img_buf = src_img_buf + ppatch_size,  // version of input image that relates to destinatioon position
351
      *grad_output_buf = dest_img_buf + ppatch_size, // output gradient for padded patch, indexed [h*ppatchW + w]
Daniel Povey's avatar
Daniel Povey committed
352
353
354
355
356
357
358
359
360
361
362
      *grad_pos_add_buf = grad_output_buf + ppatch_size,  // total grad for pos_add for this thread block, indexed [kh*kW + kw]
      *grad_pos_mul_buf = grad_pos_add_buf + (kH * kW),  // total grad for pos_mul for this thread block, indexed [kh*kW + kw]
      *reduce_buf = grad_pos_mul_buf + (kH * kW);  // buffer for reduction over threads, size == blockDim.x


  // pos_in_patch will be interpreted as h_in_patch * patchW + w_in_patch.
  int pos_in_patch = threadIdx.x / threads_per_pixel;

  // Load parts of the kernel parameters pos_add and pos_mul into shared memory,
  // in pos_add_buf and pos_mul_buf; zero the corresponding gradient buffers.
  // We know that blockDim.x >= kH * kW, see threads_per_kernel_pos.
363
364

  for (int i = threadIdx.x % (blockDim.x / 2); i < kH * kW; i += (blockDim.x / 2)) {
Daniel Povey's avatar
Daniel Povey committed
365
    int kh = i / kW, kw = i % kW;
366
367
368
369
370
371
372
373
    if (threadIdx.x < blockDim.x / 2) {  // First half of threads take care of pos_add..
      pos_add_buf[i] = pos_add[c][kh][kw];
      grad_pos_add_buf[i] = 0.0;
    } else {  // Second half take care of pos_mul... there is no warp divergence
              // because we make sure blockDim.x is a multiple of 64.
      pos_mul_buf[i] = pos_mul[c][kh][kw];
      grad_pos_mul_buf[i] = 0.0;
    }
Daniel Povey's avatar
Daniel Povey committed
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
  }

  // n is the index within the batch of images.  Loop to make sure we cover all
  // images in the batch.  input.size(0) is the batch size N.  All threads in
  // the thread-block loop the same number of times.
  for (int n = blockIdx.z; n < input.size(0); n += gridDim.z) {

    // Loop over the patch within the output image.  All threads in the
    // thread-block loop the same number of times.
    for (int patch_idx = blockIdx.y; patch_idx < npatch; patch_idx += gridDim.y) {
      // (patch_h_offset, patch_w_offset) are the (vertical, horizontal) indexes
      // of the lowest-numbered pixel in the *un-padded* patch that this thread
      // block is responsible for.  (We'll actualy be loading the padded patches
      // into memory, so be careful).
      int patch_h_offset = (patch_idx / npatchW) * patchH,
          patch_w_offset = (patch_idx % npatchW) * patchW;

      // This __syncthreads() is only necessary if we have already looped at
      // least once over n or patch_idx: it's in case other threads are still
      // using the `src_img_buf` or `dst_img_buf` buffers for a previous patch.
      __syncthreads();

      // Load the 'src' and 'dest' versions of the padded patch into
      // shared-memory buffers, and also the output gradient.
398
399
      for (int i = threadIdx.x % (blockDim.x / 2);
           i < ppatch_size; i += (blockDim.x / 2)) {
Daniel Povey's avatar
Daniel Povey committed
400
401
402
403
404
405
406
407
408
        int h_in_ppatch = i / ppatchW,
            w_in_ppatch = i % ppatchW;
        int h = patch_h_offset + h_in_ppatch - (kH / 2),  // kH / 2 is offset due to padding
            w = patch_w_offset + w_in_ppatch - (kW / 2);

        if (threadIdx.x < blockDim.x / 2) {  // The first half of the threads of the block
                                             // load `input`
          scalar_t src_val = scalar_t(0),
              dest_val = scalar_t(0);
409
          if ((unsigned int)h < (unsigned int)H &&  // h >= 0 && h < H
Daniel Povey's avatar
Daniel Povey committed
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
              (unsigned int)w < (unsigned int)W) {  // w >= 0 && w < W
            int C = grad_output.size(1);
            src_val = input[n][c][h][w];
            dest_val = input[n][c + C][h][w];
          }
          src_img_buf[i] = src_val;
          dest_img_buf[i] = dest_val;
        } else {  // second half of threads load `grad_output`.  We require
                  // blockDim.x be an even multiple of the warp size, so there
                  // is no warp divergence here.
          scalar_t grad_output_val = scalar_t(0);
          if ((unsigned int)h < (unsigned int)H &&
              (unsigned int)w < (unsigned int)W)
            grad_output_val = grad_output[n][c][h][w];
          grad_output_buf[i] = grad_output_val;
        }
      }
      // make sure all threads haave written to `src_img_buf`, `dest_img_buf` and
      // `grad_output_buf`.
      __syncthreads();

      scalar_t grad_input_src_sum = 0.0,  // grad for channel c, for our pixel
                                          // of `input` (contribution of this
                                          // thread)
          grad_input_dest_sum = 0.0;   // grad for channel c + C, for our pixel
                                       // of `input` (contribution of this thread)
      if (pos_in_patch < patch_size) {
437
        // This block computes `grad_input_src_sum` and `grad_input_dest_sum`
Daniel Povey's avatar
Daniel Povey committed
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
        // The num-threads for the backward kernel may not be an exact multiple
        // of patch_size, wo we need the if-guard.

        int h_in_patch = pos_in_patch / patchW,
            w_in_patch = pos_in_patch % patchW,
            h_in_ppatch = h_in_patch + kH / 2,
            w_in_ppatch = w_in_patch + kW / 2,
            pos_in_ppatch = h_in_ppatch * ppatchW + w_in_ppatch;

        // this_dest_val is the `destination` version of our current pixel; this
        // is an input.  It gets added to each src pixel, prior to the relu, in
        // the loop below.
        // this_src_val is the `src` version of our current pixel; it contributes
        // to the outputs of other pixels.
        scalar_t this_dest_val = dest_img_buf[pos_in_ppatch],
            this_src_val = src_img_buf[pos_in_ppatch];

        for (int pos_in_kernel = threadIdx.x % threads_per_pixel;
             pos_in_kernel < (kH * kW);
             pos_in_kernel += threads_per_pixel) {

          int h_in_kernel = pos_in_kernel / kW,
              w_in_kernel = pos_in_kernel % kW;

          // This is actually more like cross-correlation, as we don't have a
          // negative sign on the h and w indexes in the kernel.
          int src_h_in_ppatch = h_in_patch + h_in_kernel,
              src_w_in_ppatch = w_in_patch + w_in_kernel;
          int src_pos_in_ppatch = src_h_in_ppatch * ppatchW + src_w_in_ppatch;

          scalar_t src_val = src_img_buf[src_pos_in_ppatch],
              pos_add_val = pos_add_buf[pos_in_kernel],
              pos_mul_val = pos_mul_buf[pos_in_kernel];
          scalar_t relu = (src_val + this_dest_val + pos_add_val);
          if (relu >= 0.0) {
            scalar_t this_grad_output = grad_output_buf[pos_in_ppatch];
            grad_input_dest_sum += this_grad_output * pos_mul_val;
          }
476
477
478
479
480
          // To compute a contribution to "this_input_src_grad", we need to
          // consider the contribution to the destination pixel that it would
          // have contributed to with this same offset.
          // We have to flip the offsets: instead of "+ h_in_kernel",
          // we use (kH - 1) - h_in_kernel,.
Daniel Povey's avatar
Daniel Povey committed
481
482
483
484
485
486
487
488
489
490
491
492
493
          int dest_h_in_ppatch = h_in_patch + (kH - 1) - h_in_kernel,
              dest_w_in_ppatch = w_in_patch + (kW - 1) - w_in_kernel,
              dest_pos_in_ppatch = dest_h_in_ppatch * ppatchW + dest_w_in_ppatch;
          scalar_t dest_val = dest_img_buf[dest_pos_in_ppatch];
          relu = dest_val + this_src_val + pos_add_val;
          if (relu >= 0.0) {
            scalar_t dest_grad_output = grad_output_buf[dest_pos_in_ppatch];
            grad_input_src_sum += dest_grad_output * pos_mul_val;
          }
        }
      }
      // Aggregate `grad_input_src_sum` over threads, if needed; and write the
      // result to `grad_input`.
494
      // h and w are un-padded indexes into the entire image.
Daniel Povey's avatar
Daniel Povey committed
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
      int h = patch_h_offset + pos_in_patch / patchW,
          w = patch_w_offset + pos_in_patch % patchW;

      if (h < H && w < W) {
        grad_input_src_sum = tiled_warp_reduce_sum(threads_per_pixel,
                                                   reduce_buf,
                                                   grad_input_src_sum);
        grad_input_dest_sum = tiled_warp_reduce_sum(threads_per_pixel,
                                                    reduce_buf,
                                                    grad_input_dest_sum);
        if (threadIdx.x % threads_per_pixel == 0) {
          grad_input[n][c][h][w] = grad_input_src_sum;
          int C = grad_output.size(1);
          grad_input[n][c + C][h][w] = grad_input_dest_sum;
        }
      }

      // OK, we are done computing grad_input for this patch.  Now
      // we need to contribute the contributions to grad_pos_add_buf
      // and grad_pos_mul_buf for this patch.
      // 0 <= pos_in_kernel < (kH * kW).
      int pos_in_kernel = threadIdx.x / threads_per_kernel_pos;
      scalar_t this_grad_pos_add = 0.0,
              this_grad_pos_mul = 0.0;
      if (pos_in_kernel < (kH * kW)) {
        int kh = pos_in_kernel / kW,
            kw = pos_in_kernel % kW;

        // This group of (threads_per_kernel_pos) threads is responsible
524
525
        // for position (kh, kw) in the kernel; we iterate over the patch
        // (an un-padded patch of output).
Daniel Povey's avatar
Daniel Povey committed
526
        scalar_t pos_add_val = pos_add_buf[pos_in_kernel],
Daniel Povey's avatar
Daniel Povey committed
527
                 pos_mul_val = pos_mul_buf[pos_in_kernel];
Daniel Povey's avatar
Daniel Povey committed
528
529
530
531
532
533
534

        for (int pos_in_patch = threadIdx.x % threads_per_kernel_pos;
             pos_in_patch < patch_size; pos_in_patch += threads_per_kernel_pos) {
          // We are working out the contribution to the gradients for pos_add
          // and pos_mul; we let `pos_in_patch` correspond to the *output*
          // position, and work out the input position based on gthe kernel position.

535
536
          int h_in_patch = pos_in_patch / patchW,
              w_in_patch = pos_in_patch % patchW;
Daniel Povey's avatar
Daniel Povey committed
537
538
539
540
541

          // pos_in_ppatch is the position in the padded patch corresponding to
          // `pos_in_patch`.
          int pos_in_ppatch = (h_in_patch + kH / 2) * ppatchW + (w_in_patch + kW / 2);
          scalar_t dest_val = dest_img_buf[pos_in_ppatch];
542
543
          int src_pos_in_ppatch = (h_in_patch + kh) * ppatchW + (w_in_patch + kw);
          scalar_t src_val = src_img_buf[src_pos_in_ppatch];
Daniel Povey's avatar
Daniel Povey committed
544
545
546
547
548
549
550
551
552
553
554
555
556

          scalar_t relu = dest_val + src_val + pos_add_val;
          if (relu >= 0.0) {
            scalar_t this_grad_output = grad_output_buf[pos_in_ppatch];
            this_grad_pos_add += this_grad_output * pos_mul_val;
            this_grad_pos_mul += this_grad_output * relu;
          }
        }
        this_grad_pos_add = tiled_warp_reduce_sum(
            threads_per_kernel_pos, reduce_buf, this_grad_pos_add);
        this_grad_pos_mul = tiled_warp_reduce_sum(
            threads_per_kernel_pos, reduce_buf, this_grad_pos_mul);
        if (threadIdx.x % threads_per_kernel_pos == 0) {
557
558
          grad_pos_add_buf[pos_in_kernel] += this_grad_pos_add;
          grad_pos_mul_buf[pos_in_kernel] += this_grad_pos_mul;
559
560
561
562
        }
      }
    }
  }
Daniel Povey's avatar
Daniel Povey committed
563

564
565
  __syncthreads(); // make sure all threads have written to grad_pos_add_buf and
                   // grad_pos_mul_buf.
Daniel Povey's avatar
Daniel Povey committed
566
567
568
569
570
571
572
573
574
  int block = blockIdx.z * gridDim.y + blockIdx.y;

  int kernel_pos = threadIdx.x;
  if (kernel_pos < (kH * kW)) {
    int kh = kernel_pos / kW,
        kw = kernel_pos % kW;
    grad_pos_add[block][c][kh][kw] = grad_pos_add_buf[kernel_pos];
    grad_pos_mul[block][c][kh][kw] = grad_pos_mul_buf[kernel_pos];
  }
575
576
577
578
}



Daniel Povey's avatar
Daniel Povey committed
579
580
581
582
583





584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
torch::Tensor integrated_conv_cuda(torch::Tensor input,
                                   torch::Tensor pos_add,
                                   torch::Tensor pos_mul) {
  TORCH_CHECK(input.dim() == 4, "input must be 4-dimensional");
  TORCH_CHECK(pos_add.dim() == 3, "pos_add must be 3-dimensional.");
  TORCH_CHECK(pos_mul.dim() == 3, "pos_add must be 3-dimensional.");
  TORCH_CHECK(input.device().is_cuda(), "Input must be a CUDA tensor");
  const int N = input.size(0),
      C = input.size(1) / 2,
      H = input.size(2),
      W = input.size(3),
      kH = pos_add.size(1),
      kW = pos_add.size(2);
  TORCH_CHECK(kH % 2 == 1 && kW % 2 == 1);
  TORCH_CHECK(input.size(1) % 2 == 0, "Input must have even num-channels");
  TORCH_CHECK(pos_add.size(0) == C && pos_mul.size(0) == C &&
              pos_mul.size(1) == kH && pos_mul.size(2) == kW,
              "Input sizes mismatch.");
  TORCH_CHECK(pos_add.device() == input.device() &&
              pos_mul.device() == pos_add.device(),
              "Input devices mismatch");
Daniel Povey's avatar
Daniel Povey committed
605
606
607
  auto scalar_t = input.scalar_type();
  TORCH_CHECK(pos_add.scalar_type() == scalar_t &&
              pos_mul.scalar_type() == scalar_t,
608
609
610
611
612
613
              "Input dtypes mismatch");

  torch::Tensor output = torch::empty({N, C, H, W},
                                      torch::TensorOptions().dtype(scalar_t).device(input.device()));


Daniel Povey's avatar
Daniel Povey committed
614
  // Work out the configuration to call the kernel with..
615
616
617
618
619
620
621
622
623
624
625
626
627
628
  int patchH = std::min(H, kH),  // output patch height
      patchW = std::min(W, kW);  // output patch width
  // We don't want the height or width of the patch to be less than the kernel
  // width, or the padding will make the input-patch size more than twice the
  // output-patch size.
  // We aim for the output-patch size to be more than 128; this is not something
  // very exact, but it roughly corresponds to us wanting to have up to 4 threads
  // per output pixel, and the limitation of 512 threads per thread-block which
  // we impose so that we can run on architectures with little shared memory.
  while (patchW < W && patchH * (patchW + 1) <= 128)
    patchW++;
  while(patchH < H && (patchH + 1) * patchW <= 128)
    patchH++;

Daniel Povey's avatar
Daniel Povey committed
629
630
  // We are assuming that the thread-block size can be as large as 512; this
  // works even on old CUDA architectures.
631
632
633
634
635
636
637
638
639
640
641
642
643
644
  int threads_per_opixel;
  if (patchH * patchW * 4 <= 512 && (kH * kW) > 16)
    threads_per_opixel = 4;
  else if (patchH * patchW * 2 <= 512 && (kH * kW) > 8)
    threads_per_opixel = 2;
  else
    threads_per_opixel = 1;

  int input_patchH = patchH + kH - 1,
         input_patchW = patchW + kW - 1,
         input_patch_size = input_patchH * input_patchW;

  int threads_per_block = patchH * patchW * threads_per_opixel;

Daniel Povey's avatar
Fix..  
Daniel Povey committed
645
646
  int buffer_numel = 2 * (kH * kW) + std::max<int>(threads_per_block,
                                                   input_patch_size);
647
648
649
650
651
652

  int num_patches_H = (H + patchH - 1) / patchH,
      num_patches_W = (W + patchW - 1) / patchW,
      num_patches = num_patches_H * num_patches_W;

  // gridDim.x == C.
Daniel Povey's avatar
Daniel Povey committed
653
654
  int num_blocks_patch = 1,  // gridDim.y.
       num_blocks_batch = 1;  // gridDim.z
655
656
657
658
659
660
661
662
663
664
665
666
667
  while (C * num_blocks_patch <= 256 &&
         num_blocks_patch * 2 <= num_patches)
    num_blocks_patch *= 2;
  if (C * num_patches <= 512)
    num_blocks_patch = num_patches;
  while (C * num_blocks_patch * num_blocks_batch <= 512 &&
         num_blocks_batch * 2 <= N)
    num_blocks_batch *= 2;
  if (C * num_blocks_patch * N <= 1024)
    num_blocks_batch = N;

  assert(num_blocks_patch <= num_patches && num_blocks_batch <= N);

Daniel Povey's avatar
Daniel Povey committed
668
#if 0
669
670
671
672
673
  std::cout << "N,C,H,W=" << N << "," << C << "," << H << "," << W
            << "; kW,kH=" << kW << "," << kH
            << "; patchH,patchW=" << patchH << ","
            << patchW << ", num_blocks_patch="
            << num_blocks_patch << ", num_blocks_batch="
Daniel Povey's avatar
Daniel Povey committed
674
            << num_blocks_batch
Daniel Povey's avatar
Daniel Povey committed
675
            << ", threads_per_opixel=" << threads_per_opixel
Daniel Povey's avatar
Daniel Povey committed
676
677
            << ", threads_per_block=" << threads_per_block
            << std::endl;
Daniel Povey's avatar
Daniel Povey committed
678
#endif
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699

  dim3 gridDim(C, num_blocks_patch, num_blocks_batch);
  // blockDim is scalar, just threads_per_block.
  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "integrated_conv_kernel", ([&] {
        integrated_conv_kernel<scalar_t><<<gridDim, threads_per_block, sizeof(scalar_t) * buffer_numel, at::cuda::getCurrentCUDAStream()>>>(
              input.packed_accessor32<scalar_t, 4>(),
              pos_add.packed_accessor32<scalar_t, 3>(),
              pos_mul.packed_accessor32<scalar_t, 3>(),
              output.packed_accessor32<scalar_t, 4>(),
              patchH,
              patchW);
      }));
  return output;
}



std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
                                                         torch::Tensor pos_add,
                                                         torch::Tensor pos_mul,
                                                         torch::Tensor grad_output) {
Daniel Povey's avatar
Daniel Povey committed
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
  TORCH_CHECK(input.dim() == 4, "input must be 4-dimensional");
  TORCH_CHECK(pos_add.dim() == 3, "pos_add must be 3-dimensional.");
  TORCH_CHECK(pos_mul.dim() == 3, "pos_add must be 3-dimensional.");
  TORCH_CHECK(input.device().is_cuda(), "Input must be a CUDA tensor");
  const int N = input.size(0),
      C = input.size(1) / 2,
      H = input.size(2),
      W = input.size(3),
      kH = pos_add.size(1),
      kW = pos_add.size(2);
  TORCH_CHECK(kH % 2 == 1 && kW % 2 == 1);
  TORCH_CHECK(input.size(1) % 2 == 0, "Input must have even num-channels");
  TORCH_CHECK(pos_add.size(0) == C && pos_mul.size(0) == C &&
              pos_mul.size(1) == kH && pos_mul.size(2) == kW,
              "Input sizes mismatch.");
  TORCH_CHECK(pos_add.device() == input.device() &&
              pos_mul.device() == pos_add.device(),
              "Input devices mismatch");
  auto scalar_t = input.scalar_type();
  TORCH_CHECK(pos_add.scalar_type() == scalar_t &&
              pos_mul.scalar_type() == scalar_t,
              "Input dtypes mismatch");
  TORCH_CHECK(grad_output.dim() == 4 && grad_output.size(0) == N
              && grad_output.size(1) == C && grad_output.size(2) == H
              && grad_output.size(3) == W);


  // Work out the configuration to call the kernel with..
  int patchH = std::min(H, kH),  // output patch height
      patchW = std::min(W, kW);  // output patch width
  // We don't want the height or width of the patch to be less than the kernel
  // width, or the padding will make the input-patch size more than twice the
  // output-patch size.
  // We aim for the output-patch size to be more than 128; this is not something
  // very exact, but it roughly corresponds to us wanting to have up to 4 threads
  // per output pixel, and the limitation of 512 threads per thread-block which
  // we impose so that we can run on architectures with little shared memory.
  while (patchW < W && patchH * (patchW + 1) <= 128)
    patchW++;
  while(patchH < H && (patchH + 1) * patchW <= 128)
    patchH++;

  // We are assuming that the thread-block size can be as large as 512; this
  // works even on old CUDA architectures.
  int threads_per_pixel;
  if (patchH * patchW * 4 <= 512 && (kH * kW) > 8)
    threads_per_pixel = 4;
  else if (patchH * patchW * 2 <= 512 && (kH * kW) > 4)
    threads_per_pixel = 2;
  else
    threads_per_pixel = 1;

  int threads_per_block = patchH * patchW * threads_per_pixel;
  // round threads_per_block up to a multiple of 64.  We need it to be
  // equivalent to an even number of warps, because at one point we divide the
  // threads into two halves and we want them to be an even number of warps.
  threads_per_block = 64 * ((threads_per_block + 63) / 64);

  {
    // If it's possible to increase the patch width or height while not exceeding
    // this number of threads, do so.  (This is a small optimization).
    int patchW_old = patchW;
    while (patchH * (patchW + 1) * threads_per_pixel <= threads_per_block)
      patchW++;
    // If the above change to patchW did not actually reduce the number of patches
    // needed to cover the image, gthen there is no point to the change; and it
    // increases the shared-memory requirement, so revert it.
    if ((W + patchW_old - 1) / patchW_old == (W + patchW - 1) / patchW)
      patchW = patchW_old;
    int patchH_old = patchH;
    while ((patchH + 1) * patchW * threads_per_pixel <= threads_per_block)
      patchH++;
    if ((H + patchH_old - 1) / patchH_old == (H + patchH - 1) / patchH)
      patchH = patchH_old;
  }


  int threads_per_kernel_pos = 1;
  while (threads_per_kernel_pos < 32 &&
         threads_per_kernel_pos * 2 * kH * kW <= threads_per_block)
    threads_per_kernel_pos *= 2;

  // dimensions of padded patches
  int ppatchH = patchH + kH - 1,
       ppatchW = patchW + kW - 1,
   ppatch_size = ppatchH * ppatchW;

  int buffer_numel = 4 * (kH * kW) + 3 * ppatch_size + threads_per_block;

  int num_patches_H = (H + patchH - 1) / patchH,
      num_patches_W = (W + patchW - 1) / patchW,
      num_patches = num_patches_H * num_patches_W;

  // gridDim.x == C.
  int num_blocks_patch = 1,  // gridDim.y.  should not be more
      num_blocks_batch = 1;  // gridDim.z
  // We have a rough target of no more than 256 thread-groups.
  while (C * num_blocks_patch * 2 <= 256 &&
         num_blocks_patch * 2 <= num_patches)
    num_blocks_patch *= 2;
  if (C * num_patches <= 512)
    num_blocks_patch = num_patches;
  while (C * num_blocks_patch * num_blocks_batch * 2 <= 256 &&
         num_blocks_batch * 2 <= N)
    num_blocks_batch *= 2;

  assert(num_blocks_patch <= num_patches && num_blocks_batch <= N);
  assert(patchH * patchW * threads_per_pixel <= threads_per_block);
  assert(kH * kW * threads_per_kernel_pos <= threads_per_block);

Daniel Povey's avatar
Daniel Povey committed
810
#if 0
Daniel Povey's avatar
Daniel Povey committed
811
812
813
814
815
816
817
818
819
820
821
  std::cout << "[backward:] N,C,H,W=" << N << "," << C << "," << H << "," << W
            << "; kW,kH=" << kW << "," << kH
            << "; patchH,patchW=" << patchH << ","
            << patchW << ", num_blocks_patch="
            << num_blocks_patch << ", num_blocks_batch="
            << num_blocks_batch
            << ", threads_per_pixel=" << threads_per_pixel
            << ", threads_per_kernel_pos=" << threads_per_kernel_pos
            << ", threads_per_block=" << threads_per_block
            << ", buffer_numel=" << buffer_numel
            << std::endl;
Daniel Povey's avatar
Daniel Povey committed
822
#endif
Daniel Povey's avatar
Daniel Povey committed
823
824
825
826
827
828
829
830
831
832
833
834
835

  int num_blocks = num_blocks_patch * num_blocks_batch;

  torch::Tensor grad_input = torch::zeros({N, 2*C, H, W},
                                          torch::TensorOptions().dtype(scalar_t).device(input.device())),
      grad_pos_add = torch::zeros({num_blocks, C, kH, kW},
                                  torch::TensorOptions().dtype(scalar_t).device(input.device())),
      grad_pos_mul = torch::zeros({num_blocks, C, kH, kW},
                                  torch::TensorOptions().dtype(scalar_t).device(input.device()));


  dim3 gridDim(C, num_blocks_patch, num_blocks_batch);
  // blockDim is scalar, just threads_per_block.
Daniel Povey's avatar
Daniel Povey committed
836
837
838
839
  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "integrated_conv_kernel_backward", ([&] {
        integrated_conv_kernel_backward<scalar_t><<<gridDim, threads_per_block,
                                                    sizeof(scalar_t) * buffer_numel,
                                                    at::cuda::getCurrentCUDAStream()>>>(
Daniel Povey's avatar
Daniel Povey committed
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
              input.packed_accessor32<scalar_t, 4>(),
              pos_add.packed_accessor32<scalar_t, 3>(),
              pos_mul.packed_accessor32<scalar_t, 3>(),
              grad_output.packed_accessor32<scalar_t, 4>(),
              grad_input.packed_accessor32<scalar_t, 4>(),
              grad_pos_add.packed_accessor32<scalar_t, 4>(),
              grad_pos_mul.packed_accessor32<scalar_t, 4>(),
              patchH,
              patchW,
              threads_per_pixel,
              threads_per_kernel_pos);
      }));
  grad_pos_add = at::sum(grad_pos_add, {0});
  grad_pos_mul = at::sum(grad_pos_mul, {0});

  return std::vector<torch::Tensor>({grad_input, grad_pos_add, grad_pos_mul});
856
}