scaled_masked_softmax.h 30.8 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
2
3
4
5
6
7
8
9
10
11
12
13
14

#pragma once

#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>

namespace {

15
16
17
18
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);

template <>
19
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
20
21

template <>
22
23
24
25
26
27
28
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }

template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }

template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
Vijay Korthikanti's avatar
Vijay Korthikanti committed
29

30
31
32
33
34
35
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }

template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
int log2_ceil(int value) {
    int log2_value = 0;
    while ((1 << log2_value) < value) ++log2_value;
    return log2_value;
}

template<typename T>
struct Add {
  __device__ __forceinline__ T operator()(T a, T b) const {
    return a + b;
  }
};

template<typename T>
struct Max {
  __device__ __forceinline__ T operator()(T a, T b) const {
    return a < b ? b : a;
  }
};

template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
    return __shfl_xor_sync(mask, value, laneMask, width);
#else
    return __shfl_xor(value, laneMask, width);
#endif
}

template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
    ReduceOp<acc_t> r;
    #pragma unroll
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        #pragma unroll
        for (int i = 0;  i < WARP_BATCH;  ++i) {
            acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
            sum[i] = r(sum[i], b);
        }
    }
}

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189

/*
 * Extended softmax (from native aten pytorch) with following additional features
 * 1) input scaling
 */	
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_softmax_warp_forward(
    output_t *dst, 
    const input_t *src,
    const acc_t scale, 
    int micro_batch_size, 
    int element_count)
{
    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and 
    // warp_size of method warp_softmax_forward_kernel.
    constexpr int next_power_of_two = 1 << log2_elements;
    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;

    // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
    // gridDim/blockIdx = (seq_len, attn_heads, batches) 
    int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;

    // micro_batch_size might not be a multiple of WARP_BATCH. Check how
    // many batches have to computed within this WARP.
    int local_batches = micro_batch_size - first_batch;
    if (local_batches > WARP_BATCH)
        local_batches = WARP_BATCH;

    // there might be multiple batches per warp. compute the index within the batch
    int local_idx = threadIdx.x;

    src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
    dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;

    // load data from global memory
    acc_t elements[WARP_BATCH][WARP_ITERATIONS];
    input_t temp_data[ELEMENTS_PER_LDG_STG];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        int batch_element_count = (i >= local_batches) ? 0 : element_count;

        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;

            if (element_index < batch_element_count) {
                int itr_idx = i*element_count+it*WARP_SIZE;
                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);

                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    elements[i][it + element] = (acc_t)temp_data[element] * scale;
                }
            } else {
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
                }
            }
        }
    }

    // compute max_value
    acc_t max_value[WARP_BATCH];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        max_value[i] = elements[i][0];
        #pragma unroll
        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {
            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);

    acc_t sum[WARP_BATCH] { 0.0f };
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
            elements[i][it] = std::exp((elements[i][it] - max_value[i]));
            sum[i] += elements[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);

    // store result
    output_t out[ELEMENTS_PER_LDG_STG];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        if (i >= local_batches)
            break;
        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
            if (element_index < element_count) {
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    out[element] = elements[i][it + element] / sum[i];
                }
                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);  
            } else {
                break;
            } 
        }
    }
}


190
191
192
193
194
195
196
197
198
199
200
/*
 * Extended softmax (from native aten pytorch) with following additional features
 * 1) input scaling
 * 2) Explicit masking
 */	
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_forward(
    output_t *dst, 
    const input_t *src,
    const uint8_t *mask, 
    const acc_t scale, 
201
    int micro_batch_size, 
202
203
204
205
206
207
208
209
210
    int element_count,
    int pad_batches) 
{
    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and 
    // warp_size of method warp_softmax_forward_kernel.
    constexpr int next_power_of_two = 1 << log2_elements;
    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
211
    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
212
213
214
215
216
217

    // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
    // gridDim/blockIdx = (seq_len, attn_heads, batches) 
    int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
    int pad_first_batch = 0;
    if (pad_batches != 1) { // bert style
218
        pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
219
220
221
222
    } else { // gpt2 style
        pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
    }

223
    // micro_batch_size might not be a multiple of WARP_BATCH. Check how
224
    // many batches have to computed within this WARP.
225
    int local_batches = micro_batch_size - first_batch;
226
227
228
229
230
231
    if (local_batches > WARP_BATCH)
        local_batches = WARP_BATCH;

    // there might be multiple batches per warp. compute the index within the batch
    int local_idx = threadIdx.x;

232
233
234
    src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
    dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
    mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
235
236
237

    // load data from global memory
    acc_t elements[WARP_BATCH][WARP_ITERATIONS];
238
239
    input_t temp_data[ELEMENTS_PER_LDG_STG];
    uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
240
241
242
243
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        int batch_element_count = (i >= local_batches) ? 0 : element_count;

244
245
246
        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
247
248

            if (element_index < batch_element_count) {
249
250
251
252
253
254
255
256
257
258
259
260
                int itr_idx = i*element_count+it*WARP_SIZE;
                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
                copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);

                #pragma unroll
                  for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                      if (temp_mask[element] != 1) {
                          elements[i][it + element] = (acc_t)temp_data[element] * scale;
                      } else {
                          elements[i][it + element] = -10000.0;
                      }
                  }
261
            } else {
262
263
264
265
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
                }
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
            }
        }
    }

    // compute max_value
    acc_t max_value[WARP_BATCH];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        max_value[i] = elements[i][0];
        #pragma unroll
        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {
            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);

282
283
284
285
286
287
288
    // compute scale value to account for full mask
    acc_t scale_value[WARP_BATCH];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0;
    }

289
290
291
292
293
294
295
296
297
298
299
300
    acc_t sum[WARP_BATCH] { 0.0f };
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
            elements[i][it] = std::exp((elements[i][it] - max_value[i]));
            sum[i] += elements[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);

    // store result
301
    output_t out[ELEMENTS_PER_LDG_STG];
302
303
304
305
306
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        if (i >= local_batches)
            break;
        #pragma unroll
307
308
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
309
            if (element_index < element_count) {
310
311
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
312
                    out[element] = elements[i][it + element] * scale_value[i] / sum[i];
313
314
                }
                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);  
315
316
317
318
319
320
321
322
323
324
325
326
327
            } else {
                break;
            } 
        }
    }
}

template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_backward(
    output_t *gradInput, 
    input_t *grad, 
    const input_t *output,
    acc_t scale, 
328
    int micro_batch_size, 
329
330
331
332
333
334
335
336
    int element_count)
{
    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and 
    // warp_size of method warp_softmax_backward_kernel.
    constexpr int next_power_of_two = 1 << log2_elements;
    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
337
    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
338
339
340
341
342

    // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
    // gridDim/blockIdx = (seq_len, attn_heads, batches) 
    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
    
343
    // micro_batch_size might not be a multiple of WARP_BATCH. Check how
344
    // many batches have to computed within this WARP.
345
    int local_batches = micro_batch_size - first_batch;
346
347
348
349
350
351
352
    if (local_batches > WARP_BATCH)
        local_batches = WARP_BATCH;

    // there might be multiple batches per warp. compute the index within the batch
    int local_idx = threadIdx.x;

    // the first element to process by the current thread
353
    int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
354
355
356
357
358
359
    grad += thread_offset;
    output += thread_offset;
    gradInput += thread_offset;

    // load data from global memory
    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
360
    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
Vijay Korthikanti's avatar
Vijay Korthikanti committed
361
362
    input_t temp_grad[ELEMENTS_PER_LDG_STG];
    input_t temp_output[ELEMENTS_PER_LDG_STG];
363
364
365
366
367
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        int batch_element_count = (i >= local_batches) ? 0 : element_count;

        #pragma unroll
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
            if (element_index < batch_element_count) {
                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);
                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);

                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    output_reg[i][it + element] = (acc_t)temp_output[element];
                }
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
                }
            } 
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        }
    }
   
    acc_t sum[WARP_BATCH];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        sum[i] = grad_reg[i][0];
        #pragma unroll
        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {
            sum[i] += grad_reg[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);

    // store result
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        if (i >= local_batches)
            break;
        #pragma unroll
403
404
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
405
406
            if (element_index < element_count) {
                // compute gradients
407
408
409
410
411
412
                output_t out[ELEMENTS_PER_LDG_STG];
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
                }
                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
413
414
415
416
417
418
            } 
        }
    }
}
} // end of anonymous namespace

419
420
421
422
423
424
425
426
427
428
429
430
431
432
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
    int log2_elements = log2_ceil(key_seq_len);
    const int next_power_of_two = 1 << log2_elements;

    int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

    constexpr int threads_per_block = 128;
    int warps_per_block = (threads_per_block / warp_size);
    int batches_per_block = warps_per_block * batches_per_warp;

    return batches_per_block;
}

433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_softmax_forward(
    output_t *dst, 
    const input_t *src, 
    const input_t scale, 
    int query_seq_len, 
    int key_seq_len, 
    int batches,
    int attn_heads)
{
    TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
    if (key_seq_len == 0) {
        return;
    } else {
        int log2_elements = log2_ceil(key_seq_len);
        const int next_power_of_two = 1 << log2_elements;
        int batch_count = batches * attn_heads * query_seq_len;

        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;

        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

        // use 128 threads per block to maximimize gpu utilization
        constexpr int threads_per_block = 128;

        int warps_per_block = (threads_per_block / warp_size);
        int batches_per_block = warps_per_block * batches_per_warp;
        TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
        dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
        dim3 threads(warp_size, warps_per_block, 1);
        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
        switch (log2_elements) {
            case 0: // 1
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 0>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 1: // 2
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 1>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 2: // 4
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 2>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 3: // 8
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 3>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 4: // 16
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 4>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 5: // 32
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 5>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 6: // 64
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 6>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 7: // 128
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 7>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 8: // 256
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 8>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 9: // 512
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 9>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 10: // 1024
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 10>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 11: // 2048
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 11>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 12: // 4096
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 12>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            default:
                break;
        }
    }
}

525
526
527
528
529
530
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(
    output_t *dst, 
    const input_t *src, 
    const uint8_t *mask,
    const input_t scale, 
531
532
    int query_seq_len, 
    int key_seq_len, 
533
534
535
536
    int batches,
    int attn_heads,
    int pad_batches)
{
537
    TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
538
    if (key_seq_len == 0) {
539
540
        return;
    } else {
541
        int log2_elements = log2_ceil(key_seq_len);
542
        const int next_power_of_two = 1 << log2_elements;
543
        int batch_count = batches * attn_heads * query_seq_len;
544
545
546
547
548
549
550
551
552
553
554

        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;

        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

        // use 128 threads per block to maximimize gpu utilization
        constexpr int threads_per_block = 128;

        int warps_per_block = (threads_per_block / warp_size);
555
        int batches_per_block = warps_per_block * batches_per_warp;
hyunwoongko's avatar
hyunwoongko committed
556
        TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
557
        dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
558
559
560
561
562
        dim3 threads(warp_size, warps_per_block, 1);
        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
        switch (log2_elements) {
            case 0: // 1
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
563
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
564
565
566
                break;
            case 1: // 2
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
567
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
568
569
570
                break;
            case 2: // 4
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
571
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
572
573
574
                break;
            case 3: // 8
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
575
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
576
577
578
                break;
            case 4: // 16
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
579
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
580
581
582
                break;
            case 5: // 32
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
583
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
584
585
586
                break;
            case 6: // 64
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
587
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
588
589
590
                break;
            case 7: // 128
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
591
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
592
593
594
                break;
            case 8: // 256
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
595
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
596
597
598
                break;
            case 9: // 512
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
599
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
600
601
602
                break;
            case 10: // 1024
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
603
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
604
605
606
                break;
            case 11: // 2048
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
607
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
608
                break;
609
610
611
612
            case 12: // 4096
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
                break;
613
614
615
616
617
618
619
620
621
622
623
624
            default:
                break;
        }
    }
}

template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_backward(
    output_t *grad_input, 
    input_t *grad, 
    const input_t *output, 
    const acc_t scale, 
625
626
    int query_seq_len, 
    int key_seq_len, 
627
628
629
    int batches,
    int attn_heads)
{
630
    TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 4096 );
631
    if (key_seq_len == 0) {
632
633
       return;
    } else {
634
        int log2_elements = log2_ceil(key_seq_len);
635
        const int next_power_of_two = 1 << log2_elements;
636
        int batch_count = batches *  attn_heads * query_seq_len;
637
638
639
640
641
642
643
644
645
646
647

        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;

        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

        // use 128 threads per block to maximimize gpu utilization
        constexpr int threads_per_block = 128;

        int warps_per_block = (threads_per_block / warp_size);
648
        int batches_per_block = warps_per_block * batches_per_warp;
649
650
651
652
653
654
        int blocks = batch_count/batches_per_block;
        dim3 threads(warp_size, warps_per_block, 1);
        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
        switch (log2_elements) {
            case 0: // 1
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
655
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
656
657
658
                break;
            case 1: // 2
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
659
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
660
661
662
                break;
            case 2: // 4
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
663
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
664
665
666
                break;
            case 3: // 8
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
667
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
668
669
670
                break;
            case 4: // 16
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
671
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
672
673
674
                break;
            case 5: // 32
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
675
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
676
677
678
                break;
            case 6: // 64
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
679
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
680
681
682
                break;
            case 7: // 128
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
683
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
684
685
686
                break;
            case 8: // 256
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
687
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
688
689
690
                break;
            case 9: // 512
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
691
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
692
693
694
                break;
            case 10: // 1024
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
695
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
696
697
698
                break;
            case 11: // 2048
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
699
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
700
                break;
701
702
703
704
705
			case 12: // 4096
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
                break;

706
707
708
709
710
            default:
                break;
        }
    }
}