LPRNet_ORT_infer.py 2.43 KB
Newer Older
liuhy's avatar
liuhy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import onnxruntime as ort
import cv2
import numpy as np

print('Runing Based On:', ort.get_device())

CHARS = ['京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
         '苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
         '桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁',
         '新',
         '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
         'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
         'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
         'W', 'X', 'Y', 'Z', 'I', 'O', '-'
         ]

def LPRNetPreprocess(image):
    img = cv2.imread(image)
    img = cv2.resize(img, (94, 24)).astype('float32')
    img -= 127.5
    img *= 0.0078125
    img = np.expand_dims(img.transpose(2, 0, 1), 0)
    return img

def LPRNetPostprocess(infer_res):
    preb_label = []
    for j in range(infer_res.shape[1]):
        preb_label.append(np.argmax(infer_res[:, j], axis=0))
    no_repeat_blank_label = []
liuhy's avatar
liuhy committed
30
31

    print(preb_label)
liuhy's avatar
liuhy committed
32
    pre_c = preb_label[0]
liuhy's avatar
liuhy committed
33
    print(pre_c)
liuhy's avatar
liuhy committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    if pre_c != len(CHARS) - 1:
        no_repeat_blank_label.append(pre_c)
    for c in preb_label:  # dropout repeate label and blank label
        if (pre_c == c) or (c == len(CHARS) - 1):
            if c == len(CHARS) - 1:
                pre_c = c
            continue
        no_repeat_blank_label.append(c)
        pre_c = c
    result = ''.join(list(map(lambda x: CHARS[x], no_repeat_blank_label)))
    return result

def LPRNetInference(model, imgs):
    img = LPRNetPreprocess(imgs)
    
    if ort.get_device() == "GPU":
        sess = ort.InferenceSession(model, providers=['ROCMExecutionProvider'],) #DCU版本
    else:
        sess = ort.InferenceSession(model, providers=['CPUExecutionProvider']) # CPU版本
    print(sess.get_providers())
    intput = sess.get_inputs()[0].shape
    preb = sess.run(None, input_feed={sess.get_inputs()[0].name: img})[0]

    result = LPRNetPostprocess(preb)
    return result

if __name__ == '__main__':
    model_name = 'model/LPRNet.onnx'
    # model_name = 'LPRNet.onnx'
liuhy's avatar
liuhy committed
63
64
65
66
67
68
69
70
71
72
73
    # image = 'imgs/川JK0707.jpg'
    import os
    images = os.listdir('/code/lpr_ori/data/test')
    count = 0
    for image in images:
        label = image[:-4]
        InferRes = LPRNetInference(model_name, os.path.join('/code/lpr_ori/data/test', image))
        print(image, 'Inference Result:', InferRes)
        if label == InferRes:
            count += 1
    print('acc rate:', count / len(images))