#include #include #include #include #include #include #include namespace ortSamples { GPT2::GPT2() { } GPT2::~GPT2() { } ErrorCode GPT2::Initialize() { // 获取模型文件 std::string modelPath="../Resource/GPT2_shici.onnx"; // 判断路径 if(Exists(modelPath)==false) { LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str()); return MODEL_NOT_EXIST; } //加载模型 OrtROCMProviderOptions rocm_options; rocm_options.device_id = 0; sessionOptions.AppendExecutionProvider_ROCM(rocm_options); sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_BASIC); LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str()); session = new Ort::Session(env, modelPath.c_str(), sessionOptions); return SUCCESS; } static bool CompareM(Predictions a, Predictions b) { return a.predictionvalue > b.predictionvalue; } long unsigned int GPT2::Inference(const std::vector &input_id) { int64_t input[1][input_id.size()]; for (int j=0;j(input_id[j]);; } // 获取模型输入属性 input_node_names = {"input"}; // 获取模型输出属性 output_node_names = {"output","304","value.3","532","value.11","758","value.19","984","value.27","1210","value.35","1436","value.43","1662","value.51","1888","value.59","2114","value.67","2340","value.75","2566","value.83","2792","value.91"}; // 设置输入shape std::array inputShapes{1, (int)input_id.size()}; // 创建输入数据 auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); auto inputTensorSize = 1 * input_id.size(); std::vector inputTensorValues(inputTensorSize); int64_t* input_test = (int64_t*)input; Ort::Value inputTensor = Ort::Value::CreateTensor(memoryInfo, input_test, inputTensorValues.size(), inputShapes.data(), inputShapes.size()); std::vector intput_tensors; intput_tensors.push_back(std::move(inputTensor)); // 推理 auto output_tensors = session->Run(Ort::RunOptions{nullptr}, input_node_names.data(), intput_tensors.data(), input_node_names.size(), output_node_names.data(), output_node_names.size()); // 保存推理结果 const float* data = output_tensors[0].GetTensorMutableData(); long unsigned int n = 0; std::vector resultsOfPredictions(22557); for(int i=(input_id.size()-1)*22557; i &input_id) { // 分词操作 int max_seq_length =1000; std::vector tokens_question; tokens_question.reserve(max_seq_length); tokenizer.tokenize(question, &tokens_question, max_seq_length); // 保存编码信息 input_id.push_back(tokenizer.convert_token_to_id("[CLS]")); for (int i=0;i