Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
96c91907
Unverified
Commit
96c91907
authored
Nov 05, 2020
by
dyning
Committed by
GitHub
Nov 05, 2020
Browse files
Merge pull request #1105 from dyning/dygraph
updata structure of dygraph
parents
7d09cd19
1ae37919
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
170 additions
and
101 deletions
+170
-101
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+1
-1
ppocr/utils/logging.py
ppocr/utils/logging.py
+0
-1
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+12
-21
tools/export_model.py
tools/export_model.py
+76
-0
tools/program.py
tools/program.py
+49
-20
tools/train.py
tools/train.py
+32
-58
No files found.
ppocr/postprocess/__init__.py
View file @
96c91907
...
@@ -24,8 +24,8 @@ __all__ = ['build_post_process']
...
@@ -24,8 +24,8 @@ __all__ = ['build_post_process']
def
build_post_process
(
config
,
global_config
=
None
):
def
build_post_process
(
config
,
global_config
=
None
):
from
.db_postprocess
import
DBPostProcess
from
.db_postprocess
import
DBPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
support_dict
=
[
'DBPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
]
support_dict
=
[
'DBPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
]
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/utils/logging.py
View file @
96c91907
...
@@ -52,7 +52,6 @@ def get_logger(name='ppocr', log_file=None, log_level=logging.INFO):
...
@@ -52,7 +52,6 @@ def get_logger(name='ppocr', log_file=None, log_level=logging.INFO):
stream_handler
=
logging
.
StreamHandler
(
stream
=
sys
.
stdout
)
stream_handler
=
logging
.
StreamHandler
(
stream
=
sys
.
stdout
)
stream_handler
.
setFormatter
(
formatter
)
stream_handler
.
setFormatter
(
formatter
)
logger
.
addHandler
(
stream_handler
)
logger
.
addHandler
(
stream_handler
)
if
log_file
is
not
None
and
dist
.
get_rank
()
==
0
:
if
log_file
is
not
None
and
dist
.
get_rank
()
==
0
:
log_file_folder
=
os
.
path
.
split
(
log_file
)[
0
]
log_file_folder
=
os
.
path
.
split
(
log_file
)[
0
]
os
.
makedirs
(
log_file_folder
,
exist_ok
=
True
)
os
.
makedirs
(
log_file_folder
,
exist_ok
=
True
)
...
...
ppocr/utils/save_load.py
View file @
96c91907
...
@@ -42,16 +42,12 @@ def _mkdir_if_not_exist(path, logger):
...
@@ -42,16 +42,12 @@ def _mkdir_if_not_exist(path, logger):
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
def
load_dygraph_pretrain
(
def
load_dygraph_pretrain
(
model
,
logger
,
path
=
None
,
load_static_weights
=
False
):
model
,
logger
,
path
=
None
,
load_static_weights
=
False
,
):
if
not
(
os
.
path
.
isdir
(
path
)
or
os
.
path
.
exists
(
path
+
'.pdparams'
)):
if
not
(
os
.
path
.
isdir
(
path
)
or
os
.
path
.
exists
(
path
+
'.pdparams'
)):
raise
ValueError
(
"Model pretrain path {} does not "
raise
ValueError
(
"Model pretrain path {} does not "
"exists."
.
format
(
path
))
"exists."
.
format
(
path
))
if
load_static_weights
:
if
load_static_weights
:
pre_state_dict
=
paddle
.
io
.
load_program_state
(
path
)
pre_state_dict
=
paddle
.
static
.
load_program_state
(
path
)
param_state_dict
=
{}
param_state_dict
=
{}
model_dict
=
model
.
state_dict
()
model_dict
=
model
.
state_dict
()
for
key
in
model_dict
.
keys
():
for
key
in
model_dict
.
keys
():
...
@@ -110,21 +106,16 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
...
@@ -110,21 +106,16 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
logger
.
info
(
"resume from {}"
.
format
(
checkpoints
))
logger
.
info
(
"resume from {}"
.
format
(
checkpoints
))
elif
pretrained_model
:
elif
pretrained_model
:
load_static_weights
=
gloabl_config
.
get
(
'load_static_weights'
,
False
)
load_static_weights
=
gloabl_config
.
get
(
'load_static_weights'
,
False
)
if
pretrained_model
:
if
not
isinstance
(
pretrained_model
,
list
):
if
not
isinstance
(
pretrained_model
,
list
):
pretrained_model
=
[
pretrained_model
]
pretrained_model
=
[
pretrained_model
]
if
not
isinstance
(
load_static_weights
,
list
):
if
not
isinstance
(
load_static_weights
,
list
):
load_static_weights
=
[
load_static_weights
]
*
len
(
pretrained_model
)
load_static_weights
=
[
load_static_weights
]
*
len
(
for
idx
,
pretrained
in
enumerate
(
pretrained_model
):
pretrained_model
)
load_static
=
load_static_weights
[
idx
]
for
idx
,
pretrained
in
enumerate
(
pretrained_model
):
load_dygraph_pretrain
(
load_static
=
load_static_weights
[
idx
]
model
,
logger
,
path
=
pretrained
,
load_static_weights
=
load_static
)
load_dygraph_pretrain
(
logger
.
info
(
"load pretrained model from {}"
.
format
(
model
,
pretrained_model
))
logger
,
path
=
pretrained
,
load_static_weights
=
load_static
)
logger
.
info
(
"load pretrained model from {}"
.
format
(
pretrained_model
))
else
:
else
:
logger
.
info
(
'train from scratch'
)
logger
.
info
(
'train from scratch'
)
return
best_model_dict
return
best_model_dict
...
...
tools/export_model.py
0 → 100755
View file @
96c91907
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
paddle
from
paddle.jit
import
to_static
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
from
tools.program
import
load_config
from
tools.program
import
merge_config
def
parse_args
():
def
str2bool
(
v
):
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-c"
,
"--config"
,
help
=
"configuration file to use"
)
parser
.
add_argument
(
"-o"
,
"--output_path"
,
type
=
str
,
default
=
'./output/infer/'
)
return
parser
.
parse_args
()
class
Model
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
model
):
super
(
Model
,
self
).
__init__
()
self
.
pre_model
=
model
# Please modify the 'shape' according to actual needs
@
to_static
(
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
32
,
None
],
dtype
=
'float32'
)
])
def
forward
(
self
,
inputs
):
x
=
self
.
pre_model
(
inputs
)
return
x
def
main
():
FLAGS
=
parse_args
()
config
=
load_config
(
FLAGS
.
config
)
merge_config
(
FLAGS
.
opt
)
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
config
[
'Global'
])
# build model
#for rec algorithm
if
hasattr
(
post_process_class
,
'character'
):
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
,
logger
)
model
.
eval
()
model
=
Model
(
model
)
paddle
.
jit
.
save
(
model
,
FLAGS
.
output_path
)
if
__name__
==
"__main__"
:
main
()
tools/program.py
View file @
96c91907
...
@@ -28,6 +28,10 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter
...
@@ -28,6 +28,10 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter
from
ppocr.utils.stats
import
TrainingStats
from
ppocr.utils.stats
import
TrainingStats
from
ppocr.utils.save_load
import
save_model
from
ppocr.utils.save_load
import
save_model
from
ppocr.utils.utility
import
print_dict
from
ppocr.utils.logging
import
get_logger
from
ppocr.data
import
build_dataloader
import
numpy
as
np
class
ArgsParser
(
ArgumentParser
):
class
ArgsParser
(
ArgumentParser
):
...
@@ -136,18 +140,18 @@ def check_gpu(use_gpu):
...
@@ -136,18 +140,18 @@ def check_gpu(use_gpu):
def
train
(
config
,
def
train
(
config
,
train_dataloader
,
valid_dataloader
,
device
,
model
,
model
,
loss_class
,
loss_class
,
optimizer
,
optimizer
,
lr_scheduler
,
lr_scheduler
,
train_dataloader
,
valid_dataloader
,
post_process_class
,
post_process_class
,
eval_class
,
eval_class
,
pre_best_model_dict
,
pre_best_model_dict
,
logger
,
logger
,
vdl_writer
=
None
):
vdl_writer
=
None
):
global_step
=
0
cal_metric_during_train
=
config
[
'Global'
].
get
(
'cal_metric_during_train'
,
cal_metric_during_train
=
config
[
'Global'
].
get
(
'cal_metric_during_train'
,
False
)
False
)
...
@@ -156,6 +160,7 @@ def train(config,
...
@@ -156,6 +160,7 @@ def train(config,
print_batch_step
=
config
[
'Global'
][
'print_batch_step'
]
print_batch_step
=
config
[
'Global'
][
'print_batch_step'
]
eval_batch_step
=
config
[
'Global'
][
'eval_batch_step'
]
eval_batch_step
=
config
[
'Global'
][
'eval_batch_step'
]
global_step
=
0
start_eval_step
=
0
start_eval_step
=
0
if
type
(
eval_batch_step
)
==
list
and
len
(
eval_batch_step
)
>=
2
:
if
type
(
eval_batch_step
)
==
list
and
len
(
eval_batch_step
)
>=
2
:
start_eval_step
=
eval_batch_step
[
0
]
start_eval_step
=
eval_batch_step
[
0
]
...
@@ -179,26 +184,24 @@ def train(config,
...
@@ -179,26 +184,24 @@ def train(config,
start_epoch
=
0
start_epoch
=
0
for
epoch
in
range
(
start_epoch
,
epoch_num
):
for
epoch
in
range
(
start_epoch
,
epoch_num
):
if
epoch
>
0
:
train_loader
=
build_dataloader
(
config
,
'Train'
,
device
)
for
idx
,
batch
in
enumerate
(
train_dataloader
):
for
idx
,
batch
in
enumerate
(
train_dataloader
):
if
idx
>=
len
(
train_dataloader
):
if
idx
>=
len
(
train_dataloader
):
break
break
if
not
isinstance
(
lr_scheduler
,
float
):
lr_scheduler
.
step
()
lr
=
optimizer
.
get_lr
()
lr
=
optimizer
.
get_lr
()
t1
=
time
.
time
()
t1
=
time
.
time
()
batch
=
[
paddle
.
to_
variable
(
x
)
for
x
in
batch
]
batch
=
[
paddle
.
to_
tensor
(
x
)
for
x
in
batch
]
images
=
batch
[
0
]
images
=
batch
[
0
]
preds
=
model
(
images
)
preds
=
model
(
images
)
loss
=
loss_class
(
preds
,
batch
)
loss
=
loss_class
(
preds
,
batch
)
avg_loss
=
loss
[
'loss'
]
avg_loss
=
loss
[
'loss'
]
if
config
[
'Global'
][
'distributed'
]:
avg_loss
.
backward
()
avg_loss
=
model
.
scale_loss
(
avg_loss
)
avg_loss
.
backward
()
model
.
apply_collective_grads
()
else
:
avg_loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
optimizer
.
clear_grad
()
optimizer
.
clear_grad
()
if
not
isinstance
(
lr_scheduler
,
float
):
lr_scheduler
.
step
()
# logger and visualdl
# logger and visualdl
stats
=
{
k
:
v
.
numpy
().
mean
()
for
k
,
v
in
loss
.
items
()}
stats
=
{
k
:
v
.
numpy
().
mean
()
for
k
,
v
in
loss
.
items
()}
...
@@ -220,7 +223,8 @@ def train(config,
...
@@ -220,7 +223,8 @@ def train(config,
vdl_writer
.
add_scalar
(
'TRAIN/{}'
.
format
(
k
),
v
,
global_step
)
vdl_writer
.
add_scalar
(
'TRAIN/{}'
.
format
(
k
),
v
,
global_step
)
vdl_writer
.
add_scalar
(
'TRAIN/lr'
,
lr
,
global_step
)
vdl_writer
.
add_scalar
(
'TRAIN/lr'
,
lr
,
global_step
)
if
global_step
>
0
and
global_step
%
print_batch_step
==
0
:
if
dist
.
get_rank
(
)
==
0
and
global_step
>
0
and
global_step
%
print_batch_step
==
0
:
logs
=
train_stats
.
log
()
logs
=
train_stats
.
log
()
strs
=
'epoch: [{}/{}], iter: {}, {}, time: {:.3f}'
.
format
(
strs
=
'epoch: [{}/{}], iter: {}, {}, time: {:.3f}'
.
format
(
epoch
,
epoch_num
,
global_step
,
logs
,
train_batch_elapse
)
epoch
,
epoch_num
,
global_step
,
logs
,
train_batch_elapse
)
...
@@ -229,7 +233,7 @@ def train(config,
...
@@ -229,7 +233,7 @@ def train(config,
if
global_step
>
start_eval_step
and
\
if
global_step
>
start_eval_step
and
\
(
global_step
-
start_eval_step
)
%
eval_batch_step
==
0
and
dist
.
get_rank
()
==
0
:
(
global_step
-
start_eval_step
)
%
eval_batch_step
==
0
and
dist
.
get_rank
()
==
0
:
cur_metirc
=
eval
(
model
,
valid_dataloader
,
post_process_class
,
cur_metirc
=
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
)
eval_class
,
logger
,
print_batch_step
)
cur_metirc_str
=
'cur metirc, {}'
.
format
(
', '
.
join
(
cur_metirc_str
=
'cur metirc, {}'
.
format
(
', '
.
join
(
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
cur_metirc
.
items
()]))
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
cur_metirc
.
items
()]))
logger
.
info
(
cur_metirc_str
)
logger
.
info
(
cur_metirc_str
)
...
@@ -291,16 +295,17 @@ def train(config,
...
@@ -291,16 +295,17 @@ def train(config,
return
return
def
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
):
def
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
,
logger
,
print_batch_step
):
model
.
eval
()
model
.
eval
()
with
paddle
.
no_grad
():
with
paddle
.
no_grad
():
total_frame
=
0.0
total_frame
=
0.0
total_time
=
0.0
total_time
=
0.0
pbar
=
tqdm
(
total
=
len
(
valid_dataloader
),
desc
=
'eval model:
'
)
#
pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
for
idx
,
batch
in
enumerate
(
valid_dataloader
):
for
idx
,
batch
in
enumerate
(
valid_dataloader
):
if
idx
>=
len
(
valid_dataloader
):
if
idx
>=
len
(
valid_dataloader
):
break
break
images
=
paddle
.
to_
variable
(
batch
[
0
])
images
=
paddle
.
to_
tensor
(
batch
[
0
])
start
=
time
.
time
()
start
=
time
.
time
()
preds
=
model
(
images
)
preds
=
model
(
images
)
...
@@ -310,11 +315,15 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
...
@@ -310,11 +315,15 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
total_time
+=
time
.
time
()
-
start
total_time
+=
time
.
time
()
-
start
# Evaluate the results of the current batch
# Evaluate the results of the current batch
eval_class
(
post_result
,
batch
)
eval_class
(
post_result
,
batch
)
pbar
.
update
(
1
)
#
pbar.update(1)
total_frame
+=
len
(
images
)
total_frame
+=
len
(
images
)
if
idx
%
print_batch_step
==
0
and
dist
.
get_rank
()
==
0
:
logger
.
info
(
'tackling images for eval: {}/{}'
.
format
(
idx
,
len
(
valid_dataloader
)))
# Get final metirc,eg. acc or hmean
# Get final metirc,eg. acc or hmean
metirc
=
eval_class
.
get_metric
()
metirc
=
eval_class
.
get_metric
()
pbar
.
close
()
# pbar.close()
model
.
train
()
model
.
train
()
metirc
[
'fps'
]
=
total_frame
/
total_time
metirc
[
'fps'
]
=
total_frame
/
total_time
return
metirc
return
metirc
...
@@ -336,4 +345,24 @@ def preprocess():
...
@@ -336,4 +345,24 @@ def preprocess():
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
paddle
.
set_device
(
device
)
device
=
paddle
.
set_device
(
device
)
return
device
,
config
config
[
'Global'
][
'distributed'
]
=
dist
.
get_world_size
()
!=
1
# save_config
save_model_dir
=
config
[
'Global'
][
'save_model_dir'
]
os
.
makedirs
(
save_model_dir
,
exist_ok
=
True
)
with
open
(
os
.
path
.
join
(
save_model_dir
,
'config.yml'
),
'w'
)
as
f
:
yaml
.
dump
(
dict
(
config
),
f
,
default_flow_style
=
False
,
sort_keys
=
False
)
logger
=
get_logger
(
log_file
=
'{}/train.log'
.
format
(
save_model_dir
))
if
config
[
'Global'
][
'use_visualdl'
]:
from
visualdl
import
LogWriter
vdl_writer_path
=
'{}/vdl/'
.
format
(
save_model_dir
)
os
.
makedirs
(
vdl_writer_path
,
exist_ok
=
True
)
vdl_writer
=
LogWriter
(
logdir
=
vdl_writer_path
)
else
:
vdl_writer
=
None
print_dict
(
config
,
logger
)
logger
.
info
(
'train with paddle {} and device {}'
.
format
(
paddle
.
__version__
,
device
))
return
config
,
device
,
logger
,
vdl_writer
tools/train.py
View file @
96c91907
...
@@ -27,11 +27,11 @@ import yaml
...
@@ -27,11 +27,11 @@ import yaml
import
paddle
import
paddle
import
paddle.distributed
as
dist
import
paddle.distributed
as
dist
paddle
.
manual_
seed
(
2
)
paddle
.
seed
(
2
)
from
ppocr.utils.logging
import
get_logger
from
ppocr.data
import
build_dataloader
from
ppocr.data
import
build_dataloader
from
ppocr.modeling
import
build_model
,
build_loss
from
ppocr.modeling.architectures
import
build_model
from
ppocr.losses
import
build_loss
from
ppocr.optimizer
import
build_optimizer
from
ppocr.optimizer
import
build_optimizer
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.metrics
import
build_metric
from
ppocr.metrics
import
build_metric
...
@@ -48,95 +48,69 @@ def main(config, device, logger, vdl_writer):
...
@@ -48,95 +48,69 @@ def main(config, device, logger, vdl_writer):
dist
.
init_parallel_env
()
dist
.
init_parallel_env
()
global_config
=
config
[
'Global'
]
global_config
=
config
[
'Global'
]
# build dataloader
# build dataloader
train_loader
,
train_info_dict
=
build_dataloader
(
train_dataloader
=
build_dataloader
(
config
,
'Train'
,
device
,
logger
)
config
[
'TRAIN'
],
device
,
global_config
[
'distributed'
],
global_config
)
if
config
[
'Eval'
]:
if
config
[
'EVAL'
]:
valid_dataloader
=
build_dataloader
(
config
,
'Eval'
,
device
,
logger
)
eval_loader
,
_
=
build_dataloader
(
config
[
'EVAL'
],
device
,
False
,
global_config
)
else
:
else
:
eval_loader
=
None
valid_dataloader
=
None
# build post process
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
global_config
)
global_config
)
# build model
# build model
#
for rec algorithm
#for rec algorithm
if
hasattr
(
post_process_class
,
'character'
):
if
hasattr
(
post_process_class
,
'character'
):
c
onfig
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
len
(
c
har_num
=
len
(
getattr
(
post_process_class
,
'character'
))
getattr
(
post_process_class
,
'character'
))
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
if
config
[
'Global'
][
'distributed'
]:
if
config
[
'Global'
][
'distributed'
]:
model
=
paddle
.
DataParallel
(
model
)
model
=
paddle
.
DataParallel
(
model
)
# build loss
loss_class
=
build_loss
(
config
[
'Loss'
])
# build optim
# build optim
optimizer
,
lr_scheduler
=
build_optimizer
(
optimizer
,
lr_scheduler
=
build_optimizer
(
config
[
'Optimizer'
],
config
[
'Optimizer'
],
epochs
=
config
[
'Global'
][
'epoch_num'
],
epochs
=
config
[
'Global'
][
'epoch_num'
],
step_each_epoch
=
len
(
train_loader
),
step_each_epoch
=
len
(
train_
data
loader
),
parameters
=
model
.
parameters
())
parameters
=
model
.
parameters
())
best_model_dict
=
init_model
(
config
,
model
,
logger
,
optimizer
)
# build loss
loss_class
=
build_loss
(
config
[
'Loss'
])
# build metric
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
eval_class
=
build_metric
(
config
[
'Metric'
])
# load pretrain model
pre_best_model_dict
=
init_model
(
config
,
model
,
logger
,
optimizer
)
# start train
# start train
program
.
train
(
config
,
mo
de
l
,
loss_class
,
optimizer
,
lr_scheduler
,
program
.
train
(
config
,
train_dataloa
de
r
,
valid_dataloader
,
device
,
model
,
train_loader
,
eval_load
er
,
post_process_class
,
eval_class
,
loss_class
,
optimizer
,
lr_schedul
er
,
post_process_class
,
best_model_dict
,
logger
,
vdl_writer
)
eval_class
,
pre_
best_model_dict
,
logger
,
vdl_writer
)
def
test_reader
(
config
,
pla
ce
,
logger
,
global_config
):
def
test_reader
(
config
,
devi
ce
,
logger
):
train_
loader
,
_
=
build_dataloader
(
loader
=
build_dataloader
(
config
,
'Train'
,
device
)
config
[
'TRAIN'
],
place
,
global_config
=
global_config
)
#
loader = build_dataloader(config, 'Eval', device
)
import
time
import
time
starttime
=
time
.
time
()
starttime
=
time
.
time
()
count
=
0
count
=
0
try
:
try
:
for
data
in
train_
loader
:
for
data
in
loader
()
:
count
+=
1
count
+=
1
if
count
%
1
==
0
:
if
count
%
1
==
0
:
batch_time
=
time
.
time
()
-
starttime
batch_time
=
time
.
time
()
-
starttime
starttime
=
time
.
time
()
starttime
=
time
.
time
()
logger
.
info
(
"reader: {}, {}, {}"
.
format
(
logger
.
info
(
"reader: {}, {}, {}"
.
format
(
count
,
count
,
len
(
data
[
0
]
),
batch_time
))
len
(
data
),
batch_time
))
except
Exception
as
e
:
except
Exception
as
e
:
import
traceback
traceback
.
print_exc
()
logger
.
info
(
e
)
logger
.
info
(
e
)
logger
.
info
(
"finish reader: {}, Success!"
.
format
(
count
))
logger
.
info
(
"finish reader: {}, Success!"
.
format
(
count
))
def
dis_main
():
device
,
config
=
program
.
preprocess
()
config
[
'Global'
][
'distributed'
]
=
dist
.
get_world_size
()
!=
1
paddle
.
disable_static
(
device
)
# save_config
os
.
makedirs
(
config
[
'Global'
][
'save_model_dir'
],
exist_ok
=
True
)
with
open
(
os
.
path
.
join
(
config
[
'Global'
][
'save_model_dir'
],
'config.yml'
),
'w'
)
as
f
:
yaml
.
dump
(
dict
(
config
),
f
,
default_flow_style
=
False
,
sort_keys
=
False
)
logger
=
get_logger
(
log_file
=
'{}/train.log'
.
format
(
config
[
'Global'
][
'save_model_dir'
]))
if
config
[
'Global'
][
'use_visualdl'
]:
from
visualdl
import
LogWriter
vdl_writer
=
LogWriter
(
logdir
=
config
[
'Global'
][
'save_model_dir'
])
else
:
vdl_writer
=
None
print_dict
(
config
,
logger
)
logger
.
info
(
'train with paddle {} and device {}'
.
format
(
paddle
.
__version__
,
device
))
main
(
config
,
device
,
logger
,
vdl_writer
)
# test_reader(config, device, logger, config['Global'])
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
# main
()
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
# dist.spawn(dis_main, nprocs=2, selelcted_gpus='6,7'
)
main
(
config
,
device
,
logger
,
vdl_writer
)
dis_main
(
)
#
test_reader(config, device, logger
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment