Commit 84f95d56 authored by liucong's avatar liucong
Browse files

修改python代码使用migraphx后端推理

parent 8070f4bc
......@@ -3,6 +3,7 @@
分类器示例
"""
import cv2
import argparse
import numpy as np
import onnxruntime as ort
......@@ -57,20 +58,19 @@ def postprocess(scores,pathOfImage):
text = 'class=%s ' % (labels[a[0]])
saveimage(pathOfImage,text)
def ort_seg_dcu(model_path,image):
def ort_seg_dcu(model_path,image,staticInfer,dynamicInfer):
#创建sess_options
sess_options = ort.SessionOptions()
provider_options=[]
if staticInfer:
provider_options=[{'device_id':'0','migraphx_fp16_enable':'true','dynamic_model':'false'}]
#设置图优化
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
if dynamicInfer:
provider_options=[{'device_id':'0','migraphx_fp16_enable':'true','dynamic_model':'true', 'migraphx_profile_max_shapes':'data:1x3x224x224'}]
#是否开启profiling
sess_options.enable_profiling = False
dcu_session = ort.InferenceSession(model_path,sess_options,providers=['ROCMExecutionProvider'],)
dcu_session = ort.InferenceSession(model_path, providers=['MIGraphXExecutionProvider'], provider_options=provider_options)
input_name=dcu_session.get_inputs()[0].name
results = dcu_session.run(None, input_feed={input_name:image })
results = dcu_session.run(None, input_feed={input_name:image})
scores=np.array(results[0])
print("ort result.shape:",scores.shape)
......@@ -88,13 +88,24 @@ def saveimage(pathOfImage,text):
cv2.destroyAllWindows()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--imgPath', type=str, default='../Resource/Images/ImageNet_01.jpg', help="image path")
parser.add_argument('--staticModelPath', type=str, default='../Resource/Models/resnet50_static.onnx', help="static onnx filepath")
parser.add_argument('--dynamicModelPath', type=str, default='../Resource/Models/resnet50_dynamic.onnx', help="dynamic onnx filepath")
parser.add_argument("--staticInfer",action="store_true",default=False,help="Performing static inference")
parser.add_argument("--dynamicInfer",action="store_true",default=False,help="Performing dynamic inference")
args = parser.parse_args()
pathOfImage ="../Resource/Images/ImageNet_01.jpg"
image = Preprocessing(pathOfImage)
model_path = "../Resource/Models/resnet50-v2-7.onnx"
# 数据预处理
image = Preprocessing(args.imgPath)
# 推理
result = ort_seg_dcu(model_path,image)
# 静态推理
if args.staticInfer:
result = ort_seg_dcu(args.staticModelPath,image,args.staticInfer,args.dynamicInfer)
# 动态推理
if args.dynamicInfer:
result = ort_seg_dcu(args.dynamicModelPath,image,args.staticInfer,args.dynamicInfer)
# 解析分类结果
postprocess(result,pathOfImage)
postprocess(result,args.imgPath)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment