LlamaContextDecoder.cc 13.4 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

Li Zhang's avatar
Li Zhang committed
18
// Modified from
lvhan028's avatar
lvhan028 committed
19
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGptContextDecoder.cc
Li Zhang's avatar
Li Zhang committed
20

lvhan028's avatar
lvhan028 committed
21
22
23
24
25
26
27
#include "src/turbomind/models/llama/LlamaContextDecoder.h"
#include "src/turbomind/kernels/bert_preprocess_kernels.h"
#include "src/turbomind/kernels/gpt_kernels.h"
#include "src/turbomind/models/llama/LlamaContextDecoder.h"
#include "src/turbomind/models/llama/llama_decoder_kernels.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/utils/Tensor.h"
Li Zhang's avatar
Li Zhang committed
28

lvhan028's avatar
lvhan028 committed
29
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
30
31
32
33
34
35
36
37
38
39

template<typename T>
void LlamaContextDecoder<T>::allocateBuffer()
{
    FT_CHECK(false);
}

template<typename T>
void LlamaContextDecoder<T>::allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len)
{
lvhan028's avatar
lvhan028 committed
40
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
41
42
43
44
45
46
47
48
49
50
51

    attention_mask_ = (T*)allocator_->reMalloc(attention_mask_, sizeof(T) * batch_size * max_q_len * max_kv_len, false);
    padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * max_q_len, false);
    cu_seqlens_     = (int*)allocator_->reMalloc(cu_seqlens_, sizeof(int) * (batch_size + 1), false);

    is_allocate_buffer_ = true;
}

template<typename T>
void LlamaContextDecoder<T>::freeBuffer()
{
lvhan028's avatar
lvhan028 committed
52
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
53
54
55
56
57
58
59
60
61
62
    if (is_allocate_buffer_) {
        allocator_->free((void**)&padding_offset_);
        allocator_->free((void**)&cu_seqlens_);
        allocator_->free((void**)&attention_mask_);
        allocator_->free((void**)&h_pinned_token_num_ptr_, true);
        is_allocate_buffer_ = false;
    }
}

template<typename T>
63
void LlamaContextDecoder<T>::initialize(size_t kv_head_num, bool use_fmha, int quant_policy)
Li Zhang's avatar
Li Zhang committed
64
65
66
67
{
    h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true);

    context_attention_layer_ = new LlamaContextAttentionLayer<T>(head_num_,
68
                                                                 kv_head_num,
Li Zhang's avatar
Li Zhang committed
69
70
71
72
73
74
75
76
                                                                 size_per_head_,
                                                                 rotary_embedding_dim_,
                                                                 false,  // neox_rotary_style
                                                                 tensor_para_,
                                                                 stream_,
                                                                 cublas_wrapper_,
                                                                 allocator_,
                                                                 is_free_buffer_after_forward_,
77
78
                                                                 use_fmha,
                                                                 quant_policy);
Li Zhang's avatar
Li Zhang committed
79
80
81
82
83
84
85
86
87
88
89
90
91

    silu_ffn_layer_ = new LlamaFfnLayer<T>(head_num_,
                                           size_per_head_,
                                           inter_size_,
                                           tensor_para_,
                                           stream_,
                                           cublas_wrapper_,
                                           allocator_,
                                           is_free_buffer_after_forward_);
}

template<typename T>
void LlamaContextDecoder<T>::forwardSelfAttn(const Session&                                 sess,
92
                                             T*                                             attn_io,
Li Zhang's avatar
Li Zhang committed
93
94
95
96
                                             const std::unordered_map<std::string, Tensor>* input_tensors,
                                             int                                            layer,
                                             bool                                           is_final)
{
lvhan028's avatar
lvhan028 committed
97
    // TM_LOG_ERROR(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
98
    TensorMap self_attention_input_tensors{
99
        {"input_query", Tensor{MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, attn_io}},
Li Zhang's avatar
Li Zhang committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        {"attention_mask",
         {MEMORY_GPU, data_type_, {sess.batch_size, 1, sess.max_query_len, sess.max_key_len}, attention_mask_}},
        {"layer_id", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &layer}},
        {"is_final_layer", Tensor{MEMORY_CPU, TYPE_BOOL, {1}, &is_final}},
        {"padding_offset", {MEMORY_GPU, TYPE_INT32, {sess.token_num}, padding_offset_}},
        {"cu_seqlens", {MEMORY_GPU, TYPE_INT32, {sess.batch_size + 1}, cu_seqlens_}},
        {"input_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.input_length}},
        {"history_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.history_length}},
        {"context_lengths", {MEMORY_GPU, TYPE_INT32, {sess.batch_size}, sess.context_length}},
        {"max_seq_len", input_tensors->at("max_seq_len")}};

    auto& k_cache = *sess.k_cache;
    auto& v_cache = *sess.v_cache;

    TensorMap self_attention_output_tensors{
115
        {"hidden_features", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, attn_io}},
Li Zhang's avatar
Li Zhang committed
116
117
118
119
120
121
122
123
124
125
126
        {"key_cache", k_cache},
        {"value_cache", v_cache},
    };

    context_attention_layer_->forward(&self_attention_output_tensors,  //
                                      &self_attention_input_tensors,
                                      &sess.weights->at(layer)->self_attn_weights);
}

template<typename T>
LlamaContextDecoder<T>::LlamaContextDecoder(size_t           head_num,
127
                                            size_t           kv_head_num,
Li Zhang's avatar
Li Zhang committed
128
129
130
131
132
133
134
135
136
137
                                            size_t           size_per_head,
                                            size_t           inter_size,
                                            size_t           num_layer,
                                            size_t           rotary_embedding_dim,
                                            float            rmsnorm_eps,
                                            NcclParam        tensor_para,
                                            cudaStream_t     stream,
                                            cublasMMWrapper* cublas_wrapper,
                                            IAllocator*      allocator,
                                            bool             is_free_buffer_after_forward,
138
139
                                            bool             use_fmha,
                                            int              quant_policy):
Li Zhang's avatar
Li Zhang committed
140
141
142
143
144
145
146
147
148
149
150
    BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward),
    head_num_(head_num),
    size_per_head_(size_per_head),
    inter_size_(inter_size),
    hidden_units_(head_num * size_per_head),
    num_layer_(num_layer),
    rotary_embedding_dim_(rotary_embedding_dim),
    rmsnorm_eps_(rmsnorm_eps),
    tensor_para_(tensor_para),
    data_type_(getTensorType<T>())
{
151
    initialize(kv_head_num, use_fmha, quant_policy);
Li Zhang's avatar
Li Zhang committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
}

template<typename T>
LlamaContextDecoder<T>::~LlamaContextDecoder()
{
    delete context_attention_layer_;
    delete silu_ffn_layer_;
    freeBuffer();
}

template<typename T>
void LlamaContextDecoder<T>::forward(std::vector<Tensor>*                            output_tensors,
                                     const std::vector<Tensor>*                      input_tensors,
                                     const std::vector<LlamaDecoderLayerWeight<T>*>* decoder_layer_weights)
{
    FT_CHECK(false);
}

template<typename T>
void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*        output_tensors,
                                     const std::unordered_map<std::string, Tensor>*  input_tensors,
                                     const std::vector<LlamaDecoderLayerWeight<T>*>* decoder_layer_weights)
{
    /**
     * input tensors:
     *   \param decoder_input [num_token, hidden_units], float
     *   \param input_lengths [batch_size], int
     *   \param history_lengths [batch_size], int
     *   \param context_legnths [batch_size], int
     *   \param output_norm_weight [hidden_dims], float
     *   \param max_q_len [1], int on cpu
     *   \param max_kv_len [1], int on cpu
     *   \param max_seq_len [1], int on cpu
     *
     * output tensors:
187
     *   \param decoder_output [num_token, hidden_units],
Li Zhang's avatar
Li Zhang committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
     *   \param key_cache [num_layer, batch, local_head_num, size_per_head // x, max_seq_len, x]
     *   \param value_cache [num_layer, batch, local_head_num, max_seq_len, size_per_head]
     *   \param last_token_hidden_units [batch_size, hidden_units]
     */

    Session sess{};

    sess.token_num     = input_tensors->at("decoder_input").shape[0];
    sess.batch_size    = input_tensors->at("input_lengths").shape[0];
    sess.max_query_len = input_tensors->at("max_q_len").getVal<int>();
    sess.max_key_len   = input_tensors->at("max_kv_len").getVal<int>();
    sess.weights       = decoder_layer_weights;

    sess.input_length   = input_tensors->at("input_lengths").getPtr<int>();
    sess.history_length = input_tensors->at("history_lengths").getPtr<int>();
    sess.context_length = input_tensors->at("context_lengths").getPtr<int>();

    T* decoder_input_output = input_tensors->at("decoder_input").getPtr<T>();
206
    T* decoder_output       = output_tensors->at("decoder_output").getPtr<T>();
Li Zhang's avatar
Li Zhang committed
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235

    sess.k_cache = &output_tensors->at("key_cache");
    sess.v_cache = &output_tensors->at("value_cache");

    allocateBuffer(sess.batch_size, sess.token_num, sess.max_query_len, sess.max_key_len);

    size_t tmp_token_num{};
    invokeGetPaddingOffsetAndCuSeqLens(h_pinned_token_num_ptr_,
                                       &tmp_token_num,  // updated token num
                                       padding_offset_,
                                       cu_seqlens_,
                                       input_tensors->at("input_lengths").getPtr<int>(),
                                       sess.batch_size,
                                       sess.max_query_len,
                                       stream_);
    sync_check_cuda_error();
    FT_CHECK(tmp_token_num == sess.token_num);

    invokeCreateCausalMasks(attention_mask_,
                            sess.input_length,
                            sess.context_length,
                            sess.max_query_len,
                            sess.max_key_len,
                            sess.batch_size,
                            stream_);
    sync_check_cuda_error();

    /////////////////////////////////////////////
    /// RMSNorm
236
    invokeRootMeanSquareNorm(decoder_output,
Li Zhang's avatar
Li Zhang committed
237
238
239
240
241
242
243
244
245
246
247
                             decoder_input_output,
                             decoder_layer_weights->at(0)->self_attn_norm_weights,
                             rmsnorm_eps_,
                             sess.token_num,
                             hidden_units_,
                             stream_);
    sync_check_cuda_error();

    for (size_t layer = 0; layer < num_layer_; ++layer) {
        /////////////////////////////////////////////
        /// self-attention
248
        forwardSelfAttn(sess, decoder_output, input_tensors, layer, false);
Li Zhang's avatar
Li Zhang committed
249

Li Zhang's avatar
Li Zhang committed
250
        invokeFusedAddBiasResidualRMSNorm(decoder_input_output,
251
                                          decoder_output,
Li Zhang's avatar
Li Zhang committed
252
253
254
255
256
257
                                          decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
                                          decoder_layer_weights->at(layer)->ffn_norm_weights,
                                          rmsnorm_eps_,
                                          sess.token_num,
                                          hidden_units_,
                                          stream_);
Li Zhang's avatar
Li Zhang committed
258
259
260
261
        sync_check_cuda_error();

        ////////////////////////////////////////////
        /// feed-forward network
262
263
264
        TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}};
        TensorMap ffn_outputs{
            {"ffn_output", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}};
Li Zhang's avatar
Li Zhang committed
265
266
267
268
        silu_ffn_layer_->forward(&ffn_outputs, &ffn_inputs, &decoder_layer_weights->at(layer)->ffn_weights);

        auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
                                                     input_tensors->at("output_norm_weight").getPtr<T>();
Li Zhang's avatar
Li Zhang committed
269
        invokeFusedAddBiasResidualRMSNorm(decoder_input_output,  //
270
                                          decoder_output,
Li Zhang's avatar
Li Zhang committed
271
272
273
274
275
276
                                          decoder_layer_weights->at(layer)->ffn_weights.output.bias,
                                          scale_weight,
                                          rmsnorm_eps_,
                                          sess.token_num,
                                          hidden_units_,
                                          stream_);
Li Zhang's avatar
Li Zhang committed
277
278
279
280
281
282
283
284
285
286
287
        sync_check_cuda_error();
    }

    if (is_free_buffer_after_forward_) {
        freeBuffer();
    }
}

template class LlamaContextDecoder<float>;
template class LlamaContextDecoder<half>;

lvhan028's avatar
lvhan028 committed
288
}  // namespace turbomind