GPT2.cpp 3.75 KB
Newer Older
liucong's avatar
liucong committed
1
2
#include <GPT2.h>

3
4
#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>
liucong's avatar
liucong committed
5
6

#include <Filesystem.h>
7
8
9
10
11
12
13
14
#include <SimpleLog.h>
#include <algorithm>
#include <stdexcept>
#include <tokenization.h>

namespace migraphxSamples
{

liucong's avatar
liucong committed
15
GPT2::GPT2()
16
17
18
19
20
21
{

}

GPT2::~GPT2()
{
liucong's avatar
liucong committed
22

23
24
}

liucong's avatar
liucong committed
25
ErrorCode GPT2::Initialize()
26
{
liucong's avatar
liucong committed
27
28
    // 获取模型文件
    std::string modelPath="../Resource/GPT2_shici.onnx";
29
30
31

    // 设置最大输入shape
    migraphx::onnx_options onnx_options;
liucong's avatar
liucong committed
32
    onnx_options.map_input_dims["input"]={1,1000};
33
34

    // 加载模型
liucong's avatar
liucong committed
35
36
37
38
39
    if(Exists(modelPath)==false)
    {
        LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
        return MODEL_NOT_EXIST;
    }
40
    net = migraphx::parse_onnx(modelPath, onnx_options);        
liucong's avatar
liucong committed
41
    LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
42
43

    // 获取模型输入属性
liucong's avatar
liucong committed
44
45
46
    std::unordered_map<std::string, migraphx::shape> inputMap=net.get_parameter_shapes();
    inputName=inputMap.begin()->first;
    inputShape=inputMap.begin()->second;
47
48
49
50
51
52
53

    // 设置模型为GPU模式
    migraphx::target gpuTarget = migraphx::gpu::target{};

    // 编译模型
    migraphx::compile_options options;
    options.device_id=0;                          // 设置GPU设备,默认为0号设备
liucong's avatar
liucong committed
54
    options.offload_copy=true;                    // 设置offload_copy
55
    net.compile(gpuTarget,options);
liucong's avatar
liucong committed
56
    LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str());                     
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

    return SUCCESS;
}

static bool CompareM(Predictions a, Predictions b)
{
	return a.predictionvalue > b.predictionvalue;
}

long unsigned int GPT2::Inference(const std::vector<long unsigned int> &input_id)
{

    long unsigned int input[1][input_id.size()];
    for (int j=0;j<input_id.size();++j)
    {
        input[0][j] = input_id[j];
    }

    // 设置输入shape
    std::vector<std::vector<std::size_t>> inputShapes;
    inputShapes.push_back({1,input_id.size()});

liucong's avatar
liucong committed
79
    // 创建输入数据
liucong's avatar
liucong committed
80
    std::unordered_map<std::string, migraphx::argument> inputData;
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    inputData[inputName]=migraphx::argument{migraphx::shape(inputShape.type(),inputShapes[0]),(long unsigned int*)input};

    // 推理
    std::vector<migraphx::argument> results = net.eval(inputData);

    // 获取输出节点的属性
    migraphx::argument result = results[0];
    migraphx::shape outputShape = result.get_shape();       // 输出节点的shape
    int numberOfOutput=outputShape.elements();              // 输出节点元素的个数
    float *data = (float *)result.data();                   // 输出节点数据指针

    // 保存推理结果
    long unsigned int n = 0;
    std::vector<Predictions> resultsOfPredictions(22557);
    for(int i=(input_id.size()-1)*22557; i<input_id.size()*22557; ++i)
    {
        resultsOfPredictions[n].index = n;
        resultsOfPredictions[n].predictionvalue = data[i];
        ++n;
    }

    // 对于[UNK]的概率设为无穷小,模型的预测结果不可能是[UNK]
    resultsOfPredictions[100].predictionvalue = -10000;

    // 排序
    std::sort(resultsOfPredictions.begin(), resultsOfPredictions.end(), CompareM);

    return resultsOfPredictions[0].index;
}

ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer,
                             char *question,
                             std::vector<long unsigned int> &input_id)
{
    // 分词操作
    int max_seq_length =1024;
    std::vector<std::string> tokens_question;
    tokens_question.reserve(max_seq_length);
    tokenizer.tokenize(question, &tokens_question, max_seq_length);

    // 保存编码信息
    input_id.push_back(tokenizer.convert_token_to_id("[CLS]"));
    for (int i=0;i<tokens_question.size();++i) 
    {
        input_id.push_back(tokenizer.convert_token_to_id(tokens_question[i]));
    }

    return SUCCESS;
}

}