Commit 3f92af45 authored by MissPenguin's avatar MissPenguin
Browse files

Merge branch 'dygraph' of https://github.com/MissPenguin/PaddleOCR into dygraph

parents 79e83a8d 79c0a060
...@@ -20,9 +20,9 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) ...@@ -20,9 +20,9 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import cv2 import cv2
import json import json
from tqdm import tqdm from tqdm import tqdm
from test.table.table_metric import TEDS from test1.table.table_metric import TEDS
from test.table.predict_table import TableSystem from test1.table.predict_table import TableSystem
from test.utility import init_args from test1.utility import init_args
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
logger = get_logger() logger = get_logger()
......
...@@ -32,7 +32,7 @@ from ppocr.data import create_operators, transform ...@@ -32,7 +32,7 @@ from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from test.utility import parse_args from test1.utility import parse_args
logger = get_logger() logger = get_logger()
......
...@@ -30,9 +30,9 @@ import tools.infer.predict_rec as predict_rec ...@@ -30,9 +30,9 @@ import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det import tools.infer.predict_det as predict_det
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from test.table.matcher import distance, compute_iou from test1.table.matcher import distance, compute_iou
from test.utility import parse_args from test1.utility import parse_args
import test.table.predict_structure as predict_strture import test1.table.predict_structure as predict_strture
logger = get_logger() logger = get_logger()
......
...@@ -44,8 +44,15 @@ def main(): ...@@ -44,8 +44,15 @@ def main():
# build model # build model
# for rec algorithm # for rec algorithm
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
config['Architecture']["Head"]['out_channels'] = len( char_num = len(getattr(post_process_class, 'character'))
getattr(post_process_class, 'character')) if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
model_type = config['Architecture']['model_type'] model_type = config['Architecture']['model_type']
......
...@@ -31,7 +31,7 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif ...@@ -31,7 +31,7 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
import tools.infer.benchmark_utils as benchmark_utils # import tools.infer.benchmark_utils as benchmark_utils
logger = get_logger() logger = get_logger()
...@@ -100,8 +100,6 @@ class TextDetector(object): ...@@ -100,8 +100,6 @@ class TextDetector(object):
self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor( self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
args, 'det', logger) args, 'det', logger)
self.det_times = utility.Timer()
def order_points_clockwise(self, pts): def order_points_clockwise(self, pts):
""" """
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
...@@ -158,8 +156,8 @@ class TextDetector(object): ...@@ -158,8 +156,8 @@ class TextDetector(object):
def __call__(self, img): def __call__(self, img):
ori_im = img.copy() ori_im = img.copy()
data = {'image': img} data = {'image': img}
self.det_times.total_time.start()
self.det_times.preprocess_time.start() st = time.time()
data = transform(data, self.preprocess_op) data = transform(data, self.preprocess_op)
img, shape_list = data img, shape_list = data
if img is None: if img is None:
...@@ -168,16 +166,12 @@ class TextDetector(object): ...@@ -168,16 +166,12 @@ class TextDetector(object):
shape_list = np.expand_dims(shape_list, axis=0) shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy() img = img.copy()
self.det_times.preprocess_time.end()
self.det_times.inference_time.start()
self.input_tensor.copy_from_cpu(img) self.input_tensor.copy_from_cpu(img)
self.predictor.run() self.predictor.run()
outputs = [] outputs = []
for output_tensor in self.output_tensors: for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
outputs.append(output) outputs.append(output)
self.det_times.inference_time.end()
preds = {} preds = {}
if self.det_algorithm == "EAST": if self.det_algorithm == "EAST":
...@@ -193,8 +187,6 @@ class TextDetector(object): ...@@ -193,8 +187,6 @@ class TextDetector(object):
else: else:
raise NotImplementedError raise NotImplementedError
self.det_times.postprocess_time.start()
self.predictor.try_shrink_memory() self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list) post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points'] dt_boxes = post_result[0]['points']
...@@ -203,10 +195,8 @@ class TextDetector(object): ...@@ -203,10 +195,8 @@ class TextDetector(object):
else: else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
self.det_times.postprocess_time.end() et = time.time()
self.det_times.total_time.end() return dt_boxes, et - st
self.det_times.img_num += 1
return dt_boxes, self.det_times.total_time.value()
if __name__ == "__main__": if __name__ == "__main__":
...@@ -216,12 +206,13 @@ if __name__ == "__main__": ...@@ -216,12 +206,13 @@ if __name__ == "__main__":
count = 0 count = 0
total_time = 0 total_time = 0
draw_img_save = "./inference_results" draw_img_save = "./inference_results"
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
# warmup 10 times if args.warmup:
fake_img = np.random.uniform(-1, 1, [640, 640, 3]).astype(np.float32) img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
for i in range(10): for i in range(10):
dt_boxes, _ = text_detector(fake_img) res = text_detector(img)
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
if not os.path.exists(draw_img_save): if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save) os.makedirs(draw_img_save)
...@@ -239,12 +230,6 @@ if __name__ == "__main__": ...@@ -239,12 +230,6 @@ if __name__ == "__main__":
total_time += elapse total_time += elapse
count += 1 count += 1
if args.benchmark:
cm, gm, gu = utility.get_current_memory_mb(0)
cpu_mem += cm
gpu_mem += gm
gpu_util += gu
logger.info("Predict time of {}: {}".format(image_file, elapse)) logger.info("Predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, image_file) src_im = utility.draw_text_det_res(dt_boxes, image_file)
img_name_pure = os.path.split(image_file)[-1] img_name_pure = os.path.split(image_file)[-1]
...@@ -252,36 +237,3 @@ if __name__ == "__main__": ...@@ -252,36 +237,3 @@ if __name__ == "__main__":
"det_res_{}".format(img_name_pure)) "det_res_{}".format(img_name_pure))
logger.info("The visualized image saved in {}".format(img_path)) logger.info("The visualized image saved in {}".format(img_path))
# print the information about memory and time-spent
if args.benchmark:
mems = {
'cpu_rss_mb': cpu_mem / count,
'gpu_rss_mb': gpu_mem / count,
'gpu_util': gpu_util * 100 / count
}
else:
mems = None
logger.info("The predict time about detection module is as follows: ")
det_time_dict = text_detector.det_times.report(average=True)
det_model_name = args.det_model_dir
if args.benchmark:
# construct log information
model_info = {
'model_name': args.det_model_dir.split('/')[-1],
'precision': args.precision
}
data_info = {
'batch_size': 1,
'shape': 'dynamic_shape',
'data_num': det_time_dict['img_num']
}
perf_info = {
'preprocess_time_s': det_time_dict['preprocess_time'],
'inference_time_s': det_time_dict['inference_time'],
'postprocess_time_s': det_time_dict['postprocess_time'],
'total_time_s': det_time_dict['total_time']
}
benchmark_log = benchmark_utils.PaddleInferBenchmark(
text_detector.config, model_info, data_info, perf_info, mems)
benchmark_log("Det")
...@@ -257,13 +257,15 @@ def main(args): ...@@ -257,13 +257,15 @@ def main(args):
text_recognizer = TextRecognizer(args) text_recognizer = TextRecognizer(args)
valid_image_file_list = [] valid_image_file_list = []
img_list = [] img_list = []
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
count = 0
# warmup 10 times # warmup 10 times
fake_img = np.random.uniform(-1, 1, [1, 32, 320, 3]).astype(np.float32) if args.warmup:
img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8)
for i in range(10): for i in range(10):
dt_boxes, _ = text_recognizer(fake_img) res = text_recognizer([img])
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
count = 0
for image_file in image_file_list: for image_file in image_file_list:
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import sys import sys
import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
...@@ -93,7 +94,6 @@ class TextSystem(object): ...@@ -93,7 +94,6 @@ class TextSystem(object):
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
logger.debug("dt_boxes num : {}, elapse : {}".format( logger.debug("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse)) len(dt_boxes), elapse))
if dt_boxes is None: if dt_boxes is None:
return None, None return None, None
...@@ -147,15 +147,24 @@ def sorted_boxes(dt_boxes): ...@@ -147,15 +147,24 @@ def sorted_boxes(dt_boxes):
def main(args): def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list[args.process_id::args.total_process_num]
text_sys = TextSystem(args) text_sys = TextSystem(args)
is_visualize = True is_visualize = True
font_path = args.vis_font_path font_path = args.vis_font_path
drop_score = args.drop_score drop_score = args.drop_score
# warm up 10 times
if args.warmup:
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
for i in range(10):
res = text_sys(img)
total_time = 0 total_time = 0
cpu_mem, gpu_mem, gpu_util = 0, 0, 0 cpu_mem, gpu_mem, gpu_util = 0, 0, 0
_st = time.time() _st = time.time()
count = 0 count = 0
for idx, image_file in enumerate(image_file_list): for idx, image_file in enumerate(image_file_list):
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
img = cv2.imread(image_file) img = cv2.imread(image_file)
...@@ -264,4 +273,18 @@ def main(args): ...@@ -264,4 +273,18 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
main(utility.parse_args()) args = utility.parse_args()
if args.use_mp:
p_list = []
total_process_num = args.total_process_num
for process_id in range(total_process_num):
cmd = [sys.executable, "-u"] + sys.argv + [
"--process_id={}".format(process_id),
"--use_mp={}".format(False)
]
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
p_list.append(p)
for p in p_list:
p.wait()
else:
main(args)
...@@ -105,7 +105,9 @@ def init_args(): ...@@ -105,7 +105,9 @@ def init_args():
parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--cpu_threads", type=int, default=10) parser.add_argument("--cpu_threads", type=int, default=10)
parser.add_argument("--use_pdserving", type=str2bool, default=False) parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--warmup", type=str2bool, default=True)
# multi-process
parser.add_argument("--use_mp", type=str2bool, default=False) parser.add_argument("--use_mp", type=str2bool, default=False)
parser.add_argument("--total_process_num", type=int, default=1) parser.add_argument("--total_process_num", type=int, default=1)
parser.add_argument("--process_id", type=int, default=0) parser.add_argument("--process_id", type=int, default=0)
...@@ -113,7 +115,6 @@ def init_args(): ...@@ -113,7 +115,6 @@ def init_args():
parser.add_argument("--benchmark", type=bool, default=False) parser.add_argument("--benchmark", type=bool, default=False)
parser.add_argument("--save_log_path", type=str, default="./log_output/") parser.add_argument("--save_log_path", type=str, default="./log_output/")
parser.add_argument("--show_log", type=str2bool, default=True) parser.add_argument("--show_log", type=str2bool, default=True)
return parser return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment