scaled_masked_softmax.h 31.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/* coding=utf-8
 * Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>

namespace {

29
30
31
32
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);

template <>
33
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
34
35

template <>
36
37
38
39
40
41
42
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }

template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }

template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
Vijay Korthikanti's avatar
Vijay Korthikanti committed
43

44
45
46
47
48
49
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }

template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
int log2_ceil(int value) {
    int log2_value = 0;
    while ((1 << log2_value) < value) ++log2_value;
    return log2_value;
}

template<typename T>
struct Add {
  __device__ __forceinline__ T operator()(T a, T b) const {
    return a + b;
  }
};

template<typename T>
struct Max {
  __device__ __forceinline__ T operator()(T a, T b) const {
    return a < b ? b : a;
  }
};

template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
    return __shfl_xor_sync(mask, value, laneMask, width);
#else
    return __shfl_xor(value, laneMask, width);
#endif
}

template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
    ReduceOp<acc_t> r;
    #pragma unroll
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        #pragma unroll
        for (int i = 0;  i < WARP_BATCH;  ++i) {
            acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
            sum[i] = r(sum[i], b);
        }
    }
}

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

/*
 * Extended softmax (from native aten pytorch) with following additional features
 * 1) input scaling
 */	
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_softmax_warp_forward(
    output_t *dst, 
    const input_t *src,
    const acc_t scale, 
    int micro_batch_size, 
    int element_count)
{
    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and 
    // warp_size of method warp_softmax_forward_kernel.
    constexpr int next_power_of_two = 1 << log2_elements;
    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;

    // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
    // gridDim/blockIdx = (seq_len, attn_heads, batches) 
    int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;

    // micro_batch_size might not be a multiple of WARP_BATCH. Check how
    // many batches have to computed within this WARP.
    int local_batches = micro_batch_size - first_batch;
    if (local_batches > WARP_BATCH)
        local_batches = WARP_BATCH;

    // there might be multiple batches per warp. compute the index within the batch
    int local_idx = threadIdx.x;

    src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
    dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;

    // load data from global memory
    acc_t elements[WARP_BATCH][WARP_ITERATIONS];
    input_t temp_data[ELEMENTS_PER_LDG_STG];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        int batch_element_count = (i >= local_batches) ? 0 : element_count;

        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;

            if (element_index < batch_element_count) {
                int itr_idx = i*element_count+it*WARP_SIZE;
                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);

                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    elements[i][it + element] = (acc_t)temp_data[element] * scale;
                }
            } else {
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
                }
            }
        }
    }

    // compute max_value
    acc_t max_value[WARP_BATCH];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        max_value[i] = elements[i][0];
        #pragma unroll
        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {
            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);

    acc_t sum[WARP_BATCH] { 0.0f };
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
            elements[i][it] = std::exp((elements[i][it] - max_value[i]));
            sum[i] += elements[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);

    // store result
    output_t out[ELEMENTS_PER_LDG_STG];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        if (i >= local_batches)
            break;
        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
            if (element_index < element_count) {
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    out[element] = elements[i][it + element] / sum[i];
                }
                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);  
            } else {
                break;
            } 
        }
    }
}


204
205
206
207
208
209
210
211
212
213
214
/*
 * Extended softmax (from native aten pytorch) with following additional features
 * 1) input scaling
 * 2) Explicit masking
 */	
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_forward(
    output_t *dst, 
    const input_t *src,
    const uint8_t *mask, 
    const acc_t scale, 
215
    int micro_batch_size, 
216
217
218
219
220
221
222
223
224
    int element_count,
    int pad_batches) 
{
    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and 
    // warp_size of method warp_softmax_forward_kernel.
    constexpr int next_power_of_two = 1 << log2_elements;
    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
225
    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
226
227
228
229
230
231

    // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
    // gridDim/blockIdx = (seq_len, attn_heads, batches) 
    int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
    int pad_first_batch = 0;
    if (pad_batches != 1) { // bert style
232
        pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
233
234
235
236
    } else { // gpt2 style
        pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
    }

237
    // micro_batch_size might not be a multiple of WARP_BATCH. Check how
238
    // many batches have to computed within this WARP.
239
    int local_batches = micro_batch_size - first_batch;
240
241
242
243
244
245
    if (local_batches > WARP_BATCH)
        local_batches = WARP_BATCH;

    // there might be multiple batches per warp. compute the index within the batch
    int local_idx = threadIdx.x;

246
247
248
    src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
    dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
    mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
249
250
251

    // load data from global memory
    acc_t elements[WARP_BATCH][WARP_ITERATIONS];
252
253
    input_t temp_data[ELEMENTS_PER_LDG_STG];
    uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
254
255
256
257
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        int batch_element_count = (i >= local_batches) ? 0 : element_count;

258
259
260
        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
261
262

            if (element_index < batch_element_count) {
263
264
265
266
267
268
269
270
271
272
273
274
                int itr_idx = i*element_count+it*WARP_SIZE;
                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
                copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);

                #pragma unroll
                  for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                      if (temp_mask[element] != 1) {
                          elements[i][it + element] = (acc_t)temp_data[element] * scale;
                      } else {
                          elements[i][it + element] = -10000.0;
                      }
                  }
275
            } else {
276
277
278
279
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
                }
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
            }
        }
    }

    // compute max_value
    acc_t max_value[WARP_BATCH];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        max_value[i] = elements[i][0];
        #pragma unroll
        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {
            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);

296
297
298
299
300
301
302
    // compute scale value to account for full mask
    acc_t scale_value[WARP_BATCH];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0;
    }

303
304
305
306
307
308
309
310
311
312
313
314
    acc_t sum[WARP_BATCH] { 0.0f };
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
            elements[i][it] = std::exp((elements[i][it] - max_value[i]));
            sum[i] += elements[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);

    // store result
315
    output_t out[ELEMENTS_PER_LDG_STG];
316
317
318
319
320
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        if (i >= local_batches)
            break;
        #pragma unroll
321
322
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
323
            if (element_index < element_count) {
324
325
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
326
                    out[element] = elements[i][it + element] * scale_value[i] / sum[i];
327
328
                }
                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);  
329
330
331
332
333
334
335
336
337
338
339
340
341
            } else {
                break;
            } 
        }
    }
}

template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_backward(
    output_t *gradInput, 
    input_t *grad, 
    const input_t *output,
    acc_t scale, 
342
    int micro_batch_size, 
343
344
345
346
347
348
349
350
    int element_count)
{
    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and 
    // warp_size of method warp_softmax_backward_kernel.
    constexpr int next_power_of_two = 1 << log2_elements;
    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
351
    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
352
353
354
355
356

    // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
    // gridDim/blockIdx = (seq_len, attn_heads, batches) 
    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
    
357
    // micro_batch_size might not be a multiple of WARP_BATCH. Check how
358
    // many batches have to computed within this WARP.
359
    int local_batches = micro_batch_size - first_batch;
360
361
362
363
364
365
366
    if (local_batches > WARP_BATCH)
        local_batches = WARP_BATCH;

    // there might be multiple batches per warp. compute the index within the batch
    int local_idx = threadIdx.x;

    // the first element to process by the current thread
367
    int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
368
369
370
371
372
373
    grad += thread_offset;
    output += thread_offset;
    gradInput += thread_offset;

    // load data from global memory
    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
374
    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
Vijay Korthikanti's avatar
Vijay Korthikanti committed
375
376
    input_t temp_grad[ELEMENTS_PER_LDG_STG];
    input_t temp_output[ELEMENTS_PER_LDG_STG];
377
378
379
380
381
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        int batch_element_count = (i >= local_batches) ? 0 : element_count;

        #pragma unroll
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
            if (element_index < batch_element_count) {
                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);
                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);

                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    output_reg[i][it + element] = (acc_t)temp_output[element];
                }
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
                }
            } 
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
        }
    }
   
    acc_t sum[WARP_BATCH];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        sum[i] = grad_reg[i][0];
        #pragma unroll
        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {
            sum[i] += grad_reg[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);

    // store result
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        if (i >= local_batches)
            break;
        #pragma unroll
417
418
        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
419
420
            if (element_index < element_count) {
                // compute gradients
421
422
423
424
425
426
                output_t out[ELEMENTS_PER_LDG_STG];
                #pragma unroll
                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
                    out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
                }
                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
427
428
429
430
431
432
            } 
        }
    }
}
} // end of anonymous namespace

433
434
435
436
437
438
439
440
441
442
443
444
445
446
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
    int log2_elements = log2_ceil(key_seq_len);
    const int next_power_of_two = 1 << log2_elements;

    int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

    constexpr int threads_per_block = 128;
    int warps_per_block = (threads_per_block / warp_size);
    int batches_per_block = warps_per_block * batches_per_warp;

    return batches_per_block;
}

447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_softmax_forward(
    output_t *dst, 
    const input_t *src, 
    const input_t scale, 
    int query_seq_len, 
    int key_seq_len, 
    int batches,
    int attn_heads)
{
    TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
    if (key_seq_len == 0) {
        return;
    } else {
        int log2_elements = log2_ceil(key_seq_len);
        const int next_power_of_two = 1 << log2_elements;
        int batch_count = batches * attn_heads * query_seq_len;

        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;

        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

        // use 128 threads per block to maximimize gpu utilization
        constexpr int threads_per_block = 128;

        int warps_per_block = (threads_per_block / warp_size);
        int batches_per_block = warps_per_block * batches_per_warp;
        TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
        dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
        dim3 threads(warp_size, warps_per_block, 1);
        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
        switch (log2_elements) {
            case 0: // 1
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 0>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 1: // 2
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 1>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 2: // 4
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 2>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 3: // 8
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 3>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 4: // 16
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 4>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 5: // 32
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 5>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 6: // 64
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 6>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 7: // 128
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 7>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 8: // 256
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 8>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 9: // 512
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 9>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 10: // 1024
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 10>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 11: // 2048
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 11>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            case 12: // 4096
                scaled_softmax_warp_forward<input_t, output_t, acc_t, 12>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, key_seq_len);
                break;
            default:
                break;
        }
    }
}

539
540
541
542
543
544
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(
    output_t *dst, 
    const input_t *src, 
    const uint8_t *mask,
    const input_t scale, 
545
546
    int query_seq_len, 
    int key_seq_len, 
547
548
549
550
    int batches,
    int attn_heads,
    int pad_batches)
{
551
    TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 );
552
    if (key_seq_len == 0) {
553
554
        return;
    } else {
555
        int log2_elements = log2_ceil(key_seq_len);
556
        const int next_power_of_two = 1 << log2_elements;
557
        int batch_count = batches * attn_heads * query_seq_len;
558
559
560
561
562
563
564
565
566
567
568

        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;

        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

        // use 128 threads per block to maximimize gpu utilization
        constexpr int threads_per_block = 128;

        int warps_per_block = (threads_per_block / warp_size);
569
        int batches_per_block = warps_per_block * batches_per_warp;
hyunwoongko's avatar
hyunwoongko committed
570
        TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
571
        dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
572
573
574
575
576
        dim3 threads(warp_size, warps_per_block, 1);
        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
        switch (log2_elements) {
            case 0: // 1
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
577
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
578
579
580
                break;
            case 1: // 2
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
581
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
582
583
584
                break;
            case 2: // 4
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
585
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
586
587
588
                break;
            case 3: // 8
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
589
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
590
591
592
                break;
            case 4: // 16
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
593
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
594
595
596
                break;
            case 5: // 32
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
597
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
598
599
600
                break;
            case 6: // 64
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
601
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
602
603
604
                break;
            case 7: // 128
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
605
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
606
607
608
                break;
            case 8: // 256
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
609
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
610
611
612
                break;
            case 9: // 512
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
613
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
614
615
616
                break;
            case 10: // 1024
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
617
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
618
619
620
                break;
            case 11: // 2048
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
621
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
622
                break;
623
624
625
626
            case 12: // 4096
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
                break;
627
628
629
630
631
632
633
634
635
636
637
638
            default:
                break;
        }
    }
}

template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_backward(
    output_t *grad_input, 
    input_t *grad, 
    const input_t *output, 
    const acc_t scale, 
639
640
    int query_seq_len, 
    int key_seq_len, 
641
642
643
    int batches,
    int attn_heads)
{
644
    TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 4096 );
645
    if (key_seq_len == 0) {
646
647
       return;
    } else {
648
        int log2_elements = log2_ceil(key_seq_len);
649
        const int next_power_of_two = 1 << log2_elements;
650
        int batch_count = batches *  attn_heads * query_seq_len;
651
652
653
654
655
656
657
658
659
660
661

        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;

        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

        // use 128 threads per block to maximimize gpu utilization
        constexpr int threads_per_block = 128;

        int warps_per_block = (threads_per_block / warp_size);
662
        int batches_per_block = warps_per_block * batches_per_warp;
663
664
665
666
667
668
        int blocks = batch_count/batches_per_block;
        dim3 threads(warp_size, warps_per_block, 1);
        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
        switch (log2_elements) {
            case 0: // 1
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
669
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
670
671
672
                break;
            case 1: // 2
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
673
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
674
675
676
                break;
            case 2: // 4
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
677
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
678
679
680
                break;
            case 3: // 8
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
681
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
682
683
684
                break;
            case 4: // 16
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
685
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
686
687
688
                break;
            case 5: // 32
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
689
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
690
691
692
                break;
            case 6: // 64
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
693
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
694
695
696
                break;
            case 7: // 128
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
697
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
698
699
700
                break;
            case 8: // 256
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
701
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
702
703
704
                break;
            case 9: // 512
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
705
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
706
707
708
                break;
            case 10: // 1024
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
709
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
710
711
712
                break;
            case 11: // 2048
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
713
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
714
                break;
715
716
717
718
719
			case 12: // 4096
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
                break;

720
721
722
723
724
            default:
                break;
        }
    }
}