Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
fa675f89
Commit
fa675f89
authored
Nov 04, 2020
by
dyning
Browse files
updata structure of dygraph
parent
7d09cd19
Changes
38
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
144 additions
and
159 deletions
+144
-159
ppocr/losses/det_db_loss.py
ppocr/losses/det_db_loss.py
+0
-0
ppocr/losses/rec_ctc_loss.py
ppocr/losses/rec_ctc_loss.py
+0
-0
ppocr/modeling/__init__.py
ppocr/modeling/__init__.py
+0
-26
ppocr/modeling/architectures/__init__.py
ppocr/modeling/architectures/__init__.py
+10
-2
ppocr/modeling/architectures/base_model.py
ppocr/modeling/architectures/base_model.py
+11
-19
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+0
-1
ppocr/modeling/backbones/det_mobilenet_v3.py
ppocr/modeling/backbones/det_mobilenet_v3.py
+1
-2
ppocr/modeling/heads/__init__.py
ppocr/modeling/heads/__init__.py
+2
-2
ppocr/modeling/heads/rec_ctc_head.py
ppocr/modeling/heads/rec_ctc_head.py
+3
-4
ppocr/modeling/necks/__init__.py
ppocr/modeling/necks/__init__.py
+2
-3
ppocr/modeling/necks/db_fpn.py
ppocr/modeling/necks/db_fpn.py
+2
-2
ppocr/modeling/necks/rnn.py
ppocr/modeling/necks/rnn.py
+1
-2
ppocr/optimizer/__init__.py
ppocr/optimizer/__init__.py
+2
-0
ppocr/optimizer/optimizer.py
ppocr/optimizer/optimizer.py
+2
-2
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+1
-1
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+15
-16
tools/program.py
tools/program.py
+47
-14
tools/train.py
tools/train.py
+45
-63
No files found.
ppocr/
modeling/
losses/det_db_loss.py
→
ppocr/losses/det_db_loss.py
View file @
fa675f89
File moved
ppocr/
modeling/
losses/rec_ctc_loss.py
→
ppocr/losses/rec_ctc_loss.py
View file @
fa675f89
File moved
ppocr/modeling/__init__.py
deleted
100755 → 0
View file @
7d09cd19
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
from
.losses
import
build_loss
__all__
=
[
'build_model'
,
'build_loss'
]
def
build_model
(
config
):
from
.architectures
import
Model
config
=
copy
.
deepcopy
(
config
)
module_class
=
Model
(
config
)
return
module_class
ppocr/modeling/architectures/__init__.py
View file @
fa675f89
...
...
@@ -12,5 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.model
import
Model
__all__
=
[
'Model'
]
\ No newline at end of file
import
copy
__all__
=
[
'build_model'
]
def
build_model
(
config
):
from
.base_model
import
BaseModel
config
=
copy
.
deepcopy
(
config
)
module_class
=
BaseModel
(
config
)
return
module_class
\ No newline at end of file
ppocr/modeling/architectures/model.py
→
ppocr/modeling/architectures/
base_
model.py
View file @
fa675f89
...
...
@@ -15,34 +15,25 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
,
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
'/home/zhoujun20/PaddleOCR'
)
from
paddle
import
nn
from
ppocr.modeling.transform
import
build_transform
from
ppocr.modeling.backbones
import
build_backbone
from
ppocr.modeling.necks
import
build_neck
from
ppocr.modeling.heads
import
build_head
__all__
=
[
'Model'
]
__all__
=
[
'
Base
Model'
]
class
Model
(
nn
.
Layer
):
class
BaseModel
(
nn
.
Layer
):
def
__init__
(
self
,
config
):
"""
Detection
module for OCR.
the
module for OCR.
args:
config (dict): the super parameters for module.
"""
super
(
Model
,
self
).
__init__
()
algorithm
=
config
[
'algorithm'
]
self
.
type
=
config
[
'type'
]
self
.
model_name
=
'{}_{}'
.
format
(
self
.
type
,
algorithm
)
super
(
BaseModel
,
self
).
__init__
()
in_channels
=
config
.
get
(
'in_channels'
,
3
)
model_type
=
config
[
'model_type'
]
# build transfrom,
# for rec, transfrom can be TPS,None
# for det and cls, transfrom shoule to be None,
...
...
@@ -57,7 +48,7 @@ class Model(nn.Layer):
# build backbone, backbone is need for del, rec and cls
config
[
"Backbone"
][
'in_channels'
]
=
in_channels
self
.
backbone
=
build_backbone
(
config
[
"Backbone"
],
self
.
type
)
self
.
backbone
=
build_backbone
(
config
[
"Backbone"
],
model_
type
)
in_channels
=
self
.
backbone
.
out_channels
# build neck
...
...
@@ -71,6 +62,7 @@ class Model(nn.Layer):
config
[
'Neck'
][
'in_channels'
]
=
in_channels
self
.
neck
=
build_neck
(
config
[
'Neck'
])
in_channels
=
self
.
neck
.
out_channels
# # build head, head is need for det, rec and cls
config
[
"Head"
][
'in_channels'
]
=
in_channels
self
.
head
=
build_head
(
config
[
"Head"
])
...
...
ppocr/modeling/backbones/__init__.py
View file @
fa675f89
...
...
@@ -19,7 +19,6 @@ def build_backbone(config, model_type):
if
model_type
==
'det'
:
from
.det_mobilenet_v3
import
MobileNetV3
from
.det_resnet_vd
import
ResNet
support_dict
=
[
'MobileNetV3'
,
'ResNet'
,
'ResNet_SAST'
]
elif
model_type
==
'rec'
:
from
.rec_mobilenet_v3
import
MobileNetV3
...
...
ppocr/modeling/backbones/det_mobilenet_v3.py
View file @
fa675f89
...
...
@@ -130,7 +130,6 @@ class MobileNetV3(nn.Layer):
if_act
=
True
,
act
=
'hard_swish'
,
name
=
'conv_last'
))
self
.
stages
.
append
(
nn
.
Sequential
(
*
block_list
))
self
.
out_channels
.
append
(
make_divisible
(
scale
*
cls_ch_squeeze
))
for
i
,
stage
in
enumerate
(
self
.
stages
):
...
...
ppocr/modeling/heads/__init__.py
View file @
fa675f89
...
...
@@ -20,8 +20,8 @@ def build_head(config):
from
.det_db_head
import
DBHead
# rec head
from
.rec_ctc_head
import
CTC
support_dict
=
[
'DBHead'
,
'CTC'
]
from
.rec_ctc_head
import
CTC
Head
support_dict
=
[
'DBHead'
,
'CTC
Head
'
]
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
'head only support {}'
.
format
(
...
...
ppocr/modeling/heads/rec_ctc_head.py
View file @
fa675f89
...
...
@@ -33,10 +33,9 @@ def get_para_bias_attr(l2_decay, k, name):
regularizer
=
regularizer
,
initializer
=
initializer
,
name
=
name
+
"_b_attr"
)
return
[
weight_attr
,
bias_attr
]
class
CTC
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
fc_decay
=
1e-5
,
**
kwargs
):
super
(
CTC
,
self
).
__init__
()
class
CTCHead
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
fc_decay
=
0.0004
,
**
kwargs
):
super
(
CTCHead
,
self
).
__init__
()
weight_attr
,
bias_attr
=
get_para_bias_attr
(
l2_decay
=
fc_decay
,
k
=
in_channels
,
name
=
'ctc_fc'
)
self
.
fc
=
nn
.
Linear
(
...
...
ppocr/modeling/necks/__init__.py
View file @
fa675f89
...
...
@@ -14,11 +14,10 @@
__all__
=
[
'build_neck'
]
def
build_neck
(
config
):
from
.fpn
import
FPN
from
.
db_
fpn
import
DB
FPN
from
.rnn
import
SequenceEncoder
support_dict
=
[
'FPN'
,
'SequenceEncoder'
]
support_dict
=
[
'
DB
FPN'
,
'SequenceEncoder'
]
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
'neck only support {}'
.
format
(
...
...
ppocr/modeling/necks/fpn.py
→
ppocr/modeling/necks/
db_
fpn.py
View file @
fa675f89
...
...
@@ -22,9 +22,9 @@ import paddle.nn.functional as F
from
paddle
import
ParamAttr
class
FPN
(
nn
.
Layer
):
class
DB
FPN
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
**
kwargs
):
super
(
FPN
,
self
).
__init__
()
super
(
DB
FPN
,
self
).
__init__
()
self
.
out_channels
=
out_channels
weight_attr
=
paddle
.
nn
.
initializer
.
MSRA
(
uniform
=
False
)
...
...
ppocr/modeling/necks/rnn.py
View file @
fa675f89
...
...
@@ -76,8 +76,7 @@ class SequenceEncoder(nn.Layer):
'fc'
:
EncoderWithFC
,
'rnn'
:
EncoderWithRNN
}
assert
encoder_type
in
support_encoder_dict
,
'{} must in {}'
.
format
(
encoder_type
,
support_encoder_dict
.
keys
())
assert
encoder_type
in
support_encoder_dict
,
'{} must in {}'
.
format
(
encoder_type
,
support_encoder_dict
.
keys
())
self
.
encoder
=
support_encoder_dict
[
encoder_type
](
self
.
encoder_reshape
.
out_channels
,
hidden_size
)
...
...
ppocr/optimizer/__init__.py
View file @
fa675f89
...
...
@@ -50,6 +50,8 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
# step3 build optimizer
optim_name
=
config
.
pop
(
'name'
)
# Regularization is invalid. The bug will be fixed in paddle-rc. The param is
# weight_decay.
optim
=
getattr
(
optimizer
,
optim_name
)(
learning_rate
=
lr
,
regularization
=
reg
,
**
config
)
...
...
ppocr/optimizer/optimizer.py
View file @
fa675f89
...
...
@@ -40,8 +40,8 @@ class Momentum(object):
opt
=
optim
.
Momentum
(
learning_rate
=
self
.
learning_rate
,
momentum
=
self
.
momentum
,
parameters
=
self
.
weight_decay
,
weight_decay
=
parameters
)
parameters
=
parameters
,
weight_decay
=
self
.
weight_decay
)
return
opt
...
...
ppocr/postprocess/__init__.py
View file @
fa675f89
...
...
@@ -24,8 +24,8 @@ __all__ = ['build_post_process']
def
build_post_process
(
config
,
global_config
=
None
):
from
.db_postprocess
import
DBPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
support_dict
=
[
'DBPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
]
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/utils/save_load.py
View file @
fa675f89
...
...
@@ -46,7 +46,7 @@ def load_dygraph_pretrain(
model
,
logger
,
path
=
None
,
load_static_weights
=
False
,
):
load_static_weights
=
False
):
if
not
(
os
.
path
.
isdir
(
path
)
or
os
.
path
.
exists
(
path
+
'.pdparams'
)):
raise
ValueError
(
"Model pretrain path {} does not "
"exists."
.
format
(
path
))
...
...
@@ -110,7 +110,6 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
logger
.
info
(
"resume from {}"
.
format
(
checkpoints
))
elif
pretrained_model
:
load_static_weights
=
gloabl_config
.
get
(
'load_static_weights'
,
False
)
if
pretrained_model
:
if
not
isinstance
(
pretrained_model
,
list
):
pretrained_model
=
[
pretrained_model
]
if
not
isinstance
(
load_static_weights
,
list
):
...
...
tools/program.py
View file @
fa675f89
...
...
@@ -28,7 +28,10 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter
from
ppocr.utils.stats
import
TrainingStats
from
ppocr.utils.save_load
import
save_model
from
ppocr.utils.utility
import
print_dict
from
ppocr.utils.logging
import
get_logger
from
ppocr.data
import
build_dataloader
import
numpy
as
np
class
ArgsParser
(
ArgumentParser
):
def
__init__
(
self
):
...
...
@@ -136,18 +139,18 @@ def check_gpu(use_gpu):
def
train
(
config
,
train_dataloader
,
valid_dataloader
,
device
,
model
,
loss_class
,
optimizer
,
lr_scheduler
,
train_dataloader
,
valid_dataloader
,
post_process_class
,
eval_class
,
pre_best_model_dict
,
logger
,
vdl_writer
=
None
):
global_step
=
0
cal_metric_during_train
=
config
[
'Global'
].
get
(
'cal_metric_during_train'
,
False
)
...
...
@@ -156,6 +159,7 @@ def train(config,
print_batch_step
=
config
[
'Global'
][
'print_batch_step'
]
eval_batch_step
=
config
[
'Global'
][
'eval_batch_step'
]
global_step
=
0
start_eval_step
=
0
if
type
(
eval_batch_step
)
==
list
and
len
(
eval_batch_step
)
>=
2
:
start_eval_step
=
eval_batch_step
[
0
]
...
...
@@ -179,14 +183,15 @@ def train(config,
start_epoch
=
0
for
epoch
in
range
(
start_epoch
,
epoch_num
):
if
epoch
>
0
:
train_loader
=
build_dataloader
(
config
,
'Train'
,
device
)
for
idx
,
batch
in
enumerate
(
train_dataloader
):
if
idx
>=
len
(
train_dataloader
):
break
if
not
isinstance
(
lr_scheduler
,
float
):
lr_scheduler
.
step
()
lr
=
optimizer
.
get_lr
()
t1
=
time
.
time
()
batch
=
[
paddle
.
to_
variable
(
x
)
for
x
in
batch
]
batch
=
[
paddle
.
to_
tensor
(
x
)
for
x
in
batch
]
images
=
batch
[
0
]
preds
=
model
(
images
)
loss
=
loss_class
(
preds
,
batch
)
...
...
@@ -199,6 +204,8 @@ def train(config,
avg_loss
.
backward
()
optimizer
.
step
()
optimizer
.
clear_grad
()
if
not
isinstance
(
lr_scheduler
,
float
):
lr_scheduler
.
step
()
# logger and visualdl
stats
=
{
k
:
v
.
numpy
().
mean
()
for
k
,
v
in
loss
.
items
()}
...
...
@@ -228,8 +235,8 @@ def train(config,
# eval
if
global_step
>
start_eval_step
and
\
(
global_step
-
start_eval_step
)
%
eval_batch_step
==
0
and
dist
.
get_rank
()
==
0
:
cur_metirc
=
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
)
cur_metirc
=
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
,
logger
,
print_batch_step
)
cur_metirc_str
=
'cur metirc, {}'
.
format
(
', '
.
join
(
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
cur_metirc
.
items
()]))
logger
.
info
(
cur_metirc_str
)
...
...
@@ -291,12 +298,14 @@ def train(config,
return
def
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
):
def
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
,
logger
,
print_batch_step
):
model
.
eval
()
with
paddle
.
no_grad
():
total_frame
=
0.0
total_time
=
0.0
pbar
=
tqdm
(
total
=
len
(
valid_dataloader
),
desc
=
'eval model:
'
)
#
pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
for
idx
,
batch
in
enumerate
(
valid_dataloader
):
if
idx
>=
len
(
valid_dataloader
):
break
...
...
@@ -310,11 +319,14 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
total_time
+=
time
.
time
()
-
start
# Evaluate the results of the current batch
eval_class
(
post_result
,
batch
)
pbar
.
update
(
1
)
#
pbar.update(1)
total_frame
+=
len
(
images
)
if
idx
%
print_batch_step
==
0
:
logger
.
info
(
'tackling images for eval: {}/{}'
.
format
(
idx
,
len
(
valid_dataloader
)))
# Get final metirc,eg. acc or hmean
metirc
=
eval_class
.
get_metric
()
pbar
.
close
()
#
pbar.close()
model
.
train
()
metirc
[
'fps'
]
=
total_frame
/
total_time
return
metirc
...
...
@@ -336,4 +348,25 @@ def preprocess():
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
paddle
.
set_device
(
device
)
return
device
,
config
config
[
'Global'
][
'distributed'
]
=
dist
.
get_world_size
()
!=
1
paddle
.
disable_static
(
device
)
# save_config
save_model_dir
=
config
[
'Global'
][
'save_model_dir'
]
os
.
makedirs
(
save_model_dir
,
exist_ok
=
True
)
with
open
(
os
.
path
.
join
(
save_model_dir
,
'config.yml'
),
'w'
)
as
f
:
yaml
.
dump
(
dict
(
config
),
f
,
default_flow_style
=
False
,
sort_keys
=
False
)
logger
=
get_logger
(
log_file
=
'{}/train.log'
.
format
(
save_model_dir
))
if
config
[
'Global'
][
'use_visualdl'
]:
from
visualdl
import
LogWriter
vdl_writer_path
=
'{}/vdl/'
.
format
(
save_model_dir
)
os
.
makedirs
(
vdl_writer_path
,
exist_ok
=
True
)
vdl_writer
=
LogWriter
(
logdir
=
vdl_writer_path
)
else
:
vdl_writer
=
None
print_dict
(
config
,
logger
)
logger
.
info
(
'train with paddle {} and device {}'
.
format
(
paddle
.
__version__
,
device
))
return
config
,
device
,
logger
,
vdl_writer
tools/train.py
View file @
fa675f89
...
...
@@ -31,7 +31,8 @@ paddle.manual_seed(2)
from
ppocr.utils.logging
import
get_logger
from
ppocr.data
import
build_dataloader
from
ppocr.modeling
import
build_model
,
build_loss
from
ppocr.modeling.architectures
import
build_model
from
ppocr.losses
import
build_loss
from
ppocr.optimizer
import
build_optimizer
from
ppocr.postprocess
import
build_post_process
from
ppocr.metrics
import
build_metric
...
...
@@ -48,95 +49,76 @@ def main(config, device, logger, vdl_writer):
dist
.
init_parallel_env
()
global_config
=
config
[
'Global'
]
# build dataloader
train_loader
,
train_info_dict
=
build_dataloader
(
config
[
'TRAIN'
],
device
,
global_config
[
'distributed'
],
global_config
)
if
config
[
'EVAL'
]:
eval_loader
,
_
=
build_dataloader
(
config
[
'EVAL'
],
device
,
False
,
global_config
)
train_dataloader
=
build_dataloader
(
config
,
'Train'
,
device
)
if
config
[
'Eval'
]:
valid_dataloader
=
build_dataloader
(
config
,
'Eval'
,
device
)
else
:
eval_loader
=
None
valid_dataloader
=
None
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
global_config
)
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
global_config
)
# build model
#
for rec algorithm
#for rec algorithm
if
hasattr
(
post_process_class
,
'character'
):
c
onfig
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
len
(
getattr
(
post_process_class
,
'character'
))
c
har_num
=
len
(
getattr
(
post_process_class
,
'character'
))
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
model
=
build_model
(
config
[
'Architecture'
])
if
config
[
'Global'
][
'distributed'
]:
model
=
paddle
.
DataParallel
(
model
)
# build loss
loss_class
=
build_loss
(
config
[
'Loss'
])
# build optim
optimizer
,
lr_scheduler
=
build_optimizer
(
config
[
'Optimizer'
],
optimizer
,
lr_scheduler
=
build_optimizer
(
config
[
'Optimizer'
],
epochs
=
config
[
'Global'
][
'epoch_num'
],
step_each_epoch
=
len
(
train_loader
),
step_each_epoch
=
len
(
train_
data
loader
),
parameters
=
model
.
parameters
())
best_model_dict
=
init_model
(
config
,
model
,
logger
,
optimizer
)
# build loss
loss_class
=
build_loss
(
config
[
'Loss'
])
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
# start train
program
.
train
(
config
,
model
,
loss_class
,
optimizer
,
lr_scheduler
,
train_loader
,
eval_loader
,
post_process_class
,
eval_class
,
best_model_dict
,
logger
,
vdl_writer
)
# load pretrain model
pre_best_model_dict
=
init_model
(
config
,
model
,
logger
,
optimizer
)
def
test_reader
(
config
,
place
,
logger
,
global_config
):
train_loader
,
_
=
build_dataloader
(
config
[
'TRAIN'
],
place
,
global_config
=
global_config
)
# start train
program
.
train
(
config
,
train_dataloader
,
valid_dataloader
,
device
,
model
,
loss_class
,
optimizer
,
lr_scheduler
,
post_process_class
,
eval_class
,
pre_best_model_dict
,
logger
,
vdl_writer
)
def
test_reader
(
config
,
device
,
logger
):
loader
=
build_dataloader
(
config
,
'Train'
,
device
)
# loader = build_dataloader(config, 'Eval', device)
import
time
starttime
=
time
.
time
()
count
=
0
try
:
for
data
in
train_
loader
:
for
data
in
loader
()
:
count
+=
1
if
count
%
1
==
0
:
batch_time
=
time
.
time
()
-
starttime
starttime
=
time
.
time
()
logger
.
info
(
"reader: {}, {}, {}"
.
format
(
count
,
len
(
data
[
0
]),
batch_time
))
logger
.
info
(
"reader: {}, {}, {}"
.
format
(
count
,
len
(
data
),
batch_time
))
except
Exception
as
e
:
import
traceback
traceback
.
print_exc
()
logger
.
info
(
e
)
logger
.
info
(
"finish reader: {}, Success!"
.
format
(
count
))
def
dis_main
():
device
,
config
=
program
.
preprocess
()
config
[
'Global'
][
'distributed'
]
=
dist
.
get_world_size
()
!=
1
paddle
.
disable_static
(
device
)
# save_config
os
.
makedirs
(
config
[
'Global'
][
'save_model_dir'
],
exist_ok
=
True
)
with
open
(
os
.
path
.
join
(
config
[
'Global'
][
'save_model_dir'
],
'config.yml'
),
'w'
)
as
f
:
yaml
.
dump
(
dict
(
config
),
f
,
default_flow_style
=
False
,
sort_keys
=
False
)
logger
=
get_logger
(
log_file
=
'{}/train.log'
.
format
(
config
[
'Global'
][
'save_model_dir'
]))
if
config
[
'Global'
][
'use_visualdl'
]:
from
visualdl
import
LogWriter
vdl_writer
=
LogWriter
(
logdir
=
config
[
'Global'
][
'save_model_dir'
])
else
:
vdl_writer
=
None
print_dict
(
config
,
logger
)
logger
.
info
(
'train with paddle {} and device {}'
.
format
(
paddle
.
__version__
,
device
))
main
(
config
,
device
,
logger
,
vdl_writer
)
# test_reader(config, device, logger, config['Global'])
if
__name__
==
'__main__'
:
# main
()
# dist.spawn(dis_main, nprocs=2, selelcted_gpus='6,7'
)
dis_main
(
)
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
main
(
config
,
device
,
logger
,
vdl_writer
)
#
test_reader(config, device, logger
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment