Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
41a1b292
Commit
41a1b292
authored
Jan 20, 2022
by
Leif
Browse files
Merge remote-tracking branch 'origin/dygraph' into dygraph
parents
9471054e
3d30899b
Changes
162
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
63 deletions
+60
-63
tools/program.py
tools/program.py
+55
-60
tools/train.py
tools/train.py
+5
-3
No files found.
tools/program.py
View file @
41a1b292
...
@@ -69,24 +69,6 @@ class ArgsParser(ArgumentParser):
...
@@ -69,24 +69,6 @@ class ArgsParser(ArgumentParser):
return
config
return
config
class
AttrDict
(
dict
):
"""Single level attribute dict, NOT recursive"""
def
__init__
(
self
,
**
kwargs
):
super
(
AttrDict
,
self
).
__init__
()
super
(
AttrDict
,
self
).
update
(
kwargs
)
def
__getattr__
(
self
,
key
):
if
key
in
self
:
return
self
[
key
]
raise
AttributeError
(
"object has no attribute '{}'"
.
format
(
key
))
global_config
=
AttrDict
()
default_config
=
{
'Global'
:
{
'debug'
:
False
,
}}
def
load_config
(
file_path
):
def
load_config
(
file_path
):
"""
"""
Load config from yml/yaml file.
Load config from yml/yaml file.
...
@@ -94,38 +76,38 @@ def load_config(file_path):
...
@@ -94,38 +76,38 @@ def load_config(file_path):
file_path (str): Path of the config file to be loaded.
file_path (str): Path of the config file to be loaded.
Returns: global config
Returns: global config
"""
"""
merge_config
(
default_config
)
_
,
ext
=
os
.
path
.
splitext
(
file_path
)
_
,
ext
=
os
.
path
.
splitext
(
file_path
)
assert
ext
in
[
'.yml'
,
'.yaml'
],
"only support yaml files for now"
assert
ext
in
[
'.yml'
,
'.yaml'
],
"only support yaml files for now"
merge_
config
(
yaml
.
load
(
open
(
file_path
,
'rb'
),
Loader
=
yaml
.
Loader
)
)
config
=
yaml
.
load
(
open
(
file_path
,
'rb'
),
Loader
=
yaml
.
Loader
)
return
global_
config
return
config
def
merge_config
(
config
):
def
merge_config
(
config
,
opts
):
"""
"""
Merge config into global config.
Merge config into global config.
Args:
Args:
config (dict): Config to be merged.
config (dict): Config to be merged.
Returns: global config
Returns: global config
"""
"""
for
key
,
value
in
config
.
items
():
for
key
,
value
in
opts
.
items
():
if
"."
not
in
key
:
if
"."
not
in
key
:
if
isinstance
(
value
,
dict
)
and
key
in
global_
config
:
if
isinstance
(
value
,
dict
)
and
key
in
config
:
global_
config
[
key
].
update
(
value
)
config
[
key
].
update
(
value
)
else
:
else
:
global_
config
[
key
]
=
value
config
[
key
]
=
value
else
:
else
:
sub_keys
=
key
.
split
(
'.'
)
sub_keys
=
key
.
split
(
'.'
)
assert
(
assert
(
sub_keys
[
0
]
in
global_
config
sub_keys
[
0
]
in
config
),
"the sub_keys can only be one of global_config: {}, but get: {}, please check your running command"
.
format
(
),
"the sub_keys can only be one of global_config: {}, but get: {}, please check your running command"
.
format
(
global_
config
.
keys
(),
sub_keys
[
0
])
config
.
keys
(),
sub_keys
[
0
])
cur
=
global_
config
[
sub_keys
[
0
]]
cur
=
config
[
sub_keys
[
0
]]
for
idx
,
sub_key
in
enumerate
(
sub_keys
[
1
:]):
for
idx
,
sub_key
in
enumerate
(
sub_keys
[
1
:]):
if
idx
==
len
(
sub_keys
)
-
2
:
if
idx
==
len
(
sub_keys
)
-
2
:
cur
[
sub_key
]
=
value
cur
[
sub_key
]
=
value
else
:
else
:
cur
=
cur
[
sub_key
]
cur
=
cur
[
sub_key
]
return
config
def
check_gpu
(
use_gpu
):
def
check_gpu
(
use_gpu
):
...
@@ -204,20 +186,24 @@ def train(config,
...
@@ -204,20 +186,24 @@ def train(config,
model_type
=
None
model_type
=
None
algorithm
=
config
[
'Architecture'
][
'algorithm'
]
algorithm
=
config
[
'Architecture'
][
'algorithm'
]
if
'start_epoch'
in
best_model_dict
:
start_epoch
=
best_model_dict
[
start_epoch
=
best_model_dict
[
'start_epoch'
]
'start_epoch'
]
if
'start_epoch'
in
best_model_dict
else
1
else
:
start_epoch
=
1
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
reader_start
=
time
.
time
()
max_iter
=
len
(
train_dataloader
)
-
1
if
platform
.
system
(
)
==
"Windows"
else
len
(
train_dataloader
)
for
epoch
in
range
(
start_epoch
,
epoch_num
+
1
):
for
epoch
in
range
(
start_epoch
,
epoch_num
+
1
):
train_dataloader
=
build_dataloader
(
if
train_dataloader
.
dataset
.
need_reset
:
config
,
'Train'
,
device
,
logger
,
seed
=
epoch
)
train_dataloader
=
build_dataloader
(
train_reader_cost
=
0.0
config
,
'Train'
,
device
,
logger
,
seed
=
epoch
)
train_run_cost
=
0.0
max_iter
=
len
(
train_dataloader
)
-
1
if
platform
.
system
(
total_samples
=
0
)
==
"Windows"
else
len
(
train_dataloader
)
reader_start
=
time
.
time
()
max_iter
=
len
(
train_dataloader
)
-
1
if
platform
.
system
(
)
==
"Windows"
else
len
(
train_dataloader
)
for
idx
,
batch
in
enumerate
(
train_dataloader
):
for
idx
,
batch
in
enumerate
(
train_dataloader
):
profiler
.
add_profiler_step
(
profiler_options
)
profiler
.
add_profiler_step
(
profiler_options
)
train_reader_cost
+=
time
.
time
()
-
reader_start
train_reader_cost
+=
time
.
time
()
-
reader_start
...
@@ -239,10 +225,11 @@ def train(config,
...
@@ -239,10 +225,11 @@ def train(config,
else
:
else
:
if
model_type
==
'table'
or
extra_input
:
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
preds
=
model
(
images
,
data
=
batch
[
1
:])
elif
model_type
==
"kie"
:
elif
model_type
in
[
"kie"
,
'vqa'
]
:
preds
=
model
(
batch
)
preds
=
model
(
batch
)
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
loss
=
loss_class
(
preds
,
batch
)
loss
=
loss_class
(
preds
,
batch
)
avg_loss
=
loss
[
'loss'
]
avg_loss
=
loss
[
'loss'
]
...
@@ -256,6 +243,7 @@ def train(config,
...
@@ -256,6 +243,7 @@ def train(config,
optimizer
.
clear_grad
()
optimizer
.
clear_grad
()
train_run_cost
+=
time
.
time
()
-
train_start
train_run_cost
+=
time
.
time
()
-
train_start
global_step
+=
1
total_samples
+=
len
(
images
)
total_samples
+=
len
(
images
)
if
not
isinstance
(
lr_scheduler
,
float
):
if
not
isinstance
(
lr_scheduler
,
float
):
...
@@ -285,12 +273,13 @@ def train(config,
...
@@ -285,12 +273,13 @@ def train(config,
(
global_step
>
0
and
global_step
%
print_batch_step
==
0
)
or
(
global_step
>
0
and
global_step
%
print_batch_step
==
0
)
or
(
idx
>=
len
(
train_dataloader
)
-
1
)):
(
idx
>=
len
(
train_dataloader
)
-
1
)):
logs
=
train_stats
.
log
()
logs
=
train_stats
.
log
()
strs
=
'epoch: [{}/{}],
i
te
r
: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'
.
format
(
strs
=
'epoch: [{}/{}],
global_s
te
p
: {}, {},
avg_
reader_cost: {:.5f} s,
avg_
batch_cost: {:.5f} s,
avg_
samples: {}, ips: {:.5f}'
.
format
(
epoch
,
epoch_num
,
global_step
,
logs
,
train_reader_cost
/
epoch
,
epoch_num
,
global_step
,
logs
,
train_reader_cost
/
print_batch_step
,
(
train_reader_cost
+
train_run_cost
)
/
print_batch_step
,
(
train_reader_cost
+
train_run_cost
)
/
print_batch_step
,
total_samples
,
print_batch_step
,
total_samples
/
print_batch_step
,
total_samples
/
(
train_reader_cost
+
train_run_cost
))
total_samples
/
(
train_reader_cost
+
train_run_cost
))
logger
.
info
(
strs
)
logger
.
info
(
strs
)
train_reader_cost
=
0.0
train_reader_cost
=
0.0
train_run_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
total_samples
=
0
...
@@ -330,6 +319,7 @@ def train(config,
...
@@ -330,6 +319,7 @@ def train(config,
optimizer
,
optimizer
,
save_model_dir
,
save_model_dir
,
logger
,
logger
,
config
,
is_best
=
True
,
is_best
=
True
,
prefix
=
'best_accuracy'
,
prefix
=
'best_accuracy'
,
best_model_dict
=
best_model_dict
,
best_model_dict
=
best_model_dict
,
...
@@ -344,8 +334,7 @@ def train(config,
...
@@ -344,8 +334,7 @@ def train(config,
vdl_writer
.
add_scalar
(
'EVAL/best_{}'
.
format
(
main_indicator
),
vdl_writer
.
add_scalar
(
'EVAL/best_{}'
.
format
(
main_indicator
),
best_model_dict
[
main_indicator
],
best_model_dict
[
main_indicator
],
global_step
)
global_step
)
global_step
+=
1
optimizer
.
clear_grad
()
reader_start
=
time
.
time
()
reader_start
=
time
.
time
()
if
dist
.
get_rank
()
==
0
:
if
dist
.
get_rank
()
==
0
:
save_model
(
save_model
(
...
@@ -353,6 +342,7 @@ def train(config,
...
@@ -353,6 +342,7 @@ def train(config,
optimizer
,
optimizer
,
save_model_dir
,
save_model_dir
,
logger
,
logger
,
config
,
is_best
=
False
,
is_best
=
False
,
prefix
=
'latest'
,
prefix
=
'latest'
,
best_model_dict
=
best_model_dict
,
best_model_dict
=
best_model_dict
,
...
@@ -364,6 +354,7 @@ def train(config,
...
@@ -364,6 +354,7 @@ def train(config,
optimizer
,
optimizer
,
save_model_dir
,
save_model_dir
,
logger
,
logger
,
config
,
is_best
=
False
,
is_best
=
False
,
prefix
=
'iter_epoch_{}'
.
format
(
epoch
),
prefix
=
'iter_epoch_{}'
.
format
(
epoch
),
best_model_dict
=
best_model_dict
,
best_model_dict
=
best_model_dict
,
...
@@ -401,19 +392,28 @@ def eval(model,
...
@@ -401,19 +392,28 @@ def eval(model,
start
=
time
.
time
()
start
=
time
.
time
()
if
model_type
==
'table'
or
extra_input
:
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
preds
=
model
(
images
,
data
=
batch
[
1
:])
elif
model_type
==
"kie"
:
elif
model_type
in
[
"kie"
,
'vqa'
]
:
preds
=
model
(
batch
)
preds
=
model
(
batch
)
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
batch
=
[
item
.
numpy
()
for
item
in
batch
]
batch_numpy
=
[]
for
item
in
batch
:
if
isinstance
(
item
,
paddle
.
Tensor
):
batch_numpy
.
append
(
item
.
numpy
())
else
:
batch_numpy
.
append
(
item
)
# Obtain usable results from post-processing methods
# Obtain usable results from post-processing methods
total_time
+=
time
.
time
()
-
start
total_time
+=
time
.
time
()
-
start
# Evaluate the results of the current batch
# Evaluate the results of the current batch
if
model_type
in
[
'table'
,
'kie'
]:
if
model_type
in
[
'table'
,
'kie'
]:
eval_class
(
preds
,
batch
)
eval_class
(
preds
,
batch_numpy
)
elif
model_type
in
[
'vqa'
]:
post_result
=
post_process_class
(
preds
,
batch_numpy
)
eval_class
(
post_result
,
batch_numpy
)
else
:
else
:
post_result
=
post_process_class
(
preds
,
batch
[
1
])
post_result
=
post_process_class
(
preds
,
batch
_numpy
[
1
])
eval_class
(
post_result
,
batch
)
eval_class
(
post_result
,
batch
_numpy
)
pbar
.
update
(
1
)
pbar
.
update
(
1
)
total_frame
+=
len
(
images
)
total_frame
+=
len
(
images
)
...
@@ -479,9 +479,9 @@ def preprocess(is_train=False):
...
@@ -479,9 +479,9 @@ def preprocess(is_train=False):
FLAGS
=
ArgsParser
().
parse_args
()
FLAGS
=
ArgsParser
().
parse_args
()
profiler_options
=
FLAGS
.
profiler_options
profiler_options
=
FLAGS
.
profiler_options
config
=
load_config
(
FLAGS
.
config
)
config
=
load_config
(
FLAGS
.
config
)
merge_config
(
FLAGS
.
opt
)
config
=
merge_config
(
config
,
FLAGS
.
opt
)
profile_dic
=
{
"profiler_options"
:
FLAGS
.
profiler_options
}
profile_dic
=
{
"profiler_options"
:
FLAGS
.
profiler_options
}
merge_config
(
profile_dic
)
config
=
merge_config
(
config
,
profile_dic
)
if
is_train
:
if
is_train
:
# save_config
# save_config
...
@@ -503,20 +503,15 @@ def preprocess(is_train=False):
...
@@ -503,20 +503,15 @@ def preprocess(is_train=False):
assert
alg
in
[
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'SEED'
,
'SDMGR'
'SEED'
,
'SDMGR'
,
'LayoutXLM'
,
'LayoutLM'
]
]
windows_not_support_list
=
[
'PSE'
]
if
platform
.
system
()
==
"Windows"
and
alg
in
windows_not_support_list
:
logger
.
warning
(
'{} is not support in Windows now'
.
format
(
windows_not_support_list
))
sys
.
exit
()
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
paddle
.
set_device
(
device
)
device
=
paddle
.
set_device
(
device
)
config
[
'Global'
][
'distributed'
]
=
dist
.
get_world_size
()
!=
1
config
[
'Global'
][
'distributed'
]
=
dist
.
get_world_size
()
!=
1
if
config
[
'Global'
][
'use_visualdl'
]:
if
config
[
'Global'
][
'use_visualdl'
]
and
dist
.
get_rank
()
==
0
:
from
visualdl
import
LogWriter
from
visualdl
import
LogWriter
save_model_dir
=
config
[
'Global'
][
'save_model_dir'
]
save_model_dir
=
config
[
'Global'
][
'save_model_dir'
]
vdl_writer_path
=
'{}/vdl/'
.
format
(
save_model_dir
)
vdl_writer_path
=
'{}/vdl/'
.
format
(
save_model_dir
)
...
...
tools/train.py
View file @
41a1b292
...
@@ -27,8 +27,6 @@ import yaml
...
@@ -27,8 +27,6 @@ import yaml
import
paddle
import
paddle
import
paddle.distributed
as
dist
import
paddle.distributed
as
dist
paddle
.
seed
(
2
)
from
ppocr.data
import
build_dataloader
from
ppocr.data
import
build_dataloader
from
ppocr.modeling.architectures
import
build_model
from
ppocr.modeling.architectures
import
build_model
from
ppocr.losses
import
build_loss
from
ppocr.losses
import
build_loss
...
@@ -36,6 +34,7 @@ from ppocr.optimizer import build_optimizer
...
@@ -36,6 +34,7 @@ from ppocr.optimizer import build_optimizer
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.metrics
import
build_metric
from
ppocr.metrics
import
build_metric
from
ppocr.utils.save_load
import
load_model
from
ppocr.utils.save_load
import
load_model
from
ppocr.utils.utility
import
set_seed
import
tools.program
as
program
import
tools.program
as
program
dist
.
get_world_size
()
dist
.
get_world_size
()
...
@@ -97,7 +96,8 @@ def main(config, device, logger, vdl_writer):
...
@@ -97,7 +96,8 @@ def main(config, device, logger, vdl_writer):
# build metric
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
eval_class
=
build_metric
(
config
[
'Metric'
])
# load pretrain model
# load pretrain model
pre_best_model_dict
=
load_model
(
config
,
model
,
optimizer
)
pre_best_model_dict
=
load_model
(
config
,
model
,
optimizer
,
config
[
'Architecture'
][
"model_type"
])
logger
.
info
(
'train dataloader has {} iters'
.
format
(
len
(
train_dataloader
)))
logger
.
info
(
'train dataloader has {} iters'
.
format
(
len
(
train_dataloader
)))
if
valid_dataloader
is
not
None
:
if
valid_dataloader
is
not
None
:
logger
.
info
(
'valid dataloader has {} iters'
.
format
(
logger
.
info
(
'valid dataloader has {} iters'
.
format
(
...
@@ -145,5 +145,7 @@ def test_reader(config, device, logger):
...
@@ -145,5 +145,7 @@ def test_reader(config, device, logger):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
(
is_train
=
True
)
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
(
is_train
=
True
)
seed
=
config
[
'Global'
][
'seed'
]
if
'seed'
in
config
[
'Global'
]
else
1024
set_seed
(
seed
)
main
(
config
,
device
,
logger
,
vdl_writer
)
main
(
config
,
device
,
logger
,
vdl_writer
)
# test_reader(config, device, logger)
# test_reader(config, device, logger)
Prev
1
…
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment