LlamaTritonModel.h 5.2 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

Li Zhang's avatar
Li Zhang committed
18
// Modified from
lvhan028's avatar
lvhan028 committed
19
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/turbomind/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h
Li Zhang's avatar
Li Zhang committed
20
21
22

#pragma once

lvhan028's avatar
lvhan028 committed
23
#include "src/turbomind/models/llama/LlamaV2.h"
24
#include "src/turbomind/models/llama/llama_params.h"
lvhan028's avatar
lvhan028 committed
25
26
27
28
29
#include "src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h"
#include "src/turbomind/triton_backend/transformer_triton_backend.hpp"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/custom_ar_comm.h"
#include "src/turbomind/utils/nccl_utils.h"
Li Zhang's avatar
Li Zhang committed
30
31
32
#include <cuda_fp16.h>
#include <mutex>

lvhan028's avatar
lvhan028 committed
33
namespace ft = turbomind;
Li Zhang's avatar
Li Zhang committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

template<typename T>
struct LlamaTritonSharedModelInstance;

template<typename T>
struct LlamaTritonModel: public AbstractTransformerModel {
    LlamaTritonModel(size_t      tensor_para_size,
                     size_t      pipeline_para_size,
                     int         enable_custom_all_reduce,
                     std::string model_dir);

    ~LlamaTritonModel() = default;

    std::unique_ptr<AbstractTransformerModelInstance>
    createModelInstance(int                                                               deviceId,
                        int                                                               rank,
                        cudaStream_t                                                      stream,
                        std::pair<std::vector<ft::NcclParam>, std::vector<ft::NcclParam>> nccl_params,
                        std::shared_ptr<ft::AbstractCustomComm> custom_all_reduce_comm = nullptr) override;

    void createSharedWeights(int deviceId, int rank) override;

    void createCustomComms(std::vector<std::shared_ptr<ft::AbstractCustomComm>>* custom_all_reduce_comms,
                           int                                                   world_size) override;

    std::pair<std::vector<ft::NcclParam>, std::vector<ft::NcclParam>>
    createNcclParams(const int node_id, const int device_id_start, const bool multi_node) override;

    std::unique_ptr<ft::AbstractInstanceComm> createInstanceComm(int size) override;

    void handleMissingParams();

Chen Xin's avatar
Chen Xin committed
66
67
68
69
70
    void setFfiLock(ffi_api_lock_ctrl_t func)
    {
        ffi_lock_ = func;
    }

Li Zhang's avatar
Li Zhang committed
71
72
73
74
75
76
77
78
79
80
81
    std::string toString() override;
    int         getTensorParaSize() override;
    int         getPipelineParaSize() override;

private:
    std::unique_ptr<LlamaTritonSharedModelInstance<T>>
    createSharedModelInstance(int                                                               deviceId,
                              int                                                               rank,
                              std::pair<std::vector<ft::NcclParam>, std::vector<ft::NcclParam>> nccl_params,
                              std::shared_ptr<ft::AbstractCustomComm> custom_all_reduce_comm = nullptr);

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    size_t                          head_num_;
    size_t                          kv_head_num_;
    size_t                          size_per_head_;
    size_t                          inter_size_;
    size_t                          num_layer_;
    size_t                          vocab_size_;
    turbomind::LlamaAttentionParams attn_params_;
    float                           norm_eps_;
    int                             max_batch_size_;
    int                             max_context_token_num_;
    int                             session_len_;
    int                             step_length_;
    int                             start_id_;
    int                             end_id_;
    int                             cache_max_entry_count_;
    int                             cache_chunk_size_;
    int                             use_context_fmha_;
    size_t                          tensor_para_size_;
    size_t                          pipeline_para_size_;
    ft::WeightType                  weight_type_;
    bool                            attn_bias_;
    int                             quant_policy_;
    int                             group_size_;
Li Zhang's avatar
Li Zhang committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

    // shared weights for each device
    std::vector<std::shared_ptr<ft::LlamaWeight<T>>> shared_weights_;

    std::shared_ptr<typename ft::LlamaV2<T>::SharedState> shared_state_;

    // weak_ptr is used so that the instances get released when all strong references are gone
    std::vector<std::weak_ptr<LlamaTritonSharedModelInstance<T>>> shared_instances_;
    std::deque<std::mutex>                                        shared_mutexes_;  // is locking really needed?

    bool is_fp16_;
    int  enable_custom_all_reduce_ = 0;

    std::string model_name_;
    std::string model_dir_;
Chen Xin's avatar
Chen Xin committed
120
121

    ffi_api_lock_ctrl_t ffi_lock_ = nullptr;
Li Zhang's avatar
Li Zhang committed
122
};