test.py 2.29 KB
Newer Older
liuhy's avatar
liuhy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse
import cv2
import torch
import numpy as np
from lprnet import build_lprnet
from load_data import CHARS


def validation(args):
    model = build_lprnet(len(CHARS))
    model.load_state_dict(torch.load(args.model, map_location=args.device))
    model.to(args.device)

    img = cv2.imread(args.imgpath)
    height, width, _ = img.shape
    if height != 24 or width != 94:
        img = cv2.resize(img, (94, 24))
    img = img.astype('float32')
    img -= 127.5
    img *= 0.0078125
    img = np.transpose(img, (2, 0, 1))

    with torch.no_grad():
        img = torch.from_numpy(img).unsqueeze(0).to(args.device)
        preb = model(img)
        preb = preb.detach().cpu().numpy().squeeze()
    preb_label = []
    for j in range(preb.shape[1]):
        preb_label.append(np.argmax(preb[:, j], axis=0))
    no_repeat_blank_label = []
    pre_c = preb_label[0]
    if pre_c != len(CHARS) - 1:
        no_repeat_blank_label.append(pre_c)
    for c in preb_label:
        if (pre_c == c) or (c == len(CHARS) - 1):
            if c == len(CHARS) - 1:
                pre_c = c
            continue
        no_repeat_blank_label.append(c)
        pre_c = c

    if args.export_onnx:
        print('export pytroch model to onnx model...')
        onnx_input = torch.randn(1, 3, 24, 94, device=args.device)
        torch.onnx.export(
            model,
            onnx_input,
            'LPRNet.onnx',
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}} if args.dynamic else None,
            opset_version=12,
            )
    return ''.join(list(map(lambda x: CHARS[x], no_repeat_blank_label)))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='parameters to vaildate net')
    parser.add_argument('--model', default='model/lprnet.pth', help='model path to vaildate')
    parser.add_argument('--img', default='imgs/川JK0707.jpg', help='the image path')
    parser.add_argument('--device', default='cuda', help='Use cuda to vaildate model')
    parser.add_argument('--export_onnx', default=False, help='export model to onnx')
    parser.add_argument('--dynamic', default=False, help='use dynamic batch size')
    args = parser.parse_args()

    result = validation(args)
    print('recongise result:', result)