# -*- coding: utf-8 -*-
"""
MIGraphX示例程序
"""
import cv2
import numpy as np
import migraphx
import argparse
import os
import time

CHARS = ['京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
         '苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
         '桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁',
         '新',
         '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
         'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
         'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
         'W', 'X', 'Y', 'Z', 'I', 'O', '-'
         ]

def LPRNetPreprocess(image):
    img = cv2.imread(image)
    img = cv2.resize(img, (94, 24)).astype('float32')
    img -= 127.5
    img *= 0.0078125
    img = np.expand_dims(img.transpose(2, 0, 1), 0)
    return img

def LPRNetPostprocess(infer_res):
    preb_label = []
    for j in range(infer_res.shape[1]):
        preb_label.append(np.argmax(infer_res[:, j], axis=0))
    no_repeat_blank_label = []
    pre_c = preb_label[0]
    if pre_c != len(CHARS) - 1:
        no_repeat_blank_label.append(pre_c)
    for c in preb_label:  # dropout repeate label and blank label
        if (pre_c == c) or (c == len(CHARS) - 1):
            if c == len(CHARS) - 1:
                pre_c = c
            continue
        no_repeat_blank_label.append(c)
        pre_c = c
    result = ''.join(list(map(lambda x: CHARS[x], no_repeat_blank_label)))
    return result

def LPRNetInference(args):
    # 加载模型
    if args.model[-3:] == 'mxr':
        model = migraphx.load(args.model)
    else:
        print('convert onnx to mxr...')
        model = migraphx.parse_onnx(args.model)
        model.compile(t=migraphx.get_target("gpu"),device_id=0) # device_id: 设置GPU设备，默认为0号设备(>=1.2版本中支持)
        migraphx.save(model, args.savepath)

    if os.path.isdir(args.imgpath):
        images = os.listdir(args.imgpath)
        count = 0
        time1 = time.perf_counter()
        for image in images:
            img = LPRNetPreprocess(os.path.join(args.imgpath, image))
            inputName = model.get_parameter_names()[0]
            inputShape = model.get_parameter_shapes()[inputName].lens()
            # print("inputName:{0} \ninputShape:{1}".format(inputName,inputShape))
            results = model.run({inputName: migraphx.argument(img)})
            result = LPRNetPostprocess(np.array(results[0]))
            if result == image[:-4]:
                count += 1
            print('Inference Result:', result)
        time2 = time.perf_counter()
        print('accuracy rate:', count / len(images))
        print('average time', (time2 - time1)/count*1000) 
    else:
        img = LPRNetPreprocess(args.imgpath)
        inputName=model.get_parameter_names()[0]
        inputShape=model.get_parameter_shapes()[inputName].lens()
        # print("inputName:{0} \ninputShape:{1}".format(inputName,inputShape))
        results = model.run({inputName: migraphx.argument(img)})
        result = LPRNetPostprocess(np.array(results[0]))
        print('Inference Result:', result)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='parameters to vaildate net')
    parser.add_argument('--model', default='model/LPRNet.onnx', help='model path to inference')
    parser.add_argument('--imgpath', default='imgs/京PL3N67.jpg', help='the image path')
    parser.add_argument('--savepath', default='model/LPRNet.mxr', help='mxr model save path and name')
    args = parser.parse_args()

    LPRNetInference(args)
