Unverified Commit 006d84bf authored by 崔浩's avatar 崔浩 Committed by GitHub
Browse files

Merge branch 'PaddlePaddle:dygraph' into dygraph

parents 302ca30c 8beeb84c
...@@ -31,6 +31,7 @@ from ppocr.utils.stats import TrainingStats ...@@ -31,6 +31,7 @@ from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model from ppocr.utils.save_load import save_model
from ppocr.utils.utility import print_dict from ppocr.utils.utility import print_dict
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppocr.utils import profiler
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
import numpy as np import numpy as np
...@@ -42,6 +43,13 @@ class ArgsParser(ArgumentParser): ...@@ -42,6 +43,13 @@ class ArgsParser(ArgumentParser):
self.add_argument("-c", "--config", help="configuration file to use") self.add_argument("-c", "--config", help="configuration file to use")
self.add_argument( self.add_argument(
"-o", "--opt", nargs='+', help="set configuration options") "-o", "--opt", nargs='+', help="set configuration options")
self.add_argument(
'-p',
'--profiler_options',
type=str,
default=None,
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
)
def parse_args(self, argv=None): def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv) args = super(ArgsParser, self).parse_args(argv)
...@@ -158,6 +166,7 @@ def train(config, ...@@ -158,6 +166,7 @@ def train(config,
epoch_num = config['Global']['epoch_num'] epoch_num = config['Global']['epoch_num']
print_batch_step = config['Global']['print_batch_step'] print_batch_step = config['Global']['print_batch_step']
eval_batch_step = config['Global']['eval_batch_step'] eval_batch_step = config['Global']['eval_batch_step']
profiler_options = config['profiler_options']
global_step = 0 global_step = 0
if 'global_step' in pre_best_model_dict: if 'global_step' in pre_best_model_dict:
...@@ -186,10 +195,13 @@ def train(config, ...@@ -186,10 +195,13 @@ def train(config,
model.train() model.train()
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
try: extra_input = config['Architecture'][
'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
try:
model_type = config['Architecture']['model_type'] model_type = config['Architecture']['model_type']
except: except:
model_type = None model_type = None
algorithm = config['Architecture']['algorithm']
if 'start_epoch' in best_model_dict: if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch'] start_epoch = best_model_dict['start_epoch']
...@@ -206,6 +218,7 @@ def train(config, ...@@ -206,6 +218,7 @@ def train(config,
max_iter = len(train_dataloader) - 1 if platform.system( max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader) ) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
profiler.add_profiler_step(profiler_options)
train_reader_cost += time.time() - batch_start train_reader_cost += time.time() - batch_start
if idx >= max_iter: if idx >= max_iter:
break break
...@@ -213,7 +226,7 @@ def train(config, ...@@ -213,7 +226,7 @@ def train(config,
images = batch[0] images = batch[0]
if use_srn: if use_srn:
model_average = True model_average = True
if use_srn or model_type == 'table': if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
else: else:
preds = model(images) preds = model(images)
...@@ -277,7 +290,7 @@ def train(config, ...@@ -277,7 +290,7 @@ def train(config,
post_process_class, post_process_class,
eval_class, eval_class,
model_type, model_type,
use_srn=use_srn) extra_input=extra_input)
cur_metric_str = 'cur metric, {}'.format(', '.join( cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()])) ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str) logger.info(cur_metric_str)
...@@ -348,8 +361,8 @@ def eval(model, ...@@ -348,8 +361,8 @@ def eval(model,
valid_dataloader, valid_dataloader,
post_process_class, post_process_class,
eval_class, eval_class,
model_type, model_type=None,
use_srn=False): extra_input=False):
model.eval() model.eval()
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
...@@ -362,7 +375,7 @@ def eval(model, ...@@ -362,7 +375,7 @@ def eval(model,
break break
images = batch[0] images = batch[0]
start = time.time() start = time.time()
if use_srn or model_type == 'table': if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
else: else:
preds = model(images) preds = model(images)
...@@ -386,10 +399,76 @@ def eval(model, ...@@ -386,10 +399,76 @@ def eval(model,
return metric return metric
def update_center(char_center, post_result, preds):
result, label = post_result
feats, logits = preds
logits = paddle.argmax(logits, axis=-1)
feats = feats.numpy()
logits = logits.numpy()
for idx_sample in range(len(label)):
if result[idx_sample][0] == label[idx_sample][0]:
feat = feats[idx_sample]
logit = logits[idx_sample]
for idx_time in range(len(logit)):
index = logit[idx_time]
if index in char_center.keys():
char_center[index][0] = (
char_center[index][0] * char_center[index][1] +
feat[idx_time]) / (char_center[index][1] + 1)
char_center[index][1] += 1
else:
char_center[index] = [feat[idx_time], 1]
return char_center
def get_center(model, eval_dataloader, post_process_class):
pbar = tqdm(total=len(eval_dataloader), desc='get center:')
max_iter = len(eval_dataloader) - 1 if platform.system(
) == "Windows" else len(eval_dataloader)
char_center = dict()
for idx, batch in enumerate(eval_dataloader):
if idx >= max_iter:
break
images = batch[0]
start = time.time()
preds = model(images)
batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods
total_time += time.time() - start
# Evaluate the results of the current batch
post_result = post_process_class(preds, batch[1])
#update char_center
char_center = update_center(char_center, post_result, preds)
pbar.update(1)
pbar.close()
for key in char_center.keys():
char_center[key] = char_center[key][0]
return char_center
def preprocess(is_train=False): def preprocess(is_train=False):
FLAGS = ArgsParser().parse_args() FLAGS = ArgsParser().parse_args()
profiler_options = FLAGS.profiler_options
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
profile_dic = {"profiler_options": FLAGS.profiler_options}
merge_config(profile_dic)
if is_train:
# save_config
save_model_dir = config['Global']['save_model_dir']
os.makedirs(save_model_dir, exist_ok=True)
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
yaml.dump(
dict(config), f, default_flow_style=False, sort_keys=False)
log_file = '{}/train.log'.format(save_model_dir)
else:
log_file = None
logger = get_logger(name='root', log_file=log_file)
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu'] use_gpu = config['Global']['use_gpu']
...@@ -398,24 +477,20 @@ def preprocess(is_train=False): ...@@ -398,24 +477,20 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm'] alg = config['Architecture']['algorithm']
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'TableAttn' 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED'
] ]
windows_not_support_list = ['PSE']
if platform.system() == "Windows" and alg in windows_not_support_list:
logger.warning('{} is not support in Windows now'.format(
windows_not_support_list))
sys.exit()
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device) device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1 config['Global']['distributed'] = dist.get_world_size() != 1
if is_train:
# save_config
save_model_dir = config['Global']['save_model_dir']
os.makedirs(save_model_dir, exist_ok=True)
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
yaml.dump(
dict(config), f, default_flow_style=False, sort_keys=False)
log_file = '{}/train.log'.format(save_model_dir)
else:
log_file = None
logger = get_logger(name='root', log_file=log_file)
if config['Global']['use_visualdl']: if config['Global']['use_visualdl']:
from visualdl import LogWriter from visualdl import LogWriter
save_model_dir = config['Global']['save_model_dir'] save_model_dir = config['Global']['save_model_dir']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment