Commit 1d05ebdc authored by liuhy's avatar liuhy
Browse files

优化代码

parent 1089716d
...@@ -26,9 +26,7 @@ def LPRNetPreprocess(image): ...@@ -26,9 +26,7 @@ def LPRNetPreprocess(image):
return img return img
def LPRNetPostprocess(infer_res): def LPRNetPostprocess(infer_res):
preb_label = [] preb_label = np.argmax(infer_res, axis=0)
for j in range(infer_res.shape[1]):
preb_label.append(np.argmax(infer_res[:, j], axis=0))
no_repeat_blank_label = [] no_repeat_blank_label = []
pre_c = preb_label[0] pre_c = preb_label[0]
if pre_c != len(CHARS) - 1: if pre_c != len(CHARS) - 1:
...@@ -51,7 +49,9 @@ def LPRNetInference(args): ...@@ -51,7 +49,9 @@ def LPRNetInference(args):
if os.path.isdir(args.imgpath): if os.path.isdir(args.imgpath):
images = os.listdir(args.imgpath) images = os.listdir(args.imgpath)
count = 0 Tp = 0
Tn_1 = 0
Tn_2 = 0
time1 = time.perf_counter() time1 = time.perf_counter()
for image in images: for image in images:
img = LPRNetPreprocess(os.path.join(args.imgpath, image)) img = LPRNetPreprocess(os.path.join(args.imgpath, image))
...@@ -59,11 +59,16 @@ def LPRNetInference(args): ...@@ -59,11 +59,16 @@ def LPRNetInference(args):
preb = sess.run(None, input_feed={sess.get_inputs()[0].name: img})[0] preb = sess.run(None, input_feed={sess.get_inputs()[0].name: img})[0]
result = LPRNetPostprocess(preb) result = LPRNetPostprocess(preb)
if result == image[:-4]: if result == image[:-4]:
count += 1 Tp += 1
print('Inference Result:', result) elif len(result) != len(image[:-4]):
time2 = time.perf_counter() Tn_1 += 1
print('accuracy rate:', count / len(images)) else:
print('average time', (time2 - time1)/count*1000) Tn_2 += 1
print(image + ' Inference Result:', result)
time2 = time.perf_counter()
Acc = Tp * 1.0 / (Tp + Tn_1 + Tn_2)
print("[Info] Test Accuracy: {} [{}:{}:{}:{}]".format(Acc, Tp, Tn_1, Tn_2, (Tp+Tn_1+Tn_2)))
print("[Info] Test Speed: {}s 1/{}]".format((time2 - time1) / len(images), len(images)))
else: else:
img = LPRNetPreprocess(args.imgpath) img = LPRNetPreprocess(args.imgpath)
intput = sess.get_inputs()[0].shape intput = sess.get_inputs()[0].shape
......
...@@ -28,9 +28,7 @@ def LPRNetPreprocess(image): ...@@ -28,9 +28,7 @@ def LPRNetPreprocess(image):
return img return img
def LPRNetPostprocess(infer_res): def LPRNetPostprocess(infer_res):
preb_label = [] preb_label = np.argmax(infer_res, axis=0)
for j in range(infer_res.shape[1]):
preb_label.append(np.argmax(infer_res[:, j], axis=0))
no_repeat_blank_label = [] no_repeat_blank_label = []
pre_c = preb_label[0] pre_c = preb_label[0]
if pre_c != len(CHARS) - 1: if pre_c != len(CHARS) - 1:
...@@ -57,7 +55,9 @@ def LPRNetInference(args): ...@@ -57,7 +55,9 @@ def LPRNetInference(args):
if os.path.isdir(args.imgpath): if os.path.isdir(args.imgpath):
images = os.listdir(args.imgpath) images = os.listdir(args.imgpath)
count = 0 Tp = 0
Tn_1 = 0
Tn_2 = 0
time1 = time.perf_counter() time1 = time.perf_counter()
for image in images: for image in images:
img = LPRNetPreprocess(os.path.join(args.imgpath, image)) img = LPRNetPreprocess(os.path.join(args.imgpath, image))
...@@ -67,11 +67,16 @@ def LPRNetInference(args): ...@@ -67,11 +67,16 @@ def LPRNetInference(args):
results = model.run({inputName: migraphx.argument(img)}) results = model.run({inputName: migraphx.argument(img)})
result = LPRNetPostprocess(np.array(results[0])) result = LPRNetPostprocess(np.array(results[0]))
if result == image[:-4]: if result == image[:-4]:
count += 1 Tp += 1
print('Inference Result:', result) elif len(result) != len(image[:-4]):
Tn_1 += 1
else:
Tn_2 += 1
print(image + ' Inference Result:', result)
time2 = time.perf_counter() time2 = time.perf_counter()
print('accuracy rate:', count / len(images)) Acc = Tp * 1.0 / (Tp + Tn_1 + Tn_2)
print('average time', (time2 - time1)/count*1000) print("[Info] Test Accuracy: {} [{}:{}:{}:{}]".format(Acc, Tp, Tn_1, Tn_2, (Tp+Tn_1+Tn_2)))
print("[Info] Test Speed: {}s 1/{}]".format((time2 - time1) / len(images), len(images)))
else: else:
img = LPRNetPreprocess(args.imgpath) img = LPRNetPreprocess(args.imgpath)
inputName=model.get_parameter_names()[0] inputName=model.get_parameter_names()[0]
......
No preview for this file type
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment