Unverified Commit 69269c16 authored by zhoujun's avatar zhoujun Committed by GitHub
Browse files

Merge pull request #5840 from WenmuZhou/cpp_infer

add PP-Structure to hubserving
parents 33f9b1d5 97f0a2d5
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
sys.path.insert(0, ".")
import copy
import time
import paddlehub
from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving
import cv2
import numpy as np
import paddlehub as hub
from tools.infer.utility import base64_to_cv2
from ppstructure.table.predict_table import TableSystem as _TableSystem
from ppstructure.predict_system import save_structure_res
from ppstructure.utility import parse_args
from deploy.hubserving.structure_table.params import read_params
@moduleinfo(
name="structure_table",
version="1.0.0",
summary="PP-Structure table service",
author="paddle-dev",
author_email="paddle-dev@baidu.com",
type="cv/structure_table")
class TableSystem(hub.Module):
def _initialize(self, use_gpu=False, enable_mkldnn=False):
"""
initialize with the necessary elements
"""
cfg = self.merge_configs()
cfg.use_gpu = use_gpu
if use_gpu:
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
print("use gpu: ", use_gpu)
print("CUDA_VISIBLE_DEVICES: ", _places)
cfg.gpu_mem = 8000
except:
raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES via export CUDA_VISIBLE_DEVICES=cuda_device_id."
)
cfg.ir_optim = True
cfg.enable_mkldnn = enable_mkldnn
self.table_sys = _TableSystem(cfg)
def merge_configs(self):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
cfg = parse_args()
update_cfg_map = vars(read_params())
for key in update_cfg_map:
cfg.__setattr__(key, update_cfg_map[key])
sys.argv = copy.deepcopy(backup_argv)
return cfg
def read_images(self, paths=[]):
images = []
for img_path in paths:
assert os.path.isfile(
img_path), "The {} isn't a valid file.".format(img_path)
img = cv2.imread(img_path)
if img is None:
logger.info("error in loading image:{}".format(img_path))
continue
images.append(img)
return images
def predict(self, images=[], paths=[]):
"""
Get the chinese texts in the predicted images.
Args:
images (list(numpy.ndarray)): images data, shape of each is [H, W, C]. If images not paths
paths (list[str]): The paths of images. If paths not images
Returns:
res (list): The result of chinese texts and save path of images.
"""
if images != [] and isinstance(images, list) and paths == []:
predicted_data = images
elif images == [] and isinstance(paths, list) and paths != []:
predicted_data = self.read_images(paths)
else:
raise TypeError("The input data is inconsistent with expectations.")
assert predicted_data != [], "There is not any image to be predicted. Please check the input data."
all_results = []
for img in predicted_data:
if img is None:
logger.info("error in loading image")
all_results.append([])
continue
starttime = time.time()
pred_html = self.table_sys(img)
elapse = time.time() - starttime
logger.info("Predict time: {}".format(elapse))
all_results.append({'html': pred_html})
return all_results
@serving
def serving_method(self, images, **kwargs):
"""
Run as a service.
"""
images_decode = [base64_to_cv2(image) for image in images]
results = self.predict(images_decode, **kwargs)
return results
if __name__ == '__main__':
table_system = TableSystem()
table_system._initialize()
image_path = ['./doc/table/table.jpg']
res = table_system.predict(paths=image_path)
print(res)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from deploy.hubserving.ocr_system.params import read_params as pp_ocr_read_params
def read_params():
cfg = pp_ocr_read_params()
# params for table structure model
cfg.table_max_len = 488
cfg.table_model_dir = './inference/en_ppocr_mobile_v2.0_table_structure_infer/'
cfg.table_char_type = 'en'
cfg.table_char_dict_path = './ppocr/utils/dict/table_structure_dict.txt'
cfg.show_log = False
return cfg
...@@ -39,7 +39,7 @@ from ppocr.utils.utility import check_and_read_gif, get_image_file_list ...@@ -39,7 +39,7 @@ from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
from tools.infer.utility import draw_ocr, str2bool, check_gpu from tools.infer.utility import draw_ocr, str2bool, check_gpu
from ppstructure.utility import init_args, draw_structure_result from ppstructure.utility import init_args, draw_structure_result
from ppstructure.predict_system import OCRSystem, save_structure_res from ppstructure.predict_system import StructureSystem, save_structure_res
__all__ = [ __all__ = [
'PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result', 'PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result',
...@@ -398,7 +398,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -398,7 +398,7 @@ class PaddleOCR(predict_system.TextSystem):
return rec_res return rec_res
class PPStructure(OCRSystem): class PPStructure(StructureSystem):
def __init__(self, **kwargs): def __init__(self, **kwargs):
params = parse_args(mMain=False) params = parse_args(mMain=False)
params.__dict__.update(**kwargs) params.__dict__.update(**kwargs)
......
...@@ -22,6 +22,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) ...@@ -22,6 +22,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth' os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2 import cv2
import json
import numpy as np import numpy as np
import time import time
import logging import logging
...@@ -35,7 +36,7 @@ from ppstructure.utility import parse_args, draw_structure_result ...@@ -35,7 +36,7 @@ from ppstructure.utility import parse_args, draw_structure_result
logger = get_logger() logger = get_logger()
class OCRSystem(object): class StructureSystem(object):
def __init__(self, args): def __init__(self, args):
self.mode = args.mode self.mode = args.mode
if self.mode == 'structure': if self.mode == 'structure':
...@@ -66,8 +67,7 @@ class OCRSystem(object): ...@@ -66,8 +67,7 @@ class OCRSystem(object):
self.use_angle_cls = args.use_angle_cls self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score self.drop_score = args.drop_score
elif self.mode == 'vqa': elif self.mode == 'vqa':
from ppstructure.vqa.infer_ser_e2e import SerPredictor, draw_ser_results raise NotImplementedError
self.vqa_engine = SerPredictor(args)
def __call__(self, img): def __call__(self, img):
if self.mode == 'structure': if self.mode == 'structure':
...@@ -82,24 +82,24 @@ class OCRSystem(object): ...@@ -82,24 +82,24 @@ class OCRSystem(object):
res = self.table_system(roi_img) res = self.table_system(roi_img)
else: else:
filter_boxes, filter_rec_res = self.text_system(roi_img) filter_boxes, filter_rec_res = self.text_system(roi_img)
filter_boxes = [x + [x1, y1] for x in filter_boxes]
filter_boxes = [
x.reshape(-1).tolist() for x in filter_boxes
]
# remove style char # remove style char
style_token = [ style_token = [
'<strike>', '<strike>', '<sup>', '</sub>', '<b>', '<strike>', '<strike>', '<sup>', '</sub>', '<b>',
'</b>', '<sub>', '</sup>', '<overline>', '</overline>', '</b>', '<sub>', '</sup>', '<overline>', '</overline>',
'<underline>', '</underline>', '<i>', '</i>' '<underline>', '</underline>', '<i>', '</i>'
] ]
filter_rec_res_tmp = [] res = []
for rec_res in filter_rec_res: for box, rec_res in zip(filter_boxes, filter_rec_res):
rec_str, rec_conf = rec_res rec_str, rec_conf = rec_res
for token in style_token: for token in style_token:
if token in rec_str: if token in rec_str:
rec_str = rec_str.replace(token, '') rec_str = rec_str.replace(token, '')
filter_rec_res_tmp.append((rec_str, rec_conf)) box += [x1, y1]
res = (filter_boxes, filter_rec_res_tmp) res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist()
})
res_list.append({ res_list.append({
'type': region.type, 'type': region.type,
'bbox': [x1, y1, x2, y2], 'bbox': [x1, y1, x2, y2],
...@@ -107,7 +107,7 @@ class OCRSystem(object): ...@@ -107,7 +107,7 @@ class OCRSystem(object):
'res': res 'res': res
}) })
elif self.mode == 'vqa': elif self.mode == 'vqa':
res_list, _ = self.vqa_engine(img) raise NotImplementedError
return res_list return res_list
...@@ -123,15 +123,14 @@ def save_structure_res(res, save_folder, img_name): ...@@ -123,15 +123,14 @@ def save_structure_res(res, save_folder, img_name):
excel_path = os.path.join(excel_save_folder, excel_path = os.path.join(excel_save_folder,
'{}.xlsx'.format(region['bbox'])) '{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path) to_excel(region['res'], excel_path)
if region['type'] == 'Figure': elif region['type'] == 'Figure':
roi_img = region['img'] roi_img = region['img']
img_path = os.path.join(excel_save_folder, img_path = os.path.join(excel_save_folder,
'{}.jpg'.format(region['bbox'])) '{}.jpg'.format(region['bbox']))
cv2.imwrite(img_path, roi_img) cv2.imwrite(img_path, roi_img)
else: else:
for box, rec_res in zip(region['res'][0], region['res'][1]): for text_result in region['res']:
f.write('{}\t{}\n'.format( f.write('{}\n'.format(json.dumps(text_result)))
np.array(box).reshape(-1).tolist(), rec_res))
def main(args): def main(args):
...@@ -139,7 +138,7 @@ def main(args): ...@@ -139,7 +138,7 @@ def main(args):
image_file_list = image_file_list image_file_list = image_file_list
image_file_list = image_file_list[args.process_id::args.total_process_num] image_file_list = image_file_list[args.process_id::args.total_process_num]
structure_sys = OCRSystem(args) structure_sys = StructureSystem(args)
img_num = len(image_file_list) img_num = len(image_file_list)
save_folder = os.path.join(args.output, structure_sys.mode) save_folder = os.path.join(args.output, structure_sys.mode)
os.makedirs(save_folder, exist_ok=True) os.makedirs(save_folder, exist_ok=True)
...@@ -162,8 +161,9 @@ def main(args): ...@@ -162,8 +161,9 @@ def main(args):
draw_img = draw_structure_result(img, res, args.vis_font_path) draw_img = draw_structure_result(img, res, args.vis_font_path)
img_save_path = os.path.join(save_folder, img_name, 'show.jpg') img_save_path = os.path.join(save_folder, img_name, 'show.jpg')
elif structure_sys.mode == 'vqa': elif structure_sys.mode == 'vqa':
draw_img = draw_ser_results(img, res, args.vis_font_path) raise NotImplementedError
img_save_path = os.path.join(save_folder, img_name + '.jpg') # draw_img = draw_ser_results(img, res, args.vis_font_path)
# img_save_path = os.path.join(save_folder, img_name + '.jpg')
cv2.imwrite(img_save_path, draw_img) cv2.imwrite(img_save_path, draw_img)
logger.info('result save to {}'.format(img_save_path)) logger.info('result save to {}'.format(img_save_path))
elapse = time.time() - starttime elapse = time.time() - starttime
......
...@@ -40,12 +40,6 @@ def init_args(): ...@@ -40,12 +40,6 @@ def init_args():
type=ast.literal_eval, type=ast.literal_eval,
default=None, default=None,
help='label map according to ppstructure/layout/README_ch.md') help='label map according to ppstructure/layout/README_ch.md')
# params for ser
parser.add_argument("--model_name_or_path", type=str)
parser.add_argument("--max_seq_length", type=int, default=512)
parser.add_argument(
"--label_map_path", type=str, default='./vqa/labels/labels_ser.txt')
parser.add_argument( parser.add_argument(
"--mode", "--mode",
type=str, type=str,
...@@ -67,10 +61,10 @@ def draw_structure_result(image, result, font_path): ...@@ -67,10 +61,10 @@ def draw_structure_result(image, result, font_path):
if region['type'] == 'Table': if region['type'] == 'Table':
pass pass
else: else:
for box, rec_res in zip(region['res'][0], region['res'][1]): for text_result in region['res']:
boxes.append(np.array(box).reshape(-1, 2)) boxes.append(np.array(text_result['text_region']))
txts.append(rec_res[0]) txts.append(text_result['text'])
scores.append(rec_res[1]) scores.append(text_result['confidence'])
im_show = draw_ocr_box_txt( im_show = draw_ocr_box_txt(
image, boxes, txts, scores, font_path=font_path, drop_score=0) image, boxes, txts, scores, font_path=font_path, drop_score=0)
return im_show return im_show
...@@ -25,7 +25,9 @@ import numpy as np ...@@ -25,7 +25,9 @@ import numpy as np
import time import time
from PIL import Image from PIL import Image
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
from tools.infer.utility import draw_ocr, draw_boxes from tools.infer.utility import draw_ocr, draw_boxes, str2bool
from ppstructure.utility import draw_structure_result
from ppstructure.predict_system import to_excel
import requests import requests
import json import json
...@@ -69,8 +71,33 @@ def draw_server_result(image_file, res): ...@@ -69,8 +71,33 @@ def draw_server_result(image_file, res):
return draw_img return draw_img
def main(url, image_path): def save_structure_res(res, save_folder, image_file):
image_file_list = get_image_file_list(image_path) img = cv2.imread(image_file)
excel_save_folder = os.path.join(save_folder, os.path.basename(image_file))
os.makedirs(excel_save_folder, exist_ok=True)
# save res
with open(
os.path.join(excel_save_folder, 'res.txt'), 'w',
encoding='utf8') as f:
for region in res:
if region['type'] == 'Table':
excel_path = os.path.join(excel_save_folder,
'{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path)
elif region['type'] == 'Figure':
x1, y1, x2, y2 = region['bbox']
print(region['bbox'])
roi_img = img[y1:y2, x1:x2, :]
img_path = os.path.join(excel_save_folder,
'{}.jpg'.format(region['bbox']))
cv2.imwrite(img_path, roi_img)
else:
for text_result in region['res']:
f.write('{}\n'.format(json.dumps(text_result)))
def main(args):
image_file_list = get_image_file_list(args.image_dir)
is_visualize = False is_visualize = False
headers = {"Content-type": "application/json"} headers = {"Content-type": "application/json"}
cnt = 0 cnt = 0
...@@ -80,38 +107,51 @@ def main(url, image_path): ...@@ -80,38 +107,51 @@ def main(url, image_path):
if img is None: if img is None:
logger.info("error in loading image:{}".format(image_file)) logger.info("error in loading image:{}".format(image_file))
continue continue
img_name = os.path.basename(image_file)
# 发送HTTP请求 # seed http request
starttime = time.time() starttime = time.time()
data = {'images': [cv2_to_base64(img)]} data = {'images': [cv2_to_base64(img)]}
r = requests.post(url=url, headers=headers, data=json.dumps(data)) r = requests.post(
url=args.server_url, headers=headers, data=json.dumps(data))
elapse = time.time() - starttime elapse = time.time() - starttime
total_time += elapse total_time += elapse
logger.info("Predict time of %s: %.3fs" % (image_file, elapse)) logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
res = r.json()["results"][0] res = r.json()["results"][0]
logger.info(res) logger.info(res)
if is_visualize: if args.visualize:
draw_img = None
if 'structure_table' in args.server_url:
to_excel(res['html'], './{}.xlsx'.format(img_name))
elif 'structure_system' in args.server_url:
save_structure_res(res['regions'], args.output, image_file)
else:
draw_img = draw_server_result(image_file, res) draw_img = draw_server_result(image_file, res)
if draw_img is not None: if draw_img is not None:
draw_img_save = "./server_results/" if not os.path.exists(args.output):
if not os.path.exists(draw_img_save): os.makedirs(args.output)
os.makedirs(draw_img_save)
cv2.imwrite( cv2.imwrite(
os.path.join(draw_img_save, os.path.basename(image_file)), os.path.join(args.output, os.path.basename(image_file)),
draw_img[:, :, ::-1]) draw_img[:, :, ::-1])
logger.info("The visualized image saved in {}".format( logger.info("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file)))) os.path.join(args.output, os.path.basename(image_file))))
cnt += 1 cnt += 1
if cnt % 100 == 0: if cnt % 100 == 0:
logger.info("{} processed".format(cnt)) logger.info("{} processed".format(cnt))
logger.info("avg time cost: {}".format(float(total_time) / cnt)) logger.info("avg time cost: {}".format(float(total_time) / cnt))
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="args for hub serving")
parser.add_argument("--server_url", type=str, required=True)
parser.add_argument("--image_dir", type=str, required=True)
parser.add_argument("--visualize", type=str2bool, default=False)
parser.add_argument("--output", type=str, default='./hubserving_result')
args = parser.parse_args()
return args
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) != 3: args = parse_args()
logger.info("Usage: %s server_url image_path" % sys.argv[0]) main(args)
else:
server_url = sys.argv[1]
image_path = sys.argv[2]
main(server_url, image_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment