Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
a323fce6
"src/vscode:/vscode.git/clone" did not exist on "ca6a71749c774ca5b1b0f180d909bdac59750da9"
Commit
a323fce6
authored
Jan 05, 2022
by
WenmuZhou
Browse files
vqa code integrated into ppocr training system
parent
1ded2ac4
Changes
54
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
571 additions
and
1002 deletions
+571
-1002
ppocr/modeling/architectures/base_model.py
ppocr/modeling/architectures/base_model.py
+8
-3
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+3
-0
ppocr/modeling/backbones/vqa_layoutlm.py
ppocr/modeling/backbones/vqa_layoutlm.py
+123
-0
ppocr/optimizer/__init__.py
ppocr/optimizer/__init__.py
+3
-1
ppocr/optimizer/optimizer.py
ppocr/optimizer/optimizer.py
+35
-0
ppocr/optimizer/regularizer.py
ppocr/optimizer/regularizer.py
+15
-0
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+4
-1
ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
+51
-0
ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
+93
-0
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+40
-5
ppocr/utils/utility.py
ppocr/utils/utility.py
+19
-1
ppocr/utils/visual.py
ppocr/utils/visual.py
+98
-0
ppstructure/docs/model_list.md
ppstructure/docs/model_list.md
+2
-2
ppstructure/vqa/README.md
ppstructure/vqa/README.md
+64
-158
ppstructure/vqa/eval_re.py
ppstructure/vqa/eval_re.py
+0
-125
ppstructure/vqa/eval_ser.py
ppstructure/vqa/eval_ser.py
+0
-177
ppstructure/vqa/helper/trans_xfun_data.py
ppstructure/vqa/helper/trans_xfun_data.py
+13
-1
ppstructure/vqa/infer.sh
ppstructure/vqa/infer.sh
+0
-61
ppstructure/vqa/infer_re.py
ppstructure/vqa/infer_re.py
+0
-165
ppstructure/vqa/infer_ser.py
ppstructure/vqa/infer_ser.py
+0
-302
No files found.
ppocr/modeling/architectures/base_model.py
View file @
a323fce6
...
@@ -63,8 +63,12 @@ class BaseModel(nn.Layer):
...
@@ -63,8 +63,12 @@ class BaseModel(nn.Layer):
in_channels
=
self
.
neck
.
out_channels
in_channels
=
self
.
neck
.
out_channels
# # build head, head is need for det, rec and cls
# # build head, head is need for det, rec and cls
config
[
"Head"
][
'in_channels'
]
=
in_channels
if
'Head'
not
in
config
or
config
[
'Head'
]
is
None
:
self
.
head
=
build_head
(
config
[
"Head"
])
self
.
use_head
=
False
else
:
self
.
use_head
=
True
config
[
"Head"
][
'in_channels'
]
=
in_channels
self
.
head
=
build_head
(
config
[
"Head"
])
self
.
return_all_feats
=
config
.
get
(
"return_all_feats"
,
False
)
self
.
return_all_feats
=
config
.
get
(
"return_all_feats"
,
False
)
...
@@ -77,7 +81,8 @@ class BaseModel(nn.Layer):
...
@@ -77,7 +81,8 @@ class BaseModel(nn.Layer):
if
self
.
use_neck
:
if
self
.
use_neck
:
x
=
self
.
neck
(
x
)
x
=
self
.
neck
(
x
)
y
[
"neck_out"
]
=
x
y
[
"neck_out"
]
=
x
x
=
self
.
head
(
x
,
targets
=
data
)
if
self
.
use_head
:
x
=
self
.
head
(
x
,
targets
=
data
)
if
isinstance
(
x
,
dict
):
if
isinstance
(
x
,
dict
):
y
.
update
(
x
)
y
.
update
(
x
)
else
:
else
:
...
...
ppocr/modeling/backbones/__init__.py
View file @
a323fce6
...
@@ -43,6 +43,9 @@ def build_backbone(config, model_type):
...
@@ -43,6 +43,9 @@ def build_backbone(config, model_type):
from
.table_resnet_vd
import
ResNet
from
.table_resnet_vd
import
ResNet
from
.table_mobilenet_v3
import
MobileNetV3
from
.table_mobilenet_v3
import
MobileNetV3
support_dict
=
[
"ResNet"
,
"MobileNetV3"
]
support_dict
=
[
"ResNet"
,
"MobileNetV3"
]
elif
model_type
==
'vqa'
:
from
.vqa_layoutlm
import
LayoutLMForSer
,
LayoutXLMForSer
,
LayoutXLMForRe
support_dict
=
[
"LayoutLMForSer"
,
"LayoutXLMForSer"
,
'LayoutXLMForRe'
]
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
...
ppocr/modeling/backbones/vqa_layoutlm.py
0 → 100644
View file @
a323fce6
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
from
paddle
import
nn
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMForTokenClassification
,
LayoutXLMForRelationExtraction
from
paddlenlp.transformers
import
LayoutLMModel
,
LayoutLMForTokenClassification
__all__
=
[
"LayoutXLMForSer"
,
'LayoutLMForSer'
]
class
NLPBaseModel
(
nn
.
Layer
):
def
__init__
(
self
,
base_model_class
,
model_class
,
type
=
'ser'
,
pretrained_model
=
None
,
checkpoints
=
None
,
**
kwargs
):
super
(
NLPBaseModel
,
self
).
__init__
()
assert
pretrained_model
is
not
None
or
checkpoints
is
not
None
,
"one of pretrained_model and checkpoints must be not None"
if
checkpoints
is
not
None
:
self
.
model
=
model_class
.
from_pretrained
(
checkpoints
)
else
:
base_model
=
base_model_class
.
from_pretrained
(
pretrained_model
)
if
type
==
'ser'
:
self
.
model
=
model_class
(
base_model
,
num_classes
=
kwargs
[
'num_classes'
],
dropout
=
None
)
else
:
self
.
model
=
model_class
(
base_model
,
dropout
=
None
)
self
.
out_channels
=
1
class
LayoutXLMForSer
(
NLPBaseModel
):
def
__init__
(
self
,
num_classes
,
pretrained_model
=
'layoutxlm-base-uncased'
,
checkpoints
=
None
,
**
kwargs
):
super
(
LayoutXLMForSer
,
self
).
__init__
(
LayoutXLMModel
,
LayoutXLMForTokenClassification
,
'ser'
,
pretrained_model
,
checkpoints
,
num_classes
=
num_classes
)
def
forward
(
self
,
x
):
x
=
self
.
model
(
input_ids
=
x
[
0
],
bbox
=
x
[
2
],
image
=
x
[
3
],
attention_mask
=
x
[
4
],
token_type_ids
=
x
[
5
],
position_ids
=
None
,
head_mask
=
None
,
labels
=
None
)
return
x
[
0
]
class
LayoutLMForSer
(
NLPBaseModel
):
def
__init__
(
self
,
num_classes
,
pretrained_model
=
'layoutxlm-base-uncased'
,
checkpoints
=
None
,
**
kwargs
):
super
(
LayoutLMForSer
,
self
).
__init__
(
LayoutLMModel
,
LayoutLMForTokenClassification
,
'ser'
,
pretrained_model
,
checkpoints
,
num_classes
=
num_classes
)
def
forward
(
self
,
x
):
x
=
self
.
model
(
input_ids
=
x
[
0
],
bbox
=
x
[
2
],
attention_mask
=
x
[
4
],
token_type_ids
=
x
[
5
],
position_ids
=
None
,
output_hidden_states
=
False
)
return
x
class
LayoutXLMForRe
(
NLPBaseModel
):
def
__init__
(
self
,
pretrained_model
=
'layoutxlm-base-uncased'
,
checkpoints
=
None
,
**
kwargs
):
super
(
LayoutXLMForRe
,
self
).
__init__
(
LayoutXLMModel
,
LayoutXLMForRelationExtraction
,
're'
,
pretrained_model
,
checkpoints
)
def
forward
(
self
,
x
):
x
=
self
.
model
(
input_ids
=
x
[
0
],
bbox
=
x
[
1
],
labels
=
None
,
image
=
x
[
2
],
attention_mask
=
x
[
3
],
token_type_ids
=
x
[
4
],
position_ids
=
None
,
head_mask
=
None
,
entities
=
x
[
5
],
relations
=
x
[
6
])
return
x
ppocr/optimizer/__init__.py
View file @
a323fce6
...
@@ -42,7 +42,9 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
...
@@ -42,7 +42,9 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
# step2 build regularization
# step2 build regularization
if
'regularizer'
in
config
and
config
[
'regularizer'
]
is
not
None
:
if
'regularizer'
in
config
and
config
[
'regularizer'
]
is
not
None
:
reg_config
=
config
.
pop
(
'regularizer'
)
reg_config
=
config
.
pop
(
'regularizer'
)
reg_name
=
reg_config
.
pop
(
'name'
)
+
'Decay'
reg_name
=
reg_config
.
pop
(
'name'
)
if
not
hasattr
(
regularizer
,
reg_name
):
reg_name
+=
'Decay'
reg
=
getattr
(
regularizer
,
reg_name
)(
**
reg_config
)()
reg
=
getattr
(
regularizer
,
reg_name
)(
**
reg_config
)()
else
:
else
:
reg
=
None
reg
=
None
...
...
ppocr/optimizer/optimizer.py
View file @
a323fce6
...
@@ -158,3 +158,38 @@ class Adadelta(object):
...
@@ -158,3 +158,38 @@ class Adadelta(object):
name
=
self
.
name
,
name
=
self
.
name
,
parameters
=
parameters
)
parameters
=
parameters
)
return
opt
return
opt
class
AdamW
(
object
):
def
__init__
(
self
,
learning_rate
=
0.001
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-08
,
weight_decay
=
0.01
,
grad_clip
=
None
,
name
=
None
,
lazy_mode
=
False
,
**
kwargs
):
self
.
learning_rate
=
learning_rate
self
.
beta1
=
beta1
self
.
beta2
=
beta2
self
.
epsilon
=
epsilon
self
.
learning_rate
=
learning_rate
self
.
weight_decay
=
0.01
if
weight_decay
is
None
else
weight_decay
self
.
grad_clip
=
grad_clip
self
.
name
=
name
self
.
lazy_mode
=
lazy_mode
def
__call__
(
self
,
parameters
):
opt
=
optim
.
AdamW
(
learning_rate
=
self
.
learning_rate
,
beta1
=
self
.
beta1
,
beta2
=
self
.
beta2
,
epsilon
=
self
.
epsilon
,
weight_decay
=
self
.
weight_decay
,
grad_clip
=
self
.
grad_clip
,
name
=
self
.
name
,
lazy_mode
=
self
.
lazy_mode
,
parameters
=
parameters
)
return
opt
ppocr/optimizer/regularizer.py
View file @
a323fce6
...
@@ -50,3 +50,18 @@ class L2Decay(object):
...
@@ -50,3 +50,18 @@ class L2Decay(object):
def
__call__
(
self
):
def
__call__
(
self
):
reg
=
paddle
.
regularizer
.
L2Decay
(
self
.
regularization_coeff
)
reg
=
paddle
.
regularizer
.
L2Decay
(
self
.
regularization_coeff
)
return
reg
return
reg
class
ConstDecay
(
object
):
"""
Const L2 Weight Decay Regularization, which encourages the weights to be sparse.
Args:
factor(float): regularization coeff. Default:0.0.
"""
def
__init__
(
self
,
factor
=
0.0
):
super
(
ConstDecay
,
self
).
__init__
()
self
.
regularization_coeff
=
factor
def
__call__
(
self
):
return
self
.
regularization_coeff
ppocr/postprocess/__init__.py
View file @
a323fce6
...
@@ -28,6 +28,8 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
...
@@ -28,6 +28,8 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
TableLabelDecode
,
NRTRLabelDecode
,
SARLabelDecode
,
SEEDLabelDecode
TableLabelDecode
,
NRTRLabelDecode
,
SARLabelDecode
,
SEEDLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
from
.pg_postprocess
import
PGPostProcess
from
.vqa_token_ser_layoutlm_postprocess
import
VQASerTokenLayoutLMPostProcess
from
.vqa_token_re_layoutlm_postprocess
import
VQAReTokenLayoutLMPostProcess
def
build_post_process
(
config
,
global_config
=
None
):
def
build_post_process
(
config
,
global_config
=
None
):
...
@@ -36,7 +38,8 @@ def build_post_process(config, global_config=None):
...
@@ -36,7 +38,8 @@ def build_post_process(config, global_config=None):
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
,
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
,
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'DistillationDBPostProcess'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'DistillationDBPostProcess'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'SEEDLabelDecode'
'SEEDLabelDecode'
,
'VQASerTokenLayoutLMPostProcess'
,
'VQAReTokenLayoutLMPostProcess'
]
]
if
config
[
'name'
]
==
'PSEPostProcess'
:
if
config
[
'name'
]
==
'PSEPostProcess'
:
...
...
ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
0 → 100644
View file @
a323fce6
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
class
VQAReTokenLayoutLMPostProcess
(
object
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
**
kwargs
):
super
(
VQAReTokenLayoutLMPostProcess
,
self
).
__init__
()
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
label
is
not
None
:
return
self
.
_metric
(
preds
,
label
)
else
:
return
self
.
_infer
(
preds
,
*
args
,
**
kwargs
)
def
_metric
(
self
,
preds
,
label
):
return
preds
[
'pred_relations'
],
label
[
6
],
label
[
5
]
def
_infer
(
self
,
preds
,
*
args
,
**
kwargs
):
ser_results
=
kwargs
[
'ser_results'
]
entity_idx_dict_batch
=
kwargs
[
'entity_idx_dict_batch'
]
pred_relations
=
preds
[
'pred_relations'
]
# 进行 relations 到 ocr信息的转换
results
=
[]
for
pred_relation
,
ser_result
,
entity_idx_dict
in
zip
(
pred_relations
,
ser_results
,
entity_idx_dict_batch
):
result
=
[]
used_tail_id
=
[]
for
relation
in
pred_relation
:
if
relation
[
'tail_id'
]
in
used_tail_id
:
continue
used_tail_id
.
append
(
relation
[
'tail_id'
])
ocr_info_head
=
ser_result
[
entity_idx_dict
[
relation
[
'head_id'
]]]
ocr_info_tail
=
ser_result
[
entity_idx_dict
[
relation
[
'tail_id'
]]]
result
.
append
((
ocr_info_head
,
ocr_info_tail
))
results
.
append
(
result
)
return
results
ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
0 → 100644
View file @
a323fce6
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle
from
ppocr.utils.utility
import
load_vqa_bio_label_maps
class
VQASerTokenLayoutLMPostProcess
(
object
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
class_path
,
**
kwargs
):
super
(
VQASerTokenLayoutLMPostProcess
,
self
).
__init__
()
label2id_map
,
self
.
id2label_map
=
load_vqa_bio_label_maps
(
class_path
)
self
.
label2id_map_for_draw
=
dict
()
for
key
in
label2id_map
:
if
key
.
startswith
(
"I-"
):
self
.
label2id_map_for_draw
[
key
]
=
label2id_map
[
"B"
+
key
[
1
:]]
else
:
self
.
label2id_map_for_draw
[
key
]
=
label2id_map
[
key
]
self
.
id2label_map_for_show
=
dict
()
for
key
in
self
.
label2id_map_for_draw
:
val
=
self
.
label2id_map_for_draw
[
key
]
if
key
==
"O"
:
self
.
id2label_map_for_show
[
val
]
=
key
if
key
.
startswith
(
"B-"
)
or
key
.
startswith
(
"I-"
):
self
.
id2label_map_for_show
[
val
]
=
key
[
2
:]
else
:
self
.
id2label_map_for_show
[
val
]
=
key
def
__call__
(
self
,
preds
,
batch
=
None
,
*
args
,
**
kwargs
):
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
if
batch
is
not
None
:
return
self
.
_metric
(
preds
,
batch
[
1
])
else
:
return
self
.
_infer
(
preds
,
**
kwargs
)
def
_metric
(
self
,
preds
,
label
):
pred_idxs
=
preds
.
argmax
(
axis
=
2
)
decode_out_list
=
[[]
for
_
in
range
(
pred_idxs
.
shape
[
0
])]
label_decode_out_list
=
[[]
for
_
in
range
(
pred_idxs
.
shape
[
0
])]
for
i
in
range
(
pred_idxs
.
shape
[
0
]):
for
j
in
range
(
pred_idxs
.
shape
[
1
]):
if
label
[
i
,
j
]
!=
-
100
:
label_decode_out_list
[
i
].
append
(
self
.
id2label_map
[
label
[
i
,
j
]])
decode_out_list
[
i
].
append
(
self
.
id2label_map
[
pred_idxs
[
i
,
j
]])
return
decode_out_list
,
label_decode_out_list
def
_infer
(
self
,
preds
,
attention_masks
,
segment_offset_ids
,
ocr_infos
):
results
=
[]
for
pred
,
attention_mask
,
segment_offset_id
,
ocr_info
in
zip
(
preds
,
attention_masks
,
segment_offset_ids
,
ocr_infos
):
pred
=
np
.
argmax
(
pred
,
axis
=
1
)
pred
=
[
self
.
id2label_map
[
idx
]
for
idx
in
pred
]
for
idx
in
range
(
len
(
segment_offset_id
)):
if
idx
==
0
:
start_id
=
0
else
:
start_id
=
segment_offset_id
[
idx
-
1
]
end_id
=
segment_offset_id
[
idx
]
curr_pred
=
pred
[
start_id
:
end_id
]
curr_pred
=
[
self
.
label2id_map_for_draw
[
p
]
for
p
in
curr_pred
]
if
len
(
curr_pred
)
<=
0
:
pred_id
=
0
else
:
counts
=
np
.
bincount
(
curr_pred
)
pred_id
=
np
.
argmax
(
counts
)
ocr_info
[
idx
][
"pred_id"
]
=
int
(
pred_id
)
ocr_info
[
idx
][
"pred"
]
=
self
.
id2label_map_for_show
[
int
(
pred_id
)]
results
.
append
(
ocr_info
)
return
results
ppocr/utils/save_load.py
View file @
a323fce6
...
@@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger):
...
@@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger):
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
def
load_model
(
config
,
model
,
optimizer
=
None
):
def
load_model
(
config
,
model
,
optimizer
=
None
,
model_type
=
'det'
):
"""
"""
load model from checkpoint or pretrained_model
load model from checkpoint or pretrained_model
"""
"""
...
@@ -53,6 +53,33 @@ def load_model(config, model, optimizer=None):
...
@@ -53,6 +53,33 @@ def load_model(config, model, optimizer=None):
checkpoints
=
global_config
.
get
(
'checkpoints'
)
checkpoints
=
global_config
.
get
(
'checkpoints'
)
pretrained_model
=
global_config
.
get
(
'pretrained_model'
)
pretrained_model
=
global_config
.
get
(
'pretrained_model'
)
best_model_dict
=
{}
best_model_dict
=
{}
if
model_type
==
'vqa'
:
checkpoints
=
config
[
'Architecture'
][
'Backbone'
][
'checkpoints'
]
# load vqa method metric
if
checkpoints
:
if
os
.
path
.
exists
(
os
.
path
.
join
(
checkpoints
,
'metric.states'
)):
with
open
(
os
.
path
.
join
(
checkpoints
,
'metric.states'
),
'rb'
)
as
f
:
states_dict
=
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
f
,
encoding
=
'latin1'
)
best_model_dict
=
states_dict
.
get
(
'best_model_dict'
,
{})
if
'epoch'
in
states_dict
:
best_model_dict
[
'start_epoch'
]
=
states_dict
[
'epoch'
]
+
1
logger
.
info
(
"resume from {}"
.
format
(
checkpoints
))
if
optimizer
is
not
None
:
if
checkpoints
[
-
1
]
in
[
'/'
,
'
\\
'
]:
checkpoints
=
checkpoints
[:
-
1
]
if
os
.
path
.
exists
(
checkpoints
+
'.pdopt'
):
optim_dict
=
paddle
.
load
(
checkpoints
+
'.pdopt'
)
optimizer
.
set_state_dict
(
optim_dict
)
else
:
logger
.
warning
(
"{}.pdopt is not exists, params of optimizer is not loaded"
.
format
(
checkpoints
))
return
best_model_dict
if
checkpoints
:
if
checkpoints
:
if
checkpoints
.
endswith
(
'.pdparams'
):
if
checkpoints
.
endswith
(
'.pdparams'
):
checkpoints
=
checkpoints
.
replace
(
'.pdparams'
,
''
)
checkpoints
=
checkpoints
.
replace
(
'.pdparams'
,
''
)
...
@@ -127,6 +154,7 @@ def save_model(model,
...
@@ -127,6 +154,7 @@ def save_model(model,
optimizer
,
optimizer
,
model_path
,
model_path
,
logger
,
logger
,
config
,
is_best
=
False
,
is_best
=
False
,
prefix
=
'ppocr'
,
prefix
=
'ppocr'
,
**
kwargs
):
**
kwargs
):
...
@@ -135,13 +163,20 @@ def save_model(model,
...
@@ -135,13 +163,20 @@ def save_model(model,
"""
"""
_mkdir_if_not_exist
(
model_path
,
logger
)
_mkdir_if_not_exist
(
model_path
,
logger
)
model_prefix
=
os
.
path
.
join
(
model_path
,
prefix
)
model_prefix
=
os
.
path
.
join
(
model_path
,
prefix
)
paddle
.
save
(
model
.
state_dict
(),
model_prefix
+
'.pdparams'
)
paddle
.
save
(
optimizer
.
state_dict
(),
model_prefix
+
'.pdopt'
)
paddle
.
save
(
optimizer
.
state_dict
(),
model_prefix
+
'.pdopt'
)
if
config
[
'Architecture'
][
"model_type"
]
!=
'vqa'
:
paddle
.
save
(
model
.
state_dict
(),
model_prefix
+
'.pdparams'
)
metric_prefix
=
model_prefix
else
:
if
config
[
'Global'
][
'distributed'
]:
model
.
_layers
.
backbone
.
model
.
save_pretrained
(
model_prefix
)
else
:
model
.
backbone
.
model
.
save_pretrained
(
model_prefix
)
metric_prefix
=
os
.
path
.
join
(
model_prefix
,
'metric'
)
# save metric and config
# save metric and config
with
open
(
model_prefix
+
'.states'
,
'wb'
)
as
f
:
pickle
.
dump
(
kwargs
,
f
,
protocol
=
2
)
if
is_best
:
if
is_best
:
with
open
(
metric_prefix
+
'.states'
,
'wb'
)
as
f
:
pickle
.
dump
(
kwargs
,
f
,
protocol
=
2
)
logger
.
info
(
'save best model is to {}'
.
format
(
model_prefix
))
logger
.
info
(
'save best model is to {}'
.
format
(
model_prefix
))
else
:
else
:
logger
.
info
(
"save model in {}"
.
format
(
model_prefix
))
logger
.
info
(
"save model in {}"
.
format
(
model_prefix
))
ppocr/utils/utility.py
View file @
a323fce6
...
@@ -77,4 +77,22 @@ def check_and_read_gif(img_path):
...
@@ -77,4 +77,22 @@ def check_and_read_gif(img_path):
frame
=
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_GRAY2RGB
)
frame
=
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_GRAY2RGB
)
imgvalue
=
frame
[:,
:,
::
-
1
]
imgvalue
=
frame
[:,
:,
::
-
1
]
return
imgvalue
,
True
return
imgvalue
,
True
return
None
,
False
return
None
,
False
\ No newline at end of file
def
load_vqa_bio_label_maps
(
label_map_path
):
with
open
(
label_map_path
,
"r"
,
encoding
=
'utf-8'
)
as
fin
:
lines
=
fin
.
readlines
()
lines
=
[
line
.
strip
()
for
line
in
lines
]
if
"O"
not
in
lines
:
lines
.
insert
(
0
,
"O"
)
labels
=
[]
for
line
in
lines
:
if
line
==
"O"
:
labels
.
append
(
"O"
)
else
:
labels
.
append
(
"B-"
+
line
)
labels
.
append
(
"I-"
+
line
)
label2id_map
=
{
label
:
idx
for
idx
,
label
in
enumerate
(
labels
)}
id2label_map
=
{
idx
:
label
for
idx
,
label
in
enumerate
(
labels
)}
return
label2id_map
,
id2label_map
ppocr/utils/visual.py
0 → 100644
View file @
a323fce6
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
numpy
as
np
from
PIL
import
Image
,
ImageDraw
,
ImageFont
def
draw_ser_results
(
image
,
ocr_results
,
font_path
=
"doc/fonts/simfang.ttf"
,
font_size
=
18
):
np
.
random
.
seed
(
2021
)
color
=
(
np
.
random
.
permutation
(
range
(
255
)),
np
.
random
.
permutation
(
range
(
255
)),
np
.
random
.
permutation
(
range
(
255
)))
color_map
=
{
idx
:
(
color
[
0
][
idx
],
color
[
1
][
idx
],
color
[
2
][
idx
])
for
idx
in
range
(
1
,
255
)
}
if
isinstance
(
image
,
np
.
ndarray
):
image
=
Image
.
fromarray
(
image
)
elif
isinstance
(
image
,
str
)
and
os
.
path
.
isfile
(
image
):
image
=
Image
.
open
(
image
).
convert
(
'RGB'
)
img_new
=
image
.
copy
()
draw
=
ImageDraw
.
Draw
(
img_new
)
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
for
ocr_info
in
ocr_results
:
if
ocr_info
[
"pred_id"
]
not
in
color_map
:
continue
color
=
color_map
[
ocr_info
[
"pred_id"
]]
text
=
"{}: {}"
.
format
(
ocr_info
[
"pred"
],
ocr_info
[
"text"
])
draw_box_txt
(
ocr_info
[
"bbox"
],
text
,
draw
,
font
,
font_size
,
color
)
img_new
=
Image
.
blend
(
image
,
img_new
,
0.5
)
return
np
.
array
(
img_new
)
def
draw_box_txt
(
bbox
,
text
,
draw
,
font
,
font_size
,
color
):
# draw ocr results outline
bbox
=
((
bbox
[
0
],
bbox
[
1
]),
(
bbox
[
2
],
bbox
[
3
]))
draw
.
rectangle
(
bbox
,
fill
=
color
)
# draw ocr results
start_y
=
max
(
0
,
bbox
[
0
][
1
]
-
font_size
)
tw
=
font
.
getsize
(
text
)[
0
]
draw
.
rectangle
(
[(
bbox
[
0
][
0
]
+
1
,
start_y
),
(
bbox
[
0
][
0
]
+
tw
+
1
,
start_y
+
font_size
)],
fill
=
(
0
,
0
,
255
))
draw
.
text
((
bbox
[
0
][
0
]
+
1
,
start_y
),
text
,
fill
=
(
255
,
255
,
255
),
font
=
font
)
def
draw_re_results
(
image
,
result
,
font_path
=
"doc/fonts/simfang.ttf"
,
font_size
=
18
):
np
.
random
.
seed
(
0
)
if
isinstance
(
image
,
np
.
ndarray
):
image
=
Image
.
fromarray
(
image
)
elif
isinstance
(
image
,
str
)
and
os
.
path
.
isfile
(
image
):
image
=
Image
.
open
(
image
).
convert
(
'RGB'
)
img_new
=
image
.
copy
()
draw
=
ImageDraw
.
Draw
(
img_new
)
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
color_head
=
(
0
,
0
,
255
)
color_tail
=
(
255
,
0
,
0
)
color_line
=
(
0
,
255
,
0
)
for
ocr_info_head
,
ocr_info_tail
in
result
:
draw_box_txt
(
ocr_info_head
[
"bbox"
],
ocr_info_head
[
"text"
],
draw
,
font
,
font_size
,
color_head
)
draw_box_txt
(
ocr_info_tail
[
"bbox"
],
ocr_info_tail
[
"text"
],
draw
,
font
,
font_size
,
color_tail
)
center_head
=
(
(
ocr_info_head
[
'bbox'
][
0
]
+
ocr_info_head
[
'bbox'
][
2
])
//
2
,
(
ocr_info_head
[
'bbox'
][
1
]
+
ocr_info_head
[
'bbox'
][
3
])
//
2
)
center_tail
=
(
(
ocr_info_tail
[
'bbox'
][
0
]
+
ocr_info_tail
[
'bbox'
][
2
])
//
2
,
(
ocr_info_tail
[
'bbox'
][
1
]
+
ocr_info_tail
[
'bbox'
][
3
])
//
2
)
draw
.
line
([
center_head
,
center_tail
],
fill
=
color_line
,
width
=
5
)
img_new
=
Image
.
blend
(
image
,
img_new
,
0.5
)
return
np
.
array
(
img_new
)
ppstructure/docs/model_list.md
View file @
a323fce6
...
@@ -24,8 +24,8 @@
...
@@ -24,8 +24,8 @@
|模型名称|模型简介|推理模型大小|下载地址|
|模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- |
| --- | --- | --- | --- |
|PP-Layout_v1.0_ser_pretrained|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|
[
推理模型 coming soon
](
)
/
[
训练模型
](
https://paddleocr.bj.bcebos.com/pplayout/
PP-
Layout
_v1.0_ser_pretrained
.tar
)
|
|PP-Layout_v1.0_ser_pretrained|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|
[
推理模型 coming soon
](
)
/
[
训练模型
](
https://paddleocr.bj.bcebos.com/pplayout/
re_
Layout
XLM_xfun_zh
.tar
)
|
|PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|
[
推理模型 coming soon
](
)
/
[
训练模型
](
https://paddleocr.bj.bcebos.com/pplayout/
PP-
Layout
_v1.0_re_pretrained
.tar
)
|
|PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|
[
推理模型 coming soon
](
)
/
[
训练模型
](
https://paddleocr.bj.bcebos.com/pplayout/
ser_
Layout
XLM_xfun_zh
.tar
)
|
## 3. KIE模型
## 3. KIE模型
...
...
ppstructure/vqa/README.md
View file @
a323fce6
...
@@ -20,11 +20,11 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
...
@@ -20,11 +20,11 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
我们在
[
XFUN
](
https://github.com/doc-analysis/XFUND
)
的中文数据集上对算法进行了评估,性能如下
我们在
[
XFUN
](
https://github.com/doc-analysis/XFUND
)
的中文数据集上对算法进行了评估,性能如下
| 模型 | 任务 |
f1
| 模型下载地址 |
| 模型 | 任务 |
hmean
| 模型下载地址 |
|:---:|:---:|:---:| :---:|
|:---:|:---:|:---:| :---:|
| LayoutXLM | RE | 0.7
11
3 |
[
链接
](
https://paddleocr.bj.bcebos.com/pplayout/
PP-
Layout
_v1.0_re_pretrained
.tar
)
|
| LayoutXLM | RE | 0.7
48
3 |
[
链接
](
https://paddleocr.bj.bcebos.com/pplayout/
re_
Layout
XLM_xfun_zh
.tar
)
|
| LayoutXLM | SER | 0.90
56
|
[
链接
](
https://paddleocr.bj.bcebos.com/pplayout/
PP-
Layout
_v1.0_ser_pretrained
.tar
)
|
| LayoutXLM | SER | 0.90
38
|
[
链接
](
https://paddleocr.bj.bcebos.com/pplayout/
ser_
Layout
XLM_xfun_zh
.tar
)
|
| LayoutLM | SER | 0.7
8
|
[
链接
](
https://paddleocr.bj.bcebos.com/pplayout/LayoutLM_
ser_pretrained
.tar
)
|
| LayoutLM | SER | 0.7
731
|
[
链接
](
https://paddleocr.bj.bcebos.com/pplayout/
ser_
LayoutLM_
xfun_zh
.tar
)
|
...
@@ -65,10 +65,10 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
...
@@ -65,10 +65,10 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
pip3
install
--upgrade
pip
pip3
install
--upgrade
pip
# GPU安装
# GPU安装
python3
-m
pip
install
paddlepaddle-gpu
=
=
2.2
-i
https://mirror.baidu.com/pypi/simple
python3
-m
pip
install
"
paddlepaddle-gpu
>
=2.2
"
-i
https://mirror.baidu.com/pypi/simple
# CPU安装
# CPU安装
python3
-m
pip
install
paddlepaddle
=
=
2.2
-i
https://mirror.baidu.com/pypi/simple
python3
-m
pip
install
"
paddlepaddle
>
=2.2
"
-i
https://mirror.baidu.com/pypi/simple
```
```
更多需求,请参照
[
安装文档
](
https://www.paddlepaddle.org.cn/install/quick
)
中的说明进行操作。
更多需求,请参照
[
安装文档
](
https://www.paddlepaddle.org.cn/install/quick
)
中的说明进行操作。
...
@@ -79,7 +79,7 @@ python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
...
@@ -79,7 +79,7 @@ python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
-
**(1)pip快速安装PaddleOCR whl包(仅预测)**
-
**(1)pip快速安装PaddleOCR whl包(仅预测)**
```
bash
```
bash
pip
install
paddleocr
python3
-m
pip
install
paddleocr
```
```
-
**(2)下载VQA源码(预测+训练)**
-
**(2)下载VQA源码(预测+训练)**
...
@@ -93,18 +93,10 @@ git clone https://gitee.com/paddlepaddle/PaddleOCR
...
@@ -93,18 +93,10 @@ git clone https://gitee.com/paddlepaddle/PaddleOCR
# 注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
# 注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
```
```
-
**(3)安装
PaddleNLP
**
-
**(3)安装
VQA的`requirements`
**
```
bash
```
bash
pip3
install
"paddlenlp>=2.2.1"
python3
-m
pip
install
-r
ppstructure/vqa/requirements.txt
```
-
**(4)安装VQA的`requirements`**
```
bash
cd
ppstructure/vqa
pip
install
-r
requirements.txt
```
```
## 4. 使用
## 4. 使用
...
@@ -112,6 +104,10 @@ pip install -r requirements.txt
...
@@ -112,6 +104,10 @@ pip install -r requirements.txt
### 4.1 数据和预训练模型准备
### 4.1 数据和预训练模型准备
如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。
*
下载处理好的数据集
处理好的XFUN中文数据集下载地址:
[
https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
](
https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
)
。
处理好的XFUN中文数据集下载地址:
[
https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
](
https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
)
。
...
@@ -121,98 +117,62 @@ pip install -r requirements.txt
...
@@ -121,98 +117,62 @@ pip install -r requirements.txt
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
```
```
如果希望转换XFUN中其他语言的数据集,可以参考
[
XFUN数据转换脚本
](
helper/trans_xfun_data.py
)
。
*
转换数据集
如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。
若需进行其他XFUN数据集的训练,可使用下面的命令进行数据集的转换
```
bash
python3 ppstructure/vqa/helper/trans_xfun_data.py
--ori_gt_path
=
path/to/json_path
--output_path
=
path/to/save_path
```
### 4.2 SER任务
### 4.2 SER任务
*
启动训练
启动训练之前,需要修改下面的四个字段
1.
`Train.dataset.data_dir`
:指向训练集图片存放目录
2.
`Train.dataset.label_file_list`
:指向训练集标注文件
3.
`Eval.dataset.data_dir`
:指指向验证集图片存放目录
4.
`Eval.dataset.label_file_list`
:指向验证集标注文件
*
启动训练
```
shell
```
shell
python3.7 train_ser.py
\
CUDA_VISIBLE_DEVICES
=
0 python3 tools/train.py
-c
configs/vqa/ser/layoutxlm.yml
--model_name_or_path
"layoutxlm-base-uncased"
\
--ser_model_type
"LayoutXLM"
\
--train_data_dir
"XFUND/zh_train/image"
\
--train_label_path
"XFUND/zh_train/xfun_normalize_train.json"
\
--eval_data_dir
"XFUND/zh_val/image"
\
--eval_label_path
"XFUND/zh_val/xfun_normalize_val.json"
\
--num_train_epochs
200
\
--eval_steps
10
\
--output_dir
"./output/ser/"
\
--learning_rate
5e-5
\
--warmup_steps
50
\
--evaluate_during_training
\
--seed
2048
```
```
最终会打印出
`precision`
,
`recall`
,
`f1`
等指标,模型和训练日志会保存在
`./output/ser/`
文件夹中。
最终会打印出
`precision`
,
`recall`
,
`hmean`
等指标。
在
`./output/ser_layoutxlm/`
文件夹中会保存训练日志,最优的模型和最新epoch的模型。
*
恢复训练
*
恢复训练
恢复训练需要将之前训练好的模型所在文件夹路径赋值给
`Architecture.Backbone.checkpoints`
字段。
```
shell
```
shell
python3.7 train_ser.py
\
CUDA_VISIBLE_DEVICES
=
0 python3 tools/train.py
-c
configs/vqa/ser/layoutxlm.yml
-o
Architecture.Backbone.checkpoints
=
path/to/model_dir
--model_name_or_path
"model_path"
\
--ser_model_type
"LayoutXLM"
\
--train_data_dir
"XFUND/zh_train/image"
\
--train_label_path
"XFUND/zh_train/xfun_normalize_train.json"
\
--eval_data_dir
"XFUND/zh_val/image"
\
--eval_label_path
"XFUND/zh_val/xfun_normalize_val.json"
\
--num_train_epochs
200
\
--eval_steps
10
\
--output_dir
"./output/ser/"
\
--learning_rate
5e-5
\
--warmup_steps
50
\
--evaluate_during_training
\
--num_workers
8
\
--seed
2048
\
--resume
```
```
*
评估
*
评估
```
shell
export
CUDA_VISIBLE_DEVICES
=
0
python3 eval_ser.py
\
--model_name_or_path
"PP-Layout_v1.0_ser_pretrained/"
\
--ser_model_type
"LayoutXLM"
\
--eval_data_dir
"XFUND/zh_val/image"
\
--eval_label_path
"XFUND/zh_val/xfun_normalize_val.json"
\
--per_gpu_eval_batch_size
8
\
--num_workers
8
\
--output_dir
"output/ser/"
\
--seed
2048
```
最终会打印出
`precision`
,
`recall`
,
`f1`
等指标
*
使用评估集合中提供的OCR识别结果进行预测
评估需要将待评估的模型所在文件夹路径赋值给
`Architecture.Backbone.checkpoints`
字段。
```
shell
```
shell
export
CUDA_VISIBLE_DEVICES
=
0
CUDA_VISIBLE_DEVICES
=
0 python3 tools/eval.py
-c
configs/vqa/ser/layoutxlm.yml
-o
Architecture.Backbone.checkpoints
=
path/to/model_dir
python3.7 infer_ser.py
\
--model_name_or_path
"PP-Layout_v1.0_ser_pretrained/"
\
--ser_model_type
"LayoutXLM"
\
--output_dir
"output/ser/"
\
--infer_imgs
"XFUND/zh_val/image/"
\
--ocr_json_path
"XFUND/zh_val/xfun_normalize_val.json"
```
```
最终会打印出
`precision`
,
`recall`
,
`hmean`
等指标
最终会在
`output_res`
目录下保存预测结果可视化图像以及预测结果文本文件,文件名为
`infer_results.txt`
。
*
使用
`OCR引擎 + SER`
串联预测
*
使用
`OCR引擎 + SER`
串联
结果
使用
如下命令即可完成
`OCR引擎 + SER`
的
串联
预测
```
shell
```
shell
export
CUDA_VISIBLE_DEVICES
=
0
CUDA_VISIBLE_DEVICES
=
0 python3 tools/infer_vqa_token_ser.py
-c
configs/vqa/ser/layoutxlm.yml
-o
Architecture.Backbone.checkpoints
=
PP-Layout_v1.0_ser_pretrained/ Global.infer_img
=
ppstructure/vqa/images/input/zh_val_42.jpg
python3.7 infer_ser_e2e.py
\
--model_name_or_path
"PP-Layout_v1.0_ser_pretrained/"
\
--ser_model_type
"LayoutXLM"
\
--max_seq_length
512
\
--output_dir
"output/ser_e2e/"
\
--infer_imgs
"images/input/zh_val_0.jpg"
```
```
最终会在
`config.Global.save_res_path`
字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为
`infer_results.txt`
。
*
对
`OCR引擎 + SER`
预测系统进行端到端评估
*
对
`OCR引擎 + SER`
预测系统进行端到端评估
首先使用
`tools/infer_vqa_token_ser.py`
脚本完成数据集的预测,然后使用下面的命令进行评估。
```
shell
```
shell
export
CUDA_VISIBLE_DEVICES
=
0
export
CUDA_VISIBLE_DEVICES
=
0
python3.7 helper/eval_with_label_end2end.py
--gt_json_path
XFUND/zh_val/xfun_normalize_val.json
--pred_json_path
output_res/infer_results.txt
python3.7 helper/eval_with_label_end2end.py
--gt_json_path
XFUND/zh_val/xfun_normalize_val.json
--pred_json_path
output_res/infer_results.txt
...
@@ -223,102 +183,48 @@ python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_nor
...
@@ -223,102 +183,48 @@ python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_nor
*
启动训练
*
启动训练
```
shell
启动训练之前,需要修改下面的四个字段
export
CUDA_VISIBLE_DEVICES
=
0
python3 train_re.py
\
--model_name_or_path
"layoutxlm-base-uncased"
\
--train_data_dir
"XFUND/zh_train/image"
\
--train_label_path
"XFUND/zh_train/xfun_normalize_train.json"
\
--eval_data_dir
"XFUND/zh_val/image"
\
--eval_label_path
"XFUND/zh_val/xfun_normalize_val.json"
\
--label_map_path
"labels/labels_ser.txt"
\
--num_train_epochs
200
\
--eval_steps
10
\
--output_dir
"output/re/"
\
--learning_rate
5e-5
\
--warmup_steps
50
\
--per_gpu_train_batch_size
8
\
--per_gpu_eval_batch_size
8
\
--num_workers
8
\
--evaluate_during_training
\
--seed
2048
```
*
恢复训练
1.
`Train.dataset.data_dir`
:指向训练集图片存放目录
2.
`Train.dataset.label_file_list`
:指向训练集标注文件
3.
`Eval.dataset.data_dir`
:指指向验证集图片存放目录
4.
`Eval.dataset.label_file_list`
:指向验证集标注文件
```
shell
```
shell
export
CUDA_VISIBLE_DEVICES
=
0
CUDA_VISIBLE_DEVICES
=
0 python3 tools/train.py
-c
configs/vqa/re/layoutxlm.yml
python3 train_re.py
\
--model_name_or_path
"model_path"
\
--train_data_dir
"XFUND/zh_train/image"
\
--train_label_path
"XFUND/zh_train/xfun_normalize_train.json"
\
--eval_data_dir
"XFUND/zh_val/image"
\
--eval_label_path
"XFUND/zh_val/xfun_normalize_val.json"
\
--label_map_path
"labels/labels_ser.txt"
\
--num_train_epochs
2
\
--eval_steps
10
\
--output_dir
"output/re/"
\
--learning_rate
5e-5
\
--warmup_steps
50
\
--per_gpu_train_batch_size
8
\
--per_gpu_eval_batch_size
8
\
--num_workers
8
\
--evaluate_during_training
\
--seed
2048
\
--resume
```
```
最终会打印出
`precision`
,
`recall`
,
`f1`
等指标,模型和训练日志会保存在
`./output/re/`
文件夹中。
最终会打印出
`precision`
,
`recall`
,
`hmean`
等指标。
在
`./output/re_layoutxlm/`
文件夹中会保存训练日志,最优的模型和最新epoch的模型。
*
恢复训练
恢复训练需要将之前训练好的模型所在文件夹路径赋值给
`Architecture.Backbone.checkpoints`
字段。
*
评估
```
shell
```
shell
export
CUDA_VISIBLE_DEVICES
=
0
CUDA_VISIBLE_DEVICES
=
0 python3 tools/train.py
-c
configs/vqa/re/layoutxlm.yml
-o
Architecture.Backbone.checkpoints
=
path/to/model_dir
python3 eval_re.py
\
--model_name_or_path
"PP-Layout_v1.0_re_pretrained/"
\
--max_seq_length
512
\
--eval_data_dir
"XFUND/zh_val/image"
\
--eval_label_path
"XFUND/zh_val/xfun_normalize_val.json"
\
--label_map_path
"labels/labels_ser.txt"
\
--output_dir
"output/re/"
\
--per_gpu_eval_batch_size
8
\
--num_workers
8
\
--seed
2048
```
```
最终会打印出
`precision`
,
`recall`
,
`f1`
等指标
*
评估
*
使用评估集合中提供的OCR识别结果进行预测
评估需要将待评估的模型所在文件夹路径赋值给
`Architecture.Backbone.checkpoints`
字段。
```
shell
```
shell
export
CUDA_VISIBLE_DEVICES
=
0
CUDA_VISIBLE_DEVICES
=
0 python3 tools/eval.py
-c
configs/vqa/re/layoutxlm.yml
-o
Architecture.Backbone.checkpoints
=
path/to/model_dir
python3 infer_re.py
\
--model_name_or_path
"PP-Layout_v1.0_re_pretrained/"
\
--max_seq_length
512
\
--eval_data_dir
"XFUND/zh_val/image"
\
--eval_label_path
"XFUND/zh_val/xfun_normalize_val.json"
\
--label_map_path
"labels/labels_ser.txt"
\
--output_dir
"output/re/"
\
--per_gpu_eval_batch_size
1
\
--seed
2048
```
```
最终会打印出
`precision`
,
`recall`
,
`hmean`
等指标
最终会在
`output_res`
目录下保存预测结果可视化图像以及预测结果文本文件,文件名为
`infer_results.txt`
。
*
使用
`OCR引擎 + SER + RE`
串联预测
*
使用
`OCR引擎 + SER + RE`
串联结果
使用如下命令即可完成
`OCR引擎 + SER + RE`
的串联预测
```
shell
```
shell
export
CUDA_VISIBLE_DEVICES
=
0
export
CUDA_VISIBLE_DEVICES
=
0
python3.7 infer_ser_re_e2e.py
\
python3 tools/infer_vqa_token_ser_re.py
-c
configs/vqa/re/layoutxlm.yml
-o
Architecture.Backbone.checkpoints
=
PP-Layout_v1.0_re_pretrained/ Global.infer_img
=
ppstructure/vqa/images/input/zh_val_21.jpg
-c_ser
configs/vqa/ser/layoutxlm.yml
-o_ser
Architecture.Backbone.checkpoints
=
PP-Layout_v1.0_ser_pretrained/
--model_name_or_path
"PP-Layout_v1.0_ser_pretrained/"
\
--re_model_name_or_path
"PP-Layout_v1.0_re_pretrained/"
\
--ser_model_type
"LayoutXLM"
\
--max_seq_length
512
\
--output_dir
"output/ser_re_e2e/"
\
--infer_imgs
"images/input/zh_val_21.jpg"
```
```
最终会在
`config.Global.save_res_path`
字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为
`infer_results.txt`
。
## 参考链接
## 参考链接
-
LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
-
LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
...
...
ppstructure/vqa/eval_re.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
import
paddle
from
paddlenlp.transformers
import
LayoutXLMTokenizer
,
LayoutXLMModel
,
LayoutXLMForRelationExtraction
from
xfun
import
XFUNDataset
from
vqa_utils
import
parse_args
,
get_bio_label_maps
,
print_arguments
from
data_collator
import
DataCollator
from
metric
import
re_score
from
ppocr.utils.logging
import
get_logger
def
cal_metric
(
re_preds
,
re_labels
,
entities
):
gt_relations
=
[]
for
b
in
range
(
len
(
re_labels
)):
rel_sent
=
[]
for
head
,
tail
in
zip
(
re_labels
[
b
][
"head"
],
re_labels
[
b
][
"tail"
]):
rel
=
{}
rel
[
"head_id"
]
=
head
rel
[
"head"
]
=
(
entities
[
b
][
"start"
][
rel
[
"head_id"
]],
entities
[
b
][
"end"
][
rel
[
"head_id"
]])
rel
[
"head_type"
]
=
entities
[
b
][
"label"
][
rel
[
"head_id"
]]
rel
[
"tail_id"
]
=
tail
rel
[
"tail"
]
=
(
entities
[
b
][
"start"
][
rel
[
"tail_id"
]],
entities
[
b
][
"end"
][
rel
[
"tail_id"
]])
rel
[
"tail_type"
]
=
entities
[
b
][
"label"
][
rel
[
"tail_id"
]]
rel
[
"type"
]
=
1
rel_sent
.
append
(
rel
)
gt_relations
.
append
(
rel_sent
)
re_metrics
=
re_score
(
re_preds
,
gt_relations
,
mode
=
"boundaries"
)
return
re_metrics
def
evaluate
(
model
,
eval_dataloader
,
logger
,
prefix
=
""
):
# Eval!
logger
.
info
(
"***** Running evaluation {} *****"
.
format
(
prefix
))
logger
.
info
(
" Num examples = {}"
.
format
(
len
(
eval_dataloader
.
dataset
)))
re_preds
=
[]
re_labels
=
[]
entities
=
[]
eval_loss
=
0.0
model
.
eval
()
for
idx
,
batch
in
enumerate
(
eval_dataloader
):
with
paddle
.
no_grad
():
outputs
=
model
(
**
batch
)
loss
=
outputs
[
'loss'
].
mean
().
item
()
if
paddle
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
"[Eval] process: {}/{}, loss: {:.5f}"
.
format
(
idx
,
len
(
eval_dataloader
),
loss
))
eval_loss
+=
loss
re_preds
.
extend
(
outputs
[
'pred_relations'
])
re_labels
.
extend
(
batch
[
'relations'
])
entities
.
extend
(
batch
[
'entities'
])
re_metrics
=
cal_metric
(
re_preds
,
re_labels
,
entities
)
re_metrics
=
{
"precision"
:
re_metrics
[
"ALL"
][
"p"
],
"recall"
:
re_metrics
[
"ALL"
][
"r"
],
"f1"
:
re_metrics
[
"ALL"
][
"f1"
],
}
model
.
train
()
return
re_metrics
def
eval
(
args
):
logger
=
get_logger
()
label2id_map
,
id2label_map
=
get_bio_label_maps
(
args
.
label_map_path
)
pad_token_label_id
=
paddle
.
nn
.
CrossEntropyLoss
().
ignore_index
tokenizer
=
LayoutXLMTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
LayoutXLMForRelationExtraction
.
from_pretrained
(
args
.
model_name_or_path
)
eval_dataset
=
XFUNDataset
(
tokenizer
,
data_dir
=
args
.
eval_data_dir
,
label_path
=
args
.
eval_label_path
,
label2id_map
=
label2id_map
,
img_size
=
(
224
,
224
),
max_seq_len
=
args
.
max_seq_length
,
pad_token_label_id
=
pad_token_label_id
,
contains_re
=
True
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
)
eval_dataloader
=
paddle
.
io
.
DataLoader
(
eval_dataset
,
batch_size
=
args
.
per_gpu_eval_batch_size
,
num_workers
=
args
.
num_workers
,
shuffle
=
False
,
collate_fn
=
DataCollator
())
results
=
evaluate
(
model
,
eval_dataloader
,
logger
)
logger
.
info
(
"eval results: {}"
.
format
(
results
))
if
__name__
==
"__main__"
:
args
=
parse_args
()
eval
(
args
)
ppstructure/vqa/eval_ser.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
import
random
import
time
import
copy
import
logging
import
argparse
import
paddle
import
numpy
as
np
from
seqeval.metrics
import
classification_report
,
f1_score
,
precision_score
,
recall_score
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMTokenizer
,
LayoutXLMForTokenClassification
from
paddlenlp.transformers
import
LayoutLMModel
,
LayoutLMTokenizer
,
LayoutLMForTokenClassification
from
xfun
import
XFUNDataset
from
losses
import
SERLoss
from
vqa_utils
import
parse_args
,
get_bio_label_maps
,
print_arguments
from
ppocr.utils.logging
import
get_logger
MODELS
=
{
'LayoutXLM'
:
(
LayoutXLMTokenizer
,
LayoutXLMModel
,
LayoutXLMForTokenClassification
),
'LayoutLM'
:
(
LayoutLMTokenizer
,
LayoutLMModel
,
LayoutLMForTokenClassification
)
}
def
eval
(
args
):
logger
=
get_logger
()
print_arguments
(
args
,
logger
)
label2id_map
,
id2label_map
=
get_bio_label_maps
(
args
.
label_map_path
)
pad_token_label_id
=
paddle
.
nn
.
CrossEntropyLoss
().
ignore_index
tokenizer_class
,
base_model_class
,
model_class
=
MODELS
[
args
.
ser_model_type
]
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
eval_dataset
=
XFUNDataset
(
tokenizer
,
data_dir
=
args
.
eval_data_dir
,
label_path
=
args
.
eval_label_path
,
label2id_map
=
label2id_map
,
img_size
=
(
224
,
224
),
pad_token_label_id
=
pad_token_label_id
,
contains_re
=
False
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
)
eval_dataloader
=
paddle
.
io
.
DataLoader
(
eval_dataset
,
batch_size
=
args
.
per_gpu_eval_batch_size
,
num_workers
=
args
.
num_workers
,
use_shared_memory
=
True
,
collate_fn
=
None
,
)
loss_class
=
SERLoss
(
len
(
label2id_map
))
results
,
_
=
evaluate
(
args
,
model
,
tokenizer
,
loss_class
,
eval_dataloader
,
label2id_map
,
id2label_map
,
pad_token_label_id
,
logger
)
logger
.
info
(
results
)
def
evaluate
(
args
,
model
,
tokenizer
,
loss_class
,
eval_dataloader
,
label2id_map
,
id2label_map
,
pad_token_label_id
,
logger
,
prefix
=
""
):
eval_loss
=
0.0
nb_eval_steps
=
0
preds
=
None
out_label_ids
=
None
model
.
eval
()
for
idx
,
batch
in
enumerate
(
eval_dataloader
):
with
paddle
.
no_grad
():
if
args
.
ser_model_type
==
'LayoutLM'
:
if
'image'
in
batch
:
batch
.
pop
(
'image'
)
labels
=
batch
.
pop
(
'labels'
)
outputs
=
model
(
**
batch
)
if
args
.
ser_model_type
==
'LayoutXLM'
:
outputs
=
outputs
[
0
]
loss
=
loss_class
(
labels
,
outputs
,
batch
[
'attention_mask'
])
loss
=
loss
.
mean
()
if
paddle
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
"[Eval]process: {}/{}, loss: {:.5f}"
.
format
(
idx
,
len
(
eval_dataloader
),
loss
.
numpy
()[
0
]))
eval_loss
+=
loss
.
item
()
nb_eval_steps
+=
1
if
preds
is
None
:
preds
=
outputs
.
numpy
()
out_label_ids
=
labels
.
numpy
()
else
:
preds
=
np
.
append
(
preds
,
outputs
.
numpy
(),
axis
=
0
)
out_label_ids
=
np
.
append
(
out_label_ids
,
labels
.
numpy
(),
axis
=
0
)
eval_loss
=
eval_loss
/
nb_eval_steps
preds
=
np
.
argmax
(
preds
,
axis
=
2
)
# label_map = {i: label.upper() for i, label in enumerate(labels)}
out_label_list
=
[[]
for
_
in
range
(
out_label_ids
.
shape
[
0
])]
preds_list
=
[[]
for
_
in
range
(
out_label_ids
.
shape
[
0
])]
for
i
in
range
(
out_label_ids
.
shape
[
0
]):
for
j
in
range
(
out_label_ids
.
shape
[
1
]):
if
out_label_ids
[
i
,
j
]
!=
pad_token_label_id
:
out_label_list
[
i
].
append
(
id2label_map
[
out_label_ids
[
i
][
j
]])
preds_list
[
i
].
append
(
id2label_map
[
preds
[
i
][
j
]])
results
=
{
"loss"
:
eval_loss
,
"precision"
:
precision_score
(
out_label_list
,
preds_list
),
"recall"
:
recall_score
(
out_label_list
,
preds_list
),
"f1"
:
f1_score
(
out_label_list
,
preds_list
),
}
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
"test_gt.txt"
),
"w"
,
encoding
=
'utf-8'
)
as
fout
:
for
lbl
in
out_label_list
:
for
l
in
lbl
:
fout
.
write
(
l
+
"
\t
"
)
fout
.
write
(
"
\n
"
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
"test_pred.txt"
),
"w"
,
encoding
=
'utf-8'
)
as
fout
:
for
lbl
in
preds_list
:
for
l
in
lbl
:
fout
.
write
(
l
+
"
\t
"
)
fout
.
write
(
"
\n
"
)
report
=
classification_report
(
out_label_list
,
preds_list
)
logger
.
info
(
"
\n
"
+
report
)
logger
.
info
(
"***** Eval results %s *****"
,
prefix
)
for
key
in
sorted
(
results
.
keys
()):
logger
.
info
(
" %s = %s"
,
key
,
str
(
results
[
key
]))
model
.
train
()
return
results
,
preds_list
if
__name__
==
"__main__"
:
args
=
parse_args
()
eval
(
args
)
ppstructure/vqa/helper/trans_xfun_data.py
View file @
a323fce6
...
@@ -49,4 +49,16 @@ def transfer_xfun_data(json_path=None, output_file=None):
...
@@ -49,4 +49,16 @@ def transfer_xfun_data(json_path=None, output_file=None):
print
(
"===ok===="
)
print
(
"===ok===="
)
transfer_xfun_data
(
"./xfun/zh.val.json"
,
"./xfun_normalize_val.json"
)
def
parser_args
():
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
"args for paddleserving"
)
parser
.
add_argument
(
"--ori_gt_path"
,
type
=
str
,
required
=
True
,
help
=
'origin xfun gt path'
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
required
=
True
,
help
=
'path to save'
)
args
=
parser
.
parse_args
()
return
args
args
=
parser_args
()
transfer_xfun_data
(
args
.
ori_gt_path
,
args
.
output_path
)
ppstructure/vqa/infer.sh
deleted
100644 → 0
View file @
1ded2ac4
export
CUDA_VISIBLE_DEVICES
=
6
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_distributed/best_model" \
# --max_seq_length 512 \
# --output_dir "output_res_e2e/" \
# --infer_imgs "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val/zh_val_0.jpg"
# python3.7 infer_ser_re_e2e.py \
# --model_name_or_path "output/ser_distributed/best_model" \
# --re_model_name_or_path "output/re_test/best_model" \
# --max_seq_length 512 \
# --output_dir "output_ser_re_e2e_train/" \
# --infer_imgs "images/input/zh_val_21.jpg"
# python3.7 infer_ser.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --output_dir "ser_LayoutLM/" \
# --infer_imgs "images/input/zh_val_21.jpg" \
# --ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
python3.7 infer_ser.py
\
--model_name_or_path
"output/ser_new/best_model"
\
--ser_model_type
"LayoutXLM"
\
--output_dir
"ser_new/"
\
--infer_imgs
"images/input/zh_val_21.jpg"
\
--ocr_json_path
"/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_new/best_model" \
# --ser_model_type "LayoutXLM" \
# --max_seq_length 512 \
# --output_dir "output/ser_new/" \
# --infer_imgs "images/input/zh_val_0.jpg"
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --max_seq_length 512 \
# --output_dir "output/ser_LayoutLM/" \
# --infer_imgs "images/input/zh_val_0.jpg"
# python3 infer_re.py \
# --model_name_or_path "/ssd1/zhoujun20/VQA/PaddleOCR/ppstructure/vqa/output/re_test/best_model/" \
# --max_seq_length 512 \
# --eval_data_dir "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val" \
# --eval_label_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json" \
# --label_map_path 'labels/labels_ser.txt' \
# --output_dir "output_res" \
# --per_gpu_eval_batch_size 1 \
# --seed 2048
# python3.7 infer_ser_re_e2e.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --re_model_name_or_path "output/re_new/best_model" \
# --max_seq_length 512 \
# --output_dir "output_ser_re_e2e/" \
# --infer_imgs "images/input/zh_val_21.jpg"
\ No newline at end of file
ppstructure/vqa/infer_re.py
deleted
100644 → 0
View file @
1ded2ac4
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
import
random
import
cv2
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
paddle
from
paddlenlp.transformers
import
LayoutXLMTokenizer
,
LayoutXLMModel
,
LayoutXLMForRelationExtraction
from
xfun
import
XFUNDataset
from
vqa_utils
import
parse_args
,
get_bio_label_maps
,
draw_re_results
from
data_collator
import
DataCollator
from
ppocr.utils.logging
import
get_logger
def
infer
(
args
):
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
logger
=
get_logger
()
label2id_map
,
id2label_map
=
get_bio_label_maps
(
args
.
label_map_path
)
pad_token_label_id
=
paddle
.
nn
.
CrossEntropyLoss
().
ignore_index
tokenizer
=
LayoutXLMTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
LayoutXLMForRelationExtraction
.
from_pretrained
(
args
.
model_name_or_path
)
eval_dataset
=
XFUNDataset
(
tokenizer
,
data_dir
=
args
.
eval_data_dir
,
label_path
=
args
.
eval_label_path
,
label2id_map
=
label2id_map
,
img_size
=
(
224
,
224
),
max_seq_len
=
args
.
max_seq_length
,
pad_token_label_id
=
pad_token_label_id
,
contains_re
=
True
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
)
eval_dataloader
=
paddle
.
io
.
DataLoader
(
eval_dataset
,
batch_size
=
args
.
per_gpu_eval_batch_size
,
num_workers
=
8
,
shuffle
=
False
,
collate_fn
=
DataCollator
())
# 读取gt的oct数据
ocr_info_list
=
load_ocr
(
args
.
eval_data_dir
,
args
.
eval_label_path
)
for
idx
,
batch
in
enumerate
(
eval_dataloader
):
ocr_info
=
ocr_info_list
[
idx
]
image_path
=
ocr_info
[
'image_path'
]
ocr_info
=
ocr_info
[
'ocr_info'
]
save_img_path
=
os
.
path
.
join
(
args
.
output_dir
,
os
.
path
.
splitext
(
os
.
path
.
basename
(
image_path
))[
0
]
+
"_re.jpg"
)
logger
.
info
(
"[Infer] process: {}/{}, save result to {}"
.
format
(
idx
,
len
(
eval_dataloader
),
save_img_path
))
with
paddle
.
no_grad
():
outputs
=
model
(
**
batch
)
pred_relations
=
outputs
[
'pred_relations'
]
# 根据entity里的信息,做token解码后去过滤不要的ocr_info
ocr_info
=
filter_bg_by_txt
(
ocr_info
,
batch
,
tokenizer
)
# 进行 relations 到 ocr信息的转换
result
=
[]
used_tail_id
=
[]
for
relations
in
pred_relations
:
for
relation
in
relations
:
if
relation
[
'tail_id'
]
in
used_tail_id
:
continue
if
relation
[
'head_id'
]
not
in
ocr_info
or
relation
[
'tail_id'
]
not
in
ocr_info
:
continue
used_tail_id
.
append
(
relation
[
'tail_id'
])
ocr_info_head
=
ocr_info
[
relation
[
'head_id'
]]
ocr_info_tail
=
ocr_info
[
relation
[
'tail_id'
]]
result
.
append
((
ocr_info_head
,
ocr_info_tail
))
img
=
cv2
.
imread
(
image_path
)
img_show
=
draw_re_results
(
img
,
result
)
cv2
.
imwrite
(
save_img_path
,
img_show
)
def
load_ocr
(
img_folder
,
json_path
):
import
json
d
=
[]
with
open
(
json_path
,
"r"
,
encoding
=
'utf-8'
)
as
fin
:
lines
=
fin
.
readlines
()
for
line
in
lines
:
image_name
,
info_str
=
line
.
split
(
"
\t
"
)
info_dict
=
json
.
loads
(
info_str
)
info_dict
[
'image_path'
]
=
os
.
path
.
join
(
img_folder
,
image_name
)
d
.
append
(
info_dict
)
return
d
def
filter_bg_by_txt
(
ocr_info
,
batch
,
tokenizer
):
entities
=
batch
[
'entities'
][
0
]
input_ids
=
batch
[
'input_ids'
][
0
]
new_info_dict
=
{}
for
i
in
range
(
len
(
entities
[
'start'
])):
entitie_head
=
entities
[
'start'
][
i
]
entitie_tail
=
entities
[
'end'
][
i
]
word_input_ids
=
input_ids
[
entitie_head
:
entitie_tail
].
numpy
().
tolist
()
txt
=
tokenizer
.
convert_ids_to_tokens
(
word_input_ids
)
txt
=
tokenizer
.
convert_tokens_to_string
(
txt
)
for
i
,
info
in
enumerate
(
ocr_info
):
if
info
[
'text'
]
==
txt
:
new_info_dict
[
i
]
=
info
return
new_info_dict
def
post_process
(
pred_relations
,
ocr_info
,
img
):
result
=
[]
for
relations
in
pred_relations
:
for
relation
in
relations
:
ocr_info_head
=
ocr_info
[
relation
[
'head_id'
]]
ocr_info_tail
=
ocr_info
[
relation
[
'tail_id'
]]
result
.
append
((
ocr_info_head
,
ocr_info_tail
))
return
result
def
draw_re
(
result
,
image_path
,
output_folder
):
img
=
cv2
.
imread
(
image_path
)
from
matplotlib
import
pyplot
as
plt
for
ocr_info_head
,
ocr_info_tail
in
result
:
cv2
.
rectangle
(
img
,
tuple
(
ocr_info_head
[
'bbox'
][:
2
]),
tuple
(
ocr_info_head
[
'bbox'
][
2
:]),
(
255
,
0
,
0
),
thickness
=
2
)
cv2
.
rectangle
(
img
,
tuple
(
ocr_info_tail
[
'bbox'
][:
2
]),
tuple
(
ocr_info_tail
[
'bbox'
][
2
:]),
(
0
,
0
,
255
),
thickness
=
2
)
center_p1
=
[(
ocr_info_head
[
'bbox'
][
0
]
+
ocr_info_head
[
'bbox'
][
2
])
//
2
,
(
ocr_info_head
[
'bbox'
][
1
]
+
ocr_info_head
[
'bbox'
][
3
])
//
2
]
center_p2
=
[(
ocr_info_tail
[
'bbox'
][
0
]
+
ocr_info_tail
[
'bbox'
][
2
])
//
2
,
(
ocr_info_tail
[
'bbox'
][
1
]
+
ocr_info_tail
[
'bbox'
][
3
])
//
2
]
cv2
.
line
(
img
,
tuple
(
center_p1
),
tuple
(
center_p2
),
(
0
,
255
,
0
),
thickness
=
2
)
plt
.
imshow
(
img
)
plt
.
savefig
(
os
.
path
.
join
(
output_folder
,
os
.
path
.
basename
(
image_path
)),
dpi
=
600
)
# plt.show()
if
__name__
==
"__main__"
:
args
=
parse_args
()
infer
(
args
)
ppstructure/vqa/infer_ser.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
import
json
import
cv2
import
numpy
as
np
from
copy
import
deepcopy
import
paddle
# relative reference
from
vqa_utils
import
parse_args
,
get_image_file_list
,
draw_ser_results
,
get_bio_label_maps
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMTokenizer
,
LayoutXLMForTokenClassification
from
paddlenlp.transformers
import
LayoutLMModel
,
LayoutLMTokenizer
,
LayoutLMForTokenClassification
MODELS
=
{
'LayoutXLM'
:
(
LayoutXLMTokenizer
,
LayoutXLMModel
,
LayoutXLMForTokenClassification
),
'LayoutLM'
:
(
LayoutLMTokenizer
,
LayoutLMModel
,
LayoutLMForTokenClassification
)
}
def
pad_sentences
(
tokenizer
,
encoded_inputs
,
max_seq_len
=
512
,
pad_to_max_seq_len
=
True
,
return_attention_mask
=
True
,
return_token_type_ids
=
True
,
return_overflowing_tokens
=
False
,
return_special_tokens_mask
=
False
):
# Padding with larger size, reshape is carried out
max_seq_len
=
(
len
(
encoded_inputs
[
"input_ids"
])
//
max_seq_len
+
1
)
*
max_seq_len
needs_to_be_padded
=
pad_to_max_seq_len
and
\
max_seq_len
and
len
(
encoded_inputs
[
"input_ids"
])
<
max_seq_len
if
needs_to_be_padded
:
difference
=
max_seq_len
-
len
(
encoded_inputs
[
"input_ids"
])
if
tokenizer
.
padding_side
==
'right'
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
+
[
0
]
*
difference
if
return_token_type_ids
:
encoded_inputs
[
"token_type_ids"
]
=
(
encoded_inputs
[
"token_type_ids"
]
+
[
tokenizer
.
pad_token_type_id
]
*
difference
)
if
return_special_tokens_mask
:
encoded_inputs
[
"special_tokens_mask"
]
=
encoded_inputs
[
"special_tokens_mask"
]
+
[
1
]
*
difference
encoded_inputs
[
"input_ids"
]
=
encoded_inputs
[
"input_ids"
]
+
[
tokenizer
.
pad_token_id
]
*
difference
encoded_inputs
[
"bbox"
]
=
encoded_inputs
[
"bbox"
]
+
[[
0
,
0
,
0
,
0
]
]
*
difference
else
:
assert
False
,
"padding_side of tokenizer just supports [
\"
right
\"
] but got {}"
.
format
(
tokenizer
.
padding_side
)
else
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
return
encoded_inputs
def
split_page
(
encoded_inputs
,
max_seq_len
=
512
):
"""
truncate is often used in training process
"""
for
key
in
encoded_inputs
:
encoded_inputs
[
key
]
=
paddle
.
to_tensor
(
encoded_inputs
[
key
])
if
encoded_inputs
[
key
].
ndim
<=
1
:
# for input_ids, att_mask and so on
encoded_inputs
[
key
]
=
encoded_inputs
[
key
].
reshape
([
-
1
,
max_seq_len
])
else
:
# for bbox
encoded_inputs
[
key
]
=
encoded_inputs
[
key
].
reshape
(
[
-
1
,
max_seq_len
,
4
])
return
encoded_inputs
def
preprocess
(
tokenizer
,
ori_img
,
ocr_info
,
img_size
=
(
224
,
224
),
pad_token_label_id
=-
100
,
max_seq_len
=
512
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
):
ocr_info
=
deepcopy
(
ocr_info
)
height
=
ori_img
.
shape
[
0
]
width
=
ori_img
.
shape
[
1
]
img
=
cv2
.
resize
(
ori_img
,
(
224
,
224
)).
transpose
([
2
,
0
,
1
]).
astype
(
np
.
float32
)
segment_offset_id
=
[]
words_list
=
[]
bbox_list
=
[]
input_ids_list
=
[]
token_type_ids_list
=
[]
for
info
in
ocr_info
:
# x1, y1, x2, y2
bbox
=
info
[
"bbox"
]
bbox
[
0
]
=
int
(
bbox
[
0
]
*
1000.0
/
width
)
bbox
[
2
]
=
int
(
bbox
[
2
]
*
1000.0
/
width
)
bbox
[
1
]
=
int
(
bbox
[
1
]
*
1000.0
/
height
)
bbox
[
3
]
=
int
(
bbox
[
3
]
*
1000.0
/
height
)
text
=
info
[
"text"
]
encode_res
=
tokenizer
.
encode
(
text
,
pad_to_max_seq_len
=
False
,
return_attention_mask
=
True
)
if
not
add_special_ids
:
# TODO: use tok.all_special_ids to remove
encode_res
[
"input_ids"
]
=
encode_res
[
"input_ids"
][
1
:
-
1
]
encode_res
[
"token_type_ids"
]
=
encode_res
[
"token_type_ids"
][
1
:
-
1
]
encode_res
[
"attention_mask"
]
=
encode_res
[
"attention_mask"
][
1
:
-
1
]
input_ids_list
.
extend
(
encode_res
[
"input_ids"
])
token_type_ids_list
.
extend
(
encode_res
[
"token_type_ids"
])
bbox_list
.
extend
([
bbox
]
*
len
(
encode_res
[
"input_ids"
]))
words_list
.
append
(
text
)
segment_offset_id
.
append
(
len
(
input_ids_list
))
encoded_inputs
=
{
"input_ids"
:
input_ids_list
,
"token_type_ids"
:
token_type_ids_list
,
"bbox"
:
bbox_list
,
"attention_mask"
:
[
1
]
*
len
(
input_ids_list
),
}
encoded_inputs
=
pad_sentences
(
tokenizer
,
encoded_inputs
,
max_seq_len
=
max_seq_len
,
return_attention_mask
=
return_attention_mask
)
encoded_inputs
=
split_page
(
encoded_inputs
)
fake_bs
=
encoded_inputs
[
"input_ids"
].
shape
[
0
]
encoded_inputs
[
"image"
]
=
paddle
.
to_tensor
(
img
).
unsqueeze
(
0
).
expand
(
[
fake_bs
]
+
list
(
img
.
shape
))
encoded_inputs
[
"segment_offset_id"
]
=
segment_offset_id
return
encoded_inputs
def
postprocess
(
attention_mask
,
preds
,
label_map_path
):
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds
=
np
.
argmax
(
preds
,
axis
=
2
)
_
,
label_map
=
get_bio_label_maps
(
label_map_path
)
preds_list
=
[[]
for
_
in
range
(
preds
.
shape
[
0
])]
# keep batch info
for
i
in
range
(
preds
.
shape
[
0
]):
for
j
in
range
(
preds
.
shape
[
1
]):
if
attention_mask
[
i
][
j
]
==
1
:
preds_list
[
i
].
append
(
label_map
[
preds
[
i
][
j
]])
return
preds_list
def
merge_preds_list_with_ocr_info
(
label_map_path
,
ocr_info
,
segment_offset_id
,
preds_list
):
# must ensure the preds_list is generated from the same image
preds
=
[
p
for
pred
in
preds_list
for
p
in
pred
]
label2id_map
,
_
=
get_bio_label_maps
(
label_map_path
)
for
key
in
label2id_map
:
if
key
.
startswith
(
"I-"
):
label2id_map
[
key
]
=
label2id_map
[
"B"
+
key
[
1
:]]
id2label_map
=
dict
()
for
key
in
label2id_map
:
val
=
label2id_map
[
key
]
if
key
==
"O"
:
id2label_map
[
val
]
=
key
if
key
.
startswith
(
"B-"
)
or
key
.
startswith
(
"I-"
):
id2label_map
[
val
]
=
key
[
2
:]
else
:
id2label_map
[
val
]
=
key
for
idx
in
range
(
len
(
segment_offset_id
)):
if
idx
==
0
:
start_id
=
0
else
:
start_id
=
segment_offset_id
[
idx
-
1
]
end_id
=
segment_offset_id
[
idx
]
curr_pred
=
preds
[
start_id
:
end_id
]
curr_pred
=
[
label2id_map
[
p
]
for
p
in
curr_pred
]
if
len
(
curr_pred
)
<=
0
:
pred_id
=
0
else
:
counts
=
np
.
bincount
(
curr_pred
)
pred_id
=
np
.
argmax
(
counts
)
ocr_info
[
idx
][
"pred_id"
]
=
int
(
pred_id
)
ocr_info
[
idx
][
"pred"
]
=
id2label_map
[
pred_id
]
return
ocr_info
@
paddle
.
no_grad
()
def
infer
(
args
):
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
# init token and model
tokenizer_class
,
base_model_class
,
model_class
=
MODELS
[
args
.
ser_model_type
]
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
.
eval
()
# load ocr results json
ocr_results
=
dict
()
with
open
(
args
.
ocr_json_path
,
"r"
,
encoding
=
'utf-8'
)
as
fin
:
lines
=
fin
.
readlines
()
for
line
in
lines
:
img_name
,
json_info
=
line
.
split
(
"
\t
"
)
ocr_results
[
os
.
path
.
basename
(
img_name
)]
=
json
.
loads
(
json_info
)
# get infer img list
infer_imgs
=
get_image_file_list
(
args
.
infer_imgs
)
# loop for infer
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
"infer_results.txt"
),
"w"
,
encoding
=
'utf-8'
)
as
fout
:
for
idx
,
img_path
in
enumerate
(
infer_imgs
):
save_img_path
=
os
.
path
.
join
(
args
.
output_dir
,
os
.
path
.
basename
(
img_path
))
print
(
"process: [{}/{}], save result to {}"
.
format
(
idx
,
len
(
infer_imgs
),
save_img_path
))
img
=
cv2
.
imread
(
img_path
)
ocr_info
=
ocr_results
[
os
.
path
.
basename
(
img_path
)][
"ocr_info"
]
inputs
=
preprocess
(
tokenizer
=
tokenizer
,
ori_img
=
img
,
ocr_info
=
ocr_info
,
max_seq_len
=
args
.
max_seq_length
)
if
args
.
ser_model_type
==
'LayoutLM'
:
preds
=
model
(
input_ids
=
inputs
[
"input_ids"
],
bbox
=
inputs
[
"bbox"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
attention_mask
=
inputs
[
"attention_mask"
])
elif
args
.
ser_model_type
==
'LayoutXLM'
:
preds
=
model
(
input_ids
=
inputs
[
"input_ids"
],
bbox
=
inputs
[
"bbox"
],
image
=
inputs
[
"image"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
attention_mask
=
inputs
[
"attention_mask"
])
preds
=
preds
[
0
]
preds
=
postprocess
(
inputs
[
"attention_mask"
],
preds
,
args
.
label_map_path
)
ocr_info
=
merge_preds_list_with_ocr_info
(
args
.
label_map_path
,
ocr_info
,
inputs
[
"segment_offset_id"
],
preds
)
fout
.
write
(
img_path
+
"
\t
"
+
json
.
dumps
(
{
"ocr_info"
:
ocr_info
,
},
ensure_ascii
=
False
)
+
"
\n
"
)
img_res
=
draw_ser_results
(
img
,
ocr_info
)
cv2
.
imwrite
(
save_img_path
,
img_res
)
return
if
__name__
==
"__main__"
:
args
=
parse_args
()
infer
(
args
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment