"INSTALL/VentoyWebDeepin.sh" did not exist on "9f357f8ed138e92087a922812f9df3ca8aa5cbc6"
LlamaTritonModelInstance.h 3.86 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

 // Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h

#pragma once

#include "src/fastertransformer/models/llama/LlamaV2.h"
#include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h"
#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp"
#include <memory>

namespace ft = fastertransformer;

template<typename T>
struct LlamaTritonSharedModelInstance {
    std::unique_ptr<ft::LlamaV2<T>>                         llm;
    std::shared_ptr<ft::LlamaWeight<T>>                     llm_weight;
    std::unique_ptr<ft::Allocator<ft::AllocatorType::CUDA>> allocator;
    std::unique_ptr<ft::cublasAlgoMap>                      cublas_algo_map;
    std::unique_ptr<std::mutex>                             cublas_wrapper_mutex;
    std::unique_ptr<ft::cublasMMWrapper>                    cublas_wrapper;
    std::unique_ptr<cudaDeviceProp>                         cuda_device_prop_ptr;
    const int                                               session_len;
};

template<typename T>
struct LlamaTritonModelInstance: AbstractTransformerModelInstance {

    LlamaTritonModelInstance(std::shared_ptr<LlamaTritonSharedModelInstance<T>>      instance,
                             std::unique_ptr<ft::Allocator<ft::AllocatorType::CUDA>> allocator);
    ~LlamaTritonModelInstance();

    std::shared_ptr<std::vector<triton::Tensor>>
    forward(std::shared_ptr<std::vector<triton::Tensor>> input_tensors) override;

    std::shared_ptr<std::unordered_map<std::string, triton::Tensor>>
    forward(std::shared_ptr<std::unordered_map<std::string, triton::Tensor>> input_tensors) override;

    std::shared_ptr<std::unordered_map<std::string, triton::Tensor>>
    forward(std::shared_ptr<std::unordered_map<std::string, triton::Tensor>> input_tensors,
            ft::AbstractInstanceComm*) override;

    static std::shared_ptr<std::unordered_map<std::string, triton::Tensor>>
    convert_outputs(const std::unordered_map<std::string, ft::Tensor>& output_tensors);

private:
    const std::shared_ptr<LlamaTritonSharedModelInstance<T>>      instance_;
    const std::unique_ptr<ft::Allocator<ft::AllocatorType::CUDA>> allocator_;

    std::unordered_map<std::string, ft::Tensor>
    convert_inputs(std::shared_ptr<std::unordered_map<std::string, triton::Tensor>> input_tensors);

    void allocateBuffer(const size_t request_batch_size, const size_t beam_width, const size_t session_len);
    void freeBuffer();

    int*   d_input_ids_                = nullptr;
    int*   d_input_lengths_            = nullptr;
    int*   d_input_bad_words_          = nullptr;
    int*   d_input_stop_words_         = nullptr;
    int*   d_request_prompt_lengths_   = nullptr;
    T*     d_request_prompt_embedding_ = nullptr;
    float* d_top_p_decay_              = nullptr;
    float* d_top_p_min_                = nullptr;
    int*   d_top_p_reset_ids_          = nullptr;

    int*   d_output_ids_       = nullptr;
    int*   d_sequence_lengths_ = nullptr;
    float* d_output_log_probs_ = nullptr;
    float* d_cum_log_probs_    = nullptr;

    uint32_t*          h_total_output_lengths_ = nullptr;
    std::exception_ptr h_exception_            = nullptr;
};