Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
773a8c45
Unverified
Commit
773a8c45
authored
Sep 28, 2021
by
xiaoting
Committed by
GitHub
Sep 28, 2021
Browse files
Merge pull request #3851 from tink2123/upload_seed
Add seed for ocr_rec
parents
6a41a37a
560f2f49
Changes
23
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
13 deletions
+13
-13
requirements.txt
requirements.txt
+2
-1
tools/eval.py
tools/eval.py
+2
-3
tools/program.py
tools/program.py
+9
-9
No files found.
requirements.txt
View file @
773a8c45
...
@@ -12,3 +12,4 @@ cython
...
@@ -12,3 +12,4 @@ cython
lxml
lxml
premailer
premailer
openpyxl
openpyxl
fasttext
==0.9.1
\ No newline at end of file
tools/eval.py
View file @
773a8c45
...
@@ -54,8 +54,7 @@ def main():
...
@@ -54,8 +54,7 @@ def main():
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
extra_input
=
config
[
'Architecture'
][
'algorithm'
]
in
[
"SRN"
,
"SAR"
]
use_sar
=
config
[
'Architecture'
][
'algorithm'
]
==
"SAR"
if
"model_type"
in
config
[
'Architecture'
].
keys
():
if
"model_type"
in
config
[
'Architecture'
].
keys
():
model_type
=
config
[
'Architecture'
][
'model_type'
]
model_type
=
config
[
'Architecture'
][
'model_type'
]
else
:
else
:
...
@@ -72,7 +71,7 @@ def main():
...
@@ -72,7 +71,7 @@ def main():
# start eval
# start eval
metric
=
program
.
eval
(
model
,
valid_dataloader
,
post_process_class
,
metric
=
program
.
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
,
model_type
,
use_srn
,
use_sar
)
eval_class
,
model_type
,
extra_input
)
logger
.
info
(
'metric eval ***************'
)
logger
.
info
(
'metric eval ***************'
)
for
k
,
v
in
metric
.
items
():
for
k
,
v
in
metric
.
items
():
logger
.
info
(
'{}:{}'
.
format
(
k
,
v
))
logger
.
info
(
'{}:{}'
.
format
(
k
,
v
))
...
...
tools/program.py
View file @
773a8c45
...
@@ -186,12 +186,13 @@ def train(config,
...
@@ -186,12 +186,13 @@ def train(config,
model
.
train
()
model
.
train
()
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
use_nrtr
=
config
[
'Architecture'
][
'algorithm'
]
==
"NRTR"
extra_input
=
config
[
'Architecture'
][
use_sar
=
config
[
'Architecture'
][
'algorithm'
]
==
'SAR'
'algorithm'
]
in
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
]
try
:
try
:
model_type
=
config
[
'Architecture'
][
'model_type'
]
model_type
=
config
[
'Architecture'
][
'model_type'
]
except
:
except
:
model_type
=
None
model_type
=
None
algorithm
=
config
[
'Architecture'
][
'algorithm'
]
if
'start_epoch'
in
best_model_dict
:
if
'start_epoch'
in
best_model_dict
:
start_epoch
=
best_model_dict
[
'start_epoch'
]
start_epoch
=
best_model_dict
[
'start_epoch'
]
...
@@ -215,7 +216,7 @@ def train(config,
...
@@ -215,7 +216,7 @@ def train(config,
images
=
batch
[
0
]
images
=
batch
[
0
]
if
use_srn
:
if
use_srn
:
model_average
=
True
model_average
=
True
if
use_srn
or
model_type
==
'table'
or
use_nrtr
or
use_sar
:
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
...
@@ -279,8 +280,7 @@ def train(config,
...
@@ -279,8 +280,7 @@ def train(config,
post_process_class
,
post_process_class
,
eval_class
,
eval_class
,
model_type
,
model_type
,
use_srn
=
use_srn
,
extra_input
=
extra_input
)
use_sar
=
use_sar
)
cur_metric_str
=
'cur metric, {}'
.
format
(
', '
.
join
(
cur_metric_str
=
'cur metric, {}'
.
format
(
', '
.
join
(
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
cur_metric
.
items
()]))
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
cur_metric
.
items
()]))
logger
.
info
(
cur_metric_str
)
logger
.
info
(
cur_metric_str
)
...
@@ -352,8 +352,7 @@ def eval(model,
...
@@ -352,8 +352,7 @@ def eval(model,
post_process_class
,
post_process_class
,
eval_class
,
eval_class
,
model_type
=
None
,
model_type
=
None
,
use_srn
=
False
,
extra_input
=
False
):
use_sar
=
False
):
model
.
eval
()
model
.
eval
()
with
paddle
.
no_grad
():
with
paddle
.
no_grad
():
total_frame
=
0.0
total_frame
=
0.0
...
@@ -366,7 +365,7 @@ def eval(model,
...
@@ -366,7 +365,7 @@ def eval(model,
break
break
images
=
batch
[
0
]
images
=
batch
[
0
]
start
=
time
.
time
()
start
=
time
.
time
()
if
use_srn
or
model_type
==
'table'
or
use_sar
:
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
...
@@ -402,7 +401,8 @@ def preprocess(is_train=False):
...
@@ -402,7 +401,8 @@ def preprocess(is_train=False):
alg
=
config
[
'Architecture'
][
'algorithm'
]
alg
=
config
[
'Architecture'
][
'algorithm'
]
assert
alg
in
[
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'ASTER'
]
]
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment