scaled_masked_softmax.h 20.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/* coding=utf-8
 * Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>

namespace {

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);

template <>
__device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; }

template <>
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) { *dst = *src; }

template <>
__device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }

template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
int log2_ceil(int value) {
    int log2_value = 0;
    while ((1 << log2_value) < value) ++log2_value;
    return log2_value;
}

template<typename T>
struct Add {
  __device__ __forceinline__ T operator()(T a, T b) const {
    return a + b;
  }
};

template<typename T>
struct Max {
  __device__ __forceinline__ T operator()(T a, T b) const {
    return a < b ? b : a;
  }
};

template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
    return __shfl_xor_sync(mask, value, laneMask, width);
#else
    return __shfl_xor(value, laneMask, width);
#endif
}

template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
    ReduceOp<acc_t> r;
    #pragma unroll
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        #pragma unroll
        for (int i = 0;  i < WARP_BATCH;  ++i) {
            acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
            sum[i] = r(sum[i], b);
        }
    }
}

/*
 * Extended softmax (from native aten pytorch) with following additional features
 * 1) input scaling
 * 2) Explicit masking
 */	
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_forward(
    output_t *dst, 
    const input_t *src,
    const uint8_t *mask, 
    const acc_t scale, 
100
    int micro_batch_size, 
101
102
103
104
105
106
107
108
109
    int element_count,
    int pad_batches) 
{
    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and 
    // warp_size of method warp_softmax_forward_kernel.
    constexpr int next_power_of_two = 1 << log2_elements;
    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
110
    constexpr int ELEMENTS_PER_LDG_STG = 4;
111
112
113
114
115
116

    // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
    // gridDim/blockIdx = (seq_len, attn_heads, batches) 
    int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
    int pad_first_batch = 0;
    if (pad_batches != 1) { // bert style
117
        pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
118
119
120
121
    } else { // gpt2 style
        pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
    }

122
    // micro_batch_size might not be a multiple of WARP_BATCH. Check how
123
    // many batches have to computed within this WARP.
124
    int local_batches = micro_batch_size - first_batch;
125
126
127
128
129
130
    if (local_batches > WARP_BATCH)
        local_batches = WARP_BATCH;

    // there might be multiple batches per warp. compute the index within the batch
    int local_idx = threadIdx.x;

131
132
133
    src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
    dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
    mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
134
135
136

    // load data from global memory
    acc_t elements[WARP_BATCH][WARP_ITERATIONS];
137
138
    input_t temp_data[ELEMENTS_PER_LDG_STG];
    uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
139
140
141
142
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        int batch_element_count = (i >= local_batches) ? 0 : element_count;

143
144
145
        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
146
147

            if (element_index < batch_element_count) {
148
149
150
151
152
153
154
155
156
157
158
159
                int itr_idx = i*element_count+it*WARP_SIZE;
                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
                copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);

                #pragma unroll
                  for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                      if (temp_mask[element] != 1) {
                          elements[i][it + element] = (acc_t)temp_data[element] * scale;
                      } else {
                          elements[i][it + element] = -10000.0;
                      }
                  }
160
            } else {
161
162
163
164
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
                }
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
            }
        }
    }

    // compute max_value
    acc_t max_value[WARP_BATCH];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        max_value[i] = elements[i][0];
        #pragma unroll
        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {
            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);

    acc_t sum[WARP_BATCH] { 0.0f };
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
            elements[i][it] = std::exp((elements[i][it] - max_value[i]));
            sum[i] += elements[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);

    // store result
193
    output_t out[ELEMENTS_PER_LDG_STG];
194
195
196
197
198
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        if (i >= local_batches)
            break;
        #pragma unroll
199
200
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
201
            if (element_index < element_count) {
202
203
204
205
206
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    out[element] = elements[i][it + element] / sum[i];
                }
                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);  
207
208
209
210
211
212
213
214
215
216
217
218
219
            } else {
                break;
            } 
        }
    }
}

template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_backward(
    output_t *gradInput, 
    input_t *grad, 
    const input_t *output,
    acc_t scale, 
220
    int micro_batch_size, 
221
222
223
224
225
226
227
228
    int element_count)
{
    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and 
    // warp_size of method warp_softmax_backward_kernel.
    constexpr int next_power_of_two = 1 << log2_elements;
    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
229
    constexpr int ELEMENTS_PER_LDG_STG = 4;
230
231
232
233
234

    // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
    // gridDim/blockIdx = (seq_len, attn_heads, batches) 
    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
    
235
    // micro_batch_size might not be a multiple of WARP_BATCH. Check how
236
    // many batches have to computed within this WARP.
237
    int local_batches = micro_batch_size - first_batch;
238
239
240
241
242
243
244
    if (local_batches > WARP_BATCH)
        local_batches = WARP_BATCH;

    // there might be multiple batches per warp. compute the index within the batch
    int local_idx = threadIdx.x;

    // the first element to process by the current thread
245
    int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
246
247
248
249
250
251
    grad += thread_offset;
    output += thread_offset;
    gradInput += thread_offset;

    // load data from global memory
    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
252
    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
253
254
255
256
257
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        int batch_element_count = (i >= local_batches) ? 0 : element_count;

        #pragma unroll
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
            if (element_index < batch_element_count) {
                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);
                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);

                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    output_reg[i][it + element] = (acc_t)temp_output[element];
                }
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
                }
            } 
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
        }
    }
   
    acc_t sum[WARP_BATCH];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        sum[i] = grad_reg[i][0];
        #pragma unroll
        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {
            sum[i] += grad_reg[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);

    // store result
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        if (i >= local_batches)
            break;
        #pragma unroll
293
294
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
295
296
            if (element_index < element_count) {
                // compute gradients
297
298
299
300
301
302
                output_t out[ELEMENTS_PER_LDG_STG];
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
                }
                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
303
304
305
306
307
308
309
310
311
312
313
314
315
            } 
        }
    }
}

} // end of anonymous namespace

template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(
    output_t *dst, 
    const input_t *src, 
    const uint8_t *mask,
    const input_t scale, 
316
317
    int query_seq_len, 
    int key_seq_len, 
318
319
320
321
    int batches,
    int attn_heads,
    int pad_batches)
{
322
323
    TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
    if (key_seq_len == 0) {
324
325
        return;
    } else {
326
        int log2_elements = log2_ceil(key_seq_len);
327
        const int next_power_of_two = 1 << log2_elements;
328
        int batch_count = batches * attn_heads * query_seq_len;
329
330
331
332
333
334
335
336
337
338
339

        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;

        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

        // use 128 threads per block to maximimize gpu utilization
        constexpr int threads_per_block = 128;

        int warps_per_block = (threads_per_block / warp_size);
340
341
        int batches_per_block = warps_per_block * batches_per_warp;
        TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
342
        dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
343
344
345
346
347
        dim3 threads(warp_size, warps_per_block, 1);
        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
        switch (log2_elements) {
            case 0: // 1
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
348
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
349
350
351
                break;
            case 1: // 2
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
352
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
353
354
355
                break;
            case 2: // 4
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
356
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
357
358
359
                break;
            case 3: // 8
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
360
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
361
362
363
                break;
            case 4: // 16
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
364
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
365
366
367
                break;
            case 5: // 32
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
368
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
369
370
371
                break;
            case 6: // 64
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
372
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
373
374
375
                break;
            case 7: // 128
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
376
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
377
378
379
                break;
            case 8: // 256
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
380
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
381
382
383
                break;
            case 9: // 512
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
384
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
385
386
387
                break;
            case 10: // 1024
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
388
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
389
390
391
                break;
            case 11: // 2048
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
392
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
393
394
395
396
397
398
399
400
401
402
403
404
405
                break;
            default:
                break;
        }
    }
}

template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_backward(
    output_t *grad_input, 
    input_t *grad, 
    const input_t *output, 
    const acc_t scale, 
406
407
    int query_seq_len, 
    int key_seq_len, 
408
409
410
    int batches,
    int attn_heads)
{
411
412
    TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
    if (key_seq_len == 0) {
413
414
       return;
    } else {
415
        int log2_elements = log2_ceil(key_seq_len);
416
        const int next_power_of_two = 1 << log2_elements;
417
        int batch_count = batches *  attn_heads * query_seq_len;
418
419
420
421
422
423
424
425
426
427
428

        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;

        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

        // use 128 threads per block to maximimize gpu utilization
        constexpr int threads_per_block = 128;

        int warps_per_block = (threads_per_block / warp_size);
429
        int batches_per_block = warps_per_block * batches_per_warp;
430
431
432
433
434
435
        int blocks = batch_count/batches_per_block;
        dim3 threads(warp_size, warps_per_block, 1);
        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
        switch (log2_elements) {
            case 0: // 1
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
436
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
437
438
439
                break;
            case 1: // 2
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
440
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
441
442
443
                break;
            case 2: // 4
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
444
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
445
446
447
                break;
            case 3: // 8
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
448
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
449
450
451
                break;
            case 4: // 16
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
452
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
453
454
455
                break;
            case 5: // 32
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
456
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
457
458
459
                break;
            case 6: // 64
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
460
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
461
462
463
                break;
            case 7: // 128
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
464
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
465
466
467
                break;
            case 8: // 256
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
468
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
469
470
471
                break;
            case 9: // 512
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
472
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
473
474
475
                break;
            case 10: // 1024
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
476
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
477
478
479
                break;
            case 11: // 2048
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
480
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
481
482
483
484
485
486
                break;
            default:
                break;
        }
    }
}