scaled_upper_triang_masked_softmax.h 21.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/* coding=utf-8
 * Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>

namespace {

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);

template <>
__device__ __inline__ void copy_vector<__half, 1>(__half *dst, const __half *src) { *dst = *src; }

template <>
__device__ __inline__ void copy_vector<__half, 4>(__half *dst, const __half *src) { *((float2*) dst) = *((float2*) src); }

template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }

template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }

Vijay Korthikanti's avatar
Vijay Korthikanti committed
44
45
46
47
48
49
50
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);

template <>
__device__ __inline__ void copy_zero_vector<__half, 4>(__half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }


51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
int log2_ceil(int value) {
    int log2_value = 0;
    while ((1 << log2_value) < value) ++log2_value;
    return log2_value;
}

template<typename T>
struct Add {
  __device__ __forceinline__ T operator()(T a, T b) const {
    return a + b;
  }
};

template<typename T>
struct Max {
  __device__ __forceinline__ T operator()(T a, T b) const {
    return a < b ? b : a;
  }
};

template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
    return __shfl_xor_sync(mask, value, laneMask, width);
#else
    return __shfl_xor(value, laneMask, width);
#endif
}

template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
    ReduceOp<acc_t> r;
    #pragma unroll
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        #pragma unroll
        for (int i = 0;  i < WARP_BATCH;  ++i) {
            acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
            sum[i] = r(sum[i], b);
        }
    }
}

/*
 * Extended softmax (from native aten pytorch) with following additional features
 * 1) input scaling
 * 2) Implicit time (diagonal masking)
98
 */
99
100
101
102
103
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_forward(
    output_t *dst, 
    const input_t *src, 
    const acc_t scale, 
104
    int micro_batch_size, 
105
106
107
108
109
110
111
112
113
    int stride, 
    int element_count) 
{
    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and 
    // warp_size of method warp_softmax_forward_kernel.
    constexpr int next_power_of_two = 1 << log2_elements;
    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
114
    constexpr int ELEMENTS_PER_LDG_STG = 4;
115
116
117
118
119

    int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
    int local_seq = blockIdx.x + 1; 
    int warp_iteration_limit = (local_seq + WARP_SIZE - 1)/WARP_SIZE;

120
    // micro_batch_size might not be a multiple of WARP_BATCH. Check how
121
    // many batches have to computed within this WARP.
122
    int local_batches = micro_batch_size - first_batch;
123
124
125
126
127
128
    if (local_batches > WARP_BATCH)
        local_batches = WARP_BATCH;

    // there might be multiple batches per warp. compute the index within the batch
    int local_idx = threadIdx.x;

129
130
    src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
    dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
131
132
133

    // load data from global memory
    acc_t elements[WARP_BATCH][WARP_ITERATIONS];
134
    input_t temp_data[ELEMENTS_PER_LDG_STG];
135
136
137
138
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        int batch_element_count = (i >= local_batches) ? 0 : local_seq;

139
140
141
142
        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;

143
            if (element_index < batch_element_count) {
144
145
146
147
148
149
150
                int itr_idx = i*element_count*stride+it*WARP_SIZE;
                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);

                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    elements[i][it+element] = (acc_t)temp_data[element] * scale;
                }
151
            } else {
152
153
154
155
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
                }
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
            }
        }
    }

    // compute max_value
    acc_t max_value[WARP_BATCH];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        max_value[i] = elements[i][0];
        #pragma unroll
        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {
            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);

    acc_t sum[WARP_BATCH] { 0.0f };
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
177
            if (it < warp_iteration_limit) {
178
179
                elements[i][it] = std::exp((elements[i][it] - max_value[i]));
                sum[i] += elements[i][it];
180
            } 
181
182
183
184
185
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);

    // store result
186
    output_t out[ELEMENTS_PER_LDG_STG];
187
188
189
190
191
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        if (i >= local_batches)
            break;
        #pragma unroll
192
193
194
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;

195
            if (element_index < local_seq) {
196
197
198
199
200
201

                #pragma unroll  
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    out[element] = elements[i][it + element] / sum[i];
                }
                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);
202
            } else if (element_index < element_count) {
Vijay Korthikanti's avatar
Vijay Korthikanti committed
203
                copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
204
205
206
207
208
209
210
211
212
213
214
215
216
            } else {
                break;
            } 
        }
    }
}

template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_upper_triang_masked_softmax_warp_backward(
    output_t *gradInput, 
    input_t *grad, 
    const input_t *output,
    acc_t scale, 
217
    int micro_batch_size, 
218
219
220
221
222
223
224
225
226
    int stride, 
    int element_count)
{
    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and 
    // warp_size of method warp_softmax_backward_kernel.
    constexpr int next_power_of_two = 1 << log2_elements;
    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
227
    constexpr int ELEMENTS_PER_LDG_STG = 4;
228
229
230
231

    int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
    int local_seq = blockIdx.x + 1; 
    
232
    // micro_batch_size might not be a multiple of WARP_BATCH. Check how
233
    // many batches have to computed within this WARP.
234
    int local_batches = micro_batch_size - first_batch;
235
236
237
238
239
240
241
    if (local_batches > WARP_BATCH)
        local_batches = WARP_BATCH;

    // there might be multiple batches per warp. compute the index within the batch
    int local_idx = threadIdx.x;

    // the first element to process by the current thread
242
    int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
243
244
245
246
247
248
    grad += thread_offset;
    output += thread_offset;
    gradInput += thread_offset;

    // load data from global memory
    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
249
250
251
    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
    input_t temp_grad[ELEMENTS_PER_LDG_STG];
    input_t temp_output[ELEMENTS_PER_LDG_STG];
252
253
254
255
256
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        int batch_element_count = (i >= local_batches) ? 0 : local_seq;

        #pragma unroll
257
258
259
260
261
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
            if (element_index < batch_element_count) {
                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE);
262

263
264
265
266
267
268
269
270
271
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    output_reg[i][it + element] = (acc_t)temp_output[element];
                }
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
                }
            }
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
        }
    }
   
    acc_t sum[WARP_BATCH];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        sum[i] = grad_reg[i][0];
        #pragma unroll
        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {
            sum[i] += grad_reg[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);

    // store result
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        if (i >= local_batches)
            break;
        #pragma unroll
292
293
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
294
295
            if (element_index < element_count) {
                // compute gradients
296
297
298
299
300
301
                output_t out[ELEMENTS_PER_LDG_STG];
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
                }
                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out);
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
            } 
        }
    }
}

} // end of anonymous namespace

template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_forward(
    output_t *dst, 
    const input_t *src, 
    const input_t scale, 
    int softmax_elements, 
    int softmax_elements_stride, 
    int attn_batches)
{
    TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 );
    if (softmax_elements == 0) {
        return;
    } else {
        int log2_elements = log2_ceil(softmax_elements);
        const int next_power_of_two = 1 << log2_elements;
        int seq_len = softmax_elements;
        int batch_count = attn_batches * seq_len;

        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;

        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

        // use 128 threads per block to maximimize gpu utilization
        constexpr int threads_per_block = 128;

        int warps_per_block = (threads_per_block / warp_size);
        int batches_per_block = warps_per_block * batches_per_warp;
        TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
        int blocks_per_seq = attn_batches / batches_per_block;
        dim3 blocks(seq_len, blocks_per_seq, 1);
        dim3 threads(warp_size, warps_per_block, 1);
        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
        switch (log2_elements) {
            case 0: // 1
                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 1: // 2
                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 2: // 4
                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 3: // 8
                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 4: // 16
                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 5: // 32
                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 6: // 64
                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 7: // 128
                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 8: // 256
                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 9: // 512
                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 10: // 1024
                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 11: // 2048
                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            default:
                break;
        }
    }
}

template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_backward(
    output_t *grad_input, 
    input_t *grad, 
    const input_t *output, 
    const acc_t scale, 
    int softmax_elements, 
    int softmax_elements_stride, 
    int attn_batches)
{
    TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 );
    if (softmax_elements == 0) {
       return;
    } else {
        int log2_elements = log2_ceil(softmax_elements);
        const int next_power_of_two = 1 << log2_elements;
        int seq_len = softmax_elements;
        int batch_count = attn_batches * seq_len;

        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;

        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

        // use 128 threads per block to maximimize gpu utilization
        constexpr int threads_per_block = 128;

        int warps_per_block = (threads_per_block / warp_size);
        int batches_per_block = warps_per_block * batches_per_warp;
        TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
        int blocks_per_seq = attn_batches / batches_per_block;
        dim3 blocks(seq_len, blocks_per_seq, 1);
        dim3 threads(warp_size, warps_per_block, 1);
        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
        switch (log2_elements) {
            case 0: // 1
                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 1: // 2
                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 2: // 4
                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 3: // 8
                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 4: // 16
                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 5: // 32
                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 6: // 64
                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 7: // 128
                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 8: // 256
                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 9: // 512
                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 10: // 1024
                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            case 11: // 2048
                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
                break;
            default:
                break;
        }
    }
}