GPT2.cpp 3.75 KB
Newer Older
liucong's avatar
liucong committed
1
2
#include <GPT2.h>

3
4
#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>
liucong's avatar
liucong committed
5
6

#include <Filesystem.h>
7
8
9
10
11
12
13
14
#include <SimpleLog.h>
#include <algorithm>
#include <stdexcept>
#include <tokenization.h>

namespace migraphxSamples
{

liucong's avatar
liucong committed
15
GPT2::GPT2()
16
17
18
19
20
21
{

}

GPT2::~GPT2()
{
liucong's avatar
liucong committed
22

23
24
}

liucong's avatar
liucong committed
25
ErrorCode GPT2::Initialize()
26
{
liucong's avatar
liucong committed
27
28
    // 获取模型文件
    std::string modelPath="../Resource/GPT2_shici.onnx";
29
30
31

    // 设置最大输入shape
    migraphx::onnx_options onnx_options;
liucong's avatar
liucong committed
32
    onnx_options.map_input_dims["input"]={1,1000};
33
34

    // 加载模型
liucong's avatar
liucong committed
35
36
37
38
39
    if(Exists(modelPath)==false)
    {
        LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
        return MODEL_NOT_EXIST;
    }
40
    net = migraphx::parse_onnx(modelPath, onnx_options);        
liucong's avatar
liucong committed
41
    LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
42
43

    // 获取模型输入属性
liucong's avatar
liucong committed
44
45
46
    std::unordered_map<std::string, migraphx::shape> inputMap=net.get_parameter_shapes();
    inputName=inputMap.begin()->first;
    inputShape=inputMap.begin()->second;
47
48
49
50
51
52
53

    // 设置模型为GPU模式
    migraphx::target gpuTarget = migraphx::gpu::target{};

    // 编译模型
    migraphx::compile_options options;
    options.device_id=0;                          // 设置GPU设备,默认为0号设备
liucong's avatar
liucong committed
54
    options.offload_copy=true;                    // 设置offload_copy
55
    net.compile(gpuTarget,options);
liucong's avatar
liucong committed
56
    LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str());                     
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

    return SUCCESS;
}

static bool CompareM(Predictions a, Predictions b)
{
	return a.predictionvalue > b.predictionvalue;
}

long unsigned int GPT2::Inference(const std::vector<long unsigned int> &input_id)
{

    long unsigned int input[1][input_id.size()];
    for (int j=0;j<input_id.size();++j)
    {
        input[0][j] = input_id[j];
    }

    // 设置输入shape
    std::vector<std::vector<std::size_t>> inputShapes;
    inputShapes.push_back({1,input_id.size()});

liucong's avatar
liucong committed
79
    // 创建输入数据
liucong's avatar
liucong committed
80
    std::unordered_map<std::string, migraphx::argument> inputData;
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    inputData[inputName]=migraphx::argument{migraphx::shape(inputShape.type(),inputShapes[0]),(long unsigned int*)input};

    // 推理
    std::vector<migraphx::argument> results = net.eval(inputData);

    // 获取输出节点的属性
    migraphx::argument result = results[0];
    migraphx::shape outputShape = result.get_shape();       // 输出节点的shape
    int numberOfOutput=outputShape.elements();              // 输出节点元素的个数
    float *data = (float *)result.data();                   // 输出节点数据指针

    // 保存推理结果
    long unsigned int n = 0;
    std::vector<Predictions> resultsOfPredictions(22557);
    for(int i=(input_id.size()-1)*22557; i<input_id.size()*22557; ++i)
    {
        resultsOfPredictions[n].index = n;
        resultsOfPredictions[n].predictionvalue = data[i];
        ++n;
    }

    // 对于[UNK]的概率设为无穷小,模型的预测结果不可能是[UNK]
    resultsOfPredictions[100].predictionvalue = -10000;

    // 排序
    std::sort(resultsOfPredictions.begin(), resultsOfPredictions.end(), CompareM);

    return resultsOfPredictions[0].index;
}

ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer,
                             char *question,
                             std::vector<long unsigned int> &input_id)
{
    // 分词操作
liucong's avatar
liucong committed
116
    int max_seq_length =1000;
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    std::vector<std::string> tokens_question;
    tokens_question.reserve(max_seq_length);
    tokenizer.tokenize(question, &tokens_question, max_seq_length);

    // 保存编码信息
    input_id.push_back(tokenizer.convert_token_to_id("[CLS]"));
    for (int i=0;i<tokens_question.size();++i) 
    {
        input_id.push_back(tokenizer.convert_token_to_id(tokens_question[i]));
    }

    return SUCCESS;
}

}