LlamaV2.h 7.46 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

AllentDan's avatar
AllentDan committed
19
// Modified from
lvhan028's avatar
lvhan028 committed
20
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/models/multi_gpu_gpt/ParallelGpt.h
Li Zhang's avatar
Li Zhang committed
21
22
23

#pragma once

lvhan028's avatar
lvhan028 committed
24
25
26
27
28
#include "src/turbomind/layers/DynamicDecodeLayer.h"
#include "src/turbomind/models/llama/Barrier.h"
#include "src/turbomind/models/llama/LlamaBatch.h"
#include "src/turbomind/models/llama/LlamaWeight.h"
#include "src/turbomind/models/llama/Request.h"
Li Zhang's avatar
Li Zhang committed
29
30
#include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/models/llama/llama_params.h"
31
#include "src/turbomind/models/llama/unified_decoder.h"
lvhan028's avatar
lvhan028 committed
32
33
34
35
#include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/instance_comm.h"
#include "src/turbomind/utils/nccl_utils.h"
Li Zhang's avatar
Li Zhang committed
36
37
#include <unordered_map>

Chen Xin's avatar
Chen Xin committed
38
39
using ffi_api_lock_ctrl_t = std::function<void(int)>;

lvhan028's avatar
lvhan028 committed
40
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
41
42
43
44
45
46
47
48
49

template<typename T>
class LlamaV2 {
public:
    struct SharedState {
        std::vector<std::shared_ptr<Request>> infer_requests;
        std::vector<std::shared_ptr<Request>> stop_requests;
        RequestQueue                          request_queue;
        std::shared_ptr<Barrier>              barrier;
Li Zhang's avatar
Li Zhang committed
50
        bool                                  abort;
Li Zhang's avatar
Li Zhang committed
51
52
53
54
55
    };

    ~LlamaV2();

    LlamaV2(size_t                       head_num,
56
            size_t                       kv_head_num,
Li Zhang's avatar
Li Zhang committed
57
58
59
60
61
            size_t                       size_per_head,
            size_t                       inter_size,
            size_t                       num_layer,
            size_t                       vocab_size,
            float                        norm_eps,
62
            const LlamaAttentionParams&  attn_params,
Li Zhang's avatar
Li Zhang committed
63
64
            int                          start_id,
            int                          end_id,
Li Zhang's avatar
Li Zhang committed
65
            int                          cache_block_seq_len,
66
            int                          quant_policy,
Li Zhang's avatar
Li Zhang committed
67
            bool                         use_context_fmha,
68
            const EngineParams&          engine_params,
Li Zhang's avatar
Li Zhang committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
            std::shared_ptr<SharedState> shared_state,
            LlamaWeight<T>*              weights,
            NcclParam                    tensor_para,
            cudaStream_t                 stream,
            cublasMMWrapper*             cublas_wrapper,
            IAllocator*                  allocator,
            bool                         is_free_buffer_after_forward,
            cudaDeviceProp*              cuda_device_prop);

    struct Control {
        AbstractInstanceComm* comm;
        Request::Callback     callback;
    };

    void forward(std::unordered_map<std::string, Tensor>*       outputs,
                 const std::unordered_map<std::string, Tensor>* inputs,
                 Control                                        control);

    void stop(const std::vector<uint64_t>& seq_ids);

89
90
91
92
93
    size_t vocab_size() const noexcept
    {
        return vocab_size_;
    }

Chen Xin's avatar
Chen Xin committed
94
95
96
97
98
    void setFfiLock(ffi_api_lock_ctrl_t func)
    {
        ffi_lock_ = func;
    }

Li Zhang's avatar
Li Zhang committed
99
100
101
private:
    friend class Batch;

Li Zhang's avatar
Li Zhang committed
102
103
104
105
106
    void initialize(const LlamaAttentionParams& attn_params,
                    size_t                      kv_head_num,
                    bool                        use_context_fmha,
                    int                         cache_block_seq_len,
                    int                         quant_policy);
Li Zhang's avatar
Li Zhang committed
107
108
109

    void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step);

Chen Xin's avatar
Chen Xin committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    void updateEmbedding(T* decoder_input, const int bsz, const int* h_input_length, const Sequence** sequences);

    void forwardUnified(T*               out,
                        T*               decoder_output,
                        T*               decoder_input,
                        void**           k_block_ptrs,
                        void**           v_block_ptrs,
                        const int*       input_ids,
                        const int*       cu_block_cnts,
                        const float*     rope_theta,
                        const bool*      dc_finished,
                        const int*       pf_input_length,
                        const int*       pf_context_length,
                        T**              pf_tmp_k_ptrs,
                        T**              pf_tmp_v_ptrs,
                        size_t           token_num,
                        int              dc_batch_size,
                        int              dc_step,
                        int              dc_sum_seq_len,
                        int              dc_max_seq_len,
                        int              pf_batch_size,
                        int              pf_max_input_len,
                        int              pf_max_context_len,
                        int              pf_session_len,
                        const int*       h_input_length,
                        const Sequence** sequences);
Li Zhang's avatar
Li Zhang committed
136
137
138
139
140
141
142

    void postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size);

    void dynamicDecode(int*            token_ids,
                       bool*           finished,
                       int*            sequence_length,
                       bool*           should_stop,
Li Zhang's avatar
Li Zhang committed
143
                       curandState_t*  curand_state,
Li Zhang's avatar
Li Zhang committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
                       TensorMap*      inputs,
                       TensorMap*      outputs,
                       const float*    logits,
                       const uint32_t* seq_limit_len,
                       const int*      context_length,
                       const int*      end_ids,
                       int             step,
                       int             ite,
                       size_t          max_context_len,
                       size_t          token_ids_len,
                       size_t          batch_size);

private:
    friend class LlamaBatch<T>;

    const size_t head_num_;
    const size_t size_per_head_;
    const size_t inter_size_;
    const size_t num_layer_;
    const size_t vocab_size_;
164
    size_t       vocab_size_padded_;
Li Zhang's avatar
Li Zhang committed
165
166
    float        rmsnorm_eps_ = 1e-6f;

Li Zhang's avatar
Li Zhang committed
167
168
    const LlamaAttentionParams attn_params_;

Li Zhang's avatar
Li Zhang committed
169
170
171
172
173
174
175
    static constexpr bool neox_rotary_style_ = false;

    const int    start_id_;
    const int    end_id_;
    const size_t hidden_units_;

    const size_t local_head_num_;
Li Zhang's avatar
Li Zhang committed
176
    const size_t local_kv_head_num_;
Li Zhang's avatar
Li Zhang committed
177
178
179
180
181
182
183
184
185
186
    NcclParam    tensor_para_;

    cudaStream_t     stream_;
    cublasMMWrapper* cublas_wrapper_;
    IAllocator*      allocator_;
    bool             is_free_buffer_after_forward_;
    cudaDeviceProp*  cuda_device_prop_;

    const bool debug_{false};

187
188
189
190
    LlamaWeight<T>* weights_{};

    std::unique_ptr<UnifiedDecoder<T>> unified_decoder_;
    DynamicDecodeLayer<float>*         dynamic_decode_layer_{};
Li Zhang's avatar
Li Zhang committed
191

Li Zhang's avatar
Li Zhang committed
192
193
194
    std::shared_ptr<SharedState>   shared_state_;
    ffi_api_lock_ctrl_t            ffi_lock_;
    std::unique_ptr<LlamaBatch<T>> batch_;
Li Zhang's avatar
Li Zhang committed
195
196
};

lvhan028's avatar
lvhan028 committed
197
}  // namespace turbomind