import argparse import cv2 import os import torch import numpy as np from lprnet import build_lprnet from load_data import CHARS def infer(args, image, model): img = cv2.imread(image) height, width, _ = img.shape if height != 24 or width != 94: img = cv2.resize(img, (94, 24)) img = img.astype('float32') img -= 127.5 img *= 0.0078125 img = np.transpose(img, (2, 0, 1)) with torch.no_grad(): img = torch.from_numpy(img).unsqueeze(0).to(args.device) preb = model(img) preb = preb.detach().cpu().numpy().squeeze() preb_label = [] for j in range(preb.shape[1]): preb_label.append(np.argmax(preb[:, j], axis=0)) no_repeat_blank_label = [] pre_c = preb_label[0] if pre_c != len(CHARS) - 1: no_repeat_blank_label.append(pre_c) for c in preb_label: if (pre_c == c) or (c == len(CHARS) - 1): if c == len(CHARS) - 1: pre_c = c continue no_repeat_blank_label.append(c) pre_c = c return ''.join(list(map(lambda x: CHARS[x], no_repeat_blank_label))) def validation(args): model = build_lprnet(len(CHARS)) model.load_state_dict(torch.load(args.model, map_location=args.device)) model.to(args.device) if os.path.isdir(args.imgpath): images = os.listdir(args.imgpath) count = 0 for image in images: res = infer(args, os.path.join(args.imgpath, image), model) if res == image[:-4]: count += 1 print('Image: ' + image + ' recongise result: '+ res) print('acc rate:', count / len(images)) else: res = infer(args, args.imgpath, model) print('Image: ' + args.imgpath + ' recongise result: '+ res) if args.export_onnx: print('export pytroch model to onnx model...') onnx_input = torch.randn(1, 3, 24, 94, device=args.device) torch.onnx.export( model, onnx_input, 'LPRNet.onnx', input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}} if args.dynamic else None, opset_version=12, ) return res if __name__ == '__main__': parser = argparse.ArgumentParser(description='parameters to vaildate net') # parser.add_argument('--model', default='model/lprnet.pth', help='model path to vaildate') parser.add_argument('--model', default='weights/Final_LPRNet_model.pth', help='model path to vaildate') # parser.add_argument('--imgpath', default='imgs/川JK0707.jpg', help='the image path') parser.add_argument('--imgpath', default='/code/lpr_ori/data/test', help='the image path') parser.add_argument('--device', default='cuda', help='Use cuda to vaildate model') parser.add_argument('--export_onnx', default=False, help='export model to onnx') parser.add_argument('--dynamic', default=False, help='use dynamic batch size') args = parser.parse_args() result = validation(args)