Commit 1fe2937e authored by liuhy's avatar liuhy
Browse files

修改test代码

parent 40622bae
...@@ -4,60 +4,89 @@ import os ...@@ -4,60 +4,89 @@ import os
import torch import torch
import numpy as np import numpy as np
from lprnet import build_lprnet from lprnet import build_lprnet
from load_data import CHARS from load_data import CHARS, LPRDataLoader
import time import time
from torch.utils.data import *
from torch.autograd import Variable
def infer(args, image, model): def collate_fn(batch):
img = cv2.imread(image) imgs = []
height, width, _ = img.shape labels = []
if height != 24 or width != 94: lengths = []
img = cv2.resize(img, (94, 24)) for _, sample in enumerate(batch):
img = img.astype('float32') img, label, length = sample
img -= 127.5 imgs.append(torch.from_numpy(img))
img *= 0.0078125 labels.extend(label)
img = np.transpose(img, (2, 0, 1)) lengths.append(length)
labels = np.asarray(labels).flatten().astype(np.float32)
with torch.no_grad(): return (torch.stack(imgs, 0), torch.from_numpy(labels), lengths)
img = torch.from_numpy(img).unsqueeze(0).to(args.device)
preb = model(img) def Greedy_Decode_Eval(Net, datasets, args):
preb = preb.detach().cpu().numpy().squeeze() epoch_size = len(datasets) // args.batch_size
preb_label = [] batch_iterator = iter(DataLoader(datasets, args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn))
for j in range(preb.shape[1]):
preb_label.append(np.argmax(preb[:, j], axis=0)) Tp = 0
no_repeat_blank_label = [] Tn_1 = 0
pre_c = preb_label[0] Tn_2 = 0
if pre_c != len(CHARS) - 1: t1 = time.time()
no_repeat_blank_label.append(pre_c) for i in range(epoch_size):
for c in preb_label: # load train data
if (pre_c == c) or (c == len(CHARS) - 1): images, labels, lengths = next(batch_iterator)
if c == len(CHARS) - 1: start = 0
targets = []
for length in lengths:
label = labels[start:start+length]
targets.append(label)
start += length
targets = np.array([el.numpy() for el in targets])
if args.cuda:
images = Variable(images.cuda())
else:
images = Variable(images)
# forward
prebs = Net(images)
# greedy decode
prebs = prebs.cpu().detach().numpy()
preb_max = np.argmax(prebs, axis=1)
preb_labels = list()
for preb in preb_max:
no_repeat_blank_label = list()
pre_c = preb[0]
if pre_c != len(CHARS) - 1:
no_repeat_blank_label.append(pre_c)
for c in preb: # dropout repeate label and blank label
if (pre_c == c) or (c == len(CHARS) - 1):
if c == len(CHARS) - 1:
pre_c = c
continue
no_repeat_blank_label.append(c)
pre_c = c pre_c = c
continue preb_labels.append(no_repeat_blank_label)
no_repeat_blank_label.append(c) for i, label in enumerate(preb_labels):
pre_c = c if len(label) != len(targets[i]):
return ''.join(list(map(lambda x: CHARS[x], no_repeat_blank_label))) Tn_1 += 1
continue
if (np.asarray(targets[i]) == np.asarray(label)).all():
Tp += 1
else:
Tn_2 += 1
Acc = Tp * 1.0 / (Tp + Tn_1 + Tn_2)
print("[Info] Test Accuracy: {} [{}:{}:{}:{}]".format(Acc, Tp, Tn_1, Tn_2, (Tp+Tn_1+Tn_2)))
t2 = time.time()
print("[Info] Test Speed: {}s 1/{}]".format((t2 - t1) / len(datasets), len(datasets)))
def validation(args): def validation(args):
model = build_lprnet(len(CHARS)) lprnet = build_lprnet(class_num=len(CHARS), phase=args.phase_train)
model.load_state_dict(torch.load(args.model, map_location=args.device)) lprnet.load_state_dict(torch.load(args.model))
model.to(args.device) lprnet.to(args.device)
print("Successful to build network!")
test_img_dirs = os.path.expanduser(args.imgpath)
test_dataset = LPRDataLoader(test_img_dirs.split(','), args.img_size)
Greedy_Decode_Eval(lprnet, test_dataset, args)
if os.path.isdir(args.imgpath):
images = os.listdir(args.imgpath)
count = 0
time1 = time.perf_counter()
for image in images:
result = infer(args, os.path.join(args.imgpath, image), model)
if result == image[:-4]:
count += 1
print('Image: ' + image + ' recongise result: '+ result)
time2 = time.perf_counter()
print('accuracy rate:', count / len(images))
print('average time', (time2 - time1)/count*1000)
else:
result = infer(args, args.imgpath, model)
print('Image: ' + args.imgpath + ' recongise result: '+ result)
if args.export_onnx: if args.export_onnx:
print('export pytorch model to onnx model...') print('export pytorch model to onnx model...')
onnx_input = torch.randn(1, 3, 24, 94, device=args.device) onnx_input = torch.randn(1, 3, 24, 94, device=args.device)
...@@ -70,16 +99,20 @@ def validation(args): ...@@ -70,16 +99,20 @@ def validation(args):
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}} if args.dynamic else None, dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}} if args.dynamic else None,
opset_version=12, opset_version=12,
) )
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='parameters to vaildate net') parser = argparse.ArgumentParser(description='parameters to train net')
parser.add_argument('--img_size', default=[94, 24], help='the image size')
parser.add_argument('--imgpath', default="imgs", help='the image path')
parser.add_argument('--model', default='model/lprnet.pth', help='model path to vaildate') parser.add_argument('--model', default='model/lprnet.pth', help='model path to vaildate')
parser.add_argument('--imgpath', default='imgs', help='the image path') parser.add_argument('--batch_size', default=100, type=int, help='testing batch size.')
parser.add_argument('--cuda', default=True, type=bool, help='Use cuda to train model')
parser.add_argument('--device', default='cuda', help='Use cuda to vaildate model') parser.add_argument('--device', default='cuda', help='Use cuda to vaildate model')
parser.add_argument('--export_onnx', default=False, help='export model to onnx') parser.add_argument('--export_onnx', default=False, help='export model to onnx')
parser.add_argument('--dynamic', default=False, help='use dynamic batch size') parser.add_argument('--dynamic', default=False, help='use dynamic batch size')
parser.add_argument('--phase_train', default=False, type=bool, help='train or test phase flag.')
parser.add_argument('--num_workers', default=8, type=int, help='Number of workers used in dataloading')
args = parser.parse_args() args = parser.parse_args()
validation(args) validation(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment