GPT2.cpp 3.7 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#include <GPT2.h>

#include <onnxruntime/core/session/onnxruntime_cxx_api.h>

#include <Filesystem.h>
#include <SimpleLog.h>
#include <algorithm>
#include <stdexcept>
#include <tokenization.h>

namespace ortSamples
{

GPT2::GPT2()
{

}

GPT2::~GPT2()
{

}

ErrorCode GPT2::Initialize()
{
    // 获取模型文件
    std::string modelPath="../Resource/GPT2_shici.onnx";

    // 判断路径
    if(Exists(modelPath)==false)
    {
        LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
        return MODEL_NOT_EXIST;
    }
    
    //加载模型
    OrtROCMProviderOptions rocm_options;
    rocm_options.device_id = 0;
    sessionOptions.AppendExecutionProvider_ROCM(rocm_options);
    sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_BASIC);
    LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
    session = new Ort::Session(env, modelPath.c_str(), sessionOptions);             

    return SUCCESS;
}

static bool CompareM(Predictions a, Predictions b)
{
	return a.predictionvalue > b.predictionvalue;
}

long unsigned int GPT2::Inference(const std::vector<long unsigned int> &input_id)
{

    int64_t input[1][input_id.size()];
    for (int j=0;j<input_id.size();++j)
    {

        input[0][j] = static_cast<int64_t>(input_id[j]);;
    }

    // 获取模型输入属性
    input_node_names = {"input"};
  
    // 获取模型输出属性
    output_node_names = {"output","304","value.3","532","value.11","758","value.19","984","value.27","1210","value.35","1436","value.43","1662","value.51","1888","value.59","2114","value.67","2340","value.75","2566","value.83","2792","value.91"};

    // 设置输入shape
    std::array<int64_t, 2> inputShapes{1, (int)input_id.size()};
    
    // 创建输入数据
    auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    auto inputTensorSize = 1 * input_id.size();
    std::vector<float> inputTensorValues(inputTensorSize);
    int64_t* input_test = (int64_t*)input;
    Ort::Value inputTensor = Ort::Value::CreateTensor<int64_t>(memoryInfo, input_test, inputTensorValues.size(), inputShapes.data(), inputShapes.size());
    std::vector<Ort::Value> intput_tensors;
    intput_tensors.push_back(std::move(inputTensor));

    // 推理  
    auto output_tensors = session->Run(Ort::RunOptions{nullptr}, input_node_names.data(), intput_tensors.data(), input_node_names.size(), output_node_names.data(), output_node_names.size());

    // 保存推理结果
    const float* data = output_tensors[0].GetTensorMutableData<float>();
    long unsigned int n = 0;
    std::vector<Predictions> resultsOfPredictions(22557);
    for(int i=(input_id.size()-1)*22557; i<input_id.size()*22557; ++i)
    {
        resultsOfPredictions[n].index = n;
        resultsOfPredictions[n].predictionvalue = data[i];
        ++n;
    }

    // 对于[UNK]的概率设为无穷小,模型的预测结果不可能是[UNK]
    resultsOfPredictions[100].predictionvalue = -10000;

    // 排序
    std::sort(resultsOfPredictions.begin(), resultsOfPredictions.end(), CompareM);
    return resultsOfPredictions[0].index;
}

ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer,
                             char *question,
                             std::vector<long unsigned int> &input_id)
{
    // 分词操作
    int max_seq_length =1000;
    std::vector<std::string> tokens_question;
    tokens_question.reserve(max_seq_length);
    tokenizer.tokenize(question, &tokens_question, max_seq_length);

    // 保存编码信息
    input_id.push_back(tokenizer.convert_token_to_id("[CLS]"));
    for (int i=0;i<tokens_question.size();++i) 
    {
        input_id.push_back(tokenizer.convert_token_to_id(tokens_question[i]));
    }
    return SUCCESS;
}
}