Unverified Commit 8bae1e40 authored by MissPenguin's avatar MissPenguin Committed by GitHub
Browse files

Merge pull request #5174 from WenmuZhou/fix_vqa

vqa code integrated into ppocr training system
parents 9fa209e3 1cbe4bf2
......@@ -61,7 +61,8 @@ def main():
else:
model_type = None
best_model_dict = load_model(config, model)
best_model_dict = load_model(
config, model, model_type=config['Architecture']["model_type"])
if len(best_model_dict):
logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items():
......
......@@ -85,7 +85,7 @@ def export_single_model(model, arch_config, save_path, logger):
def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
config = merge_config(config, FLAGS.opt)
logger = get_logger()
# build post process
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -27,8 +27,6 @@ import yaml
import paddle
import paddle.distributed as dist
paddle.seed(2)
from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
......@@ -36,6 +34,7 @@ from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import set_seed
import tools.program as program
dist.get_world_size()
......@@ -97,7 +96,8 @@ def main(config, device, logger, vdl_writer):
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = load_model(config, model, optimizer)
pre_best_model_dict = load_model(config, model, optimizer,
config['Architecture']["model_type"])
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format(
......@@ -145,5 +145,7 @@ def test_reader(config, device, logger):
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess(is_train=True)
seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024
set_seed(seed)
main(config, device, logger, vdl_writer)
# test_reader(config, device, logger)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment