Commit f7081e38 authored by WenmuZhou's avatar WenmuZhou
Browse files

Merge remote-tracking branch 'origin/lite' into lite

parents 40f20c6f 3f1cb773
...@@ -47,6 +47,7 @@ def main(): ...@@ -47,6 +47,7 @@ def main():
config['Architecture']["Head"]['out_channels'] = len( config['Architecture']["Head"]['out_channels'] = len(
getattr(post_process_class, 'character')) getattr(post_process_class, 'character'))
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN"
best_model_dict = init_model(config, model, logger) best_model_dict = init_model(config, model, logger)
if len(best_model_dict): if len(best_model_dict):
...@@ -59,7 +60,7 @@ def main(): ...@@ -59,7 +60,7 @@ def main():
# start eval # start eval
metirc = program.eval(model, valid_dataloader, post_process_class, metirc = program.eval(model, valid_dataloader, post_process_class,
eval_class) eval_class, use_srn)
logger.info('metric eval ***************') logger.info('metric eval ***************')
for k, v in metirc.items(): for k, v in metirc.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
......
...@@ -31,6 +31,14 @@ from ppocr.utils.logging import get_logger ...@@ -31,6 +31,14 @@ from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser from tools.program import load_config, merge_config, ArgsParser
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", help="configuration file to use")
parser.add_argument(
"-o", "--output_path", type=str, default='./output/infer/')
return parser.parse_args()
def main(): def main():
FLAGS = ArgsParser().parse_args() FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
...@@ -51,14 +59,40 @@ def main(): ...@@ -51,14 +59,40 @@ def main():
model.eval() model.eval()
save_path = '{}/inference'.format(config['Global']['save_inference_dir']) save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
infer_shape = [3, 32, 100] if config['Architecture'][
'model_type'] != "det" else [3, 640, 640] if config['Architecture']['algorithm'] == "SRN":
model = to_static( other_shape = [
model,
input_spec=[
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32') shape=[None, 1, 64, 256], dtype='float32'), [
]) paddle.static.InputSpec(
shape=[None, 256, 1],
dtype="int64"), paddle.static.InputSpec(
shape=[None, 25, 1],
dtype="int64"), paddle.static.InputSpec(
shape=[None, 8, 25, 25], dtype="int64"),
paddle.static.InputSpec(
shape=[None, 8, 25, 25], dtype="int64")
]
]
model = to_static(model, input_spec=other_shape)
else:
infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32
if 'Transform' in config['Architecture'] and config['Architecture'][
'Transform'] is not None and config['Architecture'][
'Transform']['name'] == 'TPS':
logger.info(
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
)
infer_shape[-1] = 100
model = to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32')
])
paddle.jit.save(model, save_path) paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path)) logger.info('inference model is saved to {}'.format(save_path))
......
...@@ -98,10 +98,10 @@ class TextClassifier(object): ...@@ -98,10 +98,10 @@ class TextClassifier(object):
norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
starttime = time.time() starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run() self.predictor.run()
prob_out = self.output_tensors[0].copy_to_cpu() prob_out = self.output_tensors[0].copy_to_cpu()
self.predictor.try_shrink_memory()
cls_result = self.postprocess_op(prob_out) cls_result = self.postprocess_op(prob_out)
elapse += time.time() - starttime elapse += time.time() - starttime
for rno in range(len(cls_result)): for rno in range(len(cls_result)):
......
...@@ -39,10 +39,7 @@ class TextDetector(object): ...@@ -39,10 +39,7 @@ class TextDetector(object):
self.args = args self.args = args
self.det_algorithm = args.det_algorithm self.det_algorithm = args.det_algorithm
pre_process_list = [{ pre_process_list = [{
'DetResizeForTest': { 'DetResizeForTest': None
'limit_side_len': args.det_limit_side_len,
'limit_type': args.det_limit_type
}
}, { }, {
'NormalizeImage': { 'NormalizeImage': {
'std': [0.229, 0.224, 0.225], 'std': [0.229, 0.224, 0.225],
...@@ -64,7 +61,7 @@ class TextDetector(object): ...@@ -64,7 +61,7 @@ class TextDetector(object):
postprocess_params["box_thresh"] = args.det_db_box_thresh postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000 postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = True postprocess_params["use_dilation"] = args.use_dilation
elif self.det_algorithm == "EAST": elif self.det_algorithm == "EAST":
postprocess_params['name'] = 'EASTPostProcess' postprocess_params['name'] = 'EASTPostProcess'
postprocess_params["score_thresh"] = args.det_east_score_thresh postprocess_params["score_thresh"] = args.det_east_score_thresh
...@@ -183,7 +180,7 @@ class TextDetector(object): ...@@ -183,7 +180,7 @@ class TextDetector(object):
preds['maps'] = outputs[0] preds['maps'] = outputs[0]
else: else:
raise NotImplementedError raise NotImplementedError
self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list) post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points'] dt_boxes = post_result[0]['points']
if self.det_algorithm == "SAST" and self.det_sast_polygon: if self.det_algorithm == "SAST" and self.det_sast_polygon:
......
...@@ -25,6 +25,7 @@ import numpy as np ...@@ -25,6 +25,7 @@ import numpy as np
import math import math
import time import time
import traceback import traceback
import paddle
import tools.infer.utility as utility import tools.infer.utility as utility
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
...@@ -46,6 +47,20 @@ class TextRecognizer(object): ...@@ -46,6 +47,20 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path, "character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char "use_space_char": args.use_space_char
} }
if self.rec_algorithm == "SRN":
postprocess_params = {
'name': 'SRNLabelDecode',
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "RARE":
postprocess_params = {
'name': 'AttnLabelDecode',
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = \ self.predictor, self.input_tensor, self.output_tensors = \
utility.create_predictor(args, 'rec', logger) utility.create_predictor(args, 'rec', logger)
...@@ -70,6 +85,78 @@ class TextRecognizer(object): ...@@ -70,6 +85,78 @@ class TextRecognizer(object):
padding_im[:, :, 0:resized_w] = resized_image padding_im[:, :, 0:resized_w] = resized_image
return padding_im return padding_im
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(
gsrm_slf_attn_bias1,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(
gsrm_slf_attn_bias2,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
norm_img = self.resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
self.srn_other_inputs(image_shape, num_heads, max_text_length)
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
encoder_word_pos = encoder_word_pos.astype(np.int64)
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
def __call__(self, img_list): def __call__(self, img_list):
img_num = len(img_list) img_num = len(img_list)
# Calculate the aspect ratio of all text bars # Calculate the aspect ratio of all text bars
...@@ -93,21 +180,64 @@ class TextRecognizer(object): ...@@ -93,21 +180,64 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio) max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no): for ino in range(beg_img_no, end_img_no):
# norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio) if self.rec_algorithm != "SRN":
norm_img = self.resize_norm_img(img_list[indices[ino]], norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio) max_wh_ratio)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img) norm_img_batch.append(norm_img)
else:
norm_img = self.process_image_srn(
img_list[indices[ino]], self.rec_image_shape, 8, 25)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
encoder_word_pos_list.append(norm_img[1])
gsrm_word_pos_list.append(norm_img[2])
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch) if self.rec_algorithm == "SRN":
self.predictor.run() starttime = time.time()
outputs = [] encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
for output_tensor in self.output_tensors: gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
output = output_tensor.copy_to_cpu() gsrm_slf_attn_bias1_list = np.concatenate(
outputs.append(output) gsrm_slf_attn_bias1_list)
preds = outputs[0] gsrm_slf_attn_bias2_list = np.concatenate(
gsrm_slf_attn_bias2_list)
inputs = [
norm_img_batch,
encoder_word_pos_list,
gsrm_word_pos_list,
gsrm_slf_attn_bias1_list,
gsrm_slf_attn_bias2_list,
]
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[
i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = {"predict": outputs[2]}
else:
starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = outputs[0]
self.predictor.try_shrink_memory()
rec_result = self.postprocess_op(preds) rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)): for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno] rec_res[indices[beg_img_no + rno]] = rec_result[rno]
...@@ -118,9 +248,11 @@ class TextRecognizer(object): ...@@ -118,9 +248,11 @@ class TextRecognizer(object):
def main(args): def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
text_recognizer = TextRecognizer(args) text_recognizer = TextRecognizer(args)
total_run_time = 0.0
total_images_num = 0
valid_image_file_list = [] valid_image_file_list = []
img_list = [] img_list = []
for image_file in image_file_list: for idx, image_file in enumerate(image_file_list):
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
img = cv2.imread(image_file) img = cv2.imread(image_file)
...@@ -129,22 +261,29 @@ def main(args): ...@@ -129,22 +261,29 @@ def main(args):
continue continue
valid_image_file_list.append(image_file) valid_image_file_list.append(image_file)
img_list.append(img) img_list.append(img)
try: if len(img_list) >= args.rec_batch_num or idx == len(
rec_res, predict_time = text_recognizer(img_list) image_file_list) - 1:
except: try:
logger.info(traceback.format_exc()) rec_res, predict_time = text_recognizer(img_list)
logger.info( total_run_time += predict_time
"ERROR!!!! \n" except:
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" logger.info(traceback.format_exc())
"If your model has tps module: " logger.info(
"TPS does not support variable shape.\n" "ERROR!!!! \n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
exit() "If your model has tps module: "
for ino in range(len(img_list)): "TPS does not support variable shape.\n"
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
rec_res[ino])) )
exit()
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[
ino], rec_res[ino]))
total_images_num += len(valid_image_file_list)
valid_image_file_list = []
img_list = []
logger.info("Total predict time for {} images, cost: {:.3f}".format( logger.info("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time)) total_images_num, total_run_time))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -184,4 +184,4 @@ def main(args): ...@@ -184,4 +184,4 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
main(utility.parse_args()) main(utility.parse_args())
\ No newline at end of file
...@@ -47,6 +47,7 @@ def parse_args(): ...@@ -47,6 +47,7 @@ def parse_args():
parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--use_dilation", type=bool, default=False)
# EAST parmas # EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
...@@ -70,7 +71,7 @@ def parse_args(): ...@@ -70,7 +71,7 @@ def parse_args():
default="./ppocr/utils/ppocr_keys_v1.txt") default="./ppocr/utils/ppocr_keys_v1.txt")
parser.add_argument("--use_space_char", type=str2bool, default=True) parser.add_argument("--use_space_char", type=str2bool, default=True)
parser.add_argument( parser.add_argument(
"--vis_font_path", type=str, default="./doc/simfang.ttf") "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
parser.add_argument("--drop_score", type=float, default=0.5) parser.add_argument("--drop_score", type=float, default=0.5)
# params for text classifier # params for text classifier
...@@ -123,9 +124,12 @@ def create_predictor(args, mode, logger): ...@@ -123,9 +124,12 @@ def create_predictor(args, mode, logger):
# cache 10 different shapes for mkldnn to avoid memory leak # cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10) config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn() config.enable_mkldnn()
# TODO LDOUBLEV: fix mkldnn bug when bach_size > 1
#config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'})
args.rec_batch_num = 1 args.rec_batch_num = 1
# config.enable_memory_optim() # enable memory optim
config.enable_memory_optim()
config.disable_glog_info() config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
......
...@@ -97,7 +97,7 @@ def main(): ...@@ -97,7 +97,7 @@ def main():
preds = model(images) preds = model(images)
post_result = post_process_class(preds, shape_list) post_result = post_process_class(preds, shape_list)
boxes = post_result[0]['points'] boxes = post_result[0]['points']
# write resule # write result
dt_boxes_json = [] dt_boxes_json = []
for box in boxes: for box in boxes:
tmp_json = {"transcription": ""} tmp_json = {"transcription": ""}
......
...@@ -62,7 +62,13 @@ def main(): ...@@ -62,7 +62,13 @@ def main():
elif op_name in ['RecResizeImg']: elif op_name in ['RecResizeImg']:
op[op_name]['infer_mode'] = True op[op_name]['infer_mode'] = True
elif op_name == 'KeepKeys': elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image'] if config['Architecture']['algorithm'] == "SRN":
op[op_name]['keep_keys'] = [
'image', 'encoder_word_pos', 'gsrm_word_pos',
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
]
else:
op[op_name]['keep_keys'] = ['image']
transforms.append(op) transforms.append(op)
global_config['infer_mode'] = True global_config['infer_mode'] = True
ops = create_operators(transforms, global_config) ops = create_operators(transforms, global_config)
...@@ -74,10 +80,25 @@ def main(): ...@@ -74,10 +80,25 @@ def main():
img = f.read() img = f.read()
data = {'image': img} data = {'image': img}
batch = transform(data, ops) batch = transform(data, ops)
if config['Architecture']['algorithm'] == "SRN":
encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
others = [
paddle.to_tensor(encoder_word_pos_list),
paddle.to_tensor(gsrm_word_pos_list),
paddle.to_tensor(gsrm_slf_attn_bias1_list),
paddle.to_tensor(gsrm_slf_attn_bias2_list)
]
images = np.expand_dims(batch[0], axis=0) images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images) images = paddle.to_tensor(images)
preds = model(images) if config['Architecture']['algorithm'] == "SRN":
preds = model(images, others)
else:
preds = model(images)
post_result = post_process_class(preds) post_result = post_process_class(preds)
for rec_reuslt in post_result: for rec_reuslt in post_result:
logger.info('\t result: {}'.format(rec_reuslt)) logger.info('\t result: {}'.format(rec_reuslt))
......
...@@ -163,6 +163,11 @@ def train(config, ...@@ -163,6 +163,11 @@ def train(config,
if type(eval_batch_step) == list and len(eval_batch_step) >= 2: if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0] start_eval_step = eval_batch_step[0]
eval_batch_step = eval_batch_step[1] eval_batch_step = eval_batch_step[1]
if len(valid_dataloader) == 0:
logger.info(
'No Images in eval dataset, evaluation during training will be disabled'
)
start_eval_step = 1e111
logger.info( logger.info(
"During the training process, after the {}th iteration, an evaluation is run every {} iterations". "During the training process, after the {}th iteration, an evaluation is run every {} iterations".
format(start_eval_step, eval_batch_step)) format(start_eval_step, eval_batch_step))
...@@ -174,16 +179,19 @@ def train(config, ...@@ -174,16 +179,19 @@ def train(config,
best_model_dict = {main_indicator: 0} best_model_dict = {main_indicator: 0}
best_model_dict.update(pre_best_model_dict) best_model_dict.update(pre_best_model_dict)
train_stats = TrainingStats(log_smooth_window, ['lr']) train_stats = TrainingStats(log_smooth_window, ['lr'])
model_average = False
model.train() model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
if 'start_epoch' in best_model_dict: if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch'] start_epoch = best_model_dict['start_epoch']
else: else:
start_epoch = 1 start_epoch = 1
for epoch in range(start_epoch, epoch_num + 1): for epoch in range(start_epoch, epoch_num + 1):
if epoch > 0: train_dataloader = build_dataloader(
train_dataloader = build_dataloader(config, 'Train', device, logger) config, 'Train', device, logger, seed=epoch)
train_batch_cost = 0.0 train_batch_cost = 0.0
train_reader_cost = 0.0 train_reader_cost = 0.0
batch_sum = 0 batch_sum = 0
...@@ -194,7 +202,12 @@ def train(config, ...@@ -194,7 +202,12 @@ def train(config,
break break
lr = optimizer.get_lr() lr = optimizer.get_lr()
images = batch[0] images = batch[0]
preds = model(images) if use_srn:
others = batch[-4:]
preds = model(images, others)
model_average = True
else:
preds = model(images)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
avg_loss = loss['loss'] avg_loss = loss['loss']
avg_loss.backward() avg_loss.backward()
...@@ -212,12 +225,12 @@ def train(config, ...@@ -212,12 +225,12 @@ def train(config,
stats['lr'] = lr stats['lr'] = lr
train_stats.update(stats) train_stats.update(stats)
if cal_metric_during_train: # onlt rec and cls need if cal_metric_during_train: # only rec and cls need
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
post_result = post_process_class(preds, batch[1]) post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch) eval_class(post_result, batch)
metirc = eval_class.get_metric() metric = eval_class.get_metric()
train_stats.update(metirc) train_stats.update(metric)
if vdl_writer is not None and dist.get_rank() == 0: if vdl_writer is not None and dist.get_rank() == 0:
for k, v in train_stats.get().items(): for k, v in train_stats.get().items():
...@@ -238,21 +251,32 @@ def train(config, ...@@ -238,21 +251,32 @@ def train(config,
# eval # eval
if global_step > start_eval_step and \ if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
cur_metirc = eval(model, valid_dataloader, post_process_class, if model_average:
eval_class) Model_Average = paddle.incubate.optimizer.ModelAverage(
cur_metirc_str = 'cur metirc, {}'.format(', '.join( 0.15,
['{}: {}'.format(k, v) for k, v in cur_metirc.items()])) parameters=model.parameters(),
logger.info(cur_metirc_str) min_average_window=10000,
max_average_window=15625)
Model_Average.apply()
cur_metric = eval(
model,
valid_dataloader,
post_process_class,
eval_class,
use_srn=use_srn)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str)
# logger metric # logger metric
if vdl_writer is not None: if vdl_writer is not None:
for k, v in cur_metirc.items(): for k, v in cur_metric.items():
if isinstance(v, (float, int)): if isinstance(v, (float, int)):
vdl_writer.add_scalar('EVAL/{}'.format(k), vdl_writer.add_scalar('EVAL/{}'.format(k),
cur_metirc[k], global_step) cur_metric[k], global_step)
if cur_metirc[main_indicator] >= best_model_dict[ if cur_metric[main_indicator] >= best_model_dict[
main_indicator]: main_indicator]:
best_model_dict.update(cur_metirc) best_model_dict.update(cur_metric)
best_model_dict['best_epoch'] = epoch best_model_dict['best_epoch'] = epoch
save_model( save_model(
model, model,
...@@ -263,7 +287,7 @@ def train(config, ...@@ -263,7 +287,7 @@ def train(config,
prefix='best_accuracy', prefix='best_accuracy',
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch) epoch=epoch)
best_str = 'best metirc, {}'.format(', '.join([ best_str = 'best metric, {}'.format(', '.join([
'{}: {}'.format(k, v) for k, v in best_model_dict.items() '{}: {}'.format(k, v) for k, v in best_model_dict.items()
])) ]))
logger.info(best_str) logger.info(best_str)
...@@ -273,6 +297,7 @@ def train(config, ...@@ -273,6 +297,7 @@ def train(config,
best_model_dict[main_indicator], best_model_dict[main_indicator],
global_step) global_step)
global_step += 1 global_step += 1
optimizer.clear_grad()
batch_start = time.time() batch_start = time.time()
if dist.get_rank() == 0: if dist.get_rank() == 0:
save_model( save_model(
...@@ -294,7 +319,7 @@ def train(config, ...@@ -294,7 +319,7 @@ def train(config,
prefix='iter_epoch_{}'.format(epoch), prefix='iter_epoch_{}'.format(epoch),
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch) epoch=epoch)
best_str = 'best metirc, {}'.format(', '.join( best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
logger.info(best_str) logger.info(best_str)
if dist.get_rank() == 0 and vdl_writer is not None: if dist.get_rank() == 0 and vdl_writer is not None:
...@@ -302,7 +327,8 @@ def train(config, ...@@ -302,7 +327,8 @@ def train(config,
return return
def eval(model, valid_dataloader, post_process_class, eval_class): def eval(model, valid_dataloader, post_process_class, eval_class,
use_srn=False):
model.eval() model.eval()
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
...@@ -313,7 +339,12 @@ def eval(model, valid_dataloader, post_process_class, eval_class): ...@@ -313,7 +339,12 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
break break
images = batch[0] images = batch[0]
start = time.time() start = time.time()
preds = model(images)
if use_srn:
others = batch[-4:]
preds = model(images, others)
else:
preds = model(images)
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods # Obtain usable results from post-processing methods
...@@ -323,13 +354,13 @@ def eval(model, valid_dataloader, post_process_class, eval_class): ...@@ -323,13 +354,13 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
eval_class(post_result, batch) eval_class(post_result, batch)
pbar.update(1) pbar.update(1)
total_frame += len(images) total_frame += len(images)
# Get final metirc,eg. acc or hmean # Get final metric,eg. acc or hmean
metirc = eval_class.get_metric() metric = eval_class.get_metric()
pbar.close() pbar.close()
model.train() model.train()
metirc['fps'] = total_frame / total_time metric['fps'] = total_frame / total_time
return metirc return metric
def preprocess(is_train=False): def preprocess(is_train=False):
...@@ -363,6 +394,7 @@ def preprocess(is_train=False): ...@@ -363,6 +394,7 @@ def preprocess(is_train=False):
logger = get_logger(name='root', log_file=log_file) logger = get_logger(name='root', log_file=log_file)
if config['Global']['use_visualdl']: if config['Global']['use_visualdl']:
from visualdl import LogWriter from visualdl import LogWriter
save_model_dir = config['Global']['save_model_dir']
vdl_writer_path = '{}/vdl/'.format(save_model_dir) vdl_writer_path = '{}/vdl/'.format(save_model_dir)
os.makedirs(vdl_writer_path, exist_ok=True) os.makedirs(vdl_writer_path, exist_ok=True)
vdl_writer = LogWriter(logdir=vdl_writer_path) vdl_writer = LogWriter(logdir=vdl_writer_path)
......
...@@ -50,6 +50,12 @@ def main(config, device, logger, vdl_writer): ...@@ -50,6 +50,12 @@ def main(config, device, logger, vdl_writer):
# build dataloader # build dataloader
train_dataloader = build_dataloader(config, 'Train', device, logger) train_dataloader = build_dataloader(config, 'Train', device, logger)
if len(train_dataloader) == 0:
logger.error(
'No Images in train dataset, please check annotation file and path in the configuration file'
)
return
if config['Eval']: if config['Eval']:
valid_dataloader = build_dataloader(config, 'Eval', device, logger) valid_dataloader = build_dataloader(config, 'Eval', device, logger)
else: else:
...@@ -83,8 +89,10 @@ def main(config, device, logger, vdl_writer): ...@@ -83,8 +89,10 @@ def main(config, device, logger, vdl_writer):
# load pretrain model # load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer) pre_best_model_dict = init_model(config, model, logger, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'. logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
format(len(train_dataloader), len(valid_dataloader))) if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format(
len(valid_dataloader)))
# start train # start train
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class, loss_class, optimizer, lr_scheduler, post_process_class,
......
# for paddle.__version__ >= 2.0rc1 # recommended paddle.__version__ == 2.0.0
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
# for paddle.__version__ < 2.0rc1
# python3 -m paddle.distributed.launch --selected_gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment