Crnn_infer_migraphx.py 2.69 KB
Newer Older
Your Name's avatar
Your Name committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import cv2
import numpy as np
import migraphx
import time
import argparse

class Crnn:
    def __init__(self, path):
        # 解析推理模型
        self.model = migraphx.parse_onnx(path)

        # 获取模型的输入name
        self.inputName = self.model.get_parameter_names()[0]

        # 获取模型的输入尺寸
        self.inputShape = self.model.get_parameter_shapes()[self.inputName].lens()

        print("inputName:{0} \ninputShape:{1}".format(self.inputName, self.inputShape))

        # 模型编译
        self.model.compile(t=migraphx.get_target("gpu"), device_id=0)  # device_id: 设置GPU设备,默认为0号设备
        print("Success to compile")

    def infer(self, image):
        inputImage = self.prepare_input(image)

        # 执行推理
        results = self.model.run({self.model.get_parameter_names()[0]: migraphx.argument(inputImage)})
        # 获取第一个输出节点的数据,migraphx.argument类型
        result=results[0]
        result=np.array(result)
        text = self.decode(result)
        final_text = self.map_rule(text)
        
        return text, final_text

    def prepare_input(self, image):
        img_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        inputSize = (self.inputShape[3], self.inputShape[2])
        blob = cv2.dnn.blobFromImage(img_gray, scalefactor=1 / 127.5, size=inputSize, mean=127.5)

        return blob

    def decode(self, scores):
        alphabet = "-0123456789abcdefghijklmnopqrstuvwxyz"
        text = ""

        # 获取模型预测的文本序列
        for i in range(scores.shape[0]):
            c = np.argmax(scores[i][0])
            text += alphabet[c]

        return text

    def map_rule(self, text):
        char_list = []
        for i in range(len(text)):
            if i == 0:
                if text[i] != '-':
                    char_list.append(text[i])
            else:
                if text[i] != '-' and (not (text[i] == text[i - 1])):
                    char_list.append(text[i])

        return ''.join(char_list)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
Your Name's avatar
Your Name committed
69
70
    parser.add_argument('--imgpath', type=str, default='../Resource/Images/text.jpg', help="image path")
    parser.add_argument('--modelpath', type=str, default='../Resource/Models/Ocr/CRNN/crnn.onnx', help="onnx filepath")
Your Name's avatar
Your Name committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    args = parser.parse_args()

    crnn = Crnn(args.modelpath)

    srcimg = cv2.imread(args.imgpath, 1)

    # 执行推理
    print("Start to inference")
    start = time.time()
    resultRaw, resultSim = crnn.infer(srcimg)
    print('net forward time: {:.4f}'.format(time.time() - start))
    print("============= Ocr Results =============")
    print('%-20s => %-20s' % (resultRaw, resultSim))