llama_kernels.cu 26.3 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
// Copyright (c) OpenMMLab. All rights reserved.

lvhan028's avatar
lvhan028 committed
3
#include "src/turbomind/kernels/decoder_masked_multihead_attention_utils.h"
Li Zhang's avatar
Li Zhang committed
4
5
#include "src/turbomind/kernels/decoder_multihead_attention/array_ops.h"
#include "src/turbomind/kernels/gemm_s_f16/common.h"
lvhan028's avatar
lvhan028 committed
6
#include "src/turbomind/kernels/reduce_kernel_utils.cuh"
Chen Xin's avatar
Chen Xin committed
7
#include "src/turbomind/macro.h"
lvhan028's avatar
lvhan028 committed
8
9
10
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_type_utils.cuh"
Li Zhang's avatar
Li Zhang committed
11
12
#include "src/turbomind/utils/logger.h"
#include <type_traits>
Li Zhang's avatar
Li Zhang committed
13

lvhan028's avatar
lvhan028 committed
14
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

// fp16, bf16
// n is divided by 2 for this impl
template<typename T>
__global__ void rootMeanSquareNorm(T* out, const T* input, const T* scale, float eps, int m, int n)
{
    using T2 = typename TypeConverter<T>::Type;
    __shared__ float s_inv_mean;
    float            mean = 0.f;

    T2*       out_ptr   = (T2*)out;
    const T2* input_ptr = (const T2*)input;
    const T2* scale_ptr = (const T2*)scale;

    for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) {
        float2 tmp2 = cuda_cast<float2>(input_ptr[blockIdx.x * n + idx]);
        mean += tmp2.x * tmp2.x;
        mean += tmp2.y * tmp2.y;
    }

    mean = blockReduceSum<float>(mean);
    if (threadIdx.x == 0) {
        s_inv_mean = rsqrt(.5f * mean / (float)n + eps);
    }
    __syncthreads();

    for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) {
        float2 tmp2                   = cuda_cast<float2>(input_ptr[blockIdx.x * n + idx]);
        float2 sca2                   = cuda_cast<float2>(scale_ptr[idx]);
        tmp2.x                        = tmp2.x * s_inv_mean * sca2.x;
        tmp2.y                        = tmp2.y * s_inv_mean * sca2.y;
        out_ptr[blockIdx.x * n + idx] = cuda_cast<T2>(tmp2);
    }
}

template<>
__global__ void rootMeanSquareNorm(float* out, const float* input, const float* scale, float eps, int m, int n)
{
    __shared__ float s_inv_mean;
    float            mean = 0.f;

    for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) {
        float tmp = input[blockIdx.x * n + idx];
        mean += tmp * tmp;
    }

    mean = blockReduceSum<float>(mean);
    if (threadIdx.x == 0) {
        s_inv_mean = rsqrt(mean / static_cast<float>(n) + eps);
    }
    __syncthreads();

    for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) {
        float tmp                 = input[blockIdx.x * n + idx];
        out[blockIdx.x * n + idx] = tmp * s_inv_mean * scale[idx];
    }
}

template<typename T>
void invokeRootMeanSquareNorm(T* out, const T* input, const T* scale, float eps, int m, int n, cudaStream_t stream)
{
    if (sizeof(T) == 2) {
        FT_CHECK(n % 2 == 0);
        n /= 2;
    }
    dim3 grid(m);
    dim3 block(std::min(n, 1024));
    rootMeanSquareNorm<<<grid, block, 0, stream>>>(out, input, scale, eps, m, n);
}

template void invokeRootMeanSquareNorm(float*, const float*, const float*, float, int, int, cudaStream_t);
template void invokeRootMeanSquareNorm(half*, const half*, const half*, float, int, int, cudaStream_t);

// #ifdef ENABLE_BF16

// template void invokeRootMeanSquareNorm(__nv_bfloat16*, const __nv_bfloat16*, float, int, int, cudaStream_t);

// #endif

template<typename T, typename T0>
__device__ T saturate_cast(T0 x)
{
    return x;
}

template<>
__device__ half saturate_cast<half, float>(float x)
{
    return (x > 64512.f || x < -64512.f) ? (x > 0.f ? 64512.f : -64512.f) : x;
}

template<typename T>
__global__ void addResidual(T* out, const T* in, size_t n)
{
    auto idx = threadIdx.x + (size_t)blockIdx.x * blockDim.x;
    if (idx < n) {
        out[idx] = static_cast<T>(static_cast<float>(out[idx]) + static_cast<float>(in[idx]));
    }
}

template<typename T>
void invokeAddResidual(T* out, const T* in, int m, int n, cudaStream_t stream)
{
    auto total = static_cast<size_t>(m) * n;
Chen Xin's avatar
Chen Xin committed
119
    dim3 block(std::min((unsigned long)total, 1024UL));
Li Zhang's avatar
Li Zhang committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    dim3 grid((total + block.x - 1) / block.x);

    addResidual<<<grid, block, 0, stream>>>(out, in, total);
}

template void invokeAddResidual(float*, const float*, int, int, cudaStream_t);
template void invokeAddResidual(half*, const half*, int, int, cudaStream_t);

// ids [seq_len, batch_size]
// input_ids [batch_size, max_input_len]
__global__ void
fixInputIds(int* ids, const int* input_ids, const int* input_lengths, int batch_size, int seq_len, int max_input_len)
{
    int seq_id   = threadIdx.x;
    int batch_id = blockIdx.x;
    for (; seq_id < input_lengths[batch_id]; seq_id += blockDim.x) {
        ids[seq_id * batch_size + batch_id] = input_ids[batch_id * max_input_len + seq_id];
    }
}

void invokeFixInputIds(int*         ids,
                       const int*   input_ids,
                       const int*   input_lengths,
                       int          batch_size,
                       int          seq_len,
                       int          max_input_len,
                       cudaStream_t st)
{
    dim3 block(std::min(1024, max_input_len));
    dim3 grid(batch_size);
    fixInputIds<<<grid, block, 0, st>>>(ids, input_ids, input_lengths, batch_size, seq_len, max_input_len);
}

template<typename T>
__global__ void sliceCausalMask(T* mask, int seq_len, int key_len, int step)
{
    mask += (size_t)blockIdx.x * seq_len * key_len;
    for (int i = threadIdx.x; i < seq_len * key_len; i += blockDim.x) {
        int row = i / key_len;
        int col = i % key_len;
        if (col <= row + step) {
            mask[i] = static_cast<T>(1.f);
        }
        else {
            mask[i] = static_cast<T>(0.f);
        }
    }
}

// [step: step+Q, :] of the K*K causal mask
template<typename T>
void invokeSliceCausalMask(T* mask, int seq_len, int key_len, int step, int batch_size, cudaStream_t stream)
{
    FT_CHECK(step == key_len - seq_len);
    sliceCausalMask<<<batch_size, 256, 0, stream>>>(mask, seq_len, key_len, step);
}

template void invokeSliceCausalMask(half*, int, int, int, int, cudaStream_t);
template void invokeSliceCausalMask(float*, int, int, int, int, cudaStream_t);

// mask [bsz, max_q_len, max_k_len]

template<typename T>
__global__ void createCausalMasks(T* mask, const int* q_lens, const int* k_lens, int max_q_len, int max_k_len)
{
    const auto q_len = q_lens[blockIdx.x];
    const auto k_len = k_lens[blockIdx.x];
    mask += blockIdx.x * max_q_len * max_k_len;
    for (int i = threadIdx.x; i < max_q_len * max_k_len; i += blockDim.x) {
        const int q        = i / max_k_len;  // [0, max_q_len)
        const int k        = i % max_k_len;  // [0, max_k_len)
        bool      is_valid = q < q_len && k < k_len && k <= q + (k_len - q_len);
        mask[i]            = static_cast<T>(is_valid);
    }
}

template<typename T>
void invokeCreateCausalMasks(
    T* mask, const int* q_lens, const int* k_lens, int max_q_len, int max_k_len, int batch_size, cudaStream_t stream)
{
    createCausalMasks<<<batch_size, 512, 0, stream>>>(mask, q_lens, k_lens, max_q_len, max_k_len);
}

template void invokeCreateCausalMasks(float* mask, const int*, const int*, int, int, int, cudaStream_t);
template void invokeCreateCausalMasks(half* mask, const int*, const int*, int, int, int, cudaStream_t);

Li Zhang's avatar
Li Zhang committed
206
207
template<typename Ti, typename To>
struct ExtendKvCache {
Li Zhang's avatar
Li Zhang committed
208

Li Zhang's avatar
Li Zhang committed
209
210
    static constexpr int MaxElemSize = std::max(sizeof(Ti), sizeof(To));
    static constexpr int X_ELEMS     = 16 / MaxElemSize;
Li Zhang's avatar
Li Zhang committed
211

Li Zhang's avatar
Li Zhang committed
212
213
    using Vi = Array<Ti, X_ELEMS>;
    using Vo = Array<To, X_ELEMS>;
Li Zhang's avatar
Li Zhang committed
214

Li Zhang's avatar
Li Zhang committed
215
    using Transform = ConvertKvCache<Ti, To>;
Li Zhang's avatar
Li Zhang committed
216

Li Zhang's avatar
Li Zhang committed
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    struct Params {
        To**       k_dst_ptrs;
        To**       v_dst_ptrs;
        const Ti*  k_src;
        const Ti*  v_src;
        const int* cu_block_counts;
        const int* query_length;
        const int* context_length;
        int        block_length;
        size_t     dst_layer_offset;
        int        max_q_len;
        int        head_num;
        int        head_dim;
        Transform  transform_k;
        Transform  transform_v;
    };
Li Zhang's avatar
Li Zhang committed
233

Li Zhang's avatar
Li Zhang committed
234
235
236
    __device__ void operator()(const Params& params) const
    {
        const int batch_id = blockIdx.y;
Li Zhang's avatar
Li Zhang committed
237

Li Zhang's avatar
Li Zhang committed
238
239
240
        const int query_len    = params.query_length[batch_id];
        const int history_len  = params.context_length[batch_id] - query_len;
        const int cu_block_cnt = params.cu_block_counts[batch_id];
Li Zhang's avatar
Li Zhang committed
241

Li Zhang's avatar
Li Zhang committed
242
        const int head_id = blockIdx.z;
Li Zhang's avatar
Li Zhang committed
243

Li Zhang's avatar
Li Zhang committed
244
245
246
247
        const int size_per_head_div_x = params.head_dim / X_ELEMS;
        const int idx                 = blockIdx.x * blockDim.x + threadIdx.x;
        const int head_size_id        = idx % size_per_head_div_x;
        const int seq_len_id          = idx / size_per_head_div_x;
Li Zhang's avatar
Li Zhang committed
248

Li Zhang's avatar
Li Zhang committed
249
250
        const int cache_block_index  = (seq_len_id + history_len) / params.block_length;
        const int cache_block_offset = (seq_len_id + history_len) % params.block_length;
Li Zhang's avatar
Li Zhang committed
251

Li Zhang's avatar
Li Zhang committed
252
253
        const auto k_val_src = params.k_src;
        const auto v_val_src = params.v_src;
Li Zhang's avatar
Li Zhang committed
254

Li Zhang's avatar
Li Zhang committed
255
256
        const auto k_val_dst = (params.k_dst_ptrs + cu_block_cnt)[cache_block_index] + params.dst_layer_offset;
        const auto v_val_dst = (params.v_dst_ptrs + cu_block_cnt)[cache_block_index] + params.dst_layer_offset;
Li Zhang's avatar
Li Zhang committed
257

Li Zhang's avatar
Li Zhang committed
258
259
260
261
262
        if (seq_len_id < query_len) {
            // [B, H, s, D/x] -> [H, S[t:t+s], D/x]
            const int64_t dst_idx = head_id * params.block_length * size_per_head_div_x +  // H
                                    cache_block_offset * size_per_head_div_x +             // s + offset
                                    head_size_id;                                          // D/x
Li Zhang's avatar
Li Zhang committed
263

Li Zhang's avatar
Li Zhang committed
264
265
266
267
            const int64_t src_idx = batch_id * params.head_num * params.max_q_len * size_per_head_div_x +  // B
                                    head_id * params.max_q_len * size_per_head_div_x +                     // H
                                    seq_len_id * size_per_head_div_x +                                     // s
                                    head_size_id;                                                          // D/x
Li Zhang's avatar
Li Zhang committed
268

Li Zhang's avatar
Li Zhang committed
269
270
            Vi k_vi;
            Vi v_vi;
Li Zhang's avatar
Li Zhang committed
271

Li Zhang's avatar
Li Zhang committed
272
273
            Ldg(k_vi, k_val_src + src_idx * X_ELEMS);
            Ldg(v_vi, v_val_src + src_idx * X_ELEMS);
Li Zhang's avatar
Li Zhang committed
274

Li Zhang's avatar
Li Zhang committed
275
276
            Vo k_vo = params.transform_k(k_vi);
            Vo v_vo = params.transform_v(v_vi);
Li Zhang's avatar
Li Zhang committed
277

Li Zhang's avatar
Li Zhang committed
278
279
280
281
282
            Store(k_val_dst + dst_idx * X_ELEMS, k_vo);
            Store(v_val_dst + dst_idx * X_ELEMS, v_vo);
        }
    }
};
283

Li Zhang's avatar
Li Zhang committed
284
namespace {
285

Li Zhang's avatar
Li Zhang committed
286
287
template<class Kernel, class Params>
__global__ void KernelWrapper(Params params)
AllentDan's avatar
AllentDan committed
288
{
Li Zhang's avatar
Li Zhang committed
289
290
    Kernel{}(params);
};
291

Li Zhang's avatar
Li Zhang committed
292
}  // namespace
293
294

template<typename T>
Li Zhang's avatar
Li Zhang committed
295
296
void invokeExtendKVCache(void**       k_dst_ptrs,
                         void**       v_dst_ptrs,
Li Zhang's avatar
Li Zhang committed
297
298
                         const T*     k_src,
                         const T*     v_src,
Li Zhang's avatar
Li Zhang committed
299
                         const int*   cu_block_counts,
Li Zhang's avatar
Li Zhang committed
300
                         const int*   query_length,
Li Zhang's avatar
Li Zhang committed
301
302
303
304
                         const int*   context_length,
                         int          batch_size,
                         int          block_length,
                         size_t       dst_layer_offset,
Li Zhang's avatar
Li Zhang committed
305
                         int          max_q_len,
Li Zhang's avatar
Li Zhang committed
306
307
                         int          head_dim,
                         int          head_num,
308
                         int          quant,
Li Zhang's avatar
Li Zhang committed
309
310
                         const float* kv_params,
                         cudaStream_t stream)
Li Zhang's avatar
Li Zhang committed
311
312
313
{
    constexpr int block_sz = 128;

Li Zhang's avatar
Li Zhang committed
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    auto fn = [&](auto value) {
        using Tout   = decltype(value);
        using Kernel = ExtendKvCache<T, Tout>;

        dim3 grid((max_q_len * head_dim / Kernel::X_ELEMS + block_sz - 1) / block_sz, batch_size, head_num);

        typename Kernel::Params params{(Tout**)k_dst_ptrs,
                                       (Tout**)v_dst_ptrs,
                                       k_src,
                                       v_src,
                                       cu_block_counts,
                                       query_length,
                                       context_length,
                                       block_length,
                                       dst_layer_offset,
                                       max_q_len,
                                       head_num,
                                       head_dim,
                                       {kv_params[0], kv_params[1]},
                                       {kv_params[2], kv_params[3]}};

        KernelWrapper<Kernel><<<grid, block_sz, 0, stream>>>(params);
    };

    (quant & QuantPolicy::kCacheKVInt8) ? fn(int8_t{}) : fn(T{});
Li Zhang's avatar
Li Zhang committed
339
340
}

Li Zhang's avatar
Li Zhang committed
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
template void invokeExtendKVCache(void**       k_dst_ptrs,
                                  void**       v_dst_ptrs,
                                  const float* k_src,
                                  const float* v_src,
                                  const int*   cu_block_counts,
                                  const int*   query_length,
                                  const int*   history_length,
                                  int          batch_size,
                                  int          block_length,
                                  size_t       dst_layer_offset,
                                  int          max_q_len,
                                  int          head_dim,
                                  int          head_num,
                                  int          quant,
                                  const float* kv_scale,
                                  cudaStream_t stream);

template void invokeExtendKVCache(void**       k_dst_ptrs,
                                  void**       v_dst_ptrs,
                                  const half*  k_src,
                                  const half*  v_src,
                                  const int*   cu_block_counts,
                                  const int*   query_length,
                                  const int*   history_length,
                                  int          batch_size,
                                  int          block_length,
                                  size_t       dst_layer_offset,
                                  int          max_q_len,
                                  int          head_dim,
                                  int          head_num,
                                  int          quant,
                                  const float* kv_scale,
                                  cudaStream_t stream);

template<typename Ti, typename To>
struct TransposeKvCache {
    static constexpr int MaxElemSize = std::max(sizeof(Ti), sizeof(To));
    static constexpr int X_ELEMS     = 16 / MaxElemSize;

    using Vi = Array<Ti, X_ELEMS>;
    using Vo = Array<To, X_ELEMS>;

    using Transform = ConvertKvCache<Ti, To>;

    struct Params {
        To*        k_dst;
        To*        v_dst;
        const Ti** k_src;
        const Ti** v_src;
        size_t     src_offset;
        int        head_num;
        int        head_n_rep;
        int        size_per_head;
        const int* seq_length;
        int        max_kv_len;
        int        max_seq_len;
        Transform  transform_k;
        Transform  transform_v;
        // float      k_scale;
        // float      k_zp;
        // float      v_scale;
        // float      v_zp;
    };

    __device__ void operator()(const Params& params) const
    {
        const int batch_id = blockIdx.y;
        const int head_id  = blockIdx.z;

        const int idx                 = blockIdx.x * blockDim.x + threadIdx.x;
        const int size_per_head_div_x = params.size_per_head / X_ELEMS;

        const auto k_src = params.k_src[batch_id] + params.src_offset;
        const auto v_src = params.v_src[batch_id] + params.src_offset;
        const auto k_dst = params.k_dst;
        const auto v_dst = params.v_dst;

        const auto seq_len = params.seq_length[batch_id];

        const int v_head_size_id = idx % size_per_head_div_x;
        const int v_seq_len_id   = idx / size_per_head_div_x;

        if (v_seq_len_id < seq_len) {
            // [B, H, s, D/x] <- [B, H, S[:s], D/x]
            const int64_t src_idx = head_id / params.head_n_rep * size_per_head_div_x * params.max_seq_len +  // H
                                    v_seq_len_id * size_per_head_div_x +                                      // s
                                    v_head_size_id;                                                           // D/x

            const int64_t dst_idx = batch_id * params.head_num * size_per_head_div_x * params.max_kv_len +  // B
                                    head_id * size_per_head_div_x * params.max_kv_len +                     // H
                                    v_seq_len_id * size_per_head_div_x +                                    // s
                                    v_head_size_id;                                                         // D/x

            Vi k_vi;
            Vi v_vi;

            Ldg(k_vi, k_src + src_idx * X_ELEMS);
            Ldg(v_vi, v_src + src_idx * X_ELEMS);

            Vo k_vo = params.transform_k(k_vi);
            Vo v_vo = params.transform_v(v_vi);

            Store(k_dst + dst_idx * X_ELEMS, k_vo);
            Store(v_dst + dst_idx * X_ELEMS, v_vo);
        }
Li Zhang's avatar
Li Zhang committed
446
    }
Li Zhang's avatar
Li Zhang committed
447
};
Li Zhang's avatar
Li Zhang committed
448
449
450
451
452
453
454
455
456
457
458
459
460

template<typename T>
void invokeTransposeKVCache(T*           key_cache_trans,
                            T*           val_cache_trans,
                            const T**    key_cache,
                            const T**    val_cache,
                            size_t       src_offset,
                            int          batch_size,
                            const int*   key_length,
                            int          max_kv_len,
                            int          max_seq_len,
                            int          size_per_head,
                            int          head_num,
461
                            int          head_n_rep,
462
463
                            cudaStream_t stream,
                            int          quant,
Li Zhang's avatar
Li Zhang committed
464
                            const float* kv_params)
Li Zhang's avatar
Li Zhang committed
465
466
{
    constexpr int block_sz = 128;
Li Zhang's avatar
Li Zhang committed
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491

    auto fn = [&](auto value) {
        using Tin    = decltype(value);
        using Kernel = TransposeKvCache<Tin, T>;

        dim3 grid((max_kv_len * size_per_head / Kernel::X_ELEMS + block_sz - 1) / block_sz, batch_size, head_num);

        typename Kernel::Params params{key_cache_trans,
                                       val_cache_trans,
                                       (const Tin**)key_cache,
                                       (const Tin**)val_cache,
                                       src_offset,
                                       head_num,
                                       head_n_rep,
                                       size_per_head,
                                       key_length,
                                       max_kv_len,
                                       max_seq_len,
                                       {kv_params[0], kv_params[1]},
                                       {kv_params[2], kv_params[3]}};

        KernelWrapper<Kernel><<<grid, block_sz, 0, stream>>>(params);
    };

    (quant & QuantPolicy::kCacheKVInt8) ? fn(int8_t{}) : fn(T{});
Li Zhang's avatar
Li Zhang committed
492
493
}

AllentDan's avatar
AllentDan committed
494
495
496
497
498
499
500
501
502
503
504
template void invokeTransposeKVCache(float*,
                                     float*,
                                     const float**,
                                     const float**,
                                     size_t,
                                     int,
                                     const int*,
                                     int,
                                     int,
                                     int,
                                     int,
505
                                     int,
AllentDan's avatar
AllentDan committed
506
507
508
509
510
511
512
513
514
515
516
517
518
519
                                     cudaStream_t stream,
                                     int,
                                     const float*);
template void invokeTransposeKVCache(half*,
                                     half*,
                                     const half**,
                                     const half**,
                                     size_t,
                                     int,
                                     const int*,
                                     int,
                                     int,
                                     int,
                                     int,
520
                                     int,
AllentDan's avatar
AllentDan committed
521
522
523
                                     cudaStream_t stream,
                                     int,
                                     const float*);
Li Zhang's avatar
Li Zhang committed
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555

__global__ void gatherOutput(int*       output_ids,
                             const int* ids,
                             const int* context_length,
                             int        max_context_len,
                             int        max_gen_step,
                             int        max_output_len,
                             int        batch_size)
{
    const int batch_id    = blockIdx.x;
    const int context_len = context_length[batch_id];
    output_ids += batch_id * max_output_len;
    for (int src_idx = threadIdx.x; src_idx < max_gen_step; src_idx += blockDim.x) {
        // skip padding for src
        if (context_len <= src_idx && src_idx < max_context_len) {
            continue;
        }
        // skip padding for dst
        const int dst_idx   = src_idx < context_len ? src_idx : src_idx - (max_context_len - context_len);
        output_ids[dst_idx] = ids[src_idx * batch_size + batch_id];
    }
}

void invokeGatherOutput(int*         output_ids,
                        const int*   ids,
                        const int*   context_length,
                        int          max_context_len,
                        int          max_gen_step,
                        int          max_output_len,
                        int          batch_size,
                        cudaStream_t stream)
{
Li Zhang's avatar
Li Zhang committed
556
    int block_size = 128;
Li Zhang's avatar
Li Zhang committed
557
558
559
560
561
    int grid_size  = batch_size;
    gatherOutput<<<grid_size, block_size, 0, stream>>>(
        output_ids, ids, context_length, max_context_len, max_gen_step, max_output_len, batch_size);
}

Li Zhang's avatar
Li Zhang committed
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
__global__ void updateOutput(int**      request_output_ids_ptrs,
                             int**      request_seqlen_ptrs,
                             const int* output_ids,
                             const int* sequence_lengths,
                             const int* request_output_ids_lens,
                             int        max_session_len,
                             bool       token_generated)
{
    const int batch_id = blockIdx.x;

    auto request_output_ids = request_output_ids_ptrs[batch_id];
    auto request_seqlen     = request_seqlen_ptrs[batch_id];

    output_ids += max_session_len * batch_id;

    const int seqlen     = sequence_lengths[batch_id] + (int)token_generated;
    const int output_len = min(seqlen, request_output_ids_lens[batch_id]);

    for (int i = threadIdx.x; i < output_len; i += blockDim.x) {
        request_output_ids[i] = output_ids[i];
    }

    *request_seqlen = seqlen;
}

void invokeUpdateOutput(int**        request_output_ids_ptrs,
                        int**        request_seqlen_ptrs,
                        const int*   output_ids,
                        const int*   sequence_lengths,
                        const int*   request_output_ids_lens,
                        int          max_session_len,
                        bool         token_generated,
                        int          batch_size,
                        cudaStream_t stream)
{
    constexpr int block_size = 128;
    const int     grid_size  = batch_size;

    updateOutput<<<grid_size, block_size, 0, stream>>>(request_output_ids_ptrs,
                                                       request_seqlen_ptrs,
                                                       output_ids,
                                                       sequence_lengths,
                                                       request_output_ids_lens,
                                                       max_session_len,
                                                       token_generated);
}

q.yao's avatar
q.yao committed
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
#define VERSION_SWITCH(VERSION, CONST_NAME, ...)                                                                       \
    [&] {                                                                                                              \
        if (VERSION == 2) {                                                                                            \
            constexpr static int CONST_NAME = 2;                                                                       \
            return __VA_ARGS__();                                                                                      \
        }                                                                                                              \
        else {                                                                                                         \
            constexpr static int CONST_NAME = 1;                                                                       \
            return __VA_ARGS__();                                                                                      \
        }                                                                                                              \
    }()

template<typename T>
FlashAttentionOp<T>::FlashAttentionOp(int batch_size, int head_num, int key_len, int seq_len, int size_per_head):
    batch_size_(batch_size), head_num_(head_num), key_len_(key_len), seq_len_(seq_len), size_per_head_(size_per_head)
{
#ifdef _MSC_VER
    op_version_ = 1;
#else
    op_version_ = std::is_same<half, typename std::decay<T>::type>::value ? 2 : 1;
    if (op_version_ == 2 && getSMVersion() < 80) {
        op_version_ = 1;
    }
#endif
}

template<typename T>
int FlashAttentionOp<T>::get_workspace_size() const
{
#ifdef _MSC_VER
    FlashAttentionOpImpl<T, 1> attention_op(batch_size_, head_num_, key_len_, seq_len_, size_per_head_);
    return attention_op.get_workspace_size();
#else
    return VERSION_SWITCH(op_version_, OP_VERSION, [&]() {
        FlashAttentionOpImpl<T, OP_VERSION> attention_op(batch_size_, head_num_, key_len_, seq_len_, size_per_head_);
        return attention_op.get_workspace_size();
    });
#endif
}

template<typename T>
void FlashAttentionOp<T>::operator()(Params& params, cudaStream_t st) const
{
#ifdef _MSC_VER
    FlashAttentionOpImpl<T, 1> attention_op(batch_size_, head_num_, key_len_, seq_len_, size_per_head_);
    return attention_op(params, st);
#else
    return VERSION_SWITCH(op_version_, OP_VERSION, [&]() {
        FlashAttentionOpImpl<T, OP_VERSION> attention_op(batch_size_, head_num_, key_len_, seq_len_, size_per_head_);
        return attention_op(params, st);
    });
#endif
}

template class FlashAttentionOp<float>;
template class FlashAttentionOp<half>;

lvhan028's avatar
lvhan028 committed
666
}  // namespace turbomind