infer_migraphx.py 2.39 KB
Newer Older
lijian6's avatar
lijian6 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from sys import argv
import json
import cv2
import numpy as np
import migraphx
import argparse
import os
import time
from PIL import Image

img_count = 0
match_cnt = 0

def Vit_Preprocess(image):
    img = cv2.imread(image)
    img = cv2.dnn.blobFromImage(img, scalefactor=1/127.5, size=(224,224), mean=[127.5, 127.5, 127.5], swapRB=True, crop=True, ddepth=cv2.CV_32F)
    img -= 0.5
    img /= 0.5
    return img

def Vit_Postprocess(infer_res, class_indict, imgpath):
    global img_count
    global match_cnt
    infer_res_exp = np.exp(infer_res)
    predict = infer_res_exp / infer_res_exp.sum(axis=1)
    for i in range(len(predict[0])):
        if(predict[0][i] >= 0.5):
            print("class: {:10}   prob: {:.3}".format(class_indict[str(i)], predict[0][i]))
            if (imgpath.find(class_indict[str(i)])):
                match_cnt += 1

def Vit_Inference(args):
    model = migraphx.parse_onnx(args.model)
    model.compile(t=migraphx.get_target("gpu"), device_id=0)

    json_path = 'models/class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
    with open(json_path, "r") as f:
        class_indict = json.load(f)

    global img_count
    global match_cnt
    if os.path.isdir(args.imgpath):
        images = os.listdir(args.imgpath)
        for image in images:
            img_count += 1
            img = Vit_Preprocess(os.path.join(args.imgpath, image))
            inputName = model.get_parameter_names()[0]
            inputShape = model.get_parameter_shapes()[inputName].lens()
            results = model.run({inputName: migraphx.argument(img)})
            Vit_Postprocess(np.array(results[0]), class_indict, args.imgpath)
        print("Img_cnt: {:<5} match_cnt: {:<5} acc:{:.3}".format(img_count, match_cnt, match_cnt/img_count))
    else:
        img = Vit_Preprocess(args.imgpath)
        inputName=model.get_parameter_names()[0]
        inputShape=model.get_parameter_shapes()[inputName].lens()
        results = model.run({inputName: migraphx.argument(img)})
        Vit_Postprocess(np.array(results[0]), class_indict, args.imgpath)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='parameters to vaildate net')
    parser.add_argument('--model', default='models/model.onnx', help='model path to inference')
    parser.add_argument('--imgpath', default='', help='the image path')
    args = parser.parse_args()

    Vit_Inference(args)