scaled_masked_softmax.h 18.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
/* coding=utf-8
 * Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>

namespace {

int log2_ceil(int value) {
    int log2_value = 0;
    while ((1 << log2_value) < value) ++log2_value;
    return log2_value;
}

template<typename T>
struct Add {
  __device__ __forceinline__ T operator()(T a, T b) const {
    return a + b;
  }
};

template<typename T>
struct Max {
  __device__ __forceinline__ T operator()(T a, T b) const {
    return a < b ? b : a;
  }
};

template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
    return __shfl_xor_sync(mask, value, laneMask, width);
#else
    return __shfl_xor(value, laneMask, width);
#endif
}

template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
    ReduceOp<acc_t> r;
    #pragma unroll
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        #pragma unroll
        for (int i = 0;  i < WARP_BATCH;  ++i) {
            acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
            sum[i] = r(sum[i], b);
        }
    }
}

/*
 * Extended softmax (from native aten pytorch) with following additional features
 * 1) input scaling
 * 2) Explicit masking
 */	
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_forward(
    output_t *dst, 
    const input_t *src,
    const uint8_t *mask, 
    const acc_t scale, 
83
    int micro_batch_size, 
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    int element_count,
    int pad_batches) 
{
    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and 
    // warp_size of method warp_softmax_forward_kernel.
    constexpr int next_power_of_two = 1 << log2_elements;
    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;

    // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
    // gridDim/blockIdx = (seq_len, attn_heads, batches) 
    int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
    int pad_first_batch = 0;
    if (pad_batches != 1) { // bert style
    	pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
    } else { // gpt2 style
        pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
    }

104
    // micro_batch_size might not be a multiple of WARP_BATCH. Check how
105
    // many batches have to computed within this WARP.
106
    int local_batches = micro_batch_size - first_batch;
107
108
109
110
111
112
    if (local_batches > WARP_BATCH)
        local_batches = WARP_BATCH;

    // there might be multiple batches per warp. compute the index within the batch
    int local_idx = threadIdx.x;

113
114
115
    src += first_batch * element_count + local_idx;
    dst += first_batch * element_count + local_idx;
    mask += pad_first_batch * element_count + local_idx;
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185

    // load data from global memory
    acc_t elements[WARP_BATCH][WARP_ITERATIONS];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        int batch_element_count = (i >= local_batches) ? 0 : element_count;

	#pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
            int element_index = local_idx + it * WARP_SIZE;
            int itr_idx = i*element_count+it*WARP_SIZE;

            if (element_index < batch_element_count) {
	        if (mask[itr_idx] != 1) {
		    elements[i][it] = (acc_t)src[itr_idx] * scale;
		} else {
                    elements[i][it] = -10000.0;
		} 
            } else {
                elements[i][it] = -std::numeric_limits<acc_t>::infinity();
            }
        }
    }

    // compute max_value
    acc_t max_value[WARP_BATCH];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        max_value[i] = elements[i][0];
        #pragma unroll
        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {
            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);

    acc_t sum[WARP_BATCH] { 0.0f };
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
            elements[i][it] = std::exp((elements[i][it] - max_value[i]));
            sum[i] += elements[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);

    // store result
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        if (i >= local_batches)
            break;
        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
            int element_index = local_idx + it * WARP_SIZE;
            if (element_index < element_count) {
                dst[i*element_count+it*WARP_SIZE] = (output_t)(elements[i][it] / sum[i]);
            } else {
                break;
            } 
        }
    }
}

template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_backward(
    output_t *gradInput, 
    input_t *grad, 
    const input_t *output,
    acc_t scale, 
186
    int micro_batch_size, 
187
188
189
190
191
192
193
194
195
196
197
198
199
    int element_count)
{
    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and 
    // warp_size of method warp_softmax_backward_kernel.
    constexpr int next_power_of_two = 1 << log2_elements;
    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;

    // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
    // gridDim/blockIdx = (seq_len, attn_heads, batches) 
    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
    
200
    // micro_batch_size might not be a multiple of WARP_BATCH. Check how
201
    // many batches have to computed within this WARP.
202
    int local_batches = micro_batch_size - first_batch;
203
204
205
206
207
208
209
    if (local_batches > WARP_BATCH)
        local_batches = WARP_BATCH;

    // there might be multiple batches per warp. compute the index within the batch
    int local_idx = threadIdx.x;

    // the first element to process by the current thread
210
    int thread_offset = first_batch * element_count + local_idx;
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    grad += thread_offset;
    output += thread_offset;
    gradInput += thread_offset;

    // load data from global memory
    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        int batch_element_count = (i >= local_batches) ? 0 : element_count;

        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
            int element_index = local_idx + it * WARP_SIZE;
	    if (element_index < batch_element_count) {
                output_reg[i][it] = output[i*element_count+it*WARP_SIZE];
	    } else {
                output_reg[i][it] = acc_t(0);
            }
        }

       #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
            int element_index = local_idx + it * WARP_SIZE;
	    if (element_index < batch_element_count) {
                grad_reg[i][it] = (acc_t)grad[i*element_count+it*WARP_SIZE] * output_reg[i][it];
	    } else {
                grad_reg[i][it] = acc_t(0);
	    }
        }
    }
   
    acc_t sum[WARP_BATCH];
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        sum[i] = grad_reg[i][0];
        #pragma unroll
        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {
            sum[i] += grad_reg[i][it];
        }
    }
    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);

    // store result
    #pragma unroll
    for (int i = 0;  i < WARP_BATCH;  ++i) {
        if (i >= local_batches)
            break;
        #pragma unroll
        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
            int element_index = local_idx + it * WARP_SIZE;
            if (element_index < element_count) {
                // compute gradients
                gradInput[i*element_count+it*WARP_SIZE] = (output_t)(scale * (grad_reg[i][it] - output_reg[i][it] * sum[i]));
            } 
        }
    }
}

} // end of anonymous namespace

template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(
    output_t *dst, 
    const input_t *src, 
    const uint8_t *mask,
    const input_t scale, 
278
279
    int query_seq_len, 
    int key_seq_len, 
280
281
282
283
    int batches,
    int attn_heads,
    int pad_batches)
{
284
285
    TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
    if (key_seq_len == 0) {
286
287
        return;
    } else {
288
        int log2_elements = log2_ceil(key_seq_len);
289
        const int next_power_of_two = 1 << log2_elements;
290
        int batch_count = batches * attn_heads * query_seq_len;
291
292
293
294
295
296
297
298
299
300
301

        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;

        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

        // use 128 threads per block to maximimize gpu utilization
        constexpr int threads_per_block = 128;

        int warps_per_block = (threads_per_block / warp_size);
302
303
304
	    int batches_per_block = warps_per_block * batches_per_warp;
	    TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
        dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
305
306
307
308
309
        dim3 threads(warp_size, warps_per_block, 1);
        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
        switch (log2_elements) {
            case 0: // 1
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
310
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
311
312
313
                break;
            case 1: // 2
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
314
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
315
316
317
                break;
            case 2: // 4
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
318
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
319
320
321
                break;
            case 3: // 8
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
322
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
323
324
325
                break;
            case 4: // 16
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
326
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
327
328
329
                break;
            case 5: // 32
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
330
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
331
332
333
                break;
            case 6: // 64
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
334
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
335
336
337
                break;
            case 7: // 128
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
338
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
339
340
341
                break;
            case 8: // 256
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
342
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
343
344
345
                break;
            case 9: // 512
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
346
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
347
348
349
                break;
            case 10: // 1024
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
350
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
351
352
353
                break;
            case 11: // 2048
                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
354
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
355
356
357
358
359
360
361
362
363
364
365
366
367
                break;
            default:
                break;
        }
    }
}

template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_backward(
    output_t *grad_input, 
    input_t *grad, 
    const input_t *output, 
    const acc_t scale, 
368
369
    int query_seq_len, 
    int key_seq_len, 
370
371
372
    int batches,
    int attn_heads)
{
373
374
    TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
    if (key_seq_len == 0) {
375
376
       return;
    } else {
377
        int log2_elements = log2_ceil(key_seq_len);
378
        const int next_power_of_two = 1 << log2_elements;
379
        int batch_count = batches *  attn_heads * query_seq_len;
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397

        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;

        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

        // use 128 threads per block to maximimize gpu utilization
        constexpr int threads_per_block = 128;

        int warps_per_block = (threads_per_block / warp_size);
	int batches_per_block = warps_per_block * batches_per_warp;
        int blocks = batch_count/batches_per_block;
        dim3 threads(warp_size, warps_per_block, 1);
        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
        switch (log2_elements) {
            case 0: // 1
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
398
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
399
400
401
                break;
            case 1: // 2
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
402
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
403
404
405
                break;
            case 2: // 4
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
406
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
407
408
409
                break;
            case 3: // 8
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
410
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
411
412
413
                break;
            case 4: // 16
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
414
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
415
416
417
                break;
            case 5: // 32
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
418
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
419
420
421
                break;
            case 6: // 64
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
422
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
423
424
425
                break;
            case 7: // 128
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
426
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
427
428
429
                break;
            case 8: // 256
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
430
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
431
432
433
                break;
            case 9: // 512
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
434
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
435
436
437
                break;
            case 10: // 1024
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
438
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
439
440
441
                break;
            case 11: // 2048
                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
442
                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
443
444
445
446
447
448
                break;
            default:
                break;
        }
    }
}