Commit 7c6309db authored by LDOUBLEV's avatar LDOUBLEV
Browse files

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into test

parents 58ca7639 d93a445d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import cv2
import json
from tqdm import tqdm
from ppstructure.table.table_metric import TEDS
from ppstructure.table.predict_table import TableSystem
from ppstructure.utility import init_args
def parse_args():
parser = init_args()
parser.add_argument("--gt_path", type=str)
return parser.parse_args()
def main(gt_path, img_root, args):
teds = TEDS(n_jobs=16)
text_sys = TableSystem(args)
jsons_gt = json.load(open(gt_path)) # gt
pred_htmls = []
gt_htmls = []
for img_name in tqdm(jsons_gt):
# read image
img = cv2.imread(os.path.join(img_root,img_name))
pred_html = text_sys(img)
pred_htmls.append(pred_html)
gt_structures, gt_bboxes, gt_contents, contents_with_block = jsons_gt[img_name]
gt_html, gt = get_gt_html(gt_structures, contents_with_block)
gt_htmls.append(gt_html)
scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
print('teds:', sum(scores) / len(scores))
def get_gt_html(gt_structures, contents_with_block):
end_html = []
td_index = 0
for tag in gt_structures:
if '</td>' in tag:
if contents_with_block[td_index] != []:
end_html.extend(contents_with_block[td_index])
end_html.append(tag)
td_index += 1
else:
end_html.append(tag)
return ''.join(end_html), end_html
if __name__ == '__main__':
args = parse_args()
main(args.gt_path,args.image_dir, args)
import json
def distance(box_1, box_2):
x1, y1, x2, y2 = box_1
x3, y3, x4, y4 = box_2
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
dis_2 = abs(x3 - x1) + abs(y3 - y1)
dis_3 = abs(x4- x2) + abs(y4 - y2)
return dis + min(dis_2, dis_3)
def compute_iou(rec1, rec2):
"""
computing IoU
:param rec1: (y0, x0, y1, x1), which reflects
(top, left, bottom, right)
:param rec2: (y0, x0, y1, x1)
:return: scala value of IoU
"""
# computing area of each rectangles
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
# computing the sum_area
sum_area = S_rec1 + S_rec2
# find the each edge of intersect rectangle
left_line = max(rec1[1], rec2[1])
right_line = min(rec1[3], rec2[3])
top_line = max(rec1[0], rec2[0])
bottom_line = min(rec1[2], rec2[2])
# judge if there is an intersect
if left_line >= right_line or top_line >= bottom_line:
return 0.0
else:
intersect = (right_line - left_line) * (bottom_line - top_line)
return (intersect / (sum_area - intersect))*1.0
def matcher_merge(ocr_bboxes, pred_bboxes):
all_dis = []
ious = []
matched = {}
for i, gt_box in enumerate(ocr_bboxes):
distances = []
for j, pred_box in enumerate(pred_bboxes):
# compute l1 distence and IOU between two boxes
distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box)))
sorted_distances = distances.copy()
# select nearest cell
sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0]))
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
matched[distances.index(sorted_distances[0])].append(i)
return matched#, sum(ious) / len(ious)
def complex_num(pred_bboxes):
complex_nums = []
for bbox in pred_bboxes:
distances = []
temp_ious = []
for pred_bbox in pred_bboxes:
if bbox != pred_bbox:
distances.append(distance(bbox, pred_bbox))
temp_ious.append(compute_iou(bbox, pred_bbox))
complex_nums.append(temp_ious[distances.index(min(distances))])
return sum(complex_nums) / len(complex_nums)
def get_rows(pred_bboxes):
pre_bbox = pred_bboxes[0]
res = []
step = 0
for i in range(len(pred_bboxes)):
bbox = pred_bboxes[i]
if bbox[1] - pre_bbox[1] > 2 or bbox[0] - pre_bbox[0] < 0:
break
else:
res.append(bbox)
step += 1
for i in range(step):
pred_bboxes.pop(0)
return res, pred_bboxes
def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
ys_1 = []
ys_2 = []
for box in pred_bboxes:
ys_1.append(box[1])
ys_2.append(box[3])
min_y_1 = sum(ys_1) / len(ys_1)
min_y_2 = sum(ys_2) / len(ys_2)
re_boxes = []
for box in pred_bboxes:
box[1] = min_y_1
box[3] = min_y_2
re_boxes.append(box)
return re_boxes
def matcher_refine_row(gt_bboxes, pred_bboxes):
before_refine_pred_bboxes = pred_bboxes.copy()
pred_bboxes = []
while(len(before_refine_pred_bboxes) != 0):
row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes)
print(row_bboxes)
pred_bboxes.extend(refine_rows(row_bboxes))
all_dis = []
ious = []
matched = {}
for i, gt_box in enumerate(gt_bboxes):
distances = []
#temp_ious = []
for j, pred_box in enumerate(pred_bboxes):
distances.append(distance(gt_box, pred_box))
#temp_ious.append(compute_iou(gt_box, pred_box))
#all_dis.append(min(distances))
#ious.append(temp_ious[distances.index(min(distances))])
if distances.index(min(distances)) not in matched.keys():
matched[distances.index(min(distances))] = [i]
else:
matched[distances.index(min(distances))].append(i)
return matched#, sum(ious) / len(ious)
#先挑选出一行,再进行匹配
def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
gt_box_index = 0
delete_gt_bboxes = gt_bboxes.copy()
match_bboxes_ready = []
matched = {}
while(len(delete_gt_bboxes) != 0):
row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes)
row_bboxes = sorted(row_bboxes, key = lambda key: key[0])
if len(pred_bboxes_rows) > 0:
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
print(row_bboxes)
for i, gt_box in enumerate(row_bboxes):
#print(gt_box)
pred_distances = []
distances = []
for pred_bbox in pred_bboxes:
pred_distances.append(distance(gt_box, pred_bbox))
for j, pred_box in enumerate(match_bboxes_ready):
distances.append(distance(gt_box, pred_box))
index = pred_distances.index(min(distances))
#print('index', index)
if index not in matched.keys():
matched[index] = [gt_box_index]
else:
matched[index].append(gt_box_index)
gt_box_index += 1
return matched
def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
'''
gt_bboxes: 排序后
pred_bboxes:
'''
pre_bbox = gt_bboxes[0]
matched = {}
match_bboxes_ready = []
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
for i, gt_box in enumerate(gt_bboxes):
pred_distances = []
for pred_bbox in pred_bboxes:
pred_distances.append(distance(gt_box, pred_bbox))
distances = []
gap_pre = gt_box[1] - pre_bbox[1]
gap_pre_1 = gt_box[0] - pre_bbox[2]
#print(gap_pre, len(pred_bboxes_rows))
if (gap_pre_1 < 0 and len(pred_bboxes_rows) > 0):
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
if len(pred_bboxes_rows) == 1:
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) > 0:
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) == 0:
break
#print(match_bboxes_ready)
for j, pred_box in enumerate(match_bboxes_ready):
distances.append(distance(gt_box, pred_box))
index = pred_distances.index(min(distances))
#print(gt_box, index)
#match_bboxes_ready.pop(distances.index(min(distances)))
print(gt_box, match_bboxes_ready[distances.index(min(distances))])
if index not in matched.keys():
matched[index] = [i]
else:
matched[index].append(i)
pre_bbox = gt_box
return matched
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import numpy as np
import math
import time
import traceback
import paddle
import tools.infer.utility as utility
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
logger = get_logger()
class TableStructurer(object):
def __init__(self, args):
pre_process_list = [{
'ResizeTableImage': {
'max_len': args.structure_max_len
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'PaddingTableImage': None
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image']
}
}]
postprocess_params = {
'name': 'TableLabelDecode',
"character_type": args.structure_char_type,
"character_dict_path": args.structure_char_dict_path,
"max_text_length": args.structure_max_text_length,
"max_elem_length": args.structure_max_elem_length,
"max_cell_num": args.structure_max_cell_num
}
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = \
utility.create_predictor(args, 'structure', logger)
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
data = transform(data, self.preprocess_op)
img = data[0]
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
img = img.copy()
starttime = time.time()
self.input_tensor.copy_from_cpu(img)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = {}
preds['structure_probs'] = outputs[1]
preds['loc_preds'] = outputs[0]
post_result = self.postprocess_op(preds)
structure_str_list = post_result['structure_str_list']
res_loc = post_result['res_loc']
imgh, imgw = ori_im.shape[0:2]
res_loc_final = []
for rno in range(len(res_loc[0])):
x0, y0, x1, y1 = res_loc[0][rno]
left = max(int(imgw * x0), 0)
top = max(int(imgh * y0), 0)
right = min(int(imgw * x1), imgw - 1)
bottom = min(int(imgh * y1), imgh - 1)
res_loc_final.append([left, top, right, bottom])
structure_str_list = structure_str_list[0][:-1]
structure_str_list = ['<html>', '<body>', '<table>'] + structure_str_list + ['</table>', '</body>', '</html>']
elapse = time.time() - starttime
return (structure_str_list, res_loc_final), elapse
def main(args):
image_file_list = get_image_file_list(args.image_dir)
table_structurer = TableStructurer(args)
count = 0
total_time = 0
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
structure_res, elapse = table_structurer(img)
logger.info("result: {}".format(structure_res))
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
if __name__ == "__main__":
main(utility.parse_args())
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import copy
import numpy as np
import time
import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from ppstructure.table.matcher import distance, compute_iou
from ppstructure.utility import parse_args
import ppstructure.table.predict_structure as predict_strture
logger = get_logger()
def expand(pix, det_box, shape):
x0, y0, x1, y1 = det_box
# print(shape)
h, w, c = shape
tmp_x0 = x0 - pix
tmp_x1 = x1 + pix
tmp_y0 = y0 - pix
tmp_y1 = y1 + pix
x0_ = tmp_x0 if tmp_x0 >= 0 else 0
x1_ = tmp_x1 if tmp_x1 <= w else w
y0_ = tmp_y0 if tmp_y0 >= 0 else 0
y1_ = tmp_y1 if tmp_y1 <= h else h
return x0_, y0_, x1_, y1_
class TableSystem(object):
def __init__(self, args, text_detector=None, text_recognizer=None):
self.text_detector = predict_det.TextDetector(args) if text_detector is None else text_detector
self.text_recognizer = predict_rec.TextRecognizer(args) if text_recognizer is None else text_recognizer
self.table_structurer = predict_strture.TableStructurer(args)
def __call__(self, img):
ori_im = img.copy()
structure_res, elapse = self.table_structurer(copy.deepcopy(img))
dt_boxes, elapse = self.text_detector(copy.deepcopy(img))
dt_boxes = sorted_boxes(dt_boxes)
r_boxes = []
for box in dt_boxes:
x_min = box[:, 0].min() - 1
x_max = box[:, 0].max() + 1
y_min = box[:, 1].min() - 1
y_max = box[:, 1].max() + 1
box = [x_min, y_min, x_max, y_max]
r_boxes.append(box)
dt_boxes = np.array(r_boxes)
logger.debug("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse))
if dt_boxes is None:
return None, None
img_crop_list = []
for i in range(len(dt_boxes)):
det_box = dt_boxes[i]
x0, y0, x1, y1 = expand(2, det_box, ori_im.shape)
text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :]
img_crop_list.append(text_rect)
rec_res, elapse = self.text_recognizer(img_crop_list)
logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
return pred_html
def rebuild_table(self, structure_res, dt_boxes, rec_res):
pred_structures, pred_bboxes = structure_res
matched_index = self.match_result(dt_boxes, pred_bboxes)
pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
return pred_html, pred
def match_result(self, dt_boxes, pred_bboxes):
matched = {}
for i, gt_box in enumerate(dt_boxes):
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
distances = []
for j, pred_box in enumerate(pred_bboxes):
distances.append(
(distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) # 获取两两cell之间的L1距离和 1- IOU
sorted_distances = distances.copy()
# 根据距离和IOU挑选最"近"的cell
sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0]))
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
matched[distances.index(sorted_distances[0])].append(i)
return matched
def get_pred_html(self, pred_structures, matched_index, ocr_contents):
end_html = []
td_index = 0
for tag in pred_structures:
if '</td>' in tag:
if td_index in matched_index.keys():
b_with = False
if '<b>' in ocr_contents[matched_index[td_index][0]] and len(matched_index[td_index]) > 1:
b_with = True
end_html.extend('<b>')
for i, td_index_index in enumerate(matched_index[td_index]):
content = ocr_contents[td_index_index][0]
if len(matched_index[td_index]) > 1:
if len(content) == 0:
continue
if content[0] == ' ':
content = content[1:]
if '<b>' in content:
content = content[3:]
if '</b>' in content:
content = content[:-4]
if len(content) == 0:
continue
if i != len(matched_index[td_index]) - 1 and ' ' != content[-1]:
content += ' '
end_html.extend(content)
if b_with:
end_html.extend('</b>')
end_html.append(tag)
td_index += 1
else:
end_html.append(tag)
return ''.join(end_html), end_html
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
def to_excel(html_table, excel_path):
from tablepyxl import tablepyxl
tablepyxl.document_to_xl(html_table, excel_path)
def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list[args.process_id::args.total_process_num]
os.makedirs(args.output, exist_ok=True)
text_sys = TableSystem(args)
img_num = len(image_file_list)
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file)
excel_path = os.path.join(args.table_output, os.path.basename(image_file).split('.')[0] + '.xlsx')
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
pred_html = text_sys(img)
to_excel(pred_html, excel_path)
logger.info('excel saved to {}'.format(excel_path))
logger.info(pred_html)
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
if __name__ == "__main__":
args = parse_args()
if args.use_mp:
p_list = []
total_process_num = args.total_process_num
for process_id in range(total_process_num):
cmd = [sys.executable, "-u"] + sys.argv + [
"--process_id={}".format(process_id),
"--use_mp={}".format(False)
]
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
p_list.append(p)
for p in p_list:
p.wait()
else:
main(args)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['TEDS']
from .table_metric import TEDS
\ No newline at end of file
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
"""
A parallel version of the map function with a progress bar.
Args:
array (array-like): An array to iterate over.
function (function): A python function to apply to the elements of array
n_jobs (int, default=16): The number of cores to use
use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
keyword arguments to function
front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
Useful for catching bugs
Returns:
[function(array[0]), function(array[1]), ...]
"""
# We run the first few iterations serially to catch bugs
if front_num > 0:
front = [function(**a) if use_kwargs else function(a)
for a in array[:front_num]]
else:
front = []
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
if n_jobs == 1:
return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
# Assemble the workers
with ProcessPoolExecutor(max_workers=n_jobs) as pool:
# Pass the elements of array into function
if use_kwargs:
futures = [pool.submit(function, **a) for a in array[front_num:]]
else:
futures = [pool.submit(function, a) for a in array[front_num:]]
kwargs = {
'total': len(futures),
'unit': 'it',
'unit_scale': True,
'leave': True
}
# Print out the progress as tasks complete
for f in tqdm(as_completed(futures), **kwargs):
pass
out = []
# Get the results from the futures.
for i, future in tqdm(enumerate(futures)):
try:
out.append(future.result())
except Exception as e:
out.append(e)
return front + out
# Copyright 2020 IBM
# Author: peter.zhong@au1.ibm.com
#
# This is free software; you can redistribute it and/or modify
# it under the terms of the Apache 2.0 License.
#
# This software is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# Apache 2.0 License for more details.
import distance
from apted import APTED, Config
from apted.helpers import Tree
from lxml import etree, html
from collections import deque
from .parallel import parallel_process
from tqdm import tqdm
class TableTree(Tree):
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
self.tag = tag
self.colspan = colspan
self.rowspan = rowspan
self.content = content
self.children = list(children)
def bracket(self):
"""Show tree using brackets notation"""
if self.tag == 'td':
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
(self.tag, self.colspan, self.rowspan, self.content)
else:
result = '"tag": %s' % self.tag
for child in self.children:
result += child.bracket()
return "{{{}}}".format(result)
class CustomConfig(Config):
@staticmethod
def maximum(*sequences):
"""Get maximum possible value
"""
return max(map(len, sequences))
def normalized_distance(self, *sequences):
"""Get distance from 0 to 1
"""
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
def rename(self, node1, node2):
"""Compares attributes of trees"""
#print(node1.tag)
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
return 1.
if node1.tag == 'td':
if node1.content or node2.content:
#print(node1.content, )
return self.normalized_distance(node1.content, node2.content)
return 0.
class CustomConfig_del_short(Config):
@staticmethod
def maximum(*sequences):
"""Get maximum possible value
"""
return max(map(len, sequences))
def normalized_distance(self, *sequences):
"""Get distance from 0 to 1
"""
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
def rename(self, node1, node2):
"""Compares attributes of trees"""
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
return 1.
if node1.tag == 'td':
if node1.content or node2.content:
#print('before')
#print(node1.content, node2.content)
#print('after')
node1_content = node1.content
node2_content = node2.content
if len(node1_content) < 3:
node1_content = ['####']
if len(node2_content) < 3:
node2_content = ['####']
return self.normalized_distance(node1_content, node2_content)
return 0.
class CustomConfig_del_block(Config):
@staticmethod
def maximum(*sequences):
"""Get maximum possible value
"""
return max(map(len, sequences))
def normalized_distance(self, *sequences):
"""Get distance from 0 to 1
"""
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
def rename(self, node1, node2):
"""Compares attributes of trees"""
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
return 1.
if node1.tag == 'td':
if node1.content or node2.content:
node1_content = node1.content
node2_content = node2.content
while ' ' in node1_content:
print(node1_content.index(' '))
node1_content.pop(node1_content.index(' '))
while ' ' in node2_content:
print(node2_content.index(' '))
node2_content.pop(node2_content.index(' '))
return self.normalized_distance(node1_content, node2_content)
return 0.
class TEDS(object):
''' Tree Edit Distance basead Similarity
'''
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
assert isinstance(n_jobs, int) and (
n_jobs >= 1), 'n_jobs must be an integer greather than 1'
self.structure_only = structure_only
self.n_jobs = n_jobs
self.ignore_nodes = ignore_nodes
self.__tokens__ = []
def tokenize(self, node):
''' Tokenizes table cells
'''
self.__tokens__.append('<%s>' % node.tag)
if node.text is not None:
self.__tokens__ += list(node.text)
for n in node.getchildren():
self.tokenize(n)
if node.tag != 'unk':
self.__tokens__.append('</%s>' % node.tag)
if node.tag != 'td' and node.tail is not None:
self.__tokens__ += list(node.tail)
def load_html_tree(self, node, parent=None):
''' Converts HTML tree to the format required by apted
'''
global __tokens__
if node.tag == 'td':
if self.structure_only:
cell = []
else:
self.__tokens__ = []
self.tokenize(node)
cell = self.__tokens__[1:-1].copy()
new_node = TableTree(node.tag,
int(node.attrib.get('colspan', '1')),
int(node.attrib.get('rowspan', '1')),
cell, *deque())
else:
new_node = TableTree(node.tag, None, None, None, *deque())
if parent is not None:
parent.children.append(new_node)
if node.tag != 'td':
for n in node.getchildren():
self.load_html_tree(n, new_node)
if parent is None:
return new_node
def evaluate(self, pred, true):
''' Computes TEDS score between the prediction and the ground truth of a
given sample
'''
if (not pred) or (not true):
return 0.0
parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
pred = html.fromstring(pred, parser=parser)
true = html.fromstring(true, parser=parser)
if pred.xpath('body/table') and true.xpath('body/table'):
pred = pred.xpath('body/table')[0]
true = true.xpath('body/table')[0]
if self.ignore_nodes:
etree.strip_tags(pred, *self.ignore_nodes)
etree.strip_tags(true, *self.ignore_nodes)
n_nodes_pred = len(pred.xpath(".//*"))
n_nodes_true = len(true.xpath(".//*"))
n_nodes = max(n_nodes_pred, n_nodes_true)
tree_pred = self.load_html_tree(pred)
tree_true = self.load_html_tree(true)
distance = APTED(tree_pred, tree_true,
CustomConfig()).compute_edit_distance()
return 1.0 - (float(distance) / n_nodes)
else:
return 0.0
def batch_evaluate(self, pred_json, true_json):
''' Computes TEDS score between the prediction and the ground truth of
a batch of samples
@params pred_json: {'FILENAME': 'HTML CODE', ...}
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
@output: {'FILENAME': 'TEDS SCORE', ...}
'''
samples = true_json.keys()
if self.n_jobs == 1:
scores = [self.evaluate(pred_json.get(
filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
else:
inputs = [{'pred': pred_json.get(
filename, ''), 'true': true_json[filename]['html']} for filename in samples]
scores = parallel_process(
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
scores = dict(zip(samples, scores))
return scores
def batch_evaluate_html(self, pred_htmls, true_htmls):
''' Computes TEDS score between the prediction and the ground truth of
a batch of samples
'''
if self.n_jobs == 1:
scores = [self.evaluate(pred_html, true_html) for (
pred_html, true_html) in zip(pred_htmls, true_htmls)]
else:
inputs = [{"pred": pred_html, "true": true_html} for(
pred_html, true_html) in zip(pred_htmls, true_htmls)]
scores = parallel_process(
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
return scores
if __name__ == '__main__':
import json
import pprint
with open('sample_pred.json') as fp:
pred_json = json.load(fp)
with open('sample_gt.json') as fp:
true_json = json.load(fp)
teds = TEDS(n_jobs=4)
scores = teds.batch_evaluate(pred_json, true_json)
pp = pprint.PrettyPrinter()
pp.pprint(scores)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
# This is where we handle translating css styles into openpyxl styles
# and cascading those from parent to child in the dom.
from openpyxl.cell import cell
from openpyxl.styles import Font, Alignment, PatternFill, NamedStyle, Border, Side, Color
from openpyxl.styles.fills import FILL_SOLID
from openpyxl.styles.numbers import FORMAT_CURRENCY_USD_SIMPLE, FORMAT_PERCENTAGE
from openpyxl.styles.colors import BLACK
FORMAT_DATE_MMDDYYYY = 'mm/dd/yyyy'
def colormap(color):
"""
Convenience for looking up known colors
"""
cmap = {'black': BLACK}
return cmap.get(color, color)
def style_string_to_dict(style):
"""
Convert css style string to a python dictionary
"""
def clean_split(string, delim):
return (s.strip() for s in string.split(delim))
styles = [clean_split(s, ":") for s in style.split(";") if ":" in s]
return dict(styles)
def get_side(style, name):
return {'border_style': style.get('border-{}-style'.format(name)),
'color': colormap(style.get('border-{}-color'.format(name)))}
known_styles = {}
def style_dict_to_named_style(style_dict, number_format=None):
"""
Change css style (stored in a python dictionary) to openpyxl NamedStyle
"""
style_and_format_string = str({
'style_dict': style_dict,
'parent': style_dict.parent,
'number_format': number_format,
})
if style_and_format_string not in known_styles:
# Font
font = Font(bold=style_dict.get('font-weight') == 'bold',
color=style_dict.get_color('color', None),
size=style_dict.get('font-size'))
# Alignment
alignment = Alignment(horizontal=style_dict.get('text-align', 'general'),
vertical=style_dict.get('vertical-align'),
wrap_text=style_dict.get('white-space', 'nowrap') == 'normal')
# Fill
bg_color = style_dict.get_color('background-color')
fg_color = style_dict.get_color('foreground-color', Color())
fill_type = style_dict.get('fill-type')
if bg_color and bg_color != 'transparent':
fill = PatternFill(fill_type=fill_type or FILL_SOLID,
start_color=bg_color,
end_color=fg_color)
else:
fill = PatternFill()
# Border
border = Border(left=Side(**get_side(style_dict, 'left')),
right=Side(**get_side(style_dict, 'right')),
top=Side(**get_side(style_dict, 'top')),
bottom=Side(**get_side(style_dict, 'bottom')),
diagonal=Side(**get_side(style_dict, 'diagonal')),
diagonal_direction=None,
outline=Side(**get_side(style_dict, 'outline')),
vertical=None,
horizontal=None)
name = 'Style {}'.format(len(known_styles) + 1)
pyxl_style = NamedStyle(name=name, font=font, fill=fill, alignment=alignment, border=border,
number_format=number_format)
known_styles[style_and_format_string] = pyxl_style
return known_styles[style_and_format_string]
class StyleDict(dict):
"""
It's like a dictionary, but it looks for items in the parent dictionary
"""
def __init__(self, *args, **kwargs):
self.parent = kwargs.pop('parent', None)
super(StyleDict, self).__init__(*args, **kwargs)
def __getitem__(self, item):
if item in self:
return super(StyleDict, self).__getitem__(item)
elif self.parent:
return self.parent[item]
else:
raise KeyError('{} not found'.format(item))
def __hash__(self):
return hash(tuple([(k, self.get(k)) for k in self._keys()]))
# Yielding the keys avoids creating unnecessary data structures
# and happily works with both python2 and python3 where the
# .keys() method is a dictionary_view in python3 and a list in python2.
def _keys(self):
yielded = set()
for k in self.keys():
yielded.add(k)
yield k
if self.parent:
for k in self.parent._keys():
if k not in yielded:
yielded.add(k)
yield k
def get(self, k, d=None):
try:
return self[k]
except KeyError:
return d
def get_color(self, k, d=None):
"""
Strip leading # off colors if necessary
"""
color = self.get(k, d)
if hasattr(color, 'startswith') and color.startswith('#'):
color = color[1:]
if len(color) == 3: # Premailers reduces colors like #00ff00 to #0f0, openpyxl doesn't like that
color = ''.join(2 * c for c in color)
return color
class Element(object):
"""
Our base class for representing an html element along with a cascading style.
The element is created along with a parent so that the StyleDict that we store
can point to the parent's StyleDict.
"""
def __init__(self, element, parent=None):
self.element = element
self.number_format = None
parent_style = parent.style_dict if parent else None
self.style_dict = StyleDict(style_string_to_dict(element.get('style', '')), parent=parent_style)
self._style_cache = None
def style(self):
"""
Turn the css styles for this element into an openpyxl NamedStyle.
"""
if not self._style_cache:
self._style_cache = style_dict_to_named_style(self.style_dict, number_format=self.number_format)
return self._style_cache
def get_dimension(self, dimension_key):
"""
Extracts the dimension from the style dict of the Element and returns it as a float.
"""
dimension = self.style_dict.get(dimension_key)
if dimension:
if dimension[-2:] in ['px', 'em', 'pt', 'in', 'cm']:
dimension = dimension[:-2]
dimension = float(dimension)
return dimension
class Table(Element):
"""
The concrete implementations of Elements are semantically named for the types of elements we are interested in.
This defines a very concrete tree structure for html tables that we expect to deal with. I prefer this compared to
allowing Element to have an arbitrary number of children and dealing with an abstract element tree.
"""
def __init__(self, table):
"""
takes an html table object (from lxml)
"""
super(Table, self).__init__(table)
table_head = table.find('thead')
self.head = TableHead(table_head, parent=self) if table_head is not None else None
table_body = table.find('tbody')
self.body = TableBody(table_body if table_body is not None else table, parent=self)
class TableHead(Element):
"""
This class maps to the `<th>` element of the html table.
"""
def __init__(self, head, parent=None):
super(TableHead, self).__init__(head, parent=parent)
self.rows = [TableRow(tr, parent=self) for tr in head.findall('tr')]
class TableBody(Element):
"""
This class maps to the `<tbody>` element of the html table.
"""
def __init__(self, body, parent=None):
super(TableBody, self).__init__(body, parent=parent)
self.rows = [TableRow(tr, parent=self) for tr in body.findall('tr')]
class TableRow(Element):
"""
This class maps to the `<tr>` element of the html table.
"""
def __init__(self, tr, parent=None):
super(TableRow, self).__init__(tr, parent=parent)
self.cells = [TableCell(cell, parent=self) for cell in tr.findall('th') + tr.findall('td')]
def element_to_string(el):
return _element_to_string(el).strip()
def _element_to_string(el):
string = ''
for x in el.iterchildren():
string += '\n' + _element_to_string(x)
text = el.text.strip() if el.text else ''
tail = el.tail.strip() if el.tail else ''
return text + string + '\n' + tail
class TableCell(Element):
"""
This class maps to the `<td>` element of the html table.
"""
CELL_TYPES = {'TYPE_STRING', 'TYPE_FORMULA', 'TYPE_NUMERIC', 'TYPE_BOOL', 'TYPE_CURRENCY', 'TYPE_PERCENTAGE',
'TYPE_NULL', 'TYPE_INLINE', 'TYPE_ERROR', 'TYPE_FORMULA_CACHE_STRING', 'TYPE_INTEGER'}
def __init__(self, cell, parent=None):
super(TableCell, self).__init__(cell, parent=parent)
self.value = element_to_string(cell)
self.number_format = self.get_number_format()
def data_type(self):
cell_types = self.CELL_TYPES & set(self.element.get('class', '').split())
if cell_types:
if 'TYPE_FORMULA' in cell_types:
# Make sure TYPE_FORMULA takes precedence over the other classes in the set.
cell_type = 'TYPE_FORMULA'
elif cell_types & {'TYPE_CURRENCY', 'TYPE_INTEGER', 'TYPE_PERCENTAGE'}:
cell_type = 'TYPE_NUMERIC'
else:
cell_type = cell_types.pop()
else:
cell_type = 'TYPE_STRING'
return getattr(cell, cell_type)
def get_number_format(self):
if 'TYPE_CURRENCY' in self.element.get('class', '').split():
return FORMAT_CURRENCY_USD_SIMPLE
if 'TYPE_INTEGER' in self.element.get('class', '').split():
return '#,##0'
if 'TYPE_PERCENTAGE' in self.element.get('class', '').split():
return FORMAT_PERCENTAGE
if 'TYPE_DATE' in self.element.get('class', '').split():
return FORMAT_DATE_MMDDYYYY
if self.data_type() == cell.TYPE_NUMERIC:
try:
int(self.value)
except ValueError:
return '#,##0.##'
else:
return '#,##0'
def format(self, cell):
cell.style = self.style()
data_type = self.data_type()
if data_type:
cell.data_type = data_type
\ No newline at end of file
# Do imports like python3 so our package works for 2 and 3
from __future__ import absolute_import
from lxml import html
from openpyxl import Workbook
from openpyxl.utils import get_column_letter
from premailer import Premailer
from tablepyxl.style import Table
def string_to_int(s):
if s.isdigit():
return int(s)
return 0
def get_Tables(doc):
tree = html.fromstring(doc)
comments = tree.xpath('//comment()')
for comment in comments:
comment.drop_tag()
return [Table(table) for table in tree.xpath('//table')]
def write_rows(worksheet, elem, row, column=1):
"""
Writes every tr child element of elem to a row in the worksheet
returns the next row after all rows are written
"""
from openpyxl.cell.cell import MergedCell
initial_column = column
for table_row in elem.rows:
for table_cell in table_row.cells:
cell = worksheet.cell(row=row, column=column)
while isinstance(cell, MergedCell):
column += 1
cell = worksheet.cell(row=row, column=column)
colspan = string_to_int(table_cell.element.get("colspan", "1"))
rowspan = string_to_int(table_cell.element.get("rowspan", "1"))
if rowspan > 1 or colspan > 1:
worksheet.merge_cells(start_row=row, start_column=column,
end_row=row + rowspan - 1, end_column=column + colspan - 1)
cell.value = table_cell.value
table_cell.format(cell)
min_width = table_cell.get_dimension('min-width')
max_width = table_cell.get_dimension('max-width')
if colspan == 1:
# Initially, when iterating for the first time through the loop, the width of all the cells is None.
# As we start filling in contents, the initial width of the cell (which can be retrieved by:
# worksheet.column_dimensions[get_column_letter(column)].width) is equal to the width of the previous
# cell in the same column (i.e. width of A2 = width of A1)
width = max(worksheet.column_dimensions[get_column_letter(column)].width or 0, len(table_cell.value) + 2)
if max_width and width > max_width:
width = max_width
elif min_width and width < min_width:
width = min_width
worksheet.column_dimensions[get_column_letter(column)].width = width
column += colspan
row += 1
column = initial_column
return row
def table_to_sheet(table, wb):
"""
Takes a table and workbook and writes the table to a new sheet.
The sheet title will be the same as the table attribute name.
"""
ws = wb.create_sheet(title=table.element.get('name'))
insert_table(table, ws, 1, 1)
def document_to_workbook(doc, wb=None, base_url=None):
"""
Takes a string representation of an html document and writes one sheet for
every table in the document.
The workbook is returned
"""
if not wb:
wb = Workbook()
wb.remove(wb.active)
inline_styles_doc = Premailer(doc, base_url=base_url, remove_classes=False).transform()
tables = get_Tables(inline_styles_doc)
for table in tables:
table_to_sheet(table, wb)
return wb
def document_to_xl(doc, filename, base_url=None):
"""
Takes a string representation of an html document and writes one sheet for
every table in the document. The workbook is written out to a file called filename
"""
wb = document_to_workbook(doc, base_url=base_url)
wb.save(filename)
def insert_table(table, worksheet, column, row):
if table.head:
row = write_rows(worksheet, table.head, row, column)
if table.body:
row = write_rows(worksheet, table.body, row, column)
def insert_table_at_cell(table, cell):
"""
Inserts a table at the location of an openpyxl Cell object.
"""
ws = cell.parent
column, row = cell.column, cell.row
insert_table(table, ws, column, row)
\ No newline at end of file
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from PIL import Image
import numpy as np
from tools.infer.utility import draw_ocr_box_txt, init_args as infer_args
def init_args():
parser = infer_args()
# params for output
parser.add_argument("--output", type=str, default='./output/table')
# params for table structure
parser.add_argument("--structure_max_len", type=int, default=488)
parser.add_argument("--structure_max_text_length", type=int, default=100)
parser.add_argument("--structure_max_elem_length", type=int, default=800)
parser.add_argument("--structure_max_cell_num", type=int, default=500)
parser.add_argument("--structure_model_dir", type=str)
parser.add_argument("--structure_char_type", type=str, default='en')
parser.add_argument("--structure_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt")
# params for layout detector
parser.add_argument("--layout_model_dir", type=str)
return parser
def parse_args():
parser = init_args()
return parser.parse_args()
def draw_result(image, result, font_path):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
boxes, txts, scores = [], [], []
for region in result:
if region['type'] == 'Table':
pass
elif region['type'] == 'Figure':
pass
else:
for box, rec_res in zip(region['res'][0], region['res'][1]):
boxes.append(np.array(box).reshape(-1, 2))
txts.append(rec_res[0])
scores.append(rec_res[1])
im_show = draw_ocr_box_txt(image, boxes, txts, scores, font_path=font_path,drop_score=0)
return im_show
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
import logging
import paddle
import paddle.inference as paddle_infer
from pathlib import Path
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
class PaddleInferBenchmark(object):
def __init__(self,
config,
model_info: dict={},
data_info: dict={},
perf_info: dict={},
resource_info: dict={},
save_log_path: str="",
**kwargs):
"""
Construct PaddleInferBenchmark Class to format logs.
args:
config(paddle.inference.Config): paddle inference config
model_info(dict): basic model info
{'model_name': 'resnet50'
'precision': 'fp32'}
data_info(dict): input data info
{'batch_size': 1
'shape': '3,224,224'
'data_num': 1000}
perf_info(dict): performance result
{'preprocess_time_s': 1.0
'inference_time_s': 2.0
'postprocess_time_s': 1.0
'total_time_s': 4.0}
resource_info(dict):
cpu and gpu resources
{'cpu_rss': 100
'gpu_rss': 100
'gpu_util': 60}
"""
# PaddleInferBenchmark Log Version
self.log_version = 1.0
# Paddle Version
self.paddle_version = paddle.__version__
self.paddle_commit = paddle.__git_commit__
paddle_infer_info = paddle_infer.get_version()
self.paddle_branch = paddle_infer_info.strip().split(': ')[-1]
# model info
self.model_info = model_info
# data info
self.data_info = data_info
# perf info
self.perf_info = perf_info
try:
self.model_name = model_info['model_name']
self.precision = model_info['precision']
self.batch_size = data_info['batch_size']
self.shape = data_info['shape']
self.data_num = data_info['data_num']
self.preprocess_time_s = round(perf_info['preprocess_time_s'], 4)
self.inference_time_s = round(perf_info['inference_time_s'], 4)
self.postprocess_time_s = round(perf_info['postprocess_time_s'], 4)
self.total_time_s = round(perf_info['total_time_s'], 4)
except:
self.print_help()
raise ValueError(
"Set argument wrong, please check input argument and its type")
# conf info
self.config_status = self.parse_config(config)
self.save_log_path = save_log_path
# mem info
if isinstance(resource_info, dict):
self.cpu_rss_mb = int(resource_info.get('cpu_rss_mb', 0))
self.gpu_rss_mb = int(resource_info.get('gpu_rss_mb', 0))
self.gpu_util = round(resource_info.get('gpu_util', 0), 2)
else:
self.cpu_rss_mb = 0
self.gpu_rss_mb = 0
self.gpu_util = 0
# init benchmark logger
self.benchmark_logger()
def benchmark_logger(self):
"""
benchmark logger
"""
# Init logger
FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
log_output = f"{self.save_log_path}/{self.model_name}.log"
Path(f"{self.save_log_path}").mkdir(parents=True, exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format=FORMAT,
handlers=[
logging.FileHandler(
filename=log_output, mode='w'),
logging.StreamHandler(),
])
self.logger = logging.getLogger(__name__)
self.logger.info(
f"Paddle Inference benchmark log will be saved to {log_output}")
def parse_config(self, config) -> dict:
"""
parse paddle predictor config
args:
config(paddle.inference.Config): paddle inference config
return:
config_status(dict): dict style config info
"""
config_status = {}
config_status['runtime_device'] = "gpu" if config.use_gpu() else "cpu"
config_status['ir_optim'] = config.ir_optim()
config_status['enable_tensorrt'] = config.tensorrt_engine_enabled()
config_status['precision'] = self.precision
config_status['enable_mkldnn'] = config.mkldnn_enabled()
config_status[
'cpu_math_library_num_threads'] = config.cpu_math_library_num_threads(
)
return config_status
def report(self, identifier=None):
"""
print log report
args:
identifier(string): identify log
"""
if identifier:
identifier = f"[{identifier}]"
else:
identifier = ""
self.logger.info("\n")
self.logger.info(
"---------------------- Paddle info ----------------------")
self.logger.info(f"{identifier} paddle_version: {self.paddle_version}")
self.logger.info(f"{identifier} paddle_commit: {self.paddle_commit}")
self.logger.info(f"{identifier} paddle_branch: {self.paddle_branch}")
self.logger.info(f"{identifier} log_api_version: {self.log_version}")
self.logger.info(
"----------------------- Conf info -----------------------")
self.logger.info(
f"{identifier} runtime_device: {self.config_status['runtime_device']}"
)
self.logger.info(
f"{identifier} ir_optim: {self.config_status['ir_optim']}")
self.logger.info(f"{identifier} enable_memory_optim: {True}")
self.logger.info(
f"{identifier} enable_tensorrt: {self.config_status['enable_tensorrt']}"
)
self.logger.info(
f"{identifier} enable_mkldnn: {self.config_status['enable_mkldnn']}")
self.logger.info(
f"{identifier} cpu_math_library_num_threads: {self.config_status['cpu_math_library_num_threads']}"
)
self.logger.info(
"----------------------- Model info ----------------------")
self.logger.info(f"{identifier} model_name: {self.model_name}")
self.logger.info(f"{identifier} precision: {self.precision}")
self.logger.info(
"----------------------- Data info -----------------------")
self.logger.info(f"{identifier} batch_size: {self.batch_size}")
self.logger.info(f"{identifier} input_shape: {self.shape}")
self.logger.info(f"{identifier} data_num: {self.data_num}")
self.logger.info(
"----------------------- Perf info -----------------------")
self.logger.info(
f"{identifier} cpu_rss(MB): {self.cpu_rss_mb}, gpu_rss(MB): {self.gpu_rss_mb}, gpu_util: {self.gpu_util}%"
)
self.logger.info(
f"{identifier} total time spent(s): {self.total_time_s}")
self.logger.info(
f"{identifier} preprocess_time(ms): {round(self.preprocess_time_s*1000, 1)}, inference_time(ms): {round(self.inference_time_s*1000, 1)}, postprocess_time(ms): {round(self.postprocess_time_s*1000, 1)}"
)
def print_help(self):
"""
print function help
"""
print("""Usage:
==== Print inference benchmark logs. ====
config = paddle.inference.Config()
model_info = {'model_name': 'resnet50'
'precision': 'fp32'}
data_info = {'batch_size': 1
'shape': '3,224,224'
'data_num': 1000}
perf_info = {'preprocess_time_s': 1.0
'inference_time_s': 2.0
'postprocess_time_s': 1.0
'total_time_s': 4.0}
resource_info = {'cpu_rss_mb': 100
'gpu_rss_mb': 100
'gpu_util': 60}
log = PaddleInferBenchmark(config, model_info, data_info, perf_info, resource_info)
log('Test')
""")
def __call__(self, identifier=None):
"""
__call__
args:
identifier(string): identify log
"""
self.report(identifier)
......@@ -45,9 +45,11 @@ class TextClassifier(object):
"label_list": args.label_list,
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = \
self.predictor, self.input_tensor, self.output_tensors, _ = \
utility.create_predictor(args, 'cls', logger)
self.cls_times = utility.Timer()
def resize_norm_img(self, img):
imgC, imgH, imgW = self.cls_image_shape
h = img.shape[0]
......@@ -83,7 +85,9 @@ class TextClassifier(object):
cls_res = [['', 0.0]] * img_num
batch_num = self.cls_batch_num
elapse = 0
self.cls_times.total_time.start()
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
......@@ -91,6 +95,7 @@ class TextClassifier(object):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
self.cls_times.preprocess_time.start()
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
......@@ -98,11 +103,17 @@ class TextClassifier(object):
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
starttime = time.time()
self.cls_times.preprocess_time.end()
self.cls_times.inference_time.start()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run()
prob_out = self.output_tensors[0].copy_to_cpu()
self.cls_times.inference_time.end()
self.cls_times.postprocess_time.start()
self.predictor.try_shrink_memory()
cls_result = self.postprocess_op(prob_out)
self.cls_times.postprocess_time.end()
elapse += time.time() - starttime
for rno in range(len(cls_result)):
label, score = cls_result[rno]
......@@ -110,6 +121,9 @@ class TextClassifier(object):
if '180' in label and score > self.cls_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1)
self.cls_times.total_time.end()
self.cls_times.img_num += img_num
elapse = self.cls_times.total_time.value()
return img_list, cls_res, elapse
......@@ -141,8 +155,9 @@ def main(args):
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
cls_res[ino]))
logger.info("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time))
logger.info(
"The predict time about text angle classify module is as follows: ")
text_classifier.cls_times.info(average=False)
if __name__ == "__main__":
......
......@@ -31,6 +31,8 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
import tools.infer.benchmark_utils as benchmark_utils
logger = get_logger()
......@@ -41,7 +43,7 @@ class TextDetector(object):
pre_process_list = [{
'DetResizeForTest': {
'limit_side_len': args.det_limit_side_len,
'limit_type': args.det_limit_type
'limit_type': args.det_limit_type,
}
}, {
'NormalizeImage': {
......@@ -95,9 +97,10 @@ class TextDetector(object):
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor(
args, 'det', logger) # paddle.jit.load(args.det_model_dir)
# self.predictor.eval()
self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
args, 'det', logger)
self.det_times = utility.Timer()
def order_points_clockwise(self, pts):
"""
......@@ -155,6 +158,8 @@ class TextDetector(object):
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
self.det_times.total_time.start()
self.det_times.preprocess_time.start()
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
......@@ -162,7 +167,9 @@ class TextDetector(object):
img = np.expand_dims(img, axis=0)
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
starttime = time.time()
self.det_times.preprocess_time.end()
self.det_times.inference_time.start()
self.input_tensor.copy_from_cpu(img)
self.predictor.run()
......@@ -170,6 +177,7 @@ class TextDetector(object):
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
self.det_times.inference_time.end()
preds = {}
if self.det_algorithm == "EAST":
......@@ -184,6 +192,9 @@ class TextDetector(object):
preds['maps'] = outputs[0]
else:
raise NotImplementedError
self.det_times.postprocess_time.start()
self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points']
......@@ -191,8 +202,11 @@ class TextDetector(object):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
elapse = time.time() - starttime
return dt_boxes, elapse
self.det_times.postprocess_time.end()
self.det_times.total_time.end()
self.det_times.img_num += 1
return dt_boxes, self.det_times.total_time.value()
if __name__ == "__main__":
......@@ -202,6 +216,13 @@ if __name__ == "__main__":
count = 0
total_time = 0
draw_img_save = "./inference_results"
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
# warmup 10 times
fake_img = np.random.uniform(-1, 1, [640, 640, 3]).astype(np.float32)
for i in range(10):
dt_boxes, _ = text_detector(fake_img)
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
for image_file in image_file_list:
......@@ -211,16 +232,56 @@ if __name__ == "__main__":
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
dt_boxes, elapse = text_detector(img)
st = time.time()
dt_boxes, _ = text_detector(img)
elapse = time.time() - st
if count > 0:
total_time += elapse
count += 1
if args.benchmark:
cm, gm, gu = utility.get_current_memory_mb(0)
cpu_mem += cm
gpu_mem += gm
gpu_util += gu
logger.info("Predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, image_file)
img_name_pure = os.path.split(image_file)[-1]
img_path = os.path.join(draw_img_save,
"det_res_{}".format(img_name_pure))
cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path))
if count > 1:
logger.info("Avg Time: {}".format(total_time / (count - 1)))
# print the information about memory and time-spent
if args.benchmark:
mems = {
'cpu_rss_mb': cpu_mem / count,
'gpu_rss_mb': gpu_mem / count,
'gpu_util': gpu_util * 100 / count
}
else:
mems = None
logger.info("The predict time about detection module is as follows: ")
det_time_dict = text_detector.det_times.report(average=True)
det_model_name = args.det_model_dir
if args.benchmark:
# construct log information
model_info = {
'model_name': args.det_model_dir.split('/')[-1],
'precision': args.precision
}
data_info = {
'batch_size': 1,
'shape': 'dynamic_shape',
'data_num': det_time_dict['img_num']
}
perf_info = {
'preprocess_time_s': det_time_dict['preprocess_time'],
'inference_time_s': det_time_dict['inference_time'],
'postprocess_time_s': det_time_dict['postprocess_time'],
'total_time_s': det_time_dict['total_time']
}
benchmark_log = benchmark_utils.PaddleInferBenchmark(
text_detector.config, model_info, data_info, perf_info, mems)
benchmark_log("Det")
......@@ -28,6 +28,7 @@ import traceback
import paddle
import tools.infer.utility as utility
import tools.infer.benchmark_utils as benchmark_utils
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
......@@ -41,7 +42,6 @@ class TextRecognizer(object):
self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
self.max_text_length = args.max_text_length
postprocess_params = {
'name': 'CTCLabelDecode',
"character_type": args.rec_char_type,
......@@ -63,9 +63,11 @@ class TextRecognizer(object):
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = \
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
self.rec_times = utility.Timer()
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[2]
......@@ -166,17 +168,15 @@ class TextRecognizer(object):
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
# rec_res = []
self.rec_times.total_time.start()
rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num
elapse = 0
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
self.rec_times.preprocess_time.start()
for ino in range(beg_img_no, end_img_no):
# h, w = img_list[ino].shape[0:2]
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
......@@ -187,9 +187,8 @@ class TextRecognizer(object):
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
else:
norm_img = self.process_image_srn(img_list[indices[ino]],
self.rec_image_shape, 8,
self.max_text_length)
norm_img = self.process_image_srn(
img_list[indices[ino]], self.rec_image_shape, 8, 25)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
......@@ -203,7 +202,6 @@ class TextRecognizer(object):
norm_img_batch = norm_img_batch.copy()
if self.rec_algorithm == "SRN":
starttime = time.time()
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(
......@@ -218,19 +216,23 @@ class TextRecognizer(object):
gsrm_slf_attn_bias1_list,
gsrm_slf_attn_bias2_list,
]
self.rec_times.preprocess_time.end()
self.rec_times.inference_time.start()
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[
i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
self.rec_times.inference_time.end()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = {"predict": outputs[2]}
else:
starttime = time.time()
self.rec_times.preprocess_time.end()
self.rec_times.inference_time.start()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run()
......@@ -239,22 +241,31 @@ class TextRecognizer(object):
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = outputs[0]
self.predictor.try_shrink_memory()
self.rec_times.inference_time.end()
self.rec_times.postprocess_time.start()
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
elapse += time.time() - starttime
return rec_res, elapse
self.rec_times.postprocess_time.end()
self.rec_times.img_num += int(norm_img_batch.shape[0])
self.rec_times.total_time.end()
return rec_res, self.rec_times.total_time.value()
def main(args):
image_file_list = get_image_file_list(args.image_dir)
text_recognizer = TextRecognizer(args)
total_run_time = 0.0
total_images_num = 0
valid_image_file_list = []
img_list = []
for idx, image_file in enumerate(image_file_list):
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
count = 0
# warmup 10 times
fake_img = np.random.uniform(-1, 1, [1, 32, 320, 3]).astype(np.float32)
for i in range(10):
dt_boxes, _ = text_recognizer(fake_img)
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
......@@ -263,29 +274,54 @@ def main(args):
continue
valid_image_file_list.append(image_file)
img_list.append(img)
if len(img_list) >= args.rec_batch_num or idx == len(
image_file_list) - 1:
try:
rec_res, predict_time = text_recognizer(img_list)
total_run_time += predict_time
except:
logger.info(traceback.format_exc())
logger.info(
"ERROR!!!! \n"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
"If your model has tps module: "
"TPS does not support variable shape.\n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
)
exit()
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[
ino], rec_res[ino]))
total_images_num += len(valid_image_file_list)
valid_image_file_list = []
img_list = []
logger.info("Total predict time for {} images, cost: {:.3f}".format(
total_images_num, total_run_time))
try:
rec_res, _ = text_recognizer(img_list)
if args.benchmark:
cm, gm, gu = utility.get_current_memory_mb(0)
cpu_mem += cm
gpu_mem += gm
gpu_util += gu
count += 1
except Exception as E:
logger.info(traceback.format_exc())
logger.info(E)
exit()
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
rec_res[ino]))
if args.benchmark:
mems = {
'cpu_rss_mb': cpu_mem / count,
'gpu_rss_mb': gpu_mem / count,
'gpu_util': gpu_util * 100 / count
}
else:
mems = None
logger.info("The predict time about recognizer module is as follows: ")
rec_time_dict = text_recognizer.rec_times.report(average=True)
rec_model_name = args.rec_model_dir
if args.benchmark:
# construct log information
model_info = {
'model_name': args.rec_model_dir.split('/')[-1],
'precision': args.precision
}
data_info = {
'batch_size': args.rec_batch_num,
'shape': 'dynamic_shape',
'data_num': rec_time_dict['img_num']
}
perf_info = {
'preprocess_time_s': rec_time_dict['preprocess_time'],
'inference_time_s': rec_time_dict['inference_time'],
'postprocess_time_s': rec_time_dict['postprocess_time'],
'total_time_s': rec_time_dict['total_time']
}
benchmark_log = benchmark_utils.PaddleInferBenchmark(
text_recognizer.config, model_info, data_info, perf_info, mems)
benchmark_log("Rec")
if __name__ == "__main__":
......
......@@ -13,7 +13,6 @@
# limitations under the License.
import os
import sys
import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
......@@ -32,8 +31,8 @@ import tools.infer.predict_det as predict_det
import tools.infer.predict_cls as predict_cls
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from tools.infer.utility import draw_ocr_box_txt
from tools.infer.utility import draw_ocr_box_txt, get_current_memory_mb
import tools.infer.benchmark_utils as benchmark_utils
logger = get_logger()
......@@ -88,7 +87,8 @@ class TextSystem(object):
def __call__(self, img, cls=True):
ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img)
logger.info("dt_boxes num : {}, elapse : {}".format(
logger.debug("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse))
if dt_boxes is None:
return None, None
......@@ -103,11 +103,11 @@ class TextSystem(object):
if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list)
logger.info("cls num : {}, elapse : {}".format(
logger.debug("cls num : {}, elapse : {}".format(
len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list)
logger.info("rec_res num : {}, elapse : {}".format(
logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
filter_boxes, filter_rec_res = [], []
......@@ -142,23 +142,34 @@ def sorted_boxes(dt_boxes):
def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list[args.process_id::args.total_process_num]
text_sys = TextSystem(args)
is_visualize = True
font_path = args.vis_font_path
drop_score = args.drop_score
for image_file in image_file_list:
total_time = 0
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
_st = time.time()
count = 0
for idx, image_file in enumerate(image_file_list):
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
dt_boxes, rec_res = text_sys(img)
elapse = time.time() - starttime
logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
total_time += elapse
if args.benchmark and idx % 20 == 0:
cm, gm, gu = get_current_memory_mb(0)
cpu_mem += cm
gpu_mem += gm
gpu_util += gu
count += 1
logger.info(
str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse))
for text, score in rec_res:
logger.info("{}, {:.3f}".format(text, score))
......@@ -178,26 +189,74 @@ def main(args):
draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
if flag:
image_file = image_file[:-3] + "png"
cv2.imwrite(
os.path.join(draw_img_save, os.path.basename(image_file)),
draw_img[:, :, ::-1])
logger.info("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file))))
logger.info("The predict total time is {}".format(time.time() - _st))
logger.info("\nThe predict total time is {}".format(total_time))
if __name__ == "__main__":
args = utility.parse_args()
if args.use_mp:
p_list = []
total_process_num = args.total_process_num
for process_id in range(total_process_num):
cmd = [sys.executable, "-u"] + sys.argv + [
"--process_id={}".format(process_id),
"--use_mp={}".format(False)
]
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
p_list.append(p)
for p in p_list:
p.wait()
img_num = text_sys.text_detector.det_times.img_num
if args.benchmark:
mems = {
'cpu_rss_mb': cpu_mem / count,
'gpu_rss_mb': gpu_mem / count,
'gpu_util': gpu_util * 100 / count
}
else:
main(args)
mems = None
det_time_dict = text_sys.text_detector.det_times.report(average=True)
rec_time_dict = text_sys.text_recognizer.rec_times.report(average=True)
det_model_name = args.det_model_dir
rec_model_name = args.rec_model_dir
# construct det log information
model_info = {
'model_name': args.det_model_dir.split('/')[-1],
'precision': args.precision
}
data_info = {
'batch_size': 1,
'shape': 'dynamic_shape',
'data_num': det_time_dict['img_num']
}
perf_info = {
'preprocess_time_s': det_time_dict['preprocess_time'],
'inference_time_s': det_time_dict['inference_time'],
'postprocess_time_s': det_time_dict['postprocess_time'],
'total_time_s': det_time_dict['total_time']
}
benchmark_log = benchmark_utils.PaddleInferBenchmark(
text_sys.text_detector.config, model_info, data_info, perf_info, mems,
args.save_log_path)
benchmark_log("Det")
# construct rec log information
model_info = {
'model_name': args.rec_model_dir.split('/')[-1],
'precision': args.precision
}
data_info = {
'batch_size': args.rec_batch_num,
'shape': 'dynamic_shape',
'data_num': rec_time_dict['img_num']
}
perf_info = {
'preprocess_time_s': rec_time_dict['preprocess_time'],
'inference_time_s': rec_time_dict['inference_time'],
'postprocess_time_s': rec_time_dict['postprocess_time'],
'total_time_s': rec_time_dict['total_time']
}
benchmark_log = benchmark_utils.PaddleInferBenchmark(
text_sys.text_recognizer.config, model_info, data_info, perf_info, mems,
args.save_log_path)
benchmark_log("Rec")
if __name__ == "__main__":
main(utility.parse_args())
......@@ -37,7 +37,7 @@ def init_args():
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--use_fp16", type=str2bool, default=False)
parser.add_argument("--precision", type=str, default="fp32")
parser.add_argument("--gpu_mem", type=int, default=500)
# params for text detector
......@@ -109,6 +109,11 @@ def init_args():
parser.add_argument("--use_mp", type=str2bool, default=False)
parser.add_argument("--total_process_num", type=int, default=1)
parser.add_argument("--process_id", type=int, default=0)
parser.add_argument("--benchmark", type=bool, default=False)
parser.add_argument("--save_log_path", type=str, default="./log_output/")
parser.add_argument("--show_log", type=str2bool, default=True)
return parser
......@@ -118,6 +123,76 @@ def parse_args():
return parser.parse_args()
class Times(object):
def __init__(self):
self.time = 0.
self.st = 0.
self.et = 0.
def start(self):
self.st = time.time()
def end(self, accumulative=True):
self.et = time.time()
if accumulative:
self.time += self.et - self.st
else:
self.time = self.et - self.st
def reset(self):
self.time = 0.
self.st = 0.
self.et = 0.
def value(self):
return round(self.time, 4)
class Timer(Times):
def __init__(self):
super(Timer, self).__init__()
self.total_time = Times()
self.preprocess_time = Times()
self.inference_time = Times()
self.postprocess_time = Times()
self.img_num = 0
def info(self, average=False):
logger.info("----------------------- Perf info -----------------------")
logger.info("total_time: {}, img_num: {}".format(self.total_time.value(
), self.img_num))
preprocess_time = round(self.preprocess_time.value() / self.img_num,
4) if average else self.preprocess_time.value()
postprocess_time = round(
self.postprocess_time.value() / self.img_num,
4) if average else self.postprocess_time.value()
inference_time = round(self.inference_time.value() / self.img_num,
4) if average else self.inference_time.value()
average_latency = self.total_time.value() / self.img_num
logger.info("average_latency(ms): {:.2f}, QPS: {:2f}".format(
average_latency * 1000, 1 / average_latency))
logger.info(
"preprocess_latency(ms): {:.2f}, inference_latency(ms): {:.2f}, postprocess_latency(ms): {:.2f}".
format(preprocess_time * 1000, inference_time * 1000,
postprocess_time * 1000))
def report(self, average=False):
dic = {}
dic['preprocess_time'] = round(
self.preprocess_time.value() / self.img_num,
4) if average else self.preprocess_time.value()
dic['postprocess_time'] = round(
self.postprocess_time.value() / self.img_num,
4) if average else self.postprocess_time.value()
dic['inference_time'] = round(
self.inference_time.value() / self.img_num,
4) if average else self.inference_time.value()
dic['img_num'] = self.img_num
dic['total_time'] = round(self.total_time.value(), 4)
return dic
def create_predictor(args, mode, logger):
if mode == "det":
model_dir = args.det_model_dir
......@@ -125,6 +200,8 @@ def create_predictor(args, mode, logger):
model_dir = args.cls_model_dir
elif mode == 'rec':
model_dir = args.rec_model_dir
elif mode == 'structure':
model_dir = args.structure_model_dir
else:
model_dir = args.e2e_model_dir
......@@ -142,6 +219,16 @@ def create_predictor(args, mode, logger):
config = inference.Config(model_file_path, params_file_path)
if hasattr(args, 'precision'):
if args.precision == "fp16" and args.use_tensorrt:
precision = inference.PrecisionType.Half
elif args.precision == "int8":
precision = inference.PrecisionType.Int8
else:
precision = inference.PrecisionType.Float32
else:
precision = inference.PrecisionType.Float32
if args.use_gpu:
config.enable_use_gpu(args.gpu_mem, 0)
if args.use_tensorrt:
......@@ -244,7 +331,9 @@ def create_predictor(args, mode, logger):
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False)
config.switch_ir_optim(True)
if mode == 'structure':
config.switch_ir_optim(False)
# create predictor
predictor = inference.create_predictor(config)
input_names = predictor.get_input_names()
......@@ -255,7 +344,7 @@ def create_predictor(args, mode, logger):
for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor)
return predictor, input_tensor, output_tensors
return predictor, input_tensor, output_tensors, config
def draw_e2e_res(dt_boxes, strs, img_path):
......@@ -506,5 +595,30 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5):
return image
def get_current_memory_mb(gpu_id=None):
"""
It is used to Obtain the memory usage of the CPU and GPU during the running of the program.
And this function Current program is time-consuming.
"""
import pynvml
import psutil
import GPUtil
pid = os.getpid()
p = psutil.Process(pid)
info = p.memory_full_info()
cpu_mem = info.uss / 1024. / 1024.
gpu_mem = 0
gpu_percent = 0
if gpu_id is not None:
GPUs = GPUtil.getGPUs()
gpu_load = GPUs[gpu_id].load
gpu_percent = gpu_load
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
gpu_mem = meminfo.used / 1024. / 1024.
return round(cpu_mem, 4), round(gpu_mem, 4), round(gpu_percent, 4)
if __name__ == '__main__':
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment