Commit dfc4d118 authored by liuhy's avatar liuhy
Browse files

修改代码

parent 9adcf60d
......@@ -60,6 +60,14 @@ def LPRNetInference(model, imgs):
if __name__ == '__main__':
model_name = 'model/LPRNet.onnx'
# model_name = 'LPRNet.onnx'
image = 'imgs/川JK0707.jpg'
InferRes = LPRNetInference(model_name, image)
print(image, 'Inference Result:', InferRes)
# image = 'imgs/川JK0707.jpg'
import os
images = os.listdir('/code/lpr_ori/data/test')
count = 0
for image in images:
label = image[:-4]
InferRes = LPRNetInference(model_name, os.path.join('/code/lpr_ori/data/test', image))
print(image, 'Inference Result:', InferRes)
if label == InferRes:
count += 1
print('acc rate:', count / len(images))
import argparse
import cv2
import os
import torch
import numpy as np
from lprnet import build_lprnet
from load_data import CHARS
def validation(args):
model = build_lprnet(len(CHARS))
model.load_state_dict(torch.load(args.model, map_location=args.device))
model.to(args.device)
img = cv2.imread(args.img)
def infer(args, image, model):
img = cv2.imread(image)
height, width, _ = img.shape
if height != 24 or width != 94:
img = cv2.resize(img, (94, 24))
......@@ -38,7 +34,26 @@ def validation(args):
continue
no_repeat_blank_label.append(c)
pre_c = c
return ''.join(list(map(lambda x: CHARS[x], no_repeat_blank_label)))
def validation(args):
model = build_lprnet(len(CHARS))
model.load_state_dict(torch.load(args.model, map_location=args.device))
model.to(args.device)
if os.path.isdir(args.imgpath):
images = os.listdir(args.imgpath)
count = 0
for image in images:
res = infer(args, os.path.join(args.imgpath, image), model)
if res == image[:-4]:
count += 1
print('Image: ' + image + ' recongise result: '+ res)
print('acc rate:', count / len(images))
else:
res = infer(args, args.imgpath, model)
print('Image: ' + args.imgpath + ' recongise result: '+ res)
if args.export_onnx:
print('export pytroch model to onnx model...')
onnx_input = torch.randn(1, 3, 24, 94, device=args.device)
......@@ -51,16 +66,19 @@ def validation(args):
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}} if args.dynamic else None,
opset_version=12,
)
return ''.join(list(map(lambda x: CHARS[x], no_repeat_blank_label)))
return res
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='parameters to vaildate net')
parser.add_argument('--model', default='model/lprnet.pth', help='model path to vaildate')
parser.add_argument('--img', default='imgs/川JK0707.jpg', help='the image path')
# parser.add_argument('--model', default='model/lprnet.pth', help='model path to vaildate')
parser.add_argument('--model', default='weights/Final_LPRNet_model.pth', help='model path to vaildate')
# parser.add_argument('--imgpath', default='imgs/川JK0707.jpg', help='the image path')
parser.add_argument('--imgpath', default='/code/lpr_ori/data/test', help='the image path')
parser.add_argument('--device', default='cuda', help='Use cuda to vaildate model')
parser.add_argument('--export_onnx', default=False, help='export model to onnx')
parser.add_argument('--dynamic', default=False, help='use dynamic batch size')
args = parser.parse_args()
result = validation(args)
print('recongise result:', result)
......@@ -2,7 +2,6 @@
# /usr/bin/env/python3
from load_data import CHARS, CHARS_DICT, LPRDataLoader
from lprnet import build_lprnet
# import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import *
......@@ -14,7 +13,7 @@ import torch
import time
import os
print(torch.cuda.is_available())
print('Cuda Availabel:', torch.cuda.is_available())
def sparse_tuple_for_ctc(T_length, lengths):
input_lengths = []
......@@ -23,7 +22,6 @@ def sparse_tuple_for_ctc(T_length, lengths):
for ch in lengths:
input_lengths.append(T_length)
target_lengths.append(ch)
return tuple(input_lengths), tuple(target_lengths)
def adjust_learning_rate(optimizer, cur_epoch, base_lr, lr_schedule):
......@@ -39,37 +37,8 @@ def adjust_learning_rate(optimizer, cur_epoch, base_lr, lr_schedule):
lr = base_lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
def get_parser():
parser = argparse.ArgumentParser(description='parameters to train net')
parser.add_argument('--max_epoch', default=15, help='epoch to train the network')
parser.add_argument('--img_size', default=[94, 24], help='the image size')
parser.add_argument('--train_img_dirs', default="data/train", help='the train images path')
parser.add_argument('--test_img_dirs', default="data/test", help='the test images path')
parser.add_argument('--dropout_rate', default=0.5, help='dropout rate.')
parser.add_argument('--learning_rate', default=0.1, help='base value of learning rate.')
parser.add_argument('--lpr_max_len', default=8, help='license plate number max length.')
parser.add_argument('--train_batch_size', default=64, help='training batch size.')
parser.add_argument('--test_batch_size', default=10, help='testing batch size.')
parser.add_argument('--phase_train', default=True, type=bool, help='train or test phase flag.')
parser.add_argument('--num_workers', default=8, type=int, help='Number of workers used in dataloading')
parser.add_argument('--cuda', default=True, type=bool, help='Use cuda to train model')
parser.add_argument('--resume_epoch', default=10, type=int, help='resume iter for retraining')
parser.add_argument('--save_interval', default=2000, type=int, help='interval for save model state dict')
parser.add_argument('--test_interval', default=2000, type=int, help='interval for evaluate')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=2e-5, type=float, help='Weight decay for SGD')
parser.add_argument('--lr_schedule', default=[4, 8, 12, 14, 16], help='schedule for learning rate.')
parser.add_argument('--save_folder', default='./weights/', help='Location to save checkpoint models')
# parser.add_argument('--pretrained_model', default='./weights/Final_LPRNet_model.pth', help='pretrained base model')
parser.add_argument('--pretrained_model', default='', help='pretrained base model')
args = parser.parse_args()
return args
def collate_fn(batch):
imgs = []
labels = []
......@@ -80,12 +49,71 @@ def collate_fn(batch):
labels.extend(label)
lengths.append(length)
labels = np.asarray(labels).flatten().astype(np.int16)
return (torch.stack(imgs, 0), torch.from_numpy(labels), lengths)
def train():
args = get_parser()
def Greedy_Decode_Eval(Net, datasets, args):
epoch_size = len(datasets) // args.test_batch_size
batch_iterator = iter(DataLoader(datasets, args.test_batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn))
Tp = 0
Tn_1 = 0
Tn_2 = 0
t1 = time.time()
for i in range(epoch_size):
# load train data
images, labels, lengths = next(batch_iterator)
start = 0
targets = []
for length in lengths:
label = labels[start:start+length]
targets.append(label)
start += length
targets = np.array([el.numpy() for el in targets])
if args.cuda:
images = Variable(images.cuda())
else:
images = Variable(images)
# forward
Net.eval()
prebs = Net(images)
# greedy decode
prebs = prebs.cpu().detach().numpy()
preb_labels = []
for i in range(prebs.shape[0]):
preb = prebs[i, :, :]
preb_label = []
for j in range(preb.shape[1]):
preb_label.append(np.argmax(preb[:, j], axis=0))
no_repeat_blank_label = []
pre_c = preb_label[0]
if pre_c != len(CHARS) - 1:
no_repeat_blank_label.append(pre_c)
for c in preb_label: # dropout repeate label and blank label
if (pre_c == c) or (c == len(CHARS) - 1):
if c == len(CHARS) - 1:
pre_c = c
continue
no_repeat_blank_label.append(c)
pre_c = c
preb_labels.append(no_repeat_blank_label)
for i, label in enumerate(preb_labels):
if len(label) != len(targets[i]):
Tn_1 += 1
continue
if (np.asarray(targets[i]) == np.asarray(label)).all():
Tp += 1
else:
Tn_2 += 1
Acc = Tp * 1.0 / (Tp + Tn_1 + Tn_2)
print("[Info] Test Accuracy: {} [{}:{}:{}:{}]".format(Acc, Tp, Tn_1, Tn_2, (Tp+Tn_1+Tn_2)))
t2 = time.time()
print("[Info] Test Speed: {}s 1/{}]".format((t2 - t1) / len(datasets), len(datasets)))
def train(args):
T_length = 18 # args.lpr_max_len
epoch = 0 + args.resume_epoch
loss_val = 0
......@@ -121,8 +149,6 @@ def train():
print("initial net weights successful!")
# define optimizer
# optimizer = optim.SGD(lprnet.parameters(), lr=args.learning_rate,
# momentum=args.momentum, weight_decay=args.weight_decay)
optimizer = optim.RMSprop(lprnet.parameters(), lr=args.learning_rate, alpha = 0.9, eps=1e-08,
momentum=args.momentum, weight_decay=args.weight_decay)
train_img_dirs = os.path.expanduser(args.train_img_dirs)
......@@ -148,17 +174,14 @@ def train():
epoch += 1
if iteration !=0 and iteration % args.save_interval == 0:
torch.save(lprnet.state_dict(), args.save_folder + 'LPRNet_' + '_iteration_' + repr(iteration) + '.pth')
torch.save(lprnet.state_dict(), args.save_folder + 'LPRNet_' + 'iteration_' + repr(iteration) + '.pth')
if (iteration + 1) % args.test_interval == 0:
Greedy_Decode_Eval(lprnet, test_dataset, args)
# lprnet.train() # should be switch to train mode
start_time = time.time()
# load train data
images, labels, lengths = next(batch_iterator)
# labels = np.array([el.numpy() for el in labels]).T
# print(labels)
# get ctc parameters
input_lengths, target_lengths = sparse_tuple_for_ctc(T_length, lengths)
# update lr
......@@ -173,12 +196,8 @@ def train():
# forward
logits = lprnet(images)
# print(logits.size())
log_probs = logits.permute(2, 0, 1) # for ctc loss: T x N x C
# print(labels.shape)
log_probs = log_probs.log_softmax(2).requires_grad_()
# log_probs = log_probs.detach().requires_grad_()
# print(log_probs.shape)
# backprop
optimizer.zero_grad()
loss = ctc_loss(log_probs, labels, input_lengths=input_lengths, target_lengths=target_lengths)
......@@ -199,67 +218,31 @@ def train():
# save final parameters
torch.save(lprnet.state_dict(), args.save_folder + 'Final_LPRNet_model.pth')
def Greedy_Decode_Eval(Net, datasets, args):
# TestNet = Net.eval()
epoch_size = len(datasets) // args.test_batch_size
batch_iterator = iter(DataLoader(datasets, args.test_batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn))
Tp = 0
Tn_1 = 0
Tn_2 = 0
t1 = time.time()
for i in range(epoch_size):
# load train data
images, labels, lengths = next(batch_iterator)
start = 0
targets = []
for length in lengths:
label = labels[start:start+length]
targets.append(label)
start += length
targets = np.array([el.numpy() for el in targets])
if args.cuda:
images = Variable(images.cuda())
else:
images = Variable(images)
# forward
prebs = Net(images)
# greedy decode
prebs = prebs.cpu().detach().numpy()
preb_labels = list()
for i in range(prebs.shape[0]):
preb = prebs[i, :, :]
preb_label = list()
for j in range(preb.shape[1]):
preb_label.append(np.argmax(preb[:, j], axis=0))
no_repeat_blank_label = list()
pre_c = preb_label[0]
if pre_c != len(CHARS) - 1:
no_repeat_blank_label.append(pre_c)
for c in preb_label: # dropout repeate label and blank label
if (pre_c == c) or (c == len(CHARS) - 1):
if c == len(CHARS) - 1:
pre_c = c
continue
no_repeat_blank_label.append(c)
pre_c = c
preb_labels.append(no_repeat_blank_label)
for i, label in enumerate(preb_labels):
if len(label) != len(targets[i]):
Tn_1 += 1
continue
if (np.asarray(targets[i]) == np.asarray(label)).all():
Tp += 1
else:
Tn_2 += 1
Acc = Tp * 1.0 / (Tp + Tn_1 + Tn_2)
print("[Info] Test Accuracy: {} [{}:{}:{}:{}]".format(Acc, Tp, Tn_1, Tn_2, (Tp+Tn_1+Tn_2)))
t2 = time.time()
print("[Info] Test Speed: {}s 1/{}]".format((t2 - t1) / len(datasets), len(datasets)))
def get_parser():
parser = argparse.ArgumentParser(description='parameters to train net')
parser.add_argument('--max_epoch', default=15, help='epoch to train the network')
parser.add_argument('--img_size', default=[94, 24], help='the image size')
parser.add_argument('--train_img_dirs', default="data/train", help='the train images path')
parser.add_argument('--test_img_dirs', default="data/test", help='the test images path')
parser.add_argument('--dropout_rate', default=0.5, help='dropout rate.')
parser.add_argument('--learning_rate', default=0.1, help='base value of learning rate.')
parser.add_argument('--lpr_max_len', default=8, help='license plate number max length.')
parser.add_argument('--train_batch_size', default=64, help='training batch size.')
parser.add_argument('--test_batch_size', default=10, help='testing batch size.')
parser.add_argument('--phase_train', default=True, type=bool, help='train or test phase flag.')
parser.add_argument('--num_workers', default=8, type=int, help='Number of workers used in dataloading')
parser.add_argument('--cuda', default=True, type=bool, help='Use cuda to train model')
parser.add_argument('--resume_epoch', default=10, type=int, help='resume iter for retraining')
parser.add_argument('--save_interval', default=2000, type=int, help='interval for save model state dict')
parser.add_argument('--test_interval', default=2000, type=int, help='interval for evaluate')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=2e-5, type=float, help='Weight decay for SGD')
parser.add_argument('--lr_schedule', default=[4, 8, 12, 14, 16], help='schedule for learning rate.')
parser.add_argument('--save_folder', default='./weights/', help='Location to save checkpoint models')
parser.add_argument('--pretrained_model', default='./weights/Final_LPRNet_model.pth', help='pretrained base model')
args = parser.parse_args()
return args
if __name__ == "__main__":
train()
args = get_parser()
train(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment