Commit aa59fca5 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents 12d15752 f01f24c7
......@@ -31,7 +31,7 @@ from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser
def export_single_model(model, arch_config, save_path, logger):
def export_single_model(model, arch_config, save_path, logger, quanter=None):
if arch_config["algorithm"] == "SRN":
max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [
......@@ -55,6 +55,18 @@ def export_single_model(model, arch_config, save_path, logger):
shape=[None, 3, 48, 160], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "SVTR":
if arch_config["Head"]["name"] == 'MultiHead':
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 48, -1], dtype="float32"),
]
else:
other_shape = [
paddle.static.InputSpec(
shape=[None, 3, 64, 256], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "PREN":
other_shape = [
paddle.static.InputSpec(
......@@ -83,7 +95,10 @@ def export_single_model(model, arch_config, save_path, logger):
shape=[None] + infer_shape, dtype="float32")
])
paddle.jit.save(model, save_path)
if quanter is None:
paddle.jit.save(model, save_path)
else:
quanter.save_quantized_model(model, save_path)
logger.info("inference model is saved to {}".format(save_path))
return
......@@ -105,13 +120,35 @@ def main():
if config["Architecture"]["algorithm"] in ["Distillation",
]: # distillation model
for key in config["Architecture"]["Models"]:
config["Architecture"]["Models"][key]["Head"][
"out_channels"] = char_num
if config["Architecture"]["Models"][key]["Head"][
"name"] == 'MultiHead': # multi head
out_channels_list = {}
if config['PostProcess'][
'name'] == 'DistillationSARLabelDecode':
char_num = char_num - 2
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list
else:
config["Architecture"]["Models"][key]["Head"][
"out_channels"] = char_num
# just one final tensor needs to to exported for inference
config["Architecture"]["Models"][key][
"return_all_feats"] = False
elif config['Architecture']['Head'][
'name'] == 'MultiHead': # multi head
out_channels_list = {}
char_num = len(getattr(post_process_class, 'character'))
if config['PostProcess']['name'] == 'SARLabelDecode':
char_num = char_num - 2
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Head'][
'out_channels_list'] = out_channels_list
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"])
load_model(config, model)
model.eval()
......
......@@ -158,7 +158,7 @@ class TextDetector(object):
rect[1] = pts[np.argmin(diff)]
rect[3] = pts[np.argmax(diff)]
return rect
def clip_det_res(self, points, img_height, img_width):
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
......@@ -284,7 +284,7 @@ if __name__ == "__main__":
total_time += elapse
count += 1
save_pred = os.path.basename(image_file) + "\t" + str(
json.dumps(np.array(dt_boxes).astype(np.int32).tolist())) + "\n"
json.dumps([x.tolist() for x in dt_boxes])) + "\n"
save_results.append(save_pred)
logger.info(save_pred)
logger.info("The predict time of {}: {}".format(image_file, elapse))
......
......@@ -107,7 +107,7 @@ class TextRecognizer(object):
return norm_img.astype(np.float32) / 128. - 1.
assert imgC == img.shape[2]
imgW = int((32 * max_wh_ratio))
imgW = int((imgH * max_wh_ratio))
if self.use_onnx:
w = self.input_tensor.shape[3:][0]
if w is not None and w > 0:
......@@ -131,6 +131,17 @@ class TextRecognizer(object):
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
return resized_image
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
......@@ -255,18 +266,16 @@ class TextRecognizer(object):
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
imgC, imgH, imgW = self.rec_image_shape
max_wh_ratio = imgW / imgH
# max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
if self.rec_algorithm != "SRN" and self.rec_algorithm != "SAR":
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "SAR":
if self.rec_algorithm == "SAR":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
img_list[indices[ino]], self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
......@@ -274,7 +283,7 @@ class TextRecognizer(object):
valid_ratios = []
valid_ratios.append(valid_ratio)
norm_img_batch.append(norm_img)
else:
elif self.rec_algorithm == "SRN":
norm_img = self.process_image_srn(
img_list[indices[ino]], self.rec_image_shape, 8, 25)
encoder_word_pos_list = []
......@@ -286,6 +295,16 @@ class TextRecognizer(object):
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
elif self.rec_algorithm == "SVTR":
norm_img = self.resize_norm_img_svtr(
img_list[indices[ino]], self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
else:
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
if self.benchmark:
......
......@@ -271,9 +271,10 @@ def create_predictor(args, mode, logger):
elif mode == "rec":
if args.rec_algorithm != "CRNN":
use_dynamic_shape = False
min_input_shape = {"x": [1, 3, 32, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 32, 1536]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]}
imgH = int(args.rec_image_shape.split(',')[-2])
min_input_shape = {"x": [1, 3, imgH, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, imgH, 1536]}
opt_input_shape = {"x": [args.rec_batch_num, 3, imgH, 320]}
elif mode == "cls":
min_input_shape = {"x": [1, 3, 48, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 48, 1024]}
......@@ -300,8 +301,8 @@ def create_predictor(args, mode, logger):
# enable memory optim
config.enable_memory_optim()
config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.delete_pass("matmul_transpose_reshape_fuse_pass")
if mode == 'table':
config.delete_pass("fc_fuse_pass") # not supported for table
config.switch_use_feed_fetch_ops(False)
......
......@@ -57,6 +57,8 @@ def main():
continue
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image']
elif op_name == "SSLRotateResize":
op[op_name]["mode"] = "test"
transforms.append(op)
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
......
......@@ -51,8 +51,28 @@ def main():
if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
if config['Architecture']['Models'][key]['Head'][
'name'] == 'MultiHead': # for multi head
out_channels_list = {}
if config['PostProcess'][
'name'] == 'DistillationSARLabelDecode':
char_num = char_num - 2
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list
else:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
elif config['Architecture']['Head'][
'name'] == 'MultiHead': # for multi head loss
out_channels_list = {}
if config['PostProcess']['name'] == 'SARLabelDecode':
char_num = char_num - 2
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Head'][
'out_channels_list'] = out_channels_list
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num
......
......@@ -201,12 +201,19 @@ def train(config,
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input = config['Architecture'][
'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
for key in config['Architecture']["Models"]:
extra_input = extra_input or config['Architecture']['Models'][key][
'algorithm'] in extra_input_models
else:
extra_input = config['Architecture']['algorithm'] in extra_input_models
try:
model_type = config['Architecture']['model_type']
except:
model_type = None
algorithm = config['Architecture']['algorithm']
start_epoch = best_model_dict[
......@@ -269,7 +276,12 @@ def train(config,
if model_type in ['table', 'kie']:
eval_class(preds, batch)
else:
post_result = post_process_class(preds, batch[1])
if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
]: # for multi head loss
post_result = post_process_class(
preds['ctc'], batch[1]) # for CTC head out
else:
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
metric = eval_class.get_metric()
train_stats.update(metric)
......@@ -541,7 +553,7 @@ def preprocess(is_train=False):
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE'
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR'
]
device = 'cpu'
......
......@@ -74,11 +74,49 @@ def main(config, device, logger, vdl_writer):
if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
if config['Architecture']['Models'][key]['Head'][
'name'] == 'MultiHead': # for multi head
if config['PostProcess'][
'name'] == 'DistillationSARLabelDecode':
char_num = char_num - 2
# update SARLoss params
assert list(config['Loss']['loss_config_list'][-1].keys())[
0] == 'DistillationSARLoss'
config['Loss']['loss_config_list'][-1][
'DistillationSARLoss']['ignore_index'] = char_num + 1
out_channels_list = {}
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list
else:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
elif config['Architecture']['Head'][
'name'] == 'MultiHead': # for multi head
if config['PostProcess']['name'] == 'SARLabelDecode':
char_num = char_num - 2
# update SARLoss params
assert list(config['Loss']['loss_config_list'][1].keys())[
0] == 'SARLoss'
if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
config['Loss']['loss_config_list'][1]['SARLoss'] = {
'ignore_index': char_num + 1
}
else:
config['Loss']['loss_config_list'][1]['SARLoss'][
'ignore_index'] = char_num + 1
out_channels_list = {}
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Head'][
'out_channels_list'] = out_channels_list
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num
if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model
config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture'])
if config['Global']['distributed']:
model = paddle.DataParallel(model)
......@@ -91,7 +129,7 @@ def main(config, device, logger, vdl_writer):
config['Optimizer'],
epochs=config['Global']['epoch_num'],
step_each_epoch=len(train_dataloader),
parameters=model.parameters())
model=model)
# build metric
eval_class = build_metric(config['Metric'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment