LlamaBatch.h 4.2 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

lvhan028's avatar
lvhan028 committed
5
6
7
8
#include "src/turbomind/models/llama/LlamaCacheManager.h"
#include "src/turbomind/models/llama/Request.h"
#include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
Li Zhang's avatar
Li Zhang committed
9

lvhan028's avatar
lvhan028 committed
10
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

template<typename T>
class LlamaV2;

template<typename T>
class LlamaBatch {
public:
    int size() const noexcept
    {
        return batch_size_;
    };

    int maxSize() const noexcept
    {
        return max_batch_size_;
    }

    int finishedCount() const noexcept
    {
        return finished_count_;
    }

    void verifyRequests(std::vector<std::shared_ptr<Request>>& stop_reqs,
                        std::vector<std::shared_ptr<Request>>& infer_reqs);
    void handleStopRequests(const std::vector<std::shared_ptr<Request>>& requests);

    void allocateBuffer(size_t batch_size, size_t session_len);
    void allocatePersistantBuffer(size_t max_batch_size);
    void freeBuffer();

    void initializeSampling(int infer_request_count);

    void initialize(const std::vector<std::shared_ptr<Request>>& infer_requests);
    void contextDecode();

    void initializeGeneration();
    bool generate();

    void finish();
    void finishRequest(int index, bool force_end);

    void synchronize();

    void setOutputTensors(int max_gen_step);

    explicit LlamaBatch(int max_batch_size, int max_context_token_num, int session_len, LlamaV2<T>* llama);

    ~LlamaBatch()
    {
        freeBuffer();
    }

private:
    const int  max_batch_size_;
    const int  max_context_token_num_;
    const int  session_len_;
    const int  rank_;
    const bool debug_;

    LlamaV2<T>* const llama_;

    // active requests
    std::vector<std::shared_ptr<Request>> requests_;

    T* context_decoder_input_buf_{};  // CTXDEC
    // T* context_decoder_output_buf_{};  // CTXDEC
    int* context_decoder_ids_buf_{};

    T* decoder_input_buf_{};   // CTXDEC, GENERATE
    T* decoder_output_buf_{};  // CTXDEC, GENERATE

    int* input_ids_buf_{};       // input token ids + cache missed token ids, CTXDEC
    int* input_length_buf_{};    // input + cache missed length, CTXDEC, GENERATE
    int* history_length_buf_{};  // history length, CTXDEC
    int* context_length_buf_{};  // history length + input_length, CTXDEC, GENERATE

    int* total_padding_count_{};  // GENERATE
    int* sequence_lengths_{};     // current sequence length

    uint64_t* k_cache_ptr_buf_{};
    uint64_t* v_cache_ptr_buf_{};

    float* logits_buf_{};        // combined logits
    float* local_logits_buf_{};  // tensor parallel local logits

    // used by dynamic decoder
    int*      token_ids_buf_{};   // all token IDs in [S, B], indexed using `step`
    int*      output_ids_buf_{};  // output ids in [B, S]
    int*      end_ids_buf_{};
    bool*     finished_buf_{};
    uint32_t* seq_limit_len_{};

    // pinned buffers
    int*       h_input_ids_buf_{};
    int*       h_input_length_buf_{};
    int*       h_history_length_buf_{};
    int*       h_context_length_buf_{};
    int*       h_sequence_lengths_{};
    bool*      h_finished_buf_{};
    uintptr_t* h_k_cache_ptr_buf_{};
    uintptr_t* h_v_cache_ptr_buf_{};
    uint32_t*  h_seq_limit_len_{};

    int*      stop_words_buf_{};  // [batch_size, 2, kMaxStopWordsLen]
    int*      bad_words_buf_{};
    int*      h_runtime_top_k_{};
    float*    h_runtime_top_p_{};
    float*    h_temperature_{};
    float*    h_repetition_penalty_{};
    uint64_t* h_random_seed_{};

    void* topk_curandstate_buf_{};
    void* topp_curandstate_buf_{};

AllentDan's avatar
AllentDan committed
125
    // hard limits for persistent buffers
Li Zhang's avatar
Li Zhang committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    static constexpr int kMaxStopBadWordsLen = 32;

    using CachedSeq = LlamaCacheManager::Sequence;

    std::vector<CachedSeq> cached_seq_;
    std::vector<int>       request_seq_len_limit_;

    const DataType data_type_{};

    int batch_size_{};
    int max_context_len_{};
    int step_{};
    int finished_count_{};

    bool is_allocate_persistant_buffer_ = false;
    bool is_allocate_buffer_            = false;

    TensorMap inputs_;
    TensorMap outputs_;

    std::unordered_map<std::string, void*> sampling_params_;

    cudaStream_t     stream_{};
    cublasMMWrapper* cublas_wrapper_{};
    IAllocator*      allocator_{};
};

lvhan028's avatar
lvhan028 committed
153
}  // namespace turbomind