"tests/vscode:/vscode.git/clone" did not exist on "e50c25d8081d483658245d792aca816e2fec49dd"
Commit 71ed68c9 authored by LDOUBLEV's avatar LDOUBLEV
Browse files

add comparion script

parent f262e33e
import numpy as np
import os
import subprocess
import json
import argparse
def init_args():
parser = argparse.ArgumentParser()
# params for prediction engine
parser.add_argument("--atol", type=float, default=1e-3)
parser.add_argument("--rtol", type=float, default=1e-3)
parser.add_argument("--gt_file", type=str, default="")
parser.add_argument("--log_file", type=str, default="")
return parser
def parse_args():
parser = init_args()
return parser.parse_args()
def run_shell_command(cmd):
p = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
out, err = p.communicate()
if p.returncode == 0:
return out.decode('utf-8')
else:
return None
def parser_results_from_log_by_name(log_path, names_list):
if not os.path.exists(log_path):
raise ValueError("The log file {} does not exists!".format(log_path))
if names_list is None or len(names_list) < 1:
return []
parser_results = {}
for name in names_list:
cmd = "grep {} {}".format(name, log_path)
outs = run_shell_command(cmd)
outs = outs.split("\n")[0]
result = outs.split("{}".format(name))[-1]
result = json.loads(result)
parser_results[name] = result
return parser_results
def load_gt_from_file(gt_file):
if not os.path.exists(gt_file):
raise ValueError("The log file {} does not exists!".format(gt_file))
with open(gt_file, 'r') as f:
data = f.readlines()
f.close()
parser_gt = {}
for line in data:
image_name, result = line.strip("\n").split("\t")
result = json.loads(result)
parser_gt[image_name] = result
return parser_gt
def testing_assert_allclose(dict_x, dict_y, atol=1e-7, rtol=1e-7):
for k in dict_x:
np.testing.assert_allclose(
np.array(dict_x[k]), np.array(dict_y[k]), atol=atol, rtol=rtol)
if __name__ == "__main__":
# Usage:
# python3.7 tests/compare_results.py --gt_file=./det_results_gpu_fp32.txt --log_file=./test_log.log
args = parse_args()
gt_dict = load_gt_from_file(args.gt_file)
key_list = list(gt_dict.keys())
pred_dict = parser_results_from_log_by_name(args.log_file, key_list)
testing_assert_allclose(gt_dict, pred_dict, atol=args.atol, rtol=args.rtol)
This diff is collapsed.
...@@ -30,7 +30,7 @@ from ppocr.utils.logging import get_logger ...@@ -30,7 +30,7 @@ from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
import json
logger = get_logger() logger = get_logger()
...@@ -242,6 +242,7 @@ if __name__ == "__main__": ...@@ -242,6 +242,7 @@ if __name__ == "__main__":
if not os.path.exists(draw_img_save): if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save) os.makedirs(draw_img_save)
save_results = []
for image_file in image_file_list: for image_file in image_file_list:
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
...@@ -255,7 +256,10 @@ if __name__ == "__main__": ...@@ -255,7 +256,10 @@ if __name__ == "__main__":
if count > 0: if count > 0:
total_time += elapse total_time += elapse
count += 1 count += 1
save_pred = os.path.basename(image_file) + "\t" + str(
json.dumps(np.array(dt_boxes).astype(np.int32).tolist())) + "\n"
save_results.append(save_pred)
logger.info(save_pred)
logger.info("Predict time of {}: {}".format(image_file, elapse)) logger.info("Predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, image_file) src_im = utility.draw_text_det_res(dt_boxes, image_file)
img_name_pure = os.path.split(image_file)[-1] img_name_pure = os.path.split(image_file)[-1]
...@@ -264,5 +268,8 @@ if __name__ == "__main__": ...@@ -264,5 +268,8 @@ if __name__ == "__main__":
cv2.imwrite(img_path, src_im) cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path)) logger.info("The visualized image saved in {}".format(img_path))
with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f:
f.writelines(save_results)
f.close()
if args.benchmark: if args.benchmark:
text_detector.autolog.report() text_detector.autolog.report()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment