Unverified Commit 773a8c45 authored by xiaoting's avatar xiaoting Committed by GitHub
Browse files

Merge pull request #3851 from tink2123/upload_seed

Add seed for ocr_rec
parents 6a41a37a 560f2f49
...@@ -11,4 +11,5 @@ opencv-contrib-python==4.4.0.46 ...@@ -11,4 +11,5 @@ opencv-contrib-python==4.4.0.46
cython cython
lxml lxml
premailer premailer
openpyxl openpyxl
\ No newline at end of file fasttext==0.9.1
\ No newline at end of file
...@@ -54,8 +54,7 @@ def main(): ...@@ -54,8 +54,7 @@ def main():
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN" extra_input = config['Architecture']['algorithm'] in ["SRN", "SAR"]
use_sar = config['Architecture']['algorithm'] == "SAR"
if "model_type" in config['Architecture'].keys(): if "model_type" in config['Architecture'].keys():
model_type = config['Architecture']['model_type'] model_type = config['Architecture']['model_type']
else: else:
...@@ -72,7 +71,7 @@ def main(): ...@@ -72,7 +71,7 @@ def main():
# start eval # start eval
metric = program.eval(model, valid_dataloader, post_process_class, metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, use_srn, use_sar) eval_class, model_type, extra_input)
logger.info('metric eval ***************') logger.info('metric eval ***************')
for k, v in metric.items(): for k, v in metric.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
......
...@@ -186,12 +186,13 @@ def train(config, ...@@ -186,12 +186,13 @@ def train(config,
model.train() model.train()
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
use_nrtr = config['Architecture']['algorithm'] == "NRTR" extra_input = config['Architecture'][
use_sar = config['Architecture']['algorithm'] == 'SAR' 'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
try: try:
model_type = config['Architecture']['model_type'] model_type = config['Architecture']['model_type']
except: except:
model_type = None model_type = None
algorithm = config['Architecture']['algorithm']
if 'start_epoch' in best_model_dict: if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch'] start_epoch = best_model_dict['start_epoch']
...@@ -215,7 +216,7 @@ def train(config, ...@@ -215,7 +216,7 @@ def train(config,
images = batch[0] images = batch[0]
if use_srn: if use_srn:
model_average = True model_average = True
if use_srn or model_type == 'table' or use_nrtr or use_sar: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
else: else:
preds = model(images) preds = model(images)
...@@ -279,8 +280,7 @@ def train(config, ...@@ -279,8 +280,7 @@ def train(config,
post_process_class, post_process_class,
eval_class, eval_class,
model_type, model_type,
use_srn=use_srn, extra_input=extra_input)
use_sar=use_sar)
cur_metric_str = 'cur metric, {}'.format(', '.join( cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()])) ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str) logger.info(cur_metric_str)
...@@ -352,8 +352,7 @@ def eval(model, ...@@ -352,8 +352,7 @@ def eval(model,
post_process_class, post_process_class,
eval_class, eval_class,
model_type=None, model_type=None,
use_srn=False, extra_input=False):
use_sar=False):
model.eval() model.eval()
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
...@@ -366,7 +365,7 @@ def eval(model, ...@@ -366,7 +365,7 @@ def eval(model,
break break
images = batch[0] images = batch[0]
start = time.time() start = time.time()
if use_srn or model_type == 'table' or use_sar: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
else: else:
preds = model(images) preds = model(images)
...@@ -402,7 +401,8 @@ def preprocess(is_train=False): ...@@ -402,7 +401,8 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm'] alg = config['Architecture']['algorithm']
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE' 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'ASTER'
] ]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment