llama_kernels.cu 29.5 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
// Copyright (c) OpenMMLab. All rights reserved.

AllentDan's avatar
AllentDan committed
3
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h"
Li Zhang's avatar
Li Zhang committed
4
5
#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh"
#include "src/fastertransformer/models/llama/llama_kernels.h"
6
#include "src/fastertransformer/models/llama/llama_utils.h"
Li Zhang's avatar
Li Zhang committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
#include "src/fastertransformer/utils/cuda_type_utils.cuh"

namespace fastertransformer {

// fp16, bf16
// n is divided by 2 for this impl
template<typename T>
__global__ void rootMeanSquareNorm(T* out, const T* input, const T* scale, float eps, int m, int n)
{
    using T2 = typename TypeConverter<T>::Type;
    __shared__ float s_inv_mean;
    float            mean = 0.f;

    T2*       out_ptr   = (T2*)out;
    const T2* input_ptr = (const T2*)input;
    const T2* scale_ptr = (const T2*)scale;

    for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) {
        float2 tmp2 = cuda_cast<float2>(input_ptr[blockIdx.x * n + idx]);
        mean += tmp2.x * tmp2.x;
        mean += tmp2.y * tmp2.y;
    }

    mean = blockReduceSum<float>(mean);
    if (threadIdx.x == 0) {
        s_inv_mean = rsqrt(.5f * mean / (float)n + eps);
    }
    __syncthreads();

    for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) {
        float2 tmp2                   = cuda_cast<float2>(input_ptr[blockIdx.x * n + idx]);
        float2 sca2                   = cuda_cast<float2>(scale_ptr[idx]);
        tmp2.x                        = tmp2.x * s_inv_mean * sca2.x;
        tmp2.y                        = tmp2.y * s_inv_mean * sca2.y;
        out_ptr[blockIdx.x * n + idx] = cuda_cast<T2>(tmp2);
    }
}

template<>
__global__ void rootMeanSquareNorm(float* out, const float* input, const float* scale, float eps, int m, int n)
{
    __shared__ float s_inv_mean;
    float            mean = 0.f;

    for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) {
        float tmp = input[blockIdx.x * n + idx];
        mean += tmp * tmp;
    }

    mean = blockReduceSum<float>(mean);
    if (threadIdx.x == 0) {
        s_inv_mean = rsqrt(mean / static_cast<float>(n) + eps);
    }
    __syncthreads();

    for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) {
        float tmp                 = input[blockIdx.x * n + idx];
        out[blockIdx.x * n + idx] = tmp * s_inv_mean * scale[idx];
    }
}

template<typename T>
void invokeRootMeanSquareNorm(T* out, const T* input, const T* scale, float eps, int m, int n, cudaStream_t stream)
{
    if (sizeof(T) == 2) {
        FT_CHECK(n % 2 == 0);
        n /= 2;
    }
    dim3 grid(m);
    dim3 block(std::min(n, 1024));
    rootMeanSquareNorm<<<grid, block, 0, stream>>>(out, input, scale, eps, m, n);
}

template void invokeRootMeanSquareNorm(float*, const float*, const float*, float, int, int, cudaStream_t);
template void invokeRootMeanSquareNorm(half*, const half*, const half*, float, int, int, cudaStream_t);

// #ifdef ENABLE_BF16

// template void invokeRootMeanSquareNorm(__nv_bfloat16*, const __nv_bfloat16*, float, int, int, cudaStream_t);

// #endif

template<typename T, typename T0>
__device__ T saturate_cast(T0 x)
{
    return x;
}

template<>
__device__ half saturate_cast<half, float>(float x)
{
    return (x > 64512.f || x < -64512.f) ? (x > 0.f ? 64512.f : -64512.f) : x;
}

template<typename T>
__global__ void addResidual(T* out, const T* in, size_t n)
{
    auto idx = threadIdx.x + (size_t)blockIdx.x * blockDim.x;
    if (idx < n) {
        out[idx] = static_cast<T>(static_cast<float>(out[idx]) + static_cast<float>(in[idx]));
    }
}

template<typename T>
void invokeAddResidual(T* out, const T* in, int m, int n, cudaStream_t stream)
{
    auto total = static_cast<size_t>(m) * n;
    dim3 block(std::min(total, 1024UL));
    dim3 grid((total + block.x - 1) / block.x);

    addResidual<<<grid, block, 0, stream>>>(out, in, total);
}

template void invokeAddResidual(float*, const float*, int, int, cudaStream_t);
template void invokeAddResidual(half*, const half*, int, int, cudaStream_t);

// ids [seq_len, batch_size]
// input_ids [batch_size, max_input_len]
__global__ void
fixInputIds(int* ids, const int* input_ids, const int* input_lengths, int batch_size, int seq_len, int max_input_len)
{
    int seq_id   = threadIdx.x;
    int batch_id = blockIdx.x;
    for (; seq_id < input_lengths[batch_id]; seq_id += blockDim.x) {
        ids[seq_id * batch_size + batch_id] = input_ids[batch_id * max_input_len + seq_id];
    }
}

void invokeFixInputIds(int*         ids,
                       const int*   input_ids,
                       const int*   input_lengths,
                       int          batch_size,
                       int          seq_len,
                       int          max_input_len,
                       cudaStream_t st)
{
    dim3 block(std::min(1024, max_input_len));
    dim3 grid(batch_size);
    fixInputIds<<<grid, block, 0, st>>>(ids, input_ids, input_lengths, batch_size, seq_len, max_input_len);
}

template<typename T>
__global__ void sliceCausalMask(T* mask, int seq_len, int key_len, int step)
{
    mask += (size_t)blockIdx.x * seq_len * key_len;
    for (int i = threadIdx.x; i < seq_len * key_len; i += blockDim.x) {
        int row = i / key_len;
        int col = i % key_len;
        if (col <= row + step) {
            mask[i] = static_cast<T>(1.f);
        }
        else {
            mask[i] = static_cast<T>(0.f);
        }
    }
}

// [step: step+Q, :] of the K*K causal mask
template<typename T>
void invokeSliceCausalMask(T* mask, int seq_len, int key_len, int step, int batch_size, cudaStream_t stream)
{
    FT_CHECK(step == key_len - seq_len);
    sliceCausalMask<<<batch_size, 256, 0, stream>>>(mask, seq_len, key_len, step);
}

template void invokeSliceCausalMask(half*, int, int, int, int, cudaStream_t);
template void invokeSliceCausalMask(float*, int, int, int, int, cudaStream_t);

// mask [bsz, max_q_len, max_k_len]

template<typename T>
__global__ void createCausalMasks(T* mask, const int* q_lens, const int* k_lens, int max_q_len, int max_k_len)
{
    const auto q_len = q_lens[blockIdx.x];
    const auto k_len = k_lens[blockIdx.x];
    mask += blockIdx.x * max_q_len * max_k_len;
    for (int i = threadIdx.x; i < max_q_len * max_k_len; i += blockDim.x) {
        const int q        = i / max_k_len;  // [0, max_q_len)
        const int k        = i % max_k_len;  // [0, max_k_len)
        bool      is_valid = q < q_len && k < k_len && k <= q + (k_len - q_len);
        mask[i]            = static_cast<T>(is_valid);
    }
}

template<typename T>
void invokeCreateCausalMasks(
    T* mask, const int* q_lens, const int* k_lens, int max_q_len, int max_k_len, int batch_size, cudaStream_t stream)
{
    createCausalMasks<<<batch_size, 512, 0, stream>>>(mask, q_lens, k_lens, max_q_len, max_k_len);
}

template void invokeCreateCausalMasks(float* mask, const int*, const int*, int, int, int, cudaStream_t);
template void invokeCreateCausalMasks(half* mask, const int*, const int*, int, int, int, cudaStream_t);

template<typename T>
__global__ void extend_key_cache(T**          k_dst,
                                 const size_t dst_offset,
                                 const T*     k_src,
                                 const int    head_num,
                                 const int    size_per_head,
                                 const int*   query_length,
                                 const int*   history_length,
                                 const int    max_q_len,
                                 const int    max_seq_len)
{
    const int     batch_id = blockIdx.y;
    const int     head_id  = blockIdx.z;
    constexpr int X_ELEMS  = (sizeof(T) == 4) ? 4 : 8;

    const int idx                 = blockIdx.x * blockDim.x + threadIdx.x;
    int       size_per_head_div_x = size_per_head / X_ELEMS;

    // x dim is now handled by uint4 type
    const auto key_src = reinterpret_cast<const uint4*>(k_src);
    const auto key_dst = reinterpret_cast<uint4*>(k_dst[batch_id] + dst_offset);

    const auto seq_len  = query_length[batch_id];
    const auto t_offset = history_length[batch_id];

    const int k_head_size_id = idx % size_per_head_div_x;
    const int k_seq_len_id   = idx / size_per_head_div_x;

    if (k_seq_len_id < seq_len) {
        // [B, H, s, D/x] -> [H, D/x, S[t:t+s]]

        const int64_t dst_idx = head_id * size_per_head_div_x * max_seq_len +  // H
                                k_head_size_id * max_seq_len +                 // D/x
                                t_offset + k_seq_len_id;                       // s + offset

        const int64_t src_idx = batch_id * head_num * size_per_head_div_x * max_q_len +  // B
                                head_id * size_per_head_div_x * max_q_len +              // H
                                k_seq_len_id * size_per_head_div_x +                     // s
                                k_head_size_id;                                          // D/x

        key_dst[dst_idx] = key_src[src_idx];
    }
}

template<typename T>
__global__ void extend_value_cache(T**          v_dst,
                                   const size_t dst_offset,
                                   const T*     v_src,
                                   const int    head_num,
                                   const int    size_per_head,
                                   const int*   query_length,
                                   const int*   history_length,
                                   const int    max_q_len,
                                   const int    max_seq_len)
{
    const int     batch_id = blockIdx.y;
    const int     head_id  = blockIdx.z;
    constexpr int X_ELEMS  = (sizeof(T) == 4) ? 4 : 8;

    const int idx                 = blockIdx.x * blockDim.x + threadIdx.x;
    int       size_per_head_div_x = size_per_head / X_ELEMS;

    // x dim is now handled by uint4 type
    const auto val_src = reinterpret_cast<const uint4*>(v_src);
    const auto val_dst = reinterpret_cast<uint4*>(v_dst[batch_id] + dst_offset);

    const auto seq_len  = query_length[batch_id];
    const auto t_offset = history_length[batch_id];

    const int v_head_size_id = idx % size_per_head_div_x;
    const int v_seq_len_id   = idx / size_per_head_div_x;

    if (v_seq_len_id < seq_len) {
        // [B, H, s, D/x] -> [H, S[t:t+s], D/x]
        const int64_t dst_idx = head_id * size_per_head_div_x * max_seq_len +      // H
                                (v_seq_len_id + t_offset) * size_per_head_div_x +  // s + offset
                                v_head_size_id;                                    // D/x

        const int64_t src_idx = batch_id * head_num * size_per_head_div_x * max_q_len +  // B
                                head_id * size_per_head_div_x * max_q_len +              // H
                                v_seq_len_id * size_per_head_div_x +                     // s
                                v_head_size_id;                                          // D/x

        val_dst[dst_idx] = val_src[src_idx];
    }
}

288
289
290
291
292
293
294
295
inline __device__ float2 float2div(float a, float2 b)
{
    float2 c;
    c.x = b.x / a;
    c.y = b.y / a;
    return c;
}

AllentDan's avatar
AllentDan committed
296
297
298
299
300
301
302
303
static inline __device__ half4 char4_scale_to_half4(char4 value, const float scale)
{
    half4 dst;
    dst.x = __float2half(value.x * scale);
    dst.y = __float2half(value.y * scale);
    dst.z = __float2half(value.z * scale);
    dst.w = __float2half(value.w * scale);
    return dst;
304
305
}

AllentDan's avatar
AllentDan committed
306
307
308
static inline __device__ uint32_t float4_to_char4(float x, float y, float z, float w)
{
    uint32_t dst;
309
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 720
AllentDan's avatar
AllentDan committed
310
311
312
313
314
315
316
317
318
319
320
    uint32_t a;
    asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(a) : "f"(x));
    uint32_t b;
    asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(b) : "f"(y));
    uint32_t c;
    asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(c) : "f"(z));
    uint32_t d;
    asm volatile("cvt.rni.sat.s32.f32 %0, %1;\n" : "=r"(d) : "f"(w));

    asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2,  0;\n" : "=r"(dst) : "r"(d), "r"(c));
    asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, %0;\n" : "+r"(dst) : "r"(b), "r"(a));
321
#else
AllentDan's avatar
AllentDan committed
322
323
324
325
326
327
    char4 tmp;
    tmp.x = x;
    tmp.y = y;
    tmp.z = z;
    tmp.w = w;
    dst   = reinterpret_cast<const uint32_t&>(tmp);
328
#endif
AllentDan's avatar
AllentDan committed
329
    return dst;
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
}

template<typename T>
__global__ void extend_value_cache_int8(int8_t**     v_dst,
                                        const size_t dst_offset,
                                        const T*     v_src,
                                        const int    head_num,
                                        const int    size_per_head,
                                        const int*   query_length,
                                        const int*   history_length,
                                        const int    max_q_len,
                                        const int    max_seq_len,
                                        const float  v_scale)
{
    const int     batch_id = blockIdx.y;
    const int     head_id  = blockIdx.z;
    constexpr int X_ELEMS  = (sizeof(T) == 4) ? 4 : 8;

    const int idx                 = blockIdx.x * blockDim.x + threadIdx.x;
    int       size_per_head_div_x = size_per_head / X_ELEMS;

    // x dim is now handled by uint4 type
    const auto val_src = reinterpret_cast<const uint4*>(v_src);
    const auto val_dst = reinterpret_cast<uint2*>(v_dst[batch_id] + dst_offset);

    const auto seq_len  = query_length[batch_id];
    const auto t_offset = history_length[batch_id];

    const int v_head_size_id = idx % size_per_head_div_x;
    const int v_seq_len_id   = idx / size_per_head_div_x;

    if (v_seq_len_id < seq_len) {
        // [B, H, s, D/x] -> [H, S[t:t+s], D/x]
        const int64_t dst_idx = head_id * size_per_head_div_x * max_seq_len +      // H
                                (v_seq_len_id + t_offset) * size_per_head_div_x +  // s + offset
                                v_head_size_id;                                    // D/x

        const int64_t src_idx = batch_id * head_num * size_per_head_div_x * max_q_len +  // B
                                head_id * size_per_head_div_x * max_q_len +              // H
                                v_seq_len_id * size_per_head_div_x +                     // s
                                v_head_size_id;                                          // D/x

        // scale to int8 and write
        const auto value  = val_src[src_idx];
        auto       to_ptr = reinterpret_cast<uint32_t*>(val_dst + dst_idx);

        float2 float2_0 = float2div(v_scale, mmha::half2_to_float2(value.x));
        float2 float2_1 = float2div(v_scale, mmha::half2_to_float2(value.y));
        to_ptr[0]       = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y);

        float2_0  = float2div(v_scale, mmha::half2_to_float2(value.z));
        float2_1  = float2div(v_scale, mmha::half2_to_float2(value.w));
        to_ptr[1] = float4_to_char4(float2_0.x, float2_0.y, float2_1.x, float2_1.y);
    }
}

Li Zhang's avatar
Li Zhang committed
386
387
388
389
390
391
392
393
394
395
396
397
398
template<typename T>
void invokeExtendKVCache(T**          k_dst,
                         T**          v_dst,
                         size_t       dst_offset,
                         const T*     k_src,
                         const T*     v_src,
                         int          local_batch_size,
                         const int*   query_length,
                         int          max_q_len,
                         const int*   history_length,
                         int          max_seq_len,
                         int          size_per_head,
                         int          local_head_num,
399
400
401
                         cudaStream_t stream,
                         int          quant,
                         const float* kv_scale)
Li Zhang's avatar
Li Zhang committed
402
403
404
405
406
407
{
    constexpr int block_sz = 128;
    constexpr int x        = (sizeof(T) == 4) ? 4 : 8;

    dim3 grid((max_q_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_head_num);

408
    if (quant & QuantPolicy::kCacheKVInt8) {
AllentDan's avatar
AllentDan committed
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
        extend_value_cache_int8<<<grid, block_sz, 0, stream>>>(reinterpret_cast<int8_t**>(k_dst),
                                                               dst_offset,
                                                               k_src,
                                                               local_head_num,
                                                               size_per_head,
                                                               query_length,
                                                               history_length,
                                                               max_q_len,
                                                               max_seq_len,
                                                               kv_scale[0]);

        extend_value_cache_int8<<<grid, block_sz, 0, stream>>>(reinterpret_cast<int8_t**>(v_dst),
                                                               dst_offset,
                                                               v_src,
                                                               local_head_num,
                                                               size_per_head,
                                                               query_length,
                                                               history_length,
                                                               max_q_len,
                                                               max_seq_len,
                                                               kv_scale[1]);
    }
    else {
        extend_value_cache<<<grid, block_sz, 0, stream>>>(k_dst,
                                                          dst_offset,
                                                          k_src,
                                                          local_head_num,
                                                          size_per_head,
                                                          query_length,
                                                          history_length,
                                                          max_q_len,
                                                          max_seq_len);

        extend_value_cache<<<grid, block_sz, 0, stream>>>(v_dst,
                                                          dst_offset,
                                                          v_src,
                                                          local_head_num,
                                                          size_per_head,
                                                          query_length,
                                                          history_length,
                                                          max_q_len,
                                                          max_seq_len);
451
    }
Li Zhang's avatar
Li Zhang committed
452
453
454
455
456
457
458
459
460
461
462
463
464
465
}

template void invokeExtendKVCache(float**,
                                  float**,
                                  size_t,
                                  const float*,
                                  const float*,
                                  int,
                                  const int*,
                                  int,
                                  const int*,
                                  int,
                                  int,
                                  int,
466
467
468
                                  cudaStream_t stream,
                                  int,
                                  const float*);
Li Zhang's avatar
Li Zhang committed
469
470
471
472
473
474
475
476
477
478
479
480
481

template void invokeExtendKVCache(half**,
                                  half**,
                                  size_t,
                                  const half*,
                                  const half*,
                                  int,
                                  const int*,
                                  int,
                                  const int*,
                                  int,
                                  int,
                                  int,
482
483
484
                                  cudaStream_t stream,
                                  int,
                                  const float*);
Li Zhang's avatar
Li Zhang committed
485
486

template<typename T>
487
488
489
490
491
492
493
494
__global__ void transpose_value_cache(T*           v_dst,  //
                                      const T**    v_src,
                                      const size_t src_offset,
                                      const int    head_num,
                                      const int    size_per_head,
                                      const int*   seq_length,
                                      const int    max_kv_len,
                                      const int    max_seq_len)
Li Zhang's avatar
Li Zhang committed
495
496
497
498
499
500
501
502
503
{
    const int     batch_id = blockIdx.y;
    const int     head_id  = blockIdx.z;
    constexpr int X_ELEMS  = (sizeof(T) == 4) ? 4 : 8;

    const int idx                 = blockIdx.x * blockDim.x + threadIdx.x;
    int       size_per_head_div_x = size_per_head / X_ELEMS;

    // x dim is now handled by uint4 type
504
505
    const auto val_src = reinterpret_cast<const uint4*>(v_src[batch_id] + src_offset);
    const auto val_dst = reinterpret_cast<uint4*>(v_dst);
Li Zhang's avatar
Li Zhang committed
506
507
508

    const auto seq_len = seq_length[batch_id];

509
510
    const int v_head_size_id = idx % size_per_head_div_x;
    const int v_seq_len_id   = idx / size_per_head_div_x;
Li Zhang's avatar
Li Zhang committed
511

512
513
    if (v_seq_len_id < seq_len) {
        // [B, H, s, D/x] <- [B, H, S[:s], D/x]
Li Zhang's avatar
Li Zhang committed
514
        const int64_t src_idx = head_id * size_per_head_div_x * max_seq_len +  // H
515
516
                                v_seq_len_id * size_per_head_div_x +           // s
                                v_head_size_id;                                // D/x
Li Zhang's avatar
Li Zhang committed
517
518
519

        const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len +  // B
                                head_id * size_per_head_div_x * max_kv_len +              // H
520
521
                                v_seq_len_id * size_per_head_div_x +                      // s
                                v_head_size_id;                                           // D/x
Li Zhang's avatar
Li Zhang committed
522

523
        val_dst[dst_idx] = val_src[src_idx];
Li Zhang's avatar
Li Zhang committed
524
525
526
527
    }
}

template<typename T>
AllentDan's avatar
AllentDan committed
528
529
530
531
532
533
534
535
536
__global__ void transpose_value_cache_int8(T*             v_dst,  //
                                           const int8_t** v_src,
                                           const size_t   src_offset,
                                           const int      head_num,
                                           const int      size_per_head,
                                           const int*     seq_length,
                                           const int      max_kv_len,
                                           const int      max_seq_len,
                                           const float    v_scale)
Li Zhang's avatar
Li Zhang committed
537
538
539
540
541
542
543
544
545
{
    const int     batch_id = blockIdx.y;
    const int     head_id  = blockIdx.z;
    constexpr int X_ELEMS  = (sizeof(T) == 4) ? 4 : 8;

    const int idx                 = blockIdx.x * blockDim.x + threadIdx.x;
    int       size_per_head_div_x = size_per_head / X_ELEMS;

    // x dim is now handled by uint4 type
546
    const auto val_src = reinterpret_cast<const uint2*>(v_src[batch_id] + src_offset);
Li Zhang's avatar
Li Zhang committed
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    const auto val_dst = reinterpret_cast<uint4*>(v_dst);

    const auto seq_len = seq_length[batch_id];

    const int v_head_size_id = idx % size_per_head_div_x;
    const int v_seq_len_id   = idx / size_per_head_div_x;

    if (v_seq_len_id < seq_len) {
        // [B, H, s, D/x] <- [B, H, S[:s], D/x]
        const int64_t src_idx = head_id * size_per_head_div_x * max_seq_len +  // H
                                v_seq_len_id * size_per_head_div_x +           // s
                                v_head_size_id;                                // D/x

        const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len +  // B
                                head_id * size_per_head_div_x * max_kv_len +              // H
                                v_seq_len_id * size_per_head_div_x +                      // s
                                v_head_size_id;                                           // D/x

565
566
        // int8x8 -> fp16x8
        const auto from_ptr = reinterpret_cast<const char4*>(val_src + src_idx);
AllentDan's avatar
AllentDan committed
567
        auto       to_ptr   = reinterpret_cast<half4*>(val_dst + dst_idx);
568
569
570

        to_ptr[0] = char4_scale_to_half4(from_ptr[0], v_scale);
        to_ptr[1] = char4_scale_to_half4(from_ptr[1], v_scale);
Li Zhang's avatar
Li Zhang committed
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
    }
}

template<typename T>
void invokeTransposeKVCache(T*           key_cache_trans,
                            T*           val_cache_trans,
                            const T**    key_cache,
                            const T**    val_cache,
                            size_t       src_offset,
                            int          batch_size,
                            const int*   key_length,
                            int          max_kv_len,
                            int          max_seq_len,
                            int          size_per_head,
                            int          head_num,
586
587
588
                            cudaStream_t stream,
                            int          quant,
                            const float* kv_scale)
Li Zhang's avatar
Li Zhang committed
589
590
591
592
593
594
{
    constexpr int block_sz = 128;
    constexpr int x        = (sizeof(T) == 4) ? 4 : 8;

    dim3 grid((max_kv_len * size_per_head / x + block_sz - 1) / block_sz, batch_size, head_num);

595
    if (quant & QuantPolicy::kCacheKVInt8) {
AllentDan's avatar
AllentDan committed
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
        transpose_value_cache_int8<<<grid, block_sz, 0, stream>>>(key_cache_trans,
                                                                  reinterpret_cast<const int8_t**>(key_cache),
                                                                  src_offset,
                                                                  head_num,
                                                                  size_per_head,
                                                                  key_length,
                                                                  max_kv_len,
                                                                  max_seq_len,
                                                                  kv_scale[0]);

        transpose_value_cache_int8<<<grid, block_sz, 0, stream>>>(val_cache_trans,
                                                                  reinterpret_cast<const int8_t**>(val_cache),
                                                                  src_offset,
                                                                  head_num,
                                                                  size_per_head,
                                                                  key_length,
                                                                  max_kv_len,
                                                                  max_seq_len,
                                                                  kv_scale[1]);
    }
    else {
617
618
619
620
621
622
        transpose_value_cache<<<grid, block_sz, 0, stream>>>(
            key_cache_trans, key_cache, src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len);

        transpose_value_cache<<<grid, block_sz, 0, stream>>>(
            val_cache_trans, val_cache, src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len);
    }
Li Zhang's avatar
Li Zhang committed
623
624
}

AllentDan's avatar
AllentDan committed
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
template void invokeTransposeKVCache(float*,
                                     float*,
                                     const float**,
                                     const float**,
                                     size_t,
                                     int,
                                     const int*,
                                     int,
                                     int,
                                     int,
                                     int,
                                     cudaStream_t stream,
                                     int,
                                     const float*);
template void invokeTransposeKVCache(half*,
                                     half*,
                                     const half**,
                                     const half**,
                                     size_t,
                                     int,
                                     const int*,
                                     int,
                                     int,
                                     int,
                                     int,
                                     cudaStream_t stream,
                                     int,
                                     const float*);
Li Zhang's avatar
Li Zhang committed
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690

__global__ void gatherOutput(int*       output_ids,
                             const int* ids,
                             const int* context_length,
                             int        max_context_len,
                             int        max_gen_step,
                             int        max_output_len,
                             int        batch_size)
{
    const int batch_id    = blockIdx.x;
    const int context_len = context_length[batch_id];
    output_ids += batch_id * max_output_len;
    for (int src_idx = threadIdx.x; src_idx < max_gen_step; src_idx += blockDim.x) {
        // skip padding for src
        if (context_len <= src_idx && src_idx < max_context_len) {
            continue;
        }
        // skip padding for dst
        const int dst_idx   = src_idx < context_len ? src_idx : src_idx - (max_context_len - context_len);
        output_ids[dst_idx] = ids[src_idx * batch_size + batch_id];
    }
}

void invokeGatherOutput(int*         output_ids,
                        const int*   ids,
                        const int*   context_length,
                        int          max_context_len,
                        int          max_gen_step,
                        int          max_output_len,
                        int          batch_size,
                        cudaStream_t stream)
{
    int block_size = 512;
    int grid_size  = batch_size;
    gatherOutput<<<grid_size, block_size, 0, stream>>>(
        output_ids, ids, context_length, max_context_len, max_gen_step, max_output_len, batch_size);
}

AllentDan's avatar
AllentDan committed
691
}  // namespace fastertransformer