LlamaV2.cc 23 KB
Newer Older
AllentDan's avatar
AllentDan committed
1
/*
Li Zhang's avatar
Li Zhang committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.
 * Copyright (c) 2022, SK Telecom Authored by A. Dialog
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

AllentDan's avatar
AllentDan committed
20
// Modified from
lvhan028's avatar
lvhan028 committed
21
22
23
24
25
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGpt.cc

#include "src/turbomind/models/llama/LlamaV2.h"
#include "src/turbomind/kernels/decoding_kernels.h"
#include "src/turbomind/kernels/gpt_kernels.h"
Chen Xin's avatar
Chen Xin committed
26
#include "src/turbomind/macro.h"
lvhan028's avatar
lvhan028 committed
27
28
29
30
#include "src/turbomind/models/llama/LlamaBatch.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/LlamaWeight.h"
#include "src/turbomind/models/llama/Request.h"
Li Zhang's avatar
Li Zhang committed
31
#include "src/turbomind/models/llama/SequenceManager.h"
32
#include "src/turbomind/models/llama/llama_params.h"
lvhan028's avatar
lvhan028 committed
33
#include "src/turbomind/models/llama/llama_utils.h"
34
#include "src/turbomind/models/llama/unified_decoder.h"
lvhan028's avatar
lvhan028 committed
35
36
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h"
Li Zhang's avatar
Li Zhang committed
37
#include "src/turbomind/utils/logger.h"
Li Zhang's avatar
Li Zhang committed
38
39
40
41
#include <functional>
#include <memory>
#include <sstream>

lvhan028's avatar
lvhan028 committed
42
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
43
44
45

template<typename T>
LlamaV2<T>::LlamaV2(size_t                       head_num,
46
                    size_t                       kv_head_num,
Li Zhang's avatar
Li Zhang committed
47
48
49
50
51
                    size_t                       size_per_head,
                    size_t                       inter_size,
                    size_t                       num_layer,
                    size_t                       vocab_size,
                    float                        norm_eps,
52
                    const LlamaAttentionParams&  attn_params,
Li Zhang's avatar
Li Zhang committed
53
54
                    int                          start_id,
                    int                          end_id,
Li Zhang's avatar
Li Zhang committed
55
                    int                          cache_block_seq_len,
56
                    int                          quant_policy,
Li Zhang's avatar
Li Zhang committed
57
                    bool                         use_context_fmha,
58
                    const EngineParams&          engine_params,
Li Zhang's avatar
Li Zhang committed
59
60
61
62
63
64
65
66
67
68
69
70
71
                    std::shared_ptr<SharedState> shared_state,
                    LlamaWeight<T>*              weights,
                    NcclParam                    tensor_para,
                    cudaStream_t                 stream,
                    cublasMMWrapper*             cublas_wrapper,
                    IAllocator*                  allocator,
                    bool                         is_free_buffer_after_forward,
                    cudaDeviceProp*              cuda_device_prop):
    head_num_(head_num),
    size_per_head_(size_per_head),
    inter_size_(inter_size),
    num_layer_(num_layer),
    vocab_size_(vocab_size),
Li Zhang's avatar
Li Zhang committed
72
    attn_params_(attn_params),
73
    vocab_size_padded_(vocab_size),
Li Zhang's avatar
Li Zhang committed
74
75
76
77
78
    rmsnorm_eps_(norm_eps),
    start_id_(start_id),
    end_id_(end_id),
    hidden_units_(head_num * size_per_head),
    local_head_num_(head_num / tensor_para.world_size_),
Lyu Han's avatar
Lyu Han committed
79
    local_kv_head_num_(kv_head_num / tensor_para.world_size_),
Li Zhang's avatar
Li Zhang committed
80
81
82
83
84
85
86
87
88
89
90
    weights_(weights),
    tensor_para_(tensor_para),
    stream_(stream),
    cublas_wrapper_(cublas_wrapper),
    allocator_(allocator),
    is_free_buffer_after_forward_(is_free_buffer_after_forward),
    cuda_device_prop_(cuda_device_prop),
    debug_(isDebug()),
    shared_state_(shared_state)

{
lvhan028's avatar
lvhan028 committed
91
92
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
    TM_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_);
Li Zhang's avatar
Li Zhang committed
93

94
95
    vocab_size_padded_ =
        (vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_;
96

97
    batch_ = std::make_unique<LlamaBatch<T>>(engine_params, cache_block_seq_len, quant_policy, this);
Li Zhang's avatar
Li Zhang committed
98
99
100
101
102

    initialize(attn_params, kv_head_num, use_context_fmha, cache_block_seq_len, quant_policy);

    /// TODO: decouple Llama model and batch inference
    batch_->Start();
Li Zhang's avatar
Li Zhang committed
103
104
105
106
107
}

template<typename T>
LlamaV2<T>::~LlamaV2()
{
108
    unified_decoder_.reset();
Li Zhang's avatar
Li Zhang committed
109
110
111
112
    delete dynamic_decode_layer_;
}

template<typename T>
113
114
115
void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,
                            size_t                      kv_head_num,
                            bool                        use_context_fmha,
Li Zhang's avatar
Li Zhang committed
116
                            int                         cache_block_seq_len,
117
                            int                         quant_policy)
Li Zhang's avatar
Li Zhang committed
118
{
lvhan028's avatar
lvhan028 committed
119
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
120

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    unified_decoder_.reset(new UnifiedDecoder<T>(head_num_,
                                                 kv_head_num,
                                                 size_per_head_,
                                                 inter_size_,
                                                 num_layer_,
                                                 attn_params,
                                                 rmsnorm_eps_,
                                                 tensor_para_,
                                                 stream_,
                                                 cublas_wrapper_,
                                                 allocator_,
                                                 is_free_buffer_after_forward_,
                                                 use_context_fmha,
                                                 cache_block_seq_len,
                                                 quant_policy));
Li Zhang's avatar
Li Zhang committed
136
137

    dynamic_decode_layer_ = new DynamicDecodeLayer<float>(vocab_size_,
138
                                                          vocab_size_padded_,
139
                                                          0,  // end_id, deprecated
Li Zhang's avatar
Li Zhang committed
140
141
142
143
144
145
146
147
148
149
                                                          stream_,
                                                          cublas_wrapper_,
                                                          allocator_,
                                                          is_free_buffer_after_forward_,
                                                          cuda_device_prop_);
}

template<typename T>
void LlamaV2<T>::embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step)
{
Li Zhang's avatar
Li Zhang committed
150
    NvtxScope scope("embeddingLookup");
lvhan028's avatar
lvhan028 committed
151
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    // ! This kernel can't be used in context decoding
    invokeEmbeddingLookupPosEncodingPadCount(embeddings,
                                             weights_->pre_decoder_embedding_table,
                                             static_cast<T*>(nullptr),  // position encoding
                                             token_ids_buf,
                                             static_cast<int*>(nullptr),  // padding count, not used w/o pos-code
                                             batch_size,
                                             hidden_units_,
                                             static_cast<T>(1.),  // scale
                                             step,                // step, used int index into output_ids_buf_
                                             batch_size,          // token_num
                                             0,                   // ite
                                             stream_);
    sync_check_cuda_error();
}

template<typename T>
Chen Xin's avatar
Chen Xin committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
void LlamaV2<T>::updateEmbedding(T* decoder_input, const int bsz, const int* h_input_length, const Sequence** sequences)
{
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);

    for (int i = 0; i < bsz; i++) {
        const auto& seq        = *sequences[i];
        const auto& embeddings = seq.input_embeddings;
        const auto& ranges     = seq.input_embedding_ranges;
        for (int j = embeddings.size() - 1; j >= 0; j--) {
            int begin = ranges[j].first;
            int end   = ranges[j].second;
            if (end <= seq.cache_len) {
                break;
            }
            int    off_dst   = std::max(0, begin - seq.cache_len);
            int    off_src   = std::max(0, seq.cache_len - begin);
            size_t byte_size = (end - begin) * hidden_units_ * sizeof(T);
            T*     dst_ptr   = decoder_input + off_dst * hidden_units_;
            auto   src_ptr   = embeddings[j].data() + off_src * hidden_units_ * sizeof(T);
            cudaMemcpyAsync(dst_ptr, src_ptr, byte_size, cudaMemcpyDefault, stream_);
        }
        decoder_input += h_input_length[i] * hidden_units_;
    }
    sync_check_cuda_error();
}

template<typename T>
void LlamaV2<T>::forwardUnified(T*               out,
                                T*               decoder_output,
                                T*               decoder_input,
                                void**           k_block_ptrs,
                                void**           v_block_ptrs,
                                const int*       input_ids,
                                const int*       cu_block_cnts,
                                const float*     rope_theta,
                                const bool*      dc_finished,
                                const int*       pf_input_length,
                                const int*       pf_context_length,
                                T**              pf_tmp_k_ptrs,
                                T**              pf_tmp_v_ptrs,
                                size_t           token_num,
                                int              dc_batch_size,
                                int              dc_step,
                                int              dc_sum_seq_len,
                                int              dc_max_seq_len,
                                int              pf_batch_size,
                                int              pf_max_input_len,
                                int              pf_max_context_len,
                                int              pf_session_len,
                                const int*       h_input_length,
                                const Sequence** sequences)
Li Zhang's avatar
Li Zhang committed
220
{
lvhan028's avatar
lvhan028 committed
221
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
222

223
    invokeInputIdsEmbeddingLookupPosEncoding(decoder_input,
Li Zhang's avatar
Li Zhang committed
224
225
226
227
228
                                             nullptr,  // processed somewhere else
                                             weights_->pre_decoder_embedding_table,
                                             static_cast<T*>(nullptr),
                                             pPromptTuningParam<T>{},
                                             input_ids,
AllentDan's avatar
AllentDan committed
229
                                             0,  // only used for position encoding
Li Zhang's avatar
Li Zhang committed
230
231
232
233
234
                                             token_num,
                                             token_num,
                                             1,
                                             hidden_units_,
                                             stream_);
Chen Xin's avatar
Chen Xin committed
235
236
237

    updateEmbedding(decoder_input, dc_batch_size + pf_batch_size, h_input_length, sequences);

Li Zhang's avatar
Li Zhang committed
238
239
    sync_check_cuda_error();

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    const auto   dtype = getTensorType<T>();
    const size_t bsz   = dc_batch_size + pf_batch_size;

    TensorMap inputs{{"decoder_input", {MEMORY_GPU, dtype, {token_num, hidden_units_}, decoder_input}},
                     {"output_norm_weight", {MEMORY_GPU, dtype, {hidden_units_}, weights_->output_norm_weight}},
                     {"input_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, pf_input_length}},
                     {"context_lengths", {MEMORY_GPU, TYPE_INT32, {bsz}, pf_context_length}},
                     {"dc_batch_size", {MEMORY_CPU, TYPE_INT32, {1}, &dc_batch_size}},
                     {"dc_sum_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &dc_sum_seq_len}},
                     {"dc_max_seq_len", {MEMORY_CPU, TYPE_INT32, {1}, &dc_max_seq_len}},
                     {"finished", {MEMORY_GPU, TYPE_BOOL, {bsz}, dc_finished}},
                     {"pf_batch_size", {MEMORY_CPU, TYPE_INT32, {1}, &pf_batch_size}},
                     {"pf_max_q_len", {MEMORY_CPU, TYPE_INT32, {1}, &pf_max_input_len}},
                     {"pf_max_k_len", {MEMORY_CPU, TYPE_INT32, {1}, &pf_max_context_len}},
                     {"session_len", {MEMORY_CPU, TYPE_INT32, {1}, &pf_session_len}},
                     {"rope_theta", {MEMORY_GPU, TYPE_FP32, {hidden_units_}, rope_theta}},
                     {"cu_block_counts", {MEMORY_GPU, TYPE_INT32, {bsz}, cu_block_cnts}}};

    TensorMap outputs{{"decoder_output", {MEMORY_GPU, dtype, {token_num, hidden_units_}, decoder_output}},
                      {"key_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, k_block_ptrs}},
                      {"value_cache", {MEMORY_GPU, TYPE_UINT64, {bsz}, v_block_ptrs}},
                      {"tmp_k", {MEMORY_GPU, TYPE_UINT64, {bsz}, pf_tmp_k_ptrs}},
                      {"tmp_v", {MEMORY_GPU, TYPE_UINT64, {bsz}, pf_tmp_v_ptrs}},
                      {"last_token_hidden_units", {MEMORY_GPU, dtype, {bsz, hidden_units_}, out}}};

    unified_decoder_->forward(&outputs, &inputs, &weights_->decoder_layer_weights);
Li Zhang's avatar
Li Zhang committed
266
267
268
269
270
}

template<typename T>
void LlamaV2<T>::postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size)
{
Li Zhang's avatar
Li Zhang committed
271
    NvtxScope scope("postDecodeEmbedding");
lvhan028's avatar
lvhan028 committed
272
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    cudaDataType_t data_type = getCudaDataType<T>();
    float          alpha     = 1.f;
    float          beta      = 0.f;
    if (tensor_para_.world_size_ == 1) {
        cublas_wrapper_->Gemm(CUBLAS_OP_T,
                              CUBLAS_OP_N,
                              vocab_size_,  // n
                              batch_size,
                              hidden_units_,  // k
                              &alpha,
                              weights_->post_decoder_embedding_kernel,
                              data_type,
                              hidden_units_,  // k
                              decoder_output,
                              data_type,
                              hidden_units_,  // k
                              &beta,
                              logits,
                              CUDA_R_32F,
                              vocab_size_,  // n
                              CUDA_R_32F,
                              cublasGemmAlgo_t(-1));
    }
    else {
297
298
        FT_CHECK(vocab_size_padded_ % tensor_para_.world_size_ == 0);
        const size_t local_vocab_size = vocab_size_padded_ / tensor_para_.world_size_;
Li Zhang's avatar
Li Zhang committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        cublas_wrapper_->Gemm(CUBLAS_OP_T,
                              CUBLAS_OP_N,
                              local_vocab_size,  // n
                              batch_size,
                              hidden_units_,  // k
                              &alpha,
                              weights_->post_decoder_embedding_kernel
                                  + tensor_para_.rank_ * local_vocab_size * hidden_units_,
                              data_type,
                              hidden_units_,  // k
                              decoder_output,
                              data_type,
                              hidden_units_,  // k
                              &beta,
                              local_logits + tensor_para_.rank_ * batch_size * local_vocab_size,
                              CUDA_R_32F,
                              local_vocab_size,  // n
                              CUDA_R_32F,
                              cublasGemmAlgo_t(-1));
        {
            NcclGuard nccl_guard(tensor_para_, stream_);
            ftNcclAllGather(local_logits,                   // send_buf
                            local_logits,                   // recv_buf
                            batch_size * local_vocab_size,  // data_size
                            tensor_para_.rank_,
                            tensor_para_,
                            stream_);
        }
        invokeTransposeAxis01(logits, local_logits, tensor_para_.world_size_, batch_size, local_vocab_size, stream_);
        sync_check_cuda_error();
    }
}

template<typename T>
void LlamaV2<T>::dynamicDecode(int*            token_ids,
                               bool*           finished,
                               int*            sequence_length,
                               bool*           should_stop,
Li Zhang's avatar
Li Zhang committed
337
                               curandState_t*  curand_state,
Li Zhang's avatar
Li Zhang committed
338
339
340
341
342
343
344
345
346
347
348
349
                               TensorMap*      inputs,
                               TensorMap*      outputs,
                               const float*    logits,
                               const uint32_t* seq_limit_len,
                               const int*      context_length,
                               const int*      end_ids,
                               int             step,
                               int             ite,
                               size_t          max_context_len,
                               size_t          token_ids_len,
                               size_t          batch_size)
{
Li Zhang's avatar
Li Zhang committed
350
    NvtxScope scope("dynamicDecode");
lvhan028's avatar
lvhan028 committed
351
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
352
353
354
    int local_batch_size = (int)batch_size;

    std::unordered_map<std::string, Tensor> dynamic_decode_input_tensors{
355
        {"logits", {MEMORY_GPU, TYPE_FP32, {batch_size, (size_t)1, vocab_size_padded_}, logits}},
Li Zhang's avatar
Li Zhang committed
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
        {"step", {MEMORY_CPU, TYPE_INT32, {1}, &step}},
        {"max_input_length", {MEMORY_CPU, TYPE_INT32, {1}, &max_context_len}},
        {"sequence_limit_length", {MEMORY_GPU, TYPE_UINT32, {batch_size}, seq_limit_len}},
        {"input_lengths", {MEMORY_GPU, TYPE_INT32, {batch_size, 1}, context_length}},
        {"ite", {MEMORY_CPU, TYPE_UINT32, {1}, &ite}},
        {"end_id", {MEMORY_GPU, TYPE_INT32, {batch_size}, end_ids}},
        {"local_batch_size", {MEMORY_CPU, TYPE_INT32, {1}, &local_batch_size}},
    };

    const std::vector<std::string> optional_inputs{"stop_words_list",
                                                   "bad_words_list",
                                                   "runtime_top_k",
                                                   "runtime_top_p",
                                                   "temperature",
                                                   "repetition_penalty",
                                                   "random_seed"};
    for (const auto& key : optional_inputs) {
        if (inputs->isExist(key)) {
            dynamic_decode_input_tensors.insert({key, inputs->at(key)});
        }
    }

    std::unordered_map<std::string, Tensor> dynamic_decode_output_tensors{
        {"output_ids", {MEMORY_GPU, TYPE_INT32, {token_ids_len, batch_size, 1U}, token_ids}},
        {"finished", {MEMORY_GPU, TYPE_BOOL, {batch_size}, finished}},
        {"sequence_length", {MEMORY_GPU, TYPE_INT32, {batch_size}, sequence_length}},
Li Zhang's avatar
Li Zhang committed
382
383
        {"should_stop", {MEMORY_CPU, TYPE_BOOL, {1}, should_stop}},
        {"curand_state", {MEMORY_GPU, TYPE_VOID, {batch_size}, curand_state}}};
Li Zhang's avatar
Li Zhang committed
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423

    const std::vector<std::string> optional_outputs{"cum_log_probs", "output_log_probs"};
    for (const auto& key : optional_outputs) {
        if (outputs->isExist(key)) {
            dynamic_decode_output_tensors.insert({key, outputs->at(key)});
        }
    }

    dynamic_decode_layer_->forward(&dynamic_decode_output_tensors, &dynamic_decode_input_tensors);
}

static inline Tensor slice(const Tensor& tensor, int index)
{
    auto shape = tensor.shape;
    if (shape.at(0) == 1) {
        return tensor;
    }
    shape[0]          = 1;
    const auto offset = std::accumulate(shape.begin(), shape.end(), (size_t)index, std::multiplies<>{});
    return tensor.slice(shape, offset);
}

// ! implicit conversion from `unordered_map` to `TensorMap` drops 0-sized tensors
static inline TensorMap slice(const std::unordered_map<std::string, Tensor>& src, int index)
{
    TensorMap dst;
    for (const auto& kv : src) {
        dst.insert({kv.first, slice(kv.second, index)});
    }
    return dst;
}

template<typename T>
void LlamaV2<T>::forward(std::unordered_map<std::string, Tensor>*       outputs,
                         const std::unordered_map<std::string, Tensor>* inputs,
                         Control                                        control)
{
    if (debug_) {
        if (tensor_para_.rank_ == 0) {
            for (const auto& kv : *inputs) {
lvhan028's avatar
lvhan028 committed
424
                TM_LOG_INFO("[forward][rank=%d] INPUT: %s", (int)tensor_para_.rank_, format(kv).c_str());
Li Zhang's avatar
Li Zhang committed
425
426
            }
            for (const auto& kv : *outputs) {
lvhan028's avatar
lvhan028 committed
427
                TM_LOG_INFO("[forward][rank=%d] OUTPUT: %s", (int)tensor_para_.rank_, format(kv).c_str());
Li Zhang's avatar
Li Zhang committed
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
            }
        }
    }

    const int batch_size = outputs->at("output_ids").shape[0];

    const auto rank = tensor_para_.rank_;

    std::vector<std::shared_ptr<Request>> requests(batch_size);

    // rank-0 allocates all requests for the batch
    if (rank == 0) {
        for (int i = 0; i < batch_size; ++i) {
            requests[i] = std::make_shared<Request>();
            requests[i]->inputs.resize(tensor_para_.world_size_);
            requests[i]->outputs.resize(tensor_para_.world_size_);
        }
        control.comm->setSharedObject(&requests);
    }

    control.comm->barrier();

    if (rank != 0) {
        requests = *(std::vector<std::shared_ptr<Request>>*)control.comm->getSharedObject();
    }

    for (int i = 0; i < batch_size; ++i) {
        auto& r = requests[i];

        r->inputs[rank]  = slice(*inputs, i);
        r->outputs[rank] = slice(*outputs, i);

        if (rank == 0) {
            r->id         = r->inputs[rank].getVal<uint64_t>("CORRID", i);
            r->start_flag = r->inputs[rank].getVal<int>("START", 1);
            r->end_flag   = r->inputs[rank].getVal<int>("END", 1);
            r->stop_flag  = r->inputs[rank].getVal<int>("STOP", 0);
            r->stream_cb  = control.callback;
        }
    }

    control.comm->barrier();

    // rank-0 now takes the ownership of `requests`
    // rank-0 submits the tasks and wait for finish
    std::vector<int> error_codes;
    bool             has_error = 0;
    if (rank == 0) {
lvhan028's avatar
lvhan028 committed
476
        TM_LOG_INFO("[forward] Enqueue requests");
Li Zhang's avatar
Li Zhang committed
477
478
479
480
481
482

        std::vector<uint64_t> ids;
        for (const auto& r : requests) {
            ids.push_back(r->id);
        }

Li Zhang's avatar
Li Zhang committed
483
484
        auto futures = shared_state_->request_queue.enqueue(std::move(requests));

Li Zhang's avatar
Li Zhang committed
485
486
        FT_CHECK_WITH_INFO(ids.size() == futures.size(), "check failed");

lvhan028's avatar
lvhan028 committed
487
        TM_LOG_INFO("[forward] Wait for requests to complete ...");
Li Zhang's avatar
Li Zhang committed
488
489
490

        for (int i = 0; i < futures.size(); ++i) {
            auto ec = futures[i].get();
Li Zhang's avatar
Li Zhang committed
491
492
493
494
            error_codes.push_back(ec);
            if (ec) {
                has_error = true;
            }
Li Zhang's avatar
Li Zhang committed
495
            TM_LOG_INFO("[forward] Request complete for %ld, code %d", (long)ids[i], (int)ec);
Li Zhang's avatar
Li Zhang committed
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        }
    }

    // prevents request tensors being freed before the batch completes
    control.comm->barrier();

    if (rank == 0 && has_error) {
        std::stringstream ss;
        for (int i = 0; i < error_codes.size(); ++i) {
            ss << (i ? "" : " ") << error_codes[i];
        }
        throw std::runtime_error(ss.str());
    }
}

template class LlamaV2<half>;
template class LlamaV2<float>;
q.yao's avatar
q.yao committed
513
514
515
#ifdef ENABLE_BF16
template class LlamaV2<__nv_bfloat16>;
#endif
Li Zhang's avatar
Li Zhang committed
516

lvhan028's avatar
lvhan028 committed
517
}  // namespace turbomind