unfused_attention_kernels.cu 97.4 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

Li Zhang's avatar
Li Zhang committed
18
#include "src/turbomind/kernels/decoder_multihead_attention/array_ops.h"
lvhan028's avatar
lvhan028 committed
19
20
21
22
#include "src/turbomind/kernels/reduce_kernel_utils.cuh"
#include "src/turbomind/kernels/unfused_attention_kernels.h"
#include "src/turbomind/utils/cuda_type_utils.cuh"
#include "src/turbomind/utils/cuda_utils.h"
23
#include "src/turbomind/utils/logger.h"
Li Zhang's avatar
Li Zhang committed
24

lvhan028's avatar
lvhan028 committed
25
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

__inline__ __device__ int target_index(int id1, int id2, int id3, int id4, int dim_1, int dim_2, int dim_3, int dim_4)
{
    return id1 * (dim_2 * dim_3 * dim_4) + id3 * (dim_2 * dim_4) + id2 * dim_4 + id4;
}

template<typename T, typename T_IN, int ITEMS_PER_THREAD>
__global__ void softmax_kernel(T*          attn_score,
                               const T_IN* qk,
                               const T*    attn_mask,
                               const T*    linear_bias_slopes,
                               const int   batch_size,
                               const int   head_num,
                               const int   q_length,
                               const int   k_length,
                               const float qk_scale)
{
    // attn_score, [batch_size, num_heads, q_length, k_length]
    // qk, [batch_size, num_heads, q_length, k_length]
    // attn_mask, [batch_size, q_length, k_length]
    // linear_bias_slopes, [num_heads]

    const int bi = blockIdx.y;  // Batch index.
    const int hi = blockIdx.z;  // Head index.

    __shared__ float s_mean, s_max;

    const float linear_bias_slope = linear_bias_slopes != nullptr ? (float)linear_bias_slopes[hi] : 0.0f;

    // Loop along with Q dimension.
    for (int qi = blockIdx.x; qi < q_length; qi += gridDim.x) {

        float data[ITEMS_PER_THREAD];
        int   qk_offset;
        float local_max = -1e20f;

        // Loop along with K dimension.
        for (int i = 0; blockDim.x * i + threadIdx.x < k_length; i++) {
            int ki    = blockDim.x * i + threadIdx.x;  // Index of K dimension.
            qk_offset = ((bi * head_num + hi) * q_length + qi) * k_length + ki;

            float qk_val  = static_cast<float>(qk[qk_offset]);
            float qk_bias = 0.0f;

            if (linear_bias_slopes != nullptr) {
                // We don't handle the upper diagonal (ki > qi) separately, whose values
                // are negligible due to the negative infinity mask. And it matches with
                // the HF's implementation.
                qk_bias += static_cast<float>(linear_bias_slope * (ki - qi));
            }

            int   mask_offset = (bi * q_length + qi) * k_length + ki;
            float mask_val    = static_cast<float>(ldg(&attn_mask[mask_offset]));
            qk_bias += (1.0f - mask_val) * -10000.0f;

            data[i]   = qk_scale * qk_val + qk_bias;
            local_max = fmax(local_max, data[i]);
        }

        float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax<float>(local_max);
        if (threadIdx.x == 0) {
            s_max = max_val;
        }
        __syncthreads();

        float local_sum = 0;
        for (int i = 0; blockDim.x * i + threadIdx.x < k_length; i++) {
            data[i] = __expf(data[i] - s_max);
            local_sum += data[i];
        }

        float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum<float>(local_sum);
        if (threadIdx.x == 0) {
            s_mean = sum_val + 1e-6f;
            s_mean = __fdividef(1.0f, s_mean);
        }
        __syncthreads();

        for (int i = 0; blockDim.x * i + threadIdx.x < k_length; i++) {
            qk_offset             = ((bi * head_num + hi) * q_length + qi) * k_length + blockDim.x * i + threadIdx.x;
            attn_score[qk_offset] = (T)(data[i] * s_mean);
        }
    }
}

template<typename T, int ITEMS_PER_THREAD>
__global__ void softmax_kernel_h2(T*        attn_score,
                                  const T*  qk_buf,
                                  const T*  attn_mask,
                                  const T*  linear_bias_slopes,
                                  const int batch_size,
                                  const int head_num,
                                  const int q_length,
                                  const int k_length,
                                  const T   qk_scale)
{
    // attn_score, [batch_size, num_heads, q_length, k_length]
    // qk, [batch_size, num_heads, q_length, k_length]
    // attn_mask, [batch_size, q_length, k_length]
    // linear_bias_slopes, [num_heads]

    using T2 = typename TypeConverter<T>::Type;

    T2*       attn_score_h2 = reinterpret_cast<T2*>(attn_score);
    const T2* qk_buf_h2     = reinterpret_cast<const T2*>(qk_buf);
    const T2* attn_mask_h2  = reinterpret_cast<const T2*>(attn_mask);

    const int bi = blockIdx.y;  // Batch index
    const int hi = blockIdx.z;  // Head index.

    __shared__ float s_mean, s_max;

    // Constant values that will be used repeately in the q/k loop.
    const T2 ONE       = cuda_cast<T2>(1.0f);
    const T2 ZERO      = cuda_cast<T2>(0.0f);
    const T2 NEG_INFTY = cuda_cast<T2>(-10000.0f);

    // The normalization factor of QK.
    const T2 qk_scale_h2 = cuda_cast<T2>(qk_scale);
    // The slope of a linear position bias of the current attention head.
    const T2 linear_bias_slope = linear_bias_slopes != nullptr ? cuda_cast<T2>(linear_bias_slopes[hi]) : ZERO;

    // Loop over q dimension.
    for (int qi = blockIdx.x; qi < q_length; qi += gridDim.x) {
        T2    data[ITEMS_PER_THREAD];
        int   qk_offset;
        float local_max = -1e20f;

        // Loop over k dimension.
        for (int i = 0; blockDim.x * i + threadIdx.x < (k_length / 2) && i < ITEMS_PER_THREAD; i++) {
            // The half of the index of k dimension. We will use the elements at {2 * ki, 2 * ki + 1}.
            int ki          = blockDim.x * i + threadIdx.x;
            qk_offset       = ((bi * head_num + hi) * q_length + qi) * (k_length / 2) + ki;
            int mask_offset = (bi * q_length + qi) * (k_length / 2) + ki;

            // The value of QK^T matrix at (qi, ki).
            T2 qk = qk_buf_h2[qk_offset];
            // The bias value to the position (qi, ki) including both mask and positional bias.
            T2 qk_bias = ZERO;

            if (linear_bias_slopes != nullptr) {
                // The position bias depends on the distance between qi/ki and is zero if qi >= 2*ki
                // or qi >= 2*ki+1. For T2 vectorization, we should handle every two elements along
                // with k-dim simultaneously. To do this, we check qi / 2 > ki at ones instead of
                // qi >= 2*ki or 2*ki+1. It works because an diagonal element for an odd qi will be
                // zero due to slope * (qi - 2*ki+1) = 0. Thus, we don't handle the upper diagonal
                // separately, whose values are negligible due to the negative infinity mask.
                T2 dist(2.0f * ki - qi, 2.0f * ki + 1 - qi);
                qk_bias = hadd2<T2>(qk_bias, hmul2<T2>(linear_bias_slope, dist));
            }

            T2 mask_val = ldg(&attn_mask_h2[mask_offset]);
            qk_bias     = hadd2<T2>(qk_bias, hmul2<T2>(hsub2<T2>(ONE, mask_val), NEG_INFTY));

            data[i]   = hadd2<T2>(hmul2<T2>(qk, qk_scale_h2), qk_bias);
xiabo's avatar
xiabo committed
181
182
183
184
185
            // if (std::is_same<T2, half2>::value) {
                local_max = fmax(local_max, fmax((float)data[i].data[0], (float)data[i].data[1]));
            // } else {
            //     local_max = fmax(local_max, fmax((float)data[i].x, (float)data[i].y));
            // }
Li Zhang's avatar
Li Zhang committed
186
187
188
189
190
191
192
193
194
195
196
        }

        float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax<float>(local_max);
        if (threadIdx.x == 0) {
            s_max = max_val;
        }
        __syncthreads();

        float local_sum = 0.0f;
        for (int i = 0; blockDim.x * i + threadIdx.x < (k_length / 2) && i < ITEMS_PER_THREAD; i++) {
            data[i] = hexp2<T2>(hsub2<T2>(data[i], cuda_cast<T2>(s_max)));
xiabo's avatar
xiabo committed
197
198
199
200
201
            // if (std::is_same<T2, half2>::value) {
                local_sum += (float)(data[i].data[0] + data[i].data[1]);
            // } else {
            //     local_sum += (float)(data[i].x + data[i].y);
            // }
Li Zhang's avatar
Li Zhang committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        }

        float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum<float>(local_sum);

        if (threadIdx.x == 0) {
            s_mean = sum_val + 1e-6f;
            s_mean = __fdividef(1.0f, s_mean);
        }
        __syncthreads();

        for (int i = 0; blockDim.x * i + threadIdx.x < (k_length / 2) && i < ITEMS_PER_THREAD; i++) {
            qk_offset = ((bi * head_num + hi) * q_length + qi) * (k_length / 2) + blockDim.x * i + threadIdx.x;
            attn_score_h2[qk_offset] = hmul2<T2>(data[i], cuda_cast<T2>(s_mean));
        }
    }
}

template<typename T, int K_ITEMS_PER_THREAD, int Q_ITEMS_PER_THREAD>
__global__ void softmax_kernel_h2_v2(T*        attn_score,
                                     const T*  qk_buf,
                                     const T*  attn_mask,
                                     const T*  linear_bias_slopes,
                                     const int batch_size,
                                     const int head_num,
                                     const int q_length,
                                     const int k_length,
                                     const T   scalar)
{
    // attn_score, [batch_size, num_heads, q_length, k_length]
    // qk, [batch_size, num_heads, q_length, k_length]
    // attn_mask, [batch_size, q_length, k_length]
    // linear_bias_slopes, [num_heads]

    using T2 = typename TypeConverter<T>::Type;

    // QK^T matrix of shape (batch_size, head_num, q_length, k_length / 2)
    T2*       attn_score_h2 = reinterpret_cast<T2*>(attn_score);
    const T2* qk_buf_h2     = reinterpret_cast<const T2*>(qk_buf);
    const T2* attn_mask_h2  = reinterpret_cast<const T2*>(attn_mask);

    const int bi = blockIdx.y;  // Batch index
    const int hi = blockIdx.z;  // Head index.

    // Constant values that will be used repeately in the q/k loop.
    const T2 ONE       = cuda_cast<T2>(1.0f);
    const T2 ZERO      = cuda_cast<T2>(0.0f);
    const T2 NEG_INFTY = cuda_cast<T2>(-10000.0f);

    // The normalization factor of QK.
    const T2 qk_scale = cuda_cast<T2>(scalar);
    // The slope of a linear position bias of the current attention head.
    const T2 linear_bias_slope = linear_bias_slopes != nullptr ? cuda_cast<T2>(linear_bias_slopes[hi]) : ZERO;

    __shared__ float s_sum[Q_ITEMS_PER_THREAD], s_max[Q_ITEMS_PER_THREAD];

    // Loop over q dimension.
    for (int qi = blockIdx.x; qi < q_length; qi += gridDim.x * Q_ITEMS_PER_THREAD) {
        T2 data[Q_ITEMS_PER_THREAD][K_ITEMS_PER_THREAD];

        int qk_offset[Q_ITEMS_PER_THREAD];

        float local_max[Q_ITEMS_PER_THREAD];
#pragma unroll
        for (int j = 0; j < Q_ITEMS_PER_THREAD; j++) {
            local_max[j] = -1e20f;
        }

        // Loop over k dimension.
        const int Q_ITEMS = min((q_length - qi + gridDim.x - 1) / gridDim.x, Q_ITEMS_PER_THREAD);
        for (int i = 0; blockDim.x * i + threadIdx.x < k_length / 2 && i < K_ITEMS_PER_THREAD; ++i) {
            // The half of the index of k dimension. We will use the elements at {2 * ki, 2 * ki + 1}.
            int ki = blockDim.x * i + threadIdx.x;

            int mask_offset[Q_ITEMS_PER_THREAD];
#pragma unroll
            for (int j = 0; j < Q_ITEMS; j++) {
                qk_offset[j]   = ((bi * head_num + hi) * q_length + qi + j * gridDim.x) * (k_length / 2) + ki;
                mask_offset[j] = (bi * q_length + qi + j * gridDim.x) * (k_length / 2) + ki;
            }

            T2 mask_val[Q_ITEMS_PER_THREAD];
#pragma unroll
            for (int j = 0; j < Q_ITEMS; j++) {
                mask_val[j] = ldg(&attn_mask_h2[mask_offset[j]]);
            }

            T2 qk[Q_ITEMS_PER_THREAD];
#pragma unroll
            for (int j = 0; j < Q_ITEMS; j++) {
                qk[j] = qk_buf_h2[qk_offset[j]];
            }

            T2 pos_bias[Q_ITEMS_PER_THREAD];
            if (linear_bias_slopes != nullptr) {
#pragma unroll
                for (int j = 0; j < Q_ITEMS; j++) {
                    // The position bias depends on the distance between qi/ki and is zero if qi >= 2*ki
                    // or qi >= 2*ki+1. For T2 vectorization, we should handle every two elements along
                    // with k-dim simultaneously. To do this, we check qi / 2 > ki at ones instead of
                    // qi >= 2*ki or 2*ki+1. It works because an diagonal element for an odd qi will be
                    // zero due to slope * (qi - 2*ki+1) = 0. Thus, we don't handle the upper diagonal
                    // separately, whose values are negligible due to the negative infinity mask.
                    int qidx = qi + j * gridDim.x;
                    T2  dist(2.0f * ki - qidx, 2.0f * ki + 1 - qidx);
                    pos_bias[j] = hmul2<T2>(linear_bias_slope, dist);
                }
            }
#pragma unroll
            for (int j = 0; j < Q_ITEMS; j++) {
                mask_val[j] = hmul2<T2>(hsub2<T2>(ONE, mask_val[j]), NEG_INFTY);
            }

#pragma unroll
            for (int j = 0; j < Q_ITEMS; j++) {
                T2 val = hadd2<T2>(hmul2<T2>(qk_scale, qk[j]), mask_val[j]);
                if (linear_bias_slopes != nullptr) {
                    val = hadd2<T2>(val, pos_bias[j]);
                }
                data[j][i]   = val;
xiabo's avatar
xiabo committed
321
322
323
324
325
                // if (std::is_same<T2, half2>::value) {
                    local_max[j] = fmax(local_max[j], fmax((float)data[j][i].data[0], (float)data[j][i].data[1]));
                // } else {
                //     local_max[j] = fmax(local_max[j], fmax((float)data[j][i].x, (float)data[j][i].y));
                // }
Li Zhang's avatar
Li Zhang committed
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
            }
        }

        if (blockDim.x <= 32) {
            warpReduceMaxV2<float, Q_ITEMS_PER_THREAD>(local_max);
        }
        else {
            blockReduceMaxV2<float, Q_ITEMS_PER_THREAD>(local_max);
        }

        if (threadIdx.x == 0) {
#pragma unroll
            for (int j = 0; j < Q_ITEMS_PER_THREAD; j++) {
                s_max[j] = local_max[j];
            }
        }
        __syncthreads();

        float local_sum[Q_ITEMS_PER_THREAD];
#pragma unroll
        for (int j = 0; j < Q_ITEMS_PER_THREAD; j++) {
            local_sum[j] = {0.f};
        }

        for (int i = 0; blockDim.x * i + threadIdx.x < k_length / 2 && i < K_ITEMS_PER_THREAD; ++i) {
#pragma unroll
            for (int j = 0; j < Q_ITEMS; ++j) {
                data[j][i] = hexp2<T2>(hsub2<T2>(data[j][i], cuda_cast<T2>(s_max[j])));
            }

#pragma unroll
            for (int j = 0; j < Q_ITEMS; j++) {
xiabo's avatar
xiabo committed
358
359
360
361
362
                // if (std::is_same<T2, half2>::value) {
                    local_sum[j] += (float)(data[j][i].data[0] + data[j][i].data[1]);
                // } else {
                //     local_sum[j] += (float)(data[j][i].x + data[j][i].y);
                // }
Li Zhang's avatar
Li Zhang committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
            }
        }

        if (blockDim.x <= 32) {
            warpReduceSumV2<float, Q_ITEMS_PER_THREAD>(local_sum);
        }
        else {
            blockReduceSumV2<float, Q_ITEMS_PER_THREAD>(local_sum);
        }

        if (threadIdx.x == 0) {
#pragma unroll
            for (int j = 0; j < Q_ITEMS_PER_THREAD; j++) {
                s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f);
            }
        }
        __syncthreads();

        for (int i = 0; blockDim.x * i + threadIdx.x < k_length / 2 && i < K_ITEMS_PER_THREAD; ++i) {
#pragma unroll
            for (int j = 0; j < Q_ITEMS; j++) {
                qk_offset[j] = ((bi * head_num + hi) * q_length + qi + j * gridDim.x) * (k_length / 2) + blockDim.x * i
                               + threadIdx.x;
            }

#pragma unroll
            for (int j = 0; j < Q_ITEMS; j++) {
                attn_score_h2[qk_offset[j]] = hmul2<T2>(data[j][i], cuda_cast<T2>(s_sum[j]));
            }
        }
    }
}

#define LAUNCH_MAKSED_SOFTMAX_(T_, ITEMS_PER_THREAD)                                                                   \
    block.x /= ITEMS_PER_THREAD;                                                                                       \
    block.x = (block.x + 31) / 32 * 32;                                                                                \
    assert(block.x <= 1024);                                                                                           \
    if (is_half2) {                                                                                                    \
        if (grid.x % 4 == 0) {                                                                                         \
            grid.x /= 4;                                                                                               \
            softmax_kernel_h2_v2<T_, ITEMS_PER_THREAD, 4>                                                              \
                <<<grid, block, 0, stream>>>((T_*)param.attention_score,                                               \
                                             (const T_*)param.qk,                                                      \
                                             (const T_*)param.attention_mask,                                          \
                                             (const T_*)param.linear_bias_slopes,                                      \
                                             param.batch_size,                                                         \
                                             param.num_heads,                                                          \
                                             param.q_length,                                                           \
                                             param.k_length,                                                           \
                                             (const T_)param.qk_scale);                                                \
        }                                                                                                              \
        else {                                                                                                         \
            softmax_kernel_h2<T_, ITEMS_PER_THREAD><<<grid, block, 0, stream>>>((T_*)param.attention_score,            \
                                                                                (const T_*)param.qk,                   \
                                                                                (const T_*)param.attention_mask,       \
                                                                                (const T_*)param.linear_bias_slopes,   \
                                                                                param.batch_size,                      \
                                                                                param.num_heads,                       \
                                                                                param.q_length,                        \
                                                                                param.k_length,                        \
                                                                                (const T_)param.qk_scale);             \
        }                                                                                                              \
    }                                                                                                                  \
    else {                                                                                                             \
        softmax_kernel<T, T_IN, ITEMS_PER_THREAD><<<grid, block, 0, stream>>>(param.attention_score,                   \
                                                                              param.qk,                                \
                                                                              param.attention_mask,                    \
                                                                              param.linear_bias_slopes,                \
                                                                              param.batch_size,                        \
                                                                              param.num_heads,                         \
                                                                              param.q_length,                          \
                                                                              param.k_length,                          \
                                                                              param.qk_scale);                         \
    }

#define LAUNCH_MAKSED_SOFTMAX(ITEMS_PER_THREAD) LAUNCH_MAKSED_SOFTMAX_(half, ITEMS_PER_THREAD)

template<typename T, typename T_IN>
void invokeMaskedSoftmax(MaskedSoftmaxParam<T, T_IN>& param, cudaStream_t stream)
{
    // attention_score,    (batch_size, head_num, q_length, k_length), softmax output.
    // qk,                 (batch_size, head_num, q_length, k_length), QK^T.
    // attention_mask,     (batch_size, q_length, k_length), attention mask.
    // linear_bias_slopes, (head_num,) the slopes of the linear position bias.

    dim3 grid(param.q_length, param.batch_size, param.num_heads);
    if (param.batch_size * param.num_heads > 360) {
        grid.x = ceil(float(param.q_length) / 32.0f);
    }

    bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && param.k_length % 2 == 0;
    dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32);

456
457
458
459
460
461
462
    if (block.x > 8192 && block.x <= 16384) {
        LAUNCH_MAKSED_SOFTMAX(16)
    }
    if (block.x > 4096 && block.x <= 8192) {
        LAUNCH_MAKSED_SOFTMAX(8)
    }
    else if (block.x > 2048 && block.x <= 4096) {
Li Zhang's avatar
Li Zhang committed
463
464
        LAUNCH_MAKSED_SOFTMAX(4)
    }
465
    else if (block.x > 1024 && block.x <= 2048) {
Li Zhang's avatar
Li Zhang committed
466
467
        LAUNCH_MAKSED_SOFTMAX(2)
    }
468
    else if (block.x > 0 && block.x <= 1024) {
Li Zhang's avatar
Li Zhang committed
469
470
471
        LAUNCH_MAKSED_SOFTMAX(1)
    }
    else {
472
        FT_CHECK(param.k_length <= 16384);
Li Zhang's avatar
Li Zhang committed
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
    }
}

template void invokeMaskedSoftmax(MaskedSoftmaxParam<float, float>& param, cudaStream_t stream);
template void invokeMaskedSoftmax(MaskedSoftmaxParam<half, float>& param, cudaStream_t stream);
template void invokeMaskedSoftmax(MaskedSoftmaxParam<half, half>& param, cudaStream_t stream);

#ifdef ENABLE_BF16
template<>
void invokeMaskedSoftmax(MaskedSoftmaxParam<__nv_bfloat16, float>& param, cudaStream_t stream)
{
    // attention_score,    (batch_size, head_num, q_length, k_length), softmax output.
    // qk,                 (batch_size, head_num, q_length, k_length), QK^T.
    // attention_mask,     (batch_size, q_length, k_length), attention mask.
    // linear_bias_slopes, (head_num,) the slopes of the linear position bias.

    using T    = __nv_bfloat16;
    using T_IN = float;

    dim3 grid(param.q_length, param.batch_size, param.num_heads);
    if (param.batch_size * param.num_heads > 360) {
        grid.x = ceil(float(param.q_length) / 32.0f);
    }

    bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && param.k_length % 2 == 0;
    dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32);

    if (block.x > 2048 && block.x <= 4096) {
        LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 4);
    }
    else if (block.x > 1024) {
        LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 2);
    }
    else if (block.x > 0) {
        LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 1);
    }
    else {
        FT_CHECK(param.k_length <= 4096);
    }
}
template<>
void invokeMaskedSoftmax(MaskedSoftmaxParam<__nv_bfloat16, __nv_bfloat16>& param, cudaStream_t stream)
{
    // attention_score,    (batch_size, head_num, q_length, k_length), softmax output.
    // qk,                 (batch_size, head_num, q_length, k_length), QK^T.
    // attention_mask,     (batch_size, q_length, k_length), attention mask.
    // linear_bias_slopes, (head_num,) the slopes of the linear position bias.

    using T    = __nv_bfloat16;
    using T_IN = __nv_bfloat16;

    dim3 grid(param.q_length, param.batch_size, param.num_heads);
    if (param.batch_size * param.num_heads > 360) {
        grid.x = ceil(float(param.q_length) / 32.0f);
    }

    bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && param.k_length % 2 == 0;
    dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32);

    if (block.x > 2048 && block.x <= 4096) {
        LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 4);
    }
    else if (block.x > 1024) {
        LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 2);
    }
    else if (block.x > 0) {
        LAUNCH_MAKSED_SOFTMAX_(__nv_bfloat16, 1);
    }
    else {
        FT_CHECK(param.k_length <= 4096);
    }
}

#endif

#undef LAUNCH_MAKSED_SOFTMAX
#undef LAUNCH_MAKSED_SOFTMAX_

template<typename T>
__global__ void transpose(const T*     src,
                          T*           dst,
                          const int    batch_size,
                          const int    seq_len,
                          const int    head_num,
                          const int    size_per_head,
                          const float* scale,
                          int          int8_mode)
{
    int tid = blockIdx.x * blockDim.x + threadIdx.x;

    int batch_id = tid / (head_num * seq_len * size_per_head);
    int head_id  = (tid % (head_num * seq_len * size_per_head)) / (seq_len * size_per_head);
    int seq_id   = (tid % (seq_len * size_per_head)) / size_per_head;
    int id       = tid % size_per_head;

    int target_id = target_index(batch_id, head_id, seq_id, id, batch_size, head_num, seq_len, size_per_head);

    if (int8_mode == 2) {
        using Int8_Packed_T  = typename packed_as<int8_t, num_elems<T>::value>::type;
        using Float_Packed_T = typename packed_as<float, num_elems<T>::value>::type;

        const Float_Packed_T scale_val = cuda_cast<Float_Packed_T>(*scale);
        reinterpret_cast<Int8_Packed_T*>(dst)[target_id] =
            cuda_cast<Int8_Packed_T>(cuda_cast<Float_Packed_T>(src[tid]) * scale_val);
    }
    else {
        dst[target_id] = src[tid];
    }
}

template<>
__global__ void transpose(const float* src,
                          float*       dst,
                          const int    batch_size,
                          const int    seq_len,
                          const int    head_num,
                          const int    size_per_head,
                          const float* scale,
                          int          int8_mode)
{
    int batch_id = blockIdx.x / (head_num * seq_len);
    int seq_id   = blockIdx.x % seq_len;
    int head_id  = (blockIdx.x % (head_num * seq_len)) / seq_len;

    const int target_id = batch_id * (head_num * seq_len * size_per_head) + seq_id * head_num * size_per_head
                          + head_id * size_per_head + threadIdx.x;
    const int src_id = blockIdx.x * size_per_head + threadIdx.x;

    if (int8_mode == 2) {
        const float scale_val                     = *scale;
        reinterpret_cast<int8_t*>(dst)[target_id] = cuda_cast<int8_t>(src[src_id] * scale_val);
    }
    else {
        dst[target_id] = src[src_id];
    }
}

template<typename T>
void invokeTransposeQKV(T*           dst,
                        T*           src,
                        const int    batch_size,
                        const int    seq_len,
                        const int    head_num,
                        const int    size_per_head,
                        const float* scale,
                        const int    int8_mode,
                        cudaStream_t stream)
{
    dim3 grid, block;
    if (sizeof(T) == 2) {
        int seq_per_block = 1;
        grid.x            = batch_size * head_num * seq_len / seq_per_block;
        while (seq_per_block < 4 && grid.x % 2 == 0) {
            grid.x /= 2;
            seq_per_block *= 2;
        }

        FT_CHECK(grid.x * seq_per_block == (size_t)batch_size * head_num * seq_len);

        if (seq_per_block * size_per_head % 2 == 0) {
            block.x = seq_per_block * size_per_head / 2;
            if (std::is_same<T, half>::value) {
                transpose<half2><<<grid, block, 0, stream>>>(
                    (half2*)src, (half2*)dst, batch_size, seq_len, head_num, size_per_head / 2, scale, int8_mode);
            }
#ifdef ENABLE_BF16
            else {
                transpose<__nv_bfloat162><<<grid, block, 0, stream>>>((__nv_bfloat162*)src,
                                                                      (__nv_bfloat162*)dst,
                                                                      batch_size,
                                                                      seq_len,
                                                                      head_num,
                                                                      size_per_head / 2,
                                                                      scale,
                                                                      int8_mode);
            }
#endif
        }
        else {
            block.x = seq_per_block * size_per_head;
            transpose<T>
                <<<grid, block, 0, stream>>>(src, dst, batch_size, seq_len, head_num, size_per_head, scale, int8_mode);
        }
    }
    else {
        const int seq_per_block = 1;
        grid.x                  = batch_size * head_num * seq_len / seq_per_block;
        block.x                 = seq_per_block * size_per_head;
        transpose<T>
            <<<grid, block, 0, stream>>>(src, dst, batch_size, seq_len, head_num, size_per_head, scale, int8_mode);
    }
}

#define INSTANTIATETRANSPOSEQKV(T)                                                                                     \
    template void invokeTransposeQKV(T*           src,                                                                 \
                                     T*           dst,                                                                 \
                                     const int    batch_size,                                                          \
                                     const int    seq_len,                                                             \
                                     const int    head_num,                                                            \
                                     const int    size_per_head,                                                       \
                                     const float* scale,                                                               \
                                     const int    int8_mode,                                                           \
                                     cudaStream_t stream)
INSTANTIATETRANSPOSEQKV(float);
INSTANTIATETRANSPOSEQKV(half);
#ifdef ENABLE_BF16
INSTANTIATETRANSPOSEQKV(__nv_bfloat16);
#endif
#undef INSTANTIATETRANSPOSEQKV

template<typename T>
__global__ void transpose_remove_padding(const T*     src,
                                         T*           dst,
                                         const int    batch_size,
                                         const int    seq_len,
                                         const int    head_num,
                                         const int    size_per_head,
                                         const int*   mask_offset,
                                         const float* scale,
                                         const int    int8_mode)
{
    // TODO: optimize this kernel?
    // do remove_sequence_length_padding
    const int bid = blockIdx.x;  // batch * seq_len or valid_word_num

    const int src_batch_id = (bid + mask_offset[bid]) / seq_len;
    const int src_seq_id   = (bid + mask_offset[bid]) % seq_len;

    const int dst_seq_id = bid;

    const int src_offset_base = src_batch_id * seq_len * head_num * size_per_head + src_seq_id * size_per_head;
    const int dst_offset_base = dst_seq_id * head_num * size_per_head;

    using Int8_Packed_T  = typename packed_as<int8_t, num_elems<T>::value>::type;
    using Float_Packed_T = typename packed_as<float, num_elems<T>::value>::type;
    const Float_Packed_T scale_val =
        int8_mode == 2 ? cuda_cast<Float_Packed_T>(*scale) : cuda_cast<Float_Packed_T>(0.0f);

    for (int idx = threadIdx.x; idx < head_num * size_per_head; idx += blockDim.x) {
        const int head_id   = idx / size_per_head;
        const int hidden_id = idx % size_per_head;
        const T   src_elem  = ldg(&src[src_offset_base + head_id * seq_len * size_per_head + hidden_id]);
        if (int8_mode == 2) {
            reinterpret_cast<Int8_Packed_T*>(dst)[dst_offset_base + idx] =
                cuda_cast<Int8_Packed_T>(cuda_cast<Float_Packed_T>(src_elem) * scale_val);
        }
        else {
            dst[dst_offset_base + idx] = src_elem;
        }
    }
}

// clang-format off
template<typename T>
void invokeTransposeAttentionOutRemovePadding(T*           src,
                                              T*           dst,
                                              const int    valid_word_num,
                                              const int    batch_size,
                                              const int    seq_len,
                                              const int    head_num,
                                              const int    size_per_head,
                                              const int*   mask_offset,
                                              const float* scale,
                                              const int    int8_mode,
                                              cudaStream_t stream)
{
#ifdef ENABLE_BF16
    bool is_half2 = (std::is_same<T, half>::value || std::is_same<T, __nv_bfloat16>::value) && (size_per_head % 2 == 0);
#else
    bool is_half2 = (std::is_same<T, half>::value) && (size_per_head % 2 == 0);
#endif
    using T2       = typename TypeConverter<T>::Type;  // fp16 to half2, bf16 to bf162
    int block_size = head_num * size_per_head;
    if (is_half2) {
        while (block_size > 512) {
            if (block_size % 2 == 0) {
                block_size /= 2;
            }
            else {
                is_half2   = false;
                block_size = std::min(block_size, 1024);
                break;
            }
        }
    }
    else {
        block_size = std::min(block_size, 1024);
    }

    if (is_half2) {
        transpose_remove_padding<T2><<<valid_word_num, block_size, 0, stream>>>(
            (T2*)src, (T2*)dst, batch_size, seq_len, head_num, size_per_head / 2, mask_offset, scale, int8_mode);
    }
    else {
        transpose_remove_padding<<<valid_word_num, block_size, 0, stream>>>(
            src, dst, batch_size, seq_len, head_num, size_per_head, mask_offset, scale, int8_mode);
    }
}
// clang-format on

#define INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(T)                                                               \
    template void invokeTransposeAttentionOutRemovePadding(T*           src,                                           \
                                                           T*           dst,                                           \
                                                           const int    valid_word_num,                                \
                                                           const int    batch_size,                                    \
                                                           const int    seq_len,                                       \
                                                           const int    head_num,                                      \
                                                           const int    size_per_head,                                 \
                                                           const int*   mask_offset,                                   \
                                                           const float* scale,                                         \
                                                           const int    int8_mode,                                     \
                                                           cudaStream_t stream)
INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(float);
INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(half);
#ifdef ENABLE_BF16
INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING(__nv_bfloat16);
#endif
#undef INSTANTIATETRANSPOSEATTENTIONOUTREMOVEPADDING

template<typename T>
__global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
                                                   T* k_buf,
                                                   T* v_buf,
                                                   T* QKV,
                                                   const T* __restrict qkv_bias,
                                                   const int*   padding_offset,
                                                   const int    batch_size,
                                                   const int    seq_len,
                                                   const int    token_num,
                                                   const int    head_num,
                                                   const int    size_per_head,
                                                   const float* scale,
                                                   const int    int8_mode)
{
    // QKV: [token_num, 3, n]
    // qkv_bias: [3, n]
    // q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head]

    T*        qkv_ptr[3] = {q_buf, k_buf, v_buf};
    const int n          = head_num * size_per_head;
    for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < token_num * 3 * n;
         index += gridDim.x * blockDim.x) {
        const int bias_id = index % (3 * n);

        const int token_idx        = index / (3 * n);
        const int token_padded_idx = token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]);
        const int target_batch_id  = token_padded_idx / seq_len;
        const int seq_id           = token_padded_idx % seq_len;

        const int qkv_id  = (index % (3 * n)) / n;
        const int head_id = (index % n) / size_per_head;
        const int size_id = index % size_per_head;

        T val;
        if (int8_mode == 2) {
            val = cuda_cast<T>(cuda_cast<float>(reinterpret_cast<const int8_t*>(QKV)[index]) * scale[qkv_id]);
        }
        else {
            val = ldg(&QKV[index]);
        }
        val = val + ldg(&qkv_bias[bias_id]);

        if (int8_mode == 2) {
            // TODO(mseznec): add support for int8 BMM with FusedAtt
        }
        else {
            QKV[index] = val;
        }

        qkv_ptr[qkv_id][target_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head
                        + seq_id * size_per_head + size_id] = val;
    }
}

template<typename T>
struct Vec_t {
    static constexpr int size = 0;
};

template<>
struct Vec_t<float> {
    using Type                = float2;
    static constexpr int size = 2;
};

template<>
struct Vec_t<half> {
    using Type                = uint32_t;
    static constexpr int size = 2;
};

#ifdef ENABLE_BF16
template<>
struct Vec_t<__nv_bfloat16> {
    using Type                = __nv_bfloat162;
    static constexpr int size = 2;
};
#endif

/// TODO: support batch step offset
template<typename T, bool PREFIX_PROMPT>
874
875
876
877
__global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf,
                                                   T* k_buf,
                                                   T* v_buf,
                                                   T* QKV,
Li Zhang's avatar
Li Zhang committed
878
                                                   const T* __restrict qkv_bias,
Li Zhang's avatar
Li Zhang committed
879
880
881
882
883
884
885
886
887
888
889
890
891
892
                                                   const int*   padding_offset,
                                                   const int*   context_length,
                                                   const int*   input_length,
                                                   const float* rope_theta,
                                                   int          batch_size,
                                                   int          seq_len,
                                                   int          head_num,
                                                   int          kv_head_num,
                                                   int          size_per_head,
                                                   int          rotary_embedding_dim,
                                                   float        rotary_embedding_base,
                                                   int          max_position_embeddings,
                                                   bool         use_dynamic_ntk,
                                                   bool         use_logn_attn)
Li Zhang's avatar
Li Zhang committed
893
894
895
896
897
898
899
900
901
902
903
{
    // This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and
    // QKV split to 3 split buffer q, k, v and transpose them to [batch_size, head_num, seq_len, size_per_head].
    // For q and k, also apply the rotary embedding.

    // NOTE: QKV src shape (batch_size, seq_len, 3, head_num, size_per_head)
    //  QKV dst shape (3, batch_size, head_num, seq_len, size_per_head)
    extern __shared__ __align__(sizeof(float2)) char smem_[];  // align on largest vector type

    constexpr int vec_size         = Vec_t<T>::size;
    using Vec_t                    = typename Vec_t<T>::Type;
904
    const int token_idx            = blockIdx.x;
Li Zhang's avatar
Li Zhang committed
905
906
907
908
909
910
911
912
913
    const int token_padding_offset = (padding_offset == nullptr || token_idx < 0) ? 0 : padding_offset[token_idx];
    const int tgt_token_idx        = token_idx + token_padding_offset;

    const int batch_idx = tgt_token_idx / seq_len;
    const int seq_idx   = tgt_token_idx % seq_len;

    const int head_idx = blockIdx.y;
    const int tidx     = threadIdx.x;

914
    const int total_seq_len = seq_len;
Li Zhang's avatar
Li Zhang committed
915
916
917

    const bool is_masked = tidx * vec_size >= size_per_head;

918
    const int hidden_idx = head_idx * size_per_head + tidx * vec_size;
919
920
921
922
923
924
925
926
927
928

    const int q_kv_head_num = head_num + 2 * kv_head_num;

    const int k_offset = head_num * size_per_head;
    const int v_offset = k_offset + kv_head_num * size_per_head;

    // src QKV: [batch, time, q_kv_head_num, hidden]
    const int src_q_idx = token_idx * q_kv_head_num * size_per_head + hidden_idx;
    const int src_k_idx = token_idx * q_kv_head_num * size_per_head + hidden_idx + k_offset;
    const int src_v_idx = token_idx * q_kv_head_num * size_per_head + hidden_idx + v_offset;
Li Zhang's avatar
Li Zhang committed
929
930
931

    Vec_t q, k, v;
    Vec_t q_bias, k_bias, v_bias;
932

Li Zhang's avatar
Li Zhang committed
933
934
935
936
937
938
    using Vec = Array<T, vec_size>;

    static_assert(sizeof(Vec_t) == sizeof(Vec));

    using namespace ops;

939
    // load Q and apply bias
Li Zhang's avatar
Li Zhang committed
940
941
942
    if (!is_masked) {
        q = *reinterpret_cast<const Vec_t*>(&QKV[src_q_idx]);
        if (qkv_bias) {
Li Zhang's avatar
Li Zhang committed
943
944
            q_bias  = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx]);
            (Vec&)q = (Vec&)q + (Vec&)q_bias;
Li Zhang's avatar
Li Zhang committed
945
946
947
        }
    }

948
949
950
951
952
    // load KV and apply bias
    if (!is_masked && head_idx < kv_head_num) {
        k = *reinterpret_cast<const Vec_t*>(&QKV[src_k_idx]);
        v = *reinterpret_cast<const Vec_t*>(&QKV[src_v_idx]);
        if (qkv_bias) {
Li Zhang's avatar
Li Zhang committed
953
954
955
956
            k_bias  = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + k_offset]);
            v_bias  = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + v_offset]);
            (Vec&)k = (Vec&)k + (Vec&)k_bias;
            (Vec&)v = (Vec&)v + (Vec&)v_bias;
957
        }
Li Zhang's avatar
Li Zhang committed
958
959
    }

Li Zhang's avatar
Li Zhang committed
960
961
    const int context_len = context_length[batch_idx];
    const int history_len = context_len - input_length[batch_idx];
962
    const int timestep    = history_len + seq_idx;
Li Zhang's avatar
Li Zhang committed
963

Li Zhang's avatar
Li Zhang committed
964
965
    if (rope_theta) {
        rotary_embedding_base = rope_theta[batch_idx];
Li Zhang's avatar
Li Zhang committed
966
967
    }

Li Zhang's avatar
Li Zhang committed
968
969
970
971
972
973
    RotaryEmbedding<vec_size> rotary_emb(rotary_embedding_base, rotary_embedding_dim, timestep, {tidx * vec_size, 0});
    rotary_emb.apply((Array<T, vec_size>&)q);

    if (head_idx < kv_head_num) {
        rotary_emb.apply((Array<T, vec_size>&)k);
    }
Li Zhang's avatar
Li Zhang committed
974

975
976
    if (use_logn_attn) {
        // +1 to convert to context length at the timestep
Li Zhang's avatar
Li Zhang committed
977
978
        LogNScaling logn_scaling(timestep + 1, max_position_embeddings);
        logn_scaling.apply((Array<T, vec_size>&)q);
Li Zhang's avatar
Li Zhang committed
979
    }
980

AllentDan's avatar
AllentDan committed
981
    if (!is_masked && !q_buf) {  // also skip modifying QKV if q/k/v_buf are present
Li Zhang's avatar
Li Zhang committed
982
        *reinterpret_cast<Vec_t*>(&QKV[src_q_idx]) = q;
983
984
985
986
        if (head_idx < kv_head_num) {
            *reinterpret_cast<Vec_t*>(&QKV[src_k_idx]) = k;
            *reinterpret_cast<Vec_t*>(&QKV[src_v_idx]) = v;
        }
Li Zhang's avatar
Li Zhang committed
987
988
989
990
991
    }

    const int dest_q_idx = batch_idx * size_per_head * seq_len * head_num + head_idx * size_per_head * seq_len
                           + seq_idx * size_per_head + tidx * vec_size;

992
    const int dest_kv_idx = batch_idx * size_per_head * total_seq_len * kv_head_num
993
                            + head_idx * size_per_head * total_seq_len + seq_idx * size_per_head + tidx * vec_size;
Li Zhang's avatar
Li Zhang committed
994
995

    if (!is_masked) {
996
997
998
999
1000
        *reinterpret_cast<Vec_t*>(&q_buf[dest_q_idx]) = q;
        if (head_idx < kv_head_num) {
            *reinterpret_cast<Vec_t*>(&k_buf[dest_kv_idx]) = k;
            *reinterpret_cast<Vec_t*>(&v_buf[dest_kv_idx]) = v;
        }
Li Zhang's avatar
Li Zhang committed
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
    }
}

#define FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, PREFIX_PROMPT)                                                              \
    add_fusedQKV_bias_transpose_kernel<T, PREFIX_PROMPT><<<grid, block, smem_size, stream>>>(q_buf,                    \
                                                                                             k_buf,                    \
                                                                                             v_buf,                    \
                                                                                             QKV,                      \
                                                                                             qkv_bias,                 \
                                                                                             padding_offset,           \
Li Zhang's avatar
Li Zhang committed
1011
                                                                                             context_length,           \
1012
                                                                                             input_length,             \
Li Zhang's avatar
Li Zhang committed
1013
                                                                                             rope_theta,               \
Li Zhang's avatar
Li Zhang committed
1014
1015
1016
                                                                                             batch_size,               \
                                                                                             seq_len,                  \
                                                                                             head_num,                 \
1017
                                                                                             kv_head_num,              \
Li Zhang's avatar
Li Zhang committed
1018
1019
                                                                                             size_per_head,            \
                                                                                             rotary_embedding_dim,     \
Lyu Han's avatar
Lyu Han committed
1020
                                                                                             rotary_embedding_base,    \
1021
1022
1023
                                                                                             max_position_embeddings,  \
                                                                                             use_dynamic_ntk,          \
                                                                                             use_logn_attn);
Li Zhang's avatar
Li Zhang committed
1024
1025

template<typename T>
1026
1027
1028
1029
1030
1031
void invokeAddFusedQKVBiasTranspose(T*           q_buf,
                                    T*           k_buf,
                                    T*           v_buf,
                                    T*           QKV,
                                    const T*     qkv_bias,
                                    const int*   padding_offset,
Li Zhang's avatar
Li Zhang committed
1032
                                    const int*   context_length,
1033
                                    const int*   input_length,
Li Zhang's avatar
Li Zhang committed
1034
                                    const float* rope_theta,
1035
1036
1037
1038
1039
1040
1041
                                    const int    batch_size,
                                    const int    seq_len,
                                    const int    token_num,
                                    const int    head_num,
                                    const int    kv_head_num,
                                    const int    size_per_head,
                                    const int    rotary_embedding_dim,
Lyu Han's avatar
Lyu Han committed
1042
                                    float        rotary_embedding_base,
1043
1044
1045
1046
                                    int          max_position_embeddings,
                                    bool         use_dynamic_ntk,
                                    bool         use_logn_attn,
                                    cudaStream_t stream)
Li Zhang's avatar
Li Zhang committed
1047
{
1048
1049
1050
    FT_CHECK(rotary_embedding_dim);
    // To implement rotary embeddings, each thread processes two QKV elems:
    dim3   block((size_per_head / Vec_t<T>::size + 31) / 32 * 32);
1051
1052
1053
    dim3   grid(token_num, head_num);
    size_t smem_size = 0;
    FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, false);
Li Zhang's avatar
Li Zhang committed
1054
1055
1056
}

#define INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(T)                                                                         \
1057
1058
1059
1060
1061
1062
1063
1064
    template void invokeAddFusedQKVBiasTranspose(T*           q_buf,                                                   \
                                                 T*           k_buf,                                                   \
                                                 T*           v_buf,                                                   \
                                                 T*           QKV,                                                     \
                                                 const T*     qkv_bias,                                                \
                                                 const int*   padding_offset,                                          \
                                                 const int*   history_length,                                          \
                                                 const int*   input_length,                                            \
Li Zhang's avatar
Li Zhang committed
1065
                                                 const float* rope_theta,                                              \
1066
1067
1068
1069
1070
1071
1072
                                                 const int    batch_size,                                              \
                                                 const int    seq_len,                                                 \
                                                 const int    token_num,                                               \
                                                 const int    head_num,                                                \
                                                 const int    kv_head_num,                                             \
                                                 const int    size_per_head,                                           \
                                                 const int    rotary_embedding_dim,                                    \
Lyu Han's avatar
Lyu Han committed
1073
                                                 float        rotary_embedding_base,                                   \
1074
1075
1076
1077
                                                 int          max_position_embeddings,                                 \
                                                 bool         use_dynamic_ntk,                                         \
                                                 bool         use_logn_attn,                                           \
                                                 cudaStream_t stream)
Li Zhang's avatar
Li Zhang committed
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(float);
INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(half);
#ifdef ENABLE_BF16
INSTANTIATEADDFUSEDQKVBIASTRANSPOSE(__nv_bfloat16);
#endif
#undef INSTANTIATEADDFUSEDQKVBIASTRANSPOSE

template<typename T>
__global__ void transpose_4d(T*        dst,
                             T*        src,
                             const int dim0,
                             const int dim1,
                             const int dim2,
                             const int dim3,
                             const int dim0_leading_dim,
                             const int ite)
{
    // transpose from [dim0, dim1, dim2, dim3] to [dim2, X, dim1, dim3]
    // where the dimension of X is dim0_leading_dim, and offset is ite * dim0
    for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < dim0 * dim1 * dim2 * dim3; i += blockDim.x * gridDim.x) {
        int       index = i;
        const int d3    = index % dim3;
        index           = (index - d3) / dim3;
        const int d2    = index % dim2;
        index           = (index - d2) / dim2;
        const int d1    = index % dim1;
        index           = (index - d1) / dim1;
        const int d0    = index % dim0;
        index           = (index - d0) / dim0;
        dst[d2 * dim0_leading_dim * dim1 * dim3 + (d0 + dim0 * ite) * dim1 * dim3 + d1 * dim3 + d3] = src[i];
    }
}

template<>
__global__ void transpose_4d(half*     dst,
                             half*     src,
                             const int dim0,
                             const int dim1,
                             const int dim2,
                             const int dim3,
                             const int dim0_leading_dim,
                             const int ite)
{
    half2*    dst_ptr   = (half2*)dst;
    half2*    src_ptr   = (half2*)src;
    const int half_dim3 = dim3 / 2;
    // transpose from [dim0, dim1, dim2, half_dim3] to [dim2, dim0, dim1, half_dim3]
    // where the dimension of X is dim0_leading_dim, and offset is ite * dim0
    for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < dim0 * dim1 * dim2 * half_dim3;
         i += blockDim.x * gridDim.x) {
        int       index = i;
        const int d3    = index % half_dim3;
        index           = (index - d3) / half_dim3;
        const int d2    = index % dim2;
        index           = (index - d2) / dim2;
        const int d1    = index % dim1;
        index           = (index - d1) / dim1;
        const int d0    = index % dim0;
        index           = (index - d0) / dim0;
        dst_ptr[d2 * dim0_leading_dim * dim1 * half_dim3 + (d0 + dim0 * ite) * dim1 * half_dim3 + d1 * half_dim3 + d3] =
            src_ptr[i];
    }
}

template<typename T>
void invokeTranspose4d(T*           dst,
                       T*           src,
                       const int    local_batch_size,
                       const int    seq_len,
                       const int    size_per_head,
                       const int    local_hidden_units,
                       const int    local_head_num,
                       const int    batch_size,
                       const int    ite,
                       cudaStream_t stream)
{
    transpose_4d<<<local_batch_size * seq_len * local_hidden_units / 512, 512 / (4 / (sizeof(T))), 0, stream>>>(
        dst, src, local_batch_size, local_head_num, seq_len, size_per_head, batch_size, ite);
}

#define INSTANTIATETRANSPOSE4D(T)                                                                                      \
    template void invokeTranspose4d(T*           dst,                                                                  \
                                    T*           src,                                                                  \
                                    const int    local_batch_size,                                                     \
                                    const int    seq_len,                                                              \
                                    const int    size_per_head,                                                        \
                                    const int    local_hidden_units,                                                   \
                                    const int    local_head_num,                                                       \
                                    const int    batch_size,                                                           \
                                    const int    ite,                                                                  \
                                    cudaStream_t stream)
INSTANTIATETRANSPOSE4D(float);
INSTANTIATETRANSPOSE4D(half);
#undef INSTANTIATETRANSPOSE4D
template<typename T>
__global__ void transpose_4d_batch_major_k_cache(
    T* k_dst, const T* k_src, const int head_num, const int size_per_head, const int seq_len, const int max_seq_len)
{
    const int     batch_id = blockIdx.y;
    const int     head_id  = blockIdx.z;
    constexpr int X_ELEMS  = (sizeof(T) == 4) ? 4 : 8;

    auto key_src = reinterpret_cast<const uint4*>(k_src + batch_id * head_num * size_per_head * seq_len
                                                  + head_id * size_per_head * seq_len);
    auto key_dst = reinterpret_cast<uint4*>(k_dst + batch_id * head_num * size_per_head * max_seq_len
                                            + head_id * size_per_head * max_seq_len);

    const int out_idx             = blockIdx.x * blockDim.x + threadIdx.x;
    int       size_per_head_div_x = size_per_head / X_ELEMS;
    if (out_idx >= size_per_head_div_x * max_seq_len) {
        return;
    }

    int       idx            = out_idx;
    const int k_seq_len_id   = idx % max_seq_len;
    idx                      = (idx - k_seq_len_id) / max_seq_len;
    const int k_head_size_id = idx % size_per_head_div_x;

    if (k_seq_len_id < seq_len) {
        key_dst[out_idx] = key_src[k_seq_len_id * size_per_head_div_x + k_head_size_id];
    }
}

template<typename T>
__global__ void transpose_4d_batch_major_v_cache(
    T* v_dst, const T* v_src, const int head_num, const int size_per_head, const int seq_len, const int max_seq_len)
{
    const int batch_id = blockIdx.y;
    const int head_id  = blockIdx.z;

    // 16 byte loads will handle "x" dimension
    auto val_src = reinterpret_cast<const uint4*>(v_src + batch_id * head_num * size_per_head * seq_len
                                                  + head_id * size_per_head * seq_len);
    auto val_dst = reinterpret_cast<uint4*>(v_dst + batch_id * head_num * size_per_head * max_seq_len
                                            + head_id * size_per_head * max_seq_len);

    // idx is over output dimension L * size_per_head / x for values
    const int idx = blockIdx.x * blockDim.x + threadIdx.x;

    constexpr int X_ELEMS             = (sizeof(T) == 4) ? 4 : 8;
    const int     size_per_head_div_x = size_per_head / X_ELEMS;

    if (idx >= size_per_head_div_x * seq_len) {
        return;
    }

    val_dst[idx] = val_src[idx];
}

template<typename T>
void invokeTranspose4dBatchMajor(T*           k_dst,
                                 T*           v_dst,
                                 const T*     k_src,
                                 const T*     v_src,
                                 const int    local_batch_size,
                                 const int    seq_len,
                                 const int    max_seq_len,
                                 const int    size_per_head,
                                 const int    local_head_num,
                                 cudaStream_t stream)
{
    constexpr int block_sz = 128;
    constexpr int x        = (sizeof(T) == 4) ? 4 : 8;
    int           size     = max_seq_len * size_per_head / x;
    dim3          grid((size + block_sz - 1) / block_sz, local_batch_size, local_head_num);
    dim3          grid_v((seq_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_head_num);

    transpose_4d_batch_major_k_cache<<<grid, block_sz, 0, stream>>>(
        k_dst, k_src, local_head_num, size_per_head, seq_len, max_seq_len);

    transpose_4d_batch_major_v_cache<<<grid_v, block_sz, 0, stream>>>(
        v_dst, v_src, local_head_num, size_per_head, seq_len, max_seq_len);
}

#define INSTANTIATETRANSPOSE4DBATCHMAJOR(T)                                                                            \
    template void invokeTranspose4dBatchMajor(T*           k_dst,                                                      \
                                              T*           v_dst,                                                      \
                                              const T*     k_src,                                                      \
                                              const T*     v_src,                                                      \
                                              const int    local_batch_size,                                           \
                                              const int    seq_len,                                                    \
                                              const int    max_seq_len,                                                \
                                              const int    size_per_head,                                              \
                                              const int    local_head_num,                                             \
                                              cudaStream_t stream)
INSTANTIATETRANSPOSE4DBATCHMAJOR(float);
INSTANTIATETRANSPOSE4DBATCHMAJOR(half);
#ifdef ENABLE_BF16
INSTANTIATETRANSPOSE4DBATCHMAJOR(__nv_bfloat16);
#endif
#undef INSTANTIATETRANSPOSE4DBATCHMAJOR

template<typename T>
__global__ void addRelativeAttentionBias(
    T* qk_buf, const T* relative_attention_bias, const int batch_size, const int head_num, const int seq_len)
{
    for (int i = threadIdx.x; i < batch_size * seq_len; i += blockDim.x) {
        int batch_id = i / seq_len;
        int seq_id   = i % seq_len;

        const int bias_index = blockIdx.x * seq_len + seq_id;
        const int qk_index   = batch_id * gridDim.x * seq_len + bias_index;
        qk_buf[qk_index]     = add(qk_buf[qk_index], relative_attention_bias[bias_index]);
    }
}

template<typename T>
void invokeAddRelativeAttentionBias(T*           qk_buf,
                                    const T*     relative_attention_bias,
                                    const int    batch_size,
                                    const int    head_num,
                                    const int    seq_len,
                                    cudaStream_t stream)
{
    // qk_buf: [batch_size, head_num, seq_len, seq_len]
    // relative_attention_bias: [1, head_num, seq_len, seq_len]
    dim3 grid(head_num * seq_len);
    dim3 block(512);
    using T2 = typename TypeConverter<T>::Type;
#ifdef ENABLE_BF16
    const bool is_half2 = (std::is_same<T, half>::value || std::is_same<T, __nv_bfloat16>::value) && (seq_len % 2 == 0);
#else
    const bool is_half2 = (std::is_same<T, half>::value) && (seq_len % 2 == 0);
#endif
    if (is_half2) {
        addRelativeAttentionBias<T2><<<grid, block, 0, stream>>>(
            (T2*)qk_buf, (const T2*)relative_attention_bias, batch_size, head_num, seq_len / 2);
    }
    else {
        addRelativeAttentionBias<<<grid, block, 0, stream>>>(
            qk_buf, relative_attention_bias, batch_size, head_num, seq_len);
    }
}

#define INSTANTIATEADDRELATIVEATTENTIONBIAS(T)                                                                         \
    template void invokeAddRelativeAttentionBias(T*           qk_buf,                                                  \
                                                 const T*     relative_attention_bias,                                 \
                                                 const int    batch_size,                                              \
                                                 const int    head_num,                                                \
                                                 const int    seq_len,                                                 \
                                                 cudaStream_t stream)
INSTANTIATEADDRELATIVEATTENTIONBIAS(float);
INSTANTIATEADDRELATIVEATTENTIONBIAS(half);
#ifdef ENABLE_BF16
INSTANTIATEADDRELATIVEATTENTIONBIAS(__nv_bfloat16);
#endif
#undef INSTANTIATEADDRELATIVEATTENTIONBIAS

/*******************  invokeAddHead3SizeQKVBias  ***********************/
// m = batch*window_num*window_len
// mm_qkv is [m, head*3*size_per_head] row-major
// bias_qkv is [head*3*size_per_head]
// q_buf_, k_buf_, v_buf_ is [batch*window_num, num_head, window_len, size_per_head] row-major
// grid(window_len, window_num, 3*batch);
// block(num_head * size_per_head)
template<typename T>
__global__ void add_head3Size_QKV_bias(const T*  mm_qkv,
                                       const T*  bias_qkv,
                                       T*        q_buf_,
                                       T*        k_buf_,
                                       T*        v_buf_,
                                       const int batch,
                                       const int window_num,
                                       const int window_len,
                                       const int num_head,
                                       const int size_per_head)
{

    T*  buf_ptr;
    int qkv_id = blockIdx.z / batch;
    if (qkv_id == 0) {
        buf_ptr = q_buf_;
    }
    else if (qkv_id == 1) {
        buf_ptr = k_buf_;
    }
    else {
        buf_ptr = v_buf_;
    }

    const int batch_id   = blockIdx.z % batch;
    const int token_id   = blockIdx.x;
    const int window_id  = blockIdx.y;
    const int head_id    = threadIdx.x / size_per_head;
    const int id_in_head = threadIdx.x % size_per_head;

    const int bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
    const T   bias     = ldg(bias_qkv + bias_idx);

    const int input_idx =
        ((batch_id * window_num + window_id) * window_len + token_id) * num_head * 3 * size_per_head + bias_idx;
    T tmp = mm_qkv[input_idx] + bias;

    int target_id = (((batch_id * window_num + window_id) * num_head + head_id) * window_len + token_id) * size_per_head
                    + id_in_head;
    ;
    buf_ptr[target_id] = tmp;
}

// for float2, size_per_head /= 2
// m = batch*window_num*window_len
// mm_qkv is [m, head*3*size_per_head] row-major
// bias_qkv is [head*3*size_per_head]
// q_buf_, k_buf_, v_buf_ is [batch*window_num, num_head, window_len, size_per_head] row-major
// grid(window_len, window_num, 3*batch);
// block(num_head * size_per_head)
template<>
__global__ void add_head3Size_QKV_bias(const float2* mm_qkv,
                                       const float2* bias_qkv,
                                       float2*       q_buf_,
                                       float2*       k_buf_,
                                       float2*       v_buf_,
                                       const int     batch,
                                       const int     window_num,
                                       const int     window_len,
                                       const int     num_head,
                                       const int     size_per_head)
{

    float2* buf_ptr;
    int     qkv_id = blockIdx.z / batch;
    if (qkv_id == 0) {
        buf_ptr = q_buf_;
    }
    else if (qkv_id == 1) {
        buf_ptr = k_buf_;
    }
    else {
        buf_ptr = v_buf_;
    }

    const int batch_id   = blockIdx.z % batch;
    const int token_id   = blockIdx.x;
    const int window_id  = blockIdx.y;
    const int head_id    = threadIdx.x / size_per_head;
    const int id_in_head = threadIdx.x % size_per_head;

    const int    bias_idx = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
    const float2 bias     = ldg(bias_qkv + bias_idx);

    const int input_idx =
        ((batch_id * window_num + window_id) * window_len + token_id) * num_head * 3 * size_per_head + bias_idx;
    float2 tmp = mm_qkv[input_idx];
    tmp.x += bias.x;
    tmp.y += bias.y;

    int target_id = (((batch_id * window_num + window_id) * num_head + head_id) * window_len + token_id) * size_per_head
                    + id_in_head;
    ;
    buf_ptr[target_id] = tmp;
}

// for half2, size_per_head /= 2
// m = batch*window_num*window_len
// mm_qkv is [m, head*3*size_per_head] row-major
// bias_qkv is [head*3*size_per_head]
// q_buf_, k_buf_, v_buf_ is [batch*window_num, num_head, window_len, size_per_head] row-major
// grid(window_len, window_num, batch);
// block(num_head * size_per_head)
template<>
__global__ void add_head3Size_QKV_bias(const half2* mm_qkv,
                                       const half2* bias_qkv,
                                       half2*       q_buf_,
                                       half2*       k_buf_,
                                       half2*       v_buf_,
                                       const int    batch,
                                       const int    window_num,
                                       const int    window_len,
                                       const int    num_head,
                                       const int    size_per_head)
{

    const int batch_id   = blockIdx.z;
    const int token_id   = blockIdx.x;
    const int window_id  = blockIdx.y;
    const int head_id    = threadIdx.x / size_per_head;
    const int id_in_head = threadIdx.x % size_per_head;

    const int input_offset =
        ((batch_id * window_num + window_id) * window_len + token_id) * num_head * 3 * size_per_head;
    const int target_id =
        (((batch_id * window_num + window_id) * num_head + head_id) * window_len + token_id) * size_per_head
        + id_in_head;

    int   qkv_id      = 0;
    int   bias_idx    = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
    half2 bias        = __ldg(bias_qkv + bias_idx);
    int   input_idx   = input_offset + bias_idx;
    half2 tmp         = mm_qkv[input_idx];
    tmp               = __hadd2(tmp, bias);
    q_buf_[target_id] = tmp;

    qkv_id            = 1;
    bias_idx          = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
    bias              = __ldg(bias_qkv + bias_idx);
    input_idx         = input_offset + bias_idx;
    tmp               = mm_qkv[input_idx];
    tmp               = __hadd2(tmp, bias);
    k_buf_[target_id] = tmp;

    qkv_id            = 2;
    bias_idx          = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
    bias              = __ldg(bias_qkv + bias_idx);
    input_idx         = input_offset + bias_idx;
    tmp               = mm_qkv[input_idx];
    tmp               = __hadd2(tmp, bias);
    v_buf_[target_id] = tmp;
}

#ifdef ENABLE_BF16
template<>
__global__ void add_head3Size_QKV_bias(const __nv_bfloat162* mm_qkv,
                                       const __nv_bfloat162* bias_qkv,
                                       __nv_bfloat162*       q_buf_,
                                       __nv_bfloat162*       k_buf_,
                                       __nv_bfloat162*       v_buf_,
                                       const int             batch,
                                       const int             window_num,
                                       const int             window_len,
                                       const int             num_head,
                                       const int             size_per_head)
{

    const int batch_id   = blockIdx.z;
    const int token_id   = blockIdx.x;
    const int window_id  = blockIdx.y;
    const int head_id    = threadIdx.x / size_per_head;
    const int id_in_head = threadIdx.x % size_per_head;

    const int input_offset =
        ((batch_id * window_num + window_id) * window_len + token_id) * num_head * 3 * size_per_head;
    const int target_id =
        (((batch_id * window_num + window_id) * num_head + head_id) * window_len + token_id) * size_per_head
        + id_in_head;

    int            qkv_id    = 0;
    int            bias_idx  = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
    __nv_bfloat162 bias      = ldg(bias_qkv + bias_idx);
    int            input_idx = input_offset + bias_idx;
    __nv_bfloat162 tmp       = mm_qkv[input_idx];
    tmp                      = bf16hadd2(tmp, bias);
    q_buf_[target_id]        = tmp;

    qkv_id            = 1;
    bias_idx          = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
    bias              = ldg(bias_qkv + bias_idx);
    input_idx         = input_offset + bias_idx;
    tmp               = mm_qkv[input_idx];
    tmp               = bf16hadd2(tmp, bias);
    k_buf_[target_id] = tmp;

    qkv_id            = 2;
    bias_idx          = (head_id * 3 + qkv_id) * size_per_head + id_in_head;
    bias              = ldg(bias_qkv + bias_idx);
    input_idx         = input_offset + bias_idx;
    tmp               = mm_qkv[input_idx];
    tmp               = bf16hadd2(tmp, bias);
    v_buf_[target_id] = tmp;
}
#endif

template<typename T>
void invokeAddHead3SizeQKVBias(const T*     mm_qkv,
                               const T*     bias_qkv,
                               T*           q_buf_,
                               T*           k_buf_,
                               T*           v_buf_,
                               const int    batch,
                               const int    window_num,
                               const int    window_len,
                               const int    num_head,
                               const int    size_per_head,
                               cudaStream_t stream)
{
    if (std::is_same<T, float>::value) {
        dim3 grid(window_len, window_num, 3 * batch);
        dim3 block(num_head * size_per_head);

        if (block.x < 1024) {
            add_head3Size_QKV_bias<<<grid, block, 0, stream>>>(
                mm_qkv, bias_qkv, q_buf_, k_buf_, v_buf_, batch, window_num, window_len, num_head, size_per_head);
        }
        else if ((block.x % 2 == 0) && (block.x / 2 < 1024)) {
            block.x /= 2;
            add_head3Size_QKV_bias<<<grid, block, 0, stream>>>((const float2*)mm_qkv,
                                                               (const float2*)bias_qkv,
                                                               (float2*)q_buf_,
                                                               (float2*)k_buf_,
                                                               (float2*)v_buf_,
                                                               batch,
                                                               window_num,
                                                               window_len,
                                                               num_head,
                                                               size_per_head / 2);
        }
        else {
            printf("[ERROR][invokeAddHead3SizeQKVBias] unsupported block.x!\n");
            exit(-1);
        }
    }
#ifdef ENABLE_BF16
    else if (std::is_same<T, half>::value || std::is_same<T, __nv_bfloat16>::value) {
#else
    else if (std::is_same<T, half>::value) {
#endif
        dim3 grid(window_len, window_num, batch);
        dim3 block(num_head * size_per_head / 2);

        using T2 = typename TypeConverter<T>::Type;  // half2 or bfloat16

        if (block.x > 1024) {
            printf("[ERROR][invokeAddHead3SizeQKVBias] block.x > 1024!\n");
            exit(-1);
        }

        add_head3Size_QKV_bias<<<grid, block, 0, stream>>>((const T2*)mm_qkv,
                                                           (const T2*)bias_qkv,
                                                           (T2*)q_buf_,
                                                           (T2*)k_buf_,
                                                           (T2*)v_buf_,
                                                           batch,
                                                           window_num,
                                                           window_len,
                                                           num_head,
                                                           size_per_head / 2);
    }
}

#define INSTANTIATEADDHEAD3SIZEQKVBIAS(T)                                                                              \
    template void invokeAddHead3SizeQKVBias<T>(const T*     mm_qkv,                                                    \
                                               const T*     bias_qkv,                                                  \
                                               T*           q_buf_,                                                    \
                                               T*           k_buf_,                                                    \
                                               T*           v_buf_,                                                    \
                                               const int    batch,                                                     \
                                               const int    window_num,                                                \
                                               const int    window_len,                                                \
                                               const int    num_head,                                                  \
                                               const int    size_per_head,                                             \
                                               cudaStream_t stream)
INSTANTIATEADDHEAD3SIZEQKVBIAS(float);
INSTANTIATEADDHEAD3SIZEQKVBIAS(half);
#ifdef ENABLE_BF16
INSTANTIATEADDHEAD3SIZEQKVBIAS(__nv_bfloat16);
#endif
#undef INSTANTIATEADDHEAD3SIZEQKVBIAS

/*******************  invokeMaskedSoftMaxWithRelPosBias  ***********************/

// grid = (window_len/word_per_thread, window_num*num_head, batch_size)
// block.x = max(32, (window_len + 31)/32*32)
// qk_buf is [batch, window_num, num_head, window_len, window_len]
// attn_mask is [window_num, window_len, window_len] + row-major
// relative_pos_bias is [num_head, window_len, window_len] + row-majot
template<typename T>
__global__ void softmax_withRelPosBias_element1_kernel(T*          qk_buf,
                                                       const T*    attn_mask,
                                                       const T*    relative_pos_bias,
                                                       const int   batch_size,
                                                       const int   num_head,
                                                       const int   window_num,
                                                       const int   window_len,
                                                       const int   window_len_x_window_len,
                                                       const float qk_scale)
{

    bool qual = threadIdx.x < window_len;
    for (int window_id = blockIdx.x; window_id < window_len; window_id += gridDim.x) {
        float            tmp = -1e20f;
        __shared__ float s_mean, s_max;
        int              qk_offset;
        if (qual) {
            const int offset_in_window = window_id * window_len + threadIdx.x;
            qk_offset = (blockIdx.z * gridDim.y + blockIdx.y) * window_len_x_window_len + offset_in_window;
            const int relative_pos_bias_offset = (blockIdx.y % num_head) * window_len_x_window_len + offset_in_window;
            float     mask_val =
                (attn_mask == nullptr) ?
                        0.0f :
                        static_cast<float>(
                        ldg(attn_mask + ((blockIdx.y / num_head) * window_len_x_window_len + offset_in_window)));
            tmp = qk_scale * static_cast<float>(qk_buf[qk_offset]) + mask_val
                  + static_cast<float>(ldg(relative_pos_bias + relative_pos_bias_offset));
        }

        float max_val = blockReduceMax<float>(tmp);
        if (threadIdx.x == 0) {
            s_max = max_val;
        }
        __syncthreads();

        float qk_tmp  = qual ? __expf(tmp - s_max) : 0.0f;
        float sum_val = blockReduceSum<float>(qk_tmp);
        if (threadIdx.x == 0) {
            s_mean = sum_val + 1e-6f;
            s_mean = __fdividef(1.0f, s_mean);
        }
        __syncthreads();
        if (qual) {
            qk_buf[qk_offset] = (T)(qk_tmp * s_mean);
        }
    }
}

// grid = (window_len/word_per_thread, window_num*num_head, batch_size)
// block.x = max(32, (window_len/2 + 31)/32*32)
// qk_buf is [batch, window_num, num_head, window_len, window_len]
// attn_mask is [window_num, window_len, window_len] + row-major
// relative_pos_bias is [num_head, window_len, window_len] + row-majot
template<typename T2, typename T>
__global__ void softmax_withRelPosBias_element2_kernel(T2*         qk_buf,
                                                       const T2*   attn_mask,
                                                       const T2*   relative_pos_bias,
                                                       const int   batch_size,
                                                       const int   num_head,
                                                       const int   window_num,
                                                       const int   window_len,
                                                       const int   window_len_x_window_len,
                                                       const float qk_scale)
{
    const int window_len_2 = window_len / 2;
    const int tidx         = threadIdx.x;
    bool      qual         = tidx < window_len_2;
    const T2  zero         = {T(0.0f), T(0.0f)};
    const int bdim         = blockDim.x;
    for (int window_id = blockIdx.x; window_id < window_len; window_id += gridDim.x) {
        float            tmp = -1e20f;
        __shared__ float s_mean, s_max;
        int              qk_offset;
        float2           local_qk_val;
        T2               qk_val;
        if (qual) {
            const int offset_in_window = window_id * window_len + 2 * tidx;
            qk_offset = ((blockIdx.z * gridDim.y + blockIdx.y) * window_len_x_window_len + offset_in_window) / 2;
            const int relative_pos_bias_offset =
                ((blockIdx.y % num_head) * window_len_x_window_len + offset_in_window) / 2;
            T2 mask_val =
                (attn_mask == nullptr) ?
                    zero :
                    ldg(attn_mask + ((blockIdx.y / num_head) * window_len_x_window_len + offset_in_window) / 2);
            qk_val            = qk_buf[qk_offset];
            local_qk_val.x    = static_cast<float>(qk_val.x);
            local_qk_val.y    = static_cast<float>(qk_val.y);
            const T2 bias_val = ldg(relative_pos_bias + relative_pos_bias_offset);
            local_qk_val.x =
                qk_scale * local_qk_val.x + static_cast<float>(mask_val.x) + static_cast<float>(bias_val.x);
            local_qk_val.y =
                qk_scale * local_qk_val.y + static_cast<float>(mask_val.y) + static_cast<float>(bias_val.y);
            tmp = local_qk_val.x > local_qk_val.y ? local_qk_val.x : local_qk_val.y;
        }

        float max_val = bdim <= 32 ? warpReduceMax<float>(tmp) : blockReduceMax<float>(tmp);
        if (tidx == 0) {
            s_max = max_val;
        }
        __syncthreads();

        local_qk_val.x = qual ? __expf(local_qk_val.x - s_max) : 0.0f;
        local_qk_val.y = qual ? __expf(local_qk_val.y - s_max) : 0.0f;

        float sum_val = bdim <= 32 ? warpReduceSum<float>(local_qk_val.x + local_qk_val.y) :
                                     blockReduceSum<float>(local_qk_val.x + local_qk_val.y);
        if (tidx == 0) {
            s_mean = sum_val + 1e-6f;
            s_mean = __fdividef(1.0f, s_mean);
        }
        __syncthreads();
        if (qual) {
            local_qk_val.x    = local_qk_val.x * s_mean;
            local_qk_val.y    = local_qk_val.y * s_mean;
            qk_val.x          = T(local_qk_val.x);
            qk_val.y          = T(local_qk_val.y);
            qk_buf[qk_offset] = qk_val;
        }
    }
}

// grid = (window_len/word_per_thread, window_num*num_head, batch_size)
// block.x = max(32, (window_len/4 + 31)/32*32)
// qk_buf is [batch, window_num, num_head, window_len, window_len]
// attn_mask is [window_num, window_len, window_len] + row-major
// relative_pos_bias is [num_head, window_len, window_len] + row-majot
template<typename T4, typename T>
__global__ void softmax_withRelPosBias_element4_kernel(T4*         qk_buf,
                                                       const T4*   attn_mask,
                                                       const T4*   relative_pos_bias,
                                                       const int   batch_size,
                                                       const int   num_head,
                                                       const int   window_num,
                                                       const int   window_len,
                                                       const int   window_len_x_window_len,
                                                       const float qk_scale)
{
    const int window_len_4 = window_len / 4;
    const int tidx         = threadIdx.x;
    bool      qual         = tidx < window_len_4;
    const T4  zero         = {T(0.0f), T(0.0f), T(0.0f), T(0.0f)};
    const int bdim         = blockDim.x;
    for (int window_id = blockIdx.x; window_id < window_len; window_id += gridDim.x) {
        float            tmp = -1e20f;
        __shared__ float s_mean, s_max;
        int              qk_offset;
        float4           local_qk_val;
        T4               qk_val;
        if (qual) {
            const int offset_in_window = window_id * window_len + 4 * tidx;
            qk_offset = ((blockIdx.z * gridDim.y + blockIdx.y) * window_len_x_window_len + offset_in_window) / 4;
            const int relative_pos_bias_offset =
                ((blockIdx.y % num_head) * window_len_x_window_len + offset_in_window) / 4;
            T4 mask_val       = (attn_mask == nullptr) ?
                                    zero :
                                    attn_mask[((blockIdx.y / num_head) * window_len_x_window_len + offset_in_window) / 4];
            qk_val            = qk_buf[qk_offset];
            local_qk_val.x    = static_cast<float>(qk_val.x);
            local_qk_val.y    = static_cast<float>(qk_val.y);
            local_qk_val.z    = static_cast<float>(qk_val.z);
            local_qk_val.w    = static_cast<float>(qk_val.w);
            const T4 bias_val = relative_pos_bias[relative_pos_bias_offset];
            local_qk_val.x =
                qk_scale * local_qk_val.x + static_cast<float>(mask_val.x) + static_cast<float>(bias_val.x);
            local_qk_val.y =
                qk_scale * local_qk_val.y + static_cast<float>(mask_val.y) + static_cast<float>(bias_val.y);
            local_qk_val.z =
                qk_scale * local_qk_val.z + static_cast<float>(mask_val.z) + static_cast<float>(bias_val.z);
            local_qk_val.w =
                qk_scale * local_qk_val.w + static_cast<float>(mask_val.w) + static_cast<float>(bias_val.w);
            tmp = local_qk_val.x > local_qk_val.y ? local_qk_val.x : local_qk_val.y;
            tmp = tmp > local_qk_val.z ? tmp : local_qk_val.z;
            tmp = tmp > local_qk_val.w ? tmp : local_qk_val.w;
        }

        float max_val = bdim <= 32 ? warpReduceMax<float>(tmp) : blockReduceMax<float>(tmp);
        if (tidx == 0) {
            s_max = max_val;
        }
        __syncthreads();

        local_qk_val.x = qual ? __expf(local_qk_val.x - s_max) : 0.0f;
        local_qk_val.y = qual ? __expf(local_qk_val.y - s_max) : 0.0f;
        local_qk_val.z = qual ? __expf(local_qk_val.z - s_max) : 0.0f;
        local_qk_val.w = qual ? __expf(local_qk_val.w - s_max) : 0.0f;

        float sum_val = bdim <= 32 ?
                            warpReduceSum<float>(local_qk_val.x + local_qk_val.y + local_qk_val.z + local_qk_val.w) :
                            blockReduceSum<float>(local_qk_val.x + local_qk_val.y + local_qk_val.z + local_qk_val.w);
        if (tidx == 0) {
            s_mean = sum_val + 1e-6f;
            s_mean = __fdividef(1.0f, s_mean);
        }
        __syncthreads();
        if (qual) {
            local_qk_val.x    = local_qk_val.x * s_mean;
            local_qk_val.y    = local_qk_val.y * s_mean;
            local_qk_val.z    = local_qk_val.z * s_mean;
            local_qk_val.w    = local_qk_val.w * s_mean;
            qk_val.x          = T(local_qk_val.x);
            qk_val.y          = T(local_qk_val.y);
            qk_val.z          = T(local_qk_val.z);
            qk_val.w          = T(local_qk_val.w);
            qk_buf[qk_offset] = qk_val;
        }
    }
}

template<typename T>
void invokeMaskedSoftMaxWithRelPosBias(T*           qk_buf,
                                       const T*     attn_mask,
                                       const T*     relative_pos_bias,
                                       const int    batch_size,
                                       const int    num_head,
                                       const int    window_num,
                                       const int    window_len,
                                       float        qk_scale,
                                       cudaStream_t stream)
{
    const int word_per_thread = 1;
    dim3      grid((window_len + word_per_thread - 1) / word_per_thread, window_num * num_head, batch_size);
    if ((window_len % 4 == 0) && window_len / 4 >= 32) {
        dim3 block((window_len / 4 + 31) / 32 * 32);
        if (std::is_same<T, float>::value) {
            softmax_withRelPosBias_element4_kernel<float4, float>
                <<<grid, block, 0, stream>>>((float4*)qk_buf,
                                             (const float4*)attn_mask,
                                             (const float4*)relative_pos_bias,
                                             batch_size,
                                             num_head,
                                             window_num,
                                             window_len,
                                             window_len * window_len,
                                             qk_scale);
        }
        else if (std::is_same<T, half>::value) {
            softmax_withRelPosBias_element4_kernel<half4, half>
                <<<grid, block, 0, stream>>>((half4*)qk_buf,
                                             (const half4*)attn_mask,
                                             (const half4*)relative_pos_bias,
                                             batch_size,
                                             num_head,
                                             window_num,
                                             window_len,
                                             window_len * window_len,
                                             qk_scale);
        }
#ifdef ENABLE_BF16
        else {
            dim3 block((window_len + 31) / 32 * 32);
            softmax_withRelPosBias_element1_kernel<<<grid, block, 0, stream>>>(qk_buf,
                                                                               attn_mask,
                                                                               relative_pos_bias,
                                                                               batch_size,
                                                                               num_head,
                                                                               window_num,
                                                                               window_len,
                                                                               window_len * window_len,
                                                                               qk_scale);
        }
#endif
    }
    else if (window_len % 2 == 0) {
        dim3 block((window_len / 2 + 31) / 32 * 32);
        if (std::is_same<T, float>::value) {
            softmax_withRelPosBias_element2_kernel<float2, float>
                <<<grid, block, 0, stream>>>((float2*)qk_buf,
                                             (const float2*)attn_mask,
                                             (const float2*)relative_pos_bias,
                                             batch_size,
                                             num_head,
                                             window_num,
                                             window_len,
                                             window_len * window_len,
                                             qk_scale);
        }
        else if (std::is_same<T, half>::value) {
xiabo's avatar
xiabo committed
1910
            printf("============xiabo_test %s:%d\n", __FILE__,__LINE__);
Li Zhang's avatar
Li Zhang committed
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
            softmax_withRelPosBias_element2_kernel<half2, half>
                <<<grid, block, 0, stream>>>((half2*)qk_buf,
                                             (const half2*)attn_mask,
                                             (const half2*)relative_pos_bias,
                                             batch_size,
                                             num_head,
                                             window_num,
                                             window_len,
                                             window_len * window_len,
                                             qk_scale);
        }
#ifdef ENABLE_BF16
        else {
            dim3 block((window_len + 31) / 32 * 32);
            softmax_withRelPosBias_element1_kernel<<<grid, block, 0, stream>>>(qk_buf,
                                                                               attn_mask,
                                                                               relative_pos_bias,
                                                                               batch_size,
                                                                               num_head,
                                                                               window_num,
                                                                               window_len,
                                                                               window_len * window_len,
                                                                               qk_scale);
        }
#endif
    }
    else {
        dim3 block((window_len + 31) / 32 * 32);
        softmax_withRelPosBias_element1_kernel<<<grid, block, 0, stream>>>(qk_buf,
                                                                           attn_mask,
                                                                           relative_pos_bias,
                                                                           batch_size,
                                                                           num_head,
                                                                           window_num,
                                                                           window_len,
                                                                           window_len * window_len,
                                                                           qk_scale);
    }
}

#define INSTANTIATEMASKEDSOFTMAXWITHRELPOSBIAS(T)                                                                      \
    template void invokeMaskedSoftMaxWithRelPosBias(T*           qk_buf,                                               \
                                                    const T*     attn_mask,                                            \
                                                    const T*     relative_pos_bias,                                    \
                                                    const int    batch_size,                                           \
                                                    const int    num_head,                                             \
                                                    const int    window_num,                                           \
                                                    const int    window_len,                                           \
                                                    const float  qk_scale,                                             \
                                                    cudaStream_t stream)
INSTANTIATEMASKEDSOFTMAXWITHRELPOSBIAS(float);
INSTANTIATEMASKEDSOFTMAXWITHRELPOSBIAS(half);
#ifdef ENABLE_BF16
INSTANTIATEMASKEDSOFTMAXWITHRELPOSBIAS(__nv_bfloat16);
#endif
#undef INSTANTIATEMASKEDSOFTMAXWITHRELPOSBIAS

template<typename T>
__global__ void transpose_attentions(
    T* attentions_out, const T* attentions_in, size_t batch_size, size_t num_layers, size_t num_heads, size_t seq_len)
{
    // attentions_in  shape [B, H, S, S]
    // attentions_out shape [B, L, H, S, S].
    // Note that we write the L dimension as if it was index 0.
    // In reality, the pointer has already been shifted to point to the correct layer.

    const auto batch_idx = blockIdx.x;
    const auto head_idx  = blockIdx.y;

    const auto dst_offset = (batch_idx * num_layers * num_heads + head_idx) * seq_len * seq_len;
    const auto src_offset = (batch_idx * num_heads + head_idx) * seq_len * seq_len;

    for (auto x = threadIdx.x; x < seq_len * seq_len; x += blockDim.x) {
        attentions_out[dst_offset + x] = attentions_in[src_offset + x];
    }
}

template<typename T>
void invokeTransposeAttentions(Tensor& attentions_out, const Tensor& attentions_in, cudaStream_t stream)
{
    const size_t batch_size = attentions_in.shape[0];
    const size_t num_heads  = attentions_in.shape[1];
    const size_t seq_len    = attentions_in.shape[2];
    const size_t num_layers = attentions_out.shape[1];

    const dim3 gridSize(batch_size, num_heads);
    const dim3 blockSize(512);

    transpose_attentions<<<gridSize, blockSize, 0, stream>>>(
        attentions_out.getPtr<T>(), attentions_in.getPtr<const T>(), batch_size, num_layers, num_heads, seq_len);
}

#define INSTANTIATETRANSPOSEATTENTIONS(T)                                                                              \
    template void invokeTransposeAttentions<T>(                                                                        \
        Tensor & attentions_out, const Tensor& attentions_in, cudaStream_t stream)
INSTANTIATETRANSPOSEATTENTIONS(float);
INSTANTIATETRANSPOSEATTENTIONS(half);
#ifdef ENABLE_BF16
INSTANTIATETRANSPOSEATTENTIONS(__nv_bfloat16);
#endif
#undef INSTANTIATETRANSPOSEATTENTIONS

lvhan028's avatar
lvhan028 committed
2013
}  // namespace turbomind