LPRNet_ORT_infer.py 3.16 KB
Newer Older
liuhy's avatar
liuhy committed
1
2
3
import onnxruntime as ort
import cv2
import numpy as np
liuhy's avatar
liuhy committed
4
5
import argparse
import os
liuhy's avatar
liuhy committed
6
import time
liuhy's avatar
liuhy committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

print('Runing Based On:', ort.get_device())

CHARS = ['京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
         '苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
         '桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁',
         '新',
         '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
         'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
         'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
         'W', 'X', 'Y', 'Z', 'I', 'O', '-'
         ]

def LPRNetPreprocess(image):
    img = cv2.imread(image)
    img = cv2.resize(img, (94, 24)).astype('float32')
    img -= 127.5
    img *= 0.0078125
    img = np.expand_dims(img.transpose(2, 0, 1), 0)
    return img

def LPRNetPostprocess(infer_res):
liuhy's avatar
liuhy committed
29
    preb_label = np.argmax(infer_res, axis=0)
liuhy's avatar
liuhy committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    no_repeat_blank_label = []
    pre_c = preb_label[0]
    if pre_c != len(CHARS) - 1:
        no_repeat_blank_label.append(pre_c)
    for c in preb_label:  # dropout repeate label and blank label
        if (pre_c == c) or (c == len(CHARS) - 1):
            if c == len(CHARS) - 1:
                pre_c = c
            continue
        no_repeat_blank_label.append(c)
        pre_c = c
    result = ''.join(list(map(lambda x: CHARS[x], no_repeat_blank_label)))
    return result

liuhy's avatar
liuhy committed
44
def LPRNetInference(args):
liuhy's avatar
liuhy committed
45
    if ort.get_device() == "GPU-MIGRAPHX":
liuhy's avatar
liuhy committed
46
        sess = ort.InferenceSession(args.model, providers=['ROCMExecutionProvider'],) #DCU版本
liuhy's avatar
liuhy committed
47
    else:
liuhy's avatar
liuhy committed
48
        sess = ort.InferenceSession(args.model, providers=['CPUExecutionProvider']) # CPU版本
liuhy's avatar
liuhy committed
49

liuhy's avatar
liuhy committed
50
51
    if os.path.isdir(args.imgpath):
        images = os.listdir(args.imgpath)
liuhy's avatar
liuhy committed
52
53
54
        Tp = 0
        Tn_1 = 0
        Tn_2 = 0
liuhy's avatar
liuhy committed
55
        time1 = time.perf_counter()
liuhy's avatar
liuhy committed
56
57
58
59
60
        for image in images:
            img = LPRNetPreprocess(os.path.join(args.imgpath, image))
            intput = sess.get_inputs()[0].shape
            preb = sess.run(None, input_feed={sess.get_inputs()[0].name: img})[0]
            result = LPRNetPostprocess(preb)
liuhy's avatar
liuhy committed
61
            if result == image[:-4]:
liuhy's avatar
liuhy committed
62
63
64
65
66
67
68
69
70
71
                Tp += 1
            elif len(result) != len(image[:-4]):
                Tn_1 += 1
            else:
                Tn_2 += 1
            print(image + ' Inference Result:', result)
        time2 = time.perf_counter() 
        Acc = Tp * 1.0 / (Tp + Tn_1 + Tn_2)
        print("[Info] Test Accuracy: {} [{}:{}:{}:{}]".format(Acc, Tp, Tn_1, Tn_2, (Tp+Tn_1+Tn_2)))
        print("[Info] Test Speed: {}s 1/{}]".format((time2 - time1) / len(images), len(images)))     
liuhy's avatar
liuhy committed
72
73
74
75
76
77
    else:
        img = LPRNetPreprocess(args.imgpath)
        intput = sess.get_inputs()[0].shape
        preb = sess.run(None, input_feed={sess.get_inputs()[0].name: img})[0]
        result = LPRNetPostprocess(preb)
        print('Inference Result:', result)
liuhy's avatar
liuhy committed
78
79

if __name__ == '__main__':
liuhy's avatar
liuhy committed
80
81
82
83
84
85
86
    parser = argparse.ArgumentParser(description='parameters to vaildate net')
    parser.add_argument('--model', default='model/LPRNet.onnx', help='model path to vaildate')
    parser.add_argument('--imgpath', default='imgs', help='the image path')
    args = parser.parse_args()

    LPRNetInference(args)