Unverified Commit d930b24e authored by dyning's avatar dyning Committed by GitHub
Browse files

Merge pull request #42 from dyning/develop

updata doc of infer
parents b24868cd 8a56798c
...@@ -15,19 +15,17 @@ ...@@ -15,19 +15,17 @@
import utility import utility
from ppocr.utils.utility import initial_logger from ppocr.utils.utility import initial_logger
logger = initial_logger() logger = initial_logger()
from ppocr.utils.utility import get_image_file_list
import cv2 import cv2
from ppocr.data.det.east_process import EASTProcessTest from ppocr.data.det.east_process import EASTProcessTest
from ppocr.data.det.db_process import DBProcessTest from ppocr.data.det.db_process import DBProcessTest
from ppocr.postprocess.db_postprocess import DBPostProcess from ppocr.postprocess.db_postprocess import DBPostProcess
from ppocr.postprocess.east_postprocess import EASTPostPocess from ppocr.postprocess.east_postprocess import EASTPostPocess
from ppocr.utils.utility import get_image_file_list
from tools.infer.utility import draw_ocr
import copy import copy
import numpy as np import numpy as np
import math import math
import time import time
import sys import sys
import os
class TextDetector(object): class TextDetector(object):
...@@ -79,27 +77,10 @@ class TextDetector(object): ...@@ -79,27 +77,10 @@ class TextDetector(object):
rect = np.array([tl, tr, br, bl], dtype="float32") rect = np.array([tl, tr, br, bl], dtype="float32")
return rect return rect
def expand_det_res(self, points, bbox_height, bbox_width, img_height, def clip_det_res(self, points, img_height, img_width):
img_width): for pno in range(4):
if bbox_height * 1.0 / bbox_width >= 2.0: points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
expand_w = bbox_width * 0.20 points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
expand_h = bbox_width * 0.20
elif bbox_width * 1.0 / bbox_height >= 3.0:
expand_w = bbox_height * 0.20
expand_h = bbox_height * 0.20
else:
expand_w = bbox_height * 0.1
expand_h = bbox_height * 0.1
points[0, 0] = int(max((points[0, 0] - expand_w), 0))
points[1, 0] = int(min((points[1, 0] + expand_w), img_width))
points[3, 0] = int(max((points[3, 0] - expand_w), 0))
points[2, 0] = int(min((points[2, 0] + expand_w), img_width))
points[0, 1] = int(max((points[0, 1] - expand_h), 0))
points[1, 1] = int(max((points[1, 1] - expand_h), 0))
points[3, 1] = int(min((points[3, 1] + expand_h), img_height))
points[2, 1] = int(min((points[2, 1] + expand_h), img_height))
return points return points
def filter_tag_det_res(self, dt_boxes, image_shape): def filter_tag_det_res(self, dt_boxes, image_shape):
...@@ -107,22 +88,11 @@ class TextDetector(object): ...@@ -107,22 +88,11 @@ class TextDetector(object):
dt_boxes_new = [] dt_boxes_new = []
for box in dt_boxes: for box in dt_boxes:
box = self.order_points_clockwise(box) box = self.order_points_clockwise(box)
left = int(np.min(box[:, 0])) box = self.clip_det_res(box, img_height, img_width)
right = int(np.max(box[:, 0]))
top = int(np.min(box[:, 1]))
bottom = int(np.max(box[:, 1]))
bbox_height = bottom - top
bbox_width = right - left
diffh = math.fabs(box[0, 1] - box[1, 1])
diffw = math.fabs(box[0, 0] - box[3, 0])
rect_width = int(np.linalg.norm(box[0] - box[1])) rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3])) rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 10 or rect_height <= 10: if rect_width <= 10 or rect_height <= 10:
continue continue
# if diffh <= 10 and diffw <= 10:
# box = self.expand_det_res(
# copy.deepcopy(box), bbox_height, bbox_width, img_height,
# img_width)
dt_boxes_new.append(box) dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new) dt_boxes = np.array(dt_boxes_new)
return dt_boxes return dt_boxes
...@@ -153,8 +123,6 @@ class TextDetector(object): ...@@ -153,8 +123,6 @@ class TextDetector(object):
return dt_boxes, elapse return dt_boxes, elapse
from tools.infer.utility import draw_text_det_res
if __name__ == "__main__": if __name__ == "__main__":
args = utility.parse_args() args = utility.parse_args()
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
...@@ -171,9 +139,8 @@ if __name__ == "__main__": ...@@ -171,9 +139,8 @@ if __name__ == "__main__":
total_time += elapse total_time += elapse
count += 1 count += 1
print("Predict time of %s:" % image_file, elapse) print("Predict time of %s:" % image_file, elapse)
img_draw = draw_text_det_res(dt_boxes, image_file, return_img=True) src_im = utility.draw_text_det_res(dt_boxes, image_file)
save_path = os.path.join("./inference_det/", img_name_pure = image_file.split("/")[-1]
os.path.basename(image_file)) cv2.imwrite("./inference_results/det_res_%s" % img_name_pure, src_im)
print("The visualized image saved in {}".format(save_path)) if count > 1:
print("Avg Time:", total_time / (count - 1)) print("Avg Time:", total_time / (count - 1))
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
import utility import utility
from ppocr.utils.utility import initial_logger from ppocr.utils.utility import initial_logger
logger = initial_logger() logger = initial_logger()
from ppocr.utils.utility import get_image_file_list
import cv2 import cv2
import copy import copy
import numpy as np import numpy as np
import math import math
...@@ -30,6 +30,7 @@ class TextRecognizer(object): ...@@ -30,6 +30,7 @@ class TextRecognizer(object):
utility.create_predictor(args, mode="rec") utility.create_predictor(args, mode="rec")
image_shape = [int(v) for v in args.rec_image_shape.split(",")] image_shape = [int(v) for v in args.rec_image_shape.split(",")]
self.rec_image_shape = image_shape self.rec_image_shape = image_shape
self.character_type = args.rec_char_type
char_ops_params = {} char_ops_params = {}
char_ops_params["character_type"] = args.rec_char_type char_ops_params["character_type"] = args.rec_char_type
char_ops_params["character_dict_path"] = args.rec_char_dict_path char_ops_params["character_dict_path"] = args.rec_char_dict_path
...@@ -38,6 +39,7 @@ class TextRecognizer(object): ...@@ -38,6 +39,7 @@ class TextRecognizer(object):
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape imgC, imgH, imgW = self.rec_image_shape
if self.character_type == "ch":
imgW = int(32 * max_wh_ratio) imgW = int(32 * max_wh_ratio)
h = img.shape[0] h = img.shape[0]
w = img.shape[1] w = img.shape[1]
...@@ -102,7 +104,7 @@ class TextRecognizer(object): ...@@ -102,7 +104,7 @@ class TextRecognizer(object):
if __name__ == "__main__": if __name__ == "__main__":
args = utility.parse_args() args = utility.parse_args()
image_file_list = utility.get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
text_recognizer = TextRecognizer(args) text_recognizer = TextRecognizer(args)
valid_image_file_list = [] valid_image_file_list = []
img_list = [] img_list = []
...@@ -114,6 +116,7 @@ if __name__ == "__main__": ...@@ -114,6 +116,7 @@ if __name__ == "__main__":
valid_image_file_list.append(image_file) valid_image_file_list.append(image_file)
img_list.append(img) img_list.append(img)
rec_res, predict_time = text_recognizer(img_list) rec_res, predict_time = text_recognizer(img_list)
rec_res, predict_time = text_recognizer(img_list)
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
print("Total predict time for %d images:%.3f" % print("Total predict time for %d images:%.3f" %
......
...@@ -191,8 +191,8 @@ def build_export(config, main_prog, startup_prog): ...@@ -191,8 +191,8 @@ def build_export(config, main_prog, startup_prog):
func_infor = config['Architecture']['function'] func_infor = config['Architecture']['function']
model = create_module(func_infor)(params=config) model = create_module(func_infor)(params=config)
image, outputs = model(mode='export') image, outputs = model(mode='export')
fetches_var = sorted([outputs[name] for name in outputs]) fetches_var_name = sorted([name for name in outputs])
fetches_var_name = [name for name in fetches_var] fetches_var = [outputs[name] for name in fetches_var_name]
feeded_var_names = [image.name] feeded_var_names = [image.name]
target_vars = fetches_var target_vars = fetches_var
return feeded_var_names, target_vars, fetches_var_name return feeded_var_names, target_vars, fetches_var_name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment