Classifier.cpp 4.45 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#include <Classifier.h>
#include <Filesystem.h>
#include <SimpleLog.h>
#include <algorithm>
#include <CommonUtility.h>

namespace ortSamples
{

Classifier::Classifier()
{
}

Classifier::~Classifier()
{
    configurationFile.release();
}

ErrorCode Classifier::Initialize(InitializationParameterOfClassifier initializationParameterOfClassifier)
{
    // 读取配置文件
    std::string configFilePath=initializationParameterOfClassifier.configFilePath;
    if(Exists(configFilePath)==false)
    {
        LOG_ERROR(stdout, "no configuration file!\n");
        return CONFIG_FILE_NOT_EXIST;
    }
    if(!configurationFile.open(configFilePath, cv::FileStorage::READ))
    {
       LOG_ERROR(stdout, "fail to open configuration file\n");
       return FAIL_TO_OPEN_CONFIG_FILE;
    }
    LOG_INFO(stdout, "succeed to open configuration file\n");

    // 获取配置文件参数
    cv::FileNode netNode = configurationFile["Classifier"];
    std::string modelPath=(std::string)netNode["ModelPath"];

    // 初始化session
    //设置DCU
    OrtROCMProviderOptions rocm_options;
    rocm_options.device_id = 0;
    sessionOptions.AppendExecutionProvider_ROCM(rocm_options);
    sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_BASIC);
    // sessionOptions.EnableProfiling("profile_prefix");

    session = new Ort::Session(env, modelPath.c_str(), sessionOptions);
    return SUCCESS;
}

ErrorCode Classifier::Classify(const std::vector<cv::Mat> &srcImages,std::vector<std::vector<ResultOfPrediction>> &predictions)
{
    if(srcImages.size()==0||srcImages[0].empty()||srcImages[0].depth()!=CV_8U)
    {
        LOG_ERROR(stdout, "image error!\n");
        return IMAGE_ERROR;
    }
    
    // 数据预处理
    std::vector<cv::Mat> image;
    for(int i =0;i<srcImages.size();++i)
    {
        //BGR转换为RGB
        cv::Mat imgRGB;
        cv::cvtColor(srcImages[i], imgRGB, cv::COLOR_BGR2RGB);

        // 调整大小,使短边为256,保持长宽比
        cv::Mat shrink;
        float ratio = (float)256 / min(imgRGB.cols, imgRGB.rows);
        if(imgRGB.rows > imgRGB.cols)
        {
            cv::resize(imgRGB, shrink, cv::Size(256, int(ratio * imgRGB.rows)), 0, 0);
        }
        else
        {
            cv::resize(imgRGB, shrink, cv::Size(int(ratio * imgRGB.cols), 256), 0, 0);
        }

        // 裁剪中心窗口为224*224
        int start_x = shrink.cols/2 - 224/2;
        int start_y = shrink.rows/2 - 224/2;
        cv::Rect rect(start_x, start_y, 224, 224);
        cv::Mat images = shrink(rect);
        image.push_back(images);
    }

    // normalize并转换为NCHW
    cv::Mat inputBlob;
    Image2BlobParams image2BlobParams;
    image2BlobParams.scalefactor=cv::Scalar(1/58.395, 1/57.12, 1/57.375);
    image2BlobParams.mean=cv::Scalar(123.675, 116.28, 103.53);
    image2BlobParams.swapRB=false;
    blobFromImagesWithParams(image,inputBlob,image2BlobParams);

    // 设置onnx的输入和输出名
    std::vector<const char*> input_node_names = {"data"};
    std::vector<const char*> output_node_names = {"resnetv24_dense0_fwd"};

    // 初始化输入数据
    std::array<int64_t, 4> inputShape{1, 3, 224, 224};
    auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    std::array<float, 3 * 224 * 224> input_image{};
    float* input_test = (float*)inputBlob.data;
    Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memoryInfo, input_test, input_image.size(), inputShape.data(), inputShape.size());
    std::vector<Ort::Value> intput_tensors;
    intput_tensors.push_back(std::move(inputTensor));
    
    // 进行推理
    auto output_tensors = session->Run(Ort::RunOptions{nullptr}, input_node_names.data(), intput_tensors.data(), 1, output_node_names.data(), 1);
    
    // 解析输出结果
    const float* pdata = output_tensors[0].GetTensorMutableData<float>();
    int numberOfClasses = 1000 ;
    for(int i=0;i<srcImages.size();++i)
    {
        int startIndex=numberOfClasses*i;
        std::vector<float> logit;
        for(int j=0;j<numberOfClasses;++j)
        {
            logit.push_back(pdata[startIndex+j]);
yangql's avatar
yangql committed
121
        }     
yangql's avatar
yangql committed
122
123
124
125
126
127
128
129
130
131
132
133
134
        std::vector<ResultOfPrediction> resultOfPredictions;
        for(int j=0;j<numberOfClasses;++j)
        {
            ResultOfPrediction prediction;
            prediction.label=j;
            prediction.confidence=logit[j];
            resultOfPredictions.push_back(prediction);
        }
        predictions.push_back(resultOfPredictions);
    }
    return SUCCESS;
}
}