Commit 1623c17c authored by Topdu's avatar Topdu
Browse files

add rec_nrtr

parent b6f0a903
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
\
]
^
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
{
|
}
~
\ No newline at end of file
......@@ -22,6 +22,7 @@ import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
......@@ -30,6 +31,7 @@ from ppocr.utils.save_load import init_model
from ppocr.utils.utility import print_dict
import tools.program as program
def main():
global_config = config['Global']
# build dataloader
......
......@@ -186,7 +186,7 @@ def train(config,
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch']
else:
......@@ -211,6 +211,9 @@ def train(config,
others = batch[-4:]
preds = model(images, others)
model_average = True
elif use_nrtr:
max_len = batch[2].max()
preds = model(images, batch[1][:,:2+max_len])
else:
preds = model(images)
loss = loss_class(preds, batch)
......@@ -350,13 +353,11 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
break
images = batch[0]
start = time.time()
if use_srn:
others = batch[-4:]
preds = model(images, others)
else:
preds = model(images)
batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods
post_result = post_process_class(preds, batch[1])
......@@ -386,7 +387,7 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation'
'CLS', 'PGNet', 'Distillation','NRTR'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment