#include #include #include #include #include #include #include namespace ortSamples { Bert::Bert() { } Bert::~Bert() { } ErrorCode Bert::Initialize() { // 获取模型文件 std::string modelPath="../Resource/bertsquad-10.onnx"; // 加载模型 if(Exists(modelPath)==false) { LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str()); return MODEL_NOT_EXIST; } //加载模型 OrtROCMProviderOptions rocm_options; rocm_options.device_id = 0; sessionOptions.AppendExecutionProvider_ROCM(rocm_options); sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_BASIC); LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); session = new Ort::Session(env, modelPath.c_str(), sessionOptions); LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str()); return SUCCESS; } ErrorCode Bert::Inference(const std::vector> &input_ids, const std::vector> &input_masks, const std::vector> &segment_ids, std::vector &start_position, std::vector &end_position) { // 保存预处理后的数据 int num = input_ids.size(); long unsigned int input_id[num][256]; long unsigned int input_mask[num][256]; long unsigned int segment_id[num][256]; long unsigned int position_id[num][1]; for(int i=0;i input_node_names = {"unique_ids_raw_output___9:0","segment_ids:0","input_mask:0","input_ids:0"}; // 获取模型输出属性 std::vector output_node_names = {"unstack:1","unstack:0","unique_ids:0"}; // 设置输入shape std::array inputShapes1{1}; std::array inputShapes2{1, 256}; std::array inputShapes3{1, 256}; std::array inputShapes4{1, 256}; std::vector inputTensorValues1(1); std::vector inputTensorValues2(256); std::vector inputTensorValues3(256); std::vector inputTensorValues4(256); auto memoryInfo1 = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); auto memoryInfo2 = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); auto memoryInfo3 = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); auto memoryInfo4 = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); float* start_data; float* end_data; for(int i=0;i( memoryInfo1, input_test1, inputTensorValues1.size(), inputShapes1.data(), inputShapes1.size()); Ort::Value inputTensor2 = Ort::Value::CreateTensor( memoryInfo2, input_test2, inputTensorValues2.size(), inputShapes2.data(), inputShapes2.size()); Ort::Value inputTensor3 = Ort::Value::CreateTensor( memoryInfo3, input_test3, inputTensorValues3.size(), inputShapes3.data(), inputShapes3.size()); Ort::Value inputTensor4 = Ort::Value::CreateTensor( memoryInfo4, input_test4, inputTensorValues4.size(), inputShapes4.data(), inputShapes4.size()); std::vector intput_tensors; intput_tensors.push_back(std::move(inputTensor1)); intput_tensors.push_back(std::move(inputTensor2)); intput_tensors.push_back(std::move(inputTensor3)); intput_tensors.push_back(std::move(inputTensor4)); // 推理 auto output_tensors = session->Run(Ort::RunOptions{nullptr}, input_node_names.data(), intput_tensors.data(), input_node_names.size(), output_node_names.data(), output_node_names.size()); // 获取输出节点的属性 const float* start_data = output_tensors[1].GetTensorMutableData(); // 开始位置的数据指针 const float* end_data = output_tensors[0].GetTensorMutableData(); // 结束位置的数据指针 // 保存推理结果 for(int i=0;i<256;++i) { start_position.push_back(start_data[i]); end_position.push_back(end_data[i]); } } return SUCCESS; } ErrorCode Bert::Preprocessing(cuBERT::FullTokenizer tokenizer, int batch_size, int max_seq_length, const char *text, char *question, std::vector> &input_ids, std::vector> &input_masks, std::vector> &segment_ids) { std::vector input_id(max_seq_length); std::vector input_mask(max_seq_length); std::vector segment_id(max_seq_length); // 对上下文文本和问题进行分词操作 tokens_text.reserve(max_seq_length); tokens_question.reserve(max_seq_length); tokenizer.tokenize(text, &tokens_text, max_seq_length); tokenizer.tokenize(question, &tokens_question, max_seq_length); // 当上下文文本加问题文本的长度大于规定的最大长度,采用滑动窗口操作 if(tokens_text.size() + tokens_question.size() > max_seq_length - 5) { int windows_len = max_seq_length - 5 -tokens_question.size(); std::vector tokens_text_window(windows_len); std::vector> tokens_text_windows; int start_offset = 0; int position = 0; int n; while (start_offset < tokens_text.size()) { n = 0; if(start_offset+windows_len>tokens_text.size()) { for(int i=start_offset;i b.value; } static bool CompareM(ResultOfPredictions a, ResultOfPredictions b) { return a.start_predictionvalue + a.end_predictionvalue > b.start_predictionvalue + b.end_predictionvalue; } ErrorCode Bert::Postprocessing(int n_best_size, int max_answer_length, const std::vector &start_position, const std::vector &end_position, std::string &answer) { // 取前n_best_size个最大概率值的索引 std::vector start_array(start_position.size()); std::vector end_array(end_position.size()); for (int i=0;i resultsOfPredictions(400); int num = start_position.size() / 256; bool flag; int n=0; for(int i=0;i start_position.size()) { continue; } if(end_array[j].index > end_position.size()) { continue; } for(int t=0;t t*256 && start_array[i].index < tokens_question.size()+4+t*256) { flag = true; break; } if(end_array[j].index > t*256 && end_array[j].index < tokens_question.size()+4+t*256) { flag = true; break; } } if(flag) { continue; } if(start_array[i].index > end_array[j].index) { continue; } int length = end_array[j].index - start_array[i].index + 1; if(length > max_answer_length) { continue; } resultsOfPredictions[n].start_index = start_array[i].index; resultsOfPredictions[n].end_index = end_array[j].index; resultsOfPredictions[n].start_predictionvalue = start_array[i].value; resultsOfPredictions[n].end_predictionvalue = end_array[j].value; ++n; } } // 排序,将开始索引加结束索引的概率值和最大的排在前面 std::sort(resultsOfPredictions.begin(), resultsOfPredictions.end(), CompareM); int start_index = 0; int end_index = 0; for(int i=0;i<400;++i) { if(resultsOfPredictions[i].start_predictionvalue==0 && resultsOfPredictions[i].end_predictionvalue==0) { continue; } start_index = resultsOfPredictions[i].start_index; end_index = resultsOfPredictions[i].end_index; break; } // 映射回上下文文本的索引,(当前的索引值-问题的长度-4) int answer_start_index = start_index - tokens_question.size()- 4; int answer_end_index = end_index - tokens_question.size() - 4 + 1; // 根据开始索引和结束索引,获取区间内的数据 int j=0; for(int i=answer_start_index;i