Crnn_infer_migraphx.py 2.67 KB
Newer Older
Your Name's avatar
Your Name committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import cv2
import numpy as np
import migraphx
import time
import argparse

class Crnn:
    def __init__(self, path):
        # 解析推理模型
        self.model = migraphx.parse_onnx(path)

        # 获取模型的输入name
        self.inputName = self.model.get_parameter_names()[0]

        # 获取模型的输入尺寸
        self.inputShape = self.model.get_parameter_shapes()[self.inputName].lens()

        print("inputName:{0} \ninputShape:{1}".format(self.inputName, self.inputShape))

        # 模型编译
        self.model.compile(t=migraphx.get_target("gpu"), device_id=0)  # device_id: 设置GPU设备,默认为0号设备
        print("Success to compile")

    def infer(self, image):
        inputImage = self.prepare_input(image)

        # 执行推理
Your Name's avatar
Your Name committed
28
        results = self.model.run({self.model.get_parameter_names()[0]: inputImage})
Your Name's avatar
Your Name committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        # 获取第一个输出节点的数据,migraphx.argument类型
        result=results[0]
        result=np.array(result)
        text = self.decode(result)
        final_text = self.map_rule(text)
        
        return text, final_text

    def prepare_input(self, image):
        img_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        inputSize = (self.inputShape[3], self.inputShape[2])
        blob = cv2.dnn.blobFromImage(img_gray, scalefactor=1 / 127.5, size=inputSize, mean=127.5)

        return blob

    def decode(self, scores):
        alphabet = "-0123456789abcdefghijklmnopqrstuvwxyz"
        text = ""

        # 获取模型预测的文本序列
        for i in range(scores.shape[0]):
            c = np.argmax(scores[i][0])
            text += alphabet[c]

        return text

    def map_rule(self, text):
        char_list = []
        for i in range(len(text)):
            if i == 0:
                if text[i] != '-':
                    char_list.append(text[i])
            else:
                if text[i] != '-' and (not (text[i] == text[i - 1])):
                    char_list.append(text[i])

        return ''.join(char_list)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
Your Name's avatar
Your Name committed
69
70
    parser.add_argument('--imgpath', type=str, default='../Resource/Images/text.jpg', help="image path")
    parser.add_argument('--modelpath', type=str, default='../Resource/Models/Ocr/CRNN/crnn.onnx', help="onnx filepath")
Your Name's avatar
Your Name committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    args = parser.parse_args()

    crnn = Crnn(args.modelpath)

    srcimg = cv2.imread(args.imgpath, 1)

    # 执行推理
    print("Start to inference")
    start = time.time()
    resultRaw, resultSim = crnn.infer(srcimg)
    print('net forward time: {:.4f}'.format(time.time() - start))
    print("============= Ocr Results =============")
    print('%-20s => %-20s' % (resultRaw, resultSim))