Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
df001f3c
Commit
df001f3c
authored
Apr 06, 2022
by
Leif
Browse files
Merge remote-tracking branch 'origin/dygraph' into dygraph
parents
9cce1213
bdca6cd7
Changes
76
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
514 additions
and
50 deletions
+514
-50
ppstructure/vqa/README.md
ppstructure/vqa/README.md
+4
-0
requirements.txt
requirements.txt
+0
-1
test_tipc/configs/det_mv3_pse_v2.0/det_mv3_pse.yml
test_tipc/configs/det_mv3_pse_v2.0/det_mv3_pse.yml
+2
-2
test_tipc/configs/det_r18_vd_db_v2_0/train_infer_python.txt
test_tipc/configs/det_r18_vd_db_v2_0/train_infer_python.txt
+0
-0
test_tipc/configs/det_r50_vd_pse_v2_0/det_r50_vd_pse.yml
test_tipc/configs/det_r50_vd_pse_v2_0/det_r50_vd_pse.yml
+2
-2
test_tipc/prepare.sh
test_tipc/prepare.sh
+7
-0
test_tipc/supplementary/config.py
test_tipc/supplementary/config.py
+1
-1
tools/end2end/convert_ppocr_label.py
tools/end2end/convert_ppocr_label.py
+94
-0
tools/end2end/draw_html.py
tools/end2end/draw_html.py
+73
-0
tools/end2end/eval_end2end.py
tools/end2end/eval_end2end.py
+193
-0
tools/end2end/readme.md
tools/end2end/readme.md
+69
-0
tools/infer/predict_det.py
tools/infer/predict_det.py
+8
-22
tools/infer/utility.py
tools/infer/utility.py
+0
-1
tools/infer_vqa_token_ser_re.py
tools/infer_vqa_token_ser_re.py
+1
-1
tools/program.py
tools/program.py
+1
-1
tools/test_hubserving.py
tools/test_hubserving.py
+59
-19
No files found.
ppstructure/vqa/README.md
View file @
df001f3c
...
@@ -242,3 +242,7 @@ python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Archi
...
@@ -242,3 +242,7 @@ python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Archi
-
LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
-
LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
-
microsoft/unilm/layoutxlm, https://github.com/microsoft/unilm/tree/master/layoutxlm
-
microsoft/unilm/layoutxlm, https://github.com/microsoft/unilm/tree/master/layoutxlm
-
XFUND dataset, https://github.com/doc-analysis/XFUND
-
XFUND dataset, https://github.com/doc-analysis/XFUND
## License
The content of this project itself is licensed under the
[
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)
](
https://creativecommons.org/licenses/by-nc-sa/4.0/
)
requirements.txt
View file @
df001f3c
...
@@ -12,4 +12,3 @@ cython
...
@@ -12,4 +12,3 @@ cython
lxml
lxml
premailer
premailer
openpyxl
openpyxl
fasttext
==0.9.1
test_tipc/configs/det_mv3_pse_v2.0/det_mv3_pse.yml
View file @
df001f3c
...
@@ -56,7 +56,7 @@ PostProcess:
...
@@ -56,7 +56,7 @@ PostProcess:
thresh
:
0
thresh
:
0
box_thresh
:
0.85
box_thresh
:
0.85
min_area
:
16
min_area
:
16
box_type
:
box
# 'box
' or 'poly'
box_type
:
quad
# 'quad
' or 'poly'
scale
:
1
scale
:
1
Metric
:
Metric
:
...
...
test_tipc/configs/det_r18_vd_v2_0/train_infer_python.txt
→
test_tipc/configs/det_r18_vd_
db_
v2_0/train_infer_python.txt
View file @
df001f3c
File moved
test_tipc/configs/det_r50_vd_pse_v2_0/det_r50_vd_pse.yml
View file @
df001f3c
...
@@ -55,7 +55,7 @@ PostProcess:
...
@@ -55,7 +55,7 @@ PostProcess:
thresh
:
0
thresh
:
0
box_thresh
:
0.85
box_thresh
:
0.85
min_area
:
16
min_area
:
16
box_type
:
box
# 'box
' or 'poly'
box_type
:
quad
# 'quad
' or 'poly'
scale
:
1
scale
:
1
Metric
:
Metric
:
...
...
test_tipc/prepare.sh
View file @
df001f3c
...
@@ -60,6 +60,13 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
...
@@ -60,6 +60,13 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
ln
-s
./icdar2015_lite ./icdar2015
ln
-s
./icdar2015_lite ./icdar2015
cd
../
cd
../
cd
./inference
&&
tar
xf rec_inference.tar
&&
cd
../
cd
./inference
&&
tar
xf rec_inference.tar
&&
cd
../
if
[
${
model_name
}
==
"ch_PPOCRv2_det"
]
||
[
${
model_name
}
==
"ch_PPOCRv2_det_PACT"
]
;
then
wget
-nc
-P
./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar
--no-check-certificate
cd
./pretrain_models/
&&
tar
xf ch_ppocr_server_v2.0_det_train.tar
&&
cd
../
fi
if
[
${
model_name
}
==
"det_r18_db_v2_0"
]
;
then
wget
-nc
-P
./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams
--no-check-certificate
fi
if
[
${
model_name
}
==
"en_server_pgnetA"
]
;
then
if
[
${
model_name
}
==
"en_server_pgnetA"
]
;
then
wget
-nc
-P
./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar
--no-check-certificate
wget
-nc
-P
./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar
--no-check-certificate
wget
-nc
-P
./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar
--no-check-certificate
wget
-nc
-P
./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar
--no-check-certificate
...
...
test_tipc/supplementary/config.py
View file @
df001f3c
...
@@ -122,7 +122,7 @@ def preprocess(is_train=False):
...
@@ -122,7 +122,7 @@ def preprocess(is_train=False):
log_file
=
'{}/train.log'
.
format
(
save_model_dir
)
log_file
=
'{}/train.log'
.
format
(
save_model_dir
)
else
:
else
:
log_file
=
None
log_file
=
None
logger
=
get_logger
(
name
=
'root'
,
log_file
=
log_file
)
logger
=
get_logger
(
log_file
=
log_file
)
# check if set use_gpu=True in paddlepaddle cpu version
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu
=
config
[
'use_gpu'
]
use_gpu
=
config
[
'use_gpu'
]
...
...
tools/end2end/convert_ppocr_label.py
0 → 100644
View file @
df001f3c
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
json
import
os
def
poly_to_string
(
poly
):
if
len
(
poly
.
shape
)
>
1
:
poly
=
np
.
array
(
poly
).
flatten
()
string
=
"
\t
"
.
join
(
str
(
i
)
for
i
in
poly
)
return
string
def
convert_label
(
label_dir
,
mode
=
"gt"
,
save_dir
=
"./save_results/"
):
if
not
os
.
path
.
exists
(
label_dir
):
raise
ValueError
(
f
"The file
{
label_dir
}
does not exist!"
)
assert
label_dir
!=
save_dir
,
"hahahhaha"
label_file
=
open
(
label_dir
,
'r'
)
data
=
label_file
.
readlines
()
gt_dict
=
{}
for
line
in
data
:
try
:
tmp
=
line
.
split
(
'
\t
'
)
assert
len
(
tmp
)
==
2
,
""
except
:
tmp
=
line
.
strip
().
split
(
' '
)
gt_lists
=
[]
if
tmp
[
0
].
split
(
'/'
)[
0
]
is
not
None
:
img_path
=
tmp
[
0
]
anno
=
json
.
loads
(
tmp
[
1
])
gt_collect
=
[]
for
dic
in
anno
:
#txt = dic['transcription'].replace(' ', '') # ignore blank
txt
=
dic
[
'transcription'
]
if
'score'
in
dic
and
float
(
dic
[
'score'
])
<
0.5
:
continue
if
u
'
\u3000
'
in
txt
:
txt
=
txt
.
replace
(
u
'
\u3000
'
,
u
' '
)
#while ' ' in txt:
# txt = txt.replace(' ', '')
poly
=
np
.
array
(
dic
[
'points'
]).
flatten
()
if
txt
==
"###"
:
txt_tag
=
1
## ignore 1
else
:
txt_tag
=
0
if
mode
==
"gt"
:
gt_label
=
poly_to_string
(
poly
)
+
"
\t
"
+
str
(
txt_tag
)
+
"
\t
"
+
txt
+
"
\n
"
else
:
gt_label
=
poly_to_string
(
poly
)
+
"
\t
"
+
txt
+
"
\n
"
gt_lists
.
append
(
gt_label
)
gt_dict
[
img_path
]
=
gt_lists
else
:
continue
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
)
for
img_name
in
gt_dict
.
keys
():
save_name
=
img_name
.
split
(
"/"
)[
-
1
]
save_file
=
os
.
path
.
join
(
save_dir
,
save_name
+
".txt"
)
with
open
(
save_file
,
"w"
)
as
f
:
f
.
writelines
(
gt_dict
[
img_name
])
print
(
"The convert label saved in {}"
.
format
(
save_dir
))
if
__name__
==
"__main__"
:
ppocr_label_gt
=
"/paddle/Datasets/chinese/test_set/Label_refine_310_V2.txt"
convert_label
(
ppocr_label_gt
,
"gt"
,
"./save_gt_310_V2/"
)
ppocr_label_gt
=
"./infer_results/ch_PPOCRV2_infer.txt"
convert_label
(
ppocr_label_gt_en
,
"pred"
,
"./save_PPOCRV2_infer/"
)
tools/end2end/draw_html.py
0 → 100644
View file @
df001f3c
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
argparse
def
str2bool
(
v
):
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
def
init_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--image_dir"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--save_html_path"
,
type
=
str
,
default
=
"./default.html"
)
parser
.
add_argument
(
"--width"
,
type
=
int
,
default
=
640
)
return
parser
def
parse_args
():
parser
=
init_args
()
return
parser
.
parse_args
()
def
draw_debug_img
(
args
):
html_path
=
args
.
save_html_path
err_cnt
=
0
with
open
(
html_path
,
'w'
)
as
html
:
html
.
write
(
'<html>
\n
<body>
\n
'
)
html
.
write
(
'<table border="1">
\n
'
)
html
.
write
(
"<meta http-equiv=
\"
Content-Type
\"
content=
\"
text/html; charset=utf-8
\"
/>"
)
image_list
=
[]
path
=
args
.
image_dir
for
i
,
filename
in
enumerate
(
sorted
(
os
.
listdir
(
path
))):
if
filename
.
endswith
(
"txt"
):
continue
# The image path
base
=
"{}/{}"
.
format
(
path
,
filename
)
html
.
write
(
"<tr>
\n
"
)
html
.
write
(
f
'<td>
{
filename
}
\n
GT'
)
html
.
write
(
f
'<td>GT
\n
<img src="
{
base
}
" width=
{
args
.
width
}
></td>'
)
html
.
write
(
"</tr>
\n
"
)
html
.
write
(
'<style>
\n
'
)
html
.
write
(
'span {
\n
'
)
html
.
write
(
' color: red;
\n
'
)
html
.
write
(
'}
\n
'
)
html
.
write
(
'</style>
\n
'
)
html
.
write
(
'</table>
\n
'
)
html
.
write
(
'</html>
\n
</body>
\n
'
)
print
(
f
"The html file saved in
{
html_path
}
"
)
return
if
__name__
==
"__main__"
:
args
=
parse_args
()
draw_debug_img
(
args
)
tools/end2end/eval_end2end.py
0 → 100644
View file @
df001f3c
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
re
import
sys
import
shapely
from
shapely.geometry
import
Polygon
import
numpy
as
np
from
collections
import
defaultdict
import
operator
import
editdistance
def
strQ2B
(
ustring
):
rstring
=
""
for
uchar
in
ustring
:
inside_code
=
ord
(
uchar
)
if
inside_code
==
12288
:
inside_code
=
32
elif
(
inside_code
>=
65281
and
inside_code
<=
65374
):
inside_code
-=
65248
rstring
+=
chr
(
inside_code
)
return
rstring
def
polygon_from_str
(
polygon_points
):
"""
Create a shapely polygon object from gt or dt line.
"""
polygon_points
=
np
.
array
(
polygon_points
).
reshape
(
4
,
2
)
polygon
=
Polygon
(
polygon_points
).
convex_hull
return
polygon
def
polygon_iou
(
poly1
,
poly2
):
"""
Intersection over union between two shapely polygons.
"""
if
not
poly1
.
intersects
(
poly2
):
# this test is fast and can accelerate calculation
iou
=
0
else
:
try
:
inter_area
=
poly1
.
intersection
(
poly2
).
area
union_area
=
poly1
.
area
+
poly2
.
area
-
inter_area
iou
=
float
(
inter_area
)
/
union_area
except
shapely
.
geos
.
TopologicalError
:
# except Exception as e:
# print(e)
print
(
'shapely.geos.TopologicalError occured, iou set to 0'
)
iou
=
0
return
iou
def
ed
(
str1
,
str2
):
return
editdistance
.
eval
(
str1
,
str2
)
def
e2e_eval
(
gt_dir
,
res_dir
,
ignore_blank
=
False
):
print
(
'start testing...'
)
iou_thresh
=
0.5
val_names
=
os
.
listdir
(
gt_dir
)
num_gt_chars
=
0
gt_count
=
0
dt_count
=
0
hit
=
0
ed_sum
=
0
for
i
,
val_name
in
enumerate
(
val_names
):
with
open
(
os
.
path
.
join
(
gt_dir
,
val_name
),
encoding
=
'utf-8'
)
as
f
:
gt_lines
=
[
o
.
strip
()
for
o
in
f
.
readlines
()]
gts
=
[]
ignore_masks
=
[]
for
line
in
gt_lines
:
parts
=
line
.
strip
().
split
(
'
\t
'
)
# ignore illegal data
if
len
(
parts
)
<
9
:
continue
assert
(
len
(
parts
)
<
11
)
if
len
(
parts
)
==
9
:
gts
.
append
(
parts
[:
8
]
+
[
''
])
else
:
gts
.
append
(
parts
[:
8
]
+
[
parts
[
-
1
]])
ignore_masks
.
append
(
parts
[
8
])
val_path
=
os
.
path
.
join
(
res_dir
,
val_name
)
if
not
os
.
path
.
exists
(
val_path
):
dt_lines
=
[]
else
:
with
open
(
val_path
,
encoding
=
'utf-8'
)
as
f
:
dt_lines
=
[
o
.
strip
()
for
o
in
f
.
readlines
()]
dts
=
[]
for
line
in
dt_lines
:
# print(line)
parts
=
line
.
strip
().
split
(
"
\t
"
)
assert
(
len
(
parts
)
<
10
),
"line error: {}"
.
format
(
line
)
if
len
(
parts
)
==
8
:
dts
.
append
(
parts
+
[
''
])
else
:
dts
.
append
(
parts
)
dt_match
=
[
False
]
*
len
(
dts
)
gt_match
=
[
False
]
*
len
(
gts
)
all_ious
=
defaultdict
(
tuple
)
for
index_gt
,
gt
in
enumerate
(
gts
):
gt_coors
=
[
float
(
gt_coor
)
for
gt_coor
in
gt
[
0
:
8
]]
gt_poly
=
polygon_from_str
(
gt_coors
)
for
index_dt
,
dt
in
enumerate
(
dts
):
dt_coors
=
[
float
(
dt_coor
)
for
dt_coor
in
dt
[
0
:
8
]]
dt_poly
=
polygon_from_str
(
dt_coors
)
iou
=
polygon_iou
(
dt_poly
,
gt_poly
)
if
iou
>=
iou_thresh
:
all_ious
[(
index_gt
,
index_dt
)]
=
iou
sorted_ious
=
sorted
(
all_ious
.
items
(),
key
=
operator
.
itemgetter
(
1
),
reverse
=
True
)
sorted_gt_dt_pairs
=
[
item
[
0
]
for
item
in
sorted_ious
]
# matched gt and dt
for
gt_dt_pair
in
sorted_gt_dt_pairs
:
index_gt
,
index_dt
=
gt_dt_pair
if
gt_match
[
index_gt
]
==
False
and
dt_match
[
index_dt
]
==
False
:
gt_match
[
index_gt
]
=
True
dt_match
[
index_dt
]
=
True
if
ignore_blank
:
gt_str
=
strQ2B
(
gts
[
index_gt
][
8
]).
replace
(
" "
,
""
)
dt_str
=
strQ2B
(
dts
[
index_dt
][
8
]).
replace
(
" "
,
""
)
else
:
gt_str
=
strQ2B
(
gts
[
index_gt
][
8
])
dt_str
=
strQ2B
(
dts
[
index_dt
][
8
])
if
ignore_masks
[
index_gt
]
==
'0'
:
ed_sum
+=
ed
(
gt_str
,
dt_str
)
num_gt_chars
+=
len
(
gt_str
)
if
gt_str
==
dt_str
:
hit
+=
1
gt_count
+=
1
dt_count
+=
1
# unmatched dt
for
tindex
,
dt_match_flag
in
enumerate
(
dt_match
):
if
dt_match_flag
==
False
:
dt_str
=
dts
[
tindex
][
8
]
gt_str
=
''
ed_sum
+=
ed
(
dt_str
,
gt_str
)
dt_count
+=
1
# unmatched gt
for
tindex
,
gt_match_flag
in
enumerate
(
gt_match
):
if
gt_match_flag
==
False
and
ignore_masks
[
tindex
]
==
'0'
:
dt_str
=
''
gt_str
=
gts
[
tindex
][
8
]
ed_sum
+=
ed
(
gt_str
,
dt_str
)
num_gt_chars
+=
len
(
gt_str
)
gt_count
+=
1
eps
=
1e-9
print
(
'hit, dt_count, gt_count'
,
hit
,
dt_count
,
gt_count
)
precision
=
hit
/
(
dt_count
+
eps
)
recall
=
hit
/
(
gt_count
+
eps
)
fmeasure
=
2.0
*
precision
*
recall
/
(
precision
+
recall
+
eps
)
avg_edit_dist_img
=
ed_sum
/
len
(
val_names
)
avg_edit_dist_field
=
ed_sum
/
(
gt_count
+
eps
)
character_acc
=
1
-
ed_sum
/
(
num_gt_chars
+
eps
)
print
(
'character_acc: %.2f'
%
(
character_acc
*
100
)
+
"%"
)
print
(
'avg_edit_dist_field: %.2f'
%
(
avg_edit_dist_field
))
print
(
'avg_edit_dist_img: %.2f'
%
(
avg_edit_dist_img
))
print
(
'precision: %.2f'
%
(
precision
*
100
)
+
"%"
)
print
(
'recall: %.2f'
%
(
recall
*
100
)
+
"%"
)
print
(
'fmeasure: %.2f'
%
(
fmeasure
*
100
)
+
"%"
)
if
__name__
==
'__main__'
:
# if len(sys.argv) != 3:
# print("python3 ocr_e2e_eval.py gt_dir res_dir")
# exit(-1)
# gt_folder = sys.argv[1]
# pred_folder = sys.argv[2]
gt_folder
=
sys
.
argv
[
1
]
pred_folder
=
sys
.
argv
[
2
]
e2e_eval
(
gt_folder
,
pred_folder
)
tools/end2end/readme.md
0 → 100644
View file @
df001f3c
# 简介
`tools/end2end`
目录下存放了文本检测+文本识别pipeline串联预测的指标评测代码以及可视化工具。本节介绍文本检测+文本识别的端对端指标评估方式。
## 端对端评测步骤
**步骤一:**
运行
`tools/infer/predict_system.py`
,得到保存的结果:
```
python3 tools/infer/predict_system.py --det_model_dir=./ch_PP-OCRv2_det_infer/ --rec_model_dir=./ch_PP-OCRv2_rec_infer/ --image_dir=./datasets/img_dir/ --draw_img_save_dir=./ch_PP-OCRv2_results/ --is_visualize=True
```
文本检测识别可视化图默认保存在
`./ch_PP-OCRv2_results/`
目录下,预测结果默认保存在
`./ch_PP-OCRv2_results/system_results.txt`
中,格式如下:
```
all-sum-510/00224225.jpg [{"transcription": "超赞", "points": [[8.0, 48.0], [157.0, 44.0], [159.0, 115.0], [10.0, 119.0]], "score": "0.99396634"}, {"transcription": "中", "points": [[202.0, 152.0], [230.0, 152.0], [230.0, 163.0], [202.0, 163.0]], "score": "0.09310734"}, {"transcription": "58.0m", "points": [[196.0, 192.0], [444.0, 192.0], [444.0, 240.0], [196.0, 240.0]], "score": "0.44041982"}, {"transcription": "汽配", "points": [[55.0, 263.0], [95.0, 263.0], [95.0, 281.0], [55.0, 281.0]], "score": "0.9986651"}, {"transcription": "成总店", "points": [[120.0, 262.0], [176.0, 262.0], [176.0, 283.0], [120.0, 283.0]], "score": "0.9929402"}, {"transcription": "K", "points": [[237.0, 286.0], [311.0, 286.0], [311.0, 345.0], [237.0, 345.0]], "score": "0.6074794"}, {"transcription": "88:-8", "points": [[203.0, 405.0], [477.0, 414.0], [475.0, 459.0], [201.0, 450.0]], "score": "0.7106863"}]
```
**步骤二:**
将步骤一保存的数据转换为端对端评测需要的数据格式:
修改
`tools/convert_ppocr_label.py`
中的代码,convert_label函数中设置输入标签路径,Mode,保存标签路径等,对预测数据的GTlabel和预测结果的label格式进行转换。
```
ppocr_label_gt = "gt_label.txt"
convert_label(ppocr_label_gt, "gt", "./save_gt_label/")
ppocr_label_gt = "./ch_PP-OCRv2_results/system_results.txt"
convert_label(ppocr_label_gt_en, "pred", "./save_PPOCRV2_infer/")
```
运行
`convert_ppocr_label.py`
:
```
python3 tools/convert_ppocr_label.py
```
得到如下结果:
```
├── ./save_gt_label/
├── ./save_PPOCRV2_infer/
```
**步骤三:**
执行端对端评测,运行
`tools/eval_end2end.py`
计算端对端指标,运行方式如下:
```
python3 tools/eval_end2end.py "gt_label_dir" "predict_label_dir"
```
比如:
```
python3 tools/eval_end2end.py ./save_gt_label/ ./save_PPOCRV2_infer/
```
将得到如下结果,fmeasure为主要关注的指标:
```
hit, dt_count, gt_count 1557 2693 3283
character_acc: 61.77%
avg_edit_dist_field: 3.08
avg_edit_dist_img: 51.82
precision: 57.82%
recall: 47.43%
fmeasure: 52.11%
```
tools/infer/predict_det.py
View file @
df001f3c
...
@@ -150,27 +150,13 @@ class TextDetector(object):
...
@@ -150,27 +150,13 @@ class TextDetector(object):
logger
=
logger
)
logger
=
logger
)
def
order_points_clockwise
(
self
,
pts
):
def
order_points_clockwise
(
self
,
pts
):
"""
rect
=
np
.
zeros
((
4
,
2
),
dtype
=
"float32"
)
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
s
=
pts
.
sum
(
axis
=
1
)
# sort the points based on their x-coordinates
rect
[
0
]
=
pts
[
np
.
argmin
(
s
)]
"""
rect
[
2
]
=
pts
[
np
.
argmax
(
s
)]
xSorted
=
pts
[
np
.
argsort
(
pts
[:,
0
]),
:]
diff
=
np
.
diff
(
pts
,
axis
=
1
)
rect
[
1
]
=
pts
[
np
.
argmin
(
diff
)]
# grab the left-most and right-most points from the sorted
rect
[
3
]
=
pts
[
np
.
argmax
(
diff
)]
# x-roodinate points
leftMost
=
xSorted
[:
2
,
:]
rightMost
=
xSorted
[
2
:,
:]
# now, sort the left-most coordinates according to their
# y-coordinates so we can grab the top-left and bottom-left
# points, respectively
leftMost
=
leftMost
[
np
.
argsort
(
leftMost
[:,
1
]),
:]
(
tl
,
bl
)
=
leftMost
rightMost
=
rightMost
[
np
.
argsort
(
rightMost
[:,
1
]),
:]
(
tr
,
br
)
=
rightMost
rect
=
np
.
array
([
tl
,
tr
,
br
,
bl
],
dtype
=
"float32"
)
return
rect
return
rect
def
clip_det_res
(
self
,
points
,
img_height
,
img_width
):
def
clip_det_res
(
self
,
points
,
img_height
,
img_width
):
...
...
tools/infer/utility.py
View file @
df001f3c
...
@@ -622,7 +622,6 @@ def get_rotate_crop_image(img, points):
...
@@ -622,7 +622,6 @@ def get_rotate_crop_image(img, points):
def
check_gpu
(
use_gpu
):
def
check_gpu
(
use_gpu
):
if
use_gpu
and
not
paddle
.
is_compiled_with_cuda
():
if
use_gpu
and
not
paddle
.
is_compiled_with_cuda
():
use_gpu
=
False
use_gpu
=
False
return
use_gpu
return
use_gpu
...
...
tools/infer_vqa_token_ser_re.py
View file @
df001f3c
...
@@ -151,7 +151,7 @@ def preprocess():
...
@@ -151,7 +151,7 @@ def preprocess():
ser_config
=
load_config
(
FLAGS
.
config_ser
)
ser_config
=
load_config
(
FLAGS
.
config_ser
)
ser_config
=
merge_config
(
ser_config
,
FLAGS
.
opt_ser
)
ser_config
=
merge_config
(
ser_config
,
FLAGS
.
opt_ser
)
logger
=
get_logger
(
name
=
'root'
)
logger
=
get_logger
()
# check if set use_gpu=True in paddlepaddle cpu version
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu
=
config
[
'Global'
][
'use_gpu'
]
use_gpu
=
config
[
'Global'
][
'use_gpu'
]
...
...
tools/program.py
View file @
df001f3c
...
@@ -525,7 +525,7 @@ def preprocess(is_train=False):
...
@@ -525,7 +525,7 @@ def preprocess(is_train=False):
log_file
=
'{}/train.log'
.
format
(
save_model_dir
)
log_file
=
'{}/train.log'
.
format
(
save_model_dir
)
else
:
else
:
log_file
=
None
log_file
=
None
logger
=
get_logger
(
name
=
'root'
,
log_file
=
log_file
)
logger
=
get_logger
(
log_file
=
log_file
)
# check if set use_gpu=True in paddlepaddle cpu version
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu
=
config
[
'Global'
][
'use_gpu'
]
use_gpu
=
config
[
'Global'
][
'use_gpu'
]
...
...
tools/test_hubserving.py
View file @
df001f3c
...
@@ -25,7 +25,9 @@ import numpy as np
...
@@ -25,7 +25,9 @@ import numpy as np
import
time
import
time
from
PIL
import
Image
from
PIL
import
Image
from
ppocr.utils.utility
import
get_image_file_list
from
ppocr.utils.utility
import
get_image_file_list
from
tools.infer.utility
import
draw_ocr
,
draw_boxes
from
tools.infer.utility
import
draw_ocr
,
draw_boxes
,
str2bool
from
ppstructure.utility
import
draw_structure_result
from
ppstructure.predict_system
import
to_excel
import
requests
import
requests
import
json
import
json
...
@@ -69,8 +71,33 @@ def draw_server_result(image_file, res):
...
@@ -69,8 +71,33 @@ def draw_server_result(image_file, res):
return
draw_img
return
draw_img
def
main
(
url
,
image_path
):
def
save_structure_res
(
res
,
save_folder
,
image_file
):
image_file_list
=
get_image_file_list
(
image_path
)
img
=
cv2
.
imread
(
image_file
)
excel_save_folder
=
os
.
path
.
join
(
save_folder
,
os
.
path
.
basename
(
image_file
))
os
.
makedirs
(
excel_save_folder
,
exist_ok
=
True
)
# save res
with
open
(
os
.
path
.
join
(
excel_save_folder
,
'res.txt'
),
'w'
,
encoding
=
'utf8'
)
as
f
:
for
region
in
res
:
if
region
[
'type'
]
==
'Table'
:
excel_path
=
os
.
path
.
join
(
excel_save_folder
,
'{}.xlsx'
.
format
(
region
[
'bbox'
]))
to_excel
(
region
[
'res'
],
excel_path
)
elif
region
[
'type'
]
==
'Figure'
:
x1
,
y1
,
x2
,
y2
=
region
[
'bbox'
]
print
(
region
[
'bbox'
])
roi_img
=
img
[
y1
:
y2
,
x1
:
x2
,
:]
img_path
=
os
.
path
.
join
(
excel_save_folder
,
'{}.jpg'
.
format
(
region
[
'bbox'
]))
cv2
.
imwrite
(
img_path
,
roi_img
)
else
:
for
text_result
in
region
[
'res'
]:
f
.
write
(
'{}
\n
'
.
format
(
json
.
dumps
(
text_result
)))
def
main
(
args
):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
is_visualize
=
False
is_visualize
=
False
headers
=
{
"Content-type"
:
"application/json"
}
headers
=
{
"Content-type"
:
"application/json"
}
cnt
=
0
cnt
=
0
...
@@ -80,38 +107,51 @@ def main(url, image_path):
...
@@ -80,38 +107,51 @@ def main(url, image_path):
if
img
is
None
:
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
continue
img_name
=
os
.
path
.
basename
(
image_file
)
#
发送HTTP请求
#
seed http request
starttime
=
time
.
time
()
starttime
=
time
.
time
()
data
=
{
'images'
:
[
cv2_to_base64
(
img
)]}
data
=
{
'images'
:
[
cv2_to_base64
(
img
)]}
r
=
requests
.
post
(
url
=
url
,
headers
=
headers
,
data
=
json
.
dumps
(
data
))
r
=
requests
.
post
(
url
=
args
.
server_url
,
headers
=
headers
,
data
=
json
.
dumps
(
data
))
elapse
=
time
.
time
()
-
starttime
elapse
=
time
.
time
()
-
starttime
total_time
+=
elapse
total_time
+=
elapse
logger
.
info
(
"Predict time of %s: %.3fs"
%
(
image_file
,
elapse
))
logger
.
info
(
"Predict time of %s: %.3fs"
%
(
image_file
,
elapse
))
res
=
r
.
json
()[
"results"
][
0
]
res
=
r
.
json
()[
"results"
][
0
]
logger
.
info
(
res
)
logger
.
info
(
res
)
if
is_visualize
:
if
args
.
visualize
:
draw_img
=
None
if
'structure_table'
in
args
.
server_url
:
to_excel
(
res
[
'html'
],
'./{}.xlsx'
.
format
(
img_name
))
elif
'structure_system'
in
args
.
server_url
:
save_structure_res
(
res
[
'regions'
],
args
.
output
,
image_file
)
else
:
draw_img
=
draw_server_result
(
image_file
,
res
)
draw_img
=
draw_server_result
(
image_file
,
res
)
if
draw_img
is
not
None
:
if
draw_img
is
not
None
:
draw_img_save
=
"./server_results/"
if
not
os
.
path
.
exists
(
args
.
output
):
if
not
os
.
path
.
exists
(
draw_img_save
):
os
.
makedirs
(
args
.
output
)
os
.
makedirs
(
draw_img_save
)
cv2
.
imwrite
(
cv2
.
imwrite
(
os
.
path
.
join
(
draw_img_save
,
os
.
path
.
basename
(
image_file
)),
os
.
path
.
join
(
args
.
output
,
os
.
path
.
basename
(
image_file
)),
draw_img
[:,
:,
::
-
1
])
draw_img
[:,
:,
::
-
1
])
logger
.
info
(
"The visualized image saved in {}"
.
format
(
logger
.
info
(
"The visualized image saved in {}"
.
format
(
os
.
path
.
join
(
draw_img_save
,
os
.
path
.
basename
(
image_file
))))
os
.
path
.
join
(
args
.
output
,
os
.
path
.
basename
(
image_file
))))
cnt
+=
1
cnt
+=
1
if
cnt
%
100
==
0
:
if
cnt
%
100
==
0
:
logger
.
info
(
"{} processed"
.
format
(
cnt
))
logger
.
info
(
"{} processed"
.
format
(
cnt
))
logger
.
info
(
"avg time cost: {}"
.
format
(
float
(
total_time
)
/
cnt
))
logger
.
info
(
"avg time cost: {}"
.
format
(
float
(
total_time
)
/
cnt
))
def
parse_args
():
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
"args for hub serving"
)
parser
.
add_argument
(
"--server_url"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--image_dir"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--visualize"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--output"
,
type
=
str
,
default
=
'./hubserving_result'
)
args
=
parser
.
parse_args
()
return
args
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
!=
3
:
args
=
parse_args
()
logger
.
info
(
"Usage: %s server_url image_path"
%
sys
.
argv
[
0
])
main
(
args
)
else
:
server_url
=
sys
.
argv
[
1
]
image_path
=
sys
.
argv
[
2
]
main
(
server_url
,
image_path
)
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment