"docs/source/en/using-diffusers/other-modalities.md" did not exist on "bea7eb43143fe6abdd583a367e5c6b6467b12714"
Unverified Commit 48124054 authored by shaohua.zhang's avatar shaohua.zhang Committed by GitHub
Browse files

add common code to reduce code duplication

this funtion is mainly for the train ,evel,export_model
parent e0fa21bd
...@@ -22,6 +22,7 @@ import yaml ...@@ -22,6 +22,7 @@ import yaml
import os import os
from ppocr.utils.utility import create_module from ppocr.utils.utility import create_module
from ppocr.utils.utility import initial_logger from ppocr.utils.utility import initial_logger
logger = initial_logger() logger = initial_logger()
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -31,8 +32,7 @@ from eval_utils.eval_det_utils import eval_det_run ...@@ -31,8 +32,7 @@ from eval_utils.eval_det_utils import eval_det_run
from eval_utils.eval_rec_utils import eval_rec_run from eval_utils.eval_rec_utils import eval_rec_run
from ppocr.utils.save_load import save_model from ppocr.utils.save_load import save_model
import numpy as np import numpy as np
from ppocr.utils.character import cal_predicts_accuracy from ppocr.utils.character import cal_predicts_accuracy, CharacterOps
class ArgsParser(ArgumentParser): class ArgsParser(ArgumentParser):
def __init__(self): def __init__(self):
...@@ -374,3 +374,28 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): ...@@ -374,3 +374,28 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
save_path = save_model_dir + "/iter_epoch_%d" % (epoch) save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
save_model(train_info_dict['train_program'], save_path) save_model(train_info_dict['train_program'], save_path)
return return
def preProcess():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
logger.info(config)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
check_gpu(use_gpu)
alg = config['Global']['algorithm']
assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE']
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']:
config['Global']['char_ops'] = CharacterOps(config['Global'])
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
startup_program = fluid.Program()
train_program = fluid.Program()
isContain_det = False
if alg in ['EAST', 'DB']:
isContain_det = True
return startup_program, train_program, place, config, isContain_det
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment