Commit 19eb7eb8 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dy1

parents 0afe6c32 03b7daa5
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import os import os
import sys import sys
import json
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
...@@ -46,12 +47,18 @@ def main(): ...@@ -46,12 +47,18 @@ def main():
# build model # build model
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
config['Architecture']["Head"]['out_channels'] = len( char_num = len(getattr(post_process_class, 'character'))
getattr(post_process_class, 'character')) if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) init_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
...@@ -73,35 +80,57 @@ def main(): ...@@ -73,35 +80,57 @@ def main():
global_config['infer_mode'] = True global_config['infer_mode'] = True
ops = create_operators(transforms, global_config) ops = create_operators(transforms, global_config)
save_res_path = config['Global'].get('save_res_path',
"./output/rec/predicts_rec.txt")
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
model.eval() model.eval()
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file)) with open(save_res_path, "w") as fout:
with open(file, 'rb') as f: for file in get_image_file_list(config['Global']['infer_img']):
img = f.read() logger.info("infer_img: {}".format(file))
data = {'image': img} with open(file, 'rb') as f:
batch = transform(data, ops) img = f.read()
if config['Architecture']['algorithm'] == "SRN": data = {'image': img}
encoder_word_pos_list = np.expand_dims(batch[1], axis=0) batch = transform(data, ops)
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0) if config['Architecture']['algorithm'] == "SRN":
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0) encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0) gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
others = [ gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
paddle.to_tensor(encoder_word_pos_list),
paddle.to_tensor(gsrm_word_pos_list), others = [
paddle.to_tensor(gsrm_slf_attn_bias1_list), paddle.to_tensor(encoder_word_pos_list),
paddle.to_tensor(gsrm_slf_attn_bias2_list) paddle.to_tensor(gsrm_word_pos_list),
] paddle.to_tensor(gsrm_slf_attn_bias1_list),
paddle.to_tensor(gsrm_slf_attn_bias2_list)
images = np.expand_dims(batch[0], axis=0) ]
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN": images = np.expand_dims(batch[0], axis=0)
preds = model(images, others) images = paddle.to_tensor(images)
else: if config['Architecture']['algorithm'] == "SRN":
preds = model(images) preds = model(images, others)
post_result = post_process_class(preds) else:
for rec_reuslt in post_result: preds = model(images)
logger.info('\t result: {}'.format(rec_reuslt)) post_result = post_process_class(preds)
info = None
if isinstance(post_result, dict):
rec_info = dict()
for key in post_result:
if len(post_result[key][0]) >= 2:
rec_info[key] = {
"label": post_result[key][0][0],
"score": float(post_result[key][0][1]),
}
info = json.dumps(rec_info)
else:
if len(post_result[0]) >= 2:
info = post_result[0][0] + "\t" + str(post_result[0][1])
if info is not None:
logger.info("\t result: {}".format(info))
fout.write(file + "\t" + info)
logger.info("success!") logger.info("success!")
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import sys
import json
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import paddle
from paddle.jit import to_static
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
import cv2
def main(config, device, logger, vdl_writer):
global_config = config['Global']
# build post process
post_process_class = build_post_process(config['PostProcess'],
global_config)
# build model
if hasattr(post_process_class, 'character'):
config['Architecture']["Head"]['out_channels'] = len(
getattr(post_process_class, 'character'))
model = build_model(config['Architecture'])
init_model(config, model, logger)
# create data ops
transforms = []
use_padding = False
for op in config['Eval']['dataset']['transforms']:
op_name = list(op)[0]
if 'Label' in op_name:
continue
if op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image']
if op_name == "ResizeTableImage":
use_padding = True
padding_max_len = op['ResizeTableImage']['max_len']
transforms.append(op)
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
model.eval()
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f:
img = f.read()
data = {'image': img}
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
preds = model(images)
post_result = post_process_class(preds)
res_html_code = post_result['res_html_code']
res_loc = post_result['res_loc']
img = cv2.imread(file)
imgh, imgw = img.shape[0:2]
res_loc_final = []
for rno in range(len(res_loc[0])):
x0, y0, x1, y1 = res_loc[0][rno]
left = max(int(imgw * x0), 0)
top = max(int(imgh * y0), 0)
right = min(int(imgw * x1), imgw - 1)
bottom = min(int(imgh * y1), imgh - 1)
cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2)
res_loc_final.append([left, top, right, bottom])
res_loc_str = json.dumps(res_loc_final)
logger.info("result: {}, {}".format(res_html_code, res_loc_final))
logger.info("success!")
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess()
main(config, device, logger, vdl_writer)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
import os import os
import sys import sys
import platform
import yaml import yaml
import time import time
import shutil import shutil
...@@ -159,6 +160,8 @@ def train(config, ...@@ -159,6 +160,8 @@ def train(config,
eval_batch_step = config['Global']['eval_batch_step'] eval_batch_step = config['Global']['eval_batch_step']
global_step = 0 global_step = 0
if 'global_step' in pre_best_model_dict:
global_step = pre_best_model_dict['global_step']
start_eval_step = 0 start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2: if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0] start_eval_step = eval_batch_step[0]
...@@ -183,6 +186,12 @@ def train(config, ...@@ -183,6 +186,12 @@ def train(config,
model.train() model.train()
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
try:
model_type = config['Architecture']['model_type']
except:
model_type = None
if 'start_epoch' in best_model_dict: if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch'] start_epoch = best_model_dict['start_epoch']
...@@ -196,16 +205,18 @@ def train(config, ...@@ -196,16 +205,18 @@ def train(config,
train_reader_cost = 0.0 train_reader_cost = 0.0
batch_sum = 0 batch_sum = 0
batch_start = time.time() batch_start = time.time()
max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
train_reader_cost += time.time() - batch_start train_reader_cost += time.time() - batch_start
if idx >= len(train_dataloader): if idx >= max_iter:
break break
lr = optimizer.get_lr() lr = optimizer.get_lr()
images = batch[0] images = batch[0]
if use_srn: if use_srn:
others = batch[-4:]
preds = model(images, others)
model_average = True model_average = True
if use_srn or model_type == 'table' or use_nrtr:
preds = model(images, data=batch[1:])
else: else:
preds = model(images) preds = model(images)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
...@@ -227,8 +238,11 @@ def train(config, ...@@ -227,8 +238,11 @@ def train(config,
if cal_metric_during_train: # only rec and cls need if cal_metric_during_train: # only rec and cls need
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
post_result = post_process_class(preds, batch[1]) if model_type == 'table':
eval_class(post_result, batch) eval_class(preds, batch)
else:
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
metric = eval_class.get_metric() metric = eval_class.get_metric()
train_stats.update(metric) train_stats.update(metric)
...@@ -264,6 +278,7 @@ def train(config, ...@@ -264,6 +278,7 @@ def train(config,
valid_dataloader, valid_dataloader,
post_process_class, post_process_class,
eval_class, eval_class,
model_type,
use_srn=use_srn) use_srn=use_srn)
cur_metric_str = 'cur metric, {}'.format(', '.join( cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()])) ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
...@@ -287,7 +302,8 @@ def train(config, ...@@ -287,7 +302,8 @@ def train(config,
is_best=True, is_best=True,
prefix='best_accuracy', prefix='best_accuracy',
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch) epoch=epoch,
global_step=global_step)
best_str = 'best metric, {}'.format(', '.join([ best_str = 'best metric, {}'.format(', '.join([
'{}: {}'.format(k, v) for k, v in best_model_dict.items() '{}: {}'.format(k, v) for k, v in best_model_dict.items()
])) ]))
...@@ -309,7 +325,8 @@ def train(config, ...@@ -309,7 +325,8 @@ def train(config,
is_best=False, is_best=False,
prefix='latest', prefix='latest',
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch) epoch=epoch,
global_step=global_step)
if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0: if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
save_model( save_model(
model, model,
...@@ -319,7 +336,8 @@ def train(config, ...@@ -319,7 +336,8 @@ def train(config,
is_best=False, is_best=False,
prefix='iter_epoch_{}'.format(epoch), prefix='iter_epoch_{}'.format(epoch),
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch) epoch=epoch,
global_step=global_step)
best_str = 'best metric, {}'.format(', '.join( best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
logger.info(best_str) logger.info(best_str)
...@@ -328,31 +346,37 @@ def train(config, ...@@ -328,31 +346,37 @@ def train(config,
return return
def eval(model, valid_dataloader, post_process_class, eval_class, def eval(model,
valid_dataloader,
post_process_class,
eval_class,
model_type,
use_srn=False): use_srn=False):
model.eval() model.eval()
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
total_time = 0.0 total_time = 0.0
pbar = tqdm(total=len(valid_dataloader), desc='eval model:') pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
max_iter = len(valid_dataloader) - 1 if platform.system(
) == "Windows" else len(valid_dataloader)
for idx, batch in enumerate(valid_dataloader): for idx, batch in enumerate(valid_dataloader):
if idx >= len(valid_dataloader): if idx >= max_iter:
break break
images = batch[0] images = batch[0]
start = time.time() start = time.time()
if use_srn or model_type == 'table':
if use_srn: preds = model(images, data=batch[1:])
others = batch[-4:]
preds = model(images, others)
else: else:
preds = model(images) preds = model(images)
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods # Obtain usable results from post-processing methods
post_result = post_process_class(preds, batch[1])
total_time += time.time() - start total_time += time.time() - start
# Evaluate the results of the current batch # Evaluate the results of the current batch
eval_class(post_result, batch) if model_type == 'table':
eval_class(preds, batch)
else:
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
pbar.update(1) pbar.update(1)
total_frame += len(images) total_frame += len(images)
# Get final metric,eg. acc or hmean # Get final metric,eg. acc or hmean
...@@ -375,7 +399,8 @@ def preprocess(is_train=False): ...@@ -375,7 +399,8 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm'] alg = config['Architecture']['algorithm']
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS' 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn'
] ]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
...@@ -35,7 +35,7 @@ from ppocr.losses import build_loss ...@@ -35,7 +35,7 @@ from ppocr.losses import build_loss
from ppocr.optimizer import build_optimizer from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import init_model, load_dygraph_params
import tools.program as program import tools.program as program
dist.get_world_size() dist.get_world_size()
...@@ -72,7 +72,14 @@ def main(config, device, logger, vdl_writer): ...@@ -72,7 +72,14 @@ def main(config, device, logger, vdl_writer):
# for rec algorithm # for rec algorithm
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character')) char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
if config['Global']['distributed']: if config['Global']['distributed']:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
...@@ -90,8 +97,7 @@ def main(config, device, logger, vdl_writer): ...@@ -90,8 +97,7 @@ def main(config, device, logger, vdl_writer):
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer) pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
logger.info('train dataloader has {} iters'.format(len(train_dataloader))) logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None: if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format( logger.info('valid dataloader has {} iters'.format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment