Commit 341e85be authored by Your Name's avatar Your Name
Browse files

提交crnn推理示例

parent 6e3ccfd7
Pipeline #280 failed with stages
in 0 seconds
import cv2
import numpy as np
import migraphx
import time
import argparse
class Crnn:
def __init__(self, path):
# 解析推理模型
self.model = migraphx.parse_onnx(path)
# 获取模型的输入name
self.inputName = self.model.get_parameter_names()[0]
# 获取模型的输入尺寸
self.inputShape = self.model.get_parameter_shapes()[self.inputName].lens()
print("inputName:{0} \ninputShape:{1}".format(self.inputName, self.inputShape))
# 模型编译
self.model.compile(t=migraphx.get_target("gpu"), device_id=0) # device_id: 设置GPU设备,默认为0号设备
print("Success to compile")
def infer(self, image):
inputImage = self.prepare_input(image)
# 执行推理
results = self.model.run({self.model.get_parameter_names()[0]: migraphx.argument(inputImage)})
# 获取第一个输出节点的数据,migraphx.argument类型
result=results[0]
result=np.array(result)
text = self.decode(result)
final_text = self.map_rule(text)
return text, final_text
def prepare_input(self, image):
img_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
inputSize = (self.inputShape[3], self.inputShape[2])
blob = cv2.dnn.blobFromImage(img_gray, scalefactor=1 / 127.5, size=inputSize, mean=127.5)
return blob
def decode(self, scores):
alphabet = "-0123456789abcdefghijklmnopqrstuvwxyz"
text = ""
# 获取模型预测的文本序列
for i in range(scores.shape[0]):
c = np.argmax(scores[i][0])
text += alphabet[c]
return text
def map_rule(self, text):
char_list = []
for i in range(len(text)):
if i == 0:
if text[i] != '-':
char_list.append(text[i])
else:
if text[i] != '-' and (not (text[i] == text[i - 1])):
char_list.append(text[i])
return ''.join(char_list)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--imgpath', type=str, default='./data/text.jpg', help="image path")
parser.add_argument('--modelpath', type=str, default='./weights/crnn.onnx', help="onnx filepath")
args = parser.parse_args()
crnn = Crnn(args.modelpath)
srcimg = cv2.imread(args.imgpath, 1)
# 执行推理
print("Start to inference")
start = time.time()
resultRaw, resultSim = crnn.infer(srcimg)
print('net forward time: {:.4f}'.format(time.time() - start))
print("============= Ocr Results =============")
print('%-20s => %-20s' % (resultRaw, resultSim))
# CRNN_MIGraphX # CRNN
This project constructs a CRNN character recognition inference example ## 模型介绍
\ No newline at end of file
CRNN是文本识别领域的一种经典算法,该算法的主要思想是认为文本识别需要对序列进行预测,所以采用了预测序列常用的RNN网络。算法通过CNN提取图片特征,然后采用RNN对序列进行预测,最终使用CTC方法得到最终结果。
## 模型结构
CRNN模型的主要结构包括基于CNN的图像特征提取模块以及基于双向LSTM的文字序列特征提取模块。
## 推理
### 环境配置
[光源](https://www.sourcefind.cn/#/image/dcu/custom)可拉取用于推理的docker镜像,CRNN 模型推理推荐的镜像如下:
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:ort_dcu_1.14.0_migraphx2.5.2_dtk22.10.1
```
[光合开发者社区](https://cancon.hpccube.com:65024/4/main/)可下载MIGraphX安装包,python依赖安装:
```
pip install -r requirements.txt
```
### 运行示例
RetinaFace模型的推理示例程序是Crnn_infer_migraphx.py,使用如下命令运行该推理示例:
```
python Crnn_infer_migraphx.py
```
该示例输入样本图像为:
<img src="./data/text.jpg" alt="Result" />
文本识别结果为:
```
a-----v--a-i-l-a-bb-l-e--- => available
```
## 历史版本
​ https://developer.hpccube.com/codes/modelzoo/crnn_migraphx
## 参考
​ https://github.com/meijieru/crnn.pytorch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment