from sys import argv import json import cv2 import numpy as np import migraphx import argparse import os import time from PIL import Image img_count = 0 match_cnt = 0 def Vit_Preprocess(image): img = cv2.imread(image) img = cv2.dnn.blobFromImage(img, scalefactor=1/127.5, size=(224,224), mean=[127.5, 127.5, 127.5], swapRB=True, crop=True, ddepth=cv2.CV_32F) img -= 0.5 img /= 0.5 return img def Vit_Postprocess(infer_res, class_indict, imgpath): global img_count global match_cnt infer_res_exp = np.exp(infer_res) predict = infer_res_exp / infer_res_exp.sum(axis=1) for i in range(len(predict[0])): if(predict[0][i] >= 0.5): print("class: {:10} prob: {:.3}".format(class_indict[str(i)], predict[0][i])) if (imgpath.find(class_indict[str(i)])): match_cnt += 1 def Vit_Inference(args): model = migraphx.parse_onnx(args.model) model.compile(t=migraphx.get_target("gpu"), device_id=0) json_path = 'models/class_indices.json' assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) with open(json_path, "r") as f: class_indict = json.load(f) global img_count global match_cnt if os.path.isdir(args.imgpath): images = os.listdir(args.imgpath) for image in images: img_count += 1 img = Vit_Preprocess(os.path.join(args.imgpath, image)) inputName = model.get_parameter_names()[0] inputShape = model.get_parameter_shapes()[inputName].lens() results = model.run({inputName: migraphx.argument(img)}) Vit_Postprocess(np.array(results[0]), class_indict, args.imgpath) print("Img_cnt: {:<5} match_cnt: {:<5} acc:{:.3}".format(img_count, match_cnt, match_cnt/img_count)) else: img = Vit_Preprocess(args.imgpath) inputName=model.get_parameter_names()[0] inputShape=model.get_parameter_shapes()[inputName].lens() results = model.run({inputName: migraphx.argument(img)}) Vit_Postprocess(np.array(results[0]), class_indict, args.imgpath) if __name__ == '__main__': parser = argparse.ArgumentParser(description='parameters to vaildate net') parser.add_argument('--model', default='models/model.onnx', help='model path to inference') parser.add_argument('--imgpath', default='', help='the image path') args = parser.parse_args() Vit_Inference(args)