"vscode:/vscode.git/clone" did not exist on "95dc093b195e5999699cd7bdba60867c7e60fc92"
Commit ca719792 authored by liuhy's avatar liuhy
Browse files

编写readme

parent d0d41b4a
import onnxruntime as ort
import cv2
import numpy as np
import argparse
import os
print('Runing Based On:', ort.get_device())
......@@ -27,10 +29,7 @@ def LPRNetPostprocess(infer_res):
for j in range(infer_res.shape[1]):
preb_label.append(np.argmax(infer_res[:, j], axis=0))
no_repeat_blank_label = []
print(preb_label)
pre_c = preb_label[0]
print(pre_c)
if pre_c != len(CHARS) - 1:
no_repeat_blank_label.append(pre_c)
for c in preb_label: # dropout repeate label and blank label
......@@ -43,31 +42,32 @@ def LPRNetPostprocess(infer_res):
result = ''.join(list(map(lambda x: CHARS[x], no_repeat_blank_label)))
return result
def LPRNetInference(model, imgs):
img = LPRNetPreprocess(imgs)
def LPRNetInference(args):
if ort.get_device() == "GPU":
sess = ort.InferenceSession(model, providers=['ROCMExecutionProvider'],) #DCU版本
sess = ort.InferenceSession(args.model, providers=['ROCMExecutionProvider'],) #DCU版本
else:
sess = ort.InferenceSession(model, providers=['CPUExecutionProvider']) # CPU版本
print(sess.get_providers())
sess = ort.InferenceSession(args.model, providers=['CPUExecutionProvider']) # CPU版本
if os.path.isdir(args.imgpath):
images = os.listdir(args.imgpath)
for image in images:
img = LPRNetPreprocess(os.path.join(args.imgpath, image))
intput = sess.get_inputs()[0].shape
preb = sess.run(None, input_feed={sess.get_inputs()[0].name: img})[0]
result = LPRNetPostprocess(preb)
return result
print('Inference Result:', result)
else:
img = LPRNetPreprocess(args.imgpath)
intput = sess.get_inputs()[0].shape
preb = sess.run(None, input_feed={sess.get_inputs()[0].name: img})[0]
result = LPRNetPostprocess(preb)
print('Inference Result:', result)
if __name__ == '__main__':
model_name = 'model/LPRNet.onnx'
# model_name = 'LPRNet.onnx'
# image = 'imgs/川JK0707.jpg'
import os
images = os.listdir('/code/lpr_ori/data/test')
count = 0
for image in images:
label = image[:-4]
InferRes = LPRNetInference(model_name, os.path.join('/code/lpr_ori/data/test', image))
print(image, 'Inference Result:', InferRes)
if label == InferRes:
count += 1
print('acc rate:', count / len(images))
parser = argparse.ArgumentParser(description='parameters to vaildate net')
parser.add_argument('--model', default='model/LPRNet.onnx', help='model path to vaildate')
parser.add_argument('--imgpath', default='imgs', help='the image path')
args = parser.parse_args()
LPRNetInference(args)
......@@ -5,6 +5,8 @@ MIGraphX示例程序
import cv2
import numpy as np
import migraphx
import argparse
import os
CHARS = ['京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
'苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
......@@ -42,30 +44,40 @@ def LPRNetPostprocess(infer_res):
result = ''.join(list(map(lambda x: CHARS[x], no_repeat_blank_label)))
return result
def LPRNetInference(model_name, imgs):
img = LPRNetPreprocess(imgs)
def LPRNetInference(args):
# 加载模型
if model_name[-3:] == 'mxr':
model = migraphx.load(model_name)
if args.model[-3:] == 'mxr':
model = migraphx.load(args.model)
else:
print('convert onnx to mxr...')
model = migraphx.parse_onnx(model_name)
model = migraphx.parse_onnx(args.model)
model.compile(t=migraphx.get_target("gpu"),device_id=0) # device_id: 设置GPU设备,默认为0号设备(>=1.2版本中支持)
migraphx.save(model, 'model/LPRNet.mxr')
migraphx.save(model, args.savepath)
if os.path.isdir(args.imgpath):
images = os.listdir(args.imgpath)
for image in images:
img = LPRNetPreprocess(os.path.join(args.imgpath, image))
inputName=model.get_parameter_names()[0]
inputShape=model.get_parameter_shapes()[inputName].lens()
print("inputName:{0} \ninputShape:{1}".format(inputName,inputShape))
# print("inputName:{0} \ninputShape:{1}".format(inputName,inputShape))
results = model.run({inputName: migraphx.argument(img)})
result = LPRNetPostprocess(np.array(results[0]))
return result
print('Inference Result:', result)
else:
img = LPRNetPreprocess(args.imgpath)
inputName=model.get_parameter_names()[0]
inputShape=model.get_parameter_shapes()[inputName].lens()
# print("inputName:{0} \ninputShape:{1}".format(inputName,inputShape))
results = model.run({inputName: migraphx.argument(img)})
result = LPRNetPostprocess(np.array(results[0]))
print('Inference Result:', result)
if __name__ == '__main__':
# model_name = 'model/LPRNet.onnx'
model_name = 'model/LPRNet.mxr'
image = 'imgs/川JK0707.jpg'
InferRes = LPRNetInference(model_name, image)
print(image, 'Inference Result:', InferRes)
parser = argparse.ArgumentParser(description='parameters to vaildate net')
parser.add_argument('--model', default='model/LPRNet.mxr', help='model path to inference')
parser.add_argument('--imgpath', default='imgs/川JK0707.jpg', help='the image path')
parser.add_argument('--savepath', default='model/LPRNet.mxr', help='mxr model save path and name')
args = parser.parse_args()
LPRNetInference(args)
......@@ -6,13 +6,52 @@ LPR是一个基于深度学习技术的车牌识别模型,主要识别目标
模型采用LPRNet,模型结构主要包含三部分:一个轻量级CNN主干网络、基于预定位置的字符分类头部、基于贪婪算法的序列解码。此外模型使用CTC Loss和RMSprop优化器。
## 数据集
推荐使用一个车牌数据集[CCPD](https://github.com/detectRecog/CCPD "CCPD官网GITHub"),也可参考[CCPD](https://blog.csdn.net/LuohenYJ/article/details/117752120 "CCPD中文版介绍"),该数据集由中科大收集,可用于车牌的检测与识别。我们提供了一个脚本cut_ccpd.py用于剪裁出CCPD数据集中的车牌位置,以便用于LPR模型的训练,在cut_ccpd.py中修改img_path和save_path即可,分别是CCPD数据集中ccpd_base文件夹的路径和剪裁出的图像保存路径。
推荐使用一个车牌数据集[CCPD](https://github.com/detectRecog/CCPD "CCPD官网GitHub"),也可参考[CCPD](https://blog.csdn.net/LuohenYJ/article/details/117752120 "CCPD中文版介绍"),该数据集由中科大收集,可用于车牌的检测与识别。我们提供了一个脚本cut_ccpd.py用于剪裁出CCPD数据集中的车牌位置,以便用于LPR模型的训练,在cut_ccpd.py中修改img_path和save_path即可,分别是CCPD数据集中ccpd_base文件夹的路径和剪裁出的图像保存路径。LPR用于训练的数据文件名就是图像的标签。
## 训练及推理
导出onnx模型:
python test.py --export_onnx true
### 训练与Fine-tunning
LPR模型的训练程序是train.py,初次训练模型使用以下命令:
'''
python train.py \
--train_img_dirs 训练集文件夹路径 \
--test_img_dirs 验证集文件夹路径
'''
Fine-tunning使用以下命令:
'''
python train.py \
--train_img_dirs 训练集文件夹路径 \
--test_img_dirs 验证集文件夹路径 \
--pretrained_model 预训练模型路径 \
--resume_epoch Fine-tuning训练的起始epoch \ #fine-tuning时只训练从起始epoch到最大epoch
--max_epoch 训练的最大epoch
'''
### 测试
LPR模型用test.py对训练出的模型进行测试,使用方法如下:
'''
python test.py \
--model 需要测试的pth模型路径 \
--imgpath 测试集路径 # 单张图像或文件夹皆可
--export_onnx 该参数用于选择是否需要将pth模型转为onnx模型
--dynamic 该参数用于选择onnx模型是否使用动态的batch size
'''
### 推理
我们分别提供了基于OnnxRuntime(ORT)和Migraphx的推理脚本
#### ORT
LPRNet_ORT_infer.py是基于ORT的的推理脚本,使用方法:
'''
python LPRNet_ORT_infer.py --model onnx模型路径 --imgpath 数据路径(文件夹图像皆可)
'''
#### Migraphx
LPRNet_migraphx_infer.py是基于Migraphx的推理脚本,使用需安装好Migraphx,支持onnx模型和mxr模型推理,mxr模型是migraphx将onnx模型保存成的离线推理引擎,初次使用onnx模型会保存对应的mxr模型。使用方法:
'''
python LPRNet_migraphx_infer.py --model mxr/onnx模型路径 --imgpath 数据路径(文件夹图像皆可) --savepath mxr模型的保存路径以及模型名称
'''
推理onnx模型:
python LPRNet_ORT_infer.py
## 性能和准确率数据
## 参考
\ No newline at end of file
No preview for this file type
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment