"vscode:/vscode.git/clone" did not exist on "9465b6bf8158a7ba657c55e354c4492925809757"
test.py 4.44 KB
Newer Older
liuhy's avatar
liuhy committed
1
2
import argparse
import cv2
liuhy's avatar
liuhy committed
3
import os
liuhy's avatar
liuhy committed
4
5
6
import torch
import numpy as np
from lprnet import build_lprnet
liuhy's avatar
liuhy committed
7
from load_data import CHARS, LPRDataLoader
liuhy's avatar
liuhy committed
8
import time
liuhy's avatar
liuhy committed
9
10
from torch.utils.data import *
from torch.autograd import Variable
liuhy's avatar
liuhy committed
11

liuhy's avatar
liuhy committed
12
13
14
15
16
17
18
19
20
21
def collate_fn(batch):
    imgs = []
    labels = []
    lengths = []
    for _, sample in enumerate(batch):
        img, label, length = sample
        imgs.append(torch.from_numpy(img))
        labels.extend(label)
        lengths.append(length)
    labels = np.asarray(labels).flatten().astype(np.float32)
liuhy's avatar
liuhy committed
22

liuhy's avatar
liuhy committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    return (torch.stack(imgs, 0), torch.from_numpy(labels), lengths)

def Greedy_Decode_Eval(Net, datasets, args):
    epoch_size = len(datasets) // args.batch_size
    batch_iterator = iter(DataLoader(datasets, args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn))

    Tp = 0
    Tn_1 = 0
    Tn_2 = 0
    t1 = time.time()
    for i in range(epoch_size):
        # load train data
        images, labels, lengths = next(batch_iterator)
        start = 0
        targets = []
        for length in lengths:
            label = labels[start:start+length]
            targets.append(label)
            start += length
        targets = np.array([el.numpy() for el in targets])

        if args.cuda:
            images = Variable(images.cuda())
        else:
            images = Variable(images)

        # forward
        prebs = Net(images)
        # greedy decode
        prebs = prebs.cpu().detach().numpy()
        preb_max = np.argmax(prebs, axis=1)
        preb_labels = list()
        for preb in preb_max:
            no_repeat_blank_label = list()
            pre_c = preb[0]
            if pre_c != len(CHARS) - 1:
                no_repeat_blank_label.append(pre_c)
            for c in preb: # dropout repeate label and blank label
                if (pre_c == c) or (c == len(CHARS) - 1):
                    if c == len(CHARS) - 1:
                        pre_c = c
                    continue
                no_repeat_blank_label.append(c)
liuhy's avatar
liuhy committed
66
                pre_c = c
liuhy's avatar
liuhy committed
67
68
69
70
71
72
73
74
75
76
77
78
79
            preb_labels.append(no_repeat_blank_label)
        for i, label in enumerate(preb_labels):
            if len(label) != len(targets[i]):
                Tn_1 += 1
                continue
            if (np.asarray(targets[i]) == np.asarray(label)).all():
                Tp += 1
            else:
                Tn_2 += 1
    Acc = Tp * 1.0 / (Tp + Tn_1 + Tn_2)
    print("[Info] Test Accuracy: {} [{}:{}:{}:{}]".format(Acc, Tp, Tn_1, Tn_2, (Tp+Tn_1+Tn_2)))
    t2 = time.time()
    print("[Info] Test Speed: {}s 1/{}]".format((t2 - t1) / len(datasets), len(datasets)))
liuhy's avatar
liuhy committed
80

liuhy's avatar
liuhy committed
81
def validation(args):
liuhy's avatar
liuhy committed
82
83
84
85
86
87
88
    lprnet = build_lprnet(class_num=len(CHARS), phase=args.phase_train)
    lprnet.load_state_dict(torch.load(args.model))
    lprnet.to(args.device)
    print("Successful to build network!")
    test_img_dirs = os.path.expanduser(args.imgpath)
    test_dataset = LPRDataLoader(test_img_dirs.split(','), args.img_size)
    Greedy_Decode_Eval(lprnet, test_dataset, args)
liuhy's avatar
liuhy committed
89

liuhy's avatar
liuhy committed
90
    if args.export_onnx:
liuhy's avatar
liuhy committed
91
        print('export pytorch model to onnx model...')
liuhy's avatar
liuhy committed
92
93
        onnx_input = torch.randn(1, 3, 24, 94, device=args.device)
        torch.onnx.export(
liuhy's avatar
liuhy committed
94
            lprnet,
liuhy's avatar
liuhy committed
95
96
97
98
99
100
101
102
            onnx_input,
            'LPRNet.onnx',
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}} if args.dynamic else None,
            opset_version=12,
            )

liuhy's avatar
liuhy committed
103
104
105
106
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='parameters to train net')
    parser.add_argument('--img_size', default=[94, 24], help='the image size')
    parser.add_argument('--imgpath', default="imgs", help='the image path')
liuhy's avatar
liuhy committed
107
    parser.add_argument('--model', default='model/lprnet.pth', help='model path to vaildate')
liuhy's avatar
liuhy committed
108
109
    parser.add_argument('--batch_size', default=100, type=int, help='testing batch size.')
    parser.add_argument('--cuda', default=True, type=bool, help='Use cuda to train model')
liuhy's avatar
liuhy committed
110
111
112
    parser.add_argument('--device', default='cuda', help='Use cuda to vaildate model')
    parser.add_argument('--export_onnx', default=False, help='export model to onnx')
    parser.add_argument('--dynamic', default=False, help='use dynamic batch size')
liuhy's avatar
liuhy committed
113
114
115
    parser.add_argument('--phase_train', default=False, type=bool, help='train or test phase flag.')
    parser.add_argument('--num_workers', default=8, type=int, help='Number of workers used in dataloading')
    
liuhy's avatar
liuhy committed
116
    args = parser.parse_args()
liuhy's avatar
liuhy committed
117
    validation(args)
liuhy's avatar
liuhy committed
118