Commit 721c76b4 authored by LDOUBLEV's avatar LDOUBLEV
Browse files

fix conflict

parents 98162be4 b77f9ec0
...@@ -49,11 +49,19 @@ class TextSystem(object): ...@@ -49,11 +49,19 @@ class TextSystem(object):
if self.use_angle_cls: if self.use_angle_cls:
self.text_classifier = predict_cls.TextClassifier(args) self.text_classifier = predict_cls.TextClassifier(args)
def print_draw_crop_rec_res(self, img_crop_list, rec_res): self.args = args
self.crop_image_res_index = 0
def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
os.makedirs(output_dir, exist_ok=True)
bbox_num = len(img_crop_list) bbox_num = len(img_crop_list)
for bno in range(bbox_num): for bno in range(bbox_num):
cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno]) cv2.imwrite(
logger.info(bno, rec_res[bno]) os.path.join(output_dir,
f"mg_crop_{bno+self.crop_image_res_index}.jpg"),
img_crop_list[bno])
logger.debug(f"{bno}, {rec_res[bno]}")
self.crop_image_res_index += bbox_num
def __call__(self, img, cls=True): def __call__(self, img, cls=True):
ori_im = img.copy() ori_im = img.copy()
...@@ -80,7 +88,9 @@ class TextSystem(object): ...@@ -80,7 +88,9 @@ class TextSystem(object):
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.text_recognizer(img_crop_list)
logger.debug("rec_res num : {}, elapse : {}".format( logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse)) len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res) if self.args.save_crop_res:
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
rec_res)
filter_boxes, filter_rec_res = [], [] filter_boxes, filter_rec_res = [], []
for box, rec_reuslt in zip(dt_boxes, rec_res): for box, rec_reuslt in zip(dt_boxes, rec_res):
text, score = rec_reuslt text, score = rec_reuslt
...@@ -135,17 +145,17 @@ def main(args): ...@@ -135,17 +145,17 @@ def main(args):
if not flag: if not flag:
img = cv2.imread(image_file) img = cv2.imread(image_file)
if img is None: if img is None:
logger.info("error in loading image:{}".format(image_file)) logger.debug("error in loading image:{}".format(image_file))
continue continue
starttime = time.time() starttime = time.time()
dt_boxes, rec_res = text_sys(img) dt_boxes, rec_res = text_sys(img)
elapse = time.time() - starttime elapse = time.time() - starttime
total_time += elapse total_time += elapse
logger.info( logger.debug(
str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)) str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse))
for text, score in rec_res: for text, score in rec_res:
logger.info("{}, {:.3f}".format(text, score)) logger.debug("{}, {:.3f}".format(text, score))
if is_visualize: if is_visualize:
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
...@@ -160,19 +170,17 @@ def main(args): ...@@ -160,19 +170,17 @@ def main(args):
scores, scores,
drop_score=drop_score, drop_score=drop_score,
font_path=font_path) font_path=font_path)
draw_img_save = "./inference_results/" draw_img_save_dir = args.draw_img_save_dir
if not os.path.exists(draw_img_save): os.makedirs(draw_img_save_dir, exist_ok=True)
os.makedirs(draw_img_save)
if flag: if flag:
image_file = image_file[:-3] + "png" image_file = image_file[:-3] + "png"
cv2.imwrite( cv2.imwrite(
os.path.join(draw_img_save, os.path.basename(image_file)), os.path.join(draw_img_save_dir, os.path.basename(image_file)),
draw_img[:, :, ::-1]) draw_img[:, :, ::-1])
logger.info("The visualized image saved in {}".format( logger.debug("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file)))) os.path.join(draw_img_save_dir, os.path.basename(image_file))))
logger.info("The predict total time is {}".format(time.time() - _st)) logger.info("The predict total time is {}".format(time.time() - _st))
logger.info("\nThe predict total time is {}".format(total_time))
if args.benchmark: if args.benchmark:
text_sys.text_detector.autolog.report() text_sys.text_detector.autolog.report()
text_sys.text_recognizer.autolog.report() text_sys.text_recognizer.autolog.report()
......
...@@ -17,7 +17,7 @@ import os ...@@ -17,7 +17,7 @@ import os
import sys import sys
import cv2 import cv2
import numpy as np import numpy as np
import json import paddle
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import math import math
from paddle import inference from paddle import inference
...@@ -74,7 +74,6 @@ def init_args(): ...@@ -74,7 +74,6 @@ def init_args():
parser.add_argument("--rec_algorithm", type=str, default='CRNN') parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=6) parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25) parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument( parser.add_argument(
...@@ -97,7 +96,6 @@ def init_args(): ...@@ -97,7 +96,6 @@ def init_args():
parser.add_argument( parser.add_argument(
"--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt") "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext') parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
parser.add_argument("--e2e_pgnet_polygon", type=str2bool, default=True)
parser.add_argument("--e2e_pgnet_mode", type=str, default='fast') parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
# params for text classifier # params for text classifier
...@@ -111,7 +109,13 @@ def init_args(): ...@@ -111,7 +109,13 @@ def init_args():
parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--cpu_threads", type=int, default=10) parser.add_argument("--cpu_threads", type=int, default=10)
parser.add_argument("--use_pdserving", type=str2bool, default=False) parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--warmup", type=str2bool, default=True) parser.add_argument("--warmup", type=str2bool, default=False)
#
parser.add_argument(
"--draw_img_save_dir", type=str, default="./inference_results")
parser.add_argument("--save_crop_res", type=str2bool, default=False)
parser.add_argument("--crop_res_save_dir", type=str, default="./output")
# multi-process # multi-process
parser.add_argument("--use_mp", type=str2bool, default=False) parser.add_argument("--use_mp", type=str2bool, default=False)
...@@ -122,6 +126,7 @@ def init_args(): ...@@ -122,6 +126,7 @@ def init_args():
parser.add_argument("--save_log_path", type=str, default="./log_output/") parser.add_argument("--save_log_path", type=str, default="./log_output/")
parser.add_argument("--show_log", type=str2bool, default=True) parser.add_argument("--show_log", type=str2bool, default=True)
parser.add_argument("--use_onnx", type=str2bool, default=False)
return parser return parser
...@@ -145,158 +150,167 @@ def create_predictor(args, mode, logger): ...@@ -145,158 +150,167 @@ def create_predictor(args, mode, logger):
if model_dir is None: if model_dir is None:
logger.info("not find {} model file path {}".format(mode, model_dir)) logger.info("not find {} model file path {}".format(mode, model_dir))
sys.exit(0) sys.exit(0)
model_file_path = model_dir + "/inference.pdmodel" if args.use_onnx:
params_file_path = model_dir + "/inference.pdiparams" import onnxruntime as ort
if not os.path.exists(model_file_path): model_file_path = model_dir
raise ValueError("not find model file path {}".format(model_file_path)) if not os.path.exists(model_file_path):
if not os.path.exists(params_file_path): raise ValueError("not find model file path {}".format(
raise ValueError("not find params file path {}".format( model_file_path))
params_file_path)) sess = ort.InferenceSession(model_file_path)
return sess, sess.get_inputs()[0], None, None
config = inference.Config(model_file_path, params_file_path)
if hasattr(args, 'precision'):
if args.precision == "fp16" and args.use_tensorrt:
precision = inference.PrecisionType.Half
elif args.precision == "int8":
precision = inference.PrecisionType.Int8
else:
precision = inference.PrecisionType.Float32
else: else:
precision = inference.PrecisionType.Float32 model_file_path = model_dir + "/inference.pdmodel"
params_file_path = model_dir + "/inference.pdiparams"
if args.use_gpu: if not os.path.exists(model_file_path):
gpu_id = get_infer_gpuid() raise ValueError("not find model file path {}".format(
if gpu_id is None: model_file_path))
raise ValueError( if not os.path.exists(params_file_path):
"Not found GPU in current device. Please check your device or set args.use_gpu as False" raise ValueError("not find params file path {}".format(
) params_file_path))
config.enable_use_gpu(args.gpu_mem, 0)
if args.use_tensorrt: config = inference.Config(model_file_path, params_file_path)
config.enable_tensorrt_engine(
precision_mode=precision, if hasattr(args, 'precision'):
max_batch_size=args.max_batch_size, if args.precision == "fp16" and args.use_tensorrt:
min_subgraph_size=args.min_subgraph_size) precision = inference.PrecisionType.Half
# skip the minmum trt subgraph elif args.precision == "int8":
if mode == "det": precision = inference.PrecisionType.Int8
min_input_shape = { else:
"x": [1, 3, 50, 50], precision = inference.PrecisionType.Float32
"conv2d_92.tmp_0": [1, 120, 20, 20],
"conv2d_91.tmp_0": [1, 24, 10, 10],
"conv2d_59.tmp_0": [1, 96, 20, 20],
"nearest_interp_v2_1.tmp_0": [1, 256, 10, 10],
"nearest_interp_v2_2.tmp_0": [1, 256, 20, 20],
"conv2d_124.tmp_0": [1, 256, 20, 20],
"nearest_interp_v2_3.tmp_0": [1, 64, 20, 20],
"nearest_interp_v2_4.tmp_0": [1, 64, 20, 20],
"nearest_interp_v2_5.tmp_0": [1, 64, 20, 20],
"elementwise_add_7": [1, 56, 2, 2],
"nearest_interp_v2_0.tmp_0": [1, 256, 2, 2]
}
max_input_shape = {
"x": [1, 3, 2000, 2000],
"conv2d_92.tmp_0": [1, 120, 400, 400],
"conv2d_91.tmp_0": [1, 24, 200, 200],
"conv2d_59.tmp_0": [1, 96, 400, 400],
"nearest_interp_v2_1.tmp_0": [1, 256, 200, 200],
"conv2d_124.tmp_0": [1, 256, 400, 400],
"nearest_interp_v2_2.tmp_0": [1, 256, 400, 400],
"nearest_interp_v2_3.tmp_0": [1, 64, 400, 400],
"nearest_interp_v2_4.tmp_0": [1, 64, 400, 400],
"nearest_interp_v2_5.tmp_0": [1, 64, 400, 400],
"elementwise_add_7": [1, 56, 400, 400],
"nearest_interp_v2_0.tmp_0": [1, 256, 400, 400]
}
opt_input_shape = {
"x": [1, 3, 640, 640],
"conv2d_92.tmp_0": [1, 120, 160, 160],
"conv2d_91.tmp_0": [1, 24, 80, 80],
"conv2d_59.tmp_0": [1, 96, 160, 160],
"nearest_interp_v2_1.tmp_0": [1, 256, 80, 80],
"nearest_interp_v2_2.tmp_0": [1, 256, 160, 160],
"conv2d_124.tmp_0": [1, 256, 160, 160],
"nearest_interp_v2_3.tmp_0": [1, 64, 160, 160],
"nearest_interp_v2_4.tmp_0": [1, 64, 160, 160],
"nearest_interp_v2_5.tmp_0": [1, 64, 160, 160],
"elementwise_add_7": [1, 56, 40, 40],
"nearest_interp_v2_0.tmp_0": [1, 256, 40, 40]
}
min_pact_shape = {
"nearest_interp_v2_26.tmp_0": [1, 256, 20, 20],
"nearest_interp_v2_27.tmp_0": [1, 64, 20, 20],
"nearest_interp_v2_28.tmp_0": [1, 64, 20, 20],
"nearest_interp_v2_29.tmp_0": [1, 64, 20, 20]
}
max_pact_shape = {
"nearest_interp_v2_26.tmp_0": [1, 256, 400, 400],
"nearest_interp_v2_27.tmp_0": [1, 64, 400, 400],
"nearest_interp_v2_28.tmp_0": [1, 64, 400, 400],
"nearest_interp_v2_29.tmp_0": [1, 64, 400, 400]
}
opt_pact_shape = {
"nearest_interp_v2_26.tmp_0": [1, 256, 160, 160],
"nearest_interp_v2_27.tmp_0": [1, 64, 160, 160],
"nearest_interp_v2_28.tmp_0": [1, 64, 160, 160],
"nearest_interp_v2_29.tmp_0": [1, 64, 160, 160]
}
min_input_shape.update(min_pact_shape)
max_input_shape.update(max_pact_shape)
opt_input_shape.update(opt_pact_shape)
elif mode == "rec":
min_input_shape = {"x": [1, 3, 32, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]}
elif mode == "cls":
min_input_shape = {"x": [1, 3, 48, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 48, 2000]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
else: else:
min_input_shape = {"x": [1, 3, 10, 10]} precision = inference.PrecisionType.Float32
max_input_shape = {"x": [1, 3, 1000, 1000]}
opt_input_shape = {"x": [1, 3, 500, 500]} if args.use_gpu:
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, gpu_id = get_infer_gpuid()
opt_input_shape) if gpu_id is None:
logger.warning(
"GPU is not found in current device by nvidia-smi. Please check your device or ignore it if run on jeston."
)
config.enable_use_gpu(args.gpu_mem, 0)
if args.use_tensorrt:
config.enable_tensorrt_engine(
workspace_size=1 << 30,
precision_mode=precision,
max_batch_size=args.max_batch_size,
min_subgraph_size=args.min_subgraph_size)
# skip the minmum trt subgraph
if mode == "det":
min_input_shape = {
"x": [1, 3, 50, 50],
"conv2d_92.tmp_0": [1, 120, 20, 20],
"conv2d_91.tmp_0": [1, 24, 10, 10],
"conv2d_59.tmp_0": [1, 96, 20, 20],
"nearest_interp_v2_1.tmp_0": [1, 256, 10, 10],
"nearest_interp_v2_2.tmp_0": [1, 256, 20, 20],
"conv2d_124.tmp_0": [1, 256, 20, 20],
"nearest_interp_v2_3.tmp_0": [1, 64, 20, 20],
"nearest_interp_v2_4.tmp_0": [1, 64, 20, 20],
"nearest_interp_v2_5.tmp_0": [1, 64, 20, 20],
"elementwise_add_7": [1, 56, 2, 2],
"nearest_interp_v2_0.tmp_0": [1, 256, 2, 2]
}
max_input_shape = {
"x": [1, 3, 1280, 1280],
"conv2d_92.tmp_0": [1, 120, 400, 400],
"conv2d_91.tmp_0": [1, 24, 200, 200],
"conv2d_59.tmp_0": [1, 96, 400, 400],
"nearest_interp_v2_1.tmp_0": [1, 256, 200, 200],
"conv2d_124.tmp_0": [1, 256, 400, 400],
"nearest_interp_v2_2.tmp_0": [1, 256, 400, 400],
"nearest_interp_v2_3.tmp_0": [1, 64, 400, 400],
"nearest_interp_v2_4.tmp_0": [1, 64, 400, 400],
"nearest_interp_v2_5.tmp_0": [1, 64, 400, 400],
"elementwise_add_7": [1, 56, 400, 400],
"nearest_interp_v2_0.tmp_0": [1, 256, 400, 400]
}
opt_input_shape = {
"x": [1, 3, 640, 640],
"conv2d_92.tmp_0": [1, 120, 160, 160],
"conv2d_91.tmp_0": [1, 24, 80, 80],
"conv2d_59.tmp_0": [1, 96, 160, 160],
"nearest_interp_v2_1.tmp_0": [1, 256, 80, 80],
"nearest_interp_v2_2.tmp_0": [1, 256, 160, 160],
"conv2d_124.tmp_0": [1, 256, 160, 160],
"nearest_interp_v2_3.tmp_0": [1, 64, 160, 160],
"nearest_interp_v2_4.tmp_0": [1, 64, 160, 160],
"nearest_interp_v2_5.tmp_0": [1, 64, 160, 160],
"elementwise_add_7": [1, 56, 40, 40],
"nearest_interp_v2_0.tmp_0": [1, 256, 40, 40]
}
min_pact_shape = {
"nearest_interp_v2_26.tmp_0": [1, 256, 20, 20],
"nearest_interp_v2_27.tmp_0": [1, 64, 20, 20],
"nearest_interp_v2_28.tmp_0": [1, 64, 20, 20],
"nearest_interp_v2_29.tmp_0": [1, 64, 20, 20]
}
max_pact_shape = {
"nearest_interp_v2_26.tmp_0": [1, 256, 400, 400],
"nearest_interp_v2_27.tmp_0": [1, 64, 400, 400],
"nearest_interp_v2_28.tmp_0": [1, 64, 400, 400],
"nearest_interp_v2_29.tmp_0": [1, 64, 400, 400]
}
opt_pact_shape = {
"nearest_interp_v2_26.tmp_0": [1, 256, 160, 160],
"nearest_interp_v2_27.tmp_0": [1, 64, 160, 160],
"nearest_interp_v2_28.tmp_0": [1, 64, 160, 160],
"nearest_interp_v2_29.tmp_0": [1, 64, 160, 160]
}
min_input_shape.update(min_pact_shape)
max_input_shape.update(max_pact_shape)
opt_input_shape.update(opt_pact_shape)
elif mode == "rec":
min_input_shape = {"x": [1, 3, 32, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 32, 1024]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]}
elif mode == "cls":
min_input_shape = {"x": [1, 3, 48, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
else:
min_input_shape = {"x": [1, 3, 10, 10]}
max_input_shape = {"x": [1, 3, 512, 512]}
opt_input_shape = {"x": [1, 3, 256, 256]}
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
opt_input_shape)
else:
config.disable_gpu()
if hasattr(args, "cpu_threads"):
config.set_cpu_math_library_num_threads(args.cpu_threads)
else: else:
# default cpu threads as 10 config.disable_gpu()
config.set_cpu_math_library_num_threads(10) if hasattr(args, "cpu_threads"):
if args.enable_mkldnn: config.set_cpu_math_library_num_threads(args.cpu_threads)
# cache 10 different shapes for mkldnn to avoid memory leak else:
config.set_mkldnn_cache_capacity(10) # default cpu threads as 10
config.enable_mkldnn() config.set_cpu_math_library_num_threads(10)
if args.enable_mkldnn:
# enable memory optim # cache 10 different shapes for mkldnn to avoid memory leak
config.enable_memory_optim() config.set_mkldnn_cache_capacity(10)
#config.disable_glog_info() config.enable_mkldnn()
if args.precision == "fp16":
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.enable_mkldnn_bfloat16()
if mode == 'table': # enable memory optim
config.delete_pass("fc_fuse_pass") # not supported for table config.enable_memory_optim()
config.switch_use_feed_fetch_ops(False) config.disable_glog_info()
config.switch_ir_optim(True)
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
# create predictor if mode == 'table':
predictor = inference.create_predictor(config) config.delete_pass("fc_fuse_pass") # not supported for table
input_names = predictor.get_input_names() config.switch_use_feed_fetch_ops(False)
for name in input_names: config.switch_ir_optim(True)
input_tensor = predictor.get_input_handle(name)
output_names = predictor.get_output_names() # create predictor
output_tensors = [] predictor = inference.create_predictor(config)
for output_name in output_names: input_names = predictor.get_input_names()
output_tensor = predictor.get_output_handle(output_name) for name in input_names:
output_tensors.append(output_tensor) input_tensor = predictor.get_input_handle(name)
return predictor, input_tensor, output_tensors, config output_names = predictor.get_output_names()
output_tensors = []
for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor)
return predictor, input_tensor, output_tensors, config
def get_infer_gpuid(): def get_infer_gpuid():
cmd = "nvidia-smi"
res = os.popen(cmd).readlines()
if len(res) == 0:
return None
cmd = "env | grep CUDA_VISIBLE_DEVICES" cmd = "env | grep CUDA_VISIBLE_DEVICES"
env_cuda = os.popen(cmd).readlines() env_cuda = os.popen(cmd).readlines()
if len(env_cuda) == 0: if len(env_cuda) == 0:
...@@ -589,5 +603,12 @@ def get_rotate_crop_image(img, points): ...@@ -589,5 +603,12 @@ def get_rotate_crop_image(img, points):
return dst_img return dst_img
def check_gpu(use_gpu):
if use_gpu and not paddle.is_compiled_with_cuda():
use_gpu = False
return use_gpu
if __name__ == '__main__': if __name__ == '__main__':
pass pass
...@@ -32,7 +32,7 @@ import paddle ...@@ -32,7 +32,7 @@ import paddle
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
import tools.program as program import tools.program as program
...@@ -47,7 +47,7 @@ def main(): ...@@ -47,7 +47,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model) load_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
......
...@@ -34,7 +34,7 @@ import paddle ...@@ -34,7 +34,7 @@ import paddle
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model, load_dygraph_params from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
import tools.program as program import tools.program as program
...@@ -59,7 +59,7 @@ def main(): ...@@ -59,7 +59,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
_ = load_dygraph_params(config, model, logger, None) load_model(config, model)
# build post process # build post process
post_process_class = build_post_process(config['PostProcess']) post_process_class = build_post_process(config['PostProcess'])
......
...@@ -34,7 +34,7 @@ import paddle ...@@ -34,7 +34,7 @@ import paddle
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
import tools.program as program import tools.program as program
...@@ -68,7 +68,7 @@ def main(): ...@@ -68,7 +68,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model) load_model(config, model)
# build post process # build post process
post_process_class = build_post_process(config['PostProcess'], post_process_class = build_post_process(config['PostProcess'],
......
...@@ -33,7 +33,7 @@ import paddle ...@@ -33,7 +33,7 @@ import paddle
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
import tools.program as program import tools.program as program
...@@ -58,7 +58,7 @@ def main(): ...@@ -58,7 +58,7 @@ def main():
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model) load_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
...@@ -75,9 +75,7 @@ def main(): ...@@ -75,9 +75,7 @@ def main():
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2' 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
] ]
elif config['Architecture']['algorithm'] == "SAR": elif config['Architecture']['algorithm'] == "SAR":
op[op_name]['keep_keys'] = [ op[op_name]['keep_keys'] = ['image', 'valid_ratio']
'image', 'valid_ratio'
]
else: else:
op[op_name]['keep_keys'] = ['image'] op[op_name]['keep_keys'] = ['image']
transforms.append(op) transforms.append(op)
......
...@@ -34,11 +34,12 @@ from paddle.jit import to_static ...@@ -34,11 +34,12 @@ from paddle.jit import to_static
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
import tools.program as program import tools.program as program
import cv2 import cv2
def main(config, device, logger, vdl_writer): def main(config, device, logger, vdl_writer):
global_config = config['Global'] global_config = config['Global']
...@@ -53,7 +54,7 @@ def main(config, device, logger, vdl_writer): ...@@ -53,7 +54,7 @@ def main(config, device, logger, vdl_writer):
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) load_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
...@@ -104,4 +105,3 @@ def main(config, device, logger, vdl_writer): ...@@ -104,4 +105,3 @@ def main(config, device, logger, vdl_writer):
if __name__ == '__main__': if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess() config, device, logger, vdl_writer = program.preprocess()
main(config, device, logger, vdl_writer) main(config, device, logger, vdl_writer)
...@@ -159,7 +159,8 @@ def train(config, ...@@ -159,7 +159,8 @@ def train(config,
eval_class, eval_class,
pre_best_model_dict, pre_best_model_dict,
logger, logger,
vdl_writer=None): vdl_writer=None,
scaler=None):
cal_metric_during_train = config['Global'].get('cal_metric_during_train', cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False) False)
log_smooth_window = config['Global']['log_smooth_window'] log_smooth_window = config['Global']['log_smooth_window']
...@@ -211,15 +212,15 @@ def train(config, ...@@ -211,15 +212,15 @@ def train(config,
for epoch in range(start_epoch, epoch_num + 1): for epoch in range(start_epoch, epoch_num + 1):
train_dataloader = build_dataloader( train_dataloader = build_dataloader(
config, 'Train', device, logger, seed=epoch) config, 'Train', device, logger, seed=epoch)
train_batch_cost = 0.0
train_reader_cost = 0.0 train_reader_cost = 0.0
batch_sum = 0 train_run_cost = 0.0
batch_start = time.time() total_samples = 0
reader_start = time.time()
max_iter = len(train_dataloader) - 1 if platform.system( max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader) ) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
profiler.add_profiler_step(profiler_options) profiler.add_profiler_step(profiler_options)
train_reader_cost += time.time() - batch_start train_reader_cost += time.time() - reader_start
if idx >= max_iter: if idx >= max_iter:
break break
lr = optimizer.get_lr() lr = optimizer.get_lr()
...@@ -230,16 +231,34 @@ def train(config, ...@@ -230,16 +231,34 @@ def train(config,
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
if model_type == "kie": if model_type == "kie":
preds = model(batch) preds = model(batch)
train_start = time.time()
# use amp
if scaler:
with paddle.amp.auto_cast():
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
else:
preds = model(images)
else: else:
preds = model(images) if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
else:
preds = model(images)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
avg_loss = loss['loss'] avg_loss = loss['loss']
avg_loss.backward()
optimizer.step() if scaler:
scaled_avg_loss = scaler.scale(avg_loss)
scaled_avg_loss.backward()
scaler.minimize(optimizer, scaled_avg_loss)
else:
avg_loss.backward()
optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
train_batch_cost += time.time() - batch_start train_run_cost += time.time() - train_start
batch_sum += len(images) total_samples += len(images)
if not isinstance(lr_scheduler, float): if not isinstance(lr_scheduler, float):
lr_scheduler.step() lr_scheduler.step()
...@@ -270,12 +289,13 @@ def train(config, ...@@ -270,12 +289,13 @@ def train(config,
logs = train_stats.log() logs = train_stats.log()
strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format( strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
epoch, epoch_num, global_step, logs, train_reader_cost / epoch, epoch_num, global_step, logs, train_reader_cost /
print_batch_step, train_batch_cost / print_batch_step, print_batch_step, (train_reader_cost + train_run_cost) /
batch_sum, batch_sum / train_batch_cost) print_batch_step, total_samples,
total_samples / (train_reader_cost + train_run_cost))
logger.info(strs) logger.info(strs)
train_batch_cost = 0.0
train_reader_cost = 0.0 train_reader_cost = 0.0
batch_sum = 0 train_run_cost = 0.0
total_samples = 0
# eval # eval
if global_step > start_eval_step and \ if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
...@@ -328,7 +348,7 @@ def train(config, ...@@ -328,7 +348,7 @@ def train(config,
global_step) global_step)
global_step += 1 global_step += 1
optimizer.clear_grad() optimizer.clear_grad()
batch_start = time.time() reader_start = time.time()
if dist.get_rank() == 0: if dist.get_rank() == 0:
save_model( save_model(
model, model,
...@@ -369,7 +389,11 @@ def eval(model, ...@@ -369,7 +389,11 @@ def eval(model,
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
total_time = 0.0 total_time = 0.0
pbar = tqdm(total=len(valid_dataloader), desc='eval model:') pbar = tqdm(
total=len(valid_dataloader),
desc='eval model:',
position=0,
leave=True)
max_iter = len(valid_dataloader) - 1 if platform.system( max_iter = len(valid_dataloader) - 1 if platform.system(
) == "Windows" else len(valid_dataloader) ) == "Windows" else len(valid_dataloader)
for idx, batch in enumerate(valid_dataloader): for idx, batch in enumerate(valid_dataloader):
...@@ -404,6 +428,55 @@ def eval(model, ...@@ -404,6 +428,55 @@ def eval(model,
return metric return metric
def update_center(char_center, post_result, preds):
result, label = post_result
feats, logits = preds
logits = paddle.argmax(logits, axis=-1)
feats = feats.numpy()
logits = logits.numpy()
for idx_sample in range(len(label)):
if result[idx_sample][0] == label[idx_sample][0]:
feat = feats[idx_sample]
logit = logits[idx_sample]
for idx_time in range(len(logit)):
index = logit[idx_time]
if index in char_center.keys():
char_center[index][0] = (
char_center[index][0] * char_center[index][1] +
feat[idx_time]) / (char_center[index][1] + 1)
char_center[index][1] += 1
else:
char_center[index] = [feat[idx_time], 1]
return char_center
def get_center(model, eval_dataloader, post_process_class):
pbar = tqdm(total=len(eval_dataloader), desc='get center:')
max_iter = len(eval_dataloader) - 1 if platform.system(
) == "Windows" else len(eval_dataloader)
char_center = dict()
for idx, batch in enumerate(eval_dataloader):
if idx >= max_iter:
break
images = batch[0]
start = time.time()
preds = model(images)
batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods
post_result = post_process_class(preds, batch[1])
#update char_center
char_center = update_center(char_center, post_result, preds)
pbar.update(1)
pbar.close()
for key in char_center.keys():
char_center[key] = char_center[key][0]
return char_center
def preprocess(is_train=False): def preprocess(is_train=False):
FLAGS = ArgsParser().parse_args() FLAGS = ArgsParser().parse_args()
profiler_options = FLAGS.profiler_options profiler_options = FLAGS.profiler_options
......
...@@ -35,7 +35,7 @@ from ppocr.losses import build_loss ...@@ -35,7 +35,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model, load_dygraph_params from ppocr.utils.save_load import load_model
import tools.program as program import tools.program as program
dist.get_world_size() dist.get_world_size()
...@@ -97,15 +97,32 @@ def main(config, device, logger, vdl_writer): ...@@ -97,15 +97,32 @@ def main(config, device, logger, vdl_writer):
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) pre_best_model_dict = load_model(config, model, optimizer)
logger.info('train dataloader has {} iters'.format(len(train_dataloader))) logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None: if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format( logger.info('valid dataloader has {} iters'.format(
len(valid_dataloader))) len(valid_dataloader)))
use_amp = config["Global"].get("use_amp", False)
if use_amp:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
'FLAGS_max_inplace_grad_add': 8,
}
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
scale_loss = config["Global"].get("scale_loss", 1.0)
use_dynamic_loss_scaling = config["Global"].get(
"use_dynamic_loss_scaling", False)
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
else:
scaler = None
# start train # start train
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class, loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer) eval_class, pre_best_model_dict, logger, vdl_writer, scaler)
def test_reader(config, device, logger): def test_reader(config, device, logger):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment