Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
76320bf0
Commit
76320bf0
authored
Apr 09, 2021
by
littletomatodonkey
Browse files
Merge branch 'dygraph' of
https://github.com/PaddlePaddle/PaddleOCR
into dev/add_thread_pred
parents
e19bedf5
824ceca6
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3049 additions
and
80 deletions
+3049
-80
ppocr/metrics/__init__.py
ppocr/metrics/__init__.py
+2
-1
ppocr/metrics/e2e_metric.py
ppocr/metrics/e2e_metric.py
+81
-0
ppocr/metrics/eval_det_iou.py
ppocr/metrics/eval_det_iou.py
+4
-3
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+3
-0
ppocr/modeling/backbones/e2e_resnet_vd_pg.py
ppocr/modeling/backbones/e2e_resnet_vd_pg.py
+265
-0
ppocr/modeling/heads/__init__.py
ppocr/modeling/heads/__init__.py
+3
-2
ppocr/modeling/heads/e2e_pg_head.py
ppocr/modeling/heads/e2e_pg_head.py
+253
-0
ppocr/modeling/necks/__init__.py
ppocr/modeling/necks/__init__.py
+3
-1
ppocr/modeling/necks/pg_fpn.py
ppocr/modeling/necks/pg_fpn.py
+314
-0
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+2
-1
ppocr/postprocess/pg_postprocess.py
ppocr/postprocess/pg_postprocess.py
+155
-0
ppocr/postprocess/sast_postprocess.py
ppocr/postprocess/sast_postprocess.py
+127
-72
ppocr/utils/dict/arabic_dict.txt
ppocr/utils/dict/arabic_dict.txt
+162
-0
ppocr/utils/dict/cyrillic_dict.txt
ppocr/utils/dict/cyrillic_dict.txt
+163
-0
ppocr/utils/dict/devanagari_dict.txt
ppocr/utils/dict/devanagari_dict.txt
+167
-0
ppocr/utils/dict/latin_dict.txt
ppocr/utils/dict/latin_dict.txt
+185
-0
ppocr/utils/e2e_metric/Deteval.py
ppocr/utils/e2e_metric/Deteval.py
+458
-0
ppocr/utils/e2e_metric/polygon_fast.py
ppocr/utils/e2e_metric/polygon_fast.py
+83
-0
ppocr/utils/e2e_utils/extract_batchsize.py
ppocr/utils/e2e_utils/extract_batchsize.py
+87
-0
ppocr/utils/e2e_utils/extract_textpoint.py
ppocr/utils/e2e_utils/extract_textpoint.py
+532
-0
No files found.
ppocr/metrics/__init__.py
View file @
76320bf0
...
...
@@ -26,8 +26,9 @@ def build_metric(config):
from
.det_metric
import
DetMetric
from
.rec_metric
import
RecMetric
from
.cls_metric
import
ClsMetric
from
.e2e_metric
import
E2EMetric
support_dict
=
[
'DetMetric'
,
'RecMetric'
,
'ClsMetric'
]
support_dict
=
[
'DetMetric'
,
'RecMetric'
,
'ClsMetric'
,
'E2EMetric'
]
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
...
...
ppocr/metrics/e2e_metric.py
0 → 100644
View file @
76320bf0
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
__all__
=
[
'E2EMetric'
]
from
ppocr.utils.e2e_metric.Deteval
import
get_socre
,
combine_results
from
ppocr.utils.e2e_utils.extract_textpoint
import
get_dict
class
E2EMetric
(
object
):
def
__init__
(
self
,
character_dict_path
,
main_indicator
=
'f_score_e2e'
,
**
kwargs
):
self
.
label_list
=
get_dict
(
character_dict_path
)
self
.
max_index
=
len
(
self
.
label_list
)
self
.
main_indicator
=
main_indicator
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
temp_gt_polyons_batch
=
batch
[
2
]
temp_gt_strs_batch
=
batch
[
3
]
ignore_tags_batch
=
batch
[
4
]
gt_polyons_batch
=
[]
gt_strs_batch
=
[]
temp_gt_polyons_batch
=
temp_gt_polyons_batch
[
0
].
tolist
()
for
temp_list
in
temp_gt_polyons_batch
:
t
=
[]
for
index
in
temp_list
:
if
index
[
0
]
!=
-
1
and
index
[
1
]
!=
-
1
:
t
.
append
(
index
)
gt_polyons_batch
.
append
(
t
)
temp_gt_strs_batch
=
temp_gt_strs_batch
[
0
].
tolist
()
for
temp_list
in
temp_gt_strs_batch
:
t
=
""
for
index
in
temp_list
:
if
index
<
self
.
max_index
:
t
+=
self
.
label_list
[
index
]
gt_strs_batch
.
append
(
t
)
for
pred
,
gt_polyons
,
gt_strs
,
ignore_tags
in
zip
(
[
preds
],
[
gt_polyons_batch
],
[
gt_strs_batch
],
ignore_tags_batch
):
# prepare gt
gt_info_list
=
[{
'points'
:
gt_polyon
,
'text'
:
gt_str
,
'ignore'
:
ignore_tag
}
for
gt_polyon
,
gt_str
,
ignore_tag
in
zip
(
gt_polyons
,
gt_strs
,
ignore_tags
)]
# prepare det
e2e_info_list
=
[{
'points'
:
det_polyon
,
'text'
:
pred_str
}
for
det_polyon
,
pred_str
in
zip
(
pred
[
'points'
],
pred
[
'strs'
])]
result
=
get_socre
(
gt_info_list
,
e2e_info_list
)
self
.
results
.
append
(
result
)
def
get_metric
(
self
):
metircs
=
combine_results
(
self
.
results
)
self
.
reset
()
return
metircs
def
reset
(
self
):
self
.
results
=
[]
# clear results
ppocr/metrics/eval_det_iou.py
View file @
76320bf0
...
...
@@ -150,7 +150,7 @@ class DetectionIoUEvaluator(object):
pairs
.
append
({
'gt'
:
gtNum
,
'det'
:
detNum
})
detMatchedNums
.
append
(
detNum
)
evaluationLog
+=
"Match GT #"
+
\
str
(
gtNum
)
+
" with Det #"
+
str
(
detNum
)
+
"
\n
"
str
(
gtNum
)
+
" with Det #"
+
str
(
detNum
)
+
"
\n
"
numGtCare
=
(
len
(
gtPols
)
-
len
(
gtDontCarePolsNum
))
numDetCare
=
(
len
(
detPols
)
-
len
(
detDontCarePolsNum
))
...
...
@@ -162,7 +162,7 @@ class DetectionIoUEvaluator(object):
precision
=
0
if
numDetCare
==
0
else
float
(
detMatched
)
/
numDetCare
hmean
=
0
if
(
precision
+
recall
)
==
0
else
2.0
*
\
precision
*
recall
/
(
precision
+
recall
)
precision
*
recall
/
(
precision
+
recall
)
matchedSum
+=
detMatched
numGlobalCareGt
+=
numGtCare
...
...
@@ -200,7 +200,8 @@ class DetectionIoUEvaluator(object):
methodPrecision
=
0
if
numGlobalCareDet
==
0
else
float
(
matchedSum
)
/
numGlobalCareDet
methodHmean
=
0
if
methodRecall
+
methodPrecision
==
0
else
2
*
\
methodRecall
*
methodPrecision
/
(
methodRecall
+
methodPrecision
)
methodRecall
*
methodPrecision
/
(
methodRecall
+
methodPrecision
)
# print(methodRecall, methodPrecision, methodHmean)
# sys.exit(-1)
methodMetrics
=
{
...
...
ppocr/modeling/backbones/__init__.py
View file @
76320bf0
...
...
@@ -26,6 +26,9 @@ def build_backbone(config, model_type):
from
.rec_resnet_vd
import
ResNet
from
.rec_resnet_fpn
import
ResNetFPN
support_dict
=
[
'MobileNetV3'
,
'ResNet'
,
'ResNetFPN'
]
elif
model_type
==
'e2e'
:
from
.e2e_resnet_vd_pg
import
ResNet
support_dict
=
[
'ResNet'
]
else
:
raise
NotImplementedError
...
...
ppocr/modeling/backbones/e2e_resnet_vd_pg.py
0 → 100644
View file @
76320bf0
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
__all__
=
[
"ResNet"
]
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
groups
=
1
,
is_vd_mode
=
False
,
act
=
None
,
name
=
None
,
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
is_vd_mode
=
is_vd_mode
self
.
_pool2d_avg
=
nn
.
AvgPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
ceil_mode
=
True
)
self
.
_conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
(
kernel_size
-
1
)
//
2
,
groups
=
groups
,
weight_attr
=
ParamAttr
(
name
=
name
+
"_weights"
),
bias_attr
=
False
)
if
name
==
"conv1"
:
bn_name
=
"bn_"
+
name
else
:
bn_name
=
"bn"
+
name
[
3
:]
self
.
_batch_norm
=
nn
.
BatchNorm
(
out_channels
,
act
=
act
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
),
bias_attr
=
ParamAttr
(
bn_name
+
'_offset'
),
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
)
def
forward
(
self
,
inputs
):
y
=
self
.
_conv
(
inputs
)
y
=
self
.
_batch_norm
(
y
)
return
y
class
BottleneckBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
,
shortcut
=
True
,
if_first
=
False
,
name
=
None
):
super
(
BottleneckBlock
,
self
).
__init__
()
self
.
conv0
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
act
=
'relu'
,
name
=
name
+
"_branch2a"
)
self
.
conv1
=
ConvBNLayer
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
stride
=
stride
,
act
=
'relu'
,
name
=
name
+
"_branch2b"
)
self
.
conv2
=
ConvBNLayer
(
in_channels
=
out_channels
,
out_channels
=
out_channels
*
4
,
kernel_size
=
1
,
act
=
None
,
name
=
name
+
"_branch2c"
)
if
not
shortcut
:
self
.
short
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
*
4
,
kernel_size
=
1
,
stride
=
stride
,
is_vd_mode
=
False
if
if_first
else
True
,
name
=
name
+
"_branch1"
)
self
.
shortcut
=
shortcut
def
forward
(
self
,
inputs
):
y
=
self
.
conv0
(
inputs
)
conv1
=
self
.
conv1
(
y
)
conv2
=
self
.
conv2
(
conv1
)
if
self
.
shortcut
:
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
paddle
.
add
(
x
=
short
,
y
=
conv2
)
y
=
F
.
relu
(
y
)
return
y
class
BasicBlock
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
,
shortcut
=
True
,
if_first
=
False
,
name
=
None
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
stride
=
stride
self
.
conv0
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
stride
=
stride
,
act
=
'relu'
,
name
=
name
+
"_branch2a"
)
self
.
conv1
=
ConvBNLayer
(
in_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
act
=
None
,
name
=
name
+
"_branch2b"
)
if
not
shortcut
:
self
.
short
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
stride
=
1
,
is_vd_mode
=
False
if
if_first
else
True
,
name
=
name
+
"_branch1"
)
self
.
shortcut
=
shortcut
def
forward
(
self
,
inputs
):
y
=
self
.
conv0
(
inputs
)
conv1
=
self
.
conv1
(
y
)
if
self
.
shortcut
:
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
paddle
.
add
(
x
=
short
,
y
=
conv1
)
y
=
F
.
relu
(
y
)
return
y
class
ResNet
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
3
,
layers
=
50
,
**
kwargs
):
super
(
ResNet
,
self
).
__init__
()
self
.
layers
=
layers
supported_layers
=
[
18
,
34
,
50
,
101
,
152
,
200
]
assert
layers
in
supported_layers
,
\
"supported layers are {} but input layer is {}"
.
format
(
supported_layers
,
layers
)
if
layers
==
18
:
depth
=
[
2
,
2
,
2
,
2
]
elif
layers
==
34
or
layers
==
50
:
# depth = [3, 4, 6, 3]
depth
=
[
3
,
4
,
6
,
3
,
3
]
elif
layers
==
101
:
depth
=
[
3
,
4
,
23
,
3
]
elif
layers
==
152
:
depth
=
[
3
,
8
,
36
,
3
]
elif
layers
==
200
:
depth
=
[
3
,
12
,
48
,
3
]
num_channels
=
[
64
,
256
,
512
,
1024
,
2048
]
if
layers
>=
50
else
[
64
,
64
,
128
,
256
]
num_filters
=
[
64
,
128
,
256
,
512
,
512
]
self
.
conv1_1
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
64
,
kernel_size
=
7
,
stride
=
2
,
act
=
'relu'
,
name
=
"conv1_1"
)
self
.
pool2d_max
=
nn
.
MaxPool2D
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
stages
=
[]
self
.
out_channels
=
[
3
,
64
]
# num_filters = [64, 128, 256, 512, 512]
if
layers
>=
50
:
for
block
in
range
(
len
(
depth
)):
block_list
=
[]
shortcut
=
False
for
i
in
range
(
depth
[
block
]):
if
layers
in
[
101
,
152
]
and
block
==
2
:
if
i
==
0
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"a"
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"b"
+
str
(
i
)
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
bottleneck_block
=
self
.
add_sublayer
(
'bb_%d_%d'
%
(
block
,
i
),
BottleneckBlock
(
in_channels
=
num_channels
[
block
]
if
i
==
0
else
num_filters
[
block
]
*
4
,
out_channels
=
num_filters
[
block
],
stride
=
2
if
i
==
0
and
block
!=
0
else
1
,
shortcut
=
shortcut
,
if_first
=
block
==
i
==
0
,
name
=
conv_name
))
shortcut
=
True
block_list
.
append
(
bottleneck_block
)
self
.
out_channels
.
append
(
num_filters
[
block
]
*
4
)
self
.
stages
.
append
(
nn
.
Sequential
(
*
block_list
))
else
:
for
block
in
range
(
len
(
depth
)):
block_list
=
[]
shortcut
=
False
for
i
in
range
(
depth
[
block
]):
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
basic_block
=
self
.
add_sublayer
(
'bb_%d_%d'
%
(
block
,
i
),
BasicBlock
(
in_channels
=
num_channels
[
block
]
if
i
==
0
else
num_filters
[
block
],
out_channels
=
num_filters
[
block
],
stride
=
2
if
i
==
0
and
block
!=
0
else
1
,
shortcut
=
shortcut
,
if_first
=
block
==
i
==
0
,
name
=
conv_name
))
shortcut
=
True
block_list
.
append
(
basic_block
)
self
.
out_channels
.
append
(
num_filters
[
block
])
self
.
stages
.
append
(
nn
.
Sequential
(
*
block_list
))
def
forward
(
self
,
inputs
):
out
=
[
inputs
]
y
=
self
.
conv1_1
(
inputs
)
out
.
append
(
y
)
y
=
self
.
pool2d_max
(
y
)
for
block
in
self
.
stages
:
y
=
block
(
y
)
out
.
append
(
y
)
return
out
ppocr/modeling/heads/__init__.py
View file @
76320bf0
...
...
@@ -20,6 +20,7 @@ def build_head(config):
from
.det_db_head
import
DBHead
from
.det_east_head
import
EASTHead
from
.det_sast_head
import
SASTHead
from
.e2e_pg_head
import
PGHead
# rec head
from
.rec_ctc_head
import
CTCHead
...
...
@@ -30,8 +31,8 @@ def build_head(config):
from
.cls_head
import
ClsHead
support_dict
=
[
'DBHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'ClsHead'
,
'AttentionHead'
,
'SRNHead'
]
'SRNHead'
,
'PGHead'
]
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
'head only support {}'
.
format
(
...
...
ppocr/modeling/heads/e2e_pg_head.py
0 → 100644
View file @
76320bf0
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
math
import
paddle
from
paddle
import
nn
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
groups
=
1
,
if_act
=
True
,
act
=
None
,
name
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
if_act
=
if_act
self
.
act
=
act
self
.
conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
,
weight_attr
=
ParamAttr
(
name
=
name
+
'_weights'
),
bias_attr
=
False
)
self
.
bn
=
nn
.
BatchNorm
(
num_channels
=
out_channels
,
act
=
act
,
param_attr
=
ParamAttr
(
name
=
"bn_"
+
name
+
"_scale"
),
bias_attr
=
ParamAttr
(
name
=
"bn_"
+
name
+
"_offset"
),
moving_mean_name
=
"bn_"
+
name
+
"_mean"
,
moving_variance_name
=
"bn_"
+
name
+
"_variance"
,
use_global_stats
=
False
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
return
x
class
PGHead
(
nn
.
Layer
):
"""
"""
def
__init__
(
self
,
in_channels
,
**
kwargs
):
super
(
PGHead
,
self
).
__init__
()
self
.
conv_f_score1
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
64
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
'relu'
,
name
=
"conv_f_score{}"
.
format
(
1
))
self
.
conv_f_score2
=
ConvBNLayer
(
in_channels
=
64
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
act
=
'relu'
,
name
=
"conv_f_score{}"
.
format
(
2
))
self
.
conv_f_score3
=
ConvBNLayer
(
in_channels
=
64
,
out_channels
=
128
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
'relu'
,
name
=
"conv_f_score{}"
.
format
(
3
))
self
.
conv1
=
nn
.
Conv2D
(
in_channels
=
128
,
out_channels
=
1
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
groups
=
1
,
weight_attr
=
ParamAttr
(
name
=
"conv_f_score{}"
.
format
(
4
)),
bias_attr
=
False
)
self
.
conv_f_boder1
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
64
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
'relu'
,
name
=
"conv_f_boder{}"
.
format
(
1
))
self
.
conv_f_boder2
=
ConvBNLayer
(
in_channels
=
64
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
act
=
'relu'
,
name
=
"conv_f_boder{}"
.
format
(
2
))
self
.
conv_f_boder3
=
ConvBNLayer
(
in_channels
=
64
,
out_channels
=
128
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
'relu'
,
name
=
"conv_f_boder{}"
.
format
(
3
))
self
.
conv2
=
nn
.
Conv2D
(
in_channels
=
128
,
out_channels
=
4
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
groups
=
1
,
weight_attr
=
ParamAttr
(
name
=
"conv_f_boder{}"
.
format
(
4
)),
bias_attr
=
False
)
self
.
conv_f_char1
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
128
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
'relu'
,
name
=
"conv_f_char{}"
.
format
(
1
))
self
.
conv_f_char2
=
ConvBNLayer
(
in_channels
=
128
,
out_channels
=
128
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
act
=
'relu'
,
name
=
"conv_f_char{}"
.
format
(
2
))
self
.
conv_f_char3
=
ConvBNLayer
(
in_channels
=
128
,
out_channels
=
256
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
'relu'
,
name
=
"conv_f_char{}"
.
format
(
3
))
self
.
conv_f_char4
=
ConvBNLayer
(
in_channels
=
256
,
out_channels
=
256
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
act
=
'relu'
,
name
=
"conv_f_char{}"
.
format
(
4
))
self
.
conv_f_char5
=
ConvBNLayer
(
in_channels
=
256
,
out_channels
=
256
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
'relu'
,
name
=
"conv_f_char{}"
.
format
(
5
))
self
.
conv3
=
nn
.
Conv2D
(
in_channels
=
256
,
out_channels
=
37
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
groups
=
1
,
weight_attr
=
ParamAttr
(
name
=
"conv_f_char{}"
.
format
(
6
)),
bias_attr
=
False
)
self
.
conv_f_direc1
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
64
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
'relu'
,
name
=
"conv_f_direc{}"
.
format
(
1
))
self
.
conv_f_direc2
=
ConvBNLayer
(
in_channels
=
64
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
act
=
'relu'
,
name
=
"conv_f_direc{}"
.
format
(
2
))
self
.
conv_f_direc3
=
ConvBNLayer
(
in_channels
=
64
,
out_channels
=
128
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
'relu'
,
name
=
"conv_f_direc{}"
.
format
(
3
))
self
.
conv4
=
nn
.
Conv2D
(
in_channels
=
128
,
out_channels
=
2
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
groups
=
1
,
weight_attr
=
ParamAttr
(
name
=
"conv_f_direc{}"
.
format
(
4
)),
bias_attr
=
False
)
def
forward
(
self
,
x
):
f_score
=
self
.
conv_f_score1
(
x
)
f_score
=
self
.
conv_f_score2
(
f_score
)
f_score
=
self
.
conv_f_score3
(
f_score
)
f_score
=
self
.
conv1
(
f_score
)
f_score
=
F
.
sigmoid
(
f_score
)
# f_border
f_border
=
self
.
conv_f_boder1
(
x
)
f_border
=
self
.
conv_f_boder2
(
f_border
)
f_border
=
self
.
conv_f_boder3
(
f_border
)
f_border
=
self
.
conv2
(
f_border
)
f_char
=
self
.
conv_f_char1
(
x
)
f_char
=
self
.
conv_f_char2
(
f_char
)
f_char
=
self
.
conv_f_char3
(
f_char
)
f_char
=
self
.
conv_f_char4
(
f_char
)
f_char
=
self
.
conv_f_char5
(
f_char
)
f_char
=
self
.
conv3
(
f_char
)
f_direction
=
self
.
conv_f_direc1
(
x
)
f_direction
=
self
.
conv_f_direc2
(
f_direction
)
f_direction
=
self
.
conv_f_direc3
(
f_direction
)
f_direction
=
self
.
conv4
(
f_direction
)
predicts
=
{}
predicts
[
'f_score'
]
=
f_score
predicts
[
'f_border'
]
=
f_border
predicts
[
'f_char'
]
=
f_char
predicts
[
'f_direction'
]
=
f_direction
return
predicts
ppocr/modeling/necks/__init__.py
View file @
76320bf0
...
...
@@ -14,12 +14,14 @@
__all__
=
[
'build_neck'
]
def
build_neck
(
config
):
from
.db_fpn
import
DBFPN
from
.east_fpn
import
EASTFPN
from
.sast_fpn
import
SASTFPN
from
.rnn
import
SequenceEncoder
support_dict
=
[
'DBFPN'
,
'EASTFPN'
,
'SASTFPN'
,
'SequenceEncoder'
]
from
.pg_fpn
import
PGFPN
support_dict
=
[
'DBFPN'
,
'EASTFPN'
,
'SASTFPN'
,
'SequenceEncoder'
,
'PGFPN'
]
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
'neck only support {}'
.
format
(
...
...
ppocr/modeling/necks/pg_fpn.py
0 → 100644
View file @
76320bf0
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
nn
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
groups
=
1
,
is_vd_mode
=
False
,
act
=
None
,
name
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
is_vd_mode
=
is_vd_mode
self
.
_pool2d_avg
=
nn
.
AvgPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
ceil_mode
=
True
)
self
.
_conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
(
kernel_size
-
1
)
//
2
,
groups
=
groups
,
weight_attr
=
ParamAttr
(
name
=
name
+
"_weights"
),
bias_attr
=
False
)
if
name
==
"conv1"
:
bn_name
=
"bn_"
+
name
else
:
bn_name
=
"bn"
+
name
[
3
:]
self
.
_batch_norm
=
nn
.
BatchNorm
(
out_channels
,
act
=
act
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
),
bias_attr
=
ParamAttr
(
bn_name
+
'_offset'
),
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
,
use_global_stats
=
False
)
def
forward
(
self
,
inputs
):
y
=
self
.
_conv
(
inputs
)
y
=
self
.
_batch_norm
(
y
)
return
y
class
DeConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
4
,
stride
=
2
,
padding
=
1
,
groups
=
1
,
if_act
=
True
,
act
=
None
,
name
=
None
):
super
(
DeConvBNLayer
,
self
).
__init__
()
self
.
if_act
=
if_act
self
.
act
=
act
self
.
deconv
=
nn
.
Conv2DTranspose
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
,
weight_attr
=
ParamAttr
(
name
=
name
+
'_weights'
),
bias_attr
=
False
)
self
.
bn
=
nn
.
BatchNorm
(
num_channels
=
out_channels
,
act
=
act
,
param_attr
=
ParamAttr
(
name
=
"bn_"
+
name
+
"_scale"
),
bias_attr
=
ParamAttr
(
name
=
"bn_"
+
name
+
"_offset"
),
moving_mean_name
=
"bn_"
+
name
+
"_mean"
,
moving_variance_name
=
"bn_"
+
name
+
"_variance"
,
use_global_stats
=
False
)
def
forward
(
self
,
x
):
x
=
self
.
deconv
(
x
)
x
=
self
.
bn
(
x
)
return
x
class
PGFPN
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
**
kwargs
):
super
(
PGFPN
,
self
).
__init__
()
num_inputs
=
[
2048
,
2048
,
1024
,
512
,
256
]
num_outputs
=
[
256
,
256
,
192
,
192
,
128
]
self
.
out_channels
=
128
self
.
conv_bn_layer_1
=
ConvBNLayer
(
in_channels
=
3
,
out_channels
=
32
,
kernel_size
=
3
,
stride
=
1
,
act
=
None
,
name
=
'FPN_d1'
)
self
.
conv_bn_layer_2
=
ConvBNLayer
(
in_channels
=
64
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
act
=
None
,
name
=
'FPN_d2'
)
self
.
conv_bn_layer_3
=
ConvBNLayer
(
in_channels
=
256
,
out_channels
=
128
,
kernel_size
=
3
,
stride
=
1
,
act
=
None
,
name
=
'FPN_d3'
)
self
.
conv_bn_layer_4
=
ConvBNLayer
(
in_channels
=
32
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
2
,
act
=
None
,
name
=
'FPN_d4'
)
self
.
conv_bn_layer_5
=
ConvBNLayer
(
in_channels
=
64
,
out_channels
=
64
,
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
'FPN_d5'
)
self
.
conv_bn_layer_6
=
ConvBNLayer
(
in_channels
=
64
,
out_channels
=
128
,
kernel_size
=
3
,
stride
=
2
,
act
=
None
,
name
=
'FPN_d6'
)
self
.
conv_bn_layer_7
=
ConvBNLayer
(
in_channels
=
128
,
out_channels
=
128
,
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
'FPN_d7'
)
self
.
conv_bn_layer_8
=
ConvBNLayer
(
in_channels
=
128
,
out_channels
=
128
,
kernel_size
=
1
,
stride
=
1
,
act
=
None
,
name
=
'FPN_d8'
)
self
.
conv_h0
=
ConvBNLayer
(
in_channels
=
num_inputs
[
0
],
out_channels
=
num_outputs
[
0
],
kernel_size
=
1
,
stride
=
1
,
act
=
None
,
name
=
"conv_h{}"
.
format
(
0
))
self
.
conv_h1
=
ConvBNLayer
(
in_channels
=
num_inputs
[
1
],
out_channels
=
num_outputs
[
1
],
kernel_size
=
1
,
stride
=
1
,
act
=
None
,
name
=
"conv_h{}"
.
format
(
1
))
self
.
conv_h2
=
ConvBNLayer
(
in_channels
=
num_inputs
[
2
],
out_channels
=
num_outputs
[
2
],
kernel_size
=
1
,
stride
=
1
,
act
=
None
,
name
=
"conv_h{}"
.
format
(
2
))
self
.
conv_h3
=
ConvBNLayer
(
in_channels
=
num_inputs
[
3
],
out_channels
=
num_outputs
[
3
],
kernel_size
=
1
,
stride
=
1
,
act
=
None
,
name
=
"conv_h{}"
.
format
(
3
))
self
.
conv_h4
=
ConvBNLayer
(
in_channels
=
num_inputs
[
4
],
out_channels
=
num_outputs
[
4
],
kernel_size
=
1
,
stride
=
1
,
act
=
None
,
name
=
"conv_h{}"
.
format
(
4
))
self
.
dconv0
=
DeConvBNLayer
(
in_channels
=
num_outputs
[
0
],
out_channels
=
num_outputs
[
0
+
1
],
name
=
"dconv_{}"
.
format
(
0
))
self
.
dconv1
=
DeConvBNLayer
(
in_channels
=
num_outputs
[
1
],
out_channels
=
num_outputs
[
1
+
1
],
act
=
None
,
name
=
"dconv_{}"
.
format
(
1
))
self
.
dconv2
=
DeConvBNLayer
(
in_channels
=
num_outputs
[
2
],
out_channels
=
num_outputs
[
2
+
1
],
act
=
None
,
name
=
"dconv_{}"
.
format
(
2
))
self
.
dconv3
=
DeConvBNLayer
(
in_channels
=
num_outputs
[
3
],
out_channels
=
num_outputs
[
3
+
1
],
act
=
None
,
name
=
"dconv_{}"
.
format
(
3
))
self
.
conv_g1
=
ConvBNLayer
(
in_channels
=
num_outputs
[
1
],
out_channels
=
num_outputs
[
1
],
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
"conv_g{}"
.
format
(
1
))
self
.
conv_g2
=
ConvBNLayer
(
in_channels
=
num_outputs
[
2
],
out_channels
=
num_outputs
[
2
],
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
"conv_g{}"
.
format
(
2
))
self
.
conv_g3
=
ConvBNLayer
(
in_channels
=
num_outputs
[
3
],
out_channels
=
num_outputs
[
3
],
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
"conv_g{}"
.
format
(
3
))
self
.
conv_g4
=
ConvBNLayer
(
in_channels
=
num_outputs
[
4
],
out_channels
=
num_outputs
[
4
],
kernel_size
=
3
,
stride
=
1
,
act
=
'relu'
,
name
=
"conv_g{}"
.
format
(
4
))
self
.
convf
=
ConvBNLayer
(
in_channels
=
num_outputs
[
4
],
out_channels
=
num_outputs
[
4
],
kernel_size
=
1
,
stride
=
1
,
act
=
None
,
name
=
"conv_f{}"
.
format
(
4
))
def
forward
(
self
,
x
):
c0
,
c1
,
c2
,
c3
,
c4
,
c5
,
c6
=
x
# FPN_Down_Fusion
f
=
[
c0
,
c1
,
c2
]
g
=
[
None
,
None
,
None
]
h
=
[
None
,
None
,
None
]
h
[
0
]
=
self
.
conv_bn_layer_1
(
f
[
0
])
h
[
1
]
=
self
.
conv_bn_layer_2
(
f
[
1
])
h
[
2
]
=
self
.
conv_bn_layer_3
(
f
[
2
])
g
[
0
]
=
self
.
conv_bn_layer_4
(
h
[
0
])
g
[
1
]
=
paddle
.
add
(
g
[
0
],
h
[
1
])
g
[
1
]
=
F
.
relu
(
g
[
1
])
g
[
1
]
=
self
.
conv_bn_layer_5
(
g
[
1
])
g
[
1
]
=
self
.
conv_bn_layer_6
(
g
[
1
])
g
[
2
]
=
paddle
.
add
(
g
[
1
],
h
[
2
])
g
[
2
]
=
F
.
relu
(
g
[
2
])
g
[
2
]
=
self
.
conv_bn_layer_7
(
g
[
2
])
f_down
=
self
.
conv_bn_layer_8
(
g
[
2
])
# FPN UP Fusion
f1
=
[
c6
,
c5
,
c4
,
c3
,
c2
]
g
=
[
None
,
None
,
None
,
None
,
None
]
h
=
[
None
,
None
,
None
,
None
,
None
]
h
[
0
]
=
self
.
conv_h0
(
f1
[
0
])
h
[
1
]
=
self
.
conv_h1
(
f1
[
1
])
h
[
2
]
=
self
.
conv_h2
(
f1
[
2
])
h
[
3
]
=
self
.
conv_h3
(
f1
[
3
])
h
[
4
]
=
self
.
conv_h4
(
f1
[
4
])
g
[
0
]
=
self
.
dconv0
(
h
[
0
])
g
[
1
]
=
paddle
.
add
(
g
[
0
],
h
[
1
])
g
[
1
]
=
F
.
relu
(
g
[
1
])
g
[
1
]
=
self
.
conv_g1
(
g
[
1
])
g
[
1
]
=
self
.
dconv1
(
g
[
1
])
g
[
2
]
=
paddle
.
add
(
g
[
1
],
h
[
2
])
g
[
2
]
=
F
.
relu
(
g
[
2
])
g
[
2
]
=
self
.
conv_g2
(
g
[
2
])
g
[
2
]
=
self
.
dconv2
(
g
[
2
])
g
[
3
]
=
paddle
.
add
(
g
[
2
],
h
[
3
])
g
[
3
]
=
F
.
relu
(
g
[
3
])
g
[
3
]
=
self
.
conv_g3
(
g
[
3
])
g
[
3
]
=
self
.
dconv3
(
g
[
3
])
g
[
4
]
=
paddle
.
add
(
x
=
g
[
3
],
y
=
h
[
4
])
g
[
4
]
=
F
.
relu
(
g
[
4
])
g
[
4
]
=
self
.
conv_g4
(
g
[
4
])
f_up
=
self
.
convf
(
g
[
4
])
f_common
=
paddle
.
add
(
f_down
,
f_up
)
f_common
=
F
.
relu
(
f_common
)
return
f_common
ppocr/postprocess/__init__.py
View file @
76320bf0
...
...
@@ -28,10 +28,11 @@ def build_post_process(config, global_config=None):
from
.sast_postprocess
import
SASTPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
support_dict
=
[
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
]
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/postprocess/pg_postprocess.py
0 → 100644
View file @
76320bf0
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
__file__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
from
ppocr.utils.e2e_utils.extract_textpoint
import
*
from
ppocr.utils.e2e_utils.visual
import
*
import
paddle
class
PGPostProcess
(
object
):
"""
The post process for PGNet.
"""
def
__init__
(
self
,
character_dict_path
,
valid_set
,
score_thresh
,
**
kwargs
):
self
.
Lexicon_Table
=
get_dict
(
character_dict_path
)
self
.
valid_set
=
valid_set
self
.
score_thresh
=
score_thresh
# c++ la-nms is faster, but only support python 3.5
self
.
is_python35
=
False
if
sys
.
version_info
.
major
==
3
and
sys
.
version_info
.
minor
==
5
:
self
.
is_python35
=
True
def
__call__
(
self
,
outs_dict
,
shape_list
):
p_score
=
outs_dict
[
'f_score'
]
p_border
=
outs_dict
[
'f_border'
]
p_char
=
outs_dict
[
'f_char'
]
p_direction
=
outs_dict
[
'f_direction'
]
if
isinstance
(
p_score
,
paddle
.
Tensor
):
p_score
=
p_score
[
0
].
numpy
()
p_border
=
p_border
[
0
].
numpy
()
p_direction
=
p_direction
[
0
].
numpy
()
p_char
=
p_char
[
0
].
numpy
()
else
:
p_score
=
p_score
[
0
]
p_border
=
p_border
[
0
]
p_direction
=
p_direction
[
0
]
p_char
=
p_char
[
0
]
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
0
]
is_curved
=
self
.
valid_set
==
"totaltext"
instance_yxs_list
=
generate_pivot_list
(
p_score
,
p_char
,
p_direction
,
score_thresh
=
self
.
score_thresh
,
is_backbone
=
True
,
is_curved
=
is_curved
)
p_char
=
paddle
.
to_tensor
(
np
.
expand_dims
(
p_char
,
axis
=
0
))
char_seq_idx_set
=
[]
for
i
in
range
(
len
(
instance_yxs_list
)):
gather_info_lod
=
paddle
.
to_tensor
(
instance_yxs_list
[
i
])
f_char_map
=
paddle
.
transpose
(
p_char
,
[
0
,
2
,
3
,
1
])
feature_seq
=
paddle
.
gather_nd
(
f_char_map
,
gather_info_lod
)
feature_seq
=
np
.
expand_dims
(
feature_seq
.
numpy
(),
axis
=
0
)
feature_len
=
[
len
(
feature_seq
[
0
])]
featyre_seq
=
paddle
.
to_tensor
(
feature_seq
)
feature_len
=
np
.
array
([
feature_len
]).
astype
(
np
.
int64
)
length
=
paddle
.
to_tensor
(
feature_len
)
seq_pred
=
paddle
.
fluid
.
layers
.
ctc_greedy_decoder
(
input
=
featyre_seq
,
blank
=
36
,
input_length
=
length
)
seq_pred_str
=
seq_pred
[
0
].
numpy
().
tolist
()[
0
]
seq_len
=
seq_pred
[
1
].
numpy
()[
0
][
0
]
temp_t
=
[]
for
c
in
seq_pred_str
[:
seq_len
]:
temp_t
.
append
(
c
)
char_seq_idx_set
.
append
(
temp_t
)
seq_strs
=
[]
for
char_idx_set
in
char_seq_idx_set
:
pr_str
=
''
.
join
([
self
.
Lexicon_Table
[
pos
]
for
pos
in
char_idx_set
])
seq_strs
.
append
(
pr_str
)
poly_list
=
[]
keep_str_list
=
[]
all_point_list
=
[]
all_point_pair_list
=
[]
for
yx_center_line
,
keep_str
in
zip
(
instance_yxs_list
,
seq_strs
):
if
len
(
yx_center_line
)
==
1
:
yx_center_line
.
append
(
yx_center_line
[
-
1
])
offset_expand
=
1.0
if
self
.
valid_set
==
'totaltext'
:
offset_expand
=
1.2
point_pair_list
=
[]
for
batch_id
,
y
,
x
in
yx_center_line
:
offset
=
p_border
[:,
y
,
x
].
reshape
(
2
,
2
)
if
offset_expand
!=
1.0
:
offset_length
=
np
.
linalg
.
norm
(
offset
,
axis
=
1
,
keepdims
=
True
)
expand_length
=
np
.
clip
(
offset_length
*
(
offset_expand
-
1
),
a_min
=
0.5
,
a_max
=
3.0
)
offset_detal
=
offset
/
offset_length
*
expand_length
offset
=
offset
+
offset_detal
ori_yx
=
np
.
array
([
y
,
x
],
dtype
=
np
.
float32
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
4.0
/
np
.
array
(
[
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair_list
.
append
(
point_pair
)
all_point_list
.
append
([
int
(
round
(
x
*
4.0
/
ratio_w
)),
int
(
round
(
y
*
4.0
/
ratio_h
))
])
all_point_pair_list
.
append
(
point_pair
.
round
().
astype
(
np
.
int32
)
.
tolist
())
detected_poly
,
pair_length_info
=
point_pair2poly
(
point_pair_list
)
detected_poly
=
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
=
0.2
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
if
len
(
keep_str
)
<
2
:
continue
keep_str_list
.
append
(
keep_str
)
if
self
.
valid_set
==
'partvgg'
:
middle_point
=
len
(
detected_poly
)
//
2
detected_poly
=
detected_poly
[
[
0
,
middle_point
-
1
,
middle_point
,
-
1
],
:]
poly_list
.
append
(
detected_poly
)
elif
self
.
valid_set
==
'totaltext'
:
poly_list
.
append
(
detected_poly
)
else
:
print
(
'--> Not supported format.'
)
exit
(
-
1
)
data
=
{
'points'
:
poly_list
,
'strs'
:
keep_str_list
,
}
return
data
ppocr/postprocess/sast_postprocess.py
View file @
76320bf0
...
...
@@ -18,6 +18,7 @@ from __future__ import print_function
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
__file__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
))
...
...
@@ -49,12 +50,12 @@ class SASTPostProcess(object):
self
.
shrink_ratio_of_width
=
shrink_ratio_of_width
self
.
expand_scale
=
expand_scale
self
.
tcl_map_thresh
=
tcl_map_thresh
# c++ la-nms is faster, but only support python 3.5
self
.
is_python35
=
False
if
sys
.
version_info
.
major
==
3
and
sys
.
version_info
.
minor
==
5
:
self
.
is_python35
=
True
def
point_pair2poly
(
self
,
point_pair_list
):
"""
Transfer vertical point_pairs into poly point in clockwise.
...
...
@@ -66,31 +67,42 @@ class SASTPostProcess(object):
point_list
[
idx
]
=
point_pair
[
0
]
point_list
[
point_num
-
1
-
idx
]
=
point_pair
[
1
]
return
np
.
array
(
point_list
).
reshape
(
-
1
,
2
)
def
shrink_quad_along_width
(
self
,
quad
,
begin_width_ratio
=
0.
,
end_width_ratio
=
1.
):
def
shrink_quad_along_width
(
self
,
quad
,
begin_width_ratio
=
0.
,
end_width_ratio
=
1.
):
"""
Generate shrink_quad_along_width.
"""
ratio_pair
=
np
.
array
([[
begin_width_ratio
],
[
end_width_ratio
]],
dtype
=
np
.
float32
)
ratio_pair
=
np
.
array
(
[[
begin_width_ratio
],
[
end_width_ratio
]],
dtype
=
np
.
float32
)
p0_1
=
quad
[
0
]
+
(
quad
[
1
]
-
quad
[
0
])
*
ratio_pair
p3_2
=
quad
[
3
]
+
(
quad
[
2
]
-
quad
[
3
])
*
ratio_pair
return
np
.
array
([
p0_1
[
0
],
p0_1
[
1
],
p3_2
[
1
],
p3_2
[
0
]])
def
expand_poly_along_width
(
self
,
poly
,
shrink_ratio_of_width
=
0.3
):
"""
expand poly along width.
"""
point_num
=
poly
.
shape
[
0
]
left_quad
=
np
.
array
([
poly
[
0
],
poly
[
1
],
poly
[
-
2
],
poly
[
-
1
]],
dtype
=
np
.
float32
)
left_quad
=
np
.
array
(
[
poly
[
0
],
poly
[
1
],
poly
[
-
2
],
poly
[
-
1
]],
dtype
=
np
.
float32
)
left_ratio
=
-
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
1
])
+
1e-6
)
left_quad_expand
=
self
.
shrink_quad_along_width
(
left_quad
,
left_ratio
,
1.0
)
right_quad
=
np
.
array
([
poly
[
point_num
//
2
-
2
],
poly
[
point_num
//
2
-
1
],
poly
[
point_num
//
2
],
poly
[
point_num
//
2
+
1
]],
dtype
=
np
.
float32
)
(
np
.
linalg
.
norm
(
left_quad
[
0
]
-
left_quad
[
1
])
+
1e-6
)
left_quad_expand
=
self
.
shrink_quad_along_width
(
left_quad
,
left_ratio
,
1.0
)
right_quad
=
np
.
array
(
[
poly
[
point_num
//
2
-
2
],
poly
[
point_num
//
2
-
1
],
poly
[
point_num
//
2
],
poly
[
point_num
//
2
+
1
]
],
dtype
=
np
.
float32
)
right_ratio
=
1.0
+
\
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
1
])
+
1e-6
)
right_quad_expand
=
self
.
shrink_quad_along_width
(
right_quad
,
0.0
,
right_ratio
)
shrink_ratio_of_width
*
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
3
])
/
\
(
np
.
linalg
.
norm
(
right_quad
[
0
]
-
right_quad
[
1
])
+
1e-6
)
right_quad_expand
=
self
.
shrink_quad_along_width
(
right_quad
,
0.0
,
right_ratio
)
poly
[
0
]
=
left_quad_expand
[
0
]
poly
[
-
1
]
=
left_quad_expand
[
-
1
]
poly
[
point_num
//
2
-
1
]
=
right_quad_expand
[
1
]
...
...
@@ -100,7 +112,7 @@ class SASTPostProcess(object):
def
restore_quad
(
self
,
tcl_map
,
tcl_map_thresh
,
tvo_map
):
"""Restore quad."""
xy_text
=
np
.
argwhere
(
tcl_map
[:,
:,
0
]
>
tcl_map_thresh
)
xy_text
=
xy_text
[:,
::
-
1
]
# (n, 2)
xy_text
=
xy_text
[:,
::
-
1
]
# (n, 2)
# Sort the text boxes via the y axis
xy_text
=
xy_text
[
np
.
argsort
(
xy_text
[:,
1
])]
...
...
@@ -112,7 +124,7 @@ class SASTPostProcess(object):
point_num
=
int
(
tvo_map
.
shape
[
-
1
]
/
2
)
assert
point_num
==
4
tvo_map
=
tvo_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
:]
xy_text_tile
=
np
.
tile
(
xy_text
,
(
1
,
point_num
))
# (n, point_num * 2)
xy_text_tile
=
np
.
tile
(
xy_text
,
(
1
,
point_num
))
# (n, point_num * 2)
quads
=
xy_text_tile
-
tvo_map
return
scores
,
quads
,
xy_text
...
...
@@ -121,14 +133,12 @@ class SASTPostProcess(object):
"""
compute area of a quad.
"""
edge
=
[
(
quad
[
1
][
0
]
-
quad
[
0
][
0
])
*
(
quad
[
1
][
1
]
+
quad
[
0
][
1
]),
(
quad
[
2
][
0
]
-
quad
[
1
][
0
])
*
(
quad
[
2
][
1
]
+
quad
[
1
][
1
]),
(
quad
[
3
][
0
]
-
quad
[
2
][
0
])
*
(
quad
[
3
][
1
]
+
quad
[
2
][
1
]),
(
quad
[
0
][
0
]
-
quad
[
3
][
0
])
*
(
quad
[
0
][
1
]
+
quad
[
3
][
1
])
]
edge
=
[(
quad
[
1
][
0
]
-
quad
[
0
][
0
])
*
(
quad
[
1
][
1
]
+
quad
[
0
][
1
]),
(
quad
[
2
][
0
]
-
quad
[
1
][
0
])
*
(
quad
[
2
][
1
]
+
quad
[
1
][
1
]),
(
quad
[
3
][
0
]
-
quad
[
2
][
0
])
*
(
quad
[
3
][
1
]
+
quad
[
2
][
1
]),
(
quad
[
0
][
0
]
-
quad
[
3
][
0
])
*
(
quad
[
0
][
1
]
+
quad
[
3
][
1
])]
return
np
.
sum
(
edge
)
/
2.
def
nms
(
self
,
dets
):
if
self
.
is_python35
:
import
lanms
...
...
@@ -141,7 +151,7 @@ class SASTPostProcess(object):
"""
Cluster pixels in tcl_map based on quads.
"""
instance_count
=
quads
.
shape
[
0
]
+
1
# contain background
instance_count
=
quads
.
shape
[
0
]
+
1
# contain background
instance_label_map
=
np
.
zeros
(
tcl_map
.
shape
[:
2
],
dtype
=
np
.
int32
)
if
instance_count
==
1
:
return
instance_count
,
instance_label_map
...
...
@@ -149,18 +159,19 @@ class SASTPostProcess(object):
# predict text center
xy_text
=
np
.
argwhere
(
tcl_map
[:,
:,
0
]
>
tcl_map_thresh
)
n
=
xy_text
.
shape
[
0
]
xy_text
=
xy_text
[:,
::
-
1
]
# (n, 2)
tco
=
tco_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
:]
# (n, 2)
xy_text
=
xy_text
[:,
::
-
1
]
# (n, 2)
tco
=
tco_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
:]
# (n, 2)
pred_tc
=
xy_text
-
tco
# get gt text center
m
=
quads
.
shape
[
0
]
gt_tc
=
np
.
mean
(
quads
,
axis
=
1
)
# (m, 2)
gt_tc
=
np
.
mean
(
quads
,
axis
=
1
)
# (m, 2)
pred_tc_tile
=
np
.
tile
(
pred_tc
[:,
np
.
newaxis
,
:],
(
1
,
m
,
1
))
# (n, m, 2)
gt_tc_tile
=
np
.
tile
(
gt_tc
[
np
.
newaxis
,
:,
:],
(
n
,
1
,
1
))
# (n, m, 2)
dist_mat
=
np
.
linalg
.
norm
(
pred_tc_tile
-
gt_tc_tile
,
axis
=
2
)
# (n, m)
xy_text_assign
=
np
.
argmin
(
dist_mat
,
axis
=
1
)
+
1
# (n,)
pred_tc_tile
=
np
.
tile
(
pred_tc
[:,
np
.
newaxis
,
:],
(
1
,
m
,
1
))
# (n, m, 2)
gt_tc_tile
=
np
.
tile
(
gt_tc
[
np
.
newaxis
,
:,
:],
(
n
,
1
,
1
))
# (n, m, 2)
dist_mat
=
np
.
linalg
.
norm
(
pred_tc_tile
-
gt_tc_tile
,
axis
=
2
)
# (n, m)
xy_text_assign
=
np
.
argmin
(
dist_mat
,
axis
=
1
)
+
1
# (n,)
instance_label_map
[
xy_text
[:,
1
],
xy_text
[:,
0
]]
=
xy_text_assign
return
instance_count
,
instance_label_map
...
...
@@ -169,26 +180,47 @@ class SASTPostProcess(object):
"""
Estimate sample points number.
"""
eh
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
3
])
+
np
.
linalg
.
norm
(
quad
[
1
]
-
quad
[
2
]))
/
2.0
ew
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
])
+
np
.
linalg
.
norm
(
quad
[
2
]
-
quad
[
3
]))
/
2.0
eh
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
3
])
+
np
.
linalg
.
norm
(
quad
[
1
]
-
quad
[
2
]))
/
2.0
ew
=
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
])
+
np
.
linalg
.
norm
(
quad
[
2
]
-
quad
[
3
]))
/
2.0
dense_sample_pts_num
=
max
(
2
,
int
(
ew
))
dense_xy_center_line
=
xy_text
[
np
.
linspace
(
0
,
xy_text
.
shape
[
0
]
-
1
,
dense_sample_pts_num
,
endpoint
=
True
,
dtype
=
np
.
float32
).
astype
(
np
.
int32
)]
dense_xy_center_line_diff
=
dense_xy_center_line
[
1
:]
-
dense_xy_center_line
[:
-
1
]
estimate_arc_len
=
np
.
sum
(
np
.
linalg
.
norm
(
dense_xy_center_line_diff
,
axis
=
1
))
dense_xy_center_line
=
xy_text
[
np
.
linspace
(
0
,
xy_text
.
shape
[
0
]
-
1
,
dense_sample_pts_num
,
endpoint
=
True
,
dtype
=
np
.
float32
).
astype
(
np
.
int32
)]
dense_xy_center_line_diff
=
dense_xy_center_line
[
1
:]
-
dense_xy_center_line
[:
-
1
]
estimate_arc_len
=
np
.
sum
(
np
.
linalg
.
norm
(
dense_xy_center_line_diff
,
axis
=
1
))
sample_pts_num
=
max
(
2
,
int
(
estimate_arc_len
/
eh
))
return
sample_pts_num
def
detect_sast
(
self
,
tcl_map
,
tvo_map
,
tbo_map
,
tco_map
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
shrink_ratio_of_width
=
0.3
,
tcl_map_thresh
=
0.5
,
offset_expand
=
1.0
,
out_strid
=
4.0
):
def
detect_sast
(
self
,
tcl_map
,
tvo_map
,
tbo_map
,
tco_map
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
shrink_ratio_of_width
=
0.3
,
tcl_map_thresh
=
0.5
,
offset_expand
=
1.0
,
out_strid
=
4.0
):
"""
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
"""
# restore quad
scores
,
quads
,
xy_text
=
self
.
restore_quad
(
tcl_map
,
tcl_map_thresh
,
tvo_map
)
scores
,
quads
,
xy_text
=
self
.
restore_quad
(
tcl_map
,
tcl_map_thresh
,
tvo_map
)
dets
=
np
.
hstack
((
quads
,
scores
)).
astype
(
np
.
float32
,
copy
=
False
)
dets
=
self
.
nms
(
dets
)
if
dets
.
shape
[
0
]
==
0
:
...
...
@@ -202,7 +234,8 @@ class SASTPostProcess(object):
# instance segmentation
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
instance_count
,
instance_label_map
=
self
.
cluster_by_quads_tco
(
tcl_map
,
tcl_map_thresh
,
quads
,
tco_map
)
instance_count
,
instance_label_map
=
self
.
cluster_by_quads_tco
(
tcl_map
,
tcl_map_thresh
,
quads
,
tco_map
)
# restore single poly with tcl instance.
poly_list
=
[]
...
...
@@ -212,10 +245,10 @@ class SASTPostProcess(object):
q_area
=
quad_areas
[
instance_idx
-
1
]
if
q_area
<
5
:
continue
#
len1
=
float
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
]))
len2
=
float
(
np
.
linalg
.
norm
(
quad
[
1
]
-
quad
[
2
]))
len1
=
float
(
np
.
linalg
.
norm
(
quad
[
0
]
-
quad
[
1
]))
len2
=
float
(
np
.
linalg
.
norm
(
quad
[
1
]
-
quad
[
2
]))
min_len
=
min
(
len1
,
len2
)
if
min_len
<
3
:
continue
...
...
@@ -225,16 +258,18 @@ class SASTPostProcess(object):
continue
# filter low confidence instance
xy_text_scores
=
tcl_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
0
]
xy_text_scores
=
tcl_map
[
xy_text
[:,
1
],
xy_text
[:,
0
],
0
]
if
np
.
sum
(
xy_text_scores
)
/
quad_areas
[
instance_idx
-
1
]
<
0.1
:
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
continue
# sort xy_text
left_center_pt
=
np
.
array
([[(
quad
[
0
,
0
]
+
quad
[
-
1
,
0
])
/
2.0
,
(
quad
[
0
,
1
]
+
quad
[
-
1
,
1
])
/
2.0
]])
# (1, 2)
right_center_pt
=
np
.
array
([[(
quad
[
1
,
0
]
+
quad
[
2
,
0
])
/
2.0
,
(
quad
[
1
,
1
]
+
quad
[
2
,
1
])
/
2.0
]])
# (1, 2)
left_center_pt
=
np
.
array
(
[[(
quad
[
0
,
0
]
+
quad
[
-
1
,
0
])
/
2.0
,
(
quad
[
0
,
1
]
+
quad
[
-
1
,
1
])
/
2.0
]])
# (1, 2)
right_center_pt
=
np
.
array
(
[[(
quad
[
1
,
0
]
+
quad
[
2
,
0
])
/
2.0
,
(
quad
[
1
,
1
]
+
quad
[
2
,
1
])
/
2.0
]])
# (1, 2)
proj_unit_vec
=
(
right_center_pt
-
left_center_pt
)
/
\
(
np
.
linalg
.
norm
(
right_center_pt
-
left_center_pt
)
+
1e-6
)
proj_value
=
np
.
sum
(
xy_text
*
proj_unit_vec
,
axis
=
1
)
...
...
@@ -245,33 +280,45 @@ class SASTPostProcess(object):
sample_pts_num
=
self
.
estimate_sample_pts_num
(
quad
,
xy_text
)
else
:
sample_pts_num
=
self
.
sample_pts_num
xy_center_line
=
xy_text
[
np
.
linspace
(
0
,
xy_text
.
shape
[
0
]
-
1
,
sample_pts_num
,
endpoint
=
True
,
dtype
=
np
.
float32
).
astype
(
np
.
int32
)]
xy_center_line
=
xy_text
[
np
.
linspace
(
0
,
xy_text
.
shape
[
0
]
-
1
,
sample_pts_num
,
endpoint
=
True
,
dtype
=
np
.
float32
).
astype
(
np
.
int32
)]
point_pair_list
=
[]
for
x
,
y
in
xy_center_line
:
# get corresponding offset
offset
=
tbo_map
[
y
,
x
,
:].
reshape
(
2
,
2
)
if
offset_expand
!=
1.0
:
offset_length
=
np
.
linalg
.
norm
(
offset
,
axis
=
1
,
keepdims
=
True
)
expand_length
=
np
.
clip
(
offset_length
*
(
offset_expand
-
1
),
a_min
=
0.5
,
a_max
=
3.0
)
offset_length
=
np
.
linalg
.
norm
(
offset
,
axis
=
1
,
keepdims
=
True
)
expand_length
=
np
.
clip
(
offset_length
*
(
offset_expand
-
1
),
a_min
=
0.5
,
a_max
=
3.0
)
offset_detal
=
offset
/
offset_length
*
expand_length
offset
=
offset
+
offset_detal
# original point
offset
=
offset
+
offset_detal
# original point
ori_yx
=
np
.
array
([
y
,
x
],
dtype
=
np
.
float32
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
out_strid
/
np
.
array
([
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair
=
(
ori_yx
+
offset
)[:,
::
-
1
]
*
out_strid
/
np
.
array
(
[
ratio_w
,
ratio_h
]).
reshape
(
-
1
,
2
)
point_pair_list
.
append
(
point_pair
)
# ndarry: (x, 2), expand poly along width
detected_poly
=
self
.
point_pair2poly
(
point_pair_list
)
detected_poly
=
self
.
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
detected_poly
=
self
.
expand_poly_along_width
(
detected_poly
,
shrink_ratio_of_width
)
detected_poly
[:,
0
]
=
np
.
clip
(
detected_poly
[:,
0
],
a_min
=
0
,
a_max
=
src_w
)
detected_poly
[:,
1
]
=
np
.
clip
(
detected_poly
[:,
1
],
a_min
=
0
,
a_max
=
src_h
)
poly_list
.
append
(
detected_poly
)
return
poly_list
def
__call__
(
self
,
outs_dict
,
shape_list
):
def
__call__
(
self
,
outs_dict
,
shape_list
):
score_list
=
outs_dict
[
'f_score'
]
border_list
=
outs_dict
[
'f_border'
]
tvo_list
=
outs_dict
[
'f_tvo'
]
...
...
@@ -281,20 +328,28 @@ class SASTPostProcess(object):
border_list
=
border_list
.
numpy
()
tvo_list
=
tvo_list
.
numpy
()
tco_list
=
tco_list
.
numpy
()
img_num
=
len
(
shape_list
)
poly_lists
=
[]
for
ino
in
range
(
img_num
):
p_score
=
score_list
[
ino
].
transpose
((
1
,
2
,
0
))
p_border
=
border_list
[
ino
].
transpose
((
1
,
2
,
0
))
p_tvo
=
tvo_list
[
ino
].
transpose
((
1
,
2
,
0
))
p_tco
=
tco_list
[
ino
].
transpose
((
1
,
2
,
0
))
p_score
=
score_list
[
ino
].
transpose
((
1
,
2
,
0
))
p_border
=
border_list
[
ino
].
transpose
((
1
,
2
,
0
))
p_tvo
=
tvo_list
[
ino
].
transpose
((
1
,
2
,
0
))
p_tco
=
tco_list
[
ino
].
transpose
((
1
,
2
,
0
))
src_h
,
src_w
,
ratio_h
,
ratio_w
=
shape_list
[
ino
]
poly_list
=
self
.
detect_sast
(
p_score
,
p_tvo
,
p_border
,
p_tco
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
shrink_ratio_of_width
=
self
.
shrink_ratio_of_width
,
tcl_map_thresh
=
self
.
tcl_map_thresh
,
offset_expand
=
self
.
expand_scale
)
poly_list
=
self
.
detect_sast
(
p_score
,
p_tvo
,
p_border
,
p_tco
,
ratio_w
,
ratio_h
,
src_w
,
src_h
,
shrink_ratio_of_width
=
self
.
shrink_ratio_of_width
,
tcl_map_thresh
=
self
.
tcl_map_thresh
,
offset_expand
=
self
.
expand_scale
)
poly_lists
.
append
({
'points'
:
np
.
array
(
poly_list
)})
return
poly_lists
ppocr/utils/dict/arabic_dict.txt
0 → 100644
View file @
76320bf0
!
#
$
%
&
'
(
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
_
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
É
é
ء
آ
أ
ؤ
إ
ئ
ا
ب
ة
ت
ث
ج
ح
خ
د
ذ
ر
ز
س
ش
ص
ض
ط
ظ
ع
غ
ف
ق
ك
ل
م
ن
ه
و
ى
ي
ً
ٌ
ٍ
َ
ُ
ِ
ّ
ْ
ٓ
ٔ
ٰ
ٱ
ٹ
پ
چ
ڈ
ڑ
ژ
ک
ڭ
گ
ں
ھ
ۀ
ہ
ۂ
ۃ
ۆ
ۇ
ۈ
ۋ
ی
ې
ے
ۓ
ە
١
٢
٣
٤
٥
٦
٧
٨
٩
ppocr/utils/dict/cyrillic_dict.txt
0 → 100644
View file @
76320bf0
!
#
$
%
&
'
(
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
_
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
É
é
Ё
Є
І
Ј
Љ
Ў
А
Б
В
Г
Д
Е
Ж
З
И
Й
К
Л
М
Н
О
П
Р
С
Т
У
Ф
Х
Ц
Ч
Ш
Щ
Ъ
Ы
Ь
Э
Ю
Я
а
б
в
г
д
е
ж
з
и
й
к
л
м
н
о
п
р
с
т
у
ф
х
ц
ч
ш
щ
ъ
ы
ь
э
ю
я
ё
ђ
є
і
ј
љ
њ
ћ
ў
џ
Ґ
ґ
ppocr/utils/dict/devanagari_dict.txt
0 → 100644
View file @
76320bf0
!
#
$
%
&
'
(
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
_
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
É
é
ँ
ं
ः
अ
आ
इ
ई
उ
ऊ
ऋ
ए
ऐ
ऑ
ओ
औ
क
ख
ग
घ
ङ
च
छ
ज
झ
ञ
ट
ठ
ड
ढ
ण
त
थ
द
ध
न
ऩ
प
फ
ब
भ
म
य
र
ऱ
ल
ळ
व
श
ष
स
ह
़
ा
ि
ी
ु
ू
ृ
ॅ
े
ै
ॉ
ो
ौ
्
॒
क़
ख़
ग़
ज़
ड़
ढ़
फ़
ॠ
।
०
१
२
३
४
५
६
७
८
९
॰
ppocr/utils/dict/latin_dict.txt
0 → 100644
View file @
76320bf0
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
]
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
{
}
¡
£
§
ª
«
°
²
³
´
µ
·
º
»
¿
À
Á
Â
Ä
Å
Ç
È
É
Ê
Ë
Ì
Í
Î
Ï
Ò
Ó
Ô
Õ
Ö
Ú
Ü
Ý
ß
à
á
â
ã
ä
å
æ
ç
è
é
ê
ë
ì
í
î
ï
ñ
ò
ó
ô
õ
ö
ø
ù
ú
û
ü
ý
ą
Ć
ć
Č
č
Đ
đ
ę
ı
Ł
ł
ō
Œ
œ
Š
š
Ÿ
Ž
ž
ʒ
β
δ
ε
з
Ṡ
‘
€
™
ppocr/utils/e2e_metric/Deteval.py
0 → 100755
View file @
76320bf0
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
from
ppocr.utils.e2e_metric.polygon_fast
import
iod
,
area_of_intersection
,
area
def
get_socre
(
gt_dict
,
pred_dict
):
allInputs
=
1
def
input_reading_mod
(
pred_dict
):
"""This helper reads input from txt files"""
det
=
[]
n
=
len
(
pred_dict
)
for
i
in
range
(
n
):
points
=
pred_dict
[
i
][
'points'
]
text
=
pred_dict
[
i
][
'text'
]
point
=
","
.
join
(
map
(
str
,
points
.
reshape
(
-
1
,
)))
det
.
append
([
point
,
text
])
return
det
def
gt_reading_mod
(
gt_dict
):
"""This helper reads groundtruths from mat files"""
gt
=
[]
n
=
len
(
gt_dict
)
for
i
in
range
(
n
):
points
=
gt_dict
[
i
][
'points'
]
h
=
len
(
points
)
text
=
gt_dict
[
i
][
'text'
]
xx
=
[
np
.
array
(
[
'x:'
],
dtype
=
'<U2'
),
0
,
np
.
array
(
[
'y:'
],
dtype
=
'<U2'
),
0
,
np
.
array
(
[
'#'
],
dtype
=
'<U1'
),
np
.
array
(
[
'#'
],
dtype
=
'<U1'
)
]
t_x
,
t_y
=
[],
[]
for
j
in
range
(
h
):
t_x
.
append
(
points
[
j
][
0
])
t_y
.
append
(
points
[
j
][
1
])
xx
[
1
]
=
np
.
array
([
t_x
],
dtype
=
'int16'
)
xx
[
3
]
=
np
.
array
([
t_y
],
dtype
=
'int16'
)
if
text
!=
""
and
"#"
not
in
text
:
xx
[
4
]
=
np
.
array
([
text
],
dtype
=
'U{}'
.
format
(
len
(
text
)))
xx
[
5
]
=
np
.
array
([
'c'
],
dtype
=
'<U1'
)
gt
.
append
(
xx
)
return
gt
def
detection_filtering
(
detections
,
groundtruths
,
threshold
=
0.5
):
for
gt_id
,
gt
in
enumerate
(
groundtruths
):
if
(
gt
[
5
]
==
'#'
)
and
(
gt
[
1
].
shape
[
1
]
>
1
):
gt_x
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
1
])))
gt_y
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
3
])))
for
det_id
,
detection
in
enumerate
(
detections
):
detection_orig
=
detection
detection
=
[
float
(
x
)
for
x
in
detection
[
0
].
split
(
','
)]
detection
=
list
(
map
(
int
,
detection
))
det_x
=
detection
[
0
::
2
]
det_y
=
detection
[
1
::
2
]
det_gt_iou
=
iod
(
det_x
,
det_y
,
gt_x
,
gt_y
)
if
det_gt_iou
>
threshold
:
detections
[
det_id
]
=
[]
detections
[:]
=
[
item
for
item
in
detections
if
item
!=
[]]
return
detections
def
sigma_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
):
"""
sigma = inter_area / gt_area
"""
return
np
.
round
((
area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
)
/
area
(
gt_x
,
gt_y
)),
2
)
def
tau_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
):
if
area
(
det_x
,
det_y
)
==
0.0
:
return
0
return
np
.
round
((
area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
)
/
area
(
det_x
,
det_y
)),
2
)
##############################Initialization###################################
# global_sigma = []
# global_tau = []
# global_pred_str = []
# global_gt_str = []
###############################################################################
for
input_id
in
range
(
allInputs
):
if
(
input_id
!=
'.DS_Store'
)
and
(
input_id
!=
'Pascal_result.txt'
)
and
(
input_id
!=
'Pascal_result_curved.txt'
)
and
(
input_id
!=
'Pascal_result_non_curved.txt'
)
and
(
input_id
!=
'Deteval_result.txt'
)
and
(
input_id
!=
'Deteval_result_curved.txt'
)
\
and
(
input_id
!=
'Deteval_result_non_curved.txt'
):
detections
=
input_reading_mod
(
pred_dict
)
groundtruths
=
gt_reading_mod
(
gt_dict
)
detections
=
detection_filtering
(
detections
,
groundtruths
)
# filters detections overlapping with DC area
dc_id
=
[]
for
i
in
range
(
len
(
groundtruths
)):
if
groundtruths
[
i
][
5
]
==
'#'
:
dc_id
.
append
(
i
)
cnt
=
0
for
a
in
dc_id
:
num
=
a
-
cnt
del
groundtruths
[
num
]
cnt
+=
1
local_sigma_table
=
np
.
zeros
((
len
(
groundtruths
),
len
(
detections
)))
local_tau_table
=
np
.
zeros
((
len
(
groundtruths
),
len
(
detections
)))
local_pred_str
=
{}
local_gt_str
=
{}
for
gt_id
,
gt
in
enumerate
(
groundtruths
):
if
len
(
detections
)
>
0
:
for
det_id
,
detection
in
enumerate
(
detections
):
detection_orig
=
detection
detection
=
[
float
(
x
)
for
x
in
detection
[
0
].
split
(
','
)]
detection
=
list
(
map
(
int
,
detection
))
pred_seq_str
=
detection_orig
[
1
].
strip
()
det_x
=
detection
[
0
::
2
]
det_y
=
detection
[
1
::
2
]
gt_x
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
1
])))
gt_y
=
list
(
map
(
int
,
np
.
squeeze
(
gt
[
3
])))
gt_seq_str
=
str
(
gt
[
4
].
tolist
()[
0
])
local_sigma_table
[
gt_id
,
det_id
]
=
sigma_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
)
local_tau_table
[
gt_id
,
det_id
]
=
tau_calculation
(
det_x
,
det_y
,
gt_x
,
gt_y
)
local_pred_str
[
det_id
]
=
pred_seq_str
local_gt_str
[
gt_id
]
=
gt_seq_str
global_sigma
=
local_sigma_table
global_tau
=
local_tau_table
global_pred_str
=
local_pred_str
global_gt_str
=
local_gt_str
single_data
=
{}
single_data
[
'sigma'
]
=
global_sigma
single_data
[
'global_tau'
]
=
global_tau
single_data
[
'global_pred_str'
]
=
global_pred_str
single_data
[
'global_gt_str'
]
=
global_gt_str
return
single_data
def
combine_results
(
all_data
):
tr
=
0.7
tp
=
0.6
fsc_k
=
0.8
k
=
2
global_sigma
=
[]
global_tau
=
[]
global_pred_str
=
[]
global_gt_str
=
[]
for
data
in
all_data
:
global_sigma
.
append
(
data
[
'sigma'
])
global_tau
.
append
(
data
[
'global_tau'
])
global_pred_str
.
append
(
data
[
'global_pred_str'
])
global_gt_str
.
append
(
data
[
'global_gt_str'
])
global_accumulative_recall
=
0
global_accumulative_precision
=
0
total_num_gt
=
0
total_num_det
=
0
hit_str_count
=
0
hit_count
=
0
def
one_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
gt_id
in
range
(
num_gt
):
gt_matching_qualified_sigma_candidates
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
tr
)
gt_matching_num_qualified_sigma_candidates
=
gt_matching_qualified_sigma_candidates
[
0
].
shape
[
0
]
gt_matching_qualified_tau_candidates
=
np
.
where
(
local_tau_table
[
gt_id
,
:]
>
tp
)
gt_matching_num_qualified_tau_candidates
=
gt_matching_qualified_tau_candidates
[
0
].
shape
[
0
]
det_matching_qualified_sigma_candidates
=
np
.
where
(
local_sigma_table
[:,
gt_matching_qualified_sigma_candidates
[
0
]]
>
tr
)
det_matching_num_qualified_sigma_candidates
=
det_matching_qualified_sigma_candidates
[
0
].
shape
[
0
]
det_matching_qualified_tau_candidates
=
np
.
where
(
local_tau_table
[:,
gt_matching_qualified_tau_candidates
[
0
]]
>
tp
)
det_matching_num_qualified_tau_candidates
=
det_matching_qualified_tau_candidates
[
0
].
shape
[
0
]
if
(
gt_matching_num_qualified_sigma_candidates
==
1
)
and
(
gt_matching_num_qualified_tau_candidates
==
1
)
and
\
(
det_matching_num_qualified_sigma_candidates
==
1
)
and
(
det_matching_num_qualified_tau_candidates
==
1
):
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
gt_id
]
=
1
matched_det_id
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
tr
)
# recg start
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
matched_det_id
[
0
].
tolist
()[
0
]]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
det_flag
[
0
,
matched_det_id
]
=
1
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
def
one_to_many
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
gt_id
in
range
(
num_gt
):
# skip the following if the groundtruth was matched
if
gt_flag
[
0
,
gt_id
]
>
0
:
continue
non_zero_in_sigma
=
np
.
where
(
local_sigma_table
[
gt_id
,
:]
>
0
)
num_non_zero_in_sigma
=
non_zero_in_sigma
[
0
].
shape
[
0
]
if
num_non_zero_in_sigma
>=
k
:
####search for all detections that overlaps with this groundtruth
qualified_tau_candidates
=
np
.
where
((
local_tau_table
[
gt_id
,
:]
>=
tp
)
&
(
det_flag
[
0
,
:]
==
0
))
num_qualified_tau_candidates
=
qualified_tau_candidates
[
0
].
shape
[
0
]
if
num_qualified_tau_candidates
==
1
:
if
((
local_tau_table
[
gt_id
,
qualified_tau_candidates
]
>=
tp
)
and
(
local_sigma_table
[
gt_id
,
qualified_tau_candidates
]
>=
tr
)):
# became an one-to-one case
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
gt_id
]
=
1
det_flag
[
0
,
qualified_tau_candidates
]
=
1
# recg start
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
qualified_tau_candidates
[
0
].
tolist
()[
0
]]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
elif
(
np
.
sum
(
local_sigma_table
[
gt_id
,
qualified_tau_candidates
])
>=
tr
):
gt_flag
[
0
,
gt_id
]
=
1
det_flag
[
0
,
qualified_tau_candidates
]
=
1
# recg start
gt_str_cur
=
global_gt_str
[
idy
][
gt_id
]
pred_str_cur
=
global_pred_str
[
idy
][
qualified_tau_candidates
[
0
].
tolist
()[
0
]]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
# recg end
global_accumulative_recall
=
global_accumulative_recall
+
fsc_k
global_accumulative_precision
=
global_accumulative_precision
+
num_qualified_tau_candidates
*
fsc_k
local_accumulative_recall
=
local_accumulative_recall
+
fsc_k
local_accumulative_precision
=
local_accumulative_precision
+
num_qualified_tau_candidates
*
fsc_k
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
def
many_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idy
):
hit_str_num
=
0
for
det_id
in
range
(
num_det
):
# skip the following if the detection was matched
if
det_flag
[
0
,
det_id
]
>
0
:
continue
non_zero_in_tau
=
np
.
where
(
local_tau_table
[:,
det_id
]
>
0
)
num_non_zero_in_tau
=
non_zero_in_tau
[
0
].
shape
[
0
]
if
num_non_zero_in_tau
>=
k
:
####search for all detections that overlaps with this groundtruth
qualified_sigma_candidates
=
np
.
where
((
local_sigma_table
[:,
det_id
]
>=
tp
)
&
(
gt_flag
[
0
,
:]
==
0
))
num_qualified_sigma_candidates
=
qualified_sigma_candidates
[
0
].
shape
[
0
]
if
num_qualified_sigma_candidates
==
1
:
if
((
local_tau_table
[
qualified_sigma_candidates
,
det_id
]
>=
tp
)
and
(
local_sigma_table
[
qualified_sigma_candidates
,
det_id
]
>=
tr
)):
# became an one-to-one case
global_accumulative_recall
=
global_accumulative_recall
+
1.0
global_accumulative_precision
=
global_accumulative_precision
+
1.0
local_accumulative_recall
=
local_accumulative_recall
+
1.0
local_accumulative_precision
=
local_accumulative_precision
+
1.0
gt_flag
[
0
,
qualified_sigma_candidates
]
=
1
det_flag
[
0
,
det_id
]
=
1
# recg start
pred_str_cur
=
global_pred_str
[
idy
][
det_id
]
gt_len
=
len
(
qualified_sigma_candidates
[
0
])
for
idx
in
range
(
gt_len
):
ele_gt_id
=
qualified_sigma_candidates
[
0
].
tolist
()[
idx
]
if
ele_gt_id
not
in
global_gt_str
[
idy
]:
continue
gt_str_cur
=
global_gt_str
[
idy
][
ele_gt_id
]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
break
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
break
# recg end
elif
(
np
.
sum
(
local_tau_table
[
qualified_sigma_candidates
,
det_id
])
>=
tp
):
det_flag
[
0
,
det_id
]
=
1
gt_flag
[
0
,
qualified_sigma_candidates
]
=
1
# recg start
pred_str_cur
=
global_pred_str
[
idy
][
det_id
]
gt_len
=
len
(
qualified_sigma_candidates
[
0
])
for
idx
in
range
(
gt_len
):
ele_gt_id
=
qualified_sigma_candidates
[
0
].
tolist
()[
idx
]
if
ele_gt_id
not
in
global_gt_str
[
idy
]:
continue
gt_str_cur
=
global_gt_str
[
idy
][
ele_gt_id
]
if
pred_str_cur
==
gt_str_cur
:
hit_str_num
+=
1
break
else
:
if
pred_str_cur
.
lower
()
==
gt_str_cur
.
lower
():
hit_str_num
+=
1
break
# recg end
global_accumulative_recall
=
global_accumulative_recall
+
num_qualified_sigma_candidates
*
fsc_k
global_accumulative_precision
=
global_accumulative_precision
+
fsc_k
local_accumulative_recall
=
local_accumulative_recall
+
num_qualified_sigma_candidates
*
fsc_k
local_accumulative_precision
=
local_accumulative_precision
+
fsc_k
return
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
hit_str_num
for
idx
in
range
(
len
(
global_sigma
)):
local_sigma_table
=
np
.
array
(
global_sigma
[
idx
])
local_tau_table
=
global_tau
[
idx
]
num_gt
=
local_sigma_table
.
shape
[
0
]
num_det
=
local_sigma_table
.
shape
[
1
]
total_num_gt
=
total_num_gt
+
num_gt
total_num_det
=
total_num_det
+
num_det
local_accumulative_recall
=
0
local_accumulative_precision
=
0
gt_flag
=
np
.
zeros
((
1
,
num_gt
))
det_flag
=
np
.
zeros
((
1
,
num_det
))
#######first check for one-to-one case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
one_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
#######then check for one-to-many case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
one_to_many
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
\
gt_flag
,
det_flag
,
hit_str_num
=
many_to_one
(
local_sigma_table
,
local_tau_table
,
local_accumulative_recall
,
local_accumulative_precision
,
global_accumulative_recall
,
global_accumulative_precision
,
gt_flag
,
det_flag
,
idx
)
hit_str_count
+=
hit_str_num
try
:
recall
=
global_accumulative_recall
/
total_num_gt
except
ZeroDivisionError
:
recall
=
0
try
:
precision
=
global_accumulative_precision
/
total_num_det
except
ZeroDivisionError
:
precision
=
0
try
:
f_score
=
2
*
precision
*
recall
/
(
precision
+
recall
)
except
ZeroDivisionError
:
f_score
=
0
try
:
seqerr
=
1
-
float
(
hit_str_count
)
/
global_accumulative_recall
except
ZeroDivisionError
:
seqerr
=
1
try
:
recall_e2e
=
float
(
hit_str_count
)
/
total_num_gt
except
ZeroDivisionError
:
recall_e2e
=
0
try
:
precision_e2e
=
float
(
hit_str_count
)
/
total_num_det
except
ZeroDivisionError
:
precision_e2e
=
0
try
:
f_score_e2e
=
2
*
precision_e2e
*
recall_e2e
/
(
precision_e2e
+
recall_e2e
)
except
ZeroDivisionError
:
f_score_e2e
=
0
final
=
{
'total_num_gt'
:
total_num_gt
,
'total_num_det'
:
total_num_det
,
'global_accumulative_recall'
:
global_accumulative_recall
,
'hit_str_count'
:
hit_str_count
,
'recall'
:
recall
,
'precision'
:
precision
,
'f_score'
:
f_score
,
'seqerr'
:
seqerr
,
'recall_e2e'
:
recall_e2e
,
'precision_e2e'
:
precision_e2e
,
'f_score_e2e'
:
f_score_e2e
}
return
final
ppocr/utils/e2e_metric/polygon_fast.py
0 → 100755
View file @
76320bf0
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
from
shapely.geometry
import
Polygon
"""
:param det_x: [1, N] Xs of detection's vertices
:param det_y: [1, N] Ys of detection's vertices
:param gt_x: [1, N] Xs of groundtruth's vertices
:param gt_y: [1, N] Ys of groundtruth's vertices
##############
All the calculation of 'AREA' in this script is handled by:
1) First generating a binary mask with the polygon area filled up with 1's
2) Summing up all the 1's
"""
def
area
(
x
,
y
):
polygon
=
Polygon
(
np
.
stack
([
x
,
y
],
axis
=
1
))
return
float
(
polygon
.
area
)
def
approx_area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
):
"""
This helper determine if both polygons are intersecting with each others with an approximation method.
Area of intersection represented by the minimum bounding rectangular [xmin, ymin, xmax, ymax]
"""
det_ymax
=
np
.
max
(
det_y
)
det_xmax
=
np
.
max
(
det_x
)
det_ymin
=
np
.
min
(
det_y
)
det_xmin
=
np
.
min
(
det_x
)
gt_ymax
=
np
.
max
(
gt_y
)
gt_xmax
=
np
.
max
(
gt_x
)
gt_ymin
=
np
.
min
(
gt_y
)
gt_xmin
=
np
.
min
(
gt_x
)
all_min_ymax
=
np
.
minimum
(
det_ymax
,
gt_ymax
)
all_max_ymin
=
np
.
maximum
(
det_ymin
,
gt_ymin
)
intersect_heights
=
np
.
maximum
(
0.0
,
(
all_min_ymax
-
all_max_ymin
))
all_min_xmax
=
np
.
minimum
(
det_xmax
,
gt_xmax
)
all_max_xmin
=
np
.
maximum
(
det_xmin
,
gt_xmin
)
intersect_widths
=
np
.
maximum
(
0.0
,
(
all_min_xmax
-
all_max_xmin
))
return
intersect_heights
*
intersect_widths
def
area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
):
p1
=
Polygon
(
np
.
stack
([
det_x
,
det_y
],
axis
=
1
)).
buffer
(
0
)
p2
=
Polygon
(
np
.
stack
([
gt_x
,
gt_y
],
axis
=
1
)).
buffer
(
0
)
return
float
(
p1
.
intersection
(
p2
).
area
)
def
area_of_union
(
det_x
,
det_y
,
gt_x
,
gt_y
):
p1
=
Polygon
(
np
.
stack
([
det_x
,
det_y
],
axis
=
1
)).
buffer
(
0
)
p2
=
Polygon
(
np
.
stack
([
gt_x
,
gt_y
],
axis
=
1
)).
buffer
(
0
)
return
float
(
p1
.
union
(
p2
).
area
)
def
iou
(
det_x
,
det_y
,
gt_x
,
gt_y
):
return
area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
)
/
(
area_of_union
(
det_x
,
det_y
,
gt_x
,
gt_y
)
+
1.0
)
def
iod
(
det_x
,
det_y
,
gt_x
,
gt_y
):
"""
This helper determine the fraction of intersection area over detection area
"""
return
area_of_intersection
(
det_x
,
det_y
,
gt_x
,
gt_y
)
/
(
area
(
det_x
,
det_y
)
+
1.0
)
ppocr/utils/e2e_utils/extract_batchsize.py
0 → 100644
View file @
76320bf0
import
paddle
import
numpy
as
np
import
copy
def
org_tcl_rois
(
batch_size
,
pos_lists
,
pos_masks
,
label_lists
,
tcl_bs
):
"""
"""
pos_lists_
,
pos_masks_
,
label_lists_
=
[],
[],
[]
img_bs
=
batch_size
ngpu
=
int
(
batch_size
/
img_bs
)
img_ids
=
np
.
array
(
pos_lists
,
dtype
=
np
.
int32
)[:,
0
,
0
].
copy
()
pos_lists_split
,
pos_masks_split
,
label_lists_split
=
[],
[],
[]
for
i
in
range
(
ngpu
):
pos_lists_split
.
append
([])
pos_masks_split
.
append
([])
label_lists_split
.
append
([])
for
i
in
range
(
img_ids
.
shape
[
0
]):
img_id
=
img_ids
[
i
]
gpu_id
=
int
(
img_id
/
img_bs
)
img_id
=
img_id
%
img_bs
pos_list
=
pos_lists
[
i
].
copy
()
pos_list
[:,
0
]
=
img_id
pos_lists_split
[
gpu_id
].
append
(
pos_list
)
pos_masks_split
[
gpu_id
].
append
(
pos_masks
[
i
].
copy
())
label_lists_split
[
gpu_id
].
append
(
copy
.
deepcopy
(
label_lists
[
i
]))
# repeat or delete
for
i
in
range
(
ngpu
):
vp_len
=
len
(
pos_lists_split
[
i
])
if
vp_len
<=
tcl_bs
:
for
j
in
range
(
0
,
tcl_bs
-
vp_len
):
pos_list
=
pos_lists_split
[
i
][
j
].
copy
()
pos_lists_split
[
i
].
append
(
pos_list
)
pos_mask
=
pos_masks_split
[
i
][
j
].
copy
()
pos_masks_split
[
i
].
append
(
pos_mask
)
label_list
=
copy
.
deepcopy
(
label_lists_split
[
i
][
j
])
label_lists_split
[
i
].
append
(
label_list
)
else
:
for
j
in
range
(
0
,
vp_len
-
tcl_bs
):
c_len
=
len
(
pos_lists_split
[
i
])
pop_id
=
np
.
random
.
permutation
(
c_len
)[
0
]
pos_lists_split
[
i
].
pop
(
pop_id
)
pos_masks_split
[
i
].
pop
(
pop_id
)
label_lists_split
[
i
].
pop
(
pop_id
)
# merge
for
i
in
range
(
ngpu
):
pos_lists_
.
extend
(
pos_lists_split
[
i
])
pos_masks_
.
extend
(
pos_masks_split
[
i
])
label_lists_
.
extend
(
label_lists_split
[
i
])
return
pos_lists_
,
pos_masks_
,
label_lists_
def
pre_process
(
label_list
,
pos_list
,
pos_mask
,
max_text_length
,
max_text_nums
,
pad_num
,
tcl_bs
):
label_list
=
label_list
.
numpy
()
batch
,
_
,
_
,
_
=
label_list
.
shape
pos_list
=
pos_list
.
numpy
()
pos_mask
=
pos_mask
.
numpy
()
pos_list_t
=
[]
pos_mask_t
=
[]
label_list_t
=
[]
for
i
in
range
(
batch
):
for
j
in
range
(
max_text_nums
):
if
pos_mask
[
i
,
j
].
any
():
pos_list_t
.
append
(
pos_list
[
i
][
j
])
pos_mask_t
.
append
(
pos_mask
[
i
][
j
])
label_list_t
.
append
(
label_list
[
i
][
j
])
pos_list
,
pos_mask
,
label_list
=
org_tcl_rois
(
batch
,
pos_list_t
,
pos_mask_t
,
label_list_t
,
tcl_bs
)
label
=
[]
tt
=
[
l
.
tolist
()
for
l
in
label_list
]
for
i
in
range
(
tcl_bs
):
k
=
0
for
j
in
range
(
max_text_length
):
if
tt
[
i
][
j
][
0
]
!=
pad_num
:
k
+=
1
else
:
break
label
.
append
(
k
)
label
=
paddle
.
to_tensor
(
label
)
label
=
paddle
.
cast
(
label
,
dtype
=
'int64'
)
pos_list
=
paddle
.
to_tensor
(
pos_list
)
pos_mask
=
paddle
.
to_tensor
(
pos_mask
)
label_list
=
paddle
.
squeeze
(
paddle
.
to_tensor
(
label_list
),
axis
=
2
)
label_list
=
paddle
.
cast
(
label_list
,
dtype
=
'int32'
)
return
pos_list
,
pos_mask
,
label_list
,
label
ppocr/utils/e2e_utils/extract_textpoint.py
0 → 100644
View file @
76320bf0
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains various CTC decoders."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
cv2
import
math
import
numpy
as
np
from
itertools
import
groupby
from
skimage.morphology._skeletonize
import
thin
def
get_dict
(
character_dict_path
):
character_str
=
""
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
for
line
in
lines
:
line
=
line
.
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
)
character_str
+=
line
dict_character
=
list
(
character_str
)
return
dict_character
def
softmax
(
logits
):
"""
logits: N x d
"""
max_value
=
np
.
max
(
logits
,
axis
=
1
,
keepdims
=
True
)
exp
=
np
.
exp
(
logits
-
max_value
)
exp_sum
=
np
.
sum
(
exp
,
axis
=
1
,
keepdims
=
True
)
dist
=
exp
/
exp_sum
return
dist
def
get_keep_pos_idxs
(
labels
,
remove_blank
=
None
):
"""
Remove duplicate and get pos idxs of keep items.
The value of keep_blank should be [None, 95].
"""
duplicate_len_list
=
[]
keep_pos_idx_list
=
[]
keep_char_idx_list
=
[]
for
k
,
v_
in
groupby
(
labels
):
current_len
=
len
(
list
(
v_
))
if
k
!=
remove_blank
:
current_idx
=
int
(
sum
(
duplicate_len_list
)
+
current_len
//
2
)
keep_pos_idx_list
.
append
(
current_idx
)
keep_char_idx_list
.
append
(
k
)
duplicate_len_list
.
append
(
current_len
)
return
keep_char_idx_list
,
keep_pos_idx_list
def
remove_blank
(
labels
,
blank
=
0
):
new_labels
=
[
x
for
x
in
labels
if
x
!=
blank
]
return
new_labels
def
insert_blank
(
labels
,
blank
=
0
):
new_labels
=
[
blank
]
for
l
in
labels
:
new_labels
+=
[
l
,
blank
]
return
new_labels
def
ctc_greedy_decoder
(
probs_seq
,
blank
=
95
,
keep_blank_in_idxs
=
True
):
"""
CTC greedy (best path) decoder.
"""
raw_str
=
np
.
argmax
(
np
.
array
(
probs_seq
),
axis
=
1
)
remove_blank_in_pos
=
None
if
keep_blank_in_idxs
else
blank
dedup_str
,
keep_idx_list
=
get_keep_pos_idxs
(
raw_str
,
remove_blank
=
remove_blank_in_pos
)
dst_str
=
remove_blank
(
dedup_str
,
blank
=
blank
)
return
dst_str
,
keep_idx_list
def
instance_ctc_greedy_decoder
(
gather_info
,
logits_map
,
keep_blank_in_idxs
=
True
):
"""
gather_info: [[x, y], [x, y] ...]
logits_map: H x W X (n_chars + 1)
"""
_
,
_
,
C
=
logits_map
.
shape
ys
,
xs
=
zip
(
*
gather_info
)
logits_seq
=
logits_map
[
list
(
ys
),
list
(
xs
)]
# n x 96
probs_seq
=
softmax
(
logits_seq
)
dst_str
,
keep_idx_list
=
ctc_greedy_decoder
(
probs_seq
,
blank
=
C
-
1
,
keep_blank_in_idxs
=
keep_blank_in_idxs
)
keep_gather_list
=
[
gather_info
[
idx
]
for
idx
in
keep_idx_list
]
return
dst_str
,
keep_gather_list
def
ctc_decoder_for_image
(
gather_info_list
,
logits_map
,
keep_blank_in_idxs
=
True
):
"""
CTC decoder using multiple processes.
"""
decoder_results
=
[]
for
gather_info
in
gather_info_list
:
res
=
instance_ctc_greedy_decoder
(
gather_info
,
logits_map
,
keep_blank_in_idxs
=
keep_blank_in_idxs
)
decoder_results
.
append
(
res
)
return
decoder_results
def
sort_with_direction
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def
sort_part_with_direction
(
pos_list
,
point_direction
):
pos_list
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
2
)
point_direction
=
np
.
array
(
point_direction
).
reshape
(
-
1
,
2
)
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
pos_proj_leng
=
np
.
sum
(
pos_list
*
average_direction
,
axis
=
1
)
sorted_list
=
pos_list
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
sorted_direction
=
point_direction
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
return
sorted_list
,
sorted_direction
pos_list
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
2
)
point_direction
=
f_direction
[
pos_list
[:,
0
],
pos_list
[:,
1
]]
# x, y
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
sorted_point
,
sorted_direction
=
sort_part_with_direction
(
pos_list
,
point_direction
)
point_num
=
len
(
sorted_point
)
if
point_num
>=
16
:
middle_num
=
point_num
//
2
first_part_point
=
sorted_point
[:
middle_num
]
first_point_direction
=
sorted_direction
[:
middle_num
]
sorted_fist_part_point
,
sorted_fist_part_direction
=
sort_part_with_direction
(
first_part_point
,
first_point_direction
)
last_part_point
=
sorted_point
[
middle_num
:]
last_point_direction
=
sorted_direction
[
middle_num
:]
sorted_last_part_point
,
sorted_last_part_direction
=
sort_part_with_direction
(
last_part_point
,
last_point_direction
)
sorted_point
=
sorted_fist_part_point
+
sorted_last_part_point
sorted_direction
=
sorted_fist_part_direction
+
sorted_last_part_direction
return
sorted_point
,
np
.
array
(
sorted_direction
)
def
add_id
(
pos_list
,
image_id
=
0
):
"""
Add id for gather feature, for inference.
"""
new_list
=
[]
for
item
in
pos_list
:
new_list
.
append
((
image_id
,
item
[
0
],
item
[
1
]))
return
new_list
def
sort_and_expand_with_direction
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
h
,
w
,
_
=
f_direction
.
shape
sorted_list
,
point_direction
=
sort_with_direction
(
pos_list
,
f_direction
)
# expand along
point_num
=
len
(
sorted_list
)
sub_direction_len
=
max
(
point_num
//
3
,
2
)
left_direction
=
point_direction
[:
sub_direction_len
,
:]
right_dirction
=
point_direction
[
point_num
-
sub_direction_len
:,
:]
left_average_direction
=
-
np
.
mean
(
left_direction
,
axis
=
0
,
keepdims
=
True
)
left_average_len
=
np
.
linalg
.
norm
(
left_average_direction
)
left_start
=
np
.
array
(
sorted_list
[
0
])
left_step
=
left_average_direction
/
(
left_average_len
+
1e-6
)
right_average_direction
=
np
.
mean
(
right_dirction
,
axis
=
0
,
keepdims
=
True
)
right_average_len
=
np
.
linalg
.
norm
(
right_average_direction
)
right_step
=
right_average_direction
/
(
right_average_len
+
1e-6
)
right_start
=
np
.
array
(
sorted_list
[
-
1
])
append_num
=
max
(
int
((
left_average_len
+
right_average_len
)
/
2.0
*
0.15
),
1
)
left_list
=
[]
right_list
=
[]
for
i
in
range
(
append_num
):
ly
,
lx
=
np
.
round
(
left_start
+
left_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ly
<
h
and
lx
<
w
and
(
ly
,
lx
)
not
in
left_list
:
left_list
.
append
((
ly
,
lx
))
ry
,
rx
=
np
.
round
(
right_start
+
right_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ry
<
h
and
rx
<
w
and
(
ry
,
rx
)
not
in
right_list
:
right_list
.
append
((
ry
,
rx
))
all_list
=
left_list
[::
-
1
]
+
sorted_list
+
right_list
return
all_list
def
sort_and_expand_with_direction_v2
(
pos_list
,
f_direction
,
binary_tcl_map
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
binary_tcl_map: h x w
"""
h
,
w
,
_
=
f_direction
.
shape
sorted_list
,
point_direction
=
sort_with_direction
(
pos_list
,
f_direction
)
# expand along
point_num
=
len
(
sorted_list
)
sub_direction_len
=
max
(
point_num
//
3
,
2
)
left_direction
=
point_direction
[:
sub_direction_len
,
:]
right_dirction
=
point_direction
[
point_num
-
sub_direction_len
:,
:]
left_average_direction
=
-
np
.
mean
(
left_direction
,
axis
=
0
,
keepdims
=
True
)
left_average_len
=
np
.
linalg
.
norm
(
left_average_direction
)
left_start
=
np
.
array
(
sorted_list
[
0
])
left_step
=
left_average_direction
/
(
left_average_len
+
1e-6
)
right_average_direction
=
np
.
mean
(
right_dirction
,
axis
=
0
,
keepdims
=
True
)
right_average_len
=
np
.
linalg
.
norm
(
right_average_direction
)
right_step
=
right_average_direction
/
(
right_average_len
+
1e-6
)
right_start
=
np
.
array
(
sorted_list
[
-
1
])
append_num
=
max
(
int
((
left_average_len
+
right_average_len
)
/
2.0
*
0.15
),
1
)
max_append_num
=
2
*
append_num
left_list
=
[]
right_list
=
[]
for
i
in
range
(
max_append_num
):
ly
,
lx
=
np
.
round
(
left_start
+
left_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ly
<
h
and
lx
<
w
and
(
ly
,
lx
)
not
in
left_list
:
if
binary_tcl_map
[
ly
,
lx
]
>
0.5
:
left_list
.
append
((
ly
,
lx
))
else
:
break
for
i
in
range
(
max_append_num
):
ry
,
rx
=
np
.
round
(
right_start
+
right_step
*
(
i
+
1
)).
flatten
().
astype
(
'int32'
).
tolist
()
if
ry
<
h
and
rx
<
w
and
(
ry
,
rx
)
not
in
right_list
:
if
binary_tcl_map
[
ry
,
rx
]
>
0.5
:
right_list
.
append
((
ry
,
rx
))
else
:
break
all_list
=
left_list
[::
-
1
]
+
sorted_list
+
right_list
return
all_list
def
generate_pivot_list_curved
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_expand
=
True
,
is_backbone
=
False
,
image_id
=
0
):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score
=
p_score
[
0
]
f_direction
=
f_direction
.
transpose
(
1
,
2
,
0
)
p_tcl_map
=
(
p_score
>
score_thresh
)
*
1.0
skeleton_map
=
thin
(
p_tcl_map
)
instance_count
,
instance_label_map
=
cv2
.
connectedComponents
(
skeleton_map
.
astype
(
np
.
uint8
),
connectivity
=
8
)
# get TCL Instance
all_pos_yxs
=
[]
center_pos_yxs
=
[]
end_points_yxs
=
[]
instance_center_pos_yxs
=
[]
if
instance_count
>
0
:
for
instance_id
in
range
(
1
,
instance_count
):
pos_list
=
[]
ys
,
xs
=
np
.
where
(
instance_label_map
==
instance_id
)
pos_list
=
list
(
zip
(
ys
,
xs
))
### FIX-ME, eliminate outlier
if
len
(
pos_list
)
<
3
:
continue
if
is_expand
:
pos_list_sorted
=
sort_and_expand_with_direction_v2
(
pos_list
,
f_direction
,
p_tcl_map
)
else
:
pos_list_sorted
,
_
=
sort_with_direction
(
pos_list
,
f_direction
)
all_pos_yxs
.
append
(
pos_list_sorted
)
# use decoder to filter backgroud points.
p_char_maps
=
p_char_maps
.
transpose
([
1
,
2
,
0
])
decode_res
=
ctc_decoder_for_image
(
all_pos_yxs
,
logits_map
=
p_char_maps
,
keep_blank_in_idxs
=
True
)
for
decoded_str
,
keep_yxs_list
in
decode_res
:
if
is_backbone
:
keep_yxs_list_with_id
=
add_id
(
keep_yxs_list
,
image_id
=
image_id
)
instance_center_pos_yxs
.
append
(
keep_yxs_list_with_id
)
else
:
end_points_yxs
.
extend
((
keep_yxs_list
[
0
],
keep_yxs_list
[
-
1
]))
center_pos_yxs
.
extend
(
keep_yxs_list
)
if
is_backbone
:
return
instance_center_pos_yxs
else
:
return
center_pos_yxs
,
end_points_yxs
def
generate_pivot_list_horizontal
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_backbone
=
False
,
image_id
=
0
):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score
=
p_score
[
0
]
f_direction
=
f_direction
.
transpose
(
1
,
2
,
0
)
p_tcl_map_bi
=
(
p_score
>
score_thresh
)
*
1.0
instance_count
,
instance_label_map
=
cv2
.
connectedComponents
(
p_tcl_map_bi
.
astype
(
np
.
uint8
),
connectivity
=
8
)
# get TCL Instance
all_pos_yxs
=
[]
center_pos_yxs
=
[]
end_points_yxs
=
[]
instance_center_pos_yxs
=
[]
if
instance_count
>
0
:
for
instance_id
in
range
(
1
,
instance_count
):
pos_list
=
[]
ys
,
xs
=
np
.
where
(
instance_label_map
==
instance_id
)
pos_list
=
list
(
zip
(
ys
,
xs
))
### FIX-ME, eliminate outlier
if
len
(
pos_list
)
<
5
:
continue
# add rule here
main_direction
=
extract_main_direction
(
pos_list
,
f_direction
)
# y x
reference_directin
=
np
.
array
([
0
,
1
]).
reshape
([
-
1
,
2
])
# y x
is_h_angle
=
abs
(
np
.
sum
(
main_direction
*
reference_directin
))
<
math
.
cos
(
math
.
pi
/
180
*
70
)
point_yxs
=
np
.
array
(
pos_list
)
max_y
,
max_x
=
np
.
max
(
point_yxs
,
axis
=
0
)
min_y
,
min_x
=
np
.
min
(
point_yxs
,
axis
=
0
)
is_h_len
=
(
max_y
-
min_y
)
<
1.5
*
(
max_x
-
min_x
)
pos_list_final
=
[]
if
is_h_len
:
xs
=
np
.
unique
(
xs
)
for
x
in
xs
:
ys
=
instance_label_map
[:,
x
].
copy
().
reshape
((
-
1
,
))
y
=
int
(
np
.
where
(
ys
==
instance_id
)[
0
].
mean
())
pos_list_final
.
append
((
y
,
x
))
else
:
ys
=
np
.
unique
(
ys
)
for
y
in
ys
:
xs
=
instance_label_map
[
y
,
:].
copy
().
reshape
((
-
1
,
))
x
=
int
(
np
.
where
(
xs
==
instance_id
)[
0
].
mean
())
pos_list_final
.
append
((
y
,
x
))
pos_list_sorted
,
_
=
sort_with_direction
(
pos_list_final
,
f_direction
)
all_pos_yxs
.
append
(
pos_list_sorted
)
# use decoder to filter backgroud points.
p_char_maps
=
p_char_maps
.
transpose
([
1
,
2
,
0
])
decode_res
=
ctc_decoder_for_image
(
all_pos_yxs
,
logits_map
=
p_char_maps
,
keep_blank_in_idxs
=
True
)
for
decoded_str
,
keep_yxs_list
in
decode_res
:
if
is_backbone
:
keep_yxs_list_with_id
=
add_id
(
keep_yxs_list
,
image_id
=
image_id
)
instance_center_pos_yxs
.
append
(
keep_yxs_list_with_id
)
else
:
end_points_yxs
.
extend
((
keep_yxs_list
[
0
],
keep_yxs_list
[
-
1
]))
center_pos_yxs
.
extend
(
keep_yxs_list
)
if
is_backbone
:
return
instance_center_pos_yxs
else
:
return
center_pos_yxs
,
end_points_yxs
def
generate_pivot_list
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_backbone
=
False
,
is_curved
=
True
,
image_id
=
0
):
"""
Warp all the function together.
"""
if
is_curved
:
return
generate_pivot_list_curved
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
score_thresh
,
is_expand
=
True
,
is_backbone
=
is_backbone
,
image_id
=
image_id
)
else
:
return
generate_pivot_list_horizontal
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
score_thresh
,
is_backbone
=
is_backbone
,
image_id
=
image_id
)
# for refine module
def
extract_main_direction
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
pos_list
=
np
.
array
(
pos_list
)
point_direction
=
f_direction
[
pos_list
[:,
0
],
pos_list
[:,
1
]]
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
average_direction
=
average_direction
/
(
np
.
linalg
.
norm
(
average_direction
)
+
1e-6
)
return
average_direction
def
sort_by_direction_with_image_id_deprecated
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
"""
pos_list_full
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
3
)
pos_list
=
pos_list_full
[:,
1
:]
point_direction
=
f_direction
[
pos_list
[:,
0
],
pos_list
[:,
1
]]
# x, y
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
pos_proj_leng
=
np
.
sum
(
pos_list
*
average_direction
,
axis
=
1
)
sorted_list
=
pos_list_full
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
return
sorted_list
def
sort_by_direction_with_image_id
(
pos_list
,
f_direction
):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def
sort_part_with_direction
(
pos_list_full
,
point_direction
):
pos_list_full
=
np
.
array
(
pos_list_full
).
reshape
(
-
1
,
3
)
pos_list
=
pos_list_full
[:,
1
:]
point_direction
=
np
.
array
(
point_direction
).
reshape
(
-
1
,
2
)
average_direction
=
np
.
mean
(
point_direction
,
axis
=
0
,
keepdims
=
True
)
pos_proj_leng
=
np
.
sum
(
pos_list
*
average_direction
,
axis
=
1
)
sorted_list
=
pos_list_full
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
sorted_direction
=
point_direction
[
np
.
argsort
(
pos_proj_leng
)].
tolist
()
return
sorted_list
,
sorted_direction
pos_list
=
np
.
array
(
pos_list
).
reshape
(
-
1
,
3
)
point_direction
=
f_direction
[
pos_list
[:,
1
],
pos_list
[:,
2
]]
# x, y
point_direction
=
point_direction
[:,
::
-
1
]
# x, y -> y, x
sorted_point
,
sorted_direction
=
sort_part_with_direction
(
pos_list
,
point_direction
)
point_num
=
len
(
sorted_point
)
if
point_num
>=
16
:
middle_num
=
point_num
//
2
first_part_point
=
sorted_point
[:
middle_num
]
first_point_direction
=
sorted_direction
[:
middle_num
]
sorted_fist_part_point
,
sorted_fist_part_direction
=
sort_part_with_direction
(
first_part_point
,
first_point_direction
)
last_part_point
=
sorted_point
[
middle_num
:]
last_point_direction
=
sorted_direction
[
middle_num
:]
sorted_last_part_point
,
sorted_last_part_direction
=
sort_part_with_direction
(
last_part_point
,
last_point_direction
)
sorted_point
=
sorted_fist_part_point
+
sorted_last_part_point
sorted_direction
=
sorted_fist_part_direction
+
sorted_last_part_direction
return
sorted_point
def
generate_pivot_list_tt_inference
(
p_score
,
p_char_maps
,
f_direction
,
score_thresh
=
0.5
,
is_backbone
=
False
,
is_curved
=
True
,
image_id
=
0
):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score
=
p_score
[
0
]
f_direction
=
f_direction
.
transpose
(
1
,
2
,
0
)
p_tcl_map
=
(
p_score
>
score_thresh
)
*
1.0
skeleton_map
=
thin
(
p_tcl_map
)
instance_count
,
instance_label_map
=
cv2
.
connectedComponents
(
skeleton_map
.
astype
(
np
.
uint8
),
connectivity
=
8
)
# get TCL Instance
all_pos_yxs
=
[]
if
instance_count
>
0
:
for
instance_id
in
range
(
1
,
instance_count
):
pos_list
=
[]
ys
,
xs
=
np
.
where
(
instance_label_map
==
instance_id
)
pos_list
=
list
(
zip
(
ys
,
xs
))
### FIX-ME, eliminate outlier
if
len
(
pos_list
)
<
3
:
continue
pos_list_sorted
=
sort_and_expand_with_direction_v2
(
pos_list
,
f_direction
,
p_tcl_map
)
pos_list_sorted_with_id
=
add_id
(
pos_list_sorted
,
image_id
=
image_id
)
all_pos_yxs
.
append
(
pos_list_sorted_with_id
)
return
all_pos_yxs
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment