Unverified Commit 8bae1e40 authored by MissPenguin's avatar MissPenguin Committed by GitHub
Browse files

Merge pull request #5174 from WenmuZhou/fix_vqa

vqa code integrated into ppocr training system
parents 9fa209e3 1cbe4bf2
...@@ -61,7 +61,8 @@ def main(): ...@@ -61,7 +61,8 @@ def main():
else: else:
model_type = None model_type = None
best_model_dict = load_model(config, model) best_model_dict = load_model(
config, model, model_type=config['Architecture']["model_type"])
if len(best_model_dict): if len(best_model_dict):
logger.info('metric in ckpt ***************') logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items(): for k, v in best_model_dict.items():
......
...@@ -85,7 +85,7 @@ def export_single_model(model, arch_config, save_path, logger): ...@@ -85,7 +85,7 @@ def export_single_model(model, arch_config, save_path, logger):
def main(): def main():
FLAGS = ArgsParser().parse_args() FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
merge_config(FLAGS.opt) config = merge_config(config, FLAGS.opt)
logger = get_logger() logger = get_logger()
# build post process # build post process
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -27,8 +27,6 @@ import yaml ...@@ -27,8 +27,6 @@ import yaml
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
paddle.seed(2)
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss from ppocr.losses import build_loss
...@@ -36,6 +34,7 @@ from ppocr.optimizer import build_optimizer ...@@ -36,6 +34,7 @@ from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import load_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import set_seed
import tools.program as program import tools.program as program
dist.get_world_size() dist.get_world_size()
...@@ -97,7 +96,8 @@ def main(config, device, logger, vdl_writer): ...@@ -97,7 +96,8 @@ def main(config, device, logger, vdl_writer):
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = load_model(config, model, optimizer) pre_best_model_dict = load_model(config, model, optimizer,
config['Architecture']["model_type"])
logger.info('train dataloader has {} iters'.format(len(train_dataloader))) logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None: if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format( logger.info('valid dataloader has {} iters'.format(
...@@ -145,5 +145,7 @@ def test_reader(config, device, logger): ...@@ -145,5 +145,7 @@ def test_reader(config, device, logger):
if __name__ == '__main__': if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess(is_train=True) config, device, logger, vdl_writer = program.preprocess(is_train=True)
seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024
set_seed(seed)
main(config, device, logger, vdl_writer) main(config, device, logger, vdl_writer)
# test_reader(config, device, logger) # test_reader(config, device, logger)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment