# CRNN文本识别 本示例构建了CRNN的Python推理示例,利用MIGraphX框架推理可以正确的获得文本字符识别结果。 ## 模型简介 CRNN是文本识别领域的一种经典算法。该算法的主要思想是认为文本识别需要对序列进行预测,所以采用了预测序列常用的RNN网络。算法通过CNN提取图片特征,然后采用RNN对序列进行预测,最终使用CTC方法得到最终结果。模型的主要结构包括基于CNN的图像特征提取模块以及基于双向LSTM的文字序列特征提取模块,网络结构如下图所示。 CRNN_01 本示例采用了如下的开源实现:https://github.com/meijieru/crnn.pytorch, 作者提供了CRNN的预训练模型。 ## 预处理 将待识别的文本图像输入模型前,需要对图像做如下预处理: 1. 转换为单通道图像 2. resize到100x32 3. 将像素值归一化到[-1, 1] 4. 转换数据排布为NCHW 本示例代码采用了OpenCV的blobFromImage()函数实现了预处理操作,blobFromImage处理输入图像的顺序为:首先将输入图像resize到inputSize,然后减去均值mean,最后乘以缩放系数scalefactor并转换为NCHW排布,blob是预处理之后的输出图像。 ``` def prepare_input(self, image): img_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) inputSize = (self.inputShape[3], self.inputShape[2]) blob = cv2.dnn.blobFromImage(img_gray, scalefactor=1 / 127.5, size=inputSize, mean=127.5) ``` 其中self.inputShape通过MIGraphX对CRNN模型进行解析获取: ``` class Crnn: def __init__(self, path): # 解析推理模型 self.model = migraphx.parse_onnx(path) # 获取模型的输入name self.inputName = self.model.get_parameter_names()[0] # 获取模型的输入尺寸 self.inputShape = self.model.get_parameter_shapes()[self.inputName].lens() ... ``` ## 推理 待识别文本图像经过预处理之后,将其输入到crnn.onnx模型中执行推理,利用migraphx推理计算得到CRNN模型的输出。 ``` def infer(self, image): inputImage = self.prepare_input(image) # 执行推理 results = self.model.run({self.model.get_parameter_names()[0]: inputImage}) # 获取第一个输出节点的数据,migraphx.argument类型 result=results[0] result=np.array(result) text = self.decode(result) final_text = self.map_rule(text) return text, final_text ``` 经过推理获取CRNN模型的输出结果,由于该模型只有一个输出,所以输出结果等于results[0],该结果是一个argument类型,维度为[26,1,37]。为了便于对其进行后处理将其转化为numpy数据类型。后处理主要包括两个步骤: 1、第一步通过解码获取模型预测的文本序列,解码过程首先获取输出特征向量中最高得分字符对应的位置索引信息,然后从alphabet中获取对应字符保存在text中。 ``` def decode(self, scores): alphabet = "-0123456789abcdefghijklmnopqrstuvwxyz" text = "" # 获取模型预测的文本序列 for i in range(scores.shape[0]): c = np.argmax(scores[i][0]) text += alphabet[c] return text ``` 其中scores.shape[0]表示输入图像从左到右预测字符的次数,本示例中次数为scores.shape[0]=26,每次预测输出一个特征向量,使用np.argmax()判断最高得分字符对应的位置索引信息c,然后根据c从alphabet中获取对应字符,并将其保存到text中。 2、第二步通过对预测得到的文本序列进行去除空格和重复字符处理,从而得到最终的预测结果。 ``` def map_rule(self, text): char_list = [] for i in range(len(text)): if i == 0: if text[i] != '-': char_list.append(text[i]) else: if text[i] != '-' and (not (text[i] == text[i - 1])): char_list.append(text[i]) return ''.join(char_list) ```