import cv2 import numpy as np import migraphx import time import argparse class Crnn: def __init__(self, path): # 解析推理模型 self.model = migraphx.parse_onnx(path) # 获取模型输入/输出节点信息 inputs = self.model.get_inputs() outputs = self.model.get_outputs() # 获取模型的输入name self.inputName = self.model.get_parameter_names()[0] # 获取模型的输入尺寸 self.inputShape = inputs[self.inputName].lens() print("inputName:{0} \ninputShape:{1}".format(self.inputName, self.inputShape)) # 模型编译 self.model.compile(t=migraphx.get_target("gpu"), device_id=0) # device_id: 设置GPU设备,默认为0号设备 def infer(self, image): inputImage = self.prepare_input(image) # 执行推理 results = self.model.run({self.model.get_parameter_names()[0]: inputImage}) # 获取第一个输出节点的数据,migraphx.argument类型 result=results[0] result=np.array(result) text = self.decode(result) final_text = self.map_rule(text) return text, final_text def prepare_input(self, image): img_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) inputSize = (self.inputShape[3], self.inputShape[2]) blob = cv2.dnn.blobFromImage(img_gray, scalefactor=1 / 127.5, size=inputSize, mean=127.5) return blob def decode(self, scores): alphabet = "-0123456789abcdefghijklmnopqrstuvwxyz" text = "" # 获取模型预测的文本序列 for i in range(scores.shape[0]): c = np.argmax(scores[i][0]) text += alphabet[c] return text def map_rule(self, text): char_list = [] for i in range(len(text)): if i == 0: if text[i] != '-': char_list.append(text[i]) else: if text[i] != '-' and (not (text[i] == text[i - 1])): char_list.append(text[i]) return ''.join(char_list) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--imgpath', type=str, default='../Resource/Images/text.jpg', help="image path") parser.add_argument('--modelpath', type=str, default='../Resource/Models/crnn.onnx', help="onnx filepath") args = parser.parse_args() crnn = Crnn(args.modelpath) srcimg = cv2.imread(args.imgpath, 1) # 执行推理 start = time.time() resultRaw, resultSim = crnn.infer(srcimg) print("============= Ocr Results =============") print('%-20s => %-20s' % (resultRaw, resultSim))