Commit eb244d7e authored by liucong's avatar liucong
Browse files

修改C++示例程序

parent 1ddb604f
...@@ -13,6 +13,9 @@ Classifier::Classifier() ...@@ -13,6 +13,9 @@ Classifier::Classifier()
Classifier::~Classifier() Classifier::~Classifier()
{ {
delete dcu_session;
inputNamesPtr.clear();
outputNamesPtr.clear();
configurationFile.release(); configurationFile.release();
} }
...@@ -40,6 +43,11 @@ ErrorCode Classifier::Initialize(InitializationParameterOfClassifier initializat ...@@ -40,6 +43,11 @@ ErrorCode Classifier::Initialize(InitializationParameterOfClassifier initializat
OrtMIGraphXProviderOptions migraphx_options; OrtMIGraphXProviderOptions migraphx_options;
migraphx_options.device_id = 0; migraphx_options.device_id = 0;
migraphx_options.migraphx_fp16_enable = 1; migraphx_options.migraphx_fp16_enable = 1;
migraphx_options.migraphx_int8_enable = 0;
migraphx_options.dynamic_model = 0;
migraphx_options.migraphx_profile_max_shapes = "";
migraphx_options.migraphx_load_compiled_model=0;
migraphx_options.migraphx_save_compiled_model=0;
sessionOptions.AppendExecutionProvider_MIGraphX(migraphx_options); sessionOptions.AppendExecutionProvider_MIGraphX(migraphx_options);
dcu_session = new Ort::Session(env, modelPath.c_str(), sessionOptions); dcu_session = new Ort::Session(env, modelPath.c_str(), sessionOptions);
...@@ -92,8 +100,6 @@ ErrorCode Classifier::Classify(const std::vector<cv::Mat> &srcImages,std::vector ...@@ -92,8 +100,6 @@ ErrorCode Classifier::Classify(const std::vector<cv::Mat> &srcImages,std::vector
// 获取模型输入输出信息 // 获取模型输入输出信息
Ort::AllocatorWithDefaultOptions allocator; Ort::AllocatorWithDefaultOptions allocator;
std::vector<Ort::AllocatedStringPtr> inputNamesPtr;
std::vector<Ort::AllocatedStringPtr> outputNamesPtr;
for ( size_t i=0; i<dcu_session->GetInputCount(); i++) for ( size_t i=0; i<dcu_session->GetInputCount(); i++)
{ {
auto input_name = dcu_session->GetInputNameAllocated(i , allocator); auto input_name = dcu_session->GetInputNameAllocated(i , allocator);
......
...@@ -20,8 +20,11 @@ public: ...@@ -20,8 +20,11 @@ public:
private: private:
cv::FileStorage configurationFile; cv::FileStorage configurationFile;
Ort::Session *dcu_session; Ort::Session *dcu_session;
Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "ONNXRuntime"); Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "MIGraphX");
Ort::SessionOptions sessionOptions = Ort::SessionOptions(); Ort::SessionOptions sessionOptions = Ort::SessionOptions();
std::vector<Ort::AllocatedStringPtr> inputNamesPtr;
std::vector<Ort::AllocatedStringPtr> outputNamesPtr;
}; };
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment