Unverified Commit ee05c913 authored by zhoujun's avatar zhoujun Committed by GitHub
Browse files

Merge pull request #5 from PaddlePaddle/develop

merge paddleocr
parents 7c09c97d 2bdaea56
...@@ -18,9 +18,9 @@ from __future__ import print_function ...@@ -18,9 +18,9 @@ from __future__ import print_function
import os import os
import sys import sys
__dir__ = os.path.dirname(__file__) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..')) sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
def set_paddle_flags(**kwargs): def set_paddle_flags(**kwargs):
......
...@@ -13,30 +13,36 @@ ...@@ -13,30 +13,36 @@
# limitations under the License. # limitations under the License.
import os import os
import sys import sys
__dir__ = os.path.dirname(__file__) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '../..')) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import cv2
import copy
import numpy as np
import math
import time
import sys
import paddle.fluid as fluid
import tools.infer.utility as utility import tools.infer.utility as utility
from ppocr.utils.utility import initial_logger from ppocr.utils.utility import initial_logger
logger = initial_logger() logger = initial_logger()
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
import cv2 from ppocr.data.det.sast_process import SASTProcessTest
from ppocr.data.det.east_process import EASTProcessTest from ppocr.data.det.east_process import EASTProcessTest
from ppocr.data.det.db_process import DBProcessTest from ppocr.data.det.db_process import DBProcessTest
from ppocr.postprocess.db_postprocess import DBPostProcess from ppocr.postprocess.db_postprocess import DBPostProcess
from ppocr.postprocess.east_postprocess import EASTPostPocess from ppocr.postprocess.east_postprocess import EASTPostPocess
import copy from ppocr.postprocess.sast_postprocess import SASTPostProcess
import numpy as np
import math
import time
import sys
class TextDetector(object): class TextDetector(object):
def __init__(self, args): def __init__(self, args):
max_side_len = args.det_max_side_len max_side_len = args.det_max_side_len
self.det_algorithm = args.det_algorithm self.det_algorithm = args.det_algorithm
self.use_zero_copy_run = args.use_zero_copy_run
preprocess_params = {'max_side_len': max_side_len} preprocess_params = {'max_side_len': max_side_len}
postprocess_params = {} postprocess_params = {}
if self.det_algorithm == "DB": if self.det_algorithm == "DB":
...@@ -52,6 +58,20 @@ class TextDetector(object): ...@@ -52,6 +58,20 @@ class TextDetector(object):
postprocess_params["cover_thresh"] = args.det_east_cover_thresh postprocess_params["cover_thresh"] = args.det_east_cover_thresh
postprocess_params["nms_thresh"] = args.det_east_nms_thresh postprocess_params["nms_thresh"] = args.det_east_nms_thresh
self.postprocess_op = EASTPostPocess(postprocess_params) self.postprocess_op = EASTPostPocess(postprocess_params)
elif self.det_algorithm == "SAST":
self.preprocess_op = SASTProcessTest(preprocess_params)
postprocess_params["score_thresh"] = args.det_sast_score_thresh
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
self.det_sast_polygon = args.det_sast_polygon
if self.det_sast_polygon:
postprocess_params["sample_pts_num"] = 6
postprocess_params["expand_scale"] = 1.2
postprocess_params["shrink_ratio_of_width"] = 0.2
else:
postprocess_params["sample_pts_num"] = 2
postprocess_params["expand_scale"] = 1.0
postprocess_params["shrink_ratio_of_width"] = 0.3
self.postprocess_op = SASTPostProcess(postprocess_params)
else: else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0) sys.exit(0)
...@@ -84,7 +104,7 @@ class TextDetector(object): ...@@ -84,7 +104,7 @@ class TextDetector(object):
return rect return rect
def clip_det_res(self, points, img_height, img_width): def clip_det_res(self, points, img_height, img_width):
for pno in range(4): for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points return points
...@@ -103,6 +123,15 @@ class TextDetector(object): ...@@ -103,6 +123,15 @@ class TextDetector(object):
dt_boxes = np.array(dt_boxes_new) dt_boxes = np.array(dt_boxes_new)
return dt_boxes return dt_boxes
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
box = self.clip_det_res(box, img_height, img_width)
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def __call__(self, img): def __call__(self, img):
ori_im = img.copy() ori_im = img.copy()
im, ratio_list = self.preprocess_op(img) im, ratio_list = self.preprocess_op(img)
...@@ -110,8 +139,12 @@ class TextDetector(object): ...@@ -110,8 +139,12 @@ class TextDetector(object):
return None, 0 return None, 0
im = im.copy() im = im.copy()
starttime = time.time() starttime = time.time()
if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(im) self.input_tensor.copy_from_cpu(im)
self.predictor.zero_copy_run() self.predictor.zero_copy_run()
else:
im = fluid.core.PaddleTensor(im)
self.predictor.run([im])
outputs = [] outputs = []
for output_tensor in self.output_tensors: for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
...@@ -120,10 +153,19 @@ class TextDetector(object): ...@@ -120,10 +153,19 @@ class TextDetector(object):
if self.det_algorithm == "EAST": if self.det_algorithm == "EAST":
outs_dict['f_geo'] = outputs[0] outs_dict['f_geo'] = outputs[0]
outs_dict['f_score'] = outputs[1] outs_dict['f_score'] = outputs[1]
elif self.det_algorithm == 'SAST':
outs_dict['f_border'] = outputs[0]
outs_dict['f_score'] = outputs[1]
outs_dict['f_tco'] = outputs[2]
outs_dict['f_tvo'] = outputs[3]
else: else:
outs_dict['maps'] = outputs[0] outs_dict['maps'] = outputs[0]
dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list]) dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list])
dt_boxes = dt_boxes_list[0] dt_boxes = dt_boxes_list[0]
if self.det_algorithm == "SAST" and self.det_sast_polygon:
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
elapse = time.time() - starttime elapse = time.time() - starttime
return dt_boxes, elapse return dt_boxes, elapse
......
...@@ -17,15 +17,18 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) ...@@ -17,15 +17,18 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import tools.infer.utility as utility
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
import cv2 import cv2
import copy import copy
import numpy as np import numpy as np
import math import math
import time import time
import paddle.fluid as fluid
import tools.infer.utility as utility
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.character import CharacterOps from ppocr.utils.character import CharacterOps
...@@ -37,6 +40,7 @@ class TextRecognizer(object): ...@@ -37,6 +40,7 @@ class TextRecognizer(object):
self.character_type = args.rec_char_type self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm self.rec_algorithm = args.rec_algorithm
self.use_zero_copy_run = args.use_zero_copy_run
char_ops_params = { char_ops_params = {
"character_type": args.rec_char_type, "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path, "character_dict_path": args.rec_char_dict_path,
...@@ -102,8 +106,12 @@ class TextRecognizer(object): ...@@ -102,8 +106,12 @@ class TextRecognizer(object):
norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
starttime = time.time() starttime = time.time()
if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run() self.predictor.zero_copy_run()
else:
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch])
if self.loss_type == "ctc": if self.loss_type == "ctc":
rec_idx_batch = self.output_tensors[0].copy_to_cpu() rec_idx_batch = self.output_tensors[0].copy_to_cpu()
......
...@@ -157,7 +157,6 @@ def main(args): ...@@ -157,7 +157,6 @@ def main(args):
boxes, boxes,
txts, txts,
scores, scores,
draw_txt=True,
drop_score=drop_score) drop_score=drop_score)
draw_img_save = "./inference_results/" draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save): if not os.path.exists(draw_img_save):
......
...@@ -53,6 +53,11 @@ def parse_args(): ...@@ -53,6 +53,11 @@ def parse_args():
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
#SAST parmas
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
parser.add_argument("--det_sast_polygon", type=bool, default=False)
#params for text recognizer #params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_model_dir", type=str)
...@@ -66,6 +71,7 @@ def parse_args(): ...@@ -66,6 +71,7 @@ def parse_args():
default="./ppocr/utils/ppocr_keys_v1.txt") default="./ppocr/utils/ppocr_keys_v1.txt")
parser.add_argument("--use_space_char", type=bool, default=True) parser.add_argument("--use_space_char", type=bool, default=True)
parser.add_argument("--enable_mkldnn", type=bool, default=False) parser.add_argument("--enable_mkldnn", type=bool, default=False)
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
return parser.parse_args() return parser.parse_args()
...@@ -100,9 +106,12 @@ def create_predictor(args, mode): ...@@ -100,9 +106,12 @@ def create_predictor(args, mode):
#config.enable_memory_optim() #config.enable_memory_optim()
config.disable_glog_info() config.disable_glog_info()
# use zero copy if args.use_zero_copy_run:
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
else:
config.switch_use_feed_fetch_ops(True)
predictor = create_paddle_predictor(config) predictor = create_paddle_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
input_tensor = predictor.get_input_tensor(input_names[0]) input_tensor = predictor.get_input_tensor(input_names[0])
...@@ -134,7 +143,12 @@ def resize_img(img, input_size=600): ...@@ -134,7 +143,12 @@ def resize_img(img, input_size=600):
return im return im
def draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5): def draw_ocr(image,
boxes,
txts=None,
scores=None,
drop_score=0.5,
font_path="./doc/simfang.ttf"):
""" """
Visualize the results of OCR detection and recognition Visualize the results of OCR detection and recognition
args: args:
...@@ -142,23 +156,29 @@ def draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5): ...@@ -142,23 +156,29 @@ def draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5):
boxes(list): boxes with shape(N, 4, 2) boxes(list): boxes with shape(N, 4, 2)
txts(list): the texts txts(list): the texts
scores(list): txxs corresponding scores scores(list): txxs corresponding scores
draw_txt(bool): whether draw text or not
drop_score(float): only scores greater than drop_threshold will be visualized drop_score(float): only scores greater than drop_threshold will be visualized
font_path: the path of font which is used to draw text
return(array): return(array):
the visualized img the visualized img
""" """
if scores is None: if scores is None:
scores = [1] * len(boxes) scores = [1] * len(boxes)
for (box, score) in zip(boxes, scores): box_num = len(boxes)
if score < drop_score or math.isnan(score): for i in range(box_num):
if scores is not None and (scores[i] < drop_score or
math.isnan(scores[i])):
continue continue
box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64) box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2) image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
if txts is not None:
if draw_txt:
img = np.array(resize_img(image, input_size=600)) img = np.array(resize_img(image, input_size=600))
txt_img = text_visual( txt_img = text_visual(
txts, scores, img_h=img.shape[0], img_w=600, threshold=drop_score) txts,
scores,
img_h=img.shape[0],
img_w=600,
threshold=drop_score,
font_path=font_path)
img = np.concatenate([np.array(img), np.array(txt_img)], axis=1) img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
return img return img
return image return image
...@@ -236,7 +256,12 @@ def str_count(s): ...@@ -236,7 +256,12 @@ def str_count(s):
return s_len - math.ceil(en_dg_count / 2) return s_len - math.ceil(en_dg_count / 2)
def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.): def text_visual(texts,
scores,
img_h=400,
img_w=600,
threshold=0.,
font_path="./doc/simfang.ttf"):
""" """
create new blank img and draw txt on it create new blank img and draw txt on it
args: args:
...@@ -244,6 +269,7 @@ def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.): ...@@ -244,6 +269,7 @@ def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.):
scores(list|None): corresponding score of each txt scores(list|None): corresponding score of each txt
img_h(int): the height of blank img img_h(int): the height of blank img
img_w(int): the width of blank img img_w(int): the width of blank img
font_path: the path of font which is used to draw text
return(array): return(array):
""" """
...@@ -262,7 +288,7 @@ def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.): ...@@ -262,7 +288,7 @@ def text_visual(texts, scores, img_h=400, img_w=600, threshold=0.):
font_size = 20 font_size = 20
txt_color = (0, 0, 0) txt_color = (0, 0, 0)
font = ImageFont.truetype("./doc/simfang.ttf", font_size, encoding="utf-8") font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
gap = font_size + 5 gap = font_size + 5
txt_img_list = [] txt_img_list = []
...@@ -343,6 +369,6 @@ if __name__ == '__main__': ...@@ -343,6 +369,6 @@ if __name__ == '__main__':
txts.append(dic['transcription']) txts.append(dic['transcription'])
scores.append(round(dic['scores'], 3)) scores.append(round(dic['scores'], 3))
new_img = draw_ocr(image, boxes, txts, scores, draw_txt=True) new_img = draw_ocr(image, boxes, txts, scores)
cv2.imwrite(img_name, new_img) cv2.imwrite(img_name, new_img)
...@@ -22,9 +22,9 @@ import json ...@@ -22,9 +22,9 @@ import json
import os import os
import sys import sys
__dir__ = os.path.dirname(__file__) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..')) sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
def set_paddle_flags(**kwargs): def set_paddle_flags(**kwargs):
...@@ -134,8 +134,10 @@ def main(): ...@@ -134,8 +134,10 @@ def main():
dic = {'f_score': outs[0], 'f_geo': outs[1]} dic = {'f_score': outs[0], 'f_geo': outs[1]}
elif config['Global']['algorithm'] == 'DB': elif config['Global']['algorithm'] == 'DB':
dic = {'maps': outs[0]} dic = {'maps': outs[0]}
elif config['Global']['algorithm'] == 'SAST':
dic = {'f_score': outs[0], 'f_border': outs[1], 'f_tvo': outs[2], 'f_tco': outs[3]}
else: else:
raise Exception("only support algorithm: ['EAST', 'DB']") raise Exception("only support algorithm: ['EAST', 'DB', 'SAST']")
dt_boxes_list = postprocess(dic, ratio_list) dt_boxes_list = postprocess(dic, ratio_list)
for ino in range(img_num): for ino in range(img_num):
dt_boxes = dt_boxes_list[ino] dt_boxes = dt_boxes_list[ino]
......
...@@ -19,9 +19,9 @@ from __future__ import print_function ...@@ -19,9 +19,9 @@ from __future__ import print_function
import numpy as np import numpy as np
import os import os
import sys import sys
__dir__ = os.path.dirname(__file__) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..')) sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
def set_paddle_flags(**kwargs): def set_paddle_flags(**kwargs):
...@@ -140,12 +140,12 @@ def main(): ...@@ -140,12 +140,12 @@ def main():
preds = preds.reshape(-1) preds = preds.reshape(-1)
preds_text = char_ops.decode(preds) preds_text = char_ops.decode(preds)
elif loss_type == "srn": elif loss_type == "srn":
cur_pred = [] char_num = char_ops.get_char_num()
preds = np.array(predict[0]) preds = np.array(predict[0])
preds = preds.reshape(-1) preds = preds.reshape(-1)
probs = np.array(predict[1]) probs = np.array(predict[1])
ind = np.argmax(probs, axis=1) ind = np.argmax(probs, axis=1)
valid_ind = np.where(preds != 37)[0] valid_ind = np.where(preds != int(char_num-1))[0]
if len(valid_ind) == 0: if len(valid_ind) == 0:
continue continue
score = np.mean(probs[valid_ind, ind[valid_ind]]) score = np.mean(probs[valid_ind, ind[valid_ind]])
......
...@@ -18,9 +18,9 @@ from __future__ import print_function ...@@ -18,9 +18,9 @@ from __future__ import print_function
import os import os
import sys import sys
__dir__ = os.path.dirname(__file__) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..')) sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
def set_paddle_flags(**kwargs): def set_paddle_flags(**kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment