scaled_masked_softmax.h 20.6 KB
Newer Older
shenggan's avatar
shenggan committed
1
2
3
4
5
6
/*This code from NVIDIA Megatron:
 *     with minor changes. */

#pragma once

#include <assert.h>
7
#include <c10/macros/Macros.h>
shenggan's avatar
shenggan committed
8
#include <cuda_fp16.h>
9
10
#include <stdint.h>

shenggan's avatar
shenggan committed
11
12
13
14
15
16
17
18
19
#include <cfloat>
#include <limits>

namespace {

template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);

template <>
20
21
22
23
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
    c10::BFloat16 *dst, const c10::BFloat16 *src) {
  *dst = *src;
}
shenggan's avatar
shenggan committed
24
25

template <>
26
27
28
29
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
    c10::BFloat16 *dst, const c10::BFloat16 *src) {
  *((float2 *)dst) = *((float2 *)src);
}
shenggan's avatar
shenggan committed
30
31

template <>
32
33
34
35
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
                                                     const c10::Half *src) {
  *dst = *src;
}
shenggan's avatar
shenggan committed
36
37

template <>
38
39
40
41
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
                                                     const c10::Half *src) {
  *((float2 *)dst) = *((float2 *)src);
}
shenggan's avatar
shenggan committed
42
43

template <>
44
45
46
47
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst,
                                                   const uint8_t *src) {
  *dst = *src;
}
shenggan's avatar
shenggan committed
48
49

template <>
50
51
52
53
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
                                                   const uint8_t *src) {
  *((half2 *)dst) = *((half2 *)src);
}
shenggan's avatar
shenggan committed
54
55

int log2_ceil(int value) {
56
57
58
  int log2_value = 0;
  while ((1 << log2_value) < value) ++log2_value;
  return log2_value;
shenggan's avatar
shenggan committed
59
60
}

61
template <typename T>
shenggan's avatar
shenggan committed
62
struct Add {
63
  __device__ __forceinline__ T operator()(T a, T b) const { return a + b; }
shenggan's avatar
shenggan committed
64
65
};

66
template <typename T>
shenggan's avatar
shenggan committed
67
68
69
70
71
72
73
struct Max {
  __device__ __forceinline__ T operator()(T a, T b) const {
    return a < b ? b : a;
  }
};

template <typename T>
74
75
76
__device__ __forceinline__ T
WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
                     unsigned int mask = 0xffffffff) {
shenggan's avatar
shenggan committed
77
#if CUDA_VERSION >= 9000
78
  return __shfl_xor_sync(mask, value, laneMask, width);
shenggan's avatar
shenggan committed
79
#else
80
  return __shfl_xor(value, laneMask, width);
shenggan's avatar
shenggan committed
81
82
83
#endif
}

84
85
86
87
88
89
90
91
92
93
template <typename acc_t, int WARP_BATCH, int WARP_SIZE,
          template <typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t *sum) {
  ReduceOp<acc_t> r;
#pragma unroll
  for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
    for (int i = 0; i < WARP_BATCH; ++i) {
      acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
      sum[i] = r(sum[i], b);
shenggan's avatar
shenggan committed
94
    }
95
  }
shenggan's avatar
shenggan committed
96
97
98
}

/*
99
100
101
102
103
 * Extended softmax (from native aten pytorch) with following additional
 * features 1) input scaling 2) Explicit masking
 */
template <typename input_t, typename output_t, typename acc_t,
          int log2_elements>
shenggan's avatar
shenggan committed
104
__global__ void scaled_masked_softmax_warp_forward(
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale,
    int micro_batch_size, int element_count, int pad_batches) {
  // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
  // warp_size of method warp_softmax_forward_kernel.
  constexpr int next_power_of_two = 1 << log2_elements;
  constexpr int WARP_SIZE =
      (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
  constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
  constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;

  // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
  // gridDim/blockIdx = (seq_len, attn_heads, batches)
  int first_batch =
      (blockDim.y *
           (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) +
       threadIdx.y) *
      WARP_BATCH;
  int pad_first_batch = 0;
  if (pad_batches != 1) {  // bert style
    pad_first_batch =
        (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) *
        WARP_BATCH;
  } else {  // gpt2 style
    pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
  }
shenggan's avatar
shenggan committed
131

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
  // micro_batch_size might not be a multiple of WARP_BATCH. Check how
  // many batches have to computed within this WARP.
  int local_batches = micro_batch_size - first_batch;
  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;

  // there might be multiple batches per warp. compute the index within the
  // batch
  int local_idx = threadIdx.x;

  src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
  dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
  mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;

  // load data from global memory
  acc_t elements[WARP_BATCH][WARP_ITERATIONS];
  input_t temp_data[ELEMENTS_PER_LDG_STG];
  uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
#pragma unroll
  for (int i = 0; i < WARP_BATCH; ++i) {
    int batch_element_count = (i >= local_batches) ? 0 : element_count;

#pragma unroll
    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;

      if (element_index < batch_element_count) {
        int itr_idx = i * element_count + it * WARP_SIZE;
        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
        copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);

#pragma unroll
        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
          if (temp_mask[element] != 1) {
            elements[i][it + element] = (acc_t)temp_data[element] * scale;
          } else {
            elements[i][it + element] = -10000.0;
          }
        }
      } else {
#pragma unroll
        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
          elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
shenggan's avatar
shenggan committed
174
        }
175
      }
shenggan's avatar
shenggan committed
176
    }
177
  }
shenggan's avatar
shenggan committed
178

179
180
181
182
183
184
185
186
187
  // compute max_value
  acc_t max_value[WARP_BATCH];
#pragma unroll
  for (int i = 0; i < WARP_BATCH; ++i) {
    max_value[i] = elements[i][0];
#pragma unroll
    for (int it = 1; it < WARP_ITERATIONS; ++it) {
      max_value[i] =
          (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
shenggan's avatar
shenggan committed
188
    }
189
190
191
192
193
194
195
196
197
198
  }
  warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);

  acc_t sum[WARP_BATCH]{0.0f};
#pragma unroll
  for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
    for (int it = 0; it < WARP_ITERATIONS; ++it) {
      elements[i][it] = std::exp((elements[i][it] - max_value[i]));
      sum[i] += elements[i][it];
shenggan's avatar
shenggan committed
199
    }
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
  }
  warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);

  // store result
  output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
  for (int i = 0; i < WARP_BATCH; ++i) {
    if (i >= local_batches) break;
#pragma unroll
    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
      if (element_index < element_count) {
#pragma unroll
        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
          out[element] = elements[i][it + element] / sum[i];
shenggan's avatar
shenggan committed
215
        }
216
217
218
219
220
        copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
            dst + i * element_count + it * WARP_SIZE, out);
      } else {
        break;
      }
shenggan's avatar
shenggan committed
221
    }
222
  }
shenggan's avatar
shenggan committed
223
224
}

225
226
template <typename input_t, typename output_t, typename acc_t,
          int log2_elements>
shenggan's avatar
shenggan committed
227
__global__ void scaled_masked_softmax_warp_backward(
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
    output_t *gradInput, input_t *grad, const input_t *output, acc_t scale,
    int micro_batch_size, int element_count) {
  // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
  // warp_size of method warp_softmax_backward_kernel.
  constexpr int next_power_of_two = 1 << log2_elements;
  constexpr int WARP_SIZE =
      (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
  constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
  constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;

  // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
  // gridDim/blockIdx = (seq_len, attn_heads, batches)
  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;

  // micro_batch_size might not be a multiple of WARP_BATCH. Check how
  // many batches have to computed within this WARP.
  int local_batches = micro_batch_size - first_batch;
  if (local_batches > WARP_BATCH) local_batches = WARP_BATCH;

  // there might be multiple batches per warp. compute the index within the
  // batch
  int local_idx = threadIdx.x;

  // the first element to process by the current thread
  int thread_offset =
      first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
  grad += thread_offset;
  output += thread_offset;
  gradInput += thread_offset;

  // load data from global memory
  acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
  acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f};
  input_t temp_grad[ELEMENTS_PER_LDG_STG];
  input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
  for (int i = 0; i < WARP_BATCH; ++i) {
    int batch_element_count = (i >= local_batches) ? 0 : element_count;

#pragma unroll
    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
      if (element_index < batch_element_count) {
        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
            temp_grad, grad + i * element_count + it * WARP_SIZE);
        copy_vector<input_t, ELEMENTS_PER_LDG_STG>(
            temp_output, output + i * element_count + it * WARP_SIZE);

#pragma unroll
        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
          output_reg[i][it + element] = (acc_t)temp_output[element];
shenggan's avatar
shenggan committed
280
        }
281
282
283
284
#pragma unroll
        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
          grad_reg[i][it + element] =
              (acc_t)temp_grad[element] * output_reg[i][it + element];
binmakeswell's avatar
binmakeswell committed
285
        }
286
      }
287
    }
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
  }

  acc_t sum[WARP_BATCH];
#pragma unroll
  for (int i = 0; i < WARP_BATCH; ++i) {
    sum[i] = grad_reg[i][0];
#pragma unroll
    for (int it = 1; it < WARP_ITERATIONS; ++it) {
      sum[i] += grad_reg[i][it];
    }
  }
  warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);

// store result
#pragma unroll
  for (int i = 0; i < WARP_BATCH; ++i) {
    if (i >= local_batches) break;
#pragma unroll
    for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
      int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
      if (element_index < element_count) {
        // compute gradients
        output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
        for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
          out[element] =
              (output_t)(scale * (grad_reg[i][it + element] -
                                  output_reg[i][it + element] * sum[i]));
shenggan's avatar
shenggan committed
316
        }
317
318
319
        copy_vector<output_t, ELEMENTS_PER_LDG_STG>(
            gradInput + i * element_count + it * WARP_SIZE, out);
      }
shenggan's avatar
shenggan committed
320
    }
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
  }
}
}  // end of anonymous namespace

int get_batch_per_block(int query_seq_len, int key_seq_len, int batches,
                        int attn_heads) {
  int log2_elements = log2_ceil(key_seq_len);
  const int next_power_of_two = 1 << log2_elements;

  int warp_size =
      (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
  int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

  constexpr int threads_per_block = 128;
  int warps_per_block = (threads_per_block / warp_size);
  int batches_per_block = warps_per_block * batches_per_warp;

  return batches_per_block;
shenggan's avatar
shenggan committed
339
340
}

341
342
343
344
345
346
347
348
349
350
351
template <typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src,
                                            const uint8_t *mask,
                                            const input_t scale,
                                            int query_seq_len, int key_seq_len,
                                            int batches, int attn_heads,
                                            int pad_batches) {
  TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
  if (key_seq_len == 0) {
    return;
  } else {
shenggan's avatar
shenggan committed
352
353
    int log2_elements = log2_ceil(key_seq_len);
    const int next_power_of_two = 1 << log2_elements;
354
    int batch_count = batches * attn_heads * query_seq_len;
shenggan's avatar
shenggan committed
355

356
357
358
359
360
361
362
    // This value must match the WARP_SIZE constexpr value computed inside
    // softmax_warp_forward.
    int warp_size =
        (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;

    // This value must match the WARP_BATCH constexpr value computed inside
    // softmax_warp_forward.
shenggan's avatar
shenggan committed
363
364
    int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

365
    // use 128 threads per block to maximimize gpu utilization
shenggan's avatar
shenggan committed
366
    constexpr int threads_per_block = 128;
367

shenggan's avatar
shenggan committed
368
369
    int warps_per_block = (threads_per_block / warp_size);
    int batches_per_block = warps_per_block * batches_per_warp;
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0);
    dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches);
    dim3 threads(warp_size, warps_per_block, 1);
    // Launch code would be more elegant if C++ supported FOR CONSTEXPR
    switch (log2_elements) {
      case 0:  // 1
        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
        break;
      case 1:  // 2
        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
        break;
      case 2:  // 4
        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
        break;
      case 3:  // 8
        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
        break;
      case 4:  // 16
        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
        break;
      case 5:  // 32
        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
        break;
      case 6:  // 64
        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
        break;
      case 7:  // 128
        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
        break;
      case 8:  // 256
        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
        break;
      case 9:  // 512
        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
        break;
      case 10:  // 1024
        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
        break;
      case 11:  // 2048
        scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
        break;
      default:
        break;
binmakeswell's avatar
binmakeswell committed
437
    }
438
  }
binmakeswell's avatar
binmakeswell committed
439
}
440

441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
template <typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_backward(output_t *grad_input,
                                             input_t *grad,
                                             const input_t *output,
                                             const acc_t scale,
                                             int query_seq_len, int key_seq_len,
                                             int batches, int attn_heads) {
  TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048);
  if (key_seq_len == 0) {
    return;
  } else {
    int log2_elements = log2_ceil(key_seq_len);
    const int next_power_of_two = 1 << log2_elements;
    int batch_count = batches * attn_heads * query_seq_len;

    // This value must match the WARP_SIZE constexpr value computed inside
    // softmax_warp_backward.
    int warp_size =
        (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;

    // This value must match the WARP_BATCH constexpr value computed inside
    // softmax_warp_backward.
    int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

    // use 128 threads per block to maximimize gpu utilization
    constexpr int threads_per_block = 128;

    int warps_per_block = (threads_per_block / warp_size);
    int batches_per_block = warps_per_block * batches_per_warp;
    int blocks = batch_count / batches_per_block;
    dim3 threads(warp_size, warps_per_block, 1);
    // Launch code would be more elegant if C++ supported FOR CONSTEXPR
    switch (log2_elements) {
      case 0:  // 1
        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                grad_input, grad, output, scale, batch_count, key_seq_len);
        break;
      case 1:  // 2
        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                grad_input, grad, output, scale, batch_count, key_seq_len);
        break;
      case 2:  // 4
        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                grad_input, grad, output, scale, batch_count, key_seq_len);
        break;
      case 3:  // 8
        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                grad_input, grad, output, scale, batch_count, key_seq_len);
        break;
      case 4:  // 16
        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                grad_input, grad, output, scale, batch_count, key_seq_len);
        break;
      case 5:  // 32
        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                grad_input, grad, output, scale, batch_count, key_seq_len);
        break;
      case 6:  // 64
        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                grad_input, grad, output, scale, batch_count, key_seq_len);
        break;
      case 7:  // 128
        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                grad_input, grad, output, scale, batch_count, key_seq_len);
        break;
      case 8:  // 256
        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                grad_input, grad, output, scale, batch_count, key_seq_len);
        break;
      case 9:  // 512
        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                grad_input, grad, output, scale, batch_count, key_seq_len);
        break;
      case 10:  // 1024
        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                grad_input, grad, output, scale, batch_count, key_seq_len);
        break;
      case 11:  // 2048
        scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
            <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
                grad_input, grad, output, scale, batch_count, key_seq_len);
        break;
      default:
        break;
shenggan's avatar
shenggan committed
536
    }
537
  }
shenggan's avatar
shenggan committed
538
}