Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
85aeae71
Unverified
Commit
85aeae71
authored
Jun 09, 2021
by
Double_V
Committed by
GitHub
Jun 09, 2021
Browse files
Merge pull request #3002 from littletomatodonkey/dyg/add_distillation
add distillation
parents
d93a445d
95d07675
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
708 additions
and
161 deletions
+708
-161
configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml
...h_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml
+158
-0
ppocr/losses/__init__.py
ppocr/losses/__init__.py
+24
-15
ppocr/losses/basic_loss.py
ppocr/losses/basic_loss.py
+103
-0
ppocr/losses/cls_loss.py
ppocr/losses/cls_loss.py
+1
-1
ppocr/losses/combined_loss.py
ppocr/losses/combined_loss.py
+58
-0
ppocr/losses/distillation_loss.py
ppocr/losses/distillation_loss.py
+108
-0
ppocr/losses/rec_ctc_loss.py
ppocr/losses/rec_ctc_loss.py
+1
-1
ppocr/metrics/__init__.py
ppocr/metrics/__init__.py
+12
-9
ppocr/metrics/distillation_metric.py
ppocr/metrics/distillation_metric.py
+76
-0
ppocr/modeling/architectures/__init__.py
ppocr/modeling/architectures/__init__.py
+12
-4
ppocr/modeling/architectures/base_model.py
ppocr/modeling/architectures/base_model.py
+10
-2
ppocr/modeling/architectures/distillation_model.py
ppocr/modeling/architectures/distillation_model.py
+60
-0
ppocr/modeling/backbones/det_mobilenet_v3.py
ppocr/modeling/backbones/det_mobilenet_v3.py
+13
-32
ppocr/modeling/backbones/rec_mobilenet_v3.py
ppocr/modeling/backbones/rec_mobilenet_v3.py
+3
-6
ppocr/modeling/heads/det_db_head.py
ppocr/modeling/heads/det_db_head.py
+5
-16
ppocr/modeling/heads/rec_ctc_head.py
ppocr/modeling/heads/rec_ctc_head.py
+5
-8
ppocr/modeling/necks/db_fpn.py
ppocr/modeling/necks/db_fpn.py
+8
-16
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+9
-8
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+31
-0
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+11
-43
No files found.
configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml
0 → 100644
View file @
85aeae71
Global
:
debug
:
false
use_gpu
:
true
epoch_num
:
800
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec_chinese_lite_distillation_v2.1
save_epoch_step
:
3
eval_batch_step
:
[
0
,
2000
]
cal_metric_during_train
:
true
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
false
infer_img
:
doc/imgs_words/ch/word_1.jpg
character_dict_path
:
ppocr/utils/ppocr_keys_v1.txt
character_type
:
ch
max_text_length
:
25
infer_mode
:
false
use_space_char
:
false
distributed
:
true
save_res_path
:
./output/rec/predicts_chinese_lite_distillation_v2.1.txt
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
Cosine
learning_rate
:
0.0005
warmup_epoch
:
5
regularizer
:
name
:
L2
factor
:
1.0e-05
Architecture
:
name
:
DistillationModel
algorithm
:
Distillation
Models
:
Student
:
pretrained
:
freeze_params
:
false
return_all_feats
:
true
model_type
:
rec
algorithm
:
CRNN
Transform
:
Backbone
:
name
:
MobileNetV3
scale
:
0.5
model_name
:
small
small_stride
:
[
1
,
2
,
2
,
2
]
Neck
:
name
:
SequenceEncoder
encoder_type
:
rnn
hidden_size
:
48
Head
:
name
:
CTCHead
fc_decay
:
0.00001
Teacher
:
pretrained
:
freeze_params
:
false
return_all_feats
:
true
model_type
:
rec
algorithm
:
CRNN
Transform
:
Backbone
:
name
:
MobileNetV3
scale
:
0.5
model_name
:
small
small_stride
:
[
1
,
2
,
2
,
2
]
Neck
:
name
:
SequenceEncoder
encoder_type
:
rnn
hidden_size
:
48
Head
:
name
:
CTCHead
fc_decay
:
0.00001
Loss
:
name
:
CombinedLoss
loss_config_list
:
-
DistillationCTCLoss
:
weight
:
1.0
model_name_list
:
[
"
Student"
,
"
Teacher"
]
key
:
head_out
-
DistillationDMLLoss
:
weight
:
1.0
act
:
"
softmax"
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
key
:
head_out
-
DistillationDistanceLoss
:
weight
:
1.0
mode
:
"
l2"
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
key
:
backbone_out
PostProcess
:
name
:
DistillationCTCLabelDecode
model_name
:
[
"
Student"
,
"
Teacher"
]
key
:
head_out
Metric
:
name
:
DistillationMetric
base_metric_name
:
RecMetric
main_indicator
:
acc
key
:
"
Student"
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/
label_file_list
:
-
./train_data/train_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
RecAug
:
-
CTCLabelEncode
:
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
320
]
-
KeepKeys
:
keep_keys
:
-
image
-
label
-
length
loader
:
shuffle
:
true
batch_size_per_card
:
128
drop_last
:
true
num_sections
:
1
num_workers
:
8
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data
label_file_list
:
-
./train_data/val_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
CTCLabelEncode
:
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
320
]
-
KeepKeys
:
keep_keys
:
-
image
-
label
-
length
loader
:
shuffle
:
false
drop_last
:
false
batch_size_per_card
:
128
num_workers
:
8
ppocr/losses/__init__.py
View file @
85aeae71
...
...
@@ -13,28 +13,37 @@
# limitations under the License.
import
copy
import
paddle
import
paddle.nn
as
nn
# det loss
from
.det_db_loss
import
DBLoss
from
.det_east_loss
import
EASTLoss
from
.det_sast_loss
import
SASTLoss
def
build_loss
(
config
):
# det loss
from
.det_db_loss
import
DBLoss
from
.det_east_loss
import
EASTLoss
from
.det_sast_loss
import
SASTLoss
# rec loss
from
.rec_ctc_loss
import
CTCLoss
from
.rec_att_loss
import
AttentionLoss
from
.rec_srn_loss
import
SRNLoss
# cls loss
from
.cls_loss
import
ClsLoss
# e2e loss
from
.e2e_pg_loss
import
PGLoss
# rec loss
from
.rec_ctc_loss
import
CTCLoss
from
.rec_att_loss
import
AttentionLoss
from
.rec_srn_loss
import
SRNLoss
# basic loss function
from
.basic_loss
import
DistanceLoss
# cls loss
from
.c
ls
_loss
import
C
ls
Loss
# combined loss function
from
.c
ombined
_loss
import
C
ombined
Loss
# e2e loss
from
.e2e_pg_loss
import
PGLoss
def
build_loss
(
config
):
support_dict
=
[
'DBLoss'
,
'EASTLoss'
,
'SASTLoss'
,
'CTCLoss'
,
'ClsLoss'
,
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
]
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
]
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
'loss only support {}'
.
format
(
...
...
ppocr/losses/basic_loss.py
0 → 100644
View file @
85aeae71
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn
import
L1Loss
from
paddle.nn
import
MSELoss
as
L2Loss
from
paddle.nn
import
SmoothL1Loss
class
CELoss
(
nn
.
Layer
):
def
__init__
(
self
,
epsilon
=
None
):
super
().
__init__
()
if
epsilon
is
not
None
and
(
epsilon
<=
0
or
epsilon
>=
1
):
epsilon
=
None
self
.
epsilon
=
epsilon
def
_labelsmoothing
(
self
,
target
,
class_num
):
if
target
.
shape
[
-
1
]
!=
class_num
:
one_hot_target
=
F
.
one_hot
(
target
,
class_num
)
else
:
one_hot_target
=
target
soft_target
=
F
.
label_smooth
(
one_hot_target
,
epsilon
=
self
.
epsilon
)
soft_target
=
paddle
.
reshape
(
soft_target
,
shape
=
[
-
1
,
class_num
])
return
soft_target
def
forward
(
self
,
x
,
label
):
loss_dict
=
{}
if
self
.
epsilon
is
not
None
:
class_num
=
x
.
shape
[
-
1
]
label
=
self
.
_labelsmoothing
(
label
,
class_num
)
x
=
-
F
.
log_softmax
(
x
,
axis
=-
1
)
loss
=
paddle
.
sum
(
x
*
label
,
axis
=-
1
)
else
:
if
label
.
shape
[
-
1
]
==
x
.
shape
[
-
1
]:
label
=
F
.
softmax
(
label
,
axis
=-
1
)
soft_label
=
True
else
:
soft_label
=
False
loss
=
F
.
cross_entropy
(
x
,
label
=
label
,
soft_label
=
soft_label
)
return
loss
class
DMLLoss
(
nn
.
Layer
):
"""
DMLLoss
"""
def
__init__
(
self
,
act
=
None
):
super
().
__init__
()
if
act
is
not
None
:
assert
act
in
[
"softmax"
,
"sigmoid"
]
if
act
==
"softmax"
:
self
.
act
=
nn
.
Softmax
(
axis
=-
1
)
elif
act
==
"sigmoid"
:
self
.
act
=
nn
.
Sigmoid
()
else
:
self
.
act
=
None
def
forward
(
self
,
out1
,
out2
):
if
self
.
act
is
not
None
:
out1
=
self
.
act
(
out1
)
out2
=
self
.
act
(
out2
)
log_out1
=
paddle
.
log
(
out1
)
log_out2
=
paddle
.
log
(
out2
)
loss
=
(
F
.
kl_div
(
log_out1
,
out2
,
reduction
=
'batchmean'
)
+
F
.
kl_div
(
log_out2
,
out1
,
reduction
=
'batchmean'
))
/
2.0
return
loss
class
DistanceLoss
(
nn
.
Layer
):
"""
DistanceLoss:
mode: loss mode
"""
def
__init__
(
self
,
mode
=
"l2"
,
**
kargs
):
super
().
__init__
()
assert
mode
in
[
"l1"
,
"l2"
,
"smooth_l1"
]
if
mode
==
"l1"
:
self
.
loss_func
=
nn
.
L1Loss
(
**
kargs
)
elif
mode
==
"l2"
:
self
.
loss_func
=
nn
.
MSELoss
(
**
kargs
)
elif
mode
==
"smooth_l1"
:
self
.
loss_func
=
nn
.
SmoothL1Loss
(
**
kargs
)
def
forward
(
self
,
x
,
y
):
return
self
.
loss_func
(
x
,
y
)
ppocr/losses/cls_loss.py
View file @
85aeae71
...
...
@@ -24,7 +24,7 @@ class ClsLoss(nn.Layer):
super
(
ClsLoss
,
self
).
__init__
()
self
.
loss_func
=
nn
.
CrossEntropyLoss
(
reduction
=
'mean'
)
def
__call__
(
self
,
predicts
,
batch
):
def
forward
(
self
,
predicts
,
batch
):
label
=
batch
[
1
]
loss
=
self
.
loss_func
(
input
=
predicts
,
label
=
label
)
return
{
'loss'
:
loss
}
ppocr/losses/combined_loss.py
0 → 100644
View file @
85aeae71
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
from
.distillation_loss
import
DistillationCTCLoss
from
.distillation_loss
import
DistillationDMLLoss
from
.distillation_loss
import
DistillationDistanceLoss
class
CombinedLoss
(
nn
.
Layer
):
"""
CombinedLoss:
a combionation of loss function
"""
def
__init__
(
self
,
loss_config_list
=
None
):
super
().
__init__
()
self
.
loss_func
=
[]
self
.
loss_weight
=
[]
assert
isinstance
(
loss_config_list
,
list
),
(
'operator config should be a list'
)
for
config
in
loss_config_list
:
assert
isinstance
(
config
,
dict
)
and
len
(
config
)
==
1
,
"yaml format error"
name
=
list
(
config
)[
0
]
param
=
config
[
name
]
assert
"weight"
in
param
,
"weight must be in param, but param just contains {}"
.
format
(
param
.
keys
())
self
.
loss_weight
.
append
(
param
.
pop
(
"weight"
))
self
.
loss_func
.
append
(
eval
(
name
)(
**
param
))
def
forward
(
self
,
input
,
batch
,
**
kargs
):
loss_dict
=
{}
for
idx
,
loss_func
in
enumerate
(
self
.
loss_func
):
loss
=
loss_func
(
input
,
batch
,
**
kargs
)
if
isinstance
(
loss
,
paddle
.
Tensor
):
loss
=
{
"loss_{}_{}"
.
format
(
str
(
loss
),
idx
):
loss
}
weight
=
self
.
loss_weight
[
idx
]
loss
=
{
"{}_{}"
.
format
(
key
,
idx
):
loss
[
key
]
*
weight
for
key
in
loss
}
loss_dict
.
update
(
loss
)
loss_dict
[
"loss"
]
=
paddle
.
add_n
(
list
(
loss_dict
.
values
()))
return
loss_dict
ppocr/losses/distillation_loss.py
0 → 100644
View file @
85aeae71
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
paddle
import
paddle.nn
as
nn
from
.rec_ctc_loss
import
CTCLoss
from
.basic_loss
import
DMLLoss
from
.basic_loss
import
DistanceLoss
class
DistillationDMLLoss
(
DMLLoss
):
"""
"""
def
__init__
(
self
,
model_name_pairs
=
[],
act
=
None
,
key
=
None
,
name
=
"loss_dml"
):
super
().
__init__
(
act
=
act
)
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]]
out2
=
predicts
[
pair
[
1
]]
if
self
.
key
is
not
None
:
out1
=
out1
[
self
.
key
]
out2
=
out2
[
self
.
key
]
loss
=
super
().
forward
(
out1
,
out2
)
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
loss_dict
[
"{}_{}_{}_{}"
.
format
(
key
,
pair
[
0
],
pair
[
1
],
idx
)]
=
loss
[
key
]
else
:
loss_dict
[
"{}_{}"
.
format
(
self
.
name
,
idx
)]
=
loss
return
loss_dict
class
DistillationCTCLoss
(
CTCLoss
):
def
__init__
(
self
,
model_name_list
=
[],
key
=
None
,
name
=
"loss_ctc"
):
super
().
__init__
()
self
.
model_name_list
=
model_name_list
self
.
key
=
key
self
.
name
=
name
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
model_name
in
enumerate
(
self
.
model_name_list
):
out
=
predicts
[
model_name
]
if
self
.
key
is
not
None
:
out
=
out
[
self
.
key
]
loss
=
super
().
forward
(
out
,
batch
)
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
model_name
,
idx
)]
=
loss
[
key
]
else
:
loss_dict
[
"{}_{}"
.
format
(
self
.
name
,
model_name
)]
=
loss
return
loss_dict
class
DistillationDistanceLoss
(
DistanceLoss
):
"""
"""
def
__init__
(
self
,
mode
=
"l2"
,
model_name_pairs
=
[],
key
=
None
,
name
=
"loss_distance"
,
**
kargs
):
super
().
__init__
(
mode
=
mode
,
**
kargs
)
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
+
"_l2"
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]]
out2
=
predicts
[
pair
[
1
]]
if
self
.
key
is
not
None
:
out1
=
out1
[
self
.
key
]
out2
=
out2
[
self
.
key
]
loss
=
super
().
forward
(
out1
,
out2
)
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
key
,
idx
)]
=
loss
[
key
]
else
:
loss_dict
[
"{}_{}_{}_{}"
.
format
(
self
.
name
,
pair
[
0
],
pair
[
1
],
idx
)]
=
loss
return
loss_dict
ppocr/losses/rec_ctc_loss.py
View file @
85aeae71
...
...
@@ -25,7 +25,7 @@ class CTCLoss(nn.Layer):
super
(
CTCLoss
,
self
).
__init__
()
self
.
loss_func
=
nn
.
CTCLoss
(
blank
=
0
,
reduction
=
'none'
)
def
__call__
(
self
,
predicts
,
batch
):
def
forward
(
self
,
predicts
,
batch
):
predicts
=
predicts
.
transpose
((
1
,
0
,
2
))
N
,
B
,
_
=
predicts
.
shape
preds_lengths
=
paddle
.
to_tensor
([
N
]
*
B
,
dtype
=
'int64'
)
...
...
ppocr/metrics/__init__.py
View file @
85aeae71
...
...
@@ -19,20 +19,23 @@ from __future__ import unicode_literals
import
copy
__all__
=
[
'
build_metric
'
]
__all__
=
[
"
build_metric
"
]
from
.det_metric
import
DetMetric
from
.rec_metric
import
RecMetric
from
.cls_metric
import
ClsMetric
from
.e2e_metric
import
E2EMetric
from
.distillation_metric
import
DistillationMetric
def
build_metric
(
config
):
from
.det_metric
import
DetMetric
from
.rec_metric
import
RecMetric
from
.cls_metric
import
ClsMetric
from
.e2e_metric
import
E2EMetric
support_dict
=
[
'DetMetric'
,
'RecMetric'
,
'ClsMetric'
,
'E2EMetric'
]
def
build_metric
(
config
):
support_dict
=
[
"DetMetric"
,
"RecMetric"
,
"ClsMetric"
,
"E2EMetric"
,
"DistillationMetric"
]
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'
name
'
)
module_name
=
config
.
pop
(
"
name
"
)
assert
module_name
in
support_dict
,
Exception
(
'
metric only support {}
'
.
format
(
support_dict
))
"
metric only support {}
"
.
format
(
support_dict
))
module_class
=
eval
(
module_name
)(
**
config
)
return
module_class
ppocr/metrics/distillation_metric.py
0 → 100644
View file @
85aeae71
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
importlib
import
copy
from
.rec_metric
import
RecMetric
from
.det_metric
import
DetMetric
from
.e2e_metric
import
E2EMetric
from
.cls_metric
import
ClsMetric
class
DistillationMetric
(
object
):
def
__init__
(
self
,
key
=
None
,
base_metric_name
=
"RecMetric"
,
main_indicator
=
'acc'
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
key
=
key
self
.
main_indicator
=
main_indicator
self
.
base_metric_name
=
base_metric_name
self
.
kwargs
=
kwargs
self
.
metrics
=
None
def
_init_metrcis
(
self
,
preds
):
self
.
metrics
=
dict
()
mod
=
importlib
.
import_module
(
__name__
)
for
key
in
preds
:
self
.
metrics
[
key
]
=
getattr
(
mod
,
self
.
base_metric_name
)(
main_indicator
=
self
.
main_indicator
,
**
self
.
kwargs
)
self
.
metrics
[
key
].
reset
()
def
__call__
(
self
,
preds
,
*
args
,
**
kwargs
):
assert
isinstance
(
preds
,
dict
)
if
self
.
metrics
is
None
:
self
.
_init_metrcis
(
preds
)
output
=
dict
()
for
key
in
preds
:
metric
=
self
.
metrics
[
key
].
__call__
(
preds
[
key
],
*
args
,
**
kwargs
)
for
sub_key
in
metric
:
output
[
"{}_{}"
.
format
(
key
,
sub_key
)]
=
metric
[
sub_key
]
return
output
def
get_metric
(
self
):
"""
return metrics {
'acc': 0,
'norm_edit_dis': 0,
}
"""
output
=
dict
()
for
key
in
self
.
metrics
:
metric
=
self
.
metrics
[
key
].
get_metric
()
# main indicator
if
key
==
self
.
key
:
output
.
update
(
metric
)
else
:
for
sub_key
in
metric
:
output
[
"{}_{}"
.
format
(
key
,
sub_key
)]
=
metric
[
sub_key
]
return
output
def
reset
(
self
):
for
key
in
self
.
metrics
:
self
.
metrics
[
key
].
reset
()
ppocr/modeling/architectures/__init__.py
View file @
85aeae71
...
...
@@ -13,12 +13,20 @@
# limitations under the License.
import
copy
import
importlib
from
.base_model
import
BaseModel
from
.distillation_model
import
DistillationModel
__all__
=
[
'build_model'
]
def
build_model
(
config
):
from
.base_model
import
BaseModel
config
=
copy
.
deepcopy
(
config
)
module_class
=
BaseModel
(
config
)
return
module_class
\ No newline at end of file
if
not
"name"
in
config
:
arch
=
BaseModel
(
config
)
else
:
name
=
config
.
pop
(
"name"
)
mod
=
importlib
.
import_module
(
__name__
)
arch
=
getattr
(
mod
,
name
)(
config
)
return
arch
ppocr/modeling/architectures/base_model.py
View file @
85aeae71
...
...
@@ -32,7 +32,6 @@ class BaseModel(nn.Layer):
config (dict): the super parameters for module.
"""
super
(
BaseModel
,
self
).
__init__
()
in_channels
=
config
.
get
(
'in_channels'
,
3
)
model_type
=
config
[
'model_type'
]
# build transfrom,
...
...
@@ -68,14 +67,23 @@ class BaseModel(nn.Layer):
config
[
"Head"
][
'in_channels'
]
=
in_channels
self
.
head
=
build_head
(
config
[
"Head"
])
self
.
return_all_feats
=
config
.
get
(
"return_all_feats"
,
False
)
def
forward
(
self
,
x
,
data
=
None
):
y
=
dict
()
if
self
.
use_transform
:
x
=
self
.
transform
(
x
)
x
=
self
.
backbone
(
x
)
y
[
"backbone_out"
]
=
x
if
self
.
use_neck
:
x
=
self
.
neck
(
x
)
y
[
"neck_out"
]
=
x
if
data
is
None
:
x
=
self
.
head
(
x
)
else
:
x
=
self
.
head
(
x
,
data
)
return
x
y
[
"head_out"
]
=
x
if
self
.
return_all_feats
:
return
y
else
:
return
x
ppocr/modeling/architectures/distillation_model.py
0 → 100644
View file @
85aeae71
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
paddle
import
nn
from
ppocr.modeling.transforms
import
build_transform
from
ppocr.modeling.backbones
import
build_backbone
from
ppocr.modeling.necks
import
build_neck
from
ppocr.modeling.heads
import
build_head
from
.base_model
import
BaseModel
from
ppocr.utils.save_load
import
init_model
__all__
=
[
'DistillationModel'
]
class
DistillationModel
(
nn
.
Layer
):
def
__init__
(
self
,
config
):
"""
the module for OCR distillation.
args:
config (dict): the super parameters for module.
"""
super
().
__init__
()
self
.
model_list
=
[]
self
.
model_name_list
=
[]
for
key
in
config
[
"Models"
]:
model_config
=
config
[
"Models"
][
key
]
freeze_params
=
False
pretrained
=
None
if
"freeze_params"
in
model_config
:
freeze_params
=
model_config
.
pop
(
"freeze_params"
)
if
"pretrained"
in
model_config
:
pretrained
=
model_config
.
pop
(
"pretrained"
)
model
=
BaseModel
(
model_config
)
if
pretrained
is
not
None
:
init_model
(
model
,
path
=
pretrained
)
if
freeze_params
:
for
param
in
model
.
parameters
():
param
.
trainable
=
False
self
.
model_list
.
append
(
self
.
add_sublayer
(
key
,
model
))
self
.
model_name_list
.
append
(
key
)
def
forward
(
self
,
x
):
result_dict
=
dict
()
for
idx
,
model_name
in
enumerate
(
self
.
model_name_list
):
result_dict
[
model_name
]
=
self
.
model_list
[
idx
](
x
)
return
result_dict
ppocr/modeling/backbones/det_mobilenet_v3.py
View file @
85aeae71
...
...
@@ -102,8 +102,7 @@ class MobileNetV3(nn.Layer):
padding
=
1
,
groups
=
1
,
if_act
=
True
,
act
=
'hardswish'
,
name
=
'conv1'
)
act
=
'hardswish'
)
self
.
stages
=
[]
self
.
out_channels
=
[]
...
...
@@ -125,8 +124,7 @@ class MobileNetV3(nn.Layer):
kernel_size
=
k
,
stride
=
s
,
use_se
=
se
,
act
=
nl
,
name
=
"conv"
+
str
(
i
+
2
)))
act
=
nl
))
inplanes
=
make_divisible
(
scale
*
c
)
i
+=
1
block_list
.
append
(
...
...
@@ -138,8 +136,7 @@ class MobileNetV3(nn.Layer):
padding
=
0
,
groups
=
1
,
if_act
=
True
,
act
=
'hardswish'
,
name
=
'conv_last'
))
act
=
'hardswish'
))
self
.
stages
.
append
(
nn
.
Sequential
(
*
block_list
))
self
.
out_channels
.
append
(
make_divisible
(
scale
*
cls_ch_squeeze
))
for
i
,
stage
in
enumerate
(
self
.
stages
):
...
...
@@ -163,8 +160,7 @@ class ConvBNLayer(nn.Layer):
padding
,
groups
=
1
,
if_act
=
True
,
act
=
None
,
name
=
None
):
act
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
if_act
=
if_act
self
.
act
=
act
...
...
@@ -175,16 +171,9 @@ class ConvBNLayer(nn.Layer):
stride
=
stride
,
padding
=
padding
,
groups
=
groups
,
weight_attr
=
ParamAttr
(
name
=
name
+
'_weights'
),
bias_attr
=
False
)
self
.
bn
=
nn
.
BatchNorm
(
num_channels
=
out_channels
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
name
+
"_bn_scale"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"_bn_offset"
),
moving_mean_name
=
name
+
"_bn_mean"
,
moving_variance_name
=
name
+
"_bn_variance"
)
self
.
bn
=
nn
.
BatchNorm
(
num_channels
=
out_channels
,
act
=
None
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
...
...
@@ -209,8 +198,7 @@ class ResidualUnit(nn.Layer):
kernel_size
,
stride
,
use_se
,
act
=
None
,
name
=
''
):
act
=
None
):
super
(
ResidualUnit
,
self
).
__init__
()
self
.
if_shortcut
=
stride
==
1
and
in_channels
==
out_channels
self
.
if_se
=
use_se
...
...
@@ -222,8 +210,7 @@ class ResidualUnit(nn.Layer):
stride
=
1
,
padding
=
0
,
if_act
=
True
,
act
=
act
,
name
=
name
+
"_expand"
)
act
=
act
)
self
.
bottleneck_conv
=
ConvBNLayer
(
in_channels
=
mid_channels
,
out_channels
=
mid_channels
,
...
...
@@ -232,10 +219,9 @@ class ResidualUnit(nn.Layer):
padding
=
int
((
kernel_size
-
1
)
//
2
),
groups
=
mid_channels
,
if_act
=
True
,
act
=
act
,
name
=
name
+
"_depthwise"
)
act
=
act
)
if
self
.
if_se
:
self
.
mid_se
=
SEModule
(
mid_channels
,
name
=
name
+
"_se"
)
self
.
mid_se
=
SEModule
(
mid_channels
)
self
.
linear_conv
=
ConvBNLayer
(
in_channels
=
mid_channels
,
out_channels
=
out_channels
,
...
...
@@ -243,8 +229,7 @@ class ResidualUnit(nn.Layer):
stride
=
1
,
padding
=
0
,
if_act
=
False
,
act
=
None
,
name
=
name
+
"_linear"
)
act
=
None
)
def
forward
(
self
,
inputs
):
x
=
self
.
expand_conv
(
inputs
)
...
...
@@ -258,7 +243,7 @@ class ResidualUnit(nn.Layer):
class
SEModule
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
reduction
=
4
,
name
=
""
):
def
__init__
(
self
,
in_channels
,
reduction
=
4
):
super
(
SEModule
,
self
).
__init__
()
self
.
avg_pool
=
nn
.
AdaptiveAvgPool2D
(
1
)
self
.
conv1
=
nn
.
Conv2D
(
...
...
@@ -266,17 +251,13 @@ class SEModule(nn.Layer):
out_channels
=
in_channels
//
reduction
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
weight_attr
=
ParamAttr
(
name
=
name
+
"_1_weights"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"_1_offset"
))
padding
=
0
)
self
.
conv2
=
nn
.
Conv2D
(
in_channels
=
in_channels
//
reduction
,
out_channels
=
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
weight_attr
=
ParamAttr
(
name
+
"_2_weights"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"_2_offset"
))
padding
=
0
)
def
forward
(
self
,
inputs
):
outputs
=
self
.
avg_pool
(
inputs
)
...
...
ppocr/modeling/backbones/rec_mobilenet_v3.py
View file @
85aeae71
...
...
@@ -96,8 +96,7 @@ class MobileNetV3(nn.Layer):
padding
=
1
,
groups
=
1
,
if_act
=
True
,
act
=
'hardswish'
,
name
=
'conv1'
)
act
=
'hardswish'
)
i
=
0
block_list
=
[]
inplanes
=
make_divisible
(
inplanes
*
scale
)
...
...
@@ -110,8 +109,7 @@ class MobileNetV3(nn.Layer):
kernel_size
=
k
,
stride
=
s
,
use_se
=
se
,
act
=
nl
,
name
=
'conv'
+
str
(
i
+
2
)))
act
=
nl
))
inplanes
=
make_divisible
(
scale
*
c
)
i
+=
1
self
.
blocks
=
nn
.
Sequential
(
*
block_list
)
...
...
@@ -124,8 +122,7 @@ class MobileNetV3(nn.Layer):
padding
=
0
,
groups
=
1
,
if_act
=
True
,
act
=
'hardswish'
,
name
=
'conv_last'
)
act
=
'hardswish'
)
self
.
pool
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
self
.
out_channels
=
make_divisible
(
scale
*
cls_ch_squeeze
)
...
...
ppocr/modeling/heads/det_db_head.py
View file @
85aeae71
...
...
@@ -23,10 +23,10 @@ import paddle.nn.functional as F
from
paddle
import
ParamAttr
def
get_bias_attr
(
k
,
name
):
def
get_bias_attr
(
k
):
stdv
=
1.0
/
math
.
sqrt
(
k
*
1.0
)
initializer
=
paddle
.
nn
.
initializer
.
Uniform
(
-
stdv
,
stdv
)
bias_attr
=
ParamAttr
(
initializer
=
initializer
,
name
=
name
+
"_b_attr"
)
bias_attr
=
ParamAttr
(
initializer
=
initializer
)
return
bias_attr
...
...
@@ -38,18 +38,14 @@ class Head(nn.Layer):
out_channels
=
in_channels
//
4
,
kernel_size
=
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
name
=
name_list
[
0
]
+
'.w_0'
),
weight_attr
=
ParamAttr
(),
bias_attr
=
False
)
self
.
conv_bn1
=
nn
.
BatchNorm
(
num_channels
=
in_channels
//
4
,
param_attr
=
ParamAttr
(
name
=
name_list
[
1
]
+
'.w_0'
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
1.0
)),
bias_attr
=
ParamAttr
(
name
=
name_list
[
1
]
+
'.b_0'
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
1e-4
)),
moving_mean_name
=
name_list
[
1
]
+
'.w_1'
,
moving_variance_name
=
name_list
[
1
]
+
'.w_2'
,
act
=
'relu'
)
self
.
conv2
=
nn
.
Conv2DTranspose
(
in_channels
=
in_channels
//
4
,
...
...
@@ -57,19 +53,14 @@ class Head(nn.Layer):
kernel_size
=
2
,
stride
=
2
,
weight_attr
=
ParamAttr
(
name
=
name_list
[
2
]
+
'.w_0'
,
initializer
=
paddle
.
nn
.
initializer
.
KaimingUniform
()),
bias_attr
=
get_bias_attr
(
in_channels
//
4
,
name_list
[
-
1
]
+
"conv2"
))
bias_attr
=
get_bias_attr
(
in_channels
//
4
))
self
.
conv_bn2
=
nn
.
BatchNorm
(
num_channels
=
in_channels
//
4
,
param_attr
=
ParamAttr
(
name
=
name_list
[
3
]
+
'.w_0'
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
1.0
)),
bias_attr
=
ParamAttr
(
name
=
name_list
[
3
]
+
'.b_0'
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
1e-4
)),
moving_mean_name
=
name_list
[
3
]
+
'.w_1'
,
moving_variance_name
=
name_list
[
3
]
+
'.w_2'
,
act
=
"relu"
)
self
.
conv3
=
nn
.
Conv2DTranspose
(
in_channels
=
in_channels
//
4
,
...
...
@@ -77,10 +68,8 @@ class Head(nn.Layer):
kernel_size
=
2
,
stride
=
2
,
weight_attr
=
ParamAttr
(
name
=
name_list
[
4
]
+
'.w_0'
,
initializer
=
paddle
.
nn
.
initializer
.
KaimingUniform
()),
bias_attr
=
get_bias_attr
(
in_channels
//
4
,
name_list
[
-
1
]
+
"conv3"
),
)
bias_attr
=
get_bias_attr
(
in_channels
//
4
),
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
...
...
ppocr/modeling/heads/rec_ctc_head.py
View file @
85aeae71
...
...
@@ -23,14 +23,12 @@ from paddle import ParamAttr, nn
from
paddle.nn
import
functional
as
F
def
get_para_bias_attr
(
l2_decay
,
k
,
name
):
def
get_para_bias_attr
(
l2_decay
,
k
):
regularizer
=
paddle
.
regularizer
.
L2Decay
(
l2_decay
)
stdv
=
1.0
/
math
.
sqrt
(
k
*
1.0
)
initializer
=
nn
.
initializer
.
Uniform
(
-
stdv
,
stdv
)
weight_attr
=
ParamAttr
(
regularizer
=
regularizer
,
initializer
=
initializer
,
name
=
name
+
"_w_attr"
)
bias_attr
=
ParamAttr
(
regularizer
=
regularizer
,
initializer
=
initializer
,
name
=
name
+
"_b_attr"
)
weight_attr
=
ParamAttr
(
regularizer
=
regularizer
,
initializer
=
initializer
)
bias_attr
=
ParamAttr
(
regularizer
=
regularizer
,
initializer
=
initializer
)
return
[
weight_attr
,
bias_attr
]
...
...
@@ -38,13 +36,12 @@ class CTCHead(nn.Layer):
def
__init__
(
self
,
in_channels
,
out_channels
,
fc_decay
=
0.0004
,
**
kwargs
):
super
(
CTCHead
,
self
).
__init__
()
weight_attr
,
bias_attr
=
get_para_bias_attr
(
l2_decay
=
fc_decay
,
k
=
in_channels
,
name
=
'ctc_fc'
)
l2_decay
=
fc_decay
,
k
=
in_channels
)
self
.
fc
=
nn
.
Linear
(
in_channels
,
out_channels
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
,
name
=
'ctc_fc'
)
bias_attr
=
bias_attr
)
self
.
out_channels
=
out_channels
def
forward
(
self
,
x
,
labels
=
None
):
...
...
ppocr/modeling/necks/db_fpn.py
View file @
85aeae71
...
...
@@ -32,61 +32,53 @@ class DBFPN(nn.Layer):
in_channels
=
in_channels
[
0
],
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
weight_attr
=
ParamAttr
(
name
=
'conv2d_51.w_0'
,
initializer
=
weight_attr
),
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
self
.
in3_conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
[
1
],
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
weight_attr
=
ParamAttr
(
name
=
'conv2d_50.w_0'
,
initializer
=
weight_attr
),
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
self
.
in4_conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
[
2
],
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
weight_attr
=
ParamAttr
(
name
=
'conv2d_49.w_0'
,
initializer
=
weight_attr
),
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
self
.
in5_conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
[
3
],
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
weight_attr
=
ParamAttr
(
name
=
'conv2d_48.w_0'
,
initializer
=
weight_attr
),
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
self
.
p5_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
name
=
'conv2d_52.w_0'
,
initializer
=
weight_attr
),
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
self
.
p4_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
name
=
'conv2d_53.w_0'
,
initializer
=
weight_attr
),
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
self
.
p3_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
name
=
'conv2d_54.w_0'
,
initializer
=
weight_attr
),
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
self
.
p2_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
name
=
'conv2d_55.w_0'
,
initializer
=
weight_attr
),
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
def
forward
(
self
,
x
):
...
...
ppocr/postprocess/__init__.py
View file @
85aeae71
...
...
@@ -21,18 +21,19 @@ import copy
__all__
=
[
'build_post_process'
]
from
.db_postprocess
import
DBPostProcess
from
.east_postprocess
import
EASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
DistillationCTCLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
def
build_post_process
(
config
,
global_config
=
None
):
from
.db_postprocess
import
DBPostProcess
from
.east_postprocess
import
EASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
def
build_post_process
(
config
,
global_config
=
None
):
support_dict
=
[
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
,
'DistillationCTCLabelDecode'
]
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/postprocess/rec_postprocess.py
View file @
85aeae71
...
...
@@ -125,6 +125,37 @@ class CTCLabelDecode(BaseRecLabelDecode):
return
dict_character
class
DistillationCTCLabelDecode
(
CTCLabelDecode
):
"""
Convert
Convert between text-label and text-index
"""
def
__init__
(
self
,
character_dict_path
=
None
,
character_type
=
'ch'
,
use_space_char
=
False
,
model_name
=
[
"student"
],
key
=
None
,
**
kwargs
):
super
(
DistillationCTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
if
not
isinstance
(
model_name
,
list
):
model_name
=
[
model_name
]
self
.
model_name
=
model_name
self
.
key
=
key
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
output
=
dict
()
for
name
in
self
.
model_name
:
pred
=
preds
[
name
]
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
output
[
name
]
=
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
return
output
class
AttnLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
...
...
ppocr/utils/save_load.py
View file @
85aeae71
...
...
@@ -23,6 +23,8 @@ import six
import
paddle
from
ppocr.utils.logging
import
get_logger
__all__
=
[
'init_model'
,
'save_model'
,
'load_dygraph_pretrain'
]
...
...
@@ -42,44 +44,11 @@ def _mkdir_if_not_exist(path, logger):
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
def
load_dygraph_pretrain
(
model
,
logger
,
path
=
None
,
load_static_weights
=
False
):
if
not
(
os
.
path
.
isdir
(
path
)
or
os
.
path
.
exists
(
path
+
'.pdparams'
)):
raise
ValueError
(
"Model pretrain path {} does not "
"exists."
.
format
(
path
))
if
load_static_weights
:
pre_state_dict
=
paddle
.
static
.
load_program_state
(
path
)
param_state_dict
=
{}
model_dict
=
model
.
state_dict
()
for
key
in
model_dict
.
keys
():
weight_name
=
model_dict
[
key
].
name
weight_name
=
weight_name
.
replace
(
'binarize'
,
''
).
replace
(
'thresh'
,
''
)
# for DB
if
weight_name
in
pre_state_dict
.
keys
():
# logger.info('Load weight: {}, shape: {}'.format(
# weight_name, pre_state_dict[weight_name].shape))
if
'encoder_rnn'
in
key
:
# delete axis which is 1
pre_state_dict
[
weight_name
]
=
pre_state_dict
[
weight_name
].
squeeze
()
# change axis
if
len
(
pre_state_dict
[
weight_name
].
shape
)
>
1
:
pre_state_dict
[
weight_name
]
=
pre_state_dict
[
weight_name
].
transpose
((
1
,
0
))
param_state_dict
[
key
]
=
pre_state_dict
[
weight_name
]
else
:
param_state_dict
[
key
]
=
model_dict
[
key
]
model
.
set_state_dict
(
param_state_dict
)
return
param_state_dict
=
paddle
.
load
(
path
+
'.pdparams'
)
model
.
set_state_dict
(
param_state_dict
)
return
def
init_model
(
config
,
model
,
logger
,
optimizer
=
None
,
lr_scheduler
=
None
):
def
init_model
(
config
,
model
,
optimizer
=
None
,
lr_scheduler
=
None
):
"""
load model from checkpoint or pretrained_model
"""
logger
=
get_logger
()
global_config
=
config
[
'Global'
]
checkpoints
=
global_config
.
get
(
'checkpoints'
)
pretrained_model
=
global_config
.
get
(
'pretrained_model'
)
...
...
@@ -102,18 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
best_model_dict
=
states_dict
.
get
(
'best_model_dict'
,
{})
if
'epoch'
in
states_dict
:
best_model_dict
[
'start_epoch'
]
=
states_dict
[
'epoch'
]
+
1
logger
.
info
(
"resume from {}"
.
format
(
checkpoints
))
elif
pretrained_model
:
load_static_weights
=
global_config
.
get
(
'load_static_weights'
,
False
)
if
not
isinstance
(
pretrained_model
,
list
):
pretrained_model
=
[
pretrained_model
]
if
not
isinstance
(
load_static_weights
,
list
):
load_static_weights
=
[
load_static_weights
]
*
len
(
pretrained_model
)
for
idx
,
pretrained
in
enumerate
(
pretrained_model
):
load_static
=
load_static_weights
[
idx
]
load_dygraph_pretrain
(
model
,
logger
,
path
=
pretrained
,
load_static_weights
=
load_static
)
for
pretrained
in
pretrained_model
:
if
not
(
os
.
path
.
isdir
(
pretrained
)
or
os
.
path
.
exists
(
pretrained
+
'.pdparams'
)):
raise
ValueError
(
"Model pretrain path {} does not "
"exists."
.
format
(
pretrained
))
param_state_dict
=
paddle
.
load
(
pretrained
+
'.pdparams'
)
model
.
set_state_dict
(
param_state_dict
)
logger
.
info
(
"load pretrained model from {}"
.
format
(
pretrained_model
))
else
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment