Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
5c664bf4
"sgl-router/src/vscode:/vscode.git/clone" did not exist on "cd4da1f19b1422ae960ab330a6d0e5d780fc5de5"
Unverified
Commit
5c664bf4
authored
Aug 26, 2021
by
xiaoting
Committed by
GitHub
Aug 26, 2021
Browse files
Merge pull request #3721 from Topdu/dygraph
add rec_nrtr
parents
28a40efe
2bf8ad9b
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
4 deletions
+6
-4
tools/program.py
tools/program.py
+6
-4
No files found.
tools/program.py
View file @
5c664bf4
...
@@ -186,9 +186,11 @@ def train(config,
...
@@ -186,9 +186,11 @@ def train(config,
model
.
train
()
model
.
train
()
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
try
:
use_nrtr
=
config
[
'Architecture'
][
'algorithm'
]
==
"NRTR"
try
:
model_type
=
config
[
'Architecture'
][
'model_type'
]
model_type
=
config
[
'Architecture'
][
'model_type'
]
except
:
except
:
model_type
=
None
model_type
=
None
if
'start_epoch'
in
best_model_dict
:
if
'start_epoch'
in
best_model_dict
:
...
@@ -213,7 +215,7 @@ def train(config,
...
@@ -213,7 +215,7 @@ def train(config,
images
=
batch
[
0
]
images
=
batch
[
0
]
if
use_srn
:
if
use_srn
:
model_average
=
True
model_average
=
True
if
use_srn
or
model_type
==
'table'
:
if
use_srn
or
model_type
==
'table'
or
use_nrtr
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
...
@@ -398,7 +400,7 @@ def preprocess(is_train=False):
...
@@ -398,7 +400,7 @@ def preprocess(is_train=False):
alg
=
config
[
'Architecture'
][
'algorithm'
]
alg
=
config
[
'Architecture'
][
'algorithm'
]
assert
alg
in
[
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'TableAttn'
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
]
]
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment