Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
0bf6a75e
Commit
0bf6a75e
authored
Oct 12, 2021
by
LDOUBLEV
Browse files
Merge branch 'dygraph' of
https://github.com/PaddlePaddle/PaddleOCR
into dygraph
parents
faa88edd
af0bac58
Changes
41
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
22 deletions
+41
-22
tools/program.py
tools/program.py
+41
-22
No files found.
tools/program.py
View file @
0bf6a75e
...
...
@@ -31,6 +31,7 @@ from ppocr.utils.stats import TrainingStats
from
ppocr.utils.save_load
import
save_model
from
ppocr.utils.utility
import
print_dict
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils
import
profiler
from
ppocr.data
import
build_dataloader
import
numpy
as
np
...
...
@@ -42,6 +43,13 @@ class ArgsParser(ArgumentParser):
self
.
add_argument
(
"-c"
,
"--config"
,
help
=
"configuration file to use"
)
self
.
add_argument
(
"-o"
,
"--opt"
,
nargs
=
'+'
,
help
=
"set configuration options"
)
self
.
add_argument
(
'-p'
,
'--profiler_options'
,
type
=
str
,
default
=
None
,
help
=
'The option of profiler, which should be in format
\"
key1=value1;key2=value2;key3=value3
\"
.'
)
def
parse_args
(
self
,
argv
=
None
):
args
=
super
(
ArgsParser
,
self
).
parse_args
(
argv
)
...
...
@@ -158,6 +166,7 @@ def train(config,
epoch_num
=
config
[
'Global'
][
'epoch_num'
]
print_batch_step
=
config
[
'Global'
][
'print_batch_step'
]
eval_batch_step
=
config
[
'Global'
][
'eval_batch_step'
]
profiler_options
=
config
[
'profiler_options'
]
global_step
=
0
if
'global_step'
in
pre_best_model_dict
:
...
...
@@ -186,12 +195,13 @@ def train(config,
model
.
train
()
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
use_nrtr
=
config
[
'Architecture'
][
'algorithm'
]
==
"NRTR"
use_sar
=
config
[
'Architecture'
][
'algorithm'
]
==
'SAR'
extra_input
=
config
[
'Architecture'
][
'algorithm'
]
in
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
]
try
:
model_type
=
config
[
'Architecture'
][
'model_type'
]
except
:
model_type
=
None
algorithm
=
config
[
'Architecture'
][
'algorithm'
]
if
'start_epoch'
in
best_model_dict
:
start_epoch
=
best_model_dict
[
'start_epoch'
]
...
...
@@ -208,6 +218,7 @@ def train(config,
max_iter
=
len
(
train_dataloader
)
-
1
if
platform
.
system
(
)
==
"Windows"
else
len
(
train_dataloader
)
for
idx
,
batch
in
enumerate
(
train_dataloader
):
profiler
.
add_profiler_step
(
profiler_options
)
train_reader_cost
+=
time
.
time
()
-
batch_start
if
idx
>=
max_iter
:
break
...
...
@@ -215,7 +226,7 @@ def train(config,
images
=
batch
[
0
]
if
use_srn
:
model_average
=
True
if
use_srn
or
model_type
==
'table'
or
use_nrtr
or
use_sar
:
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
preds
=
model
(
images
)
...
...
@@ -279,8 +290,7 @@ def train(config,
post_process_class
,
eval_class
,
model_type
,
use_srn
=
use_srn
,
use_sar
=
use_sar
)
extra_input
=
extra_input
)
cur_metric_str
=
'cur metric, {}'
.
format
(
', '
.
join
(
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
cur_metric
.
items
()]))
logger
.
info
(
cur_metric_str
)
...
...
@@ -351,9 +361,8 @@ def eval(model,
valid_dataloader
,
post_process_class
,
eval_class
,
model_type
,
use_srn
=
False
,
use_sar
=
False
):
model_type
=
None
,
extra_input
=
False
):
model
.
eval
()
with
paddle
.
no_grad
():
total_frame
=
0.0
...
...
@@ -366,7 +375,7 @@ def eval(model,
break
images
=
batch
[
0
]
start
=
time
.
time
()
if
use_srn
or
model_type
==
'table'
or
use_sar
:
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
preds
=
model
(
images
)
...
...
@@ -392,8 +401,23 @@ def eval(model,
def
preprocess
(
is_train
=
False
):
FLAGS
=
ArgsParser
().
parse_args
()
profiler_options
=
FLAGS
.
profiler_options
config
=
load_config
(
FLAGS
.
config
)
merge_config
(
FLAGS
.
opt
)
profile_dic
=
{
"profiler_options"
:
FLAGS
.
profiler_options
}
merge_config
(
profile_dic
)
if
is_train
:
# save_config
save_model_dir
=
config
[
'Global'
][
'save_model_dir'
]
os
.
makedirs
(
save_model_dir
,
exist_ok
=
True
)
with
open
(
os
.
path
.
join
(
save_model_dir
,
'config.yml'
),
'w'
)
as
f
:
yaml
.
dump
(
dict
(
config
),
f
,
default_flow_style
=
False
,
sort_keys
=
False
)
log_file
=
'{}/train.log'
.
format
(
save_model_dir
)
else
:
log_file
=
None
logger
=
get_logger
(
name
=
'root'
,
log_file
=
log_file
)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu
=
config
[
'Global'
][
'use_gpu'
]
...
...
@@ -402,24 +426,19 @@ def preprocess(is_train=False):
alg
=
config
[
'Architecture'
][
'algorithm'
]
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
]
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'SEED'
]
windows_not_support_list
=
[
'PSE'
]
if
platform
.
system
()
==
"Windows"
and
alg
in
windows_not_support_list
:
logger
.
warning
(
'{} is not support in Windows now'
.
format
(
windows_not_support_list
))
sys
.
exit
()
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
paddle
.
set_device
(
device
)
config
[
'Global'
][
'distributed'
]
=
dist
.
get_world_size
()
!=
1
if
is_train
:
# save_config
save_model_dir
=
config
[
'Global'
][
'save_model_dir'
]
os
.
makedirs
(
save_model_dir
,
exist_ok
=
True
)
with
open
(
os
.
path
.
join
(
save_model_dir
,
'config.yml'
),
'w'
)
as
f
:
yaml
.
dump
(
dict
(
config
),
f
,
default_flow_style
=
False
,
sort_keys
=
False
)
log_file
=
'{}/train.log'
.
format
(
save_model_dir
)
else
:
log_file
=
None
logger
=
get_logger
(
name
=
'root'
,
log_file
=
log_file
)
if
config
[
'Global'
][
'use_visualdl'
]:
from
visualdl
import
LogWriter
save_model_dir
=
config
[
'Global'
][
'save_model_dir'
]
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment