import cv2
import numpy as np
import migraphx
import time
import argparse

class Crnn:
    def __init__(self, path):
        # 解析推理模型
        self.model = migraphx.parse_onnx(path)

        # 获取模型输入/输出节点信息
        inputs = self.model.get_inputs()
        outputs = self.model.get_outputs()

        # 获取模型的输入name
        self.inputName = self.model.get_parameter_names()[0]

        # 获取模型的输入尺寸
        self.inputShape = inputs[self.inputName].lens()
        print("inputName:{0} \ninputShape:{1}".format(self.inputName, self.inputShape))

        # 模型编译
        self.model.compile(t=migraphx.get_target("gpu"), device_id=0)  # device_id: 设置GPU设备，默认为0号设备

    def infer(self, image):
        inputImage = self.prepare_input(image)

        # 执行推理
        results = self.model.run({self.model.get_parameter_names()[0]: inputImage})

        # 获取第一个输出节点的数据,migraphx.argument类型
        result=results[0]
        result=np.array(result)
        text = self.decode(result)
        final_text = self.map_rule(text)
        
        return text, final_text

    def prepare_input(self, image):
        img_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        inputSize = (self.inputShape[3], self.inputShape[2])
        blob = cv2.dnn.blobFromImage(img_gray, scalefactor=1 / 127.5, size=inputSize, mean=127.5)

        return blob

    def decode(self, scores):
        alphabet = "-0123456789abcdefghijklmnopqrstuvwxyz"
        text = ""

        # 获取模型预测的文本序列
        for i in range(scores.shape[0]):
            c = np.argmax(scores[i][0])
            text += alphabet[c]

        return text

    def map_rule(self, text):
        char_list = []
        for i in range(len(text)):
            if i == 0:
                if text[i] != '-':
                    char_list.append(text[i])
            else:
                if text[i] != '-' and (not (text[i] == text[i - 1])):
                    char_list.append(text[i])

        return ''.join(char_list)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--imgpath', type=str, default='../Resource/Images/text.jpg', help="image path")
    parser.add_argument('--modelpath', type=str, default='../Resource/Models/crnn.onnx', help="onnx filepath")
    args = parser.parse_args()

    crnn = Crnn(args.modelpath)

    srcimg = cv2.imread(args.imgpath, 1)

    # 执行推理
    resultRaw, resultSim = crnn.infer(srcimg)
    print("============= Ocr Results =============")
    print('%-20s => %-20s' % (resultRaw, resultSim))



