Bert.cpp 14.7 KB
Newer Older
liucong's avatar
liucong committed
1
#include <Bert.h>
2
3
4
5
#include <Filesystem.h>
#include <SimpleLog.h>
#include <tokenization.h>

liucong's avatar
liucong committed
6
7
8
9
#include <algorithm>
#include <migraphx/gpu/target.hpp>
#include <migraphx/onnx.hpp>
#include <stdexcept>
10

liucong's avatar
liucong committed
11
namespace migraphxSamples {
12

liucong's avatar
liucong committed
13
Bert::Bert() {}
14

liucong's avatar
liucong committed
15
Bert::~Bert() {}
16

liucong's avatar
liucong committed
17
ErrorCode Bert::Initialize()
18
{
liucong's avatar
liucong committed
19
    // 获取模型文件
liucong's avatar
liucong committed
20
    std::string modelPath = "../Resource/bertsquad-10.onnx";
21

liucong's avatar
liucong committed
22
23
    // 设置最大输入shape
    migraphx::onnx_options onnx_options;
liucong's avatar
liucong committed
24
25
26
27
    onnx_options.map_input_dims["unique_ids_raw_output___9:0"] = {1};
    onnx_options.map_input_dims["input_ids:0"]                 = {1, 256};
    onnx_options.map_input_dims["input_mask:0"]                = {1, 256};
    onnx_options.map_input_dims["segment_ids:0"]               = {1, 256};
liucong's avatar
liucong committed
28

29
    // 加载模型
liucong's avatar
liucong committed
30
    if(!Exists(modelPath))
31
    {
liucong's avatar
liucong committed
32
        LOG_ERROR(stdout, "%s not exist!\n", modelPath.c_str());
33
34
        return MODEL_NOT_EXIST;
    }
liucong's avatar
liucong committed
35
36
    net = migraphx::parse_onnx(modelPath, onnx_options);
    LOG_INFO(stdout, "succeed to load model: %s\n", GetFileName(modelPath).c_str());
37

liucong's avatar
liucong committed
38
    // 获取模型输入/输出节点信息
liucong's avatar
liucong committed
39
40
    std::unordered_map<std::string, migraphx::shape> inputs  = net.get_inputs();
    std::unordered_map<std::string, migraphx::shape> outputs = net.get_outputs();
liucong's avatar
liucong committed
41

liucong's avatar
liucong committed
42
    inputName1  = "unique_ids_raw_output___9:0";
liucong's avatar
liucong committed
43
    inputShape1 = inputs.at(inputName1);
44

liucong's avatar
liucong committed
45
    inputName2  = "segment_ids:0";
liucong's avatar
liucong committed
46
    inputShape2 = inputs.at(inputName2);
47

liucong's avatar
liucong committed
48
    inputName3  = "input_mask:0";
liucong's avatar
liucong committed
49
    inputShape3 = inputs.at(inputName3);
50

liucong's avatar
liucong committed
51
    inputName4  = "input_ids:0";
liucong's avatar
liucong committed
52
    inputShape4 = inputs.at(inputName4);
53
54
55
56
57
58

    // 设置模型为GPU模式
    migraphx::target gpuTarget = migraphx::gpu::target{};

    // 编译模型
    migraphx::compile_options options;
liucong's avatar
liucong committed
59
60
61
62
    options.device_id    = 0; // 设置GPU设备,默认为0号设备
    options.offload_copy = true;
    net.compile(gpuTarget, options);
    LOG_INFO(stdout, "succeed to compile model: %s\n", GetFileName(modelPath).c_str());
liucong's avatar
liucong committed
63
64
65

    // warm up
    std::unordered_map<std::string, migraphx::argument> inputData;
liucong's avatar
liucong committed
66
67
68
69
    inputData[inputName1] = migraphx::argument(inputShape1);
    inputData[inputName2] = migraphx::argument(inputShape2);
    inputData[inputName3] = migraphx::argument(inputShape3);
    inputData[inputName4] = migraphx::argument(inputShape4);
70
71
72
73
74
    net.eval(inputData);

    return SUCCESS;
}

liucong's avatar
liucong committed
75
76
77
78
79
ErrorCode Bert::Inference(const std::vector<std::vector<long unsigned int>>& input_ids,
                          const std::vector<std::vector<long unsigned int>>& input_masks,
                          const std::vector<std::vector<long unsigned int>>& segment_ids,
                          std::vector<float>& start_position,
                          std::vector<float>& end_position)
80
81
82
83
84
85
86
{
    // 保存预处理后的数据
    int num = input_ids.size();
    long unsigned int input_id[num][256];
    long unsigned int input_mask[num][256];
    long unsigned int segment_id[num][256];
    long unsigned int position_id[num][1];
liucong's avatar
liucong committed
87
    for(int i = 0; i < input_ids.size(); ++i)
88
    {
liucong's avatar
liucong committed
89
        for(int j = 0; j < input_ids[0].size(); ++j)
90
        {
liucong's avatar
liucong committed
91
92
93
            input_id[i][j]    = input_ids[i][j];
            segment_id[i][j]  = segment_ids[i][j];
            input_mask[i][j]  = input_masks[i][j];
94
95
96
97
            position_id[i][0] = 1;
        }
    }

liucong's avatar
liucong committed
98
    std::unordered_map<std::string, migraphx::argument> inputData;
99
100
101
102
103
104
    std::vector<migraphx::argument> results;
    migraphx::argument start_prediction;
    migraphx::argument end_prediction;
    float* start_data;
    float* end_data;

liucong's avatar
liucong committed
105
    for(int i = 0; i < input_ids.size(); ++i)
106
    {
liucong's avatar
liucong committed
107
        // 创建输入数据
liucong's avatar
liucong committed
108
109
110
111
        inputData[inputName1] = migraphx::argument{inputShape1, (long unsigned int*)position_id[i]};
        inputData[inputName2] = migraphx::argument{inputShape2, (long unsigned int*)segment_id[i]};
        inputData[inputName3] = migraphx::argument{inputShape3, (long unsigned int*)input_mask[i]};
        inputData[inputName4] = migraphx::argument{inputShape4, (long unsigned int*)input_id[i]};
112
113
114
115
116

        // 推理
        results = net.eval(inputData);

        // 获取输出节点的属性
liucong's avatar
liucong committed
117
118
119
120
        start_prediction = results[1];                      // 答案的开始位置
        start_data       = (float*)start_prediction.data(); // 开始位置的数据指针
        end_prediction   = results[0];                      // 答案的结束位置
        end_data         = (float*)end_prediction.data();   // 结束位置的数据指针
121
122

        // 保存推理结果
liucong's avatar
liucong committed
123
        for(int i = 0; i < 256; ++i)
124
125
126
127
128
129
130
131
132
133
        {
            start_position.push_back(start_data[i]);
            end_position.push_back(end_data[i]);
        }
    }

    return SUCCESS;
}

ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer,
liucong's avatar
liucong committed
134
135
136
137
138
139
140
                              int batch_size,
                              int max_seq_length,
                              const char* text,
                              char* question,
                              std::vector<std::vector<long unsigned int>>& input_ids,
                              std::vector<std::vector<long unsigned int>>& input_masks,
                              std::vector<std::vector<long unsigned int>>& segment_ids)
141
{
liucong's avatar
liucong committed
142
    std::vector<long unsigned int> input_id(max_seq_length);
143
144
145
146
    std::vector<long unsigned int> input_mask(max_seq_length);
    std::vector<long unsigned int> segment_id(max_seq_length);

    // 对上下文文本和问题进行分词操作
liucong's avatar
liucong committed
147
148
149
    tokens_text.reserve(max_seq_length);
    tokens_question.reserve(max_seq_length);
    tokenizer.tokenize(text, &tokens_text, max_seq_length);
150
151
152
153
154
    tokenizer.tokenize(question, &tokens_question, max_seq_length);

    // 当上下文文本加问题文本的长度大于规定的最大长度,采用滑动窗口操作
    if(tokens_text.size() + tokens_question.size() > max_seq_length - 5)
    {
liucong's avatar
liucong committed
155
        int windows_len = max_seq_length - 5 - tokens_question.size();
156
157
158
        std::vector<std::string> tokens_text_window(windows_len);
        std::vector<std::vector<std::string>> tokens_text_windows;
        int start_offset = 0;
liucong's avatar
liucong committed
159
        int position     = 0;
160
        int n;
liucong's avatar
liucong committed
161
        while(start_offset < tokens_text.size())
162
163
        {
            n = 0;
liucong's avatar
liucong committed
164
            if(start_offset + windows_len > tokens_text.size())
165
            {
liucong's avatar
liucong committed
166
                for(int i = start_offset; i < tokens_text.size(); ++i)
167
168
169
170
171
172
173
                {
                    tokens_text_window[n] = tokens_text[i];
                    ++n;
                }
            }
            else
            {
liucong's avatar
liucong committed
174
                for(int i = start_offset; i < start_offset + windows_len; ++i)
175
176
177
178
179
180
                {
                    tokens_text_window[n] = tokens_text[i];
                    ++n;
                }
            }
            tokens_text_windows.push_back(tokens_text_window);
liucong's avatar
liucong committed
181
            start_offset += 256;
182
183
184
            ++position;
        }

liucong's avatar
liucong committed
185
        for(int i = 0; i < position; ++i)
186
        {
liucong's avatar
liucong committed
187
            input_id[0]   = tokenizer.convert_token_to_id("[CLS]");
188
189
            segment_id[0] = 0;

liucong's avatar
liucong committed
190
            input_id[1]   = tokenizer.convert_token_to_id("[CLS]");
191
192
            segment_id[1] = 0;

liucong's avatar
liucong committed
193
            for(int j = 0; j < tokens_question.size(); ++j)
194
            {
liucong's avatar
liucong committed
195
                input_id[j + 2]   = tokenizer.convert_token_to_id(tokens_question[j]);
196
197
198
                segment_id[j + 2] = 0;
            }

liucong's avatar
liucong committed
199
            input_id[tokens_question.size() + 2]   = tokenizer.convert_token_to_id("[SEP]");
200
201
            segment_id[tokens_question.size() + 2] = 0;

liucong's avatar
liucong committed
202
            input_id[tokens_question.size() + 3]   = tokenizer.convert_token_to_id("[SEP]");
203
            segment_id[tokens_question.size() + 3] = 0;
liucong's avatar
liucong committed
204
205

            for(int j = 0; j < tokens_question.size(); ++j)
206
            {
liucong's avatar
liucong committed
207
208
                input_id[j + tokens_text_windows[i].size() + 4] =
                    tokenizer.convert_token_to_id(tokens_text_windows[i][j]);
209
210
211
                segment_id[j + tokens_text_windows[i].size() + 4] = 1;
            }

liucong's avatar
liucong committed
212
213
            input_id[tokens_question.size() + tokens_text_windows[i].size() + 4] =
                tokenizer.convert_token_to_id("[SEP]");
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
            segment_id[tokens_question.size() + tokens_text_windows[i].size() + 4] = 1;

            // 掩码为1的表示为真实标记,0表示为填充标记。
            int len = tokens_text_windows[i].size() + tokens_question.size() + 5;
            std::fill(input_mask.begin(), input_mask.begin() + len, 1);
            std::fill(input_mask.begin() + len, input_mask.begin() + max_seq_length, 0);
            std::fill(input_id.begin() + len, input_id.begin() + max_seq_length, 0);
            std::fill(segment_id.begin() + len, segment_id.begin() + max_seq_length, 0);
            input_ids.push_back(input_id);
            input_masks.push_back(input_mask);
            segment_ids.push_back(segment_id);
        }
    }
    else
    {
        // 当上下文文本加问题文本的长度小于等于规定的最大长度,直接拼接处理
liucong's avatar
liucong committed
230
        input_id[0]   = tokenizer.convert_token_to_id("[CLS]");
231
232
        segment_id[0] = 0;

liucong's avatar
liucong committed
233
        input_id[1]   = tokenizer.convert_token_to_id("[CLS]");
234
235
        segment_id[1] = 0;

liucong's avatar
liucong committed
236
        for(int i = 0; i < tokens_question.size(); ++i)
237
        {
liucong's avatar
liucong committed
238
            input_id[i + 2]   = tokenizer.convert_token_to_id(tokens_question[i]);
239
240
241
            segment_id[i + 2] = 0;
        }

liucong's avatar
liucong committed
242
        input_id[tokens_question.size() + 2]   = tokenizer.convert_token_to_id("[SEP]");
243
244
        segment_id[tokens_question.size() + 2] = 0;

liucong's avatar
liucong committed
245
        input_id[tokens_question.size() + 3]   = tokenizer.convert_token_to_id("[SEP]");
246
247
        segment_id[tokens_question.size() + 3] = 0;

liucong's avatar
liucong committed
248
        for(int i = 0; i < tokens_text.size(); ++i)
249
        {
liucong's avatar
liucong committed
250
251
            input_id[i + tokens_question.size() + 4] =
                tokenizer.convert_token_to_id(tokens_text[i]);
252
253
254
            segment_id[i + tokens_question.size() + 4] = 1;
        }

liucong's avatar
liucong committed
255
256
        input_id[tokens_question.size() + tokens_text.size() + 4] =
            tokenizer.convert_token_to_id("[SEP]");
257
258
259
260
261
262
263
264
265
266
267
268
        segment_id[tokens_question.size() + tokens_text.size() + 4] = 1;

        // 掩码为1的表示为真实标记,0表示为填充标记。
        int len = tokens_text.size() + tokens_question.size() + 5;
        std::fill(input_mask.begin(), input_mask.begin() + len, 1);
        std::fill(input_mask.begin() + len, input_mask.begin() + max_seq_length, 0);
        std::fill(input_id.begin() + len, input_id.begin() + max_seq_length, 0);
        std::fill(segment_id.begin() + len, segment_id.begin() + max_seq_length, 0);
        input_ids.push_back(input_id);
        input_masks.push_back(input_mask);
        segment_ids.push_back(segment_id);
    }
liucong's avatar
liucong committed
269

270
271
272
    return SUCCESS;
}

liucong's avatar
liucong committed
273
static bool Compare(Sort_st a, Sort_st b) { return a.value > b.value; }
274
275
276

static bool CompareM(ResultOfPredictions a, ResultOfPredictions b)
{
liucong's avatar
liucong committed
277
278
    return a.start_predictionvalue + a.end_predictionvalue >
           b.start_predictionvalue + b.end_predictionvalue;
279
280
281
282
}

ErrorCode Bert::Postprocessing(int n_best_size,
                               int max_answer_length,
liucong's avatar
liucong committed
283
284
285
                               const std::vector<float>& start_position,
                               const std::vector<float>& end_position,
                               std::string& answer)
286
287
288
289
290
{
    // 取前n_best_size个最大概率值的索引
    std::vector<Sort_st> start_array(start_position.size());
    std::vector<Sort_st> end_array(end_position.size());

liucong's avatar
liucong committed
291
    for(int i = 0; i < start_position.size(); ++i)
292
    {
liucong's avatar
liucong committed
293
294
295
296
297
        start_array[i].index = i;
        start_array[i].value = start_position.at(i);
        end_array[i].index   = i;
        end_array[i].value   = end_position.at(i);
    }
298
299
    std::sort(start_array.begin(), start_array.end(), Compare);
    std::sort(end_array.begin(), end_array.end(), Compare);
liucong's avatar
liucong committed
300

301
302
303
304
    // 过滤和筛选,筛选掉不符合的索引
    std::vector<ResultOfPredictions> resultsOfPredictions(400);
    int num = start_position.size() / 256;
    bool flag;
liucong's avatar
liucong committed
305
306
    int n = 0;
    for(int i = 0; i < n_best_size; ++i)
307
    {
liucong's avatar
liucong committed
308
        for(int j = 0; j < n_best_size; ++j)
309
310
311
312
313
314
315
316
317
318
319
320
        {
            flag = false;
            if(start_array[i].index > start_position.size())
            {
                continue;
            }

            if(end_array[j].index > end_position.size())
            {
                continue;
            }

liucong's avatar
liucong committed
321
            for(int t = 0; t < num; ++t)
322
            {
liucong's avatar
liucong committed
323
324
                if(start_array[i].index > t * 256 &&
                   start_array[i].index < tokens_question.size() + 4 + t * 256)
325
326
327
328
329
                {
                    flag = true;
                    break;
                }

liucong's avatar
liucong committed
330
331
                if(end_array[j].index > t * 256 &&
                   end_array[j].index < tokens_question.size() + 4 + t * 256)
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
                {
                    flag = true;
                    break;
                }
            }
            if(flag)
            {
                continue;
            }

            if(start_array[i].index > end_array[j].index)
            {
                continue;
            }

liucong's avatar
liucong committed
347
            int length = end_array[j].index - start_array[i].index + 1;
348
349
350
351
352
            if(length > max_answer_length)
            {
                continue;
            }

liucong's avatar
liucong committed
353
354
            resultsOfPredictions[n].start_index           = start_array[i].index;
            resultsOfPredictions[n].end_index             = end_array[j].index;
355
            resultsOfPredictions[n].start_predictionvalue = start_array[i].value;
liucong's avatar
liucong committed
356
            resultsOfPredictions[n].end_predictionvalue   = end_array[j].value;
357
358
359
            ++n;
        }
    }
liucong's avatar
liucong committed
360

361
362
363
364
    // 排序,将开始索引加结束索引的概率值和最大的排在前面
    std::sort(resultsOfPredictions.begin(), resultsOfPredictions.end(), CompareM);

    int start_index = 0;
liucong's avatar
liucong committed
365
366
    int end_index   = 0;
    for(int i = 0; i < 400; ++i)
367
    {
liucong's avatar
liucong committed
368
369
        if(resultsOfPredictions[i].start_predictionvalue == 0 &&
           resultsOfPredictions[i].end_predictionvalue == 0)
370
371
372
373
        {
            continue;
        }
        start_index = resultsOfPredictions[i].start_index;
liucong's avatar
liucong committed
374
        end_index   = resultsOfPredictions[i].end_index;
375
376
        break;
    }
liucong's avatar
liucong committed
377

378
    // 映射回上下文文本的索引,(当前的索引值-问题的长度-4)
liucong's avatar
liucong committed
379
380
    int answer_start_index = start_index - tokens_question.size() - 4;
    int answer_end_index   = end_index - tokens_question.size() - 4 + 1;
381
382

    // 根据开始索引和结束索引,获取区间内的数据
liucong's avatar
liucong committed
383
384
    int j = 0;
    for(int i = answer_start_index; i < answer_end_index; ++i)
385
386
387
    {
        if(tokens_text[i].find('#') != -1)
        {
liucong's avatar
liucong committed
388
            j = i - 1;
389
390
391
392
            break;
        }
    }

liucong's avatar
liucong committed
393
    for(int i = answer_start_index; i < answer_end_index; ++i)
394
395
    {
        answer += tokens_text[i];
liucong's avatar
liucong committed
396
        if(tokens_text[i].find('#') != -1 || i == j)
397
398
399
400
401
        {
            continue;
        }
        answer += " ";
    }
liucong's avatar
liucong committed
402

403
    int index = 0;
liucong's avatar
liucong committed
404
    while((index = answer.find('#', index)) != string::npos)
405
    {
liucong's avatar
liucong committed
406
407
        answer.erase(index, 1);
    }
408
409
410
411
412
413
    tokens_text.clear();
    tokens_question.clear();

    return SUCCESS;
}

liucong's avatar
liucong committed
414
} // namespace migraphxSamples