#include #include #include #include #include #include #include #include namespace migraphxSamples { Bert::Bert() {} Bert::~Bert() {} ErrorCode Bert::Initialize() { // 获取模型文件 std::string modelPath = "../Resource/bertsquad-10.onnx"; // 设置最大输入shape migraphx::onnx_options onnx_options; onnx_options.map_input_dims["unique_ids_raw_output___9:0"] = {1}; onnx_options.map_input_dims["input_ids:0"] = {1, 256}; onnx_options.map_input_dims["input_mask:0"] = {1, 256}; onnx_options.map_input_dims["segment_ids:0"] = {1, 256}; // 加载模型 if(!Exists(modelPath)) { LOG_ERROR(stdout, "%s not exist!\n", modelPath.c_str()); return MODEL_NOT_EXIST; } net = migraphx::parse_onnx(modelPath, onnx_options); LOG_INFO(stdout, "succeed to load model: %s\n", GetFileName(modelPath).c_str()); // 获取模型输入/输出节点信息 std::unordered_map inputs = net.get_inputs(); std::unordered_map outputs = net.get_outputs(); inputName1 = "unique_ids_raw_output___9:0"; inputShape1 = inputs.at(inputName1); inputName2 = "segment_ids:0"; inputShape2 = inputs.at(inputName2); inputName3 = "input_mask:0"; inputShape3 = inputs.at(inputName3); inputName4 = "input_ids:0"; inputShape4 = inputs.at(inputName4); // 设置模型为GPU模式 migraphx::target gpuTarget = migraphx::gpu::target{}; // 编译模型 migraphx::compile_options options; options.device_id = 0; // 设置GPU设备,默认为0号设备 options.offload_copy = true; net.compile(gpuTarget, options); LOG_INFO(stdout, "succeed to compile model: %s\n", GetFileName(modelPath).c_str()); // warm up std::unordered_map inputData; inputData[inputName1] = migraphx::argument(inputShape1); inputData[inputName2] = migraphx::argument(inputShape2); inputData[inputName3] = migraphx::argument(inputShape3); inputData[inputName4] = migraphx::argument(inputShape4); net.eval(inputData); return SUCCESS; } ErrorCode Bert::Inference(const std::vector>& input_ids, const std::vector>& input_masks, const std::vector>& segment_ids, std::vector& start_position, std::vector& end_position) { // 保存预处理后的数据 int num = input_ids.size(); long unsigned int input_id[num][256]; long unsigned int input_mask[num][256]; long unsigned int segment_id[num][256]; long unsigned int position_id[num][1]; for(int i = 0; i < input_ids.size(); ++i) { for(int j = 0; j < input_ids[0].size(); ++j) { input_id[i][j] = input_ids[i][j]; segment_id[i][j] = segment_ids[i][j]; input_mask[i][j] = input_masks[i][j]; position_id[i][0] = 1; } } std::unordered_map inputData; std::vector results; migraphx::argument start_prediction; migraphx::argument end_prediction; float* start_data; float* end_data; for(int i = 0; i < input_ids.size(); ++i) { // 创建输入数据 inputData[inputName1] = migraphx::argument{inputShape1, (long unsigned int*)position_id[i]}; inputData[inputName2] = migraphx::argument{inputShape2, (long unsigned int*)segment_id[i]}; inputData[inputName3] = migraphx::argument{inputShape3, (long unsigned int*)input_mask[i]}; inputData[inputName4] = migraphx::argument{inputShape4, (long unsigned int*)input_id[i]}; // 推理 results = net.eval(inputData); // 获取输出节点的属性 start_prediction = results[1]; // 答案的开始位置 start_data = (float*)start_prediction.data(); // 开始位置的数据指针 end_prediction = results[0]; // 答案的结束位置 end_data = (float*)end_prediction.data(); // 结束位置的数据指针 // 保存推理结果 for(int i = 0; i < 256; ++i) { start_position.push_back(start_data[i]); end_position.push_back(end_data[i]); } } return SUCCESS; } ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, int batch_size, int max_seq_length, const char* text, char* question, std::vector>& input_ids, std::vector>& input_masks, std::vector>& segment_ids) { std::vector input_id(max_seq_length); std::vector input_mask(max_seq_length); std::vector segment_id(max_seq_length); // 对上下文文本和问题进行分词操作 tokens_text.reserve(max_seq_length); tokens_question.reserve(max_seq_length); tokenizer.tokenize(text, &tokens_text, max_seq_length); tokenizer.tokenize(question, &tokens_question, max_seq_length); // 当上下文文本加问题文本的长度大于规定的最大长度,采用滑动窗口操作 if(tokens_text.size() + tokens_question.size() > max_seq_length - 5) { int windows_len = max_seq_length - 5 - tokens_question.size(); std::vector tokens_text_window(windows_len); std::vector> tokens_text_windows; int start_offset = 0; int position = 0; int n; while(start_offset < tokens_text.size()) { n = 0; if(start_offset + windows_len > tokens_text.size()) { for(int i = start_offset; i < tokens_text.size(); ++i) { tokens_text_window[n] = tokens_text[i]; ++n; } } else { for(int i = start_offset; i < start_offset + windows_len; ++i) { tokens_text_window[n] = tokens_text[i]; ++n; } } tokens_text_windows.push_back(tokens_text_window); start_offset += 256; ++position; } for(int i = 0; i < position; ++i) { input_id[0] = tokenizer.convert_token_to_id("[CLS]"); segment_id[0] = 0; input_id[1] = tokenizer.convert_token_to_id("[CLS]"); segment_id[1] = 0; for(int j = 0; j < tokens_question.size(); ++j) { input_id[j + 2] = tokenizer.convert_token_to_id(tokens_question[j]); segment_id[j + 2] = 0; } input_id[tokens_question.size() + 2] = tokenizer.convert_token_to_id("[SEP]"); segment_id[tokens_question.size() + 2] = 0; input_id[tokens_question.size() + 3] = tokenizer.convert_token_to_id("[SEP]"); segment_id[tokens_question.size() + 3] = 0; for(int j = 0; j < tokens_question.size(); ++j) { input_id[j + tokens_text_windows[i].size() + 4] = tokenizer.convert_token_to_id(tokens_text_windows[i][j]); segment_id[j + tokens_text_windows[i].size() + 4] = 1; } input_id[tokens_question.size() + tokens_text_windows[i].size() + 4] = tokenizer.convert_token_to_id("[SEP]"); segment_id[tokens_question.size() + tokens_text_windows[i].size() + 4] = 1; // 掩码为1的表示为真实标记,0表示为填充标记。 int len = tokens_text_windows[i].size() + tokens_question.size() + 5; std::fill(input_mask.begin(), input_mask.begin() + len, 1); std::fill(input_mask.begin() + len, input_mask.begin() + max_seq_length, 0); std::fill(input_id.begin() + len, input_id.begin() + max_seq_length, 0); std::fill(segment_id.begin() + len, segment_id.begin() + max_seq_length, 0); input_ids.push_back(input_id); input_masks.push_back(input_mask); segment_ids.push_back(segment_id); } } else { // 当上下文文本加问题文本的长度小于等于规定的最大长度,直接拼接处理 input_id[0] = tokenizer.convert_token_to_id("[CLS]"); segment_id[0] = 0; input_id[1] = tokenizer.convert_token_to_id("[CLS]"); segment_id[1] = 0; for(int i = 0; i < tokens_question.size(); ++i) { input_id[i + 2] = tokenizer.convert_token_to_id(tokens_question[i]); segment_id[i + 2] = 0; } input_id[tokens_question.size() + 2] = tokenizer.convert_token_to_id("[SEP]"); segment_id[tokens_question.size() + 2] = 0; input_id[tokens_question.size() + 3] = tokenizer.convert_token_to_id("[SEP]"); segment_id[tokens_question.size() + 3] = 0; for(int i = 0; i < tokens_text.size(); ++i) { input_id[i + tokens_question.size() + 4] = tokenizer.convert_token_to_id(tokens_text[i]); segment_id[i + tokens_question.size() + 4] = 1; } input_id[tokens_question.size() + tokens_text.size() + 4] = tokenizer.convert_token_to_id("[SEP]"); segment_id[tokens_question.size() + tokens_text.size() + 4] = 1; // 掩码为1的表示为真实标记,0表示为填充标记。 int len = tokens_text.size() + tokens_question.size() + 5; std::fill(input_mask.begin(), input_mask.begin() + len, 1); std::fill(input_mask.begin() + len, input_mask.begin() + max_seq_length, 0); std::fill(input_id.begin() + len, input_id.begin() + max_seq_length, 0); std::fill(segment_id.begin() + len, segment_id.begin() + max_seq_length, 0); input_ids.push_back(input_id); input_masks.push_back(input_mask); segment_ids.push_back(segment_id); } return SUCCESS; } static bool Compare(Sort_st a, Sort_st b) { return a.value > b.value; } static bool CompareM(ResultOfPredictions a, ResultOfPredictions b) { return a.start_predictionvalue + a.end_predictionvalue > b.start_predictionvalue + b.end_predictionvalue; } ErrorCode Bert::Postprocessing(int n_best_size, int max_answer_length, const std::vector& start_position, const std::vector& end_position, std::string& answer) { // 取前n_best_size个最大概率值的索引 std::vector start_array(start_position.size()); std::vector end_array(end_position.size()); for(int i = 0; i < start_position.size(); ++i) { start_array[i].index = i; start_array[i].value = start_position.at(i); end_array[i].index = i; end_array[i].value = end_position.at(i); } std::sort(start_array.begin(), start_array.end(), Compare); std::sort(end_array.begin(), end_array.end(), Compare); // 过滤和筛选,筛选掉不符合的索引 std::vector resultsOfPredictions(400); int num = start_position.size() / 256; bool flag; int n = 0; for(int i = 0; i < n_best_size; ++i) { for(int j = 0; j < n_best_size; ++j) { flag = false; if(start_array[i].index > start_position.size()) { continue; } if(end_array[j].index > end_position.size()) { continue; } for(int t = 0; t < num; ++t) { if(start_array[i].index > t * 256 && start_array[i].index < tokens_question.size() + 4 + t * 256) { flag = true; break; } if(end_array[j].index > t * 256 && end_array[j].index < tokens_question.size() + 4 + t * 256) { flag = true; break; } } if(flag) { continue; } if(start_array[i].index > end_array[j].index) { continue; } int length = end_array[j].index - start_array[i].index + 1; if(length > max_answer_length) { continue; } resultsOfPredictions[n].start_index = start_array[i].index; resultsOfPredictions[n].end_index = end_array[j].index; resultsOfPredictions[n].start_predictionvalue = start_array[i].value; resultsOfPredictions[n].end_predictionvalue = end_array[j].value; ++n; } } // 排序,将开始索引加结束索引的概率值和最大的排在前面 std::sort(resultsOfPredictions.begin(), resultsOfPredictions.end(), CompareM); int start_index = 0; int end_index = 0; for(int i = 0; i < 400; ++i) { if(resultsOfPredictions[i].start_predictionvalue == 0 && resultsOfPredictions[i].end_predictionvalue == 0) { continue; } start_index = resultsOfPredictions[i].start_index; end_index = resultsOfPredictions[i].end_index; break; } // 映射回上下文文本的索引,(当前的索引值-问题的长度-4) int answer_start_index = start_index - tokens_question.size() - 4; int answer_end_index = end_index - tokens_question.size() - 4 + 1; // 根据开始索引和结束索引,获取区间内的数据 int j = 0; for(int i = answer_start_index; i < answer_end_index; ++i) { if(tokens_text[i].find('#') != -1) { j = i - 1; break; } } for(int i = answer_start_index; i < answer_end_index; ++i) { answer += tokens_text[i]; if(tokens_text[i].find('#') != -1 || i == j) { continue; } answer += " "; } int index = 0; while((index = answer.find('#', index)) != string::npos) { answer.erase(index, 1); } tokens_text.clear(); tokens_question.clear(); return SUCCESS; } } // namespace migraphxSamples