Commit 41a1b292 authored by Leif's avatar Leif
Browse files

Merge remote-tracking branch 'origin/dygraph' into dygraph

parents 9471054e 3d30899b
......@@ -69,24 +69,6 @@ class ArgsParser(ArgumentParser):
return config
class AttrDict(dict):
"""Single level attribute dict, NOT recursive"""
def __init__(self, **kwargs):
super(AttrDict, self).__init__()
super(AttrDict, self).update(kwargs)
def __getattr__(self, key):
if key in self:
return self[key]
raise AttributeError("object has no attribute '{}'".format(key))
global_config = AttrDict()
default_config = {'Global': {'debug': False, }}
def load_config(file_path):
"""
Load config from yml/yaml file.
......@@ -94,38 +76,38 @@ def load_config(file_path):
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
merge_config(default_config)
_, ext = os.path.splitext(file_path)
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
return global_config
config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)
return config
def merge_config(config):
def merge_config(config, opts):
"""
Merge config into global config.
Args:
config (dict): Config to be merged.
Returns: global config
"""
for key, value in config.items():
for key, value in opts.items():
if "." not in key:
if isinstance(value, dict) and key in global_config:
global_config[key].update(value)
if isinstance(value, dict) and key in config:
config[key].update(value)
else:
global_config[key] = value
config[key] = value
else:
sub_keys = key.split('.')
assert (
sub_keys[0] in global_config
sub_keys[0] in config
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
global_config.keys(), sub_keys[0])
cur = global_config[sub_keys[0]]
config.keys(), sub_keys[0])
cur = config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]):
if idx == len(sub_keys) - 2:
cur[sub_key] = value
else:
cur = cur[sub_key]
return config
def check_gpu(use_gpu):
......@@ -204,20 +186,24 @@ def train(config,
model_type = None
algorithm = config['Architecture']['algorithm']
if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch']
else:
start_epoch = 1
start_epoch = best_model_dict[
'start_epoch'] if 'start_epoch' in best_model_dict else 1
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
reader_start = time.time()
max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
for epoch in range(start_epoch, epoch_num + 1):
train_dataloader = build_dataloader(
config, 'Train', device, logger, seed=epoch)
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
reader_start = time.time()
max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
if train_dataloader.dataset.need_reset:
train_dataloader = build_dataloader(
config, 'Train', device, logger, seed=epoch)
max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader):
profiler.add_profiler_step(profiler_options)
train_reader_cost += time.time() - reader_start
......@@ -239,10 +225,11 @@ def train(config,
else:
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type == "kie":
elif model_type in ["kie", 'vqa']:
preds = model(batch)
else:
preds = model(images)
loss = loss_class(preds, batch)
avg_loss = loss['loss']
......@@ -256,6 +243,7 @@ def train(config,
optimizer.clear_grad()
train_run_cost += time.time() - train_start
global_step += 1
total_samples += len(images)
if not isinstance(lr_scheduler, float):
......@@ -285,12 +273,13 @@ def train(config,
(global_step > 0 and global_step % print_batch_step == 0) or
(idx >= len(train_dataloader) - 1)):
logs = train_stats.log()
strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ips: {:.5f}'.format(
epoch, epoch_num, global_step, logs, train_reader_cost /
print_batch_step, (train_reader_cost + train_run_cost) /
print_batch_step, total_samples,
print_batch_step, total_samples / print_batch_step,
total_samples / (train_reader_cost + train_run_cost))
logger.info(strs)
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
......@@ -330,6 +319,7 @@ def train(config,
optimizer,
save_model_dir,
logger,
config,
is_best=True,
prefix='best_accuracy',
best_model_dict=best_model_dict,
......@@ -344,8 +334,7 @@ def train(config,
vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator),
best_model_dict[main_indicator],
global_step)
global_step += 1
optimizer.clear_grad()
reader_start = time.time()
if dist.get_rank() == 0:
save_model(
......@@ -353,6 +342,7 @@ def train(config,
optimizer,
save_model_dir,
logger,
config,
is_best=False,
prefix='latest',
best_model_dict=best_model_dict,
......@@ -364,6 +354,7 @@ def train(config,
optimizer,
save_model_dir,
logger,
config,
is_best=False,
prefix='iter_epoch_{}'.format(epoch),
best_model_dict=best_model_dict,
......@@ -401,19 +392,28 @@ def eval(model,
start = time.time()
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type == "kie":
elif model_type in ["kie", 'vqa']:
preds = model(batch)
else:
preds = model(images)
batch = [item.numpy() for item in batch]
batch_numpy = []
for item in batch:
if isinstance(item, paddle.Tensor):
batch_numpy.append(item.numpy())
else:
batch_numpy.append(item)
# Obtain usable results from post-processing methods
total_time += time.time() - start
# Evaluate the results of the current batch
if model_type in ['table', 'kie']:
eval_class(preds, batch)
eval_class(preds, batch_numpy)
elif model_type in ['vqa']:
post_result = post_process_class(preds, batch_numpy)
eval_class(post_result, batch_numpy)
else:
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
post_result = post_process_class(preds, batch_numpy[1])
eval_class(post_result, batch_numpy)
pbar.update(1)
total_frame += len(images)
......@@ -479,9 +479,9 @@ def preprocess(is_train=False):
FLAGS = ArgsParser().parse_args()
profiler_options = FLAGS.profiler_options
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
config = merge_config(config, FLAGS.opt)
profile_dic = {"profiler_options": FLAGS.profiler_options}
merge_config(profile_dic)
config = merge_config(config, profile_dic)
if is_train:
# save_config
......@@ -503,20 +503,15 @@ def preprocess(is_train=False):
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR'
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM'
]
windows_not_support_list = ['PSE']
if platform.system() == "Windows" and alg in windows_not_support_list:
logger.warning('{} is not support in Windows now'.format(
windows_not_support_list))
sys.exit()
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1
if config['Global']['use_visualdl']:
if config['Global']['use_visualdl'] and dist.get_rank() == 0:
from visualdl import LogWriter
save_model_dir = config['Global']['save_model_dir']
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
......
......@@ -27,8 +27,6 @@ import yaml
import paddle
import paddle.distributed as dist
paddle.seed(2)
from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
......@@ -36,6 +34,7 @@ from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import set_seed
import tools.program as program
dist.get_world_size()
......@@ -97,7 +96,8 @@ def main(config, device, logger, vdl_writer):
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = load_model(config, model, optimizer)
pre_best_model_dict = load_model(config, model, optimizer,
config['Architecture']["model_type"])
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format(
......@@ -145,5 +145,7 @@ def test_reader(config, device, logger):
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess(is_train=True)
seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024
set_seed(seed)
main(config, device, logger, vdl_writer)
# test_reader(config, device, logger)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment