"vscode:/vscode.git/clone" did not exist on "72adf7a4155b6642da1e6761678fe546590c0269"
softmax_fast.h 20.1 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
#pragma once
#include <iostream>
#include <type_traits>
#include <limits>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <curand_kernel.h>
#include "util.h"

template <int N>
12
13
14
15
using IntegerBits = typename std::conditional<N <= 8, uint8_t,
                                              typename std::conditional<N <= 16, uint16_t,
                                                                        typename std::conditional<N <= 32, uint32_t,
                                                                                                  typename std::conditional<N <= 64, uint64_t, void>::type>::type>::type>::type;
Guolin Ke's avatar
Guolin Ke committed
16
17

template <int LogElements>
18
19
struct SoftmaxParameters
{
Guolin Ke's avatar
Guolin Ke committed
20
21
22
23
24
25
26
27
28
    static_assert(LogElements <= 11, "");
    static constexpr int Elements = 1 << LogElements;
    static constexpr int WarpBatch = Elements <= 128 ? 2 : 1;
    static constexpr int WarpIterations = Elements <= 32 ? 1 : Elements / 32;
    using MaskType = IntegerBits<WarpIterations>;
    static constexpr int WarpSize = Elements <= 32 ? Elements : 32;
    static constexpr int MaskStride = WarpSize;
};

29
30
inline int log2_ceil(int value)
{
Guolin Ke's avatar
Guolin Ke committed
31
    int log2_value = 0;
32
33
    while ((1 << log2_value) < value)
        ++log2_value;
Guolin Ke's avatar
Guolin Ke committed
34
35
36
    return log2_value;
}

37
38
39
40
inline at::ScalarType softmax_mask_dtype(int elements)
{
    if (elements > 1024)
    {
Guolin Ke's avatar
Guolin Ke committed
41
        return torch::kInt64;
42
43
44
    }
    else if (elements > 512)
    {
Guolin Ke's avatar
Guolin Ke committed
45
        return torch::kInt32;
46
47
48
    }
    else if (elements > 256)
    {
Guolin Ke's avatar
Guolin Ke committed
49
50
51
52
53
        return torch::kInt16;
    }
    return torch::kInt8;
}

54
55
inline int softmax_mask_size(int batch_size, int elements)
{
Guolin Ke's avatar
Guolin Ke committed
56
57
58
59
60
61
    int log2_elements = log2_ceil(elements);
    int e = 1 << log2_elements;
    int warp_size = e < 32 ? e : 32;
    return batch_size * warp_size;
}

62
63
inline int softmax_rng_delta_offset(int elements)
{
Guolin Ke's avatar
Guolin Ke committed
64
65
66
67
68
69
70
71
72
    int log2_elements = log2_ceil(elements);
    int e = 1 << log2_elements;
    int warp_iterations = e <= 32 ? 1 : e / 32;
    int warp_batch = e <= 128 ? 2 : 1;
    return warp_iterations * warp_batch;
}

template <
    typename input_t, typename output_t, typename acc_t,
73
74
75
76
    typename Parameters, bool NeedMask, bool NeedBias, bool NeedAttnMask>
__global__ void softmax_warp_forward(input_t *dst, input_t *dst_orig, const output_t *src, const input_t *attn_mask, const input_t *bias,
                                     typename Parameters::MaskType *mask, acc_t p, int64_t batch_size, int64_t attn_inner_skip_batch, int64_t bias_batch_size, int element_count, uint64_t seed, uint64_t rand_offset)
{
Guolin Ke's avatar
Guolin Ke committed
77
78
79
80
81
82
    using MaskType = typename Parameters::MaskType;
    curandStatePhilox4_32_10_t state;
    int64_t first_batch = (static_cast<int64_t>(blockDim.y) * static_cast<int64_t>(blockIdx.x) + threadIdx.y) * Parameters::WarpBatch;
    // there might be multiple batches per warp. compute the index within the batch
    int64_t local_idx = threadIdx.x;
    const int64_t thread_offset = first_batch * element_count + local_idx;
83
84
    if IF_CONSTEXPR (NeedMask)
    {
Guolin Ke's avatar
Guolin Ke committed
85
86
        curand_init(seed, thread_offset, rand_offset, &state);
    }
87

Guolin Ke's avatar
Guolin Ke committed
88
89
90
91
92
    // batch_size might not be a multiple of Parameters::WarpBatch. Check how
    // many batches have to computed within this WARP.
    int local_batches = batch_size - first_batch;
    if (local_batches > Parameters::WarpBatch)
        local_batches = Parameters::WarpBatch;
93

Guolin Ke's avatar
Guolin Ke committed
94
95
    src += thread_offset;
    dst += thread_offset;
96
97
    if IF_CONSTEXPR (NeedMask)
    {
Guolin Ke's avatar
Guolin Ke committed
98
99
100
        dst_orig += thread_offset;
        mask += first_batch * Parameters::MaskStride;
    }
101
102
103
104
105
106
107
108
109

    int64_t bias_mod_size = bias_batch_size * element_count;

    int64_t attn_mask_div_size = element_count;
    if IF_CONSTEXPR (NeedAttnMask)
    {
        attn_mask_div_size = attn_inner_skip_batch * element_count;
    }

Guolin Ke's avatar
Guolin Ke committed
110
111
    // load data from global memory
    input_t elements_input[Parameters::WarpBatch][Parameters::WarpIterations];
112
113
114
#pragma unroll
    for (int i = 0; i < Parameters::WarpBatch; ++i)
    {
Guolin Ke's avatar
Guolin Ke committed
115
        int batch_element_count = (i >= local_batches) ? 0 : element_count;
116
117
118
#pragma unroll
        for (int it = 0; it < Parameters::WarpIterations; ++it)
        {
Guolin Ke's avatar
Guolin Ke committed
119
120
            int element_index = local_idx + it * Parameters::WarpSize;
            elements_input[i][it] = -std::numeric_limits<float>::infinity();
121
122
123

            if (element_index < batch_element_count)
            {
Guolin Ke's avatar
Guolin Ke committed
124
125
126
127
                elements_input[i][it] = src[i * element_count + it * Parameters::WarpSize];
            }
        }
    }
128

Guolin Ke's avatar
Guolin Ke committed
129
130
    // convert input_t to acc_t
    acc_t elements[Parameters::WarpBatch][Parameters::WarpIterations];
131
132
133
134
135
136
137
#pragma unroll
    for (int i = 0; i < Parameters::WarpBatch; ++i)
    {
        int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
        for (int it = 0; it < Parameters::WarpIterations; ++it)
        {
Guolin Ke's avatar
Guolin Ke committed
138
            elements[i][it] = elements_input[i][it];
139
140
141
142
143
144
145
146
147
148
149
150
151
152
            int element_index = local_idx + it * Parameters::WarpSize;
            if (element_index < batch_element_count)
            {
                int64_t global_idx = thread_offset + i * element_count + it * Parameters::WarpSize;
                if IF_CONSTEXPR (NeedAttnMask)
                {
                    auto attn_mask_idx = static_cast<int64_t>(global_idx / attn_mask_div_size) * element_count + (global_idx % element_count);
                    elements[i][it] += attn_mask[attn_mask_idx];
                }
                if IF_CONSTEXPR (NeedBias)
                {
                    elements[i][it] += bias[global_idx % bias_mod_size];
                }
            }
Guolin Ke's avatar
Guolin Ke committed
153
154
        }
    }
155

Guolin Ke's avatar
Guolin Ke committed
156
    // compute local max_value
157

Guolin Ke's avatar
Guolin Ke committed
158
159
    // take the max_value of the first element to avoid one max call
    acc_t max_value[Parameters::WarpBatch];
160
161
162
#pragma unroll
    for (int i = 0; i < Parameters::WarpBatch; ++i)
    {
Guolin Ke's avatar
Guolin Ke committed
163
164
        max_value[i] = elements[i][0];
    }
165
166
167
168
169
170
171

#pragma unroll
    for (int it = 1; it < Parameters::WarpIterations; ++it)
    {
#pragma unroll
        for (int i = 0; i < Parameters::WarpBatch; ++i)
        {
Guolin Ke's avatar
Guolin Ke committed
172
173
174
175
            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
        }
    }

176
177
178
179
// reduction max_value
#pragma unroll
    for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2)
    {
Guolin Ke's avatar
Guolin Ke committed
180
        float val[Parameters::WarpBatch];
181
182
183
#pragma unroll
        for (int i = 0; i < Parameters::WarpBatch; ++i)
        {
Guolin Ke's avatar
Guolin Ke committed
184
185
            val[i] = SHFL_XOR(max_value[i], offset, Parameters::WarpSize);
        }
186
187
188
#pragma unroll
        for (int i = 0; i < Parameters::WarpBatch; ++i)
        {
Guolin Ke's avatar
Guolin Ke committed
189
190
191
            max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i];
        }
    }
192

Guolin Ke's avatar
Guolin Ke committed
193
    // compute local sum
194
195
196
197
198
199
200
201
    acc_t sum[Parameters::WarpBatch]{0.0f};

#pragma unroll
    for (int i = 0; i < Parameters::WarpBatch; ++i)
    {
#pragma unroll
        for (int it = 0; it < Parameters::WarpIterations; ++it)
        {
Guolin Ke's avatar
Guolin Ke committed
202
203
204
205
            elements[i][it] = std::exp(elements[i][it] - max_value[i]);
            sum[i] += elements[i][it];
        }
    }
206
207
208
209
210
211
212
213

// reduction sum
#pragma unroll
    for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2)
    {
#pragma unroll
        for (int i = 0; i < Parameters::WarpBatch; ++i)
        {
Guolin Ke's avatar
Guolin Ke committed
214
215
216
217
218
            sum[i] += SHFL_XOR(sum[i], offset, Parameters::WarpSize);
        }
    }

    // store result
219
220
    if IF_CONSTEXPR (NeedMask)
    {
Guolin Ke's avatar
Guolin Ke committed
221
        const acc_t pinv = 1.0 / p;
222
223
224
#pragma unroll
        for (int i = 0; i < Parameters::WarpBatch; ++i)
        {
Guolin Ke's avatar
Guolin Ke committed
225
226
227
            if (i >= local_batches)
                break;
            MaskType m = 0;
228
229
            if IF_CONSTEXPR (Parameters::WarpIterations == 1)
            {
Guolin Ke's avatar
Guolin Ke committed
230
231
                float rand = curand_uniform(&state);
                m = rand < p;
232
233
234
            }
            else if IF_CONSTEXPR (Parameters::WarpIterations == 2)
            {
Guolin Ke's avatar
Guolin Ke committed
235
236
                m = curand_uniform(&state) < p;
                m |= (curand_uniform(&state) < p) << 1;
237
238
239
240
241
242
            }
            else
            {
#pragma unroll
                for (int j = 0; j < DIV_CELL(Parameters::WarpIterations, 4); ++j)
                {
Guolin Ke's avatar
Guolin Ke committed
243
                    float4 rand4 = curand_uniform4(&state);
244
                    m |= (((MaskType)(rand4.x < p)) << (j * 4)) | (((MaskType)(rand4.y < p)) << (j * 4 + 1)) | (((MaskType)(rand4.z < p)) << (j * 4 + 2)) | (((MaskType)(rand4.w < p)) << (j * 4 + 3));
Guolin Ke's avatar
Guolin Ke committed
245
246
247
                }
            }
            mask[i * Parameters::MaskStride + local_idx] = m;
248
249
250
#pragma unroll
            for (int it = 0; it < Parameters::WarpIterations; ++it)
            {
Guolin Ke's avatar
Guolin Ke committed
251
                int element_index = local_idx + it * Parameters::WarpSize;
252
253
                if (element_index < element_count)
                {
Guolin Ke's avatar
Guolin Ke committed
254
255
256
257
                    const output_t d = elements[i][it] / sum[i];
                    dst[i * element_count + it * Parameters::WarpSize] = (acc_t)d * ((acc_t)((m >> it) & 1) * pinv);
                    dst_orig[i * element_count + it * Parameters::WarpSize] = d;
                }
258
259
                else
                {
Guolin Ke's avatar
Guolin Ke committed
260
261
262
263
                    break;
                }
            }
        }
264
265
266
267
268
269
    }
    else
    {
#pragma unroll
        for (int i = 0; i < Parameters::WarpBatch; ++i)
        {
Guolin Ke's avatar
Guolin Ke committed
270
271
            if (i >= local_batches)
                break;
272
273
274
#pragma unroll
            for (int it = 0; it < Parameters::WarpIterations; ++it)
            {
Guolin Ke's avatar
Guolin Ke committed
275
                int element_index = local_idx + it * Parameters::WarpSize;
276
277
                if (element_index < element_count)
                {
Guolin Ke's avatar
Guolin Ke committed
278
279
                    dst[i * element_count + it * Parameters::WarpSize] = elements[i][it] / sum[i];
                }
280
281
                else
                {
Guolin Ke's avatar
Guolin Ke committed
282
283
284
285
286
287
288
                    break;
                }
            }
        }
    }
}

289
290
291
292
293
294
295
296
297
298
#define LAUNCH_FORWARD_KERNEL(l)                                                                           \
    softmax_warp_forward<input_t, output_t, acc_t, SoftmaxParameters<l>, NeedMask, NeedBias, NeedAttnMask> \
        <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(                                        \
            dst, dst_orig, src, attn_mask, bias, (typename SoftmaxParameters<l>::MaskType *)mask, p,       \
            batch_count, attn_inner_skip_batch, bias_batch_count, softmax_elements, seed, offset);         \
    return true;

template <typename input_t, typename output_t, typename acc_t, bool NeedMask, bool NeedBias, bool NeedAttnMask>
bool dispatch_softmax_forward(output_t *dst, output_t *dst_orig, const input_t *src, const input_t *attn_mask, const input_t *bias, void *mask, acc_t p,
                              int softmax_elements, int64_t batch_count, int64_t attn_inner_skip_batch, int64_t bias_batch_count, uint64_t seed, uint64_t offset)
Guolin Ke's avatar
Guolin Ke committed
299
{
300
301
302
303
304
305
306
    TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
    if (softmax_elements == 0)
    {
        return false;
    }
    else
    {
Guolin Ke's avatar
Guolin Ke committed
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
        int log2_elements = log2_ceil(softmax_elements);
        const int next_power_of_two = 1 << log2_elements;

        // This value must match the Parameters::WarpSize constexpr value computed inside softmax_warp_backward.
        int warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;

        // This value must match the Parameters::WarpBatch constexpr value computed inside softmax_warp_backward.
        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

        // use 128 threads per block to maximimize gpu utilization
        constexpr int threads_per_block = 128;

        int warps_per_block = (threads_per_block / warp_size);
        int batches_per_block = warps_per_block * batches_per_warp;
        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
        dim3 threads(warp_size, warps_per_block, 1);
        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        switch (log2_elements)
        {
        case 0:
            LAUNCH_FORWARD_KERNEL(0)
        case 1:
            LAUNCH_FORWARD_KERNEL(1)
        case 2:
            LAUNCH_FORWARD_KERNEL(2)
        case 3:
            LAUNCH_FORWARD_KERNEL(3)
        case 4:
            LAUNCH_FORWARD_KERNEL(4)
        case 5:
            LAUNCH_FORWARD_KERNEL(5)
        case 6:
            LAUNCH_FORWARD_KERNEL(6)
        case 7:
            LAUNCH_FORWARD_KERNEL(7)
        case 8:
            LAUNCH_FORWARD_KERNEL(8)
        case 9:
            LAUNCH_FORWARD_KERNEL(9)
        case 10:
            LAUNCH_FORWARD_KERNEL(10)
        case 11:
            LAUNCH_FORWARD_KERNEL(11)
        default:
            return false;
Guolin Ke's avatar
Guolin Ke committed
352
353
354
355
356
357
358
        }
    }
    return false;
}

template <
    typename input_t, typename output_t, typename acc_t, typename Parameters,
359
    bool IsLogSoftmax, bool NeedMask>
Guolin Ke's avatar
Guolin Ke committed
360
__global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output,
361
                                      const typename Parameters::MaskType *mask, acc_t p, int64_t batch_size, int element_count)
Guolin Ke's avatar
Guolin Ke committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
{
    using MaskType = typename Parameters::MaskType;
    int64_t first_batch = (static_cast<int64_t>(blockDim.y) * static_cast<int64_t>(blockIdx.x) + threadIdx.y) * Parameters::WarpBatch;

    // batch_size might not be a multiple of Parameters::WarpBatch. Check how
    // many batches have to computed within this WARP.
    int local_batches = batch_size - first_batch;
    if (local_batches > Parameters::WarpBatch)
        local_batches = Parameters::WarpBatch;

    // there might be multiple batches per warp. compute the index within the batch
    int64_t local_idx = threadIdx.x;

    // the first element to process by the current thread
    int64_t thread_offset = first_batch * element_count + local_idx;
    grad += thread_offset;
    output += thread_offset;
    gradInput += thread_offset;
380
381
    if IF_CONSTEXPR (NeedMask)
    {
Guolin Ke's avatar
Guolin Ke committed
382
383
384
385
386
387
388
389
390
391
        mask += first_batch * Parameters::MaskStride;
    }

    // The nested loops over Parameters::WarpBatch and then Parameters::WarpIterations can be simplified to one loop,
    // but I think doing so would obfuscate the logic of the algorithm, thus I chose to keep
    // the nested loops.
    // This should have no impact on performance because the loops are unrolled anyway.

    // load data from global memory
    acc_t grad_reg[Parameters::WarpBatch][Parameters::WarpIterations];
392
393
394
    acc_t output_reg[Parameters::WarpBatch][Parameters::WarpIterations];
    if IF_CONSTEXPR (NeedMask)
    {
Guolin Ke's avatar
Guolin Ke committed
395
        MaskType mask_reg[Parameters::WarpBatch];
396
397
398
#pragma unroll
        for (int i = 0; i < Parameters::WarpBatch; ++i)
        {
Guolin Ke's avatar
Guolin Ke committed
399
400
401
402
            if (i >= local_batches)
                break;
            mask_reg[i] = mask[i * Parameters::MaskStride + local_idx];
        }
403

Guolin Ke's avatar
Guolin Ke committed
404
        const acc_t pinv = 1.0 / p;
405
406
407
408

#pragma unroll
        for (int i = 0; i < Parameters::WarpBatch; ++i)
        {
Guolin Ke's avatar
Guolin Ke committed
409
410
            int batch_element_count = (i >= local_batches) ? 0 : element_count;
            MaskType m = mask_reg[i];
411
412
413
#pragma unroll
            for (int it = 0; it < Parameters::WarpIterations; ++it)
            {
Guolin Ke's avatar
Guolin Ke committed
414
                int element_index = local_idx + it * Parameters::WarpSize;
415
416
                if (element_index < batch_element_count)
                {
Guolin Ke's avatar
Guolin Ke committed
417
                    grad_reg[i][it] =
418
419
420
                        (input_t)((acc_t)((m >> it) & 1) *
                                  (acc_t)grad[i * element_count + it * Parameters::WarpSize] *
                                  pinv) *
Guolin Ke's avatar
Guolin Ke committed
421
422
                        output[i * element_count + it * Parameters::WarpSize];
                    output_reg[i][it] = output[i * element_count + it * Parameters::WarpSize];
423
424
425
                }
                else
                {
Guolin Ke's avatar
Guolin Ke committed
426
427
428
429
430
                    grad_reg[i][it] = acc_t(0);
                    output_reg[i][it] = acc_t(0);
                }
            }
        }
431
432
433
434
435
436
    }
    else
    {
#pragma unroll
        for (int i = 0; i < Parameters::WarpBatch; ++i)
        {
Guolin Ke's avatar
Guolin Ke committed
437
            int batch_element_count = (i >= local_batches) ? 0 : element_count;
438
439
440
#pragma unroll
            for (int it = 0; it < Parameters::WarpIterations; ++it)
            {
Guolin Ke's avatar
Guolin Ke committed
441
                int element_index = local_idx + it * Parameters::WarpSize;
442
443
                if (element_index < batch_element_count)
                {
Guolin Ke's avatar
Guolin Ke committed
444
                    grad_reg[i][it] = grad[i * element_count + it * Parameters::WarpSize] *
445
                                      output[i * element_count + it * Parameters::WarpSize];
Guolin Ke's avatar
Guolin Ke committed
446
                    output_reg[i][it] = output[i * element_count + it * Parameters::WarpSize];
447
448
449
                }
                else
                {
Guolin Ke's avatar
Guolin Ke committed
450
451
452
453
454
455
456
457
                    grad_reg[i][it] = acc_t(0);
                    output_reg[i][it] = acc_t(0);
                }
            }
        }
    }

    acc_t sum[Parameters::WarpBatch];
458
459
460
461
462
463
464
#pragma unroll
    for (int i = 0; i < Parameters::WarpBatch; ++i)
    {
        sum[i] = grad_reg[i][0];
#pragma unroll
        for (int it = 1; it < Parameters::WarpIterations; ++it)
        {
Guolin Ke's avatar
Guolin Ke committed
465
466
467
468
            sum[i] += grad_reg[i][it];
        }
    }

469
470
471
472
473
474
#pragma unroll
    for (int offset = Parameters::WarpSize / 2; offset > 0; offset /= 2)
    {
#pragma unroll
        for (int i = 0; i < Parameters::WarpBatch; ++i)
        {
Guolin Ke's avatar
Guolin Ke committed
475
476
477
478
            sum[i] += SHFL_XOR(sum[i], offset, Parameters::WarpSize);
        }
    }

479
480
481
482
// store result
#pragma unroll
    for (int i = 0; i < Parameters::WarpBatch; ++i)
    {
Guolin Ke's avatar
Guolin Ke committed
483
484
        if (i >= local_batches)
            break;
485
486
487
#pragma unroll
        for (int it = 0; it < Parameters::WarpIterations; ++it)
        {
Guolin Ke's avatar
Guolin Ke committed
488
            int element_index = local_idx + it * Parameters::WarpSize;
489
490
            if (element_index < element_count)
            {
Guolin Ke's avatar
Guolin Ke committed
491
                // compute gradients
492
493
                if IF_CONSTEXPR (IsLogSoftmax)
                {
Guolin Ke's avatar
Guolin Ke committed
494
495
                    gradInput[i * element_count + it * Parameters::WarpSize] =
                        (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]);
496
497
498
                }
                else
                {
Guolin Ke's avatar
Guolin Ke committed
499
500
501
502
503
504
505
506
                    gradInput[i * element_count + it * Parameters::WarpSize] =
                        (grad_reg[i][it] - output_reg[i][it] * sum[i]);
                }
            }
        }
    }
}

507
508
509
510
511
512
#define LAUNCH_BACKWARD_KERNEL(l)                                                                 \
    softmax_warp_backward<input_t, output_t, acc_t, SoftmaxParameters<l>, IsLogSoftmax, NeedMask> \
        <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(                               \
            grad_input, grad, output, (const typename SoftmaxParameters<l>::MaskType *)mask, p,   \
            batch_count, softmax_elements);                                                       \
    break;
Guolin Ke's avatar
Guolin Ke committed
513

514
template <typename input_t, typename output_t, typename acc_t, bool IsLogSoftmax, bool NeedMask>
Guolin Ke's avatar
Guolin Ke committed
515
void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const input_t *output,
516
                               const void *mask, acc_t p, int softmax_elements, int64_t batch_count)
Guolin Ke's avatar
Guolin Ke committed
517
{
518
519
520
521
522
523
524
    TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
    if (softmax_elements == 0)
    {
        return;
    }
    else
    {
Guolin Ke's avatar
Guolin Ke committed
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
        int log2_elements = log2_ceil(softmax_elements);
        const int next_power_of_two = 1 << log2_elements;

        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
        int warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;

        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

        // use 128 threads per block to maximimize gpu utilization
        constexpr int threads_per_block = 128;

        int warps_per_block = (threads_per_block / warp_size);
        int batches_per_block = warps_per_block * batches_per_warp;
        int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
        dim3 threads(warp_size, warps_per_block, 1);
        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
        switch (log2_elements)
        {
        case 0:
            LAUNCH_BACKWARD_KERNEL(0)
        case 1:
            LAUNCH_BACKWARD_KERNEL(1)
        case 2:
            LAUNCH_BACKWARD_KERNEL(2)
        case 3:
            LAUNCH_BACKWARD_KERNEL(3)
        case 4:
            LAUNCH_BACKWARD_KERNEL(4)
        case 5:
            LAUNCH_BACKWARD_KERNEL(5)
        case 6:
            LAUNCH_BACKWARD_KERNEL(6)
        case 7:
            LAUNCH_BACKWARD_KERNEL(7)
        case 8:
            LAUNCH_BACKWARD_KERNEL(8)
        case 9:
            LAUNCH_BACKWARD_KERNEL(9)
        case 10:
            LAUNCH_BACKWARD_KERNEL(10)
        case 11:
            LAUNCH_BACKWARD_KERNEL(11)
        default:
            break;
Guolin Ke's avatar
Guolin Ke committed
570
571
572
        }
    }
}