LlamaBatch.cc 58 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
// Copyright (c) OpenMMLab. All rights reserved.

lvhan028's avatar
lvhan028 committed
3
4
#include "src/turbomind/models/llama/LlamaBatch.h"
#include "src/turbomind/kernels/decoding_kernels.h"
Chen Xin's avatar
Chen Xin committed
5
#include "src/turbomind/macro.h"
lvhan028's avatar
lvhan028 committed
6
7
8
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/LlamaV2.h"
#include "src/turbomind/models/llama/Request.h"
Li Zhang's avatar
Li Zhang committed
9
10
#include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/models/llama/llama_kernels.h"
lvhan028's avatar
lvhan028 committed
11
12
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/Tensor.h"
Li Zhang's avatar
Li Zhang committed
13
14
15
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/debug_utils.h"
#include "src/turbomind/utils/gemm_test/gemm_func.h"
lvhan028's avatar
lvhan028 committed
16
#include "src/turbomind/utils/logger.h"
Li Zhang's avatar
Li Zhang committed
17
18
#include <algorithm>
#include <cmath>
Li Zhang's avatar
Li Zhang committed
19
20
#include <cstdint>
#include <iomanip>
Li Zhang's avatar
Li Zhang committed
21
22
#include <mutex>
#include <numeric>
Li Zhang's avatar
Li Zhang committed
23
24
25
#include <sstream>
#include <unordered_map>

lvhan028's avatar
lvhan028 committed
26
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
27

Li Zhang's avatar
Li Zhang committed
28
29
30
31
32
33
34
void ClearState(BatchState& s)
{
    std::fill_n(s.requests.begin(), s.size, nullptr);
    std::fill_n(s.sequences.begin(), s.size, nullptr);
    s.size = s.active_size = 0;
}

Li Zhang's avatar
Li Zhang committed
35
template<typename T>
Li Zhang's avatar
Li Zhang committed
36
void LlamaBatch<T>::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_reqs)
Li Zhang's avatar
Li Zhang committed
37
{
AllentDan's avatar
AllentDan committed
38
    std::unordered_map<uint64_t, int> occurrence;
Li Zhang's avatar
Li Zhang committed
39

Li Zhang's avatar
Li Zhang committed
40
    auto count_occurrence = [&occurrence](const Requests& rs) {
Li Zhang's avatar
Li Zhang committed
41
        for (const auto& r : rs) {
AllentDan's avatar
AllentDan committed
42
            ++occurrence[r->id];
Li Zhang's avatar
Li Zhang committed
43
44
45
        }
    };

Li Zhang's avatar
Li Zhang committed
46
47
48
    auto reject = [](const char* type, std::shared_ptr<Request>& req, int ec) {
        TM_LOG_WARNING(
            "[RejectInvalidRequests] Skipping invalid %s request for id %ld, code = %d", type, (long)req->id, ec);
Li Zhang's avatar
Li Zhang committed
49
50
51
52
        req->signal.set_value(ec);
        req.reset();
    };

Li Zhang's avatar
Li Zhang committed
53
    auto handle_conflict_or_invalid = [this, &occurrence, &reject](Requests& rs, const char* type) {
Li Zhang's avatar
Li Zhang committed
54
55
56
57
        for (auto& r : rs) {
            if (r) {
                int ec = 0;

Li Zhang's avatar
Li Zhang committed
58
59
60
61
62
                const int  input_length = r->inputs[rank_].getVal<int>("input_lengths", 0);
                const auto get_offset   = [&](int token_count) {
                    return std::max(0, std::min(token_count, r->inputs[rank_].getVal<int>("step", token_count)));
                };

AllentDan's avatar
AllentDan committed
63
                if (occurrence[r->id] != 1) {
Li Zhang's avatar
Li Zhang committed
64
65
66
67
68
                    ec = Request::kConflict;
                }
                else if (r->start_flag && r->stop_flag) {
                    ec = Request::kInvalid;
                }
Li Zhang's avatar
Li Zhang committed
69
70
71
72
73
74
75
76
77
78
                else if (input_length > session_len_) {
                    ec = Request::kTooLong;
                }
                else if (!r->start_flag) {
                    if (auto seq = sequence_manager_->Get(r->id); seq == nullptr) {
                        ec = Request::kInvalid;
                    }
                    else if (get_offset(seq->tokens.size()) + input_length > session_len_) {
                        ec = Request::kTooLong;
                    }
Li Zhang's avatar
Li Zhang committed
79
80
81
                }

                if (ec) {
Li Zhang's avatar
Li Zhang committed
82
                    reject(type, r, ec);
Li Zhang's avatar
Li Zhang committed
83
84
85
86
87
                }
            }
        }
    };

Li Zhang's avatar
Li Zhang committed
88
    auto drop_invalid = [](Requests& rs) {
Li Zhang's avatar
Li Zhang committed
89
90
91
92
93
94
95
96
97
        int count = 0;
        for (int i = 0; i < rs.size(); ++i) {
            if (rs[i]) {
                rs[count++] = std::move(rs[i]);
            }
        }
        rs.resize(count);
    };

AllentDan's avatar
AllentDan committed
98
99
    count_occurrence(stop_reqs);
    count_occurrence(infer_reqs);
Li Zhang's avatar
Li Zhang committed
100
101
102
103
104
105
106
107

    if (!stop_reqs.empty()) {
        handle_conflict_or_invalid(stop_reqs, "stop");

        // invalidate stop-only requests for inactive sequences
        for (auto& r : stop_reqs) {
            if (r && r->end_flag == false) {
                int ec = Request::kInactive;
Li Zhang's avatar
Li Zhang committed
108
109
                for (int i = 0; i < state_->size; ++i) {
                    if (state_->requests[i] && state_->requests[i]->id == r->id) {
Li Zhang's avatar
Li Zhang committed
110
111
112
113
114
                        ec = 0;
                        break;
                    }
                }
                if (ec) {
Li Zhang's avatar
Li Zhang committed
115
                    reject("stop", r, ec);
Li Zhang's avatar
Li Zhang committed
116
117
118
119
120
121
122
123
124
125
126
127
128
                }
            }
        }

        drop_invalid(stop_reqs);
    }

    if (!infer_reqs.empty()) {
        handle_conflict_or_invalid(infer_reqs, "infer");

        // invalidate requests for busy sequences
        for (auto& r : infer_reqs) {
            if (r) {
Li Zhang's avatar
Li Zhang committed
129
130
131
                for (int i = 0; i < state_->size; ++i) {
                    if (state_->requests[i] && state_->requests[i]->id == r->id) {
                        reject("infer", r, Request::kBusy);
Li Zhang's avatar
Li Zhang committed
132
133
134
135
136
137
138
139
140
141
142
                        break;
                    }
                }
            }
        }

        drop_invalid(infer_reqs);
    }
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
143
auto LlamaBatch<T>::ProcessStopRequests(const Requests& requests) -> std::vector<Signal>
Li Zhang's avatar
Li Zhang committed
144
{
Li Zhang's avatar
Li Zhang committed
145
    std::vector<Signal> signals;
Li Zhang's avatar
Li Zhang committed
146
147
148
    for (const auto& r : requests) {
        int ec = Request::kFail;
        // find matching active sequence
Li Zhang's avatar
Li Zhang committed
149
        for (int i = 0; i < state_->size; ++i) {
Li Zhang's avatar
Li Zhang committed
150
            // stop & optionally erase active sequence
Li Zhang's avatar
Li Zhang committed
151
            if (state_->requests[i] && state_->requests[i]->id == r->id) {
Li Zhang's avatar
Li Zhang committed
152
                ec = 0;
Li Zhang's avatar
Li Zhang committed
153
154
                CompleteRequest(i, true, r->end_flag);
                state_->requests[i].reset();
Li Zhang's avatar
Li Zhang committed
155
156
157
                break;
            }
        }
Li Zhang's avatar
Li Zhang committed
158
        // mismatch, try erase inactive sequence, in this case there is no active request to finish
Li Zhang's avatar
Li Zhang committed
159
160
        if (ec && r->end_flag) {
            ec = 0;
Li Zhang's avatar
Li Zhang committed
161
            sequence_manager_->Erase(r->id);
Li Zhang's avatar
Li Zhang committed
162
        }
AllentDan's avatar
AllentDan committed
163
        // clear output buffers (prevent leaking conversations) if request is successful
Li Zhang's avatar
Li Zhang committed
164
        if (ec == 0) {
Li Zhang's avatar
Li Zhang committed
165
166
167
168
            if (rank_ == 0) {
                std::unique_lock lock{output_mutex_};
                output_cv_.wait(lock, [&] { return output_reqs_.empty(); });
            }
Li Zhang's avatar
Li Zhang committed
169
170
            auto& output_ids      = r->outputs[rank_].at("output_ids");
            auto& sequence_length = r->outputs[rank_].at("sequence_length");
Li Zhang's avatar
Li Zhang committed
171
172
            Clear(output_ids.getPtr<int>(), output_ids.shape.at(2));
            Clear(sequence_length.getPtr<int>(), 1);
Li Zhang's avatar
Li Zhang committed
173
174
            check_cuda_error(cudaStreamSynchronize(stream_));
        }
Li Zhang's avatar
Li Zhang committed
175
176
177
178
        signals.push_back([=] { r->signal.set_value(ec); });
    }
    return signals;
}
akhoroshev's avatar
akhoroshev committed
179

Li Zhang's avatar
Li Zhang committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
template<typename T>
void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
{
    auto& state = *incoming_;

    FT_CHECK(state.size == 0);
    FT_CHECK(state.active_size == 0);

    int i = 0;
    for (const auto& r : requests) {

        // sanity check, incoming request in previous iter should have been moved to `state_`
        FT_CHECK(!state.requests[i]);

        TM_LOG_WARNING("[ProcessInferRequests] Request for %ld received.", (long)r->id);

        state.requests[i] = r;

        // get sequence for the request
        state.sequences[i] = r->start_flag ? sequence_manager_->Create(r->id) : sequence_manager_->Get(r->id);

        auto& seq = *state.sequences[i];

        if (int step = r->inputs[rank_].getVal<int>("step", -1); step >= 0) {
            /// TODO: revise step setting
            if (step <= seq.tokens.size()) {
                seq.tokens.resize(step);
                seq.cache_len = std::min(seq.cache_len, step);
            }
            else if (rank_ == 0) {
                TM_LOG_WARNING(
                    "[ProcessInferRequests] Skipping invalid step (%d) setting for ID %ld", step, (long)seq.id);
            }
Li Zhang's avatar
Li Zhang committed
213
        }
Li Zhang's avatar
Li Zhang committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292

        const int  input_length = r->inputs[rank_].getVal<int>("input_lengths");
        const int* input_ids    = r->inputs[rank_].getPtr<int>("input_ids");

        // `output_ids` contains all token ids of the sequences
        const auto output_ids_base = state.output_ids + session_len_ * i;
        auto       output_ids      = output_ids_base;

        // copy history tokens
        if (!seq.tokens.empty()) {
            output_ids = Copy(seq.tokens.data(), seq.tokens.size(), output_ids);
        }

        // copy input tokens
        if (input_length) {
            output_ids = Copy(input_ids, input_length, output_ids);
        }

        // total context length (history + input)
        state.h_context_length[i] = output_ids - output_ids_base;
        state.h_finished[i]       = false;

        const int request_output_len = state.requests[i]->inputs[rank_].getVal<int>("request_output_len");
        state.seq_len_limit[i]       = state.h_context_length[i] + request_output_len;
        // `length_criterion` sets finish flag when step >= seq_limit_len, however when step == seq_limit_len
        // the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1
        if (state.seq_len_limit[i] >= session_len_) {
            state.seq_len_limit[i] = session_len_ - 1;
            if (rank_ == 0) {
                const int trunc_output_len = state.seq_len_limit[i] - state.h_context_length[i];
                TM_LOG_WARNING(
                    "[ProcessInferRequests] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `request_output_len` is truncated to %d",
                    (long)seq.id,
                    state.h_context_length[i],
                    request_output_len,
                    (int)session_len_,
                    trunc_output_len);
            }
        }

        // compute rope scaling factor
        if (r->start_flag) {
            seq.rope_theta      = model_->attn_params_.rotary_embedding_base;
            auto scaling_factor = 1.f;
            if (r->inputs[rank_].isExist("rope_scaling_factor")) {  // runtime scaling factor
                scaling_factor = r->inputs[rank_].getVal<float>("rope_scaling_factor");
            }
            else if (model_->attn_params_.rope_scaling_factor >= 1.f) {  // infer by `seq_len_limit`
                scaling_factor   = model_->attn_params_.rope_scaling_factor;
                auto max_seq_len = state.seq_len_limit[i];
                auto max_pos_emb = model_->attn_params_.max_position_embeddings;
                if (max_seq_len > max_pos_emb) {
                    scaling_factor = scaling_factor * max_seq_len / max_pos_emb - (scaling_factor - 1);
                    // scaling_factor = std::max(exp2f(ceilf(log2f((float)max_seq_len / max_pos_emb) + 1.f))
                    // - 1.f, 1.f);
                }
            }
            if (scaling_factor != 1.f) {
                float rope_dim = model_->attn_params_.rotary_embedding_dim;
                seq.rope_theta *= powf(scaling_factor, rope_dim / (rope_dim - 2.f));
                TM_LOG_INFO("[ProcessInferRequests] %ld rope_scaling_factor: %f, rope_theta = %f",
                            (long)seq.id,
                            scaling_factor,
                            seq.rope_theta);
            }
        }
        state.h_rope_theta[i] = seq.rope_theta;

        // recover device states if not a new sequence
        if (!r->start_flag) {
            Copy((curandState_t*)seq.random_state.data() + 0, 1, (curandState_t*)state.top_k_curand_state);
            Copy((curandState_t*)seq.random_state.data() + 1, 1, (curandState_t*)state.top_p_curand_state);
        }

        // assign priority based on arrival time
        r->priority = request_count_++;

        // increment pointer
        i++;
Li Zhang's avatar
Li Zhang committed
293
    }
Li Zhang's avatar
Li Zhang committed
294
295

    incoming_->size = i;
Li Zhang's avatar
Li Zhang committed
296
297
298
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
bool LlamaBatch<T>::Initialize()
{
    NvtxScope                                scope("initialize");
    std::vector<const Sequence*>             sequences;
    std::vector<Sequence::Status>            status;
    std::vector<uint64_t>                    priorities;
    std::vector<int>                         context_lengths;
    std::vector<std::pair<BatchState*, int>> coords;

    // count the holes introduced by finished requests in from previous iteration or stop requests from
    // current iteration
    int holes{};
    int active_holes{};
    for (int i = 0; i < state_->size; ++i) {
        if (!state_->requests[i]) {
            ++holes;
            if (i < state_->active_size) {
                ++active_holes;
            }
        }
    }

    // dbg(holes, active_holes);

    auto process = [&](BatchState* state) {
        for (int i = 0; i < state->size; ++i) {
            if (auto& r = state->requests[i]) {
                sequences.push_back(state->sequences[i]);
                status.push_back(state->sequences[i]->status);
                priorities.push_back(r->priority);
                context_lengths.push_back(state->h_context_length[i]);
                coords.emplace_back(state, i);
                // clear swap-in flags
                state->is_swap_in[i] = 0;
            }
        }
    };

    process(state_);
    process(incoming_);

    auto outcome = sequence_manager_->Materialize(sequences, context_lengths, priorities, step_length_);

    if (outcome.allocation || outcome.swap_in || outcome.swap_out) {
        dbg(outcome);
    }

    bool exchange = outcome.swap_in + outcome.swap_out > 0;

    std::vector<int> idxs(sequences.size());
    std::iota(idxs.begin(), idxs.end(), 0);

    if (exchange || holes || incoming_->size) {
        // put active ones first
        auto active_end = std::stable_partition(idxs.begin(), idxs.end(), [&](int idx) {
            return sequences[idx]->status == Sequence::kActive;  // present status
        });

        // all blocks are not enough to hold a single sequence
Li Zhang's avatar
Li Zhang committed
358
359
360
        if (!sequences.empty()) {
            FT_CHECK_WITH_INFO(active_end != idxs.begin(), "No enough blocks.");
        }
Li Zhang's avatar
Li Zhang committed
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402

        // move swap-ins to the back
        auto swapin_beg = std::stable_partition(idxs.begin(), active_end, [&](int idx) {
            return status[idx] == Sequence::kActive;  // past status
        });

        // sort swap-ins according to missing length
        if (swapin_beg != active_end) {
            std::vector<int> missing_len(sequences.size());
            for (int i = 0; i < sequences.size(); ++i) {
                missing_len[i] = context_lengths[i] - sequences[i]->cache_len;
            }
            std::stable_sort(swapin_beg, active_end, [&](int i, int j) { return missing_len[i] < missing_len[j]; });
        }

        // Copy sequence states to back buffer
        FT_CHECK(back_->size == 0 && back_->active_size == 0);
        for (const auto& i : idxs) {
            auto& s = *sequences[i];
            if (exchange) {
                const auto& [state, idx] = coords[i];
                // backup random states from dynamic decode layers for swap-outs
                if (status[i] == Sequence::kActive && s.status != Sequence::kActive) {
                    SaveRandomState(*state, idx);
                }
                // mark swap-ins
                if (status[i] != Sequence::kActive && s.status == Sequence::kActive) {
                    state->is_swap_in[idx] = 1;
                }
            }
            if (s.status == Sequence::kActive) {
                ++back_->active_size;
            }
            CopyState(coords[i], {back_, back_->size++});
        }
        // Swap the buffers
        std::swap(state_, back_);

        ClearState(*back_);
        ClearState(*incoming_);
    }

Li Zhang's avatar
Li Zhang committed
403
404
    FT_CHECK(state_->size <= max_batch_size_);

Li Zhang's avatar
Li Zhang committed
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
    /// Update block ptrs when there were
    //  1. swap-in or swap-out
    //  2. holes in the active buffer
    //  3. new allocations (for existing active sequences)
    if (exchange || active_holes || outcome.allocation) {
        // Prepare intermediate buffers
        h_cu_block_counts_[0] = 0;

        auto k_ptrs = h_k_block_ptrs_;
        auto v_ptrs = h_v_block_ptrs_;

        const int batch_size = state_->active_size;

        for (int i = 0; i < batch_size; ++i) {
            const auto& seq = *state_->sequences[i];

            // cumulative num of blocks
            h_cu_block_counts_[i + 1] = h_cu_block_counts_[i] + seq.blocks.size();

            k_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), k_ptrs, [&](auto p) {
                return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetKey(p->data));
            });
            v_ptrs = std::transform(seq.blocks.cbegin(), seq.blocks.cend(), v_ptrs, [&](auto p) {
                return reinterpret_cast<uintptr_t>(sequence_manager_->OffsetVal(p->data));
            });
        }

        static_assert(sizeof(uintptr_t) == sizeof(void*));

        Copy(h_cu_block_counts_, batch_size + 1, cu_block_counts_);
        Copy(h_k_block_ptrs_, h_cu_block_counts_[batch_size], k_block_ptrs_);
        Copy(h_v_block_ptrs_, h_cu_block_counts_[batch_size], v_block_ptrs_);
    }

    /// Layout of the buffers is changed, generation & sampling need to be re-initialized for correctness when there
    /// were
    //  1. swap-in or swap-out
    //  2. holes in the active buffer
    return exchange || active_holes;
}

template<typename T>
void LlamaBatch<T>::CopyState(const std::pair<BatchState*, int> _src, const std::pair<BatchState*, int>& _dst)
{
    const auto& [src, i] = _src;
    const auto& [dst, j] = _dst;

    FT_CHECK((bool)src->requests[i]);
    FT_CHECK(!(bool)dst->requests[j]);

    dst->h_context_length[j] = src->h_context_length[i];
    dst->h_finished[j]       = src->h_finished[i];
    dst->h_rope_theta[j]     = src->h_rope_theta[i];
    dst->seq_len_limit[j]    = src->seq_len_limit[i];
    dst->sequences[j]        = src->sequences[i];
    dst->is_swap_in[j]       = src->is_swap_in[i];
    dst->requests[j]         = std::move(src->requests[i]);

    Copy(src->output_ids + i * session_len_, src->h_context_length[i], dst->output_ids + j * session_len_);

    Copy((curandState_t*)src->top_k_curand_state + i, 1, (curandState_t*)dst->top_k_curand_state + j);
    Copy((curandState_t*)src->top_p_curand_state + i, 1, (curandState_t*)dst->top_p_curand_state + j);
}

template<typename T>
void LlamaBatch<T>::SaveRandomState(BatchState& state, int idx)
{
    Copy(model_->GetTopKState(idx), 1, (curandState_t*)state.top_k_curand_state + idx);
    Copy(model_->GetTopPState(idx), 1, (curandState_t*)state.top_k_curand_state + idx);
}

template<typename T>
void LlamaBatch<T>::LoadRandomState(BatchState& state, int idx)
{
    dbg(idx);
    Copy((curandState_t*)state.top_k_curand_state + idx, 1, model_->GetTopKState(idx));
    Copy((curandState_t*)state.top_p_curand_state + idx, 1, model_->GetTopPState(idx));
}

template<typename T>
void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len)
Li Zhang's avatar
Li Zhang committed
486
{
lvhan028's avatar
lvhan028 committed
487
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
488
489
    const size_t batchxbeam = batch_size;

Li Zhang's avatar
Li Zhang committed
490
491
492
493
494
495
    const size_t hidden_units      = model_->hidden_units_;
    const size_t vocab_size        = model_->vocab_size_padded_;
    const size_t head_dim          = model_->size_per_head_;
    const size_t local_kv_head_num = model_->local_kv_head_num_;
    // +1 padding, BlockIterator does not use predicate
    const size_t max_block_count = sequence_manager_->max_block_count() + 1;
Li Zhang's avatar
Li Zhang committed
496
497
498

    context_decoder_input_buf_ =
        (T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false);
499
500
    context_decoder_output_buf_ =
        (T*)allocator_->reMalloc(context_decoder_output_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false);
Li Zhang's avatar
Li Zhang committed
501
502
503
    context_decoder_ids_buf_ =
        (int*)allocator_->reMalloc(context_decoder_ids_buf_, sizeof(int) * max_context_token_num_, false);

Li Zhang's avatar
Li Zhang committed
504
505
506
507
508
509
510
511
    tmp_k_cache_buf_ = (T*)allocator_->reMalloc(
        tmp_k_cache_buf_, sizeof(T) * max_context_token_num_ * local_kv_head_num * head_dim, false);
    tmp_v_cache_buf_ = (T*)allocator_->reMalloc(
        tmp_v_cache_buf_, sizeof(T) * max_context_token_num_ * local_kv_head_num * head_dim, false);

    tmp_k_ptrs_ = (void**)allocator_->reMalloc(tmp_k_ptrs_, sizeof(void*) * batch_size, false);
    tmp_v_ptrs_ = (void**)allocator_->reMalloc(tmp_v_ptrs_, sizeof(void*) * batch_size, false);

Li Zhang's avatar
Li Zhang committed
512
513
514
515
516
517
518
    decoder_input_buf_  = (T*)allocator_->reMalloc(decoder_input_buf_, sizeof(T) * batchxbeam * hidden_units, false);
    decoder_output_buf_ = (T*)allocator_->reMalloc(decoder_output_buf_, sizeof(T) * batchxbeam * hidden_units, false);

    input_ids_buf_      = (int*)allocator_->reMalloc(input_ids_buf_, sizeof(int) * batchxbeam * session_len, true);
    input_length_buf_   = (int*)allocator_->reMalloc(input_length_buf_, sizeof(int) * batchxbeam);
    context_length_buf_ = (int*)allocator_->reMalloc(context_length_buf_, sizeof(int) * batchxbeam);

Li Zhang's avatar
Li Zhang committed
519
    sequence_lengths_ = (int*)allocator_->reMalloc(sequence_lengths_, sizeof(int) * batchxbeam, false);
Li Zhang's avatar
Li Zhang committed
520

Li Zhang's avatar
Li Zhang committed
521
522
523
    cu_block_counts_ = (int*)allocator_->reMalloc(cu_block_counts_, sizeof(int) * (batch_size + 1));
    k_block_ptrs_    = (uintptr_t*)allocator_->reMalloc(k_block_ptrs_, sizeof(uintptr_t) * max_block_count);
    v_block_ptrs_    = (uintptr_t*)allocator_->reMalloc(v_block_ptrs_, sizeof(uintptr_t) * max_block_count);
Li Zhang's avatar
Li Zhang committed
524
525
526
527
528
529
530
531
532
533

    logits_buf_       = (float*)allocator_->reMalloc(logits_buf_, sizeof(float) * batchxbeam * vocab_size, false);
    local_logits_buf_ = (float*)allocator_->reMalloc(local_logits_buf_, sizeof(float) * batchxbeam * vocab_size, false);

    token_ids_buf_ = (int*)allocator_->reMalloc(token_ids_buf_, sizeof(int) * batchxbeam * session_len * 2, true);

    end_ids_buf_   = (int*)allocator_->reMalloc(end_ids_buf_, sizeof(int) * batch_size, false);
    finished_buf_  = (bool*)allocator_->reMalloc(finished_buf_, sizeof(bool) * batchxbeam, false);
    seq_limit_len_ = (uint32_t*)allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false);

Li Zhang's avatar
Li Zhang committed
534
535
536
537
538
539
    request_output_ids_ptrs_ = (int**)allocator_->reMalloc(request_output_ids_ptrs_, sizeof(int*) * batch_size, true);
    request_output_ids_lens_ = (int*)allocator_->reMalloc(request_output_ids_lens_, sizeof(int) * batch_size, true);
    request_seqlen_ptrs_     = (int**)allocator_->reMalloc(request_seqlen_ptrs_, sizeof(int*) * batch_size, true);

    rope_theta_ = (float*)allocator_->reMalloc(rope_theta_, sizeof(float) * batch_size, false);

Li Zhang's avatar
Li Zhang committed
540
541
542
543
    is_allocate_buffer_ = true;
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
544
void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
Li Zhang's avatar
Li Zhang committed
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
{
    stop_words_buf_ =
        (int*)allocator_->reMalloc(stop_words_buf_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true);
    bad_words_buf_ =
        (int*)allocator_->reMalloc(bad_words_buf_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true);

    h_runtime_top_k_ = (int*)allocator_->reMalloc(h_runtime_top_k_, sizeof(int) * max_batch_size, true, true);
    h_runtime_top_p_ = (float*)allocator_->reMalloc(h_runtime_top_p_, sizeof(float) * max_batch_size, true, true);
    h_temperature_   = (float*)allocator_->reMalloc(h_temperature_, sizeof(float) * max_batch_size, true, true);
    h_repetition_penalty_ =
        (float*)allocator_->reMalloc(h_repetition_penalty_, sizeof(float) * max_batch_size, true, true);
    h_random_seed_ = (uint64_t*)allocator_->reMalloc(h_random_seed_, sizeof(uint64_t) * max_batch_size, true, true);

    sampling_params_ = {{"stop_words_list", stop_words_buf_},
                        {"bad_words_list", bad_words_buf_},
                        {"runtime_top_k", h_runtime_top_k_},
                        {"runtime_top_p", h_runtime_top_p_},
                        {"temperature", h_temperature_},
                        {"repetition_penalty", h_repetition_penalty_},
                        {"random_seed", h_random_seed_}};

Li Zhang's avatar
Li Zhang committed
566
567
568
569
570
571
572
    for (auto& s : states_) {
        s.output_ids = (int*)allocator_->reMalloc(s.output_ids, sizeof(int) * max_batch_size * session_len_, true);
        s.top_k_curand_state = allocator_->reMalloc(s.top_k_curand_state, sizeof(curandState_t) * max_batch_size, true);
        s.top_p_curand_state = allocator_->reMalloc(s.top_p_curand_state, sizeof(curandState_t) * max_batch_size, true);
    }

    const size_t max_block_count = sequence_manager_->max_block_count();
Li Zhang's avatar
Li Zhang committed
573
574

    {
Li Zhang's avatar
Li Zhang committed
575
        NcclGuard barrier(model_->tensor_para_, stream_, true);
Li Zhang's avatar
Li Zhang committed
576
577
578
579
        h_input_ids_buf_ =
            (int*)allocator_->reMalloc(h_input_ids_buf_, sizeof(int) * max_batch_size * session_len_, false, true);
        h_input_length_buf_ =
            (int*)allocator_->reMalloc(h_input_length_buf_, sizeof(int) * max_batch_size, false, true);
Li Zhang's avatar
Li Zhang committed
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597

        h_tmp_k_ptrs_ = (void**)allocator_->reMalloc(h_tmp_k_ptrs_, sizeof(void*) * max_batch_size, false, true);
        h_tmp_v_ptrs_ = (void**)allocator_->reMalloc(h_tmp_v_ptrs_, sizeof(void*) * max_batch_size, false, true);

        h_cu_block_counts_ =
            (int*)allocator_->reMalloc(h_cu_block_counts_, sizeof(int) * (max_batch_size + 1), false, true);
        h_k_block_ptrs_ =
            (uintptr_t*)allocator_->reMalloc(h_k_block_ptrs_, sizeof(uintptr_t) * max_block_count, false, true);
        h_v_block_ptrs_ =
            (uintptr_t*)allocator_->reMalloc(h_v_block_ptrs_, sizeof(uintptr_t) * max_block_count, false, true);

        for (auto& s : states_) {
            s.h_context_length =
                (int*)allocator_->reMalloc(s.h_context_length, sizeof(int) * max_batch_size, false, true);
            s.h_finished   = (bool*)allocator_->reMalloc(s.h_finished, sizeof(bool) * max_batch_size * 2, false, true);
            s.h_rope_theta = (float*)allocator_->reMalloc(s.h_rope_theta, sizeof(float) * max_batch_size, false, true);
        }

Li Zhang's avatar
Li Zhang committed
598
599
        h_seq_limit_len_ =
            (uint32_t*)allocator_->reMalloc(h_seq_limit_len_, sizeof(uint32_t) * max_batch_size, false, true);
Li Zhang's avatar
Li Zhang committed
600
601
602
603
604
605
606

        h_request_output_ids_ptrs_ =
            (int**)allocator_->reMalloc(h_request_output_ids_ptrs_, sizeof(int*) * max_batch_size, true, true);
        h_request_output_ids_lens_ =
            (int*)allocator_->reMalloc(h_request_output_ids_lens_, sizeof(int) * max_batch_size, true, true);
        h_request_seqlen_ptrs_ =
            (int**)allocator_->reMalloc(h_request_seqlen_ptrs_, sizeof(int*) * max_batch_size, true, true);
Li Zhang's avatar
Li Zhang committed
607
608
609
610
611
612
    }

    is_allocate_persistant_buffer_ = true;
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
613
void LlamaBatch<T>::FreeBuffer()
Li Zhang's avatar
Li Zhang committed
614
{
lvhan028's avatar
lvhan028 committed
615
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
Li Zhang's avatar
Li Zhang committed
616
617
    if (is_allocate_buffer_) {
        allocator_->free((void**)&context_decoder_input_buf_);
618
        allocator_->free((void**)&context_decoder_output_buf_);
Li Zhang's avatar
Li Zhang committed
619
620
        allocator_->free((void**)&context_decoder_ids_buf_);

Li Zhang's avatar
Li Zhang committed
621
622
623
624
625
        allocator_->free((void**)&tmp_k_cache_buf_);
        allocator_->free((void**)&tmp_v_cache_buf_);
        allocator_->free((void**)&tmp_k_ptrs_);
        allocator_->free((void**)&tmp_v_ptrs_);

Li Zhang's avatar
Li Zhang committed
626
627
628
629
630
631
632
633
634
        allocator_->free((void**)&decoder_input_buf_);
        allocator_->free((void**)&decoder_output_buf_);

        allocator_->free((void**)&input_ids_buf_);
        allocator_->free((void**)&input_length_buf_);
        allocator_->free((void**)&context_length_buf_);

        allocator_->free((void**)&sequence_lengths_);

Li Zhang's avatar
Li Zhang committed
635
636
637
        allocator_->free((void**)&cu_block_counts_);
        allocator_->free((void**)&k_block_ptrs_);
        allocator_->free((void**)&v_block_ptrs_);
Li Zhang's avatar
Li Zhang committed
638
639
640
641

        allocator_->free((void**)&logits_buf_);
        allocator_->free((void**)&local_logits_buf_);

642
643
644
645
646
647
648
        if (local_context_logits_buf_) {
            allocator_->free((void**)&local_context_logits_buf_);
        }
        if (context_logits_buf_) {
            allocator_->free((void**)&context_logits_buf_);
        }

Li Zhang's avatar
Li Zhang committed
649
650
651
652
653
654
        allocator_->free((void**)&token_ids_buf_);

        allocator_->free((void**)&end_ids_buf_);
        allocator_->free((void**)&finished_buf_);
        allocator_->free((void**)&seq_limit_len_);

Li Zhang's avatar
Li Zhang committed
655
656
657
658
659
660
        allocator_->free((void**)&request_output_ids_ptrs_);
        allocator_->free((void**)&request_output_ids_lens_);
        allocator_->free((void**)&request_seqlen_ptrs_);

        allocator_->free((void**)&rope_theta_);

Li Zhang's avatar
Li Zhang committed
661
662
663
664
        is_allocate_buffer_ = false;
    }

    if (is_allocate_persistant_buffer_) {
Li Zhang's avatar
Li Zhang committed
665
666
667
668
669
670
671
672
673
674
675
        for (auto& s : states_) {
            allocator_->free((void**)&s.h_context_length, true);
            allocator_->free((void**)&s.h_finished, true);
            allocator_->free((void**)&s.h_rope_theta, true);
            allocator_->free((void**)&s.output_ids);
        }
        allocator_->free((void**)&h_tmp_k_ptrs_, true);
        allocator_->free((void**)&h_tmp_v_ptrs_, true);
        allocator_->free((void**)&h_cu_block_counts_, true);
        allocator_->free((void**)&h_k_block_ptrs_, true);
        allocator_->free((void**)&h_v_block_ptrs_, true);
Li Zhang's avatar
Li Zhang committed
676
677
678
679
        allocator_->free((void**)&h_input_ids_buf_, true);
        allocator_->free((void**)&h_input_length_buf_, true);
        allocator_->free((void**)&h_seq_limit_len_, true);

Li Zhang's avatar
Li Zhang committed
680
681
682
        allocator_->free((void**)&h_request_output_ids_ptrs_, true);
        allocator_->free((void**)&h_request_output_ids_lens_, true);
        allocator_->free((void**)&h_request_seqlen_ptrs_, true);
Li Zhang's avatar
Li Zhang committed
683
684
685
686
687
688

        is_allocate_persistant_buffer_ = false;
    }
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
689
690
691
692
693
LlamaBatch<T>::LlamaBatch(int                              max_batch_size,
                          int                              max_context_token_num,
                          int                              session_len,
                          std::unique_ptr<SequenceManager> sequence_manager,
                          LlamaV2<T>*                      model):
Li Zhang's avatar
Li Zhang committed
694
695
696
    max_batch_size_(max_batch_size),
    max_context_token_num_(max_context_token_num),
    session_len_(session_len),
Li Zhang's avatar
Li Zhang committed
697
698
699
700
701
    rank_(model->tensor_para_.rank_),
    debug_(model->debug_),
    step_length_(model->step_length_),
    sequence_manager_(std::move(sequence_manager)),
    model_(model),
Li Zhang's avatar
Li Zhang committed
702
703
    data_type_(getTensorType<T>())
{
Li Zhang's avatar
Li Zhang committed
704
705
706
707
708
709
710
711
712
713
    stream_         = model_->stream_;
    allocator_      = model_->allocator_;
    cublas_wrapper_ = model_->cublas_wrapper_;

    for (auto& s : states_) {
        s.requests.resize(max_batch_size);
        s.sequences.resize(max_batch_size);
        s.seq_len_limit.resize(max_batch_size);
        s.is_swap_in.resize(max_batch_size);
    }
Li Zhang's avatar
Li Zhang committed
714

Li Zhang's avatar
Li Zhang committed
715
716
717
    state_    = &states_[0];
    back_     = &states_[1];
    incoming_ = &states_[2];
Li Zhang's avatar
Li Zhang committed
718

Li Zhang's avatar
Li Zhang committed
719
720
    AllocateBuffer(max_batch_size, session_len_);
    AllocatePersistantBuffer(max_batch_size);
Li Zhang's avatar
Li Zhang committed
721
722
723
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
724
void LlamaBatch<T>::InitializeSampling()
Li Zhang's avatar
Li Zhang committed
725
{
Li Zhang's avatar
Li Zhang committed
726
    const int batch_size = state_->active_size;
Li Zhang's avatar
Li Zhang committed
727
728
    TensorMap inputs;
    for (const auto& param : sampling_params_) {
Li Zhang's avatar
Li Zhang committed
729
        // find an exemplar that matches the param name
Li Zhang's avatar
Li Zhang committed
730
        const Tensor* ptr{};
Li Zhang's avatar
Li Zhang committed
731
732
733
        for (int i = 0; i < batch_size; ++i) {
            if (state_->requests[i]->inputs[rank_].isExist(param.first)) {
                ptr = &state_->requests[i]->inputs[rank_].at(param.first);
Li Zhang's avatar
Li Zhang committed
734
735
736
                break;
            }
        }
Li Zhang's avatar
Li Zhang committed
737
        // fill the batch of the param
Li Zhang's avatar
Li Zhang committed
738
739
740
741
        if (ptr) {
            const auto& ref   = *ptr;
            auto        shape = ref.shape;
            FT_CHECK(shape[0] == 1);
Li Zhang's avatar
Li Zhang committed
742
            shape[0]                = batch_size;
Li Zhang's avatar
Li Zhang committed
743
            const int size_in_bytes = ref.sizeBytes();
Li Zhang's avatar
Li Zhang committed
744
745
746
747
            Clear((std::byte*)param.second, size_in_bytes * batch_size);
            for (int i = 0; i < batch_size; ++i) {
                if (state_->requests[i]->inputs[rank_].isExist(param.first)) {
                    auto& src = state_->requests[i]->inputs[rank_].at(param.first);
Li Zhang's avatar
Li Zhang committed
748
                    FT_CHECK(ref.shape == src.shape);
Li Zhang's avatar
Li Zhang committed
749
                    Copy(src.getPtr<std::byte>(), size_in_bytes, (std::byte*)param.second + size_in_bytes * i);
Li Zhang's avatar
Li Zhang committed
750
751
752
753
                }
            }
            inputs.insert({param.first, {ref.where, ref.type, shape, param.second}});
            if (debug_ && rank_ == 0) {
lvhan028's avatar
lvhan028 committed
754
                TM_LOG_INFO("[initializeSampling] %s", format({param.first, inputs.at(param.first)}).c_str());
Li Zhang's avatar
Li Zhang committed
755
756
757
758
759
760
            }
        }
    }

    inputs_ = std::move(inputs);

Li Zhang's avatar
Li Zhang committed
761
762
763
764
765
766
    model_->dynamic_decode_layer_->setup(batch_size, 1, &inputs_);

    // recover random states if not a new request
    for (int i = 0; i < batch_size; ++i) {
        if (!state_->requests[i]->start_flag && state_->is_swap_in[i]) {
            LoadRandomState(*state_, i);
Li Zhang's avatar
Li Zhang committed
767
768
769
        }
    }

Li Zhang's avatar
Li Zhang committed
770
    handleOptArg(&inputs_, "end_id", end_ids_buf_, model_->end_id_, batch_size);
Li Zhang's avatar
Li Zhang committed
771
772
773
774
    cudaStreamSynchronize(0);
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
775
auto LlamaBatch<T>::InitializeGeneration() -> GenerationState
Li Zhang's avatar
Li Zhang committed
776
{
Li Zhang's avatar
Li Zhang committed
777
778
    const int batch_size      = state_->active_size;
    const int max_context_len = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size);
Li Zhang's avatar
Li Zhang committed
779

Li Zhang's avatar
Li Zhang committed
780
781
    Clear(token_ids_buf_, batch_size * session_len_);
    invokeTransposeAxis01(token_ids_buf_, state_->output_ids, batch_size, session_len_, 1, stream_);
Li Zhang's avatar
Li Zhang committed
782
783
784
785
786
787
788
789
    sync_check_cuda_error();

    // token_ids_buf_[s, b]
    // ABCDe            ABCDe     e
    // ABCDEFGHIJk      ABCDEFGHIJk
    // ABCDEFGHi    ->  ABCDEFGHi i
    // ABCDEFGh         ABCDEFGh  h
    // ABCd             ABCd      d
Li Zhang's avatar
Li Zhang committed
790
    for (int i = 0; i < batch_size; ++i) {
Li Zhang's avatar
Li Zhang committed
791
        auto token_ids = token_ids_buf_ + i;
Li Zhang's avatar
Li Zhang committed
792
793
        auto p_src     = state_->h_context_length[i] - 1;
        auto p_dst     = max_context_len - 1;
Li Zhang's avatar
Li Zhang committed
794
        if (p_src != p_dst) {  // dst and src of `cudaMemcpyAsync` must not overlap
Li Zhang's avatar
Li Zhang committed
795
            Copy(token_ids + p_src * batch_size, 1, token_ids + p_dst * batch_size);
Li Zhang's avatar
Li Zhang committed
796
797
798
        }
    }

Li Zhang's avatar
Li Zhang committed
799
800
    Copy(state_->h_context_length, batch_size, context_length_buf_);  // also referenced in `SetOutputTensors`
    Copy(context_length_buf_, batch_size, sequence_lengths_);
Li Zhang's avatar
Li Zhang committed
801
    // `sequence_lengths_` will be increased by dynamic decode
AllentDan's avatar
AllentDan committed
802
    // note that in decoder and in output "sequence length" has different semantic
Li Zhang's avatar
Li Zhang committed
803
804
    // - in decoder it means length of sequence that has kv cache already computed
    // - in output it means length of all tokens (the last generated token does not have k/v cache computed yet)
Li Zhang's avatar
Li Zhang committed
805
    invokePlusScalar(sequence_lengths_, -1, batch_size, stream_);
Li Zhang's avatar
Li Zhang committed
806
807
    sync_check_cuda_error();

Li Zhang's avatar
Li Zhang committed
808
809
810
811
812
813
814
815
816
817
818
819
    // used for dispatching split-k decoding kernels
    const int sum_seq_len =
        std::accumulate(state_->h_context_length, state_->h_context_length + batch_size, -batch_size);
    const int max_seq_len = *std::max_element(state_->h_context_length, state_->h_context_length + batch_size) - 1;

    // seq_limit_len_, will be compared to `step` instead of `sequence_length`, so padding len should be accounted
    // for
    for (int i = 0; i < batch_size; ++i) {
        h_seq_limit_len_[i] = state_->seq_len_limit[i] + (max_context_len - state_->h_context_length[i]);
    }
    Copy(h_seq_limit_len_, batch_size, seq_limit_len_);
    Copy(state_->h_finished, batch_size, finished_buf_);
Li Zhang's avatar
Li Zhang committed
820

Li Zhang's avatar
Li Zhang committed
821
822
823
824
825
826
827
828
829
830
831
832
    for (int i = 0; i < batch_size; ++i) {
        Tensor& output_ids         = state_->requests[i]->outputs[rank_].at("output_ids");
        int*    req_output_ids_ptr = output_ids.getPtr<int>();
        int*    req_seqlen_ptr     = state_->requests[i]->outputs[rank_].getPtr<int>("sequence_length");

        h_request_output_ids_ptrs_[i] = req_output_ids_ptr;
        h_request_output_ids_lens_[i] = output_ids.shape.at(2);
        h_request_seqlen_ptrs_[i]     = req_seqlen_ptr;

        FT_CHECK(h_request_output_ids_ptrs_[i]);
        FT_CHECK(h_request_output_ids_lens_[i]);
        FT_CHECK(h_request_seqlen_ptrs_[i]);
Li Zhang's avatar
Li Zhang committed
833
    }
Li Zhang's avatar
Li Zhang committed
834
835
836
837
838
    Copy(h_request_output_ids_ptrs_, batch_size, request_output_ids_ptrs_);
    Copy(h_request_output_ids_lens_, batch_size, request_output_ids_lens_);
    Copy(h_request_seqlen_ptrs_, batch_size, request_seqlen_ptrs_);

    Copy(state_->h_rope_theta, batch_size, rope_theta_);
Li Zhang's avatar
Li Zhang committed
839
840
841
842

    // ! range of step_ [1, 2 * session_len]
    // consider a sequence with context_len == session_len and another sequence with context_len == 1 and
    // request_output_len == session_len - 1 => step_ will loop in [session_len, 2 * session_len)
Li Zhang's avatar
Li Zhang committed
843
    const int start_step = max_context_len;
Li Zhang's avatar
Li Zhang committed
844
845

    if (rank_ == 0) {
Li Zhang's avatar
Li Zhang committed
846
847
        TM_LOG_INFO("[initGen] batch_size = %d", (int)batch_size);
        TM_LOG_INFO("[initGen] max_context_len = %d", (int)max_context_len);
Li Zhang's avatar
Li Zhang committed
848

lvhan028's avatar
lvhan028 committed
849
        TM_LOG_INFO("[initGen] slot  sequence_id  context_len  seq_limit_len  finished");
Li Zhang's avatar
Li Zhang committed
850
        for (int i = 0; i < batch_size; ++i) {
lvhan028's avatar
lvhan028 committed
851
            TM_LOG_INFO("[initGen] %4d  %11ld  %11d  %13d  %8d",
Li Zhang's avatar
Li Zhang committed
852
                        i,
Li Zhang's avatar
Li Zhang committed
853
854
                        (long)state_->sequences[i]->id,
                        state_->h_context_length[i],
Li Zhang's avatar
Li Zhang committed
855
                        (int)h_seq_limit_len_[i],
Li Zhang's avatar
Li Zhang committed
856
                        (int)state_->h_finished[i]);
Li Zhang's avatar
Li Zhang committed
857
858
        }
    }
Li Zhang's avatar
Li Zhang committed
859
860
861
862
863
864

    // for (int i = 0; i < batch_size; ++i) {
    //     gSequenceIds(i) = state_->requests[i]->id;
    // }

    return GenerationState{max_context_len, start_step, sum_seq_len, max_seq_len};
Li Zhang's avatar
Li Zhang committed
865
866
867
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
868
bool LlamaBatch<T>::Generate(GenerationState& g)
Li Zhang's avatar
Li Zhang committed
869
{
Li Zhang's avatar
Li Zhang committed
870
871
872
    NvtxScope scope("Generate");
    const int batch_size = state_->active_size;

Li Zhang's avatar
Li Zhang committed
873
    constexpr int kLogInterval = 10;
Li Zhang's avatar
Li Zhang committed
874
875
    if (rank_ == 0 && (g.step - 1) % kLogInterval == 0) {
        TM_LOG_INFO("------------------------- step = %d -------------------------", g.step - 1);
Li Zhang's avatar
Li Zhang committed
876
877
    }

Li Zhang's avatar
Li Zhang committed
878
    const bool is_first_step = (g.step == g.max_init_ctx_len);
Li Zhang's avatar
Li Zhang committed
879
880
881

    std::vector<int> prev;
    if (debug_ && rank_ == 0 && is_first_step) {
Li Zhang's avatar
Li Zhang committed
882
883
        prev.resize(batch_size);
        Copy(token_ids_buf_ + (g.step - 1) * batch_size, batch_size, prev.data());
Li Zhang's avatar
Li Zhang committed
884
885
886
    }

    // embeddingLookup(step_ - 1);
Li Zhang's avatar
Li Zhang committed
887
    model_->embeddingLookup(decoder_input_buf_,  //
Li Zhang's avatar
Li Zhang committed
888
                            token_ids_buf_,
Li Zhang's avatar
Li Zhang committed
889
890
                            batch_size,
                            g.step - 1);
Li Zhang's avatar
Li Zhang committed
891

Li Zhang's avatar
Li Zhang committed
892
893
894
    model_->decoderForward(decoder_output_buf_,
                           k_block_ptrs_,
                           v_block_ptrs_,
Li Zhang's avatar
Li Zhang committed
895
896
897
                           decoder_input_buf_,
                           sequence_lengths_,
                           finished_buf_,
Li Zhang's avatar
Li Zhang committed
898
899
900
                           cu_block_counts_,
                           rope_theta_,
                           g.step,
Li Zhang's avatar
Li Zhang committed
901
                           0,
Li Zhang's avatar
Li Zhang committed
902
903
904
                           g.sum_seq_len,
                           g.max_seq_len,
                           batch_size);
Li Zhang's avatar
Li Zhang committed
905

Li Zhang's avatar
Li Zhang committed
906
    model_->postDecodeEmbedding(logits_buf_,  //
Li Zhang's avatar
Li Zhang committed
907
908
                                local_logits_buf_,
                                decoder_output_buf_,
Li Zhang's avatar
Li Zhang committed
909
                                batch_size);
Li Zhang's avatar
Li Zhang committed
910
911
912
913

    // stop-words & bad-words require the matched tokens to be contiguous, so item size > 1 is
    // not supported yet.
    bool should_stop{};
Li Zhang's avatar
Li Zhang committed
914
    model_->dynamicDecode(token_ids_buf_,
Li Zhang's avatar
Li Zhang committed
915
916
917
918
919
920
921
922
923
                          finished_buf_,
                          sequence_lengths_,
                          &should_stop,
                          &inputs_,
                          &outputs_,
                          logits_buf_,
                          seq_limit_len_,
                          context_length_buf_,
                          end_ids_buf_,
Li Zhang's avatar
Li Zhang committed
924
                          g.step,
Li Zhang's avatar
Li Zhang committed
925
                          0,
Li Zhang's avatar
Li Zhang committed
926
                          g.max_init_ctx_len,
Li Zhang's avatar
Li Zhang committed
927
                          session_len_ * 2,
Li Zhang's avatar
Li Zhang committed
928
                          batch_size);
Li Zhang's avatar
Li Zhang committed
929
930

    if (debug_ && rank_ == 0) {
Li Zhang's avatar
Li Zhang committed
931
        std::vector<int> curr(batch_size);
Li Zhang's avatar
Li Zhang committed
932

Li Zhang's avatar
Li Zhang committed
933
        Copy(token_ids_buf_ + g.step * batch_size, batch_size, curr.data());
Li Zhang's avatar
Li Zhang committed
934
935
936
937
938
939
940
        cudaStreamSynchronize(stream_);

        if (is_first_step) {
            std::stringstream sprev;
            for (int k = 0; k < prev.size(); ++k) {
                sprev << std::setw(6) << prev[k];
            }
Li Zhang's avatar
Li Zhang committed
941
            TM_LOG_INFO("[ lookup ] step = %d, [%s]", g.step - 1, sprev.str().c_str());
Li Zhang's avatar
Li Zhang committed
942
943
944
945
946
947
        }

        std::stringstream scurr;
        for (int k = 0; k < curr.size(); ++k) {
            scurr << std::setw(6) << curr[k];
        }
Li Zhang's avatar
Li Zhang committed
948
        TM_LOG_INFO("[generate] step = %d, [%s]", g.step - 1, scurr.str().c_str());
Li Zhang's avatar
Li Zhang committed
949
950
951
    }

    ////////////////////////////////////////////////
Li Zhang's avatar
Li Zhang committed
952
953
954
955
    /// ! increase the counters
    g.step += 1;
    g.max_seq_len += 1;
    g.sum_seq_len += batch_size;
Li Zhang's avatar
Li Zhang committed
956
957
958
959
960

    return !should_stop;
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
961
void LlamaBatch<T>::ContextDecode()
Li Zhang's avatar
Li Zhang committed
962
{
Li Zhang's avatar
Li Zhang committed
963
    const auto batch_size = state_->active_size;
Li Zhang's avatar
Li Zhang committed
964

Li Zhang's avatar
Li Zhang committed
965
966
967
968
969
970
971
972
973
974
975
    int base = -1;
    for (int i = 0; i < batch_size; ++i) {
        if (state_->is_swap_in[i]) {
            const auto& seq = *state_->sequences[i];
            dbg(std::tuple(i, state_->h_context_length[i], seq.cache_len));
            if (const int missing = state_->h_context_length[i] - seq.cache_len; missing > 1) {
                base = base < 0 ? i : base;
                dbg(seq.tokens, seq.cache_len);
                Copy(state_->output_ids + i * session_len_ + seq.cache_len, missing, input_ids_buf_ + i * session_len_);
                // subtract input/context len by 1 to skip last input token (will process with decoder later)
                h_input_length_buf_[i] = missing - 1;
Li Zhang's avatar
Li Zhang committed
976
977
978
            }
        }
    }
Li Zhang's avatar
Li Zhang committed
979
980
981
    if (base < 0) {
        // TM_LOG_INFO("[decodeContext] Context decoding is not needed.");
        return;
Li Zhang's avatar
Li Zhang committed
982
983
    }

Li Zhang's avatar
Li Zhang committed
984
    const int context_decode_count = batch_size - base;
Li Zhang's avatar
Li Zhang committed
985

Li Zhang's avatar
Li Zhang committed
986
987
988
    Copy(state_->h_context_length, batch_size, context_length_buf_);
    Copy(state_->h_rope_theta, batch_size, rope_theta_);
    Copy(h_input_length_buf_, batch_size, input_length_buf_);
Li Zhang's avatar
Li Zhang committed
989

Li Zhang's avatar
Li Zhang committed
990
991
    check_cuda_error(cudaStreamSynchronize(stream_));
    const auto tick = std::chrono::high_resolution_clock::now();
Li Zhang's avatar
Li Zhang committed
992

Li Zhang's avatar
Li Zhang committed
993
994
    if (rank_ == 0) {
        TM_LOG_INFO("[decodeContext] base = %d, count = %d", base, context_decode_count);
Li Zhang's avatar
Li Zhang committed
995
    }
Li Zhang's avatar
Li Zhang committed
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
    // subtract input/context len by 1 to skip last input token (will process with decoder later)
    invokePlusScalar(context_length_buf_ + base, -1, context_decode_count, stream_);

    // find sub-batch offsets
    std::vector<int> offsets{base};
    std::vector<int> max_context_cnts;
    int              accum_size        = 0;
    int              accum_input_count = 0;
    int              max_context_count = 0;
    for (int i = base; i < batch_size; ++i) {
        int size          = accum_size + 1;
        int input_count   = accum_input_count + h_input_length_buf_[i];
        int context_count = std::max(max_context_count, state_->h_context_length[i] - 1);
        // we have `cu_seqlens` on q so no padding for input is needed
        // kernels are expecting uniform k/v cache length -> `max_context_count * size <= max_context_token_num_`
        if (input_count <= max_context_token_num_ && context_count * size <= max_context_token_num_) {
            accum_size        = size;
            accum_input_count = input_count;
            max_context_count = context_count;
Li Zhang's avatar
Li Zhang committed
1015
        }
Li Zhang's avatar
Li Zhang committed
1016
1017
1018
1019
1020
1021
        else {
            offsets.push_back(i);
            max_context_cnts.push_back(max_context_count);
            accum_size        = 1;
            accum_input_count = h_input_length_buf_[i];
            max_context_count = state_->h_context_length[i] - 1;
Li Zhang's avatar
Li Zhang committed
1022
1023
        }
    }
Li Zhang's avatar
Li Zhang committed
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
    offsets.push_back(batch_size);
    max_context_cnts.push_back(max_context_count);

    dbg(offsets, max_context_cnts);

    // context decode on sub-batches
    for (int k = 0; k < offsets.size() - 1; ++k) {
        int              first          = offsets[k];
        int              last           = offsets[k + 1];
        int              sub_batch_size = last - first;
        T*               k_ptr          = tmp_k_cache_buf_;
        T*               v_ptr          = tmp_v_cache_buf_;
        std::vector<int> decode_indices{};
        std::vector<int> decode_lengths{};
        int              max_input_len{};
        auto             input_ids = context_decoder_ids_buf_;
        TM_LOG_INFO("first = %d, last = %d", first, last);
        for (int i = first; i < last; ++i) {
            TM_LOG_INFO("session_len = %d, input_length = %d", session_len_, h_input_length_buf_[i]);
            input_ids = Copy(input_ids_buf_ + i * session_len_, h_input_length_buf_[i], input_ids);
            dbg(i, h_input_length_buf_[i]);
            h_tmp_k_ptrs_[i] = k_ptr;
            h_tmp_v_ptrs_[i] = v_ptr;
            k_ptr += model_->local_kv_head_num_ * max_context_cnts[k] * model_->size_per_head_;
            v_ptr += model_->local_kv_head_num_ * max_context_cnts[k] * model_->size_per_head_;
            decode_indices.push_back(i);
            decode_lengths.push_back(h_input_length_buf_[i]);
            max_input_len = std::max(max_input_len, h_input_length_buf_[i]);
Li Zhang's avatar
Li Zhang committed
1052
        }
Li Zhang's avatar
Li Zhang committed
1053
1054
        int token_count = input_ids - context_decoder_ids_buf_;
        dbg(token_count, max_input_len, max_context_cnts[k]);
Li Zhang's avatar
Li Zhang committed
1055

Li Zhang's avatar
Li Zhang committed
1056
1057
        Copy(h_tmp_k_ptrs_ + first, sub_batch_size, tmp_k_ptrs_ + first);
        Copy(h_tmp_v_ptrs_ + first, sub_batch_size, tmp_v_ptrs_ + first);
Li Zhang's avatar
Li Zhang committed
1058
1059

        if (rank_ == 0) {
Li Zhang's avatar
Li Zhang committed
1060
1061
1062
1063
1064
1065
1066
            TM_LOG_INFO(
                "[decodeContext] offset = %d, batch_size = %d, token_num = %d, max_input_len = %d, max_context_len = %d",
                base,
                sub_batch_size,
                token_count,
                max_input_len,
                max_context_cnts[k]);
Li Zhang's avatar
Li Zhang committed
1067
1068
        }

Li Zhang's avatar
Li Zhang committed
1069
1070
        dbg(first, last);
        dbg(k_block_ptrs_, v_block_ptrs_);
Li Zhang's avatar
Li Zhang committed
1071

Li Zhang's avatar
Li Zhang committed
1072
1073
1074
1075
1076
1077
1078
        if (1) {
            std::vector<int> input_len(sub_batch_size);
            std::vector<int> context_len(sub_batch_size);
            Copy(input_length_buf_ + first, sub_batch_size, input_len.data());
            Copy(context_length_buf_ + first, sub_batch_size, context_len.data());
            cudaStreamSynchronize(stream_);
            dbg(input_len, context_len);
Li Zhang's avatar
Li Zhang committed
1079
1080
        }

Li Zhang's avatar
Li Zhang committed
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
        model_->contextDecode(nullptr,
                              k_block_ptrs_,
                              v_block_ptrs_,
                              tmp_k_ptrs_ + first,
                              tmp_v_ptrs_ + first,
                              context_decoder_input_buf_,
                              context_decoder_output_buf_,
                              context_decoder_ids_buf_,
                              input_length_buf_ + first,
                              context_length_buf_ + first,
                              cu_block_counts_ + first,
                              rope_theta_ + first,
                              token_count,
                              max_input_len,
                              max_context_cnts[k],
                              max_context_cnts[k],
                              sub_batch_size);

        // compute logits of inputs if requested
        OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths);
    }

    invokePlusScalar(context_length_buf_ + base, 1, context_decode_count, stream_);

    std::fill(h_input_length_buf_ + base, h_input_length_buf_ + batch_size, 0);

    // `SequenceManager` needs real-time value of cache length
    for (int i = base; i < batch_size; ++i) {
        if (state_->requests[i]) {
            FT_CHECK(state_->sequences[i]);
            state_->sequences[i]->cache_len = state_->h_context_length[i] - 1;  // -1 since we skip last token
Li Zhang's avatar
Li Zhang committed
1112
1113
        }
    }
Li Zhang's avatar
Li Zhang committed
1114
1115
1116
1117
1118

    check_cuda_error(cudaStreamSynchronize(stream_));
    const auto tock = std::chrono::high_resolution_clock::now();
    if (rank_ == 0) {
        TM_LOG_INFO("[decodeContext] %.2f ms", std::chrono::duration<float, std::milli>(tock - tick).count());
Li Zhang's avatar
Li Zhang committed
1119
1120
1121
    }
}

1122
template<typename T>
Li Zhang's avatar
Li Zhang committed
1123
void LlamaBatch<T>::OutputContextLogits(T*                      context_decoder_output,
1124
1125
1126
1127
1128
1129
1130
1131
                                        const std::vector<int>& indices,
                                        const std::vector<int>& lengths)
{
    std::vector<float*> output_logits;
    int                 num_token = 0;
    {
        bool is_return_logits = false;
        for (int k = 0; k < indices.size(); ++k) {
Li Zhang's avatar
Li Zhang committed
1132
            auto& request = state_->requests[indices[k]];
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
            output_logits.push_back(request->outputs[rank_].getPtr<float>("logits", nullptr));
            num_token += lengths[k];
            if (output_logits.back()) {
                is_return_logits = true;
            }
        }
        if (!is_return_logits) {
            return;
        }
    }

    if (context_logits_buf_ == nullptr) {
Li Zhang's avatar
Li Zhang committed
1145
        NcclGuard guard(model_->tensor_para_, stream_, true);
Chen Xin's avatar
Chen Xin committed
1146
        context_logits_buf_ =
Li Zhang's avatar
Li Zhang committed
1147
1148
            (float*)allocator_->malloc(sizeof(float) * model_->vocab_size_padded_ * max_context_token_num_);
        const auto tp = model_->tensor_para_.world_size_;
1149
        if (tp > 1) {
Li Zhang's avatar
Li Zhang committed
1150
1151
            FT_CHECK(model_->vocab_size_padded_ % tp == 0);
            const auto local_vocab_size = model_->vocab_size_padded_ / tp;
1152
1153
1154
1155
1156
            local_context_logits_buf_ =
                (float*)allocator_->malloc(sizeof(float) * local_vocab_size * max_context_token_num_);
        }
    }

Li Zhang's avatar
Li Zhang committed
1157
    model_->postDecodeEmbedding(context_logits_buf_, local_context_logits_buf_, context_decoder_output, num_token);
1158
1159
1160
1161
1162

    auto logits = context_logits_buf_;

    for (int k = 0; k < indices.size(); ++k) {
        if (output_logits[k]) {
Li Zhang's avatar
Li Zhang committed
1163
            Copy(logits, model_->vocab_size_ * lengths[k], output_logits[k]);
1164
        }
Li Zhang's avatar
Li Zhang committed
1165
        logits += model_->vocab_size_padded_ * lengths[k];
1166
1167
1168
    }
}

Li Zhang's avatar
Li Zhang committed
1169
template<typename T>
Li Zhang's avatar
Li Zhang committed
1170
auto LlamaBatch<T>::Finish(GenerationState& g) -> std::vector<Signal>
Li Zhang's avatar
Li Zhang committed
1171
{
Li Zhang's avatar
Li Zhang committed
1172
1173
    NvtxScope scope("Finish");
    const int batch_size = state_->active_size;
Li Zhang's avatar
Li Zhang committed
1174

Li Zhang's avatar
Li Zhang committed
1175
1176
    // secure info needed by `Initialize()`
    Copy(finished_buf_, batch_size, state_->h_finished);
Li Zhang's avatar
Li Zhang committed
1177

Li Zhang's avatar
Li Zhang committed
1178
1179
1180
1181
    // invariant: context_length = sequence_length + 1
    invokePlusScalar(sequence_lengths_, 1, batch_size, stream_);
    Copy(sequence_lengths_, batch_size, state_->h_context_length);
    invokePlusScalar(sequence_lengths_, -1, batch_size, stream_);
Li Zhang's avatar
Li Zhang committed
1182

Li Zhang's avatar
Li Zhang committed
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
    if constexpr (0) {
        std::unique_lock<std::mutex> lock;
        if (rank_ == 0) {
            NvtxScope _("acquire_outputs");
            // wait for previous output operations
            lock = std::unique_lock{output_mutex_};
            output_cv_.wait(lock, [&] { return output_reqs_.empty(); });
        }

        SetOutputTensors(g);
        check_cuda_error(cudaStreamSynchronize(stream_));

        if (rank_ == 0) {
            NvtxScope _("signal_output_thread");
            // enqueue new output requests
            for (int i = 0; i < batch_size; ++i) {
                FT_CHECK(state_->requests[i] != nullptr);
                if (state_->requests[i]->stream_cb) {
                    output_reqs_.push_back(state_->requests[i]);
                }
            }
            lock.unlock();
            // notify output thread when we do have stream cbs to call
            if (!output_reqs_.empty()) {
                output_cv_.notify_one();
            }
Li Zhang's avatar
Li Zhang committed
1209
1210
        }
    }
Li Zhang's avatar
Li Zhang committed
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
    else {
        SetOutputTensors(g);
        check_cuda_error(cudaStreamSynchronize(stream_));

        {
            NvtxScope _("output_cb");
            if (rank_ == 0 && model_->ffi_lock_) {
                model_->ffi_lock_(1);
            }
            for (int i = 0; i < batch_size; ++i) {
                FT_CHECK(state_->requests[i] != nullptr);
                if (state_->requests[i]->stream_cb && rank_ == 0) {
                    state_->requests[i]->stream_cb(&state_->requests[i]->outputs[rank_].get());
                }
            }
            if (rank_ == 0 && model_->ffi_lock_) {
                model_->ffi_lock_(0);
            }
        }
Chen Xin's avatar
Chen Xin committed
1230
    }
Li Zhang's avatar
Li Zhang committed
1231
1232
1233

    if (debug_ && rank_ == 0) {
        std::stringstream ss;
Li Zhang's avatar
Li Zhang committed
1234
1235
        for (int i = 0; i < batch_size; ++i) {
            ss << (i ? ", " : "") << "(" << state_->h_context_length[i] << "," << state_->h_finished[i] << ")";
Li Zhang's avatar
Li Zhang committed
1236
        }
lvhan028's avatar
lvhan028 committed
1237
        TM_LOG_INFO("[finish] [%s]", ss.str().c_str());
Li Zhang's avatar
Li Zhang committed
1238
1239
    }

Li Zhang's avatar
Li Zhang committed
1240
1241
1242
1243
1244
    // `SequenceManager` needs real-time value of cache length
    for (int i = 0; i < batch_size; ++i) {
        if (state_->requests[i]) {
            FT_CHECK(state_->sequences[i]);
            state_->sequences[i]->cache_len = state_->h_context_length[i];
Li Zhang's avatar
Li Zhang committed
1245
1246
1247
        }
    }

Li Zhang's avatar
Li Zhang committed
1248
1249
1250
1251
1252
1253
1254
    std::vector<Signal> signals;
    {
        NvtxScope _("prepare_completion_signal");
        for (int i = 0; i < batch_size; ++i) {
            if (state_->requests[i] && state_->h_finished[i]) {
                CompleteRequest(i, false, false);
                signals.push_back([r = std::move(state_->requests[i])] { r->signal.set_value(0); });
Li Zhang's avatar
Li Zhang committed
1255
1256
1257
            }
        }
    }
Li Zhang's avatar
Li Zhang committed
1258
    return signals;
Li Zhang's avatar
Li Zhang committed
1259
1260
1261
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
1262
void LlamaBatch<T>::SetOutputTensors(const GenerationState& g)
Li Zhang's avatar
Li Zhang committed
1263
{
Li Zhang's avatar
Li Zhang committed
1264
1265
1266
    NvtxScope scope("SetOutputTensors");
    // dbg(g.max_init_ctx_len);
    const auto batch_size = state_->active_size;
Li Zhang's avatar
Li Zhang committed
1267
    // [s,b] -> [b,s] and skip padding in [context_len, max_context_len)
Li Zhang's avatar
Li Zhang committed
1268
    invokeGatherOutput(state_->output_ids,
Li Zhang's avatar
Li Zhang committed
1269
1270
                       token_ids_buf_,
                       context_length_buf_,
Li Zhang's avatar
Li Zhang committed
1271
1272
                       g.max_init_ctx_len,
                       g.step,
Li Zhang's avatar
Li Zhang committed
1273
                       session_len_,
Li Zhang's avatar
Li Zhang committed
1274
                       batch_size,
Li Zhang's avatar
Li Zhang committed
1275
1276
1277
                       stream_);
    sync_check_cuda_error();

Li Zhang's avatar
Li Zhang committed
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
    if constexpr (1) {
        invokeUpdateOutput(request_output_ids_ptrs_,
                           request_seqlen_ptrs_,
                           state_->output_ids,
                           sequence_lengths_,
                           request_output_ids_lens_,
                           session_len_,
                           g.step > g.max_init_ctx_len,
                           batch_size,
                           stream_);
        sync_check_cuda_error();
    }
    else {
        // for (int i = 0; i < batch_size; ++i) {
        //     if (state_->requests[i]) {
        //         auto& output_ids      = state_->requests[i]->outputs[rank_].at("output_ids");
        //         auto& sequence_length = state_->requests[i]->outputs[rank_].at("sequence_length");
        //         Copy(state_->output_ids + i * session_len_, output_ids.shape.at(2), output_ids.getPtr<int>());
        //         Copy(sequence_lengths_ + i, 1, sequence_length.getPtr<int>());
        //         if (g.step > g.max_init_ctx_len) {  // +1 for newly generated token
        //             invokePlusScalar(sequence_length.getPtr<int>(), 1, 1, stream_);
        //         }
        //     }
        // }
Li Zhang's avatar
Li Zhang committed
1302
1303
1304
1305
    }
}

template<typename T>
Li Zhang's avatar
Li Zhang committed
1306
void LlamaBatch<T>::CompleteRequest(int index, bool is_stop_request, bool is_force_end)
Li Zhang's avatar
Li Zhang committed
1307
1308
{
    if (rank_ == 0) {
Li Zhang's avatar
Li Zhang committed
1309
        TM_LOG_INFO("[CompleteRequest] slot = %d, id = %lu", index, (long)state_->requests[index]->id);
Li Zhang's avatar
Li Zhang committed
1310
1311
1312
    }

    if (debug_ && rank_ == 0) {
Li Zhang's avatar
Li Zhang committed
1313
1314
        std::vector<int> tokens(state_->h_context_length[index]);
        Copy(state_->output_ids + index * session_len_, tokens.size(), tokens.data());
Li Zhang's avatar
Li Zhang committed
1315
1316
1317
1318
1319
        cudaStreamSynchronize(stream_);
        std::stringstream ss;
        for (const auto& t : tokens) {
            ss << " " << t;
        }
Li Zhang's avatar
Li Zhang committed
1320
        TM_LOG_INFO("[CompleteRequest] slot %d, tokens [%s]", index, ss.str().c_str());
Li Zhang's avatar
Li Zhang committed
1321
1322
    }

Li Zhang's avatar
Li Zhang committed
1323
1324
    if (state_->requests[index]->end_flag || is_force_end) {
        sequence_manager_->Erase(state_->requests[index]->id);
Li Zhang's avatar
Li Zhang committed
1325
1326
    }
    else {
Li Zhang's avatar
Li Zhang committed
1327
1328
        // account for the last generated token if not a stop request (which doesn't generate)
        const int output_len = state_->h_context_length[index] + 1 - static_cast<int>(is_stop_request);
Li Zhang's avatar
Li Zhang committed
1329

Li Zhang's avatar
Li Zhang committed
1330
        auto& seq = *state_->sequences[index];
Li Zhang's avatar
Li Zhang committed
1331
1332

        // update token IDs
Li Zhang's avatar
Li Zhang committed
1333
1334
1335
1336
        seq.tokens.resize(output_len);

        const auto output_ids_data = state_->requests[index]->outputs[rank_].at("output_ids").getPtr<int>();
        Copy(output_ids_data, output_len, seq.tokens.data());
Li Zhang's avatar
Li Zhang committed
1337
1338

        // update random states
Li Zhang's avatar
Li Zhang committed
1339
1340
1341
1342
1343
1344
1345
        seq.random_state.resize(sizeof(curandState_t) * 2);

        // save random state in host memory
        if (auto ptr = (curandState_t*)seq.random_state.data()) {
            ptr = Copy(model_->GetTopKState(index), 1, ptr);
            ptr = Copy(model_->GetTopPState(index), 1, ptr);
        }
Li Zhang's avatar
Li Zhang committed
1346
1347
1348

        check_cuda_error(cudaStreamSynchronize(stream_));

Li Zhang's avatar
Li Zhang committed
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
        sequence_manager_->UpdateAndSetUnlock(seq);
    }

    state_->sequences[index] = nullptr;
}

template<typename T>
void LlamaBatch<T>::InternalThreadEntry(int device_id)
{
    TM_LOG_INFO("[InternalThreadEntry] %d", (int)rank_);
    check_cuda_error(cudaSetDevice(device_id));

    auto& shared_state = model_->shared_state_;

    auto& request_queue  = shared_state->request_queue;
    auto& infer_requests = shared_state->infer_requests;
    auto& stop_requests  = shared_state->stop_requests;

    int finished_count = 0;

    GenerationState g{};

    while (1) {
        if (rank_ == 0) {
            const int  free_slot_count = max_batch_size_ - state_->size + finished_count;
            const bool is_empty        = (free_slot_count == max_batch_size_);

            // will block if batch is empty
            request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty, shared_state->abort);

            if (!shared_state->abort) {
                RejectInvalidRequests(stop_requests, infer_requests);
            }
        }

        NvtxScope scope("mainloop");

        // wait while rank-0 is dequeueing
        shared_state->barrier->wait();

        if (shared_state->abort) {
            TM_LOG_INFO("[InternalThreadEntry] stop requested.");
            // if (state_->size && rank_ == 0) {
            //     TM_LOG_WARNING("Active request(s) present (%d) while exiting.", state_->size);
            // }
            return;
        }

        auto signals = ProcessStopRequests(stop_requests);
        BarrierSignalRequests(*shared_state->barrier, signals);

        ProcessInferRequests(infer_requests);

        // wait while shared stop/infer_requests is being used
        shared_state->barrier->wait();

        auto modified = Initialize();
Li Zhang's avatar
Li Zhang committed
1406
1407
        // finished sequences is handled by `Initialize()`
        finished_count = 0;
Li Zhang's avatar
Li Zhang committed
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424

        ContextDecode();

        if (state_->active_size) {
            if (modified) {
                g = InitializeGeneration();
                InitializeSampling();
            }
            for (int i = 0; i < step_length_; ++i) {
                if (!Generate(g)) {
                    break;
                }
            }
            auto signals   = Finish(g);
            finished_count = signals.size();
            BarrierSignalRequests(*shared_state->barrier, signals);
        }
Li Zhang's avatar
Li Zhang committed
1425
1426
    }

Li Zhang's avatar
Li Zhang committed
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
    FT_CHECK(0);
}

template<typename T>
void LlamaBatch<T>::BarrierSignalRequests(Barrier& barrier, const std::vector<Signal>& signals)
{
    if (!signals.empty()) {
        barrier.wait();
        if (rank_ == 0) {
            std::for_each(signals.cbegin(), signals.cend(), [](auto& s) { s(); });
        }
        barrier.wait();
    }
}

template<typename T>
void LlamaBatch<T>::Start()
{
    TM_LOG_ERROR("LlamaBatch<T>::Start()");
    int device_id = -1;
    check_cuda_error(cudaGetDevice(&device_id));
    internal_thread_ = std::thread(&LlamaBatch::InternalThreadEntry, this, device_id);
Li Zhang's avatar
Li Zhang committed
1449
    if (rank_ == 0) {
Li Zhang's avatar
Li Zhang committed
1450
        output_thread_ = std::thread(&LlamaBatch::OutputThreadEntry, this);
Li Zhang's avatar
Li Zhang committed
1451
    }
Li Zhang's avatar
Li Zhang committed
1452
}
Li Zhang's avatar
Li Zhang committed
1453

Li Zhang's avatar
Li Zhang committed
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
template<typename T>
void LlamaBatch<T>::OutputThreadEntry()
{
    while (true) {
        {
            // wait for requests with stream cbs
            std::unique_lock lock(output_mutex_);
            output_cv_.wait(lock, [&] { return !output_reqs_.empty() || output_stop_token_; });
            // NvtxScope _("output_callback");
            // stop requested
            if (output_stop_token_) {
                TM_LOG_INFO("[OutputThreadEntry] stop requested.");
                return;
            }

            if (rank_ == 0 && model_->ffi_lock_) {
                TM_LOG_INFO("acquire GIL");
                model_->ffi_lock_(1);
                TM_LOG_INFO("acquire GIL success");
            }
            // invoke stream cbs
            for (const auto& r : output_reqs_) {
                r->stream_cb(&r->outputs[rank_].get());
            }
            if (rank_ == 0 && model_->ffi_lock_) {
                TM_LOG_INFO("release GIL");
                model_->ffi_lock_(0);
                TM_LOG_INFO("release GIL success");
            }
            output_reqs_.clear();
        }
        FT_CHECK(output_reqs_.empty());
        // notify infer thread 0
        output_cv_.notify_one();
    }
Li Zhang's avatar
Li Zhang committed
1489
1490
1491
1492
1493
}

template class LlamaBatch<half>;
template class LlamaBatch<float>;

lvhan028's avatar
lvhan028 committed
1494
}  // namespace turbomind