#include #include #include #include #include namespace ortSamples { Classifier::Classifier() { } Classifier::~Classifier() { configurationFile.release(); } ErrorCode Classifier::Initialize(InitializationParameterOfClassifier initializationParameterOfClassifier) { // 读取配置文件 std::string configFilePath=initializationParameterOfClassifier.configFilePath; if(Exists(configFilePath)==false) { LOG_ERROR(stdout, "no configuration file!\n"); return CONFIG_FILE_NOT_EXIST; } if(!configurationFile.open(configFilePath, cv::FileStorage::READ)) { LOG_ERROR(stdout, "fail to open configuration file\n"); return FAIL_TO_OPEN_CONFIG_FILE; } LOG_INFO(stdout, "succeed to open configuration file\n"); // 获取配置文件参数 cv::FileNode netNode = configurationFile["Classifier"]; std::string modelPath=(std::string)netNode["ModelPath"]; // 初始化session //设置DCU OrtROCMProviderOptions rocm_options; rocm_options.device_id = 0; sessionOptions.AppendExecutionProvider_ROCM(rocm_options); sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_BASIC); // sessionOptions.EnableProfiling("profile_prefix"); session = new Ort::Session(env, modelPath.c_str(), sessionOptions); return SUCCESS; } ErrorCode Classifier::Classify(const std::vector &srcImages,std::vector> &predictions) { if(srcImages.size()==0||srcImages[0].empty()||srcImages[0].depth()!=CV_8U) { LOG_ERROR(stdout, "image error!\n"); return IMAGE_ERROR; } // 数据预处理 std::vector image; for(int i =0;i imgRGB.cols) { cv::resize(imgRGB, shrink, cv::Size(256, int(ratio * imgRGB.rows)), 0, 0); } else { cv::resize(imgRGB, shrink, cv::Size(int(ratio * imgRGB.cols), 256), 0, 0); } // 裁剪中心窗口为224*224 int start_x = shrink.cols/2 - 224/2; int start_y = shrink.rows/2 - 224/2; cv::Rect rect(start_x, start_y, 224, 224); cv::Mat images = shrink(rect); image.push_back(images); } // normalize并转换为NCHW cv::Mat inputBlob; Image2BlobParams image2BlobParams; image2BlobParams.scalefactor=cv::Scalar(1/58.395, 1/57.12, 1/57.375); image2BlobParams.mean=cv::Scalar(123.675, 116.28, 103.53); image2BlobParams.swapRB=false; blobFromImagesWithParams(image,inputBlob,image2BlobParams); // 设置onnx的输入和输出名 std::vector input_node_names = {"data"}; std::vector output_node_names = {"resnetv24_dense0_fwd"}; // 初始化输入数据 std::array inputShape{1, 3, 224, 224}; auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); std::array input_image{}; float* input_test = (float*)inputBlob.data; Ort::Value inputTensor = Ort::Value::CreateTensor(memoryInfo, input_test, input_image.size(), inputShape.data(), inputShape.size()); std::vector intput_tensors; intput_tensors.push_back(std::move(inputTensor)); // 进行推理 auto output_tensors = session->Run(Ort::RunOptions{nullptr}, input_node_names.data(), intput_tensors.data(), 1, output_node_names.data(), 1); // 解析输出结果 const float* pdata = output_tensors[0].GetTensorMutableData(); int numberOfClasses = 1000 ; for(int i=0;i logit; for(int j=0;j resultOfPredictions; for(int j=0;j