LlamaBatch.h 8.89 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

Li Zhang's avatar
Li Zhang committed
5
// #include "src/turbomind/models/llama/LlamaCacheManager.h"
Li Zhang's avatar
Li Zhang committed
6
#include "src/turbomind/layers/sampling_layers/BaseSamplingLayer.h"
Li Zhang's avatar
Li Zhang committed
7
#include "src/turbomind/models/llama/Barrier.h"
8
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
lvhan028's avatar
lvhan028 committed
9
#include "src/turbomind/models/llama/Request.h"
Li Zhang's avatar
Li Zhang committed
10
#include "src/turbomind/models/llama/SequenceManager.h"
Li Zhang's avatar
Li Zhang committed
11
#include "src/turbomind/models/llama/llama_kernels.h"
lvhan028's avatar
lvhan028 committed
12
13
#include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
Li Zhang's avatar
Li Zhang committed
14
#include "src/turbomind/utils/cuda_utils.h"
Li Zhang's avatar
Li Zhang committed
15
16
#include <condition_variable>
#include <mutex>
Li Zhang's avatar
Li Zhang committed
17
#include <type_traits>
Li Zhang's avatar
Li Zhang committed
18

lvhan028's avatar
lvhan028 committed
19
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
20

Li Zhang's avatar
Li Zhang committed
21
22
23
24
struct BatchState {
    int*  h_context_length;
    bool* h_finished;

Li Zhang's avatar
Li Zhang committed
25
26
    curandState_t* curand_state;
    int*           output_ids;  // output ids in [B, S]
Li Zhang's avatar
Li Zhang committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

    float* h_rope_theta;

    std::vector<int> seq_len_limit;
    std::vector<int> is_swap_in;

    std::vector<const Sequence*>          sequences;
    std::vector<std::shared_ptr<Request>> requests;

    // |<-- existing -->|<-- swap-in -->|
    // |<----------- active ----------->|<-- inactive -->|
    int active_size;
    int size;
};

Li Zhang's avatar
Li Zhang committed
42
43
44
45
46
47
template<typename T>
class LlamaV2;

template<typename T>
class LlamaBatch {
public:
Li Zhang's avatar
Li Zhang committed
48
49
50
    void AllocateBuffer(size_t batch_size, size_t session_len);
    void AllocatePersistantBuffer(size_t max_batch_size);
    void FreeBuffer();
Li Zhang's avatar
Li Zhang committed
51

Li Zhang's avatar
Li Zhang committed
52
53
    using Requests = std::vector<std::shared_ptr<Request>>;
    using Signal   = std::function<void()>;
Li Zhang's avatar
Li Zhang committed
54

Li Zhang's avatar
Li Zhang committed
55
56
57
58
59
    void RejectInvalidRequests(Requests& stop_reqs, Requests& infer_reqs);

    [[nodiscard]] auto ProcessStopRequests(const Requests& requests) -> std::vector<Signal>;

    void ProcessInferRequests(const Requests& requests);
Li Zhang's avatar
Li Zhang committed
60

Li Zhang's avatar
Li Zhang committed
61
    [[nodiscard]] bool Initialize();
Li Zhang's avatar
Li Zhang committed
62

Li Zhang's avatar
Li Zhang committed
63
    void ContextDecode();
Li Zhang's avatar
Li Zhang committed
64

Li Zhang's avatar
Li Zhang committed
65
66
67
68
69
70
    struct GenerationState {
        int max_init_ctx_len;
        int step;
        int sum_seq_len;
        int max_seq_len;
    };
Li Zhang's avatar
Li Zhang committed
71

Li Zhang's avatar
Li Zhang committed
72
73
    void InitializeSampling();

Li Zhang's avatar
Li Zhang committed
74
    GenerationState InitializeGeneration();
Li Zhang's avatar
Li Zhang committed
75

Li Zhang's avatar
Li Zhang committed
76
    [[nodiscard]] bool Generate(GenerationState& g);
Li Zhang's avatar
Li Zhang committed
77

Li Zhang's avatar
Li Zhang committed
78
    [[nodiscard]] auto Finish(GenerationState& g, int& finished_count) -> std::vector<Signal>;
Li Zhang's avatar
Li Zhang committed
79

Li Zhang's avatar
Li Zhang committed
80
    [[nodiscard]] Signal Interrupt(int index, bool force_stop = false, bool force_end = false);
Li Zhang's avatar
Li Zhang committed
81

82
    void
Li Zhang's avatar
Li Zhang committed
83
    OutputContextLogits(T* context_decoder_output, const std::vector<int>& indices, const std::vector<int>& lengths);
84

Li Zhang's avatar
Li Zhang committed
85
86
87
88
89
    explicit LlamaBatch(int                              max_batch_size,
                        int                              max_context_token_num,
                        int                              session_len,
                        std::unique_ptr<SequenceManager> sequence_manager,
                        LlamaV2<T>*                      llama);
Li Zhang's avatar
Li Zhang committed
90
91
92

    ~LlamaBatch()
    {
Li Zhang's avatar
Li Zhang committed
93
        TM_LOG_INFO("~LlamaBatch()");
Li Zhang's avatar
Li Zhang committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        model_->shared_state_->request_queue.close();

        internal_thread_.join();

        if (output_thread_.joinable()) {
            {
                std::lock_guard lock{output_mutex_};
                output_stop_token_ = true;
            }
            output_cv_.notify_one();
            output_thread_.join();
        }

        FreeBuffer();
    }

    void Start();

private:
    void InternalThreadEntry(int device_id);

    void OutputThreadEntry();

Li Zhang's avatar
Li Zhang committed
117
    void CopyState(const std::vector<std::tuple<BatchState*, BatchState*, int, int>>& desc);
Li Zhang's avatar
Li Zhang committed
118

Li Zhang's avatar
Li Zhang committed
119
    void SendSignals(std::vector<Signal> signals);
Li Zhang's avatar
Li Zhang committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133

    // analogs to `std::copy_n`
    template<typename U>
    U* Copy(const U* src, size_t count, U* dst)
    {
        check_cuda_error(cudaMemcpyAsync(dst, src, sizeof(U) * count, cudaMemcpyDefault, stream_));
        return dst += count;
    }

    template<typename U>
    U* Clear(U* data, size_t count)
    {
        check_cuda_error(cudaMemsetAsync(data, 0, sizeof(U) * count, stream_));
        return data += count;
Li Zhang's avatar
Li Zhang committed
134
135
    }

Li Zhang's avatar
Li Zhang committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    template<class... Ts>
    void IndexedCopyImpl(const int* src_idx, const int* dst_idx, int count, const std::tuple<Ts*, Ts*, int>&... cpys)
    {
        if (!count) {
            return;
        }
        constexpr int N = sizeof...(Ts);
        static_assert((!std::is_same_v<Ts, void> && ...));
        std::array<void*, N> src_ptr{std::get<0>(cpys)...};
        std::array<void*, N> dst_ptr{std::get<1>(cpys)...};
        std::array<int, N>   elem_sz{int(sizeof(Ts) * std::get<2>(cpys))...};
        invokeIndexedCopy(src_ptr.data(),  //
                          dst_ptr.data(),
                          elem_sz.data(),
                          src_idx,
                          dst_idx,
                          count,
                          N,
                          stream_);
        sync_check_cuda_error();
    }

    template<class... Ts>
    void IndexedCopy(const std::vector<int>& src_idx,
                     const std::vector<int>& dst_idx,
                     const std::tuple<Ts*, Ts*, int>&... cpys)
    {
        // has the same size, or one is empty
        FT_CHECK(src_idx.size() == dst_idx.size() || (src_idx.empty() ^ dst_idx.empty()));
        IndexedCopyImpl(src_idx.empty() ? nullptr : src_idx.data(),
                        dst_idx.empty() ? nullptr : dst_idx.data(),
                        std::max(src_idx.size(), dst_idx.size()),
                        cpys...);
    }

    template<class... Ts>
    void IndexedCopy(int count, const std::tuple<Ts*, Ts*, int>&... cpys)
    {
        IndexedCopyImpl(nullptr, nullptr, count, cpys...);
    }

Li Zhang's avatar
Li Zhang committed
177
178
179
180
181
182
private:
    const int  max_batch_size_;
    const int  max_context_token_num_;
    const int  session_len_;
    const int  rank_;
    const bool debug_;
Li Zhang's avatar
Li Zhang committed
183
    const int  step_length_;
Li Zhang's avatar
Li Zhang committed
184

Li Zhang's avatar
Li Zhang committed
185
    LlamaV2<T>* const model_;
Li Zhang's avatar
Li Zhang committed
186

Li Zhang's avatar
Li Zhang committed
187
    std::unique_ptr<SequenceManager> sequence_manager_;
Li Zhang's avatar
Li Zhang committed
188

Li Zhang's avatar
Li Zhang committed
189
190
191
192
193
    ///////////////////////////////////////////////////////////////////
    // k/v cache block buffers
    int*       cu_block_counts_{};
    uintptr_t* k_block_ptrs_{};
    uintptr_t* v_block_ptrs_{};
Li Zhang's avatar
Li Zhang committed
194

Li Zhang's avatar
Li Zhang committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    ////////////////////////////////////////////////////////////////////
    // context decoding temp buffers
    T*   context_decoder_input_buf_{};
    T*   context_decoder_output_buf_{};
    int* context_decoder_ids_buf_{};
    int* input_ids_buf_{};
    // lengths
    int* input_length_buf_{};    // input + cache missed length
    int* context_length_buf_{};  // history length + input_length
    // temp buffers used for block->linear kv-cache conversion
    T*     tmp_k_cache_buf_{};
    T*     tmp_v_cache_buf_{};
    void** tmp_k_ptrs_{};
    void** tmp_v_ptrs_{};
    void** h_tmp_k_ptrs_{};
    void** h_tmp_v_ptrs_{};

    T*   decoder_input_buf_{};
    T*   decoder_output_buf_{};
    int* sequence_lengths_{};  // current sequence length
    int* init_ctx_lens_{};
Li Zhang's avatar
Li Zhang committed
216
217
218

    float* logits_buf_{};        // combined logits
    float* local_logits_buf_{};  // tensor parallel local logits
219
220
    float* context_logits_buf_{};
    float* local_context_logits_buf_{};
Li Zhang's avatar
Li Zhang committed
221

Li Zhang's avatar
Li Zhang committed
222
223
    float* rope_theta_{};

Li Zhang's avatar
Li Zhang committed
224
    // used by dynamic decoder
Li Zhang's avatar
Li Zhang committed
225
    int*      token_ids_buf_{};  // all token IDs in [S, B], indexed using `step`
Li Zhang's avatar
Li Zhang committed
226
227
    bool*     finished_buf_{};
    uint32_t* seq_limit_len_{};
Li Zhang's avatar
Li Zhang committed
228
229
    int*      h_end_ids_buf_{};
    int*      d_end_ids_buf_{};
Li Zhang's avatar
Li Zhang committed
230

Li Zhang's avatar
Li Zhang committed
231
232
233
234
235
236
237
    int** request_output_ids_ptrs_{};
    int*  request_output_ids_lens_{};
    int** request_seqlen_ptrs_{};
    int** h_request_output_ids_ptrs_{};
    int*  h_request_output_ids_lens_{};
    int** h_request_seqlen_ptrs_{};

Li Zhang's avatar
Li Zhang committed
238
239
240
241
    // pinned buffers
    int*       h_input_ids_buf_{};
    int*       h_input_length_buf_{};
    uint32_t*  h_seq_limit_len_{};
Li Zhang's avatar
Li Zhang committed
242
243
244
    int*       h_cu_block_counts_{};
    uintptr_t* h_k_block_ptrs_{};
    uintptr_t* h_v_block_ptrs_{};
Li Zhang's avatar
Li Zhang committed
245

Li Zhang's avatar
Li Zhang committed
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    int*   h_runtime_top_k_{};
    float* h_runtime_top_p_{};
    float* h_temperature_{};
    float* h_repetition_penalty_{};
    int*   h_stop_words_{};  // [batch_size, 2, kMaxStopWordsLen]
    int*   h_bad_words_{};
    int*   d_stop_words_{};  // [batch_size, 2, kMaxStopWordsLen]
    int*   d_bad_words_{};

    unsigned long long* h_random_seed_{};
    unsigned long long* d_random_seed_{};

    curandState_t* h_curand_state_{};
    curandState_t* d_curand_state_{};
Li Zhang's avatar
Li Zhang committed
260

Li Zhang's avatar
Li Zhang committed
261
    std::array<BatchState, 3> states_{};
Li Zhang's avatar
Li Zhang committed
262

Li Zhang's avatar
Li Zhang committed
263
264
265
    BatchState* state_{};
    BatchState* back_{};
    BatchState* incoming_{};
Li Zhang's avatar
Li Zhang committed
266

Li Zhang's avatar
Li Zhang committed
267
    uint64_t request_count_{0};
Li Zhang's avatar
Li Zhang committed
268

Li Zhang's avatar
Li Zhang committed
269
270
    // hard limits for persistent buffers
    static constexpr int kMaxStopBadWordsLen = 32;
Li Zhang's avatar
Li Zhang committed
271
272
273
274
275
276
277
278
279

    const DataType data_type_{};

    bool is_allocate_persistant_buffer_ = false;
    bool is_allocate_buffer_            = false;

    TensorMap inputs_;
    TensorMap outputs_;

Li Zhang's avatar
Li Zhang committed
280
    std::vector<std::tuple<std::string, std::byte*, std::byte*>> sampling_params_;
Li Zhang's avatar
Li Zhang committed
281
282
283
284

    cudaStream_t     stream_{};
    cublasMMWrapper* cublas_wrapper_{};
    IAllocator*      allocator_{};
Li Zhang's avatar
Li Zhang committed
285
286
287
288
289
290
291

    std::thread internal_thread_;

    // async stream callback utils
    std::thread             output_thread_;
    std::mutex              output_mutex_;
    std::condition_variable output_cv_;
Li Zhang's avatar
Li Zhang committed
292
    std::vector<Signal>     output_signals_;
Li Zhang's avatar
Li Zhang committed
293
    bool                    output_stop_token_{false};
Li Zhang's avatar
Li Zhang committed
294
295

    int* h_output_ids_{};
Li Zhang's avatar
Li Zhang committed
296
297
};

lvhan028's avatar
lvhan028 committed
298
}  // namespace turbomind