LlamaContextAttentionLayer.cc 19.5 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.
 * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
Li Zhang's avatar
Li Zhang committed
18
19

// Modified from
lvhan028's avatar
lvhan028 committed
20
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/layers/attention_layers/GptContextAttentionLayer.cc
Li Zhang's avatar
Li Zhang committed
21

lvhan028's avatar
lvhan028 committed
22
23
24
25
26
27
28
29
#include "src/turbomind/models/llama/LlamaContextAttentionLayer.h"
#include "src/turbomind/kernels/bert_preprocess_kernels.h"
#include "src/turbomind/kernels/unfused_attention_kernels.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h"
30
#include "src/turbomind/utils/logger.h"
Li Zhang's avatar
Li Zhang committed
31

lvhan028's avatar
lvhan028 committed
32
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
33
34
35
36
37
38
39

template<typename T>
void LlamaContextAttentionLayer<T>::allocateBuffer(size_t batch_size,
                                                   size_t num_token,
                                                   size_t max_q_len,
                                                   size_t max_k_len)
{
lvhan028's avatar
lvhan028 committed
40
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
41

42
43
    const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_;

Li Zhang's avatar
Li Zhang committed
44
    // no padding
45
    qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * local_q_kv_head_num * size_per_head_, true);
Li Zhang's avatar
Li Zhang committed
46
47

    // padding is rebuilt for q/k/v_buf_2_
48
49
50
51
52
    // [qH + 2kvH, B, S, D]
    q_buf_2_ = (T*)allocator_->reMalloc(
        q_buf_2_, sizeof(T) * local_q_kv_head_num * batch_size * max_q_len * size_per_head_, true);
    k_buf_2_ = q_buf_2_ + local_head_num_ * batch_size * max_q_len * size_per_head_;
    v_buf_2_ = k_buf_2_ + local_kv_head_num_ * batch_size * max_q_len * size_per_head_;
Li Zhang's avatar
Li Zhang committed
53
54
55
56
57
58
59
60

    if (use_fmha_) {
        FlashAttentionOp<T> flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_);
        if (flash_attention.get_workspace_size() > 0) {
            qk_buf_float_ = (float*)allocator_->reMalloc(qk_buf_float_, flash_attention.get_workspace_size(), true);
        }
    }
    else {
61
        // kv heads are repeated for unfused attention
Li Zhang's avatar
Li Zhang committed
62
63
64
65
66
67
68
69
        k_cache_buf_ = (T*)allocator_->reMalloc(
            k_cache_buf_, 2 * sizeof(T) * batch_size * local_head_num_ * max_k_len * size_per_head_, true);
        v_cache_buf_ = k_cache_buf_ + batch_size * local_head_num_ * max_k_len * size_per_head_;

        qk_buf_ =
            (T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * local_head_num_ * max_q_len * max_k_len, true);

        // qkv_buf_2_ has padding
70
71
        qkv_buf_2_ = (T*)allocator_->reMalloc(
            qkv_buf_2_, sizeof(T) * batch_size * max_q_len * local_head_num_ * size_per_head_, true);
Li Zhang's avatar
Li Zhang committed
72
73
74
    }

    // qkv_buf_3_ padding is removed
75
    qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_head_num_ * size_per_head_, true);
Li Zhang's avatar
Li Zhang committed
76
77
78
79
80
81
82
83

    is_allocate_buffer_ = true;
}

template<typename T>
void LlamaContextAttentionLayer<T>::freeBuffer()
{
    if (is_allocate_buffer_) {
lvhan028's avatar
lvhan028 committed
84
        TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

        allocator_->free((void**)(&qkv_buf_));
        allocator_->free((void**)(&q_buf_2_));
        if (use_fmha_) {
            allocator_->free((void**)&qk_buf_float_);
        }
        else {
            allocator_->free((void**)(&k_cache_buf_));
            allocator_->free((void**)(&qk_buf_));
            allocator_->free((void**)(&qkv_buf_2_));
        }
        allocator_->free((void**)(&qkv_buf_3_));

        is_allocate_buffer_ = false;
    }
}

template<typename T>
inline void LlamaContextAttentionLayer<T>::forward(TensorMap*                     output_tensors,
                                                   const TensorMap*               input_tensors,
                                                   const LlamaAttentionWeight<T>* weights)
{
lvhan028's avatar
lvhan028 committed
107
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

    /**
     * input_tensors:
     *   \param input_query [token_num, hidden_dim]
     *   \param attention_mask [batch_size, 1, max_q_len, max_kv_len]
     *   \param padding_offset [token_num], int
     *   \param input_lengths [batch_size], int
     *   \param history_lengths [batch_size], int
     *   \param context_lengths [batch_size], int
     *   \param cu_seqlens [batch_size+1], int
     *   \param max_seq_len [1], int on cpu
     *   \param is_final_layer [1], bool on cpu
     *   \param layer_id [1], int on cpu
     *
     * output_tensors:
     *   \param hidden_features [token_num, hidden_dim]
     *   \param key_cache [batch_size], uint64
     *   \param value_cache [batch_size], uint64
     */

    /////////////////////////////////////////////
    /// parse inputs
    const int batch_size = input_tensors->at("attention_mask").shape[0];
    const int max_q_len  = input_tensors->at("attention_mask").shape[2];
    const int max_k_len  = input_tensors->at("attention_mask").shape[3];
    const int layer_id   = input_tensors->getVal<int>("layer_id");

    const int num_token = input_tensors->at("input_query").shape[0];

    const int max_seq_len = input_tensors->at("max_seq_len").getVal<int>();

    T* attention_out   = output_tensors->at("hidden_features").getPtr<T>();
    T* attention_input = input_tensors->at("input_query").getPtr<T>();
    T* attention_mask  = input_tensors->at("attention_mask").getPtr<T>();

    const auto input_length   = input_tensors->at("input_lengths").getPtr<const int>();
    const auto history_length = input_tensors->at("history_lengths").getPtr<const int>();
    const auto context_length = input_tensors->at("context_lengths").getPtr<const int>();
    int*       cu_seqlens     = input_tensors->at("cu_seqlens").getPtr<int>();

    const auto padding_offset = input_tensors->at("padding_offset").getPtr<int>();

    /////////////////////////////////////////////
    /// allocate buffers
    allocateBuffer(batch_size, num_token, max_q_len, max_k_len);

    //////////////////////////////////////////////
    /// qkv gemm
    // [token_num, hidden_dim] -> [token_num, 3, local_hidden_dim]
    linear_.forward(qkv_buf_, attention_input, num_token, weights->qkv);

    //////////////////////////////////////////////
    /// transpose qkv & apply rotary embedding & rebuild padding
161
    /// qkv [B, s, H + 2kvH, D] -> (q [B, H, s, D], k [B, kvH, s, D], v [B, kvH, s, D])
Li Zhang's avatar
Li Zhang committed
162
163
164
165
166
    invokeAddFusedQKVBiasTranspose(q_buf_2_,
                                   k_buf_2_,
                                   v_buf_2_,
                                   PrefixPromptBatchWeightsParam<T>{},
                                   qkv_buf_,
Li Zhang's avatar
Li Zhang committed
167
168
169
                                   weights->qkv.bias,
                                   padding_offset,  // padding_offset,
                                   history_length,  // used for applying rotary embedding
Li Zhang's avatar
Li Zhang committed
170
171
172
173
                                   batch_size,
                                   max_q_len,  // seq_len
                                   num_token,  // batch_size * seq_len
                                   local_head_num_,
174
                                   local_kv_head_num_,
Li Zhang's avatar
Li Zhang committed
175
176
177
178
179
180
181
182
                                   size_per_head_,
                                   rotary_embedding_dim_,
                                   neox_rotary_style_,
                                   nullptr,  // query_weight.scale_out
                                   0,        // int8 mode
                                   stream_);
    sync_check_cuda_error();

183
    const size_t layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_;
Li Zhang's avatar
Li Zhang committed
184
185
186
187
188
189

    auto k_cache_ptrs = output_tensors->getPtr<T*>("key_cache");
    auto v_cache_ptrs = output_tensors->getPtr<T*>("value_cache");
    //////////////////////////////////////////////////////////
    /// insert the k/v computed from inputs into k/v cache
    /// transpose kv -> kv cache
190
191
192
    // put k/v_buf from shape [B, kvH, s, D] to
    // k_buf_2 [B, kvH, s, D] -> key_cache [B, kvH, S[t:t+s], D/x, x]
    // v_buf_2 [B, kvH, s, D] -> val_cache [B, kvH, S[t:t+s], D/x, x]
Li Zhang's avatar
Li Zhang committed
193
194
195
196
197
198
199
200
201
202
203
    invokeExtendKVCache(k_cache_ptrs,
                        v_cache_ptrs,
                        layer_offset,
                        k_buf_2_,
                        v_buf_2_,
                        batch_size,
                        input_length,
                        max_q_len,
                        history_length,
                        max_seq_len,
                        size_per_head_,
204
                        local_kv_head_num_,
205
206
207
208
209
                        stream_,
                        quant_policy_,
                        weights->past_kv_scale.data());

    sync_check_cuda_error();
Li Zhang's avatar
Li Zhang committed
210
    if (use_fmha_) {
211
        FT_CHECK(local_head_num_ == local_kv_head_num_);
Li Zhang's avatar
Li Zhang committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        fusedMultiHeadAttention(k_cache_ptrs,
                                v_cache_ptrs,
                                layer_offset,
                                attention_mask,
                                cu_seqlens,
                                batch_size,
                                max_q_len,
                                max_k_len,
                                max_seq_len);
    }
    else {
        unfusedMultiHeadAttention(k_cache_ptrs,
                                  v_cache_ptrs,
                                  layer_offset,
                                  attention_mask,
                                  padding_offset,
                                  context_length,
                                  batch_size,
                                  num_token,
                                  max_q_len,
                                  max_k_len,
233
234
235
                                  max_seq_len,
                                  quant_policy_,
                                  weights->past_kv_scale.data());
Li Zhang's avatar
Li Zhang committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    }

    //////////////////////////////////////////////
    /// output gemm <Bs,HD> -> <Bs,HD>
    linear_.forward(attention_out, qkv_buf_3_, num_token, weights->output);

    if (tensor_para_.world_size_ > 1) {
        NcclGuard nccl_guard(tensor_para_, stream_);
        ftNcclAllReduceSum(attention_out, attention_out, num_token * hidden_units_, tensor_para_, stream_);
        sync_check_cuda_error();
    }

    if (is_free_buffer_after_forward_ == true) {
        freeBuffer();
    }
    sync_check_cuda_error();
}

template<typename T>
void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T**    key_cache_ptrs,
                                                            T**    val_cache_ptrs,
                                                            size_t cache_layer_offset,
                                                            T*     attention_mask,
                                                            int*   cu_seqlens,
                                                            int    batch_size,
                                                            int    max_q_len,
                                                            int    max_k_len,
                                                            int    max_seq_len)
{
    //////////////////////////////////////////////
    // flash attention
    using AttentionOp = FlashAttentionOp<T>;
    using Layout      = typename AttentionOp::AttentionLayout;
    Layout layout_q{.stride_batch = int(local_head_num_ * max_q_len * size_per_head_),
                    .stride_seq   = int(size_per_head_),
                    .stride_head  = int(max_q_len * size_per_head_)};
    Layout layout_k{.stride_batch      = int(local_head_num_ * max_seq_len * size_per_head_),
                    .stride_seq        = int(size_per_head_),
                    .stride_head       = int(max_seq_len * size_per_head_),
                    .batch_seqs_offset = int(cache_layer_offset),
                    .batch_seqs        = key_cache_ptrs};
    Layout layout_v{.stride_batch      = int(local_head_num_ * max_seq_len * size_per_head_),
                    .stride_seq        = int(size_per_head_),
                    .stride_head       = int(max_seq_len * size_per_head_),
                    .batch_seqs_offset = int(cache_layer_offset),
                    .batch_seqs        = val_cache_ptrs};
    Layout layout_o{
        .stride_batch = int(local_head_num_ * max_q_len * size_per_head_),
        .stride_seq   = int(local_head_num_ * size_per_head_),
        .stride_head  = int(size_per_head_),
        .use_seqlens  = true,
    };
    AttentionOp flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_);

    typename AttentionOp::Params attn_params{.attn_out     = qkv_buf_3_,
                                             .query        = q_buf_2_,
                                             .key          = k_cache_buf_,
                                             .val          = v_cache_buf_,
                                             .mask         = attention_mask,
                                             .out_accum    = qk_buf_float_,
                                             .cu_seqlens_q = cu_seqlens,
                                             .cu_seqlens_k = nullptr,
                                             .layout_q     = layout_q,
                                             .layout_k     = layout_k,
                                             .layout_v     = layout_v,
                                             .layout_o     = layout_o};

    //
    flash_attention(attn_params, stream_);
}

template<typename T>
AllentDan's avatar
AllentDan committed
308
309
310
311
312
313
314
315
316
317
318
319
320
void LlamaContextAttentionLayer<T>::unfusedMultiHeadAttention(T**          key_cache_ptrs,
                                                              T**          val_cache_ptrs,
                                                              size_t       cache_layer_offset,
                                                              const T*     attention_mask,
                                                              const int*   padding_offset,
                                                              const int*   context_length,
                                                              int          batch_size,
                                                              int          num_token,
                                                              int          max_q_len,
                                                              int          max_k_len,
                                                              int          max_seq_len,
                                                              int          quant,
                                                              const float* kv_scale)
Li Zhang's avatar
Li Zhang committed
321
{
322
323
    // key_cache [B, kvH, S[:t+s], D/x, x] -> [B, qH, t+s, D]
    // val_cache [B, kvH, S[:t+s], D/x, x] -> [B, qH, t+s, D]
Li Zhang's avatar
Li Zhang committed
324
325
326
327
328
329
330
331
332
333
334
    invokeTransposeKVCache(k_cache_buf_,
                           v_cache_buf_,
                           (const T**)key_cache_ptrs,
                           (const T**)val_cache_ptrs,
                           cache_layer_offset,
                           batch_size,
                           context_length,  // history_len + input_len = context_len
                           max_k_len,
                           max_seq_len,
                           size_per_head_,
                           local_head_num_,
335
                           head_n_rep_,
336
337
338
                           stream_,
                           quant,
                           kv_scale);
Li Zhang's avatar
Li Zhang committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
    sync_check_cuda_error();

    const T qk_scale = static_cast<T>(1.f / sqrtf(size_per_head_ * 1.f));

    //////////////////////////////////////////////
    /// Q*K batch gemm
    /// -> [B, H, s, t + s]
    cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_T,
                                        CUBLAS_OP_N,
                                        max_k_len,                      // m
                                        max_q_len,                      // n
                                        size_per_head_,                 // k
                                        k_cache_buf_,                   // A
                                        size_per_head_,                 // lda
                                        max_k_len * size_per_head_,     // strideA
                                        q_buf_2_,                       // B
                                        size_per_head_,                 // ldb
                                        max_q_len * size_per_head_,     // strideB
                                        qk_buf_,                        // C
                                        max_k_len,                      // ldc
                                        max_q_len * max_k_len,          // strideC
                                        batch_size * local_head_num_);  // batchCount

    //////////////////////////////////////////////
    /// ! masked softmax (kernel asserts k_length <= 4096)
    MaskedSoftmaxParam<T, T> param{};
    param.attention_score    = qk_buf_;
    param.qk                 = qk_buf_;
    param.attention_mask     = attention_mask;
    param.batch_size         = batch_size;
    param.q_length           = max_q_len;
    param.k_length           = max_k_len;
    param.num_heads          = local_head_num_;
    param.qk_scale           = qk_scale;
    param.linear_bias_slopes = nullptr;
    invokeMaskedSoftmax(param, stream_);
    sync_check_cuda_error();

    //////////////////////////////////////////////
    /// softmax(QK)*V batch gemm
    // -> [B, H, S, D]
    cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_N,
                                        CUBLAS_OP_N,
                                        size_per_head_,                 // m
                                        max_q_len,                      // n
                                        max_k_len,                      // k
                                        v_cache_buf_,                   // A
                                        size_per_head_,                 // lda
                                        max_k_len * size_per_head_,     // strideA,
                                        qk_buf_,                        // B
                                        max_k_len,                      // ldb
                                        max_k_len * max_q_len,          // strideB
                                        qkv_buf_2_,                     // C
                                        size_per_head_,                 // ldc,
                                        max_q_len * size_per_head_,     // strideC
                                        batch_size * local_head_num_);  // batchCount

    //////////////////////////////////////////////
    /// transpose <B,h,s,D> -> <B,s,h,D>
    invokeTransposeAttentionOutRemovePadding(qkv_buf_2_,
                                             qkv_buf_3_,
                                             num_token,
                                             batch_size,
                                             max_q_len,
                                             local_head_num_,
                                             size_per_head_,
                                             padding_offset,
                                             nullptr,
                                             0,
                                             stream_);
    sync_check_cuda_error();
}

template class LlamaContextAttentionLayer<float>;
template class LlamaContextAttentionLayer<half>;

lvhan028's avatar
lvhan028 committed
415
}  // namespace turbomind