Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
721c76b4
Commit
721c76b4
authored
Dec 16, 2021
by
LDOUBLEV
Browse files
fix conflict
parents
98162be4
b77f9ec0
Changes
289
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
305 additions
and
156 deletions
+305
-156
ppocr/modeling/backbones/rec_resnet_aster.py
ppocr/modeling/backbones/rec_resnet_aster.py
+4
-1
ppocr/modeling/heads/det_pse_head.py
ppocr/modeling/heads/det_pse_head.py
+11
-9
ppocr/modeling/heads/rec_aster_head.py
ppocr/modeling/heads/rec_aster_head.py
+4
-0
ppocr/modeling/heads/rec_att_head.py
ppocr/modeling/heads/rec_att_head.py
+2
-2
ppocr/modeling/heads/rec_sar_head.py
ppocr/modeling/heads/rec_sar_head.py
+19
-1
ppocr/modeling/heads/table_att_head.py
ppocr/modeling/heads/table_att_head.py
+24
-16
ppocr/modeling/necks/fpn.py
ppocr/modeling/necks/fpn.py
+63
-25
ppocr/modeling/necks/rnn.py
ppocr/modeling/necks/rnn.py
+1
-1
ppocr/modeling/transforms/stn.py
ppocr/modeling/transforms/stn.py
+4
-1
ppocr/modeling/transforms/tps.py
ppocr/modeling/transforms/tps.py
+4
-0
ppocr/modeling/transforms/tps_spatial_transformer.py
ppocr/modeling/transforms/tps_spatial_transformer.py
+4
-0
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+8
-9
ppocr/postprocess/db_postprocess.py
ppocr/postprocess/db_postprocess.py
+13
-8
ppocr/postprocess/east_postprocess.py
ppocr/postprocess/east_postprocess.py
+5
-11
ppocr/postprocess/locality_aware_nms.py
ppocr/postprocess/locality_aware_nms.py
+1
-0
ppocr/postprocess/pse_postprocess/pse/README.md
ppocr/postprocess/pse_postprocess/pse/README.md
+2
-1
ppocr/postprocess/pse_postprocess/pse/__init__.py
ppocr/postprocess/pse_postprocess/pse/__init__.py
+9
-3
ppocr/postprocess/pse_postprocess/pse_postprocess.py
ppocr/postprocess/pse_postprocess/pse_postprocess.py
+15
-9
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+18
-59
ppocr/utils/EN_symbol_dict.txt
ppocr/utils/EN_symbol_dict.txt
+94
-0
No files found.
ppocr/modeling/backbones/rec_resnet_aster.py
View file @
721c76b4
...
...
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/resnet_aster.py
"""
import
paddle
import
paddle.nn
as
nn
...
...
ppocr/modeling/heads/det_pse_head.py
View file @
721c76b4
# copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,22 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/whai362/PSENet/blob/python3/models/head/psenet_head.py
"""
from
paddle
import
nn
class
PSEHead
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
hidden_dim
=
256
,
out_channels
=
7
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
hidden_dim
=
256
,
out_channels
=
7
,
**
kwargs
):
super
(
PSEHead
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2D
(
in_channels
,
hidden_dim
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv1
=
nn
.
Conv2D
(
in_channels
,
hidden_dim
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
bn1
=
nn
.
BatchNorm2D
(
hidden_dim
)
self
.
relu1
=
nn
.
ReLU
()
self
.
conv2
=
nn
.
Conv2D
(
hidden_dim
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
conv2
=
nn
.
Conv2D
(
hidden_dim
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
,
**
kwargs
):
out
=
self
.
conv1
(
x
)
...
...
ppocr/modeling/heads/rec_aster_head.py
View file @
721c76b4
...
...
@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/attention_recognition_head.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
ppocr/modeling/heads/rec_att_head.py
View file @
721c76b4
...
...
@@ -53,7 +53,6 @@ class AttentionHead(nn.Layer):
output_hiddens
.
append
(
paddle
.
unsqueeze
(
outputs
,
axis
=
1
))
output
=
paddle
.
concat
(
output_hiddens
,
axis
=
1
)
probs
=
self
.
generator
(
output
)
else
:
targets
=
paddle
.
zeros
(
shape
=
[
batch_size
],
dtype
=
"int32"
)
probs
=
None
...
...
@@ -75,7 +74,8 @@ class AttentionHead(nn.Layer):
probs_step
,
axis
=
1
)],
axis
=
1
)
next_input
=
probs_step
.
argmax
(
axis
=
1
)
targets
=
next_input
if
not
self
.
training
:
probs
=
paddle
.
nn
.
functional
.
softmax
(
probs
,
axis
=
2
)
return
probs
...
...
ppocr/modeling/heads/rec_sar_head.py
View file @
721c76b4
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/sar_encoder.py
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sar_decoder.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -275,7 +294,6 @@ class ParallelSARDecoder(BaseDecoder):
if
img_metas
is
not
None
and
self
.
mask
:
valid_ratios
=
img_metas
[
-
1
]
label
=
label
.
cuda
()
lab_embedding
=
self
.
embedding
(
label
)
# bsz * seq_len * emb_dim
out_enc
=
out_enc
.
unsqueeze
(
1
)
...
...
ppocr/modeling/heads/table_att_head.py
View file @
721c76b4
...
...
@@ -23,32 +23,40 @@ import numpy as np
class
TableAttentionHead
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
hidden_size
,
loc_type
,
in_max_len
=
488
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
hidden_size
,
loc_type
,
in_max_len
=
488
,
max_text_length
=
100
,
max_elem_length
=
800
,
max_cell_num
=
500
,
**
kwargs
):
super
(
TableAttentionHead
,
self
).
__init__
()
self
.
input_size
=
in_channels
[
-
1
]
self
.
hidden_size
=
hidden_size
self
.
elem_num
=
30
self
.
max_text_length
=
100
self
.
max_elem_length
=
500
self
.
max_cell_num
=
500
self
.
max_text_length
=
max_text_length
self
.
max_elem_length
=
max_elem_length
self
.
max_cell_num
=
max_cell_num
self
.
structure_attention_cell
=
AttentionGRUCell
(
self
.
input_size
,
hidden_size
,
self
.
elem_num
,
use_gru
=
False
)
self
.
structure_generator
=
nn
.
Linear
(
hidden_size
,
self
.
elem_num
)
self
.
loc_type
=
loc_type
self
.
in_max_len
=
in_max_len
if
self
.
loc_type
==
1
:
self
.
loc_generator
=
nn
.
Linear
(
hidden_size
,
4
)
else
:
if
self
.
in_max_len
==
640
:
self
.
loc_fea_trans
=
nn
.
Linear
(
400
,
self
.
max_elem_length
+
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
400
,
self
.
max_elem_length
+
1
)
elif
self
.
in_max_len
==
800
:
self
.
loc_fea_trans
=
nn
.
Linear
(
625
,
self
.
max_elem_length
+
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
625
,
self
.
max_elem_length
+
1
)
else
:
self
.
loc_fea_trans
=
nn
.
Linear
(
256
,
self
.
max_elem_length
+
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
256
,
self
.
max_elem_length
+
1
)
self
.
loc_generator
=
nn
.
Linear
(
self
.
input_size
+
hidden_size
,
4
)
def
_char_to_onehot
(
self
,
input_char
,
onehot_dim
):
input_ont_hot
=
F
.
one_hot
(
input_char
,
onehot_dim
)
return
input_ont_hot
...
...
@@ -60,16 +68,16 @@ class TableAttentionHead(nn.Layer):
if
len
(
fea
.
shape
)
==
3
:
pass
else
:
last_shape
=
int
(
np
.
prod
(
fea
.
shape
[
2
:]))
# gry added
last_shape
=
int
(
np
.
prod
(
fea
.
shape
[
2
:]))
# gry added
fea
=
paddle
.
reshape
(
fea
,
[
fea
.
shape
[
0
],
fea
.
shape
[
1
],
last_shape
])
fea
=
fea
.
transpose
([
0
,
2
,
1
])
# (NTC)(batch, width, channels)
batch_size
=
fea
.
shape
[
0
]
hidden
=
paddle
.
zeros
((
batch_size
,
self
.
hidden_size
))
output_hiddens
=
[]
if
self
.
training
and
targets
is
not
None
:
structure
=
targets
[
0
]
for
i
in
range
(
self
.
max_elem_length
+
1
):
for
i
in
range
(
self
.
max_elem_length
+
1
):
elem_onehots
=
self
.
_char_to_onehot
(
structure
[:,
i
],
onehot_dim
=
self
.
elem_num
)
(
outputs
,
hidden
),
alpha
=
self
.
structure_attention_cell
(
...
...
@@ -96,7 +104,7 @@ class TableAttentionHead(nn.Layer):
alpha
=
None
max_elem_length
=
paddle
.
to_tensor
(
self
.
max_elem_length
)
i
=
0
while
i
<
max_elem_length
+
1
:
while
i
<
max_elem_length
+
1
:
elem_onehots
=
self
.
_char_to_onehot
(
temp_elem
,
onehot_dim
=
self
.
elem_num
)
(
outputs
,
hidden
),
alpha
=
self
.
structure_attention_cell
(
...
...
@@ -105,7 +113,7 @@ class TableAttentionHead(nn.Layer):
structure_probs_step
=
self
.
structure_generator
(
outputs
)
temp_elem
=
structure_probs_step
.
argmax
(
axis
=
1
,
dtype
=
"int32"
)
i
+=
1
output
=
paddle
.
concat
(
output_hiddens
,
axis
=
1
)
structure_probs
=
self
.
structure_generator
(
output
)
structure_probs
=
F
.
softmax
(
structure_probs
)
...
...
@@ -119,9 +127,9 @@ class TableAttentionHead(nn.Layer):
loc_concat
=
paddle
.
concat
([
output
,
loc_fea
],
axis
=
2
)
loc_preds
=
self
.
loc_generator
(
loc_concat
)
loc_preds
=
F
.
sigmoid
(
loc_preds
)
return
{
'structure_probs'
:
structure_probs
,
'loc_preds'
:
loc_preds
}
return
{
'structure_probs'
:
structure_probs
,
'loc_preds'
:
loc_preds
}
class
AttentionGRUCell
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
,
hidden_size
,
num_embeddings
,
use_gru
=
False
):
super
(
AttentionGRUCell
,
self
).
__init__
()
...
...
ppocr/modeling/necks/fpn.py
View file @
721c76b4
...
...
@@ -11,64 +11,102 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/whai362/PSENet/blob/python3/models/neck/fpn.py
"""
import
paddle.nn
as
nn
import
paddle
import
math
import
paddle.nn.functional
as
F
class
Conv_BN_ReLU
(
nn
.
Layer
):
def
__init__
(
self
,
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
):
def
__init__
(
self
,
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
):
super
(
Conv_BN_ReLU
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2D
(
in_planes
,
out_planes
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias_attr
=
False
)
self
.
conv
=
nn
.
Conv2D
(
in_planes
,
out_planes
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
bias_attr
=
False
)
self
.
bn
=
nn
.
BatchNorm2D
(
out_planes
,
momentum
=
0.1
)
self
.
relu
=
nn
.
ReLU
()
for
m
in
self
.
sublayers
():
if
isinstance
(
m
,
nn
.
Conv2D
):
n
=
m
.
_kernel_size
[
0
]
*
m
.
_kernel_size
[
1
]
*
m
.
_out_channels
m
.
weight
=
paddle
.
create_parameter
(
shape
=
m
.
weight
.
shape
,
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
Normal
(
0
,
math
.
sqrt
(
2.
/
n
)))
m
.
weight
=
paddle
.
create_parameter
(
shape
=
m
.
weight
.
shape
,
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
Normal
(
0
,
math
.
sqrt
(
2.
/
n
)))
elif
isinstance
(
m
,
nn
.
BatchNorm2D
):
m
.
weight
=
paddle
.
create_parameter
(
shape
=
m
.
weight
.
shape
,
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
Constant
(
1.0
))
m
.
bias
=
paddle
.
create_parameter
(
shape
=
m
.
bias
.
shape
,
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
))
m
.
weight
=
paddle
.
create_parameter
(
shape
=
m
.
weight
.
shape
,
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
Constant
(
1.0
))
m
.
bias
=
paddle
.
create_parameter
(
shape
=
m
.
bias
.
shape
,
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
))
def
forward
(
self
,
x
):
return
self
.
relu
(
self
.
bn
(
self
.
conv
(
x
)))
class
FPN
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
):
super
(
FPN
,
self
).
__init__
()
# Top layer
self
.
toplayer_
=
Conv_BN_ReLU
(
in_channels
[
3
],
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
toplayer_
=
Conv_BN_ReLU
(
in_channels
[
3
],
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# Lateral layers
self
.
latlayer1_
=
Conv_BN_ReLU
(
in_channels
[
2
],
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
latlayer1_
=
Conv_BN_ReLU
(
in_channels
[
2
],
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
latlayer2_
=
Conv_BN_ReLU
(
in_channels
[
1
],
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
latlayer2_
=
Conv_BN_ReLU
(
in_channels
[
1
],
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
latlayer3_
=
Conv_BN_ReLU
(
in_channels
[
0
],
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
latlayer3_
=
Conv_BN_ReLU
(
in_channels
[
0
],
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
# Smooth layers
self
.
smooth1_
=
Conv_BN_ReLU
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
smooth2_
=
Conv_BN_ReLU
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
smooth1_
=
Conv_BN_ReLU
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
smooth3_
=
Conv_BN_ReLU
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
smooth2_
=
Conv_BN_ReLU
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
smooth3_
=
Conv_BN_ReLU
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
out_channels
=
out_channels
*
4
for
m
in
self
.
sublayers
():
if
isinstance
(
m
,
nn
.
Conv2D
):
n
=
m
.
_kernel_size
[
0
]
*
m
.
_kernel_size
[
1
]
*
m
.
_out_channels
m
.
weight
=
paddle
.
create_parameter
(
shape
=
m
.
weight
.
shape
,
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
Normal
(
0
,
math
.
sqrt
(
2.
/
n
)))
m
.
weight
=
paddle
.
create_parameter
(
shape
=
m
.
weight
.
shape
,
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
Normal
(
0
,
math
.
sqrt
(
2.
/
n
)))
elif
isinstance
(
m
,
nn
.
BatchNorm2D
):
m
.
weight
=
paddle
.
create_parameter
(
shape
=
m
.
weight
.
shape
,
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
Constant
(
1.0
))
m
.
bias
=
paddle
.
create_parameter
(
shape
=
m
.
bias
.
shape
,
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
))
m
.
weight
=
paddle
.
create_parameter
(
shape
=
m
.
weight
.
shape
,
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
Constant
(
1.0
))
m
.
bias
=
paddle
.
create_parameter
(
shape
=
m
.
bias
.
shape
,
dtype
=
'float32'
,
default_initializer
=
paddle
.
nn
.
initializer
.
Constant
(
0.0
))
def
_upsample
(
self
,
x
,
scale
=
1
):
return
F
.
upsample
(
x
,
scale_factor
=
scale
,
mode
=
'bilinear'
)
...
...
@@ -81,15 +119,15 @@ class FPN(nn.Layer):
p5
=
self
.
toplayer_
(
f5
)
f4
=
self
.
latlayer1_
(
f4
)
p4
=
self
.
_upsample_add
(
p5
,
f4
,
2
)
p4
=
self
.
_upsample_add
(
p5
,
f4
,
2
)
p4
=
self
.
smooth1_
(
p4
)
f3
=
self
.
latlayer2_
(
f3
)
p3
=
self
.
_upsample_add
(
p4
,
f3
,
2
)
p3
=
self
.
_upsample_add
(
p4
,
f3
,
2
)
p3
=
self
.
smooth2_
(
p3
)
f2
=
self
.
latlayer3_
(
f2
)
p2
=
self
.
_upsample_add
(
p3
,
f2
,
2
)
p2
=
self
.
_upsample_add
(
p3
,
f2
,
2
)
p2
=
self
.
smooth3_
(
p2
)
p3
=
self
.
_upsample
(
p3
,
2
)
...
...
@@ -97,4 +135,4 @@ class FPN(nn.Layer):
p5
=
self
.
_upsample
(
p5
,
8
)
fuse
=
paddle
.
concat
([
p2
,
p3
,
p4
,
p5
],
axis
=
1
)
return
fuse
\ No newline at end of file
return
fuse
ppocr/modeling/necks/rnn.py
View file @
721c76b4
...
...
@@ -51,7 +51,7 @@ class EncoderWithFC(nn.Layer):
super
(
EncoderWithFC
,
self
).
__init__
()
self
.
out_channels
=
hidden_size
weight_attr
,
bias_attr
=
get_para_bias_attr
(
l2_decay
=
0.00001
,
k
=
in_channels
,
name
=
'reduce_encoder_fea'
)
l2_decay
=
0.00001
,
k
=
in_channels
)
self
.
fc
=
nn
.
Linear
(
in_channels
,
hidden_size
,
...
...
ppocr/modeling/transforms/stn.py
View file @
721c76b4
...
...
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/stn_head.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
ppocr/modeling/transforms/tps.py
View file @
721c76b4
...
...
@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/transformation.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
...
...
ppocr/modeling/transforms/tps_spatial_transformer.py
View file @
721c76b4
...
...
@@ -11,6 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/tps_spatial_transformer.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
ppocr/postprocess/__init__.py
View file @
721c76b4
...
...
@@ -18,7 +18,6 @@ from __future__ import print_function
from
__future__
import
unicode_literals
import
copy
import
platform
__all__
=
[
'build_post_process'
]
...
...
@@ -26,24 +25,24 @@ from .db_postprocess import DBPostProcess, DistillationDBPostProcess
from
.east_postprocess
import
EASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
DistillationCTCLabelDecode
,
\
TableLabelDecode
,
NRTRLabelDecode
,
SARLabelDecode
,
SEEDLabelDecode
TableLabelDecode
,
NRTRLabelDecode
,
SARLabelDecode
,
SEEDLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
if
platform
.
system
()
!=
"Windows"
:
# pse is not support in Windows
from
.pse_postprocess
import
PSEPostProcess
def
build_post_process
(
config
,
global_config
=
None
):
support_dict
=
[
'DBPostProcess'
,
'PSEPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
,
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
,
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'DistillationDBPostProcess'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'SEEDLabelDecode'
]
if
config
[
'name'
]
==
'PSEPostProcess'
:
from
.pse_postprocess
import
PSEPostProcess
support_dict
.
append
(
'PSEPostProcess'
)
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
if
module_name
==
"None"
:
...
...
ppocr/postprocess/db_postprocess.py
View file @
721c76b4
...
...
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refered from:
https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
@@ -190,7 +193,8 @@ class DBPostProcess(object):
class
DistillationDBPostProcess
(
object
):
def
__init__
(
self
,
model_name
=
[
"student"
],
def
__init__
(
self
,
model_name
=
[
"student"
],
key
=
None
,
thresh
=
0.3
,
box_thresh
=
0.6
,
...
...
@@ -201,12 +205,13 @@ class DistillationDBPostProcess(object):
**
kwargs
):
self
.
model_name
=
model_name
self
.
key
=
key
self
.
post_process
=
DBPostProcess
(
thresh
=
thresh
,
box_thresh
=
box_thresh
,
max_candidates
=
max_candidates
,
unclip_ratio
=
unclip_ratio
,
use_dilation
=
use_dilation
,
score_mode
=
score_mode
)
self
.
post_process
=
DBPostProcess
(
thresh
=
thresh
,
box_thresh
=
box_thresh
,
max_candidates
=
max_candidates
,
unclip_ratio
=
unclip_ratio
,
use_dilation
=
use_dilation
,
score_mode
=
score_mode
)
def
__call__
(
self
,
predicts
,
shape_list
):
results
=
{}
...
...
ppocr/postprocess/east_postprocess.py
View file @
721c76b4
...
...
@@ -20,6 +20,7 @@ import numpy as np
from
.locality_aware_nms
import
nms_locality
import
cv2
import
paddle
import
lanms
import
os
import
sys
...
...
@@ -29,6 +30,7 @@ class EASTPostProcess(object):
"""
The post process for EAST.
"""
def
__init__
(
self
,
score_thresh
=
0.8
,
cover_thresh
=
0.1
,
...
...
@@ -38,11 +40,6 @@ class EASTPostProcess(object):
self
.
score_thresh
=
score_thresh
self
.
cover_thresh
=
cover_thresh
self
.
nms_thresh
=
nms_thresh
# c++ la-nms is faster, but only support python 3.5
self
.
is_python35
=
False
if
sys
.
version_info
.
major
==
3
and
sys
.
version_info
.
minor
==
5
:
self
.
is_python35
=
True
def
restore_rectangle_quad
(
self
,
origin
,
geometry
):
"""
...
...
@@ -79,11 +76,8 @@ class EASTPostProcess(object):
boxes
=
np
.
zeros
((
text_box_restored
.
shape
[
0
],
9
),
dtype
=
np
.
float32
)
boxes
[:,
:
8
]
=
text_box_restored
.
reshape
((
-
1
,
8
))
boxes
[:,
8
]
=
score_map
[
xy_text
[:,
0
],
xy_text
[:,
1
]]
if
self
.
is_python35
:
import
lanms
boxes
=
lanms
.
merge_quadrangle_n9
(
boxes
,
nms_thresh
)
else
:
boxes
=
nms_locality
(
boxes
.
astype
(
np
.
float64
),
nms_thresh
)
boxes
=
lanms
.
merge_quadrangle_n9
(
boxes
,
nms_thresh
)
# boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
if
boxes
.
shape
[
0
]
==
0
:
return
[]
# Here we filter some low score boxes by the average score map,
...
...
@@ -139,4 +133,4 @@ class EASTPostProcess(object):
continue
boxes_norm
.
append
(
box
)
dt_boxes_list
.
append
({
'points'
:
np
.
array
(
boxes_norm
)})
return
dt_boxes_list
\ No newline at end of file
return
dt_boxes_list
ppocr/postprocess/locality_aware_nms.py
View file @
721c76b4
"""
Locality aware nms.
This code is refered from: https://github.com/songdejia/EAST/blob/master/locality_aware_nms.py
"""
import
numpy
as
np
...
...
ppocr/postprocess/pse_postprocess/pse/README.md
View file @
721c76b4
## 编译
code from https://github.com/whai362/pan_pp.pytorch
This code is refer from:
https://github.com/whai362/PSENet/blob/python3/models/post_processing/pse
```
python
python3
setup
.
py
build_ext
--
inplace
```
ppocr/postprocess/pse_postprocess/pse/__init__.py
View file @
721c76b4
...
...
@@ -17,7 +17,13 @@ import subprocess
python_path
=
sys
.
executable
if
subprocess
.
call
(
'cd ppocr/postprocess/pse_postprocess/pse;{} setup.py build_ext --inplace;cd -'
.
format
(
python_path
),
shell
=
True
)
!=
0
:
raise
RuntimeError
(
'Cannot compile pse: {}'
.
format
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))))
ori_path
=
os
.
getcwd
()
os
.
chdir
(
'ppocr/postprocess/pse_postprocess/pse'
)
if
subprocess
.
call
(
'{} setup.py build_ext --inplace'
.
format
(
python_path
),
shell
=
True
)
!=
0
:
raise
RuntimeError
(
'Cannot compile pse: {}, if your system is windows, you need to install all the default components of `desktop development using C++` in visual studio 2019+'
.
format
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))))
os
.
chdir
(
ori_path
)
from
.pse
import
pse
\ No newline at end of file
from
.pse
import
pse
ppocr/postprocess/pse_postprocess/pse_postprocess.py
View file @
721c76b4
#
C
opyright (c) 2021 PaddlePaddle Authors. All Rights Reserve
d
.
#
c
opyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/whai362/PSENet/blob/python3/models/head/psenet_head.py
"""
from
__future__
import
absolute_import
from
__future__
import
division
...
...
@@ -47,7 +51,8 @@ class PSEPostProcess(object):
pred
=
outs_dict
[
'maps'
]
if
not
isinstance
(
pred
,
paddle
.
Tensor
):
pred
=
paddle
.
to_tensor
(
pred
)
pred
=
F
.
interpolate
(
pred
,
scale_factor
=
4
//
self
.
scale
,
mode
=
'bilinear'
)
pred
=
F
.
interpolate
(
pred
,
scale_factor
=
4
//
self
.
scale
,
mode
=
'bilinear'
)
score
=
F
.
sigmoid
(
pred
[:,
0
,
:,
:])
...
...
@@ -60,7 +65,9 @@ class PSEPostProcess(object):
boxes_batch
=
[]
for
batch_index
in
range
(
pred
.
shape
[
0
]):
boxes
,
scores
=
self
.
boxes_from_bitmap
(
score
[
batch_index
],
kernels
[
batch_index
],
shape_list
[
batch_index
])
boxes
,
scores
=
self
.
boxes_from_bitmap
(
score
[
batch_index
],
kernels
[
batch_index
],
shape_list
[
batch_index
])
boxes_batch
.
append
({
'points'
:
boxes
,
'scores'
:
scores
})
return
boxes_batch
...
...
@@ -98,15 +105,14 @@ class PSEPostProcess(object):
mask
=
np
.
zeros
((
box_height
,
box_width
),
np
.
uint8
)
mask
[
points
[:,
1
],
points
[:,
0
]]
=
255
contours
,
_
=
cv2
.
findContours
(
mask
,
cv2
.
RETR_EXTERNAL
,
cv2
.
CHAIN_APPROX_SIMPLE
)
contours
,
_
=
cv2
.
findContours
(
mask
,
cv2
.
RETR_EXTERNAL
,
cv2
.
CHAIN_APPROX_SIMPLE
)
bbox
=
np
.
squeeze
(
contours
[
0
],
1
)
else
:
raise
NotImplementedError
bbox
[:,
0
]
=
np
.
clip
(
np
.
round
(
bbox
[:,
0
]
/
ratio_w
),
0
,
src_w
)
bbox
[:,
1
]
=
np
.
clip
(
np
.
round
(
bbox
[:,
1
]
/
ratio_h
),
0
,
src_h
)
bbox
[:,
0
]
=
np
.
clip
(
np
.
round
(
bbox
[:,
0
]
/
ratio_w
),
0
,
src_w
)
bbox
[:,
1
]
=
np
.
clip
(
np
.
round
(
bbox
[:,
1
]
/
ratio_h
),
0
,
src_h
)
boxes
.
append
(
bbox
)
scores
.
append
(
score_i
)
return
boxes
,
scores
ppocr/postprocess/rec_postprocess.py
View file @
721c76b4
...
...
@@ -21,33 +21,15 @@ import re
class
BaseRecLabelDecode
(
object
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
character_type
=
'ch'
,
use_space_char
=
False
):
support_character_type
=
[
'ch'
,
'en'
,
'EN_symbol'
,
'french'
,
'german'
,
'japan'
,
'korean'
,
'it'
,
'xi'
,
'pu'
,
'ru'
,
'ar'
,
'ta'
,
'ug'
,
'fa'
,
'ur'
,
'rs'
,
'oc'
,
'rsc'
,
'bg'
,
'uk'
,
'be'
,
'te'
,
'ka'
,
'chinese_cht'
,
'hi'
,
'mr'
,
'ne'
,
'EN'
,
'latin'
,
'arabic'
,
'cyrillic'
,
'devanagari'
]
assert
character_type
in
support_character_type
,
"Only {} are supported now but get {}"
.
format
(
support_character_type
,
character_type
)
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
if
character_type
==
"en"
:
self
.
character_str
=
[]
if
character_dict_path
is
None
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
dict_character
=
list
(
self
.
character_str
)
elif
character_type
==
"EN_symbol"
:
# same with ASTER setting (use 94 char).
self
.
character_str
=
string
.
printable
[:
-
6
]
dict_character
=
list
(
self
.
character_str
)
elif
character_type
in
support_character_type
:
self
.
character_str
=
[]
assert
character_dict_path
is
not
None
,
"character_dict_path should not be None when character_type is {}"
.
format
(
character_type
)
else
:
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
for
line
in
lines
:
...
...
@@ -57,9 +39,6 @@ class BaseRecLabelDecode(object):
self
.
character_str
.
append
(
" "
)
dict_character
=
list
(
self
.
character_str
)
else
:
raise
NotImplementedError
self
.
character_type
=
character_type
dict_character
=
self
.
add_special_char
(
dict_character
)
self
.
dict
=
{}
for
i
,
char
in
enumerate
(
dict_character
):
...
...
@@ -102,13 +81,10 @@ class BaseRecLabelDecode(object):
class
CTCLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
character_type
=
'ch'
,
use_space_char
=
False
,
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
CTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
use_space_char
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
isinstance
(
preds
,
tuple
):
...
...
@@ -136,13 +112,12 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
def
__init__
(
self
,
character_dict_path
=
None
,
character_type
=
'ch'
,
use_space_char
=
False
,
model_name
=
[
"student"
],
key
=
None
,
**
kwargs
):
super
(
DistillationCTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
super
(
DistillationCTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
use_space_char
)
if
not
isinstance
(
model_name
,
list
):
model_name
=
[
model_name
]
self
.
model_name
=
model_name
...
...
@@ -162,13 +137,9 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
class
NRTRLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
character_type
=
'EN_symbol'
,
use_space_char
=
True
,
**
kwargs
):
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
True
,
**
kwargs
):
super
(
NRTRLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
use_space_char
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
...
...
@@ -230,13 +201,10 @@ class NRTRLabelDecode(BaseRecLabelDecode):
class
AttnLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
character_type
=
'ch'
,
use_space_char
=
False
,
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
AttnLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
...
...
@@ -313,13 +281,10 @@ class AttnLabelDecode(BaseRecLabelDecode):
class
SEEDLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
character_type
=
'ch'
,
use_space_char
=
False
,
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
SEEDLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
use_space_char
)
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
...
...
@@ -394,13 +359,10 @@ class SEEDLabelDecode(BaseRecLabelDecode):
class
SRNLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
character_type
=
'en'
,
use_space_char
=
False
,
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
SRNLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
use_space_char
)
self
.
max_text_length
=
kwargs
.
get
(
'max_text_length'
,
25
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
...
...
@@ -616,13 +578,10 @@ class TableLabelDecode(object):
class
SARLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
character_type
=
'ch'
,
use_space_char
=
False
,
def
__init__
(
self
,
character_dict_path
=
None
,
use_space_char
=
False
,
**
kwargs
):
super
(
SARLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
use_space_char
)
self
.
rm_symbol
=
kwargs
.
get
(
'rm_symbol'
,
False
)
...
...
ppocr/utils/EN_symbol_dict.txt
0 → 100644
View file @
721c76b4
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
:
;
<
=
>
?
@
[
\
]
^
_
`
{
|
}
~
\ No newline at end of file
Prev
1
…
4
5
6
7
8
9
10
11
12
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment