LlamaDecoderSelfAttentionLayer.cc 13.9 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

Li Zhang's avatar
Li Zhang committed
18
// Modified from
lvhan028's avatar
lvhan028 committed
19
20
21
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/layers/attention_layers/DecoderSelfAttentionLayer.cc
#include "src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h"
#include "src/turbomind/kernels/decoder_masked_multihead_attention.h"
Chen Xin's avatar
Chen Xin committed
22
#include "src/turbomind/macro.h"
lvhan028's avatar
lvhan028 committed
23
24
25
26
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_utils.h"
27
#include "src/turbomind/utils/logger.h"
lvhan028's avatar
lvhan028 committed
28
#include "src/turbomind/utils/nvtx_utils.h"
Li Zhang's avatar
Li Zhang committed
29
30
31
#include <string>
// #include <glog/logging.h>

lvhan028's avatar
lvhan028 committed
32
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

template<typename T>
struct SATypeConverter {
    using Type = T;
};

template<>
struct SATypeConverter<half> {
    using Type = uint16_t;
};

template<typename T>
static inline void fusedQKV_masked_attention_dispatch(const T*     qkv_buf,
                                                      const T*     qkv_bias,
                                                      const T*     relative_attention_bias,
                                                      T*           key_cache,
                                                      T*           value_cache,
                                                      T**          k_cache_per_sample,
                                                      T**          v_cache_per_sample,
                                                      size_t       kv_cache_per_sample_offset,
                                                      const int*   cache_indir,
                                                      T*           context_buf,
                                                      const bool*  finished,
                                                      const int*   sequence_lengths,
                                                      const int    max_batch_size,
                                                      const int    inference_batch_size,
                                                      const int    beam_width,
                                                      const int    head_num,
61
                                                      const int    kv_head_num,
Li Zhang's avatar
Li Zhang committed
62
63
                                                      const int    size_per_head,
                                                      const int    rotary_embedding_dim,
64
65
66
                                                      const int    max_position_embeddings,
                                                      const bool   use_dynamic_ntk,
                                                      const bool   use_logn_attn,
Li Zhang's avatar
Li Zhang committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
                                                      const int    memory_max_len,
                                                      const int*   prefix_prompt_lengths,
                                                      const int    max_prefix_prompt_length,
                                                      const int    max_input_len,
                                                      const int*   total_padding_tokens,
                                                      const int    step,
                                                      const float  q_scaling,
                                                      const int    relative_attention_bias_stride,
                                                      const T*     linear_bias_slopes,
                                                      const bool*  masked_tokens,
                                                      const int*   ia3_tasks,
                                                      const T*     ia3_key_weights,
                                                      const T*     ia3_value_weights,
                                                      const float* qkv_scale_out,
                                                      const float* attention_out_scale,
                                                      const int    int8_mode,
83
                                                      const float* attention_kv_scale,
Li Zhang's avatar
Li Zhang committed
84
85
86
87
88
89
                                                      cudaStream_t stream)
{
    using DataType = typename SATypeConverter<T>::Type;
    // Prepare the parameters.
    Masked_multihead_attention_params<DataType> params;
    memset(&params, 0, sizeof(params));
90
    // int hidden_units = head_num * size_per_head;
Li Zhang's avatar
Li Zhang committed
91
92
    if (qkv_bias != nullptr) {
        params.q_bias = reinterpret_cast<const DataType*>(qkv_bias);
93
94
        params.k_bias = reinterpret_cast<const DataType*>(qkv_bias) + head_num * size_per_head;
        params.v_bias = reinterpret_cast<const DataType*>(qkv_bias) + (head_num + kv_head_num) * size_per_head;
Li Zhang's avatar
Li Zhang committed
95
96
97
98
99
100
101
102
103
104
105
    }
    else {
        params.q_bias = nullptr;
        params.k_bias = nullptr;
        params.v_bias = nullptr;
    }

    // Set the output buffer.
    params.out = reinterpret_cast<DataType*>(context_buf);

    // Set the input buffers.
106
    // [B, nH + kvH, D]
Li Zhang's avatar
Li Zhang committed
107
    params.q = reinterpret_cast<const DataType*>(qkv_buf);
108
109
    params.k = reinterpret_cast<const DataType*>(qkv_buf) + head_num * size_per_head;
    params.v = reinterpret_cast<const DataType*>(qkv_buf) + (head_num + kv_head_num) * size_per_head;
110

111
    params.stride   = (head_num + 2 * kv_head_num) * size_per_head;
Li Zhang's avatar
Li Zhang committed
112
113
    params.finished = const_cast<bool*>(finished);

114
115
    FT_CHECK(k_cache_per_sample && v_cache_per_sample);

Li Zhang's avatar
Li Zhang committed
116
117
118
119
120
121
122
123
124
125
    params.k_cache_per_sample         = reinterpret_cast<DataType**>(k_cache_per_sample);
    params.v_cache_per_sample         = reinterpret_cast<DataType**>(v_cache_per_sample);
    params.kv_cache_per_sample_offset = kv_cache_per_sample_offset;
    params.batch_size                 = inference_batch_size;
    params.beam_width                 = beam_width;
    params.memory_max_len             = memory_max_len;
    params.prefix_prompt_lengths      = prefix_prompt_lengths;
    params.max_prefix_prompt_length   = max_prefix_prompt_length;
    params.length_per_sample          = sequence_lengths;  // max_input_length + current output length
    // timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation
126
127
128
129
    params.timestep     = step + max_prefix_prompt_length - 1;
    params.num_heads    = head_num;
    params.num_kv_heads = kv_head_num;

130
131
132
133
134
135
    params.hidden_size_per_head    = size_per_head;
    params.rotary_embedding_dim    = rotary_embedding_dim;
    params.max_position_embeddings = max_position_embeddings;
    params.use_dynamic_ntk         = use_dynamic_ntk;
    params.use_logn_attn           = use_logn_attn;

Li Zhang's avatar
Li Zhang committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    // Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust)
    params.inv_sqrt_dh = 1.F / (sqrtf((float)params.hidden_size_per_head) * q_scaling);

    params.total_padding_tokens = total_padding_tokens;
    if (relative_attention_bias != nullptr) {
        params.relative_attention_bias = reinterpret_cast<const DataType*>(relative_attention_bias);
    }
    params.relative_attention_bias_stride = relative_attention_bias_stride;
    params.masked_tokens                  = masked_tokens;

    // The slope of linear position bias per head, e.g., ALiBi.
    if (linear_bias_slopes != nullptr) {
        params.linear_bias_slopes = reinterpret_cast<const DataType*>(linear_bias_slopes);
    }
    params.max_input_length = max_input_len;

    params.int8_mode = int8_mode;
153
154
155

    if (int8_mode & QuantPolicy::kCacheKVInt8) {
        params.attention_k_scale = attention_kv_scale[0];
156
157
158
        params.attention_k_zp    = attention_kv_scale[1];
        params.attention_v_scale = attention_kv_scale[2];
        params.attention_v_zp    = attention_kv_scale[3];
Li Zhang's avatar
Li Zhang committed
159
160
161
162
163
164
165
166
167
168
    }

    PUSH_RANGE("scaled dot-product fusion");
    masked_multihead_attention(params, stream);
    POP_RANGE;
}

template<typename T>
void LlamaDecoderSelfAttentionLayer<T>::allocateBuffer(size_t batch_size, int key_len, int max_memory_len)
{
lvhan028's avatar
lvhan028 committed
169
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
170
171
172
173
174

    const size_t local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_;

    qkv_buf_ = reinterpret_cast<T*>(
        allocator_->reMalloc(qkv_buf_, sizeof(T) * batch_size * local_q_kv_head_num * size_per_head_, false));
Li Zhang's avatar
Li Zhang committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    context_buf_ =
        reinterpret_cast<T*>(allocator_->reMalloc(context_buf_, sizeof(T) * batch_size * local_hidden_units_, false));

    is_allocate_buffer_ = true;
}

template<typename T>
void LlamaDecoderSelfAttentionLayer<T>::freeBuffer()
{
    if (is_allocate_buffer_) {
        allocator_->free((void**)(&qkv_buf_));
        allocator_->free((void**)(&context_buf_));
        is_allocate_buffer_ = false;
    }
}

template<typename T>
void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap*                     output_tensors,
                                                const TensorMap*               input_tensors,
                                                const LlamaAttentionWeight<T>* weights)
{
    /**
     * input tensors:
     *    \param input_query [batch_size, hidden_units],
     *    \param sequence_lengths [batch_size]
     *    \param step [1] on cpu
     *    \param finished [batch_size]
     *    \param total_padding_tokens [batch_size]
     *    \param layer_id [1], int on cpu
     *    \param max_seq_len [1] on cpu
     *    \param masked_tokens [batch_size, memory_len], (optional), NOT USED YET
     *    \param cache_indirection [batch_size / beam_width, beam_width, memory_max_len] (optional)
     *
     * output tensors:
     *    \param attention_output [batch_size, hidden_units],
210
     *    \param key_cache [batch, local_head_num, memory_max_len, size_per_head]
Li Zhang's avatar
Li Zhang committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
     *    \param value_cache [batch, local_head_num, memory_max_len, size_per_head]
     */

    const T*    input_query_data      = input_tensors->getPtr<T>("input_query");
    const int*  sequence_lengths_data = input_tensors->getPtr<int>("sequence_lengths");
    const int*  total_padding_len     = input_tensors->getPtr<int>("total_padding_tokens");
    const bool* finished_data         = input_tensors->getPtr<bool>("finished", nullptr);
    const bool* masked_tokens_data    = input_tensors->getPtr<bool>("masked_tokens", nullptr);
    const int*  cache_indir           = input_tensors->getPtr<int>("cache_indirection", nullptr);

    T*  hidden_features_data = output_tensors->getPtr<T>("attention_output");
    T** key_cache_ptrs       = output_tensors->getPtr<T*>("key_cache");
    T** value_cache_ptrs     = output_tensors->getPtr<T*>("value_cache");

    const int layer_id = input_tensors->getVal<int>("layer_id");

    const int max_seq_len = input_tensors->getVal<int>("max_seq_len");
    const int step        = input_tensors->getVal<int>("step");

    const int step_1 = step - 1;

    const int batch_size = input_tensors->at("input_query").shape[0];
    const int beam_width = cache_indir != nullptr ? input_tensors->at("cache_indirection").shape[1] : 1;

    allocateBuffer(batch_size, step, max_seq_len);

    PUSH_RANGE("qkv_gemm");
    linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv);
    POP_RANGE;

241
    const auto kv_cache_layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_;
Li Zhang's avatar
Li Zhang committed
242
243
244
245
    const int  memory_len            = max_seq_len;

    fusedQKV_masked_attention_dispatch<T>(
        qkv_buf_,
Li Zhang's avatar
Li Zhang committed
246
247
        weights->qkv.bias,  // query_weight.bias,
        nullptr,            // relative_attention_bias,
Li Zhang's avatar
Li Zhang committed
248
249
250
251
252
253
254
255
256
257
258
259
260
        nullptr,
        nullptr,
        key_cache_ptrs,
        value_cache_ptrs,
        kv_cache_layer_offset,
        cache_indir,
        context_buf_,
        finished_data,
        sequence_lengths_data,  // NOTE: current seq len including padding (fixed after meeting the finished id)
        batch_size,
        batch_size,
        beam_width,
        local_head_num_,
261
        local_kv_head_num_,
Li Zhang's avatar
Li Zhang committed
262
        size_per_head_,
263
264
265
266
        params_.rotray_embedding_dim,
        params_.max_position_embeddings,
        params_.use_dynamic_ntk,
        params_.use_logn_attn,
Li Zhang's avatar
Li Zhang committed
267
268
269
270
271
272
        memory_len,
        nullptr,  // prefix_prompt_lengths
        0,        // max_prefix_prompt_length
        0,        // max_input_length, not used w/o linear_bias_slopes
        input_tensors->getPtr<int>("total_padding_tokens", nullptr),
        step,
AllentDan's avatar
AllentDan committed
273
274
275
276
277
278
279
280
281
282
283
        1.f,                            // q_scaling
        0,                              // relative_attention_bias_stride
        nullptr,                        // linear_bias_slopes
        nullptr,                        //  masked_tokens_data,
        nullptr,                        // ia3_tasks
        nullptr,                        // ia3_key_weights
        nullptr,                        // ia3_value_weights
        nullptr,                        // qkv_scale_out
        nullptr,                        // attention_out_scale
        quant_policy_,                  // int8_mode
        weights->past_kv_scale.data(),  // attention kv scale
Li Zhang's avatar
Li Zhang committed
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        stream_);
    sync_check_cuda_error();

    linear_.forward(hidden_features_data, context_buf_, batch_size, weights->output);

    if (tensor_para_.world_size_ > 1) {
        NcclGuard nccl_guard(tensor_para_, stream_);
        ftNcclAllReduceSum(
            hidden_features_data, hidden_features_data, batch_size * hidden_units_, tensor_para_, stream_);
        sync_check_cuda_error();
    }

    if (is_free_buffer_after_forward_) {
        freeBuffer();
    }

    // LOG(WARNING);
}

template class LlamaDecoderSelfAttentionLayer<float>;
template class LlamaDecoderSelfAttentionLayer<half>;

lvhan028's avatar
lvhan028 committed
306
}  // namespace turbomind