LlamaV2.h 7.03 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 * Copyright (c) 2021, NAVER Corp.  Authored by CLOVA.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

AllentDan's avatar
AllentDan committed
19
20
// Modified from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h
Li Zhang's avatar
Li Zhang committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

#pragma once

#include "src/fastertransformer/layers/DynamicDecodeLayer.h"
#include "src/fastertransformer/models/llama/Barrier.h"
#include "src/fastertransformer/models/llama/LlamaBatch.h"
#include "src/fastertransformer/models/llama/LlamaContextDecoder.h"
#include "src/fastertransformer/models/llama/LlamaDecoder.h"
#include "src/fastertransformer/models/llama/LlamaWeight.h"
#include "src/fastertransformer/models/llama/Request.h"
#include "src/fastertransformer/utils/allocator.h"
#include "src/fastertransformer/utils/cublasMMWrapper.h"
#include "src/fastertransformer/utils/instance_comm.h"
#include "src/fastertransformer/utils/nccl_utils.h"
#include <unordered_map>

namespace fastertransformer {

template<typename T>
class LlamaV2 {
public:
    struct SharedState {
        std::vector<std::shared_ptr<Request>> infer_requests;
        std::vector<std::shared_ptr<Request>> stop_requests;
        RequestQueue                          request_queue;
        std::shared_ptr<Barrier>              barrier;
    };

    ~LlamaV2();

    LlamaV2(size_t                       head_num,
            size_t                       size_per_head,
            size_t                       inter_size,
            size_t                       num_layer,
            size_t                       vocab_size,
            size_t                       rotary_embedding_dim,
            float                        norm_eps,
            int                          max_batch_size,
            int                          max_context_token_num,
            int                          session_len,
            int                          step_length,
            int                          start_id,
            int                          end_id,
            int                          cache_max_entry_count,
            int                          cache_chunk_size,
66
            int                          quant_policy,
Li Zhang's avatar
Li Zhang committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
            bool                         use_context_fmha,
            std::shared_ptr<SharedState> shared_state,
            LlamaWeight<T>*              weights,
            NcclParam                    tensor_para,
            cudaStream_t                 stream,
            cublasMMWrapper*             cublas_wrapper,
            IAllocator*                  allocator,
            bool                         is_free_buffer_after_forward,
            cudaDeviceProp*              cuda_device_prop);

    struct Control {
        AbstractInstanceComm* comm;
        Request::Callback     callback;
    };

    void forward(std::unordered_map<std::string, Tensor>*       outputs,
                 const std::unordered_map<std::string, Tensor>* inputs,
                 Control                                        control);

    void stop(const std::vector<uint64_t>& seq_ids);

private:
    friend class Batch;

    void internalThreadEntry(int device_id);

93
    void initialize(bool use_context_fmha, int quant_policy);
Li Zhang's avatar
Li Zhang committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185

    void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step);

    void contextDecode(T*         deocder_output,
                       uintptr_t* k_cache_ptr,
                       uintptr_t* v_cache_ptr,
                       T*         context_decoder_input_buf,
                       T*         context_decoder_output_buf,
                       const int* input_ids,
                       const int* input_length,
                       const int* history_length,
                       const int* context_length,
                       size_t     token_num,
                       size_t     max_input_len,
                       size_t     max_context_len,
                       size_t     session_len,
                       size_t     batch_size);

    void decoderForward(T*         decoder_output,
                        uintptr_t* k_cache_ptr,
                        uintptr_t* v_cache_ptr,
                        T*         decoder_input,
                        const int* sequence_length,
                        const int* total_padding_count,
                        bool*      finished,
                        int        step,
                        int        ite,
                        size_t     session_len,
                        size_t     batch_size);

    void postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size);

    void dynamicDecode(int*            token_ids,
                       bool*           finished,
                       int*            sequence_length,
                       bool*           should_stop,
                       TensorMap*      inputs,
                       TensorMap*      outputs,
                       const float*    logits,
                       const uint32_t* seq_limit_len,
                       const int*      context_length,
                       const int*      end_ids,
                       int             step,
                       int             ite,
                       size_t          max_context_len,
                       size_t          token_ids_len,
                       size_t          batch_size);

    void start();

private:
    friend class LlamaBatch<T>;

    const size_t head_num_;
    const size_t size_per_head_;
    const size_t inter_size_;
    const size_t num_layer_;
    const size_t vocab_size_;
    const size_t rotary_embedding_dim_;
    float        rmsnorm_eps_ = 1e-6f;

    static constexpr bool neox_rotary_style_ = false;

    const int    start_id_;
    const int    end_id_;
    const size_t hidden_units_;

    const size_t local_head_num_;
    NcclParam    tensor_para_;

    cudaStream_t     stream_;
    cublasMMWrapper* cublas_wrapper_;
    IAllocator*      allocator_;
    bool             is_free_buffer_after_forward_;
    cudaDeviceProp*  cuda_device_prop_;

    const bool debug_{false};

    std::unique_ptr<LlamaCacheManager> kv_cache_mgr_;

    LlamaWeight<T>*            weights_{};
    LlamaDecoder<T>*           decoder_{};
    LlamaContextDecoder<T>*    context_decoder_{};
    DynamicDecodeLayer<float>* dynamic_decode_layer_{};

    const int                    step_length_;
    LlamaBatch<T>                batch_;
    std::shared_ptr<SharedState> shared_state_;

    std::thread internal_thread_;
};

AllentDan's avatar
AllentDan committed
186
}  // namespace fastertransformer