Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
0343756e
Commit
0343756e
authored
Jun 03, 2021
by
littletomatodonkey
Browse files
fix metric
parent
b48f7609
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
41 additions
and
38 deletions
+41
-38
configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml
...h_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml
+4
-4
ppocr/losses/basic_loss.py
ppocr/losses/basic_loss.py
+6
-15
ppocr/losses/distillation_loss.py
ppocr/losses/distillation_loss.py
+8
-5
ppocr/metrics/__init__.py
ppocr/metrics/__init__.py
+12
-9
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+11
-5
No files found.
configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml
View file @
0343756e
...
@@ -95,17 +95,17 @@ Loss:
...
@@ -95,17 +95,17 @@ Loss:
model_name_pairs
:
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
-
[
"
Student"
,
"
Teacher"
]
key
:
backbone_out
key
:
backbone_out
PostProcess
:
PostProcess
:
name
:
DistillationCTCLabelDecode
name
:
DistillationCTCLabelDecode
model_name
:
"
Student"
model_name
:
[
"
Student"
,
"
Teacher"
]
key
:
head_out
key
:
head_out
Metric
:
Metric
:
name
:
RecMetric
name
:
DistillationMetric
base_metric_name
:
RecMetric
main_indicator
:
acc
main_indicator
:
acc
key
:
"
Student"
Train
:
Train
:
dataset
:
dataset
:
...
...
ppocr/losses/basic_loss.py
View file @
0343756e
...
@@ -22,9 +22,8 @@ from paddle.nn import SmoothL1Loss
...
@@ -22,9 +22,8 @@ from paddle.nn import SmoothL1Loss
class
CELoss
(
nn
.
Layer
):
class
CELoss
(
nn
.
Layer
):
def
__init__
(
self
,
name
=
"loss_ce"
,
epsilon
=
None
):
def
__init__
(
self
,
epsilon
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
name
=
name
if
epsilon
is
not
None
and
(
epsilon
<=
0
or
epsilon
>=
1
):
if
epsilon
is
not
None
and
(
epsilon
<=
0
or
epsilon
>=
1
):
epsilon
=
None
epsilon
=
None
self
.
epsilon
=
epsilon
self
.
epsilon
=
epsilon
...
@@ -52,9 +51,7 @@ class CELoss(nn.Layer):
...
@@ -52,9 +51,7 @@ class CELoss(nn.Layer):
else
:
else
:
soft_label
=
False
soft_label
=
False
loss
=
F
.
cross_entropy
(
x
,
label
=
label
,
soft_label
=
soft_label
)
loss
=
F
.
cross_entropy
(
x
,
label
=
label
,
soft_label
=
soft_label
)
return
loss
loss_dict
[
self
.
name
]
=
paddle
.
mean
(
loss
)
return
loss_dict
class
DMLLoss
(
nn
.
Layer
):
class
DMLLoss
(
nn
.
Layer
):
...
@@ -62,11 +59,10 @@ class DMLLoss(nn.Layer):
...
@@ -62,11 +59,10 @@ class DMLLoss(nn.Layer):
DMLLoss
DMLLoss
"""
"""
def
__init__
(
self
,
act
=
None
,
name
=
"loss_dml"
):
def
__init__
(
self
,
act
=
None
):
super
().
__init__
()
super
().
__init__
()
if
act
is
not
None
:
if
act
is
not
None
:
assert
act
in
[
"softmax"
,
"sigmoid"
]
assert
act
in
[
"softmax"
,
"sigmoid"
]
self
.
name
=
name
if
act
==
"softmax"
:
if
act
==
"softmax"
:
self
.
act
=
nn
.
Softmax
(
axis
=-
1
)
self
.
act
=
nn
.
Softmax
(
axis
=-
1
)
elif
act
==
"sigmoid"
:
elif
act
==
"sigmoid"
:
...
@@ -75,7 +71,6 @@ class DMLLoss(nn.Layer):
...
@@ -75,7 +71,6 @@ class DMLLoss(nn.Layer):
self
.
act
=
None
self
.
act
=
None
def
forward
(
self
,
out1
,
out2
):
def
forward
(
self
,
out1
,
out2
):
loss_dict
=
{}
if
self
.
act
is
not
None
:
if
self
.
act
is
not
None
:
out1
=
self
.
act
(
out1
)
out1
=
self
.
act
(
out1
)
out2
=
self
.
act
(
out2
)
out2
=
self
.
act
(
out2
)
...
@@ -85,18 +80,16 @@ class DMLLoss(nn.Layer):
...
@@ -85,18 +80,16 @@ class DMLLoss(nn.Layer):
loss
=
(
F
.
kl_div
(
loss
=
(
F
.
kl_div
(
log_out1
,
out2
,
reduction
=
'batchmean'
)
+
F
.
kl_div
(
log_out1
,
out2
,
reduction
=
'batchmean'
)
+
F
.
kl_div
(
log_out2
,
log_out1
,
reduction
=
'batchmean'
))
/
2.0
log_out2
,
log_out1
,
reduction
=
'batchmean'
))
/
2.0
loss_dict
[
self
.
name
]
=
loss
return
loss
return
loss_dict
class
DistanceLoss
(
nn
.
Layer
):
class
DistanceLoss
(
nn
.
Layer
):
"""
"""
DistanceLoss:
DistanceLoss:
mode: loss mode
mode: loss mode
name: loss key in the output dict
"""
"""
def
__init__
(
self
,
mode
=
"l2"
,
name
=
"loss_dist"
,
**
kargs
):
def
__init__
(
self
,
mode
=
"l2"
,
**
kargs
):
super
().
__init__
()
super
().
__init__
()
assert
mode
in
[
"l1"
,
"l2"
,
"smooth_l1"
]
assert
mode
in
[
"l1"
,
"l2"
,
"smooth_l1"
]
if
mode
==
"l1"
:
if
mode
==
"l1"
:
...
@@ -106,7 +99,5 @@ class DistanceLoss(nn.Layer):
...
@@ -106,7 +99,5 @@ class DistanceLoss(nn.Layer):
elif
mode
==
"smooth_l1"
:
elif
mode
==
"smooth_l1"
:
self
.
loss_func
=
nn
.
SmoothL1Loss
(
**
kargs
)
self
.
loss_func
=
nn
.
SmoothL1Loss
(
**
kargs
)
self
.
name
=
"{}_{}"
.
format
(
name
,
mode
)
def
forward
(
self
,
x
,
y
):
def
forward
(
self
,
x
,
y
):
return
{
self
.
name
:
self
.
loss_func
(
x
,
y
)
}
return
self
.
loss_func
(
x
,
y
)
ppocr/losses/distillation_loss.py
View file @
0343756e
...
@@ -26,10 +26,11 @@ class DistillationDMLLoss(DMLLoss):
...
@@ -26,10 +26,11 @@ class DistillationDMLLoss(DMLLoss):
def
__init__
(
self
,
model_name_pairs
=
[],
act
=
None
,
key
=
None
,
def
__init__
(
self
,
model_name_pairs
=
[],
act
=
None
,
key
=
None
,
name
=
"loss_dml"
):
name
=
"loss_dml"
):
super
().
__init__
(
act
=
act
,
name
=
name
)
super
().
__init__
(
act
=
act
)
assert
isinstance
(
model_name_pairs
,
list
)
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
def
forward
(
self
,
predicts
,
batch
):
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
loss_dict
=
dict
()
...
@@ -42,8 +43,8 @@ class DistillationDMLLoss(DMLLoss):
...
@@ -42,8 +43,8 @@ class DistillationDMLLoss(DMLLoss):
loss
=
super
().
forward
(
out1
,
out2
)
loss
=
super
().
forward
(
out1
,
out2
)
if
isinstance
(
loss
,
dict
):
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
for
key
in
loss
:
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
key
,
idx
)]
=
loss
[
loss_dict
[
"{}_{}_{}
_{}
"
.
format
(
key
,
pair
[
0
],
pair
[
1
],
key
]
idx
)]
=
loss
[
key
]
else
:
else
:
loss_dict
[
"{}_{}"
.
format
(
self
.
name
,
idx
)]
=
loss
loss_dict
[
"{}_{}"
.
format
(
self
.
name
,
idx
)]
=
loss
return
loss_dict
return
loss_dict
...
@@ -82,10 +83,11 @@ class DistillationDistanceLoss(DistanceLoss):
...
@@ -82,10 +83,11 @@ class DistillationDistanceLoss(DistanceLoss):
key
=
None
,
key
=
None
,
name
=
"loss_distance"
,
name
=
"loss_distance"
,
**
kargs
):
**
kargs
):
super
().
__init__
(
mode
=
mode
,
name
=
name
,
**
kargs
)
super
().
__init__
(
mode
=
mode
,
**
kargs
)
assert
isinstance
(
model_name_pairs
,
list
)
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
+
"_l2"
def
forward
(
self
,
predicts
,
batch
):
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
loss_dict
=
dict
()
...
@@ -101,5 +103,6 @@ class DistillationDistanceLoss(DistanceLoss):
...
@@ -101,5 +103,6 @@ class DistillationDistanceLoss(DistanceLoss):
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
key
,
idx
)]
=
loss
[
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
key
,
idx
)]
=
loss
[
key
]
key
]
else
:
else
:
loss_dict
[
"{}_{}"
.
format
(
self
.
name
,
idx
)]
=
loss
loss_dict
[
"{}_{}_{}_{}"
.
format
(
self
.
name
,
pair
[
0
],
pair
[
1
],
idx
)]
=
loss
return
loss_dict
return
loss_dict
ppocr/metrics/__init__.py
View file @
0343756e
...
@@ -19,20 +19,23 @@ from __future__ import unicode_literals
...
@@ -19,20 +19,23 @@ from __future__ import unicode_literals
import
copy
import
copy
__all__
=
[
'
build_metric
'
]
__all__
=
[
"
build_metric
"
]
from
.det_metric
import
DetMetric
from
.rec_metric
import
RecMetric
from
.cls_metric
import
ClsMetric
from
.e2e_metric
import
E2EMetric
from
.distillation_metric
import
DistillationMetric
def
build_metric
(
config
):
from
.det_metric
import
DetMetric
from
.rec_metric
import
RecMetric
from
.cls_metric
import
ClsMetric
from
.e2e_metric
import
E2EMetric
support_dict
=
[
'DetMetric'
,
'RecMetric'
,
'ClsMetric'
,
'E2EMetric'
]
def
build_metric
(
config
):
support_dict
=
[
"DetMetric"
,
"RecMetric"
,
"ClsMetric"
,
"E2EMetric"
,
"DistillationMetric"
]
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'
name
'
)
module_name
=
config
.
pop
(
"
name
"
)
assert
module_name
in
support_dict
,
Exception
(
assert
module_name
in
support_dict
,
Exception
(
'
metric only support {}
'
.
format
(
support_dict
))
"
metric only support {}
"
.
format
(
support_dict
))
module_class
=
eval
(
module_name
)(
**
config
)
module_class
=
eval
(
module_name
)(
**
config
)
return
module_class
return
module_class
ppocr/postprocess/rec_postprocess.py
View file @
0343756e
...
@@ -135,19 +135,25 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
...
@@ -135,19 +135,25 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
character_dict_path
=
None
,
character_dict_path
=
None
,
character_type
=
'ch'
,
character_type
=
'ch'
,
use_space_char
=
False
,
use_space_char
=
False
,
model_name
=
"student"
,
model_name
=
[
"student"
]
,
key
=
None
,
key
=
None
,
**
kwargs
):
**
kwargs
):
super
(
DistillationCTCLabelDecode
,
self
).
__init__
(
super
(
DistillationCTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
character_dict_path
,
character_type
,
use_space_char
)
if
not
isinstance
(
model_name
,
list
):
model_name
=
[
model_name
]
self
.
model_name
=
model_name
self
.
model_name
=
model_name
self
.
key
=
key
self
.
key
=
key
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
pred
=
preds
[
self
.
model_name
]
output
=
dict
()
if
self
.
key
is
not
None
:
for
name
in
self
.
model_name
:
pred
=
pred
[
self
.
key
]
pred
=
preds
[
name
]
return
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
output
[
name
]
=
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
return
output
class
AttnLabelDecode
(
BaseRecLabelDecode
):
class
AttnLabelDecode
(
BaseRecLabelDecode
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment