Commit d3c50fda authored by xmy0916's avatar xmy0916
Browse files

fix bugs

parent 46ac85ad
import yaml import yaml
from argparse import ArgumentParser, RawDescriptionHelpFormatter from argparse import ArgumentParser, RawDescriptionHelpFormatter
import os.path import os.path
import logging
logging.basicConfig(level=logging.INFO)
support_list = { support_list = {
'it':'italian', 'xi':'spanish', 'pu':'portuguese', 'ru':'russian', 'ar':'arabic', 'it':'italian', 'xi':'spanish', 'pu':'portuguese', 'ru':'russian', 'ar':'arabic',
...@@ -16,6 +18,7 @@ You can download it from \ ...@@ -16,6 +18,7 @@ You can download it from \
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/" https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/"
global_config = yaml.load(open("./rec_multi_language_lite_train.yml", 'rb'), Loader=yaml.Loader) global_config = yaml.load(open("./rec_multi_language_lite_train.yml", 'rb'), Loader=yaml.Loader)
project_path = os.path.abspath(os.path.join(os.getcwd(), "../../../"))
class ArgsParser(ArgumentParser): class ArgsParser(ArgumentParser):
def __init__(self): def __init__(self):
...@@ -32,7 +35,7 @@ class ArgsParser(ArgumentParser): ...@@ -32,7 +35,7 @@ class ArgsParser(ArgumentParser):
self.add_argument( self.add_argument(
"--dict",type=str,help="you can use this command to change the dictionary default path") "--dict",type=str,help="you can use this command to change the dictionary default path")
self.add_argument( self.add_argument(
"--dataset_root_path",type=str,help="you can use this command to change the dataset default root path") "--data_dir",type=str,help="you can use this command to change the dataset default root path")
def parse_args(self, argv=None): def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv) args = super(ArgsParser, self).parse_args(argv)
...@@ -51,15 +54,19 @@ class ArgsParser(ArgumentParser): ...@@ -51,15 +54,19 @@ class ArgsParser(ArgumentParser):
return config return config
def _set_language(self, type): def _set_language(self, type):
assert(type),"please use -t or --type to choose language type" assert(type),"please use -l or --language to choose language type"
assert( assert(
type[0] in support_list.keys() type[0] in support_list.keys()
),"the sub_keys(-t or --type) can only be one of support list: \n{},\nbut get: {}, " \ ),"the sub_keys(-l or --language) can only be one of support list: \n{},\nbut get: {}, " \
"please check your running command".format(support_list, type) "please check your running command".format(support_list, type)
global_config['Global']['character_dict_path'] = 'ppocr/utils/dict/{}_dict.txt'.format(type[0]) global_config['Global']['character_dict_path'] = 'ppocr/utils/dict/{}_dict.txt'.format(type[0])
global_config['Global']['save_model_dir'] = './output/rec_{}_lite'.format(type[0]) global_config['Global']['save_model_dir'] = './output/rec_{}_lite'.format(type[0])
global_config['Train']['dataset']['label_file_list'] = ["train_data/{}_train.txt".format(type[0])] global_config['Train']['dataset']['label_file_list'] = ["train_data/{}_train.txt".format(type[0])]
global_config['Eval']['dataset']['label_file_list'] = ["train_data/{}_val.txt".format(type[0])] global_config['Eval']['dataset']['label_file_list'] = ["train_data/{}_val.txt".format(type[0])]
assert(
os.path.isfile(os.path.join(project_path,global_config['Global']['character_dict_path']))
),"Loss default dictionary file {}_dict.txt.You can download it from \
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/ppocr/utils/dict/".format(type[0])
return type[0] return type[0]
...@@ -88,27 +95,42 @@ def merge_config(config): ...@@ -88,27 +95,42 @@ def merge_config(config):
cur[sub_key] = value cur[sub_key] = value
else: else:
cur = cur[sub_key] cur = cur[sub_key]
def loss_file(path):
if not os.path.exists(path):
logging.warning('There is no such file:{},Please do not forget to put in the specified file'.format(path))
if __name__ == '__main__': if __name__ == '__main__':
FLAGS = ArgsParser().parse_args() FLAGS = ArgsParser().parse_args()
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
if FLAGS.train: if FLAGS.train:
global_config['Train']['dataset']['label_file_list'] = [FLAGS.train] global_config['Train']['dataset']['label_file_list'] = [FLAGS.train]
train_label_path = os.path.join(project_path,FLAGS.train)
loss_file(train_label_path)
if FLAGS.val: if FLAGS.val:
global_config['Eval']['dataset']['label_file_list'] = [FLAGS.val] global_config['Eval']['dataset']['label_file_list'] = [FLAGS.val]
eval_label_path = os.path.join(project_path,FLAGS.val)
loss_file(Eval_label_path)
if FLAGS.dict: if FLAGS.dict:
global_config['Global']['character_dict_path'] = FLAGS.dict global_config['Global']['character_dict_path'] = FLAGS.dict
if FLAGS.dataset_root_path: dict_path = os.path.join(project_path,FLAGS.dict)
global_config['Eval']['dataset']['data_dir'] = FLAGS.dataset_root_path loss_file(dict_path)
global_config['Train']['dataset']['data_dir'] = FLAGS.dataset_root_path if FLAGS.data_dir:
global_config['Eval']['dataset']['data_dir'] = FLAGS.data_dir
global_config['Train']['dataset']['data_dir'] = FLAGS.data_dir
data_dir = os.path.join(project_path,FLAGS.data_dir)
loss_file(data_dir)
save_file_path = 'rec_{}_lite_train.yml'.format(FLAGS.language) save_file_path = 'rec_{}_lite_train.yml'.format(FLAGS.language)
if os.path.isfile(save_file_path): if os.path.isfile(save_file_path):
os.remove(save_file_path) os.remove(save_file_path)
with open(save_file_path, 'w') as f: with open(save_file_path, 'w') as f:
yaml.dump(dict(global_config), f, default_flow_style=False, sort_keys=False) yaml.dump(dict(global_config), f, default_flow_style=False, sort_keys=False)
print("Train list path set to :{}".format(global_config['Train']['dataset']['label_file_list'][0])) logging.info("Project path is :{}".format(project_path))
print("Eval list path set to :{}".format(global_config['Eval']['dataset']['label_file_list'][0])) logging.info("Train list path set to :{}".format(global_config['Train']['dataset']['label_file_list'][0]))
print("Dataset root path set to :{}".format(global_config['Eval']['dataset']['data_dir'])) logging.info("Eval list path set to :{}".format(global_config['Eval']['dataset']['label_file_list'][0]))
print("Dict path set to :{}".format(global_config['Global']['character_dict_path'])) logging.info("Dataset root path set to :{}".format(global_config['Eval']['dataset']['data_dir']))
print("Config file set to :configs/rec/multi_language/{}".format(save_file_path)) logging.info("Dict path set to :{}".format(global_config['Global']['character_dict_path']))
logging.info("Config file set to :configs/rec/multi_language/{}".format(save_file_path))
...@@ -64,7 +64,7 @@ Metric: ...@@ -64,7 +64,7 @@ Metric:
Train: Train:
dataset: dataset:
name: SimpleDataSet name: SimpleDataSet
data_dir: ./train_data/ data_dir: train_data/
label_file_list: ["./train_data/train_list.txt"] label_file_list: ["./train_data/train_list.txt"]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
...@@ -85,7 +85,7 @@ Train: ...@@ -85,7 +85,7 @@ Train:
Eval: Eval:
dataset: dataset:
name: SimpleDataSet name: SimpleDataSet
data_dir: ./train_data/ data_dir: train_data/
label_file_list: ["./train_data/val_list.txt"] label_file_list: ["./train_data/val_list.txt"]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
......
File added
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment