GPT2.cpp 3.65 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#include <GPT2.h>

#include <Filesystem.h>
#include <SimpleLog.h>
#include <algorithm>
#include <stdexcept>
#include <tokenization.h>

namespace ortSamples
{

GPT2::GPT2()
{

}

GPT2::~GPT2()
{

}

ErrorCode GPT2::Initialize()
{
    // 获取模型文件
    std::string modelPath="../Resource/GPT2_shici.onnx";

    // 判断路径
    if(Exists(modelPath)==false)
    {
        LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
        return MODEL_NOT_EXIST;
    }
    
    //加载模型
    OrtROCMProviderOptions rocm_options;
    rocm_options.device_id = 0;
    sessionOptions.AppendExecutionProvider_ROCM(rocm_options);
    sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_BASIC);
    LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
    session = new Ort::Session(env, modelPath.c_str(), sessionOptions);             

    return SUCCESS;
}

static bool CompareM(Predictions a, Predictions b)
{
yangql's avatar
yangql committed
47
    return a.predictionvalue > b.predictionvalue;
yangql's avatar
yangql committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
}

long unsigned int GPT2::Inference(const std::vector<long unsigned int> &input_id)
{

    int64_t input[1][input_id.size()];
    for (int j=0;j<input_id.size();++j)
    {

        input[0][j] = static_cast<int64_t>(input_id[j]);;
    }

    // 获取模型输入属性
    input_node_names = {"input"};
  
    // 获取模型输出属性
    output_node_names = {"output","304","value.3","532","value.11","758","value.19","984","value.27","1210","value.35","1436","value.43","1662","value.51","1888","value.59","2114","value.67","2340","value.75","2566","value.83","2792","value.91"};

    // 设置输入shape
    std::array<int64_t, 2> inputShapes{1, (int)input_id.size()};
    
    // 创建输入数据
    auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    auto inputTensorSize = 1 * input_id.size();
    std::vector<float> inputTensorValues(inputTensorSize);
    int64_t* input_test = (int64_t*)input;
    Ort::Value inputTensor = Ort::Value::CreateTensor<int64_t>(memoryInfo, input_test, inputTensorValues.size(), inputShapes.data(), inputShapes.size());
    std::vector<Ort::Value> intput_tensors;
    intput_tensors.push_back(std::move(inputTensor));

    // 推理  
    auto output_tensors = session->Run(Ort::RunOptions{nullptr}, input_node_names.data(), intput_tensors.data(), input_node_names.size(), output_node_names.data(), output_node_names.size());

    // 保存推理结果
    const float* data = output_tensors[0].GetTensorMutableData<float>();
    long unsigned int n = 0;
    std::vector<Predictions> resultsOfPredictions(22557);
    for(int i=(input_id.size()-1)*22557; i<input_id.size()*22557; ++i)
    {
        resultsOfPredictions[n].index = n;
        resultsOfPredictions[n].predictionvalue = data[i];
        ++n;
    }

    // 对于[UNK]的概率设为无穷小,模型的预测结果不可能是[UNK]
    resultsOfPredictions[100].predictionvalue = -10000;

    // 排序
    std::sort(resultsOfPredictions.begin(), resultsOfPredictions.end(), CompareM);
    return resultsOfPredictions[0].index;
}

ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer,
                             char *question,
                             std::vector<long unsigned int> &input_id)
{
    // 分词操作
    int max_seq_length =1000;
    std::vector<std::string> tokens_question;
    tokens_question.reserve(max_seq_length);
    tokenizer.tokenize(question, &tokens_question, max_seq_length);

    // 保存编码信息
    input_id.push_back(tokenizer.convert_token_to_id("[CLS]"));
    for (int i=0;i<tokens_question.size();++i) 
    {
        input_id.push_back(tokenizer.convert_token_to_id(tokens_question[i]));
    }
    return SUCCESS;
}
yangql's avatar
yangql committed
118
}