Unverified Commit 5c664bf4 authored by xiaoting's avatar xiaoting Committed by GitHub
Browse files

Merge pull request #3721 from Topdu/dygraph

add rec_nrtr
parents 28a40efe 2bf8ad9b
......@@ -186,6 +186,8 @@ def train(config,
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
try:
model_type = config['Architecture']['model_type']
except:
......@@ -213,7 +215,7 @@ def train(config,
images = batch[0]
if use_srn:
model_average = True
if use_srn or model_type == 'table':
if use_srn or model_type == 'table' or use_nrtr:
preds = model(images, data=batch[1:])
else:
preds = model(images)
......@@ -398,7 +400,7 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'TableAttn'
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment