from sys import argv import os import cv2 import json import numpy as np import migraphx import torch from PIL import Image from torchvision import transforms import matplotlib.pyplot as plt from vit_model import vit_base_patch16_224_in21k as create_model def main(intputdir): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_transform = transforms.Compose( [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) # read class_indict json_path = './models/class_indices.json' assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) with open(json_path, "r") as f: class_indict = json.load(f) # create pth model model = create_model(num_classes=5, has_logits=False).to(device) # load model weights model_weight_path = "./models/model.pth" model.load_state_dict(torch.load(model_weight_path, map_location=device)) # create onnx model model1 = migraphx.parse_onnx('./models/model.onnx') model1.compile(t=migraphx.get_target("gpu"), device_id=1) # load image Img_cnt = 0 pytorch_match_cnt = 0 migraphx_match_cnt = 0 for filename in os.listdir(intputdir): Img_cnt += 1 img_path = intputdir + filename assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) img = Image.open(img_path) # [N, C, H, W] img = data_transform(img) # expand batch dimension img = torch.unsqueeze(img, dim=0) # pythorch eval model.eval() # migraphx run img1 = img.numpy() inputName = model1.get_parameter_names()[0] inputShape = model1.get_parameter_shapes()[inputName].lens() results = model1.run({inputName: migraphx.argument(img1)}) # migraphx postprocess infer_res_exp = np.exp(results[0]) predict = infer_res_exp / infer_res_exp.sum(axis=1) for i in range(len(predict[0])): if(predict[0][i] >= 0.5): print("migraphx class: {:10} prob: {:.3}".format(class_indict[str(i)], predict[0][i])) if (intputdir.find(class_indict[str(i)])): migraphx_match_cnt += 1 # pythorch postprocess with torch.no_grad(): # predict class output = torch.squeeze(model(img.to(device))).cpu() predict = torch.softmax(output, dim=0) for i in range(len(predict)): if(predict[i].numpy() >= 0.5): print("pytorch class: {:10} prob: {:.3}".format(class_indict[str(i)], predict[i].numpy())) if intputdir.find(class_indict[str(i)]): pytorch_match_cnt += 1 print("Pytorch Img_cnt: {:<5} match_cnt: {:<5} acc:{:.3}".format(Img_cnt, pytorch_match_cnt, pytorch_match_cnt/Img_cnt)) print("Migraphx Img_cnt: {:<5} match_cnt: {:<5} acc:{:.3}".format(Img_cnt, migraphx_match_cnt, migraphx_match_cnt/Img_cnt)) if __name__ == '__main__': main(argv[1])