Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
a5530565
Commit
a5530565
authored
Oct 14, 2021
by
Leif
Browse files
Merge remote-tracking branch 'origin/dygraph' into dygraph
parents
a9d5349c
37eec4d5
Changes
77
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
555 additions
and
22 deletions
+555
-22
doc/doc_ch/focal_loss_formula.png
doc/doc_ch/focal_loss_formula.png
+0
-0
doc/doc_ch/focal_loss_image.png
doc/doc_ch/focal_loss_image.png
+0
-0
doc/doc_ch/rec_algo_compare.png
doc/doc_ch/rec_algo_compare.png
+0
-0
doc/doc_ch/recognition.md
doc/doc_ch/recognition.md
+3
-2
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+37
-1
ppocr/data/imaug/operators.py
ppocr/data/imaug/operators.py
+14
-1
ppocr/data/imaug/rec_img_aug.py
ppocr/data/imaug/rec_img_aug.py
+13
-6
ppocr/losses/__init__.py
ppocr/losses/__init__.py
+3
-2
ppocr/losses/ace_loss.py
ppocr/losses/ace_loss.py
+50
-0
ppocr/losses/center_loss.py
ppocr/losses/center_loss.py
+89
-0
ppocr/losses/combined_loss.py
ppocr/losses/combined_loss.py
+4
-0
ppocr/losses/distillation_loss.py
ppocr/losses/distillation_loss.py
+1
-1
ppocr/losses/rec_aster_loss.py
ppocr/losses/rec_aster_loss.py
+99
-0
ppocr/losses/rec_ctc_loss.py
ppocr/losses/rec_ctc_loss.py
+10
-2
ppocr/losses/rec_enhanced_ctc_loss.py
ppocr/losses/rec_enhanced_ctc_loss.py
+70
-0
ppocr/losses/rec_sar_loss.py
ppocr/losses/rec_sar_loss.py
+6
-3
ppocr/metrics/rec_metric.py
ppocr/metrics/rec_metric.py
+11
-2
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+3
-1
ppocr/modeling/backbones/rec_resnet_aster.py
ppocr/modeling/backbones/rec_resnet_aster.py
+140
-0
ppocr/modeling/heads/__init__.py
ppocr/modeling/heads/__init__.py
+2
-1
No files found.
doc/doc_ch/focal_loss_formula.png
0 → 100644
View file @
a5530565
23.3 KB
doc/doc_ch/focal_loss_image.png
0 → 100644
View file @
a5530565
125 KB
doc/doc_ch/rec_algo_compare.png
0 → 100644
View file @
a5530565
224 KB
doc/doc_ch/recognition.md
View file @
a5530565
...
...
@@ -234,6 +234,9 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
| rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder |
| rec_resnet_stn_bilstm_att.yml | SEED | Aster_Resnet | STN | BiLSTM | att |
*
其中SEED模型需要额外加载FastText训练好的
[
语言模型
](
https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz
)
训练中文数据,推荐使用
[
rec_chinese_lite_train_v2.0.yml
](
../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml
)
,如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
...
...
@@ -460,5 +463,3 @@ python3 tools/export_model.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_trai
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path"
```
ppocr/data/imaug/label_ops.py
View file @
a5530565
...
...
@@ -215,6 +215,11 @@ class CTCLabelEncode(BaseRecLabelEncode):
data
[
'length'
]
=
np
.
array
(
len
(
text
))
text
=
text
+
[
0
]
*
(
self
.
max_text_len
-
len
(
text
))
data
[
'label'
]
=
np
.
array
(
text
)
label
=
[
0
]
*
len
(
self
.
character
)
for
x
in
text
:
label
[
x
]
+=
1
data
[
'label_ace'
]
=
np
.
array
(
label
)
return
data
def
add_special_char
(
self
,
dict_character
):
...
...
@@ -342,6 +347,38 @@ class AttnLabelEncode(BaseRecLabelEncode):
return
idx
class
SEEDLabelEncode
(
BaseRecLabelEncode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
max_text_length
,
character_dict_path
=
None
,
character_type
=
'ch'
,
use_space_char
=
False
,
**
kwargs
):
super
(
SEEDLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
character_type
,
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
self
.
end_str
=
"eos"
dict_character
=
dict_character
+
[
self
.
end_str
]
return
dict_character
def
__call__
(
self
,
data
):
text
=
data
[
'label'
]
text
=
self
.
encode
(
text
)
if
text
is
None
:
return
None
if
len
(
text
)
>=
self
.
max_text_len
:
return
None
data
[
'length'
]
=
np
.
array
(
len
(
text
))
+
1
# conclude eos
text
=
text
+
[
len
(
self
.
character
)
-
1
]
*
(
self
.
max_text_len
-
len
(
text
)
)
data
[
'label'
]
=
np
.
array
(
text
)
return
data
class
SRNLabelEncode
(
BaseRecLabelEncode
):
""" Convert between text-label and text-index """
...
...
@@ -421,7 +458,6 @@ class TableLabelEncode(object):
substr
=
lines
[
0
].
decode
(
'utf-8'
).
strip
(
"
\r\n
"
).
split
(
"
\t
"
)
character_num
=
int
(
substr
[
0
])
elem_num
=
int
(
substr
[
1
])
for
cno
in
range
(
1
,
1
+
character_num
):
character
=
lines
[
cno
].
decode
(
'utf-8'
).
strip
(
"
\r\n
"
)
list_character
.
append
(
character
)
...
...
ppocr/data/imaug/operators.py
View file @
a5530565
...
...
@@ -23,6 +23,7 @@ import sys
import
six
import
cv2
import
numpy
as
np
import
fasttext
class
DecodeImage
(
object
):
...
...
@@ -83,12 +84,13 @@ class NRTRDecodeImage(object):
elif
self
.
img_mode
==
'RGB'
:
assert
img
.
shape
[
2
]
==
3
,
'invalid shape of image[%s]'
%
(
img
.
shape
)
img
=
img
[:,
:,
::
-
1
]
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
if
self
.
channel_first
:
img
=
img
.
transpose
((
2
,
0
,
1
))
data
[
'image'
]
=
img
return
data
class
NormalizeImage
(
object
):
""" normalize image such as substract mean, divide std
"""
...
...
@@ -133,6 +135,17 @@ class ToCHWImage(object):
return
data
class
Fasttext
(
object
):
def
__init__
(
self
,
path
=
"None"
,
**
kwargs
):
self
.
fast_model
=
fasttext
.
load_model
(
path
)
def
__call__
(
self
,
data
):
label
=
data
[
'label'
]
fast_label
=
self
.
fast_model
[
label
]
data
[
'fast_label'
]
=
fast_label
return
data
class
KeepKeys
(
object
):
def
__init__
(
self
,
keep_keys
,
**
kwargs
):
self
.
keep_keys
=
keep_keys
...
...
ppocr/data/imaug/rec_img_aug.py
View file @
a5530565
...
...
@@ -88,17 +88,19 @@ class RecResizeImg(object):
image_shape
,
infer_mode
=
False
,
character_type
=
'ch'
,
padding
=
True
,
**
kwargs
):
self
.
image_shape
=
image_shape
self
.
infer_mode
=
infer_mode
self
.
character_type
=
character_type
self
.
padding
=
padding
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
if
self
.
infer_mode
and
self
.
character_type
==
"ch"
:
norm_img
=
resize_norm_img_chinese
(
img
,
self
.
image_shape
)
else
:
norm_img
=
resize_norm_img
(
img
,
self
.
image_shape
)
norm_img
=
resize_norm_img
(
img
,
self
.
image_shape
,
self
.
padding
)
data
[
'image'
]
=
norm_img
return
data
...
...
@@ -174,16 +176,21 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
return
padding_im
,
resize_shape
,
pad_shape
,
valid_ratio
def
resize_norm_img
(
img
,
image_shape
):
def
resize_norm_img
(
img
,
image_shape
,
padding
=
True
):
imgC
,
imgH
,
imgW
=
image_shape
h
=
img
.
shape
[
0
]
w
=
img
.
shape
[
1
]
ratio
=
w
/
float
(
h
)
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
if
not
padding
:
resized_image
=
cv2
.
resize
(
img
,
(
imgW
,
imgH
),
interpolation
=
cv2
.
INTER_LINEAR
)
resized_w
=
imgW
else
:
resized_w
=
int
(
math
.
ceil
(
imgH
*
ratio
))
resized_image
=
cv2
.
resize
(
img
,
(
resized_w
,
imgH
))
ratio
=
w
/
float
(
h
)
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
resized_w
=
imgW
else
:
resized_w
=
int
(
math
.
ceil
(
imgH
*
ratio
))
resized_image
=
cv2
.
resize
(
img
,
(
resized_w
,
imgH
))
resized_image
=
resized_image
.
astype
(
'float32'
)
if
image_shape
[
0
]
==
1
:
resized_image
=
resized_image
/
255
...
...
ppocr/losses/__init__.py
View file @
a5530565
...
...
@@ -28,6 +28,8 @@ from .rec_att_loss import AttentionLoss
from
.rec_srn_loss
import
SRNLoss
from
.rec_nrtr_loss
import
NRTRLoss
from
.rec_sar_loss
import
SARLoss
from
.rec_aster_loss
import
AsterLoss
# cls loss
from
.cls_loss
import
ClsLoss
...
...
@@ -48,9 +50,8 @@ def build_loss(config):
support_dict
=
[
'DBLoss'
,
'PSELoss'
,
'EASTLoss'
,
'SASTLoss'
,
'CTCLoss'
,
'ClsLoss'
,
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'NRTRLoss'
,
'TableAttentionLoss'
,
'SARLoss'
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
]
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
'loss only support {}'
.
format
(
...
...
ppocr/losses/ace_loss.py
0 → 100644
View file @
a5530565
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.nn
as
nn
class
ACELoss
(
nn
.
Layer
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
()
self
.
loss_func
=
nn
.
CrossEntropyLoss
(
weight
=
None
,
ignore_index
=
0
,
reduction
=
'none'
,
soft_label
=
True
,
axis
=-
1
)
def
__call__
(
self
,
predicts
,
batch
):
if
isinstance
(
predicts
,
(
list
,
tuple
)):
predicts
=
predicts
[
-
1
]
B
,
N
=
predicts
.
shape
[:
2
]
div
=
paddle
.
to_tensor
([
N
]).
astype
(
'float32'
)
predicts
=
nn
.
functional
.
softmax
(
predicts
,
axis
=-
1
)
aggregation_preds
=
paddle
.
sum
(
predicts
,
axis
=
1
)
aggregation_preds
=
paddle
.
divide
(
aggregation_preds
,
div
)
length
=
batch
[
2
].
astype
(
"float32"
)
batch
=
batch
[
3
].
astype
(
"float32"
)
batch
[:,
0
]
=
paddle
.
subtract
(
div
,
length
)
batch
=
paddle
.
divide
(
batch
,
div
)
loss
=
self
.
loss_func
(
aggregation_preds
,
batch
)
return
{
"loss_ace"
:
loss
}
ppocr/losses/center_loss.py
0 → 100644
View file @
a5530565
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
pickle
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
class
CenterLoss
(
nn
.
Layer
):
"""
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
"""
def
__init__
(
self
,
num_classes
=
6625
,
feat_dim
=
96
,
init_center
=
False
,
center_file_path
=
None
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
feat_dim
=
feat_dim
self
.
centers
=
paddle
.
randn
(
shape
=
[
self
.
num_classes
,
self
.
feat_dim
]).
astype
(
"float64"
)
#random center
if
init_center
:
assert
os
.
path
.
exists
(
center_file_path
),
f
"center path(
{
center_file_path
}
) must exist when init_center is set as True."
with
open
(
center_file_path
,
'rb'
)
as
f
:
char_dict
=
pickle
.
load
(
f
)
for
key
in
char_dict
.
keys
():
self
.
centers
[
key
]
=
paddle
.
to_tensor
(
char_dict
[
key
])
def
__call__
(
self
,
predicts
,
batch
):
assert
isinstance
(
predicts
,
(
list
,
tuple
))
features
,
predicts
=
predicts
feats_reshape
=
paddle
.
reshape
(
features
,
[
-
1
,
features
.
shape
[
-
1
]]).
astype
(
"float64"
)
label
=
paddle
.
argmax
(
predicts
,
axis
=
2
)
label
=
paddle
.
reshape
(
label
,
[
label
.
shape
[
0
]
*
label
.
shape
[
1
]])
batch_size
=
feats_reshape
.
shape
[
0
]
#calc feat * feat
dist1
=
paddle
.
sum
(
paddle
.
square
(
feats_reshape
),
axis
=
1
,
keepdim
=
True
)
dist1
=
paddle
.
expand
(
dist1
,
[
batch_size
,
self
.
num_classes
])
#dist2 of centers
dist2
=
paddle
.
sum
(
paddle
.
square
(
self
.
centers
),
axis
=
1
,
keepdim
=
True
)
#num_classes
dist2
=
paddle
.
expand
(
dist2
,
[
self
.
num_classes
,
batch_size
]).
astype
(
"float64"
)
dist2
=
paddle
.
transpose
(
dist2
,
[
1
,
0
])
#first x * x + y * y
distmat
=
paddle
.
add
(
dist1
,
dist2
)
tmp
=
paddle
.
matmul
(
feats_reshape
,
paddle
.
transpose
(
self
.
centers
,
[
1
,
0
]))
distmat
=
distmat
-
2.0
*
tmp
#generate the mask
classes
=
paddle
.
arange
(
self
.
num_classes
).
astype
(
"int64"
)
label
=
paddle
.
expand
(
paddle
.
unsqueeze
(
label
,
1
),
(
batch_size
,
self
.
num_classes
))
mask
=
paddle
.
equal
(
paddle
.
expand
(
classes
,
[
batch_size
,
self
.
num_classes
]),
label
).
astype
(
"float64"
)
#get mask
dist
=
paddle
.
multiply
(
distmat
,
mask
)
loss
=
paddle
.
sum
(
paddle
.
clip
(
dist
,
min
=
1e-12
,
max
=
1e+12
))
/
batch_size
return
{
'loss_center'
:
loss
}
ppocr/losses/combined_loss.py
View file @
a5530565
...
...
@@ -15,6 +15,10 @@
import
paddle
import
paddle.nn
as
nn
from
.rec_ctc_loss
import
CTCLoss
from
.center_loss
import
CenterLoss
from
.ace_loss
import
ACELoss
from
.distillation_loss
import
DistillationCTCLoss
from
.distillation_loss
import
DistillationDMLLoss
from
.distillation_loss
import
DistillationDistanceLoss
,
DistillationDBLoss
,
DistillationDilaDBLoss
...
...
ppocr/losses/distillation_loss.py
View file @
a5530565
...
...
@@ -112,7 +112,7 @@ class DistillationDMLLoss(DMLLoss):
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
loss_dict
[
"{}_{}_{}_{}_{}"
.
format
(
key
,
pair
[
0
],
pair
[
1
],
map_name
,
idx
)]
=
loss
[
key
]
0
],
pair
[
1
],
self
.
map
s
_name
,
idx
)]
=
loss
[
key
]
else
:
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
self
.
maps_name
[
_c
],
idx
)]
=
loss
...
...
ppocr/losses/rec_aster_loss.py
0 → 100644
View file @
a5530565
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
nn
class
CosineEmbeddingLoss
(
nn
.
Layer
):
def
__init__
(
self
,
margin
=
0.
):
super
(
CosineEmbeddingLoss
,
self
).
__init__
()
self
.
margin
=
margin
self
.
epsilon
=
1e-12
def
forward
(
self
,
x1
,
x2
,
target
):
similarity
=
paddle
.
fluid
.
layers
.
reduce_sum
(
x1
*
x2
,
dim
=-
1
)
/
(
paddle
.
norm
(
x1
,
axis
=-
1
)
*
paddle
.
norm
(
x2
,
axis
=-
1
)
+
self
.
epsilon
)
one_list
=
paddle
.
full_like
(
target
,
fill_value
=
1
)
out
=
paddle
.
fluid
.
layers
.
reduce_mean
(
paddle
.
where
(
paddle
.
equal
(
target
,
one_list
),
1.
-
similarity
,
paddle
.
maximum
(
paddle
.
zeros_like
(
similarity
),
similarity
-
self
.
margin
)))
return
out
class
AsterLoss
(
nn
.
Layer
):
def
__init__
(
self
,
weight
=
None
,
size_average
=
True
,
ignore_index
=-
100
,
sequence_normalize
=
False
,
sample_normalize
=
True
,
**
kwargs
):
super
(
AsterLoss
,
self
).
__init__
()
self
.
weight
=
weight
self
.
size_average
=
size_average
self
.
ignore_index
=
ignore_index
self
.
sequence_normalize
=
sequence_normalize
self
.
sample_normalize
=
sample_normalize
self
.
loss_sem
=
CosineEmbeddingLoss
()
self
.
is_cosin_loss
=
True
self
.
loss_func_rec
=
nn
.
CrossEntropyLoss
(
weight
=
None
,
reduction
=
'none'
)
def
forward
(
self
,
predicts
,
batch
):
targets
=
batch
[
1
].
astype
(
"int64"
)
label_lengths
=
batch
[
2
].
astype
(
'int64'
)
sem_target
=
batch
[
3
].
astype
(
'float32'
)
embedding_vectors
=
predicts
[
'embedding_vectors'
]
rec_pred
=
predicts
[
'rec_pred'
]
if
not
self
.
is_cosin_loss
:
sem_loss
=
paddle
.
sum
(
self
.
loss_sem
(
embedding_vectors
,
sem_target
))
else
:
label_target
=
paddle
.
ones
([
embedding_vectors
.
shape
[
0
]])
sem_loss
=
paddle
.
sum
(
self
.
loss_sem
(
embedding_vectors
,
sem_target
,
label_target
))
# rec loss
batch_size
,
def_max_length
=
targets
.
shape
[
0
],
targets
.
shape
[
1
]
mask
=
paddle
.
zeros
([
batch_size
,
def_max_length
])
for
i
in
range
(
batch_size
):
mask
[
i
,
:
label_lengths
[
i
]]
=
1
mask
=
paddle
.
cast
(
mask
,
"float32"
)
max_length
=
max
(
label_lengths
)
assert
max_length
==
rec_pred
.
shape
[
1
]
targets
=
targets
[:,
:
max_length
]
mask
=
mask
[:,
:
max_length
]
rec_pred
=
paddle
.
reshape
(
rec_pred
,
[
-
1
,
rec_pred
.
shape
[
2
]])
input
=
nn
.
functional
.
log_softmax
(
rec_pred
,
axis
=
1
)
targets
=
paddle
.
reshape
(
targets
,
[
-
1
,
1
])
mask
=
paddle
.
reshape
(
mask
,
[
-
1
,
1
])
output
=
-
paddle
.
index_sample
(
input
,
index
=
targets
)
*
mask
output
=
paddle
.
sum
(
output
)
if
self
.
sequence_normalize
:
output
=
output
/
paddle
.
sum
(
mask
)
if
self
.
sample_normalize
:
output
=
output
/
batch_size
loss
=
output
+
sem_loss
*
0.1
return
{
'loss'
:
loss
}
ppocr/losses/rec_ctc_loss.py
View file @
a5530565
...
...
@@ -21,16 +21,24 @@ from paddle import nn
class
CTCLoss
(
nn
.
Layer
):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
use_focal_loss
=
False
,
**
kwargs
):
super
(
CTCLoss
,
self
).
__init__
()
self
.
loss_func
=
nn
.
CTCLoss
(
blank
=
0
,
reduction
=
'none'
)
self
.
use_focal_loss
=
use_focal_loss
def
forward
(
self
,
predicts
,
batch
):
if
isinstance
(
predicts
,
(
list
,
tuple
)):
predicts
=
predicts
[
-
1
]
predicts
=
predicts
.
transpose
((
1
,
0
,
2
))
N
,
B
,
_
=
predicts
.
shape
preds_lengths
=
paddle
.
to_tensor
([
N
]
*
B
,
dtype
=
'int64'
)
labels
=
batch
[
1
].
astype
(
"int32"
)
label_lengths
=
batch
[
2
].
astype
(
'int64'
)
loss
=
self
.
loss_func
(
predicts
,
labels
,
preds_lengths
,
label_lengths
)
loss
=
loss
.
mean
()
# sum
if
self
.
use_focal_loss
:
weight
=
paddle
.
exp
(
-
loss
)
weight
=
paddle
.
subtract
(
paddle
.
to_tensor
([
1.0
]),
weight
)
weight
=
paddle
.
square
(
weight
)
loss
=
paddle
.
multiply
(
loss
,
weight
)
loss
=
loss
.
mean
()
return
{
'loss'
:
loss
}
ppocr/losses/rec_enhanced_ctc_loss.py
0 → 100644
View file @
a5530565
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
nn
from
.ace_loss
import
ACELoss
from
.center_loss
import
CenterLoss
from
.rec_ctc_loss
import
CTCLoss
class
EnhancedCTCLoss
(
nn
.
Layer
):
def
__init__
(
self
,
use_focal_loss
=
False
,
use_ace_loss
=
False
,
ace_loss_weight
=
0.1
,
use_center_loss
=
False
,
center_loss_weight
=
0.05
,
num_classes
=
6625
,
feat_dim
=
96
,
init_center
=
False
,
center_file_path
=
None
,
**
kwargs
):
super
(
EnhancedCTCLoss
,
self
).
__init__
()
self
.
ctc_loss_func
=
CTCLoss
(
use_focal_loss
=
use_focal_loss
)
self
.
use_ace_loss
=
False
if
use_ace_loss
:
self
.
use_ace_loss
=
use_ace_loss
self
.
ace_loss_func
=
ACELoss
()
self
.
ace_loss_weight
=
ace_loss_weight
self
.
use_center_loss
=
False
if
use_center_loss
:
self
.
use_center_loss
=
use_center_loss
self
.
center_loss_func
=
CenterLoss
(
num_classes
=
num_classes
,
feat_dim
=
feat_dim
,
init_center
=
init_center
,
center_file_path
=
center_file_path
)
self
.
center_loss_weight
=
center_loss_weight
def
__call__
(
self
,
predicts
,
batch
):
loss
=
self
.
ctc_loss_func
(
predicts
,
batch
)[
"loss"
]
if
self
.
use_center_loss
:
center_loss
=
self
.
center_loss_func
(
predicts
,
batch
)[
"loss_center"
]
*
self
.
center_loss_weight
loss
=
loss
+
center_loss
if
self
.
use_ace_loss
:
ace_loss
=
self
.
ace_loss_func
(
predicts
,
batch
)[
"loss_ace"
]
*
self
.
ace_loss_weight
loss
=
loss
+
ace_loss
return
{
'enhanced_ctc_loss'
:
loss
}
ppocr/losses/rec_sar_loss.py
View file @
a5530565
...
...
@@ -9,11 +9,14 @@ from paddle import nn
class
SARLoss
(
nn
.
Layer
):
def
__init__
(
self
,
**
kwargs
):
super
(
SARLoss
,
self
).
__init__
()
self
.
loss_func
=
paddle
.
nn
.
loss
.
CrossEntropyLoss
(
reduction
=
"mean"
,
ignore_index
=
96
)
self
.
loss_func
=
paddle
.
nn
.
loss
.
CrossEntropyLoss
(
reduction
=
"mean"
,
ignore_index
=
92
)
def
forward
(
self
,
predicts
,
batch
):
predict
=
predicts
[:,
:
-
1
,
:]
# ignore last index of outputs to be in same seq_len with targets
label
=
batch
[
1
].
astype
(
"int64"
)[:,
1
:]
# ignore first index of target in loss calculation
predict
=
predicts
[:,
:
-
1
,
:]
# ignore last index of outputs to be in same seq_len with targets
label
=
batch
[
1
].
astype
(
"int64"
)[:,
1
:]
# ignore first index of target in loss calculation
batch_size
,
num_steps
,
num_classes
=
predict
.
shape
[
0
],
predict
.
shape
[
1
],
predict
.
shape
[
2
]
assert
len
(
label
.
shape
)
==
len
(
list
(
predict
.
shape
))
-
1
,
\
...
...
ppocr/metrics/rec_metric.py
View file @
a5530565
...
...
@@ -13,13 +13,20 @@
# limitations under the License.
import
Levenshtein
import
string
class
RecMetric
(
object
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
**
kwargs
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
is_filter
=
False
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
is_filter
=
is_filter
self
.
reset
()
def
_normalize_text
(
self
,
text
):
text
=
''
.
join
(
filter
(
lambda
x
:
x
in
(
string
.
digits
+
string
.
ascii_letters
),
text
))
return
text
.
lower
()
def
__call__
(
self
,
pred_label
,
*
args
,
**
kwargs
):
preds
,
labels
=
pred_label
correct_num
=
0
...
...
@@ -28,6 +35,9 @@ class RecMetric(object):
for
(
pred
,
pred_conf
),
(
target
,
_
)
in
zip
(
preds
,
labels
):
pred
=
pred
.
replace
(
" "
,
""
)
target
=
target
.
replace
(
" "
,
""
)
if
self
.
is_filter
:
pred
=
self
.
_normalize_text
(
pred
)
target
=
self
.
_normalize_text
(
target
)
norm_edit_dis
+=
Levenshtein
.
distance
(
pred
,
target
)
/
max
(
len
(
pred
),
len
(
target
),
1
)
if
pred
==
target
:
...
...
@@ -57,4 +67,3 @@ class RecMetric(object):
self
.
correct_num
=
0
self
.
all_num
=
0
self
.
norm_edit_dis
=
0
ppocr/modeling/backbones/__init__.py
View file @
a5530565
...
...
@@ -28,8 +28,10 @@ def build_backbone(config, model_type):
from
.rec_mv1_enhance
import
MobileNetV1Enhance
from
.rec_nrtr_mtb
import
MTB
from
.rec_resnet_31
import
ResNet31
from
.rec_resnet_aster
import
ResNet_ASTER
support_dict
=
[
'MobileNetV1Enhance'
,
'MobileNetV3'
,
'ResNet'
,
'ResNetFPN'
,
'MTB'
,
"ResNet31"
'MobileNetV1Enhance'
,
'MobileNetV3'
,
'ResNet'
,
'ResNetFPN'
,
'MTB'
,
"ResNet31"
,
"ResNet_ASTER"
]
elif
model_type
==
"e2e"
:
from
.e2e_resnet_vd_pg
import
ResNet
...
...
ppocr/modeling/backbones/rec_resnet_aster.py
0 → 100644
View file @
a5530565
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
sys
import
math
def
conv3x3
(
in_planes
,
out_planes
,
stride
=
1
):
"""3x3 convolution with padding"""
return
nn
.
Conv2D
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias_attr
=
False
)
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
):
"""1x1 convolution"""
return
nn
.
Conv2D
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
bias_attr
=
False
)
def
get_sinusoid_encoding
(
n_position
,
feat_dim
,
wave_length
=
10000
):
# [n_position]
positions
=
paddle
.
arange
(
0
,
n_position
)
# [feat_dim]
dim_range
=
paddle
.
arange
(
0
,
feat_dim
)
dim_range
=
paddle
.
pow
(
wave_length
,
2
*
(
dim_range
//
2
)
/
feat_dim
)
# [n_position, feat_dim]
angles
=
paddle
.
unsqueeze
(
positions
,
axis
=
1
)
/
paddle
.
unsqueeze
(
dim_range
,
axis
=
0
)
angles
=
paddle
.
cast
(
angles
,
"float32"
)
angles
[:,
0
::
2
]
=
paddle
.
sin
(
angles
[:,
0
::
2
])
angles
[:,
1
::
2
]
=
paddle
.
cos
(
angles
[:,
1
::
2
])
return
angles
class
AsterBlock
(
nn
.
Layer
):
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
):
super
(
AsterBlock
,
self
).
__init__
()
self
.
conv1
=
conv1x1
(
inplanes
,
planes
,
stride
)
self
.
bn1
=
nn
.
BatchNorm2D
(
planes
)
self
.
relu
=
nn
.
ReLU
()
self
.
conv2
=
conv3x3
(
planes
,
planes
)
self
.
bn2
=
nn
.
BatchNorm2D
(
planes
)
self
.
downsample
=
downsample
self
.
stride
=
stride
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
out
+=
residual
out
=
self
.
relu
(
out
)
return
out
class
ResNet_ASTER
(
nn
.
Layer
):
"""For aster or crnn"""
def
__init__
(
self
,
with_lstm
=
True
,
n_group
=
1
,
in_channels
=
3
):
super
(
ResNet_ASTER
,
self
).
__init__
()
self
.
with_lstm
=
with_lstm
self
.
n_group
=
n_group
self
.
layer0
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
,
32
,
kernel_size
=
(
3
,
3
),
stride
=
1
,
padding
=
1
,
bias_attr
=
False
),
nn
.
BatchNorm2D
(
32
),
nn
.
ReLU
())
self
.
inplanes
=
32
self
.
layer1
=
self
.
_make_layer
(
32
,
3
,
[
2
,
2
])
# [16, 50]
self
.
layer2
=
self
.
_make_layer
(
64
,
4
,
[
2
,
2
])
# [8, 25]
self
.
layer3
=
self
.
_make_layer
(
128
,
6
,
[
2
,
1
])
# [4, 25]
self
.
layer4
=
self
.
_make_layer
(
256
,
6
,
[
2
,
1
])
# [2, 25]
self
.
layer5
=
self
.
_make_layer
(
512
,
3
,
[
2
,
1
])
# [1, 25]
if
with_lstm
:
self
.
rnn
=
nn
.
LSTM
(
512
,
256
,
direction
=
"bidirect"
,
num_layers
=
2
)
self
.
out_channels
=
2
*
256
else
:
self
.
out_channels
=
512
def
_make_layer
(
self
,
planes
,
blocks
,
stride
):
downsample
=
None
if
stride
!=
[
1
,
1
]
or
self
.
inplanes
!=
planes
:
downsample
=
nn
.
Sequential
(
conv1x1
(
self
.
inplanes
,
planes
,
stride
),
nn
.
BatchNorm2D
(
planes
))
layers
=
[]
layers
.
append
(
AsterBlock
(
self
.
inplanes
,
planes
,
stride
,
downsample
))
self
.
inplanes
=
planes
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
AsterBlock
(
self
.
inplanes
,
planes
))
return
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
x0
=
self
.
layer0
(
x
)
x1
=
self
.
layer1
(
x0
)
x2
=
self
.
layer2
(
x1
)
x3
=
self
.
layer3
(
x2
)
x4
=
self
.
layer4
(
x3
)
x5
=
self
.
layer5
(
x4
)
cnn_feat
=
x5
.
squeeze
(
2
)
# [N, c, w]
cnn_feat
=
paddle
.
transpose
(
cnn_feat
,
perm
=
[
0
,
2
,
1
])
if
self
.
with_lstm
:
rnn_feat
,
_
=
self
.
rnn
(
cnn_feat
)
return
rnn_feat
else
:
return
cnn_feat
ppocr/modeling/heads/__init__.py
View file @
a5530565
...
...
@@ -29,13 +29,14 @@ def build_head(config):
from
.rec_srn_head
import
SRNHead
from
.rec_nrtr_head
import
Transformer
from
.rec_sar_head
import
SARHead
from
.rec_aster_head
import
AsterHead
# cls head
from
.cls_head
import
ClsHead
support_dict
=
[
'DBHead'
,
'PSEHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'ClsHead'
,
'AttentionHead'
,
'SRNHead'
,
'PGHead'
,
'Transformer'
,
'TableAttentionHead'
,
'SARHead'
'TableAttentionHead'
,
'SARHead'
,
'AsterHead'
]
#table head
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment