#include #include #include #include #include #include #include #include namespace migraphxSamples { Bert::Bert() { } Bert::~Bert() { } ErrorCode Bert::Initialize() { // 获取模型文件 std::string modelPath="../Resource/bertsquad-10.onnx"; // 加载模型 if(Exists(modelPath)==false) { LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str()); return MODEL_NOT_EXIST; } net = migraphx::parse_onnx(modelPath); LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); // 获取模型输入属性 std::unordered_map input = net.get_parameter_shapes(); inputName1 = "unique_ids_raw_output___9:0"; inputShape1 = input.at(inputName1); inputName2 = "segment_ids:0"; inputShape2 = input.at(inputName2); inputName3 = "input_mask:0"; inputShape3 = input.at(inputName3); inputName4 = "input_ids:0"; inputShape4 = input.at(inputName4); // 设置模型为GPU模式 migraphx::target gpuTarget = migraphx::gpu::target{}; // 编译模型 migraphx::compile_options options; options.device_id=0; // 设置GPU设备,默认为0号设备 options.offload_copy=true; net.compile(gpuTarget,options); LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str()); // warm up std::unordered_map inputData; inputData[inputName1]=migraphx::argument(inputShape1); inputData[inputName2]=migraphx::argument(inputShape2); inputData[inputName3]=migraphx::argument(inputShape3); inputData[inputName4]=migraphx::argument(inputShape4); net.eval(inputData); return SUCCESS; } ErrorCode Bert::Inference(const std::vector> &input_ids, const std::vector> &input_masks, const std::vector> &segment_ids, std::vector &start_position, std::vector &end_position) { // 保存预处理后的数据 int num = input_ids.size(); long unsigned int input_id[num][256]; long unsigned int input_mask[num][256]; long unsigned int segment_id[num][256]; long unsigned int position_id[num][1]; for(int i=0;i inputData; std::vector results; migraphx::argument start_prediction; migraphx::argument end_prediction; float* start_data; float* end_data; for(int i=0;i> &input_ids, std::vector> &input_masks, std::vector> &segment_ids) { std::vector input_id(max_seq_length); std::vector input_mask(max_seq_length); std::vector segment_id(max_seq_length); // 对上下文文本和问题进行分词操作 tokens_text.reserve(max_seq_length); tokens_question.reserve(max_seq_length); tokenizer.tokenize(text, &tokens_text, max_seq_length); tokenizer.tokenize(question, &tokens_question, max_seq_length); // 当上下文文本加问题文本的长度大于规定的最大长度,采用滑动窗口操作 if(tokens_text.size() + tokens_question.size() > max_seq_length - 5) { int windows_len = max_seq_length - 5 -tokens_question.size(); std::vector tokens_text_window(windows_len); std::vector> tokens_text_windows; int start_offset = 0; int position = 0; int n; while (start_offset < tokens_text.size()) { n = 0; if(start_offset+windows_len>tokens_text.size()) { for(int i=start_offset;i b.value; } static bool CompareM(ResultOfPredictions a, ResultOfPredictions b) { return a.start_predictionvalue + a.end_predictionvalue > b.start_predictionvalue + b.end_predictionvalue; } ErrorCode Bert::Postprocessing(int n_best_size, int max_answer_length, const std::vector &start_position, const std::vector &end_position, std::string &answer) { // 取前n_best_size个最大概率值的索引 std::vector start_array(start_position.size()); std::vector end_array(end_position.size()); for (int i=0;i resultsOfPredictions(400); int num = start_position.size() / 256; bool flag; int n=0; for(int i=0;i start_position.size()) { continue; } if(end_array[j].index > end_position.size()) { continue; } for(int t=0;t t*256 && start_array[i].index < tokens_question.size()+4+t*256) { flag = true; break; } if(end_array[j].index > t*256 && end_array[j].index < tokens_question.size()+4+t*256) { flag = true; break; } } if(flag) { continue; } if(start_array[i].index > end_array[j].index) { continue; } int length = end_array[j].index - start_array[i].index + 1; if(length > max_answer_length) { continue; } resultsOfPredictions[n].start_index = start_array[i].index; resultsOfPredictions[n].end_index = end_array[j].index; resultsOfPredictions[n].start_predictionvalue = start_array[i].value; resultsOfPredictions[n].end_predictionvalue = end_array[j].value; ++n; } } // 排序,将开始索引加结束索引的概率值和最大的排在前面 std::sort(resultsOfPredictions.begin(), resultsOfPredictions.end(), CompareM); int start_index = 0; int end_index = 0; for(int i=0;i<400;++i) { if(resultsOfPredictions[i].start_predictionvalue==0 && resultsOfPredictions[i].end_predictionvalue==0) { continue; } start_index = resultsOfPredictions[i].start_index; end_index = resultsOfPredictions[i].end_index; break; } // 映射回上下文文本的索引,(当前的索引值-问题的长度-4) int answer_start_index = start_index - tokens_question.size()- 4; int answer_end_index = end_index - tokens_question.size() - 4 + 1; // 根据开始索引和结束索引,获取区间内的数据 int j=0; for(int i=answer_start_index;i