Commit 1ddb604f authored by liucong's avatar liucong
Browse files

修改C++示例代码

parent e6a8bf6a
Pipeline #2210 failed with stages
in 0 seconds
...@@ -40,10 +40,9 @@ ErrorCode Classifier::Initialize(InitializationParameterOfClassifier initializat ...@@ -40,10 +40,9 @@ ErrorCode Classifier::Initialize(InitializationParameterOfClassifier initializat
OrtMIGraphXProviderOptions migraphx_options; OrtMIGraphXProviderOptions migraphx_options;
migraphx_options.device_id = 0; migraphx_options.device_id = 0;
migraphx_options.migraphx_fp16_enable = 1; migraphx_options.migraphx_fp16_enable = 1;
migraphx_options.dynamic_model = 0;
sessionOptions.AppendExecutionProvider_MIGraphX(migraphx_options); sessionOptions.AppendExecutionProvider_MIGraphX(migraphx_options);
session = new Ort::Session(env, modelPath.c_str(), sessionOptions); dcu_session = new Ort::Session(env, modelPath.c_str(), sessionOptions);
return SUCCESS; return SUCCESS;
} }
...@@ -91,24 +90,42 @@ ErrorCode Classifier::Classify(const std::vector<cv::Mat> &srcImages,std::vector ...@@ -91,24 +90,42 @@ ErrorCode Classifier::Classify(const std::vector<cv::Mat> &srcImages,std::vector
image2BlobParams.swapRB=false; image2BlobParams.swapRB=false;
blobFromImagesWithParams(image,inputBlob,image2BlobParams); blobFromImagesWithParams(image,inputBlob,image2BlobParams);
// 设置onnx的输入和输出名 // 获取模型输入输出信息
std::vector<const char*> input_node_names = {"data"}; Ort::AllocatorWithDefaultOptions allocator;
std::vector<const char*> output_node_names = {"resnetv24_dense0_fwd"}; std::vector<Ort::AllocatedStringPtr> inputNamesPtr;
std::vector<Ort::AllocatedStringPtr> outputNamesPtr;
for ( size_t i=0; i<dcu_session->GetInputCount(); i++)
{
auto input_name = dcu_session->GetInputNameAllocated(i , allocator);
inputNamesPtr.push_back(std::move(input_name));
}
for ( size_t i=0; i<dcu_session->GetOutputCount(); i++)
{
auto out_name = dcu_session->GetOutputNameAllocated(i , allocator);
outputNamesPtr.push_back(std::move(out_name));
}
std::vector<const char *> inputNames = {inputNamesPtr.data()->get()};
std::vector<const char *> outputNames = {outputNamesPtr.data()->get()};
// 初始化输入数据 float* input_data = (float*)inputBlob.data;
std::array<int64_t, 4> inputShape{1, 3, 224, 224}; std::array<float, 3 * 224 * 224> input_data_len{};
auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
std::array<float, 3 * 224 * 224> input_image{}; Ort::MemoryInfo memoryInfo =Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
float* input_test = (float*)inputBlob.data;
Ort::Value inputTensor = Ort::Value::CreateTensor<float>(memoryInfo, input_test, input_image.size(), inputShape.data(), inputShape.size()); std::vector<Ort::Value> inputTensors;
std::vector<Ort::Value> intput_tensors; for(size_t i=0; i<inputNames.size(); i++)
intput_tensors.push_back(std::move(inputTensor)); {
Ort::TypeInfo inputTypeInfo = dcu_session->GetInputTypeInfo(i);
auto inputTensorInfo = inputTypeInfo.GetTensorTypeAndShapeInfo();
std::vector<int64_t> inputDims = inputTensorInfo.GetShape();
inputTensors.push_back(Ort::Value::CreateTensor<float>(memoryInfo,input_data,input_data_len.size(), inputDims.data(), inputDims.size()));
}
// 进行推理 // 进行推理
auto output_tensors = session->Run(Ort::RunOptions{nullptr}, input_node_names.data(), intput_tensors.data(), 1, output_node_names.data(), 1); auto output = dcu_session->Run(Ort::RunOptions{nullptr}, inputNames.data(), inputTensors.data(), inputNames.size(), outputNames.data(), outputNames.size());
// 解析输出结果 // 解析输出结果
const float* pdata = output_tensors[0].GetTensorMutableData<float>(); const float* pdata = output[0].GetTensorMutableData<float>();
int numberOfClasses = 1000 ; int numberOfClasses = 1000 ;
for(int i=0;i<srcImages.size();++i) for(int i=0;i<srcImages.size();++i)
{ {
......
...@@ -19,9 +19,7 @@ public: ...@@ -19,9 +19,7 @@ public:
private: private:
cv::FileStorage configurationFile; cv::FileStorage configurationFile;
Ort::Session *session; Ort::Session *dcu_session;
cv::Size inputSize;
std::string inputName;
Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "ONNXRuntime"); Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "ONNXRuntime");
Ort::SessionOptions sessionOptions = Ort::SessionOptions(); Ort::SessionOptions sessionOptions = Ort::SessionOptions();
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment