Commit 84f95d56 authored by liucong's avatar liucong
Browse files

修改python代码使用migraphx后端推理

parent 8070f4bc
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
分类器示例 分类器示例
""" """
import cv2 import cv2
import argparse
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
...@@ -57,20 +58,19 @@ def postprocess(scores,pathOfImage): ...@@ -57,20 +58,19 @@ def postprocess(scores,pathOfImage):
text = 'class=%s ' % (labels[a[0]]) text = 'class=%s ' % (labels[a[0]])
saveimage(pathOfImage,text) saveimage(pathOfImage,text)
def ort_seg_dcu(model_path,image): def ort_seg_dcu(model_path,image,staticInfer,dynamicInfer):
#创建sess_options provider_options=[]
sess_options = ort.SessionOptions() if staticInfer:
provider_options=[{'device_id':'0','migraphx_fp16_enable':'true','dynamic_model':'false'}]
#设置图优化 if dynamicInfer:
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC provider_options=[{'device_id':'0','migraphx_fp16_enable':'true','dynamic_model':'true', 'migraphx_profile_max_shapes':'data:1x3x224x224'}]
#是否开启profiling dcu_session = ort.InferenceSession(model_path, providers=['MIGraphXExecutionProvider'], provider_options=provider_options)
sess_options.enable_profiling = False
dcu_session = ort.InferenceSession(model_path,sess_options,providers=['ROCMExecutionProvider'],)
input_name=dcu_session.get_inputs()[0].name input_name=dcu_session.get_inputs()[0].name
results = dcu_session.run(None, input_feed={input_name:image }) results = dcu_session.run(None, input_feed={input_name:image})
scores=np.array(results[0]) scores=np.array(results[0])
print("ort result.shape:",scores.shape) print("ort result.shape:",scores.shape)
...@@ -88,13 +88,24 @@ def saveimage(pathOfImage,text): ...@@ -88,13 +88,24 @@ def saveimage(pathOfImage,text):
cv2.destroyAllWindows() cv2.destroyAllWindows()
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--imgPath', type=str, default='../Resource/Images/ImageNet_01.jpg', help="image path")
parser.add_argument('--staticModelPath', type=str, default='../Resource/Models/resnet50_static.onnx', help="static onnx filepath")
parser.add_argument('--dynamicModelPath', type=str, default='../Resource/Models/resnet50_dynamic.onnx', help="dynamic onnx filepath")
parser.add_argument("--staticInfer",action="store_true",default=False,help="Performing static inference")
parser.add_argument("--dynamicInfer",action="store_true",default=False,help="Performing dynamic inference")
args = parser.parse_args()
pathOfImage ="../Resource/Images/ImageNet_01.jpg" # 数据预处理
image = Preprocessing(pathOfImage) image = Preprocessing(args.imgPath)
model_path = "../Resource/Models/resnet50-v2-7.onnx"
# 推理 # 静态推理
result = ort_seg_dcu(model_path,image) if args.staticInfer:
result = ort_seg_dcu(args.staticModelPath,image,args.staticInfer,args.dynamicInfer)
# 动态推理
if args.dynamicInfer:
result = ort_seg_dcu(args.dynamicModelPath,image,args.staticInfer,args.dynamicInfer)
# 解析分类结果 # 解析分类结果
postprocess(result,pathOfImage) postprocess(result,args.imgPath)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment